From 43917317b46bce67865b022b63894945a77a11cd Mon Sep 17 00:00:00 2001 From: thomas girod Date: Tue, 24 Sep 2024 12:38:10 +0200 Subject: [PATCH] optimize file recursive rights --- core/models.py | 34 ++++++++++++++++++++++++++-------- core/tests/test_family.py | 2 +- core/tests/test_files.py | 37 ++++++++++++++++++++++++++++++++++++- 3 files changed, 63 insertions(+), 10 deletions(-) diff --git a/core/models.py b/core/models.py index a0c8c794..41087c51 100644 --- a/core/models.py +++ b/core/models.py @@ -1057,20 +1057,38 @@ class SithFile(models.Model): if self.is_file and (self.file is None or self.file == ""): raise ValidationError(_("You must provide a file")) - def apply_rights_recursively(self, *, only_folders=False): - children = self.children.all() - if only_folders: - children = children.filter(is_folder=True) - for c in children: - c.copy_rights() - c.apply_rights_recursively(only_folders=only_folders) + def apply_rights_recursively(self, *, only_folders: bool = False) -> None: + """Apply the rights of this file to all children recursively. + + Args: + only_folders: If True, only apply the rights to SithFiles that are folders. + """ + file_ids = [] + explored_ids = [self.id] + while len(explored_ids) > 0: # find all children recursively + file_ids.extend(explored_ids) + next_level = SithFile.objects.filter(parent_id__in=explored_ids) + if only_folders: + next_level = next_level.filter(is_folder=True) + explored_ids = list(next_level.values_list("id", flat=True)) + for through in (SithFile.view_groups.through, SithFile.edit_groups.through): + # force evaluation. Without this, the iterator yields nothing + groups = list( + through.objects.filter(sithfile_id=self.id).values_list( + "group_id", flat=True + ) + ) + # delete previous rights + through.objects.filter(sithfile_id__in=file_ids).delete() + through.objects.bulk_create( # create new rights + [through(sithfile_id=f, group_id=g) for f in file_ids for g in groups] + ) def copy_rights(self): """Copy, if possible, the rights of the parent folder.""" if self.parent is not None: self.edit_groups.set(self.parent.edit_groups.all()) self.view_groups.set(self.parent.view_groups.all()) - self.save() def move_to(self, parent): """Move a file to a new parent. diff --git a/core/tests/test_family.py b/core/tests/test_family.py index 842ff12a..58d90a92 100644 --- a/core/tests/test_family.py +++ b/core/tests/test_family.py @@ -22,7 +22,7 @@ class TestFetchFamilyApi(TestCase): # <- user5 cls.main_user = baker.make(User) - cls.users = baker.make(User, _quantity=17) + cls.users = baker.make(User, _quantity=17, _bulk_create=True) cls.main_user.godfathers.add(*cls.users[0:3]) cls.main_user.godchildren.add(*cls.users[3:6]) cls.users[1].godfathers.add(cls.users[6]) diff --git a/core/tests/test_files.py b/core/tests/test_files.py index f750c0b0..a887b93c 100644 --- a/core/tests/test_files.py +++ b/core/tests/test_files.py @@ -1,4 +1,5 @@ from io import BytesIO +from itertools import cycle from typing import Callable from uuid import uuid4 @@ -10,9 +11,10 @@ from django.urls import reverse from model_bakery import baker from model_bakery.recipe import Recipe, foreign_key from PIL import Image +from pytest_django.asserts import assertNumQueries from core.baker_recipes import board_user, subscriber_user -from core.models import SithFile, User +from core.models import Group, SithFile, User @pytest.mark.django_db @@ -184,3 +186,36 @@ class TestUserProfilePicture: assert user.profile_pict is not None # uploaded images should be converted to WEBP assert Image.open(user.profile_pict.file).format == "WEBP" + + +@pytest.mark.django_db +def test_apply_rights_recursively(): + """Test that the apply_rights_recursively method works as intended.""" + files = [baker.make(SithFile)] + files.extend(baker.make(SithFile, _quantity=3, parent=files[0], _bulk_create=True)) + files.extend( + baker.make(SithFile, _quantity=3, parent=iter(files[1:4]), _bulk_create=True) + ) + files.extend( + baker.make(SithFile, _quantity=6, parent=cycle(files[4:7]), _bulk_create=True) + ) + + groups = list(baker.make(Group, _quantity=7)) + files[0].view_groups.set(groups[:3]) + files[0].edit_groups.set(groups[2:6]) + + # those groups should be erased after the function call + files[1].view_groups.set(groups[6:]) + + with assertNumQueries(10): + # 1 query for each level of depth (here 4) + # 1 query to get the view_groups of the first file + # 1 query to delete the previous view_groups + # 1 query apply the new view_groups + # same 3 queries for the edit_groups + files[0].apply_rights_recursively() + for file in SithFile.objects.filter(pk__in=[f.pk for f in files]).prefetch_related( + "view_groups", "edit_groups" + ): + assert set(file.view_groups.all()) == set(groups[:3]) + assert set(file.edit_groups.all()) == set(groups[2:6])