diff --git a/core/tests.py b/core/tests.py index c220a5c..98b84aa 100644 --- a/core/tests.py +++ b/core/tests.py @@ -964,8 +964,12 @@ class AdjustmentsTabTests(TestCase): self.w2 = Worker.objects.create( name='Bob', id_number='B1', monthly_salary=Decimal('4000') ) + # Two teams, BOTH workers in BOTH teams, so the naive M2M JOIN + # multiplies rows by team count. Exercises the subquery fix. self.team = Team.objects.create(name='Alpha', supervisor=self.admin) + self.team2 = Team.objects.create(name='Beta', supervisor=self.admin) self.team.workers.add(self.w1, self.w2) + self.team2.workers.add(self.w1, self.w2) self.proj = Project.objects.create(name='Site X') # 3 unpaid adjustments — 1 bonus Alice, 1 bonus Bob, 1 deduction Alice self.a1 = PayrollAdjustment.objects.create( @@ -1003,12 +1007,13 @@ class AdjustmentsTabTests(TestCase): self.assertEqual(resp.status_code, 302) def test_type_multi_filter(self): + """?type=Bonus&type=Deduction returns the UNION (3 rows: 2 bonuses + 1 + deduction), not the intersection.""" self._login_admin() - resp = self.client.get(self.url + '&type=Bonus') + resp = self.client.get(self.url + '&type=Bonus&type=Deduction') + self.assertEqual(resp.context['adj_total_count'], 3) ids = {a.id for a in resp.context['adj_page'].object_list} - self.assertIn(self.a1.id, ids) - self.assertIn(self.a2.id, ids) - self.assertNotIn(self.a3.id, ids) + self.assertEqual(ids, {self.a1.id, self.a2.id, self.a3.id}) def test_worker_multi_filter(self): self._login_admin() @@ -1019,10 +1024,17 @@ class AdjustmentsTabTests(TestCase): self.assertIn(self.a3.id, ids) def test_team_filter_uses_subquery_no_inflation(self): - """Filtering by team must NOT multiply rows (M2M JOIN inflation - would give 6 rows for 3 adjustments x 2 workers on team Alpha).""" + """Filtering by team must NOT multiply rows. With 2 teams x 2 workers x 3 + adjustments, a naive worker__teams__id__in filter would return 6 inflated + rows; the subquery pattern returns the true 3. See CLAUDE.md ORM gotcha.""" self._login_admin() - resp = self.client.get(self.url + f'&team={self.team.id}') + resp = self.client.get( + self.url + f'&team={self.team.id}&team={self.team2.id}' + ) + # .count() at the queryset level would blow up under inflation — + # asserting it guards against regressions more strictly than checking + # the paginator's object_list length. + self.assertEqual(resp.context['adj_total_count'], 3) self.assertEqual(len(resp.context['adj_page'].object_list), 3) def test_status_filter_unpaid(self):