diff --git a/club/forms.py b/club/forms.py index da0f5fb7..6bb09fc4 100644 --- a/club/forms.py +++ b/club/forms.py @@ -24,13 +24,15 @@ from django import forms from django.conf import settings +from django.db.models import Exists, OuterRef, Q +from django.db.models.functions import Lower from django.utils.translation import gettext_lazy as _ from club.models import Club, Mailing, MailingSubscription, Membership from core.models import User from core.views.forms import SelectDate, SelectDateTime from core.views.widgets.ajax_select import AutoCompleteSelectMultipleUser -from counter.models import Counter +from counter.models import Counter, Selling class ClubEditForm(forms.ModelForm): @@ -159,12 +161,20 @@ class SellingsForm(forms.Form): label=_("End date"), widget=SelectDateTime, required=False ) - counters = forms.ModelMultipleChoiceField( - Counter.objects.order_by("name").all(), label=_("Counter"), required=False - ) - def __init__(self, club, *args, **kwargs): super().__init__(*args, **kwargs) + counters_qs = ( + Counter.objects.filter( + Q(club=club) + | Q(products__club=club) + | Exists(Selling.objects.filter(counter=OuterRef("pk"), club=club)) + ) + .distinct() + .order_by(Lower("name")) + ) + self.fields["counters"] = forms.ModelMultipleChoiceField( + counters_qs, label=_("Counter"), required=False + ) self.fields["products"] = forms.ModelMultipleChoiceField( club.products.order_by("name").filter(archived=False).all(), label=_("Products"), diff --git a/club/tests/test_sales.py b/club/tests/test_sales.py index ad4733de..6e734f80 100644 --- a/club/tests/test_sales.py +++ b/club/tests/test_sales.py @@ -3,8 +3,11 @@ from django.test import Client from django.urls import reverse from model_bakery import baker +from club.forms import SellingsForm from club.models import Club from core.models import User +from counter.baker_recipes import product_recipe, sale_recipe +from counter.models import Counter, Customer @pytest.mark.django_db @@ -14,3 +17,22 @@ def test_sales_page_doesnt_crash(client: Client): client.force_login(admin) response = client.get(reverse("club:club_sellings", kwargs={"club_id": club.id})) assert response.status_code == 200 + + +@pytest.mark.django_db +def test_sales_form_counter_filter(): + """Test that counters are properly filtered in SellingsForm""" + club = baker.make(Club) + counters = baker.make( + Counter, _quantity=5, _bulk_create=True, name=iter(["Z", "a", "B", "e", "f"]) + ) + counters[0].club = club + counters[0].save() + sale_recipe.make( + counter=counters[1], club=club, unit_price=0, customer=baker.make(Customer) + ) + product_recipe.make(counters=[counters[2]], club=club) + + form = SellingsForm(club) + form_counters = list(form.fields["counters"].queryset) + assert form_counters == [counters[1], counters[2], counters[0]]