mirror of
https://github.com/ae-utbm/sith.git
synced 2025-10-30 00:23:54 +00:00
112 lines
4.2 KiB
Python
112 lines
4.2 KiB
Python
from unittest import mock
|
|
from unittest.mock import Mock
|
|
|
|
from django.db.models import Max
|
|
from django.test import TestCase
|
|
from django.urls import reverse
|
|
from model_bakery import baker
|
|
from pytest_django.asserts import assertRedirects
|
|
|
|
from api.models import ApiClient, get_hmac_key
|
|
from core.baker_recipes import subscriber_user
|
|
from core.utils import hmac_hexdigest
|
|
|
|
|
|
def mocked_post(*, ok: bool):
|
|
class MockedResponse(Mock):
|
|
@property
|
|
def ok(self):
|
|
return ok
|
|
|
|
def mocked():
|
|
return MockedResponse()
|
|
|
|
return mocked
|
|
|
|
|
|
class TestThirdPartyAuth(TestCase):
|
|
@classmethod
|
|
def setUpTestData(cls):
|
|
cls.user = subscriber_user.make()
|
|
cls.api_client = baker.make(ApiClient)
|
|
|
|
def setUp(self):
|
|
self.query = {
|
|
"client_id": self.api_client.id,
|
|
"third_party_app": "app",
|
|
"cgu_link": "https://foobar.fr/",
|
|
"username": "bibou",
|
|
"callback_url": "https://callback.fr/",
|
|
}
|
|
self.query["signature"] = hmac_hexdigest(self.api_client.hmac_key, self.query)
|
|
self.callback_data = {"user_id": self.user.id}
|
|
self.callback_data["signature"] = hmac_hexdigest(
|
|
self.api_client.hmac_key, self.callback_data
|
|
)
|
|
|
|
def test_auth_ok(self):
|
|
self.client.force_login(self.user)
|
|
res = self.client.get(reverse("api-link:third-party-auth", query=self.query))
|
|
assert res.status_code == 200
|
|
with mock.patch("requests.post", new_callable=mocked_post(ok=True)) as mocked:
|
|
res = self.client.post(
|
|
reverse("api-link:third-party-auth"),
|
|
data={"cgu_accepted": True, "is_username_valid": True, **self.query},
|
|
)
|
|
mocked.assert_called_once_with(
|
|
self.query["callback_url"], data=self.callback_data
|
|
)
|
|
assertRedirects(
|
|
res,
|
|
reverse("api-link:third-party-auth-result", kwargs={"result": "success"}),
|
|
)
|
|
|
|
def test_callback_error(self):
|
|
"""Test that the user see the failure page if the callback request failed."""
|
|
self.client.force_login(self.user)
|
|
with mock.patch("requests.post", new_callable=mocked_post(ok=False)) as mocked:
|
|
res = self.client.post(
|
|
reverse("api-link:third-party-auth"),
|
|
data={"cgu_accepted": True, "is_username_valid": True, **self.query},
|
|
)
|
|
mocked.assert_called_once_with(
|
|
self.query["callback_url"], data=self.callback_data
|
|
)
|
|
assertRedirects(
|
|
res,
|
|
reverse("api-link:third-party-auth-result", kwargs={"result": "failure"}),
|
|
)
|
|
|
|
def test_wrong_signature(self):
|
|
"""Test that a 403 is raised if the signature of the query is wrong."""
|
|
self.client.force_login(subscriber_user.make())
|
|
new_key = get_hmac_key()
|
|
del self.query["signature"]
|
|
self.query["signature"] = hmac_hexdigest(new_key, self.query)
|
|
res = self.client.get(reverse("api-link:third-party-auth", query=self.query))
|
|
assert res.status_code == 403
|
|
|
|
def test_cgu_not_accepted(self):
|
|
self.client.force_login(self.user)
|
|
res = self.client.get(reverse("api-link:third-party-auth", query=self.query))
|
|
assert res.status_code == 200
|
|
res = self.client.post(reverse("api-link:third-party-auth"), data=self.query)
|
|
assert res.status_code == 200 # no redirect means invalid form
|
|
res = self.client.post(
|
|
reverse("api-link:third-party-auth"),
|
|
data={"cgu_accepted": False, "is_username_valid": False, **self.query},
|
|
)
|
|
assert res.status_code == 200
|
|
|
|
def test_invalid_client(self):
|
|
self.query["client_id"] = ApiClient.objects.aggregate(res=Max("id"))["res"] + 1
|
|
res = self.client.get(reverse("api-link:third-party-auth", query=self.query))
|
|
assert res.status_code == 403
|
|
|
|
def test_missing_parameter(self):
|
|
"""Test that a 403 is raised if there is a missing parameter."""
|
|
del self.query["username"]
|
|
self.query["signature"] = hmac_hexdigest(self.api_client.hmac_key, self.query)
|
|
res = self.client.get(reverse("api-link:third-party-auth", query=self.query))
|
|
assert res.status_code == 403
|