Merge pull request #1232 from ae-utbm/group-simplify

simplify `User.is_in_group`
This commit is contained in:
thomas girod
2025-11-09 21:27:51 +01:00
committed by GitHub
2 changed files with 13 additions and 58 deletions

View File

@@ -30,7 +30,7 @@ import unicodedata
from datetime import timedelta
from io import BytesIO
from pathlib import Path
from typing import TYPE_CHECKING, Optional, Self
from typing import TYPE_CHECKING, Self
from uuid import uuid4
from django.conf import settings
@@ -97,48 +97,6 @@ def validate_promo(value: int) -> None:
)
def get_group(*, pk: int | None = None, name: str | None = None) -> Group | None:
"""Search for a group by its primary key or its name.
Either one of the two must be set.
The result is cached for the default duration (should be 5 minutes).
Args:
pk: The primary key of the group
name: The name of the group
Returns:
The group if it exists, else None
Raises:
ValueError: If no group matches the criteria
"""
if pk is None and name is None:
raise ValueError("Either pk or name must be set")
# replace space characters to hide warnings with memcached backend
pk_or_name: str | int = pk if pk is not None else name.replace(" ", "_")
group = cache.get(f"sith_group_{pk_or_name}")
if group == "not_found":
# Using None as a cache value is a little bit tricky,
# so we use a special string to represent None
return None
elif group is not None:
return group
# if this point is reached, the group is not in cache
if pk is not None:
group = Group.objects.filter(pk=pk).first()
else:
group = Group.objects.filter(name=name).first()
if group is not None:
name = group.name.replace(" ", "_")
cache.set_many({f"sith_group_{group.id}": group, f"sith_group_{name}": group})
else:
cache.set(f"sith_group_{pk_or_name}", "not_found")
return group
class BanGroup(AuthGroup):
"""An anti-group, that removes permissions instead of giving them.
@@ -382,19 +340,18 @@ class User(AbstractUser):
Returns:
True if the user is the group, else False
"""
if pk is not None:
group: Optional[Group] = get_group(pk=pk)
elif name is not None:
group: Optional[Group] = get_group(name=name)
else:
if not pk and not name:
raise ValueError("You must either provide the id or the name of the group")
if group is None:
group_id: int | None = (
pk or Group.objects.filter(name=name).values_list("id", flat=True).first()
)
if group_id is None:
return False
if group.id == settings.SITH_GROUP_SUBSCRIBERS_ID:
if group_id == settings.SITH_GROUP_SUBSCRIBERS_ID:
return self.is_subscribed
if group.id == settings.SITH_GROUP_ROOT_ID:
if group_id == settings.SITH_GROUP_ROOT_ID:
return self.is_root
return group in self.cached_groups
return any(g.id == group_id for g in self.cached_groups)
@cached_property
def cached_groups(self) -> list[Group]:
@@ -689,8 +646,8 @@ class AnonymousUser(AuthAnonymousUser):
if pk is not None:
return pk == allowed_id
elif name is not None:
group = get_group(name=name)
return group is not None and group.id == allowed_id
group = Group.objects.get(id=allowed_id)
return group.name == name
else:
raise ValueError("You must either provide the id or the name of the group")

View File

@@ -421,18 +421,16 @@ class TestUserIsInGroup(TestCase):
# clear the cached property `User.cached_groups`
self.public_user.__dict__.pop("cached_groups", None)
cache.clear()
# Test when the user is in the group
with self.assertNumQueries(2):
with self.assertNumQueries(1):
self.public_user.is_in_group(pk=group_in.id)
with self.assertNumQueries(0):
self.public_user.is_in_group(pk=group_in.id)
group_not_in = baker.make(Group)
self.public_user.__dict__.pop("cached_groups", None)
cache.clear()
# Test when the user is not in the group
with self.assertNumQueries(2):
with self.assertNumQueries(1):
self.public_user.is_in_group(pk=group_not_in.id)
with self.assertNumQueries(0):
self.public_user.is_in_group(pk=group_not_in.id)