diff --git a/club/api.py b/club/api.py index 2e59d3e5..9bf7bcad 100644 --- a/club/api.py +++ b/club/api.py @@ -1,7 +1,5 @@ -from typing import Annotated - -from annotated_types import MinLen from django.db.models import Prefetch +from ninja import Query from ninja.security import SessionAuth from ninja_extra import ControllerBase, api_controller, paginate, route from ninja_extra.pagination import PageNumberPaginationExtra @@ -10,7 +8,7 @@ from ninja_extra.schemas import PaginatedResponseSchema from api.auth import ApiKeyAuth from api.permissions import CanAccessLookup, HasPerm from club.models import Club, Membership -from club.schemas import ClubSchema, SimpleClubSchema +from club.schemas import ClubSchema, ClubSearchFilterSchema, SimpleClubSchema @api_controller("/club") @@ -23,8 +21,8 @@ class ClubController(ControllerBase): url_name="search_club", ) @paginate(PageNumberPaginationExtra, page_size=50) - def search_club(self, search: Annotated[str, MinLen(1)]): - return Club.objects.filter(name__icontains=search).values() + def search_club(self, filters: Query[ClubSearchFilterSchema]): + return filters.filter(Club.objects.all()) @route.get( "/{int:club_id}", diff --git a/club/schemas.py b/club/schemas.py index b0601af8..5a7ccccb 100644 --- a/club/schemas.py +++ b/club/schemas.py @@ -1,9 +1,26 @@ -from ninja import ModelSchema +from typing import Annotated + +from annotated_types import MinLen +from django.db.models import Q +from ninja import Field, FilterSchema, ModelSchema from club.models import Club, Membership from core.schemas import SimpleUserSchema +class ClubSearchFilterSchema(FilterSchema): + search: Annotated[str, MinLen(1)] | None = Field(None, q="name__icontains") + is_active: bool | None = None + parent_id: int | None = None + parent_name: str | None = Field(None, q="parent__name__icontains") + exclude_ids: set[int] | None = None + + def filter_exclude_ids(self, value: set[int] | None): + if value is None: + return Q() + return ~Q(id__in=value) + + class SimpleClubSchema(ModelSchema): class Meta: model = Club diff --git a/club/tests/test_club_controller.py b/club/tests/test_club_controller.py index 18a3aef1..500dbe6a 100644 --- a/club/tests/test_club_controller.py +++ b/club/tests/test_club_controller.py @@ -1,7 +1,8 @@ from datetime import date, timedelta import pytest -from django.test import Client +from django.contrib.auth.models import Permission +from django.test import Client, TestCase from django.urls import reverse from model_bakery import baker from model_bakery.recipe import Recipe @@ -9,6 +10,54 @@ from pytest_django.asserts import assertNumQueries from club.models import Club, Membership from core.baker_recipes import subscriber_user +from core.models import Group, Page, User + + +class TestClubSearch(TestCase): + @classmethod + def setUpTestData(cls): + cls.url = reverse("api:search_club") + cls.user = baker.make( + User, user_permissions=[Permission.objects.get(codename="access_lookup")] + ) + # delete existing clubs to avoid side effect + groups = list( + Group.objects.exclude(club=None, club_board=None).values_list( + "id", flat=True + ) + ) + Page.objects.exclude(club=None).delete() + Club.objects.all().delete() + Group.objects.filter(id__in=groups).delete() + + cls.clubs = baker.make( + Club, + _quantity=5, + name=iter(["AE", "ae 1", "Troll", "Dev AE", "pdf"]), + is_active=True, + ) + + def test_inactive_club(self): + self.client.force_login(self.user) + inactive_ids = {self.clubs[0].id, self.clubs[2].id} + Club.objects.filter(id__in=inactive_ids).update(is_active=False) + response = self.client.get(self.url, {"is_active": False}) + assert response.status_code == 200 + assert {d["id"] for d in response.json()["results"]} == inactive_ids + + def test_excluded_id(self): + self.client.force_login(self.user) + response = self.client.get(self.url, {"exclude_ids": [self.clubs[1].id]}) + assert response.status_code == 200 + ids = {d["id"] for d in response.json()["results"]} + assert ids == {c.id for c in [self.clubs[0], *self.clubs[2:]]} + + def test_club_search(self): + self.client.force_login(self.user) + response = self.client.get(self.url, {"search": "AE"}) + assert response.status_code == 200 + ids = {d["id"] for d in response.json()["results"]} + assert ids == {c.id for c in [self.clubs[0], self.clubs[1], self.clubs[3]]} @pytest.mark.django_db