diff --git a/backend/apps/github/api/internal/dataloaders/__init__.py b/backend/apps/github/api/internal/dataloaders/__init__.py index 23368246eb..b8815d5124 100644 --- a/backend/apps/github/api/internal/dataloaders/__init__.py +++ b/backend/apps/github/api/internal/dataloaders/__init__.py @@ -2,6 +2,7 @@ from apps.github.api.internal.dataloaders.release import get_release_loaders from apps.github.api.internal.dataloaders.repository import get_repository_loaders +from apps.github.api.internal.dataloaders.user import get_user_loaders def get_github_dataloaders() -> dict[str, object]: @@ -9,4 +10,5 @@ def get_github_dataloaders() -> dict[str, object]: loaders: dict[str, object] = {} loaders.update(get_repository_loaders()) loaders.update(get_release_loaders()) + loaders.update(get_user_loaders()) return loaders diff --git a/backend/apps/github/api/internal/dataloaders/user.py b/backend/apps/github/api/internal/dataloaders/user.py new file mode 100644 index 0000000000..c08c6f9210 --- /dev/null +++ b/backend/apps/github/api/internal/dataloaders/user.py @@ -0,0 +1,65 @@ +"""DataLoaders for users.""" + +from django.db.models import Count +from strawberry.dataloader import DataLoader + +from apps.common.api.internal.dataloaders.utils import get_result_by_keys, get_results_by_keys +from apps.github.models.user import User +from apps.nest.models.badge import Badge +from apps.nest.models.user_badge import UserBadge + +USER_BADGES_BY_USER_ID_LOADER = "user_badges_by_user_id" +USER_ISSUES_COUNT_LOADER = "user_issues_count" +USER_RELEASES_COUNT_LOADER = "user_releases_count" + + +async def load_user_badges_by_user_id(user_ids: list[int]) -> list[list[Badge]]: + """Batch-load badges for the given user IDs in a single query.""" + user_badges = ( + UserBadge.objects.select_related("badge") + .filter(user_id__in=user_ids, is_active=True) + .order_by( + "badge__weight", + "badge__name", + ) + ) + return await get_results_by_keys( + user_badges, user_ids, key_field="user_id", value_field="badge" + ) + + +async def load_user_issues_count(user_ids: list[int]) -> list[int]: + """Batch-load issues count for the given user IDs in a single query.""" + users = User.objects.filter(pk__in=user_ids).annotate(items_count=Count("created_issues")) + return [ + result or 0 + for result in await get_result_by_keys( + users, user_ids, key_field="pk", value_field="items_count" + ) + ] + + +async def load_user_releases_count(user_ids: list[int]) -> list[int]: + """Batch-load releases count for the given user IDs in a single query.""" + users = User.objects.filter(pk__in=user_ids).annotate(items_count=Count("created_releases")) + return [ + result or 0 + for result in await get_result_by_keys( + users, user_ids, key_field="pk", value_field="items_count" + ) + ] + + +def get_user_loaders() -> dict[str, DataLoader[int, int] | DataLoader[int, list[Badge]]]: + """Return a mapping of per-request DataLoader instances.""" + return { + USER_BADGES_BY_USER_ID_LOADER: DataLoader[int, list[Badge]]( + load_fn=load_user_badges_by_user_id, + ), + USER_ISSUES_COUNT_LOADER: DataLoader[int, int]( + load_fn=load_user_issues_count, + ), + USER_RELEASES_COUNT_LOADER: DataLoader[int, int]( + load_fn=load_user_releases_count, + ), + } diff --git a/backend/apps/github/api/internal/nodes/issue.py b/backend/apps/github/api/internal/nodes/issue.py index 11a0378125..aa3aa77b8a 100644 --- a/backend/apps/github/api/internal/nodes/issue.py +++ b/backend/apps/github/api/internal/nodes/issue.py @@ -44,7 +44,7 @@ class IssueNode(strawberry.relay.Node): """GitHub issue node.""" assignees: list[UserNode] = strawberry_django.field() - author: UserNode | None = strawberry_django.field() + author: UserNode | None = strawberry_django.field(select_related=["author"]) @strawberry_django.field(prefetch_related=["pull_requests"]) def pull_requests(self, limit: int = 4, offset: int = 0) -> list[PullRequestNode]: @@ -57,7 +57,7 @@ def pull_requests(self, limit: int = 4, offset: int = 0) -> list[PullRequestNode self.pull_requests.all().order_by("-created_at")[offset : offset + normalized_limit] ) - @strawberry_django.field(select_related=["repository__organization", "repository"]) + @strawberry_django.field(select_related=["repository__organization"]) def organization_name(self, root: Issue) -> str | None: """Resolve organization name.""" return ( diff --git a/backend/apps/github/api/internal/nodes/user.py b/backend/apps/github/api/internal/nodes/user.py index a3929b5133..a650d65a28 100644 --- a/backend/apps/github/api/internal/nodes/user.py +++ b/backend/apps/github/api/internal/nodes/user.py @@ -2,7 +2,13 @@ import strawberry_django from django.db.models.query import Prefetch +from strawberry.types.info import Info +from apps.github.api.internal.dataloaders.user import ( + USER_BADGES_BY_USER_ID_LOADER, + USER_ISSUES_COUNT_LOADER, + USER_RELEASES_COUNT_LOADER, +) from apps.github.models.user import User from apps.nest.api.internal.nodes.badge import BadgeNode from apps.nest.models.user_badge import UserBadge @@ -41,10 +47,10 @@ class UserNode: """GitHub user node.""" - @strawberry_django.field(prefetch_related=[USER_BADGES_PREFETCH]) - def badges(self, root: User) -> list[BadgeNode]: + @strawberry_django.field + async def badges(self, root: User, info: Info) -> list[BadgeNode]: """Return user badges.""" - return [user_badge.badge for user_badge in getattr(root, "user_badges_list", [])] + return await info.context.github_dataloaders[USER_BADGES_BY_USER_ID_LOADER].load(root.pk) @strawberry_django.field def created_at(self, root: User) -> str: @@ -80,9 +86,9 @@ def is_gsoc_mentor(self, root: User) -> bool: return root.owasp_profile.is_gsoc_mentor if hasattr(root, "owasp_profile") else False @strawberry_django.field - def issues_count(self, root: User) -> int: + async def issues_count(self, root: User, info: Info) -> int: """Resolve issues count.""" - return root.idx_issues_count + return await info.context.github_dataloaders[USER_ISSUES_COUNT_LOADER].load(root.pk) @strawberry_django.field(select_related=["owasp_profile"]) def linkedin_page_id(self, root: User) -> str: @@ -94,9 +100,9 @@ def linkedin_page_id(self, root: User) -> str: ) @strawberry_django.field - def releases_count(self, root: User) -> int: + async def releases_count(self, root: User, info: Info) -> int: """Resolve releases count.""" - return root.idx_releases_count + return await info.context.github_dataloaders[USER_RELEASES_COUNT_LOADER].load(root.pk) @strawberry_django.field def updated_at(self, root: User) -> str: diff --git a/backend/tests/unit/apps/github/api/internal/dataloaders/user_test.py b/backend/tests/unit/apps/github/api/internal/dataloaders/user_test.py new file mode 100644 index 0000000000..dafabe0d36 --- /dev/null +++ b/backend/tests/unit/apps/github/api/internal/dataloaders/user_test.py @@ -0,0 +1,235 @@ +"""Tests for user dataloaders.""" + +from unittest.mock import MagicMock, patch + +import pytest +from strawberry.dataloader import DataLoader + +from apps.github.api.internal.dataloaders.user import ( + USER_BADGES_BY_USER_ID_LOADER, + USER_ISSUES_COUNT_LOADER, + USER_RELEASES_COUNT_LOADER, + get_user_loaders, + load_user_badges_by_user_id, + load_user_issues_count, + load_user_releases_count, +) + + +class TestLoadUserBadgesByUserId: + """Tests for load_user_badges_by_user_id.""" + + @patch("apps.github.api.internal.dataloaders.user.UserBadge") + @pytest.mark.asyncio + async def test_builds_queryset_with_correct_chain(self, mock_user_badge): + """Queryset uses filter, select_related, and order_by.""" + mock_qs = MagicMock() + mock_user_badge.objects.select_related.return_value = mock_qs + mock_qs.filter.return_value = mock_qs + mock_qs.order_by.return_value = qs = MagicMock() + qs.__aiter__.return_value = iter([]) + + await load_user_badges_by_user_id([1, 2]) + + mock_user_badge.objects.select_related.assert_called_once_with("badge") + mock_qs.filter.assert_called_once_with(user_id__in=[1, 2], is_active=True) + mock_qs.order_by.assert_called_once_with("badge__weight", "badge__name") + + @patch("apps.github.api.internal.dataloaders.user.UserBadge") + @pytest.mark.asyncio + async def test_returns_badges_grouped_by_user_id(self, mock_user_badge): + """Badges are grouped by user ID in the correct order.""" + badge_1 = MagicMock() + badge_2 = MagicMock() + badge_3 = MagicMock() + + user_badges = [ + MagicMock(user_id=1, badge=badge_1), + MagicMock(user_id=1, badge=badge_2), + MagicMock(user_id=2, badge=badge_3), + ] + + mock_qs = mock_user_badge.objects.select_related.return_value + mock_qs.filter.return_value = mock_qs + mock_qs.order_by.return_value = qs = MagicMock() + qs.__aiter__.return_value = iter(user_badges) + + result = await load_user_badges_by_user_id([1, 2]) + + assert result == [[badge_1, badge_2], [badge_3]] + + @patch("apps.github.api.internal.dataloaders.user.UserBadge") + @pytest.mark.asyncio + async def test_empty_user_ids(self, mock_user_badge): + """An empty user_ids list returns an empty list.""" + mock_qs = mock_user_badge.objects.select_related.return_value + mock_qs.filter.return_value = mock_qs + mock_qs.order_by.return_value = qs = MagicMock() + qs.__aiter__.return_value = iter([]) + + result = await load_user_badges_by_user_id([]) + + assert result == [] + + @patch("apps.github.api.internal.dataloaders.user.UserBadge") + @pytest.mark.asyncio + async def test_user_with_no_badges_returns_empty_list(self, mock_user_badge): + """A user with no badges yields an empty list.""" + mock_qs = mock_user_badge.objects.select_related.return_value + mock_qs.filter.return_value = mock_qs + mock_qs.order_by.return_value = qs = MagicMock() + qs.__aiter__.return_value = iter([]) + + result = await load_user_badges_by_user_id([1]) + + assert result == [[]] + + @patch("apps.github.api.internal.dataloaders.user.UserBadge") + @pytest.mark.asyncio + async def test_order_matches_keys_not_queryset(self, mock_user_badge): + """The output order follows user_ids, not the queryset iteration order.""" + badge_a = MagicMock() + badge_b = MagicMock() + + user_badges = [ + MagicMock(user_id=2, badge=badge_a), + MagicMock(user_id=1, badge=badge_b), + ] + + mock_qs = mock_user_badge.objects.select_related.return_value + mock_qs.filter.return_value = mock_qs + mock_qs.order_by.return_value = qs = MagicMock() + qs.__aiter__.return_value = iter(user_badges) + + result = await load_user_badges_by_user_id([1, 2]) + + assert result == [[badge_b], [badge_a]] + + +class TestLoadUserIssuesCount: + """Tests for load_user_issues_count.""" + + @patch("apps.github.api.internal.dataloaders.user.User") + @pytest.mark.asyncio + async def test_returns_counts_grouped_by_user_id(self, mock_user): + """Issues counts are grouped by user ID in the correct order.""" + users_data = [ + MagicMock(pk=1, items_count=5), + MagicMock(pk=2, items_count=3), + ] + mock_qs = MagicMock() + mock_user.objects.filter.return_value = mock_qs + mock_qs.annotate.return_value = qs = MagicMock() + qs.__aiter__.return_value = iter(users_data) + + result = await load_user_issues_count([1, 2]) + + assert result == [5, 3] + + @patch("apps.github.api.internal.dataloaders.user.User") + @pytest.mark.asyncio + async def test_empty_user_ids(self, mock_user): + """An empty user_ids list returns an empty list.""" + mock_qs = MagicMock() + mock_user.objects.filter.return_value = mock_qs + mock_qs.annotate.return_value = qs = MagicMock() + qs.__aiter__.return_value = iter([]) + + result = await load_user_issues_count([]) + + assert result == [] + + @patch("apps.github.api.internal.dataloaders.user.User") + @pytest.mark.asyncio + async def test_user_with_no_issues_returns_zero(self, mock_user): + """A user with no issues yields 0.""" + users_data = [MagicMock(pk=1, items_count=None)] + mock_qs = MagicMock() + mock_user.objects.filter.return_value = mock_qs + mock_qs.annotate.return_value = qs = MagicMock() + qs.__aiter__.return_value = iter(users_data) + + result = await load_user_issues_count([1]) + + assert result == [0] + + +class TestLoadUserReleasesCount: + """Tests for load_user_releases_count.""" + + @patch("apps.github.api.internal.dataloaders.user.User") + @pytest.mark.asyncio + async def test_returns_counts_grouped_by_user_id(self, mock_user): + """Releases counts are grouped by user ID in the correct order.""" + users_data = [ + MagicMock(pk=1, items_count=2), + MagicMock(pk=2, items_count=7), + ] + mock_qs = MagicMock() + mock_user.objects.filter.return_value = mock_qs + mock_qs.annotate.return_value = qs = MagicMock() + qs.__aiter__.return_value = iter(users_data) + + result = await load_user_releases_count([1, 2]) + + assert result == [2, 7] + + @patch("apps.github.api.internal.dataloaders.user.User") + @pytest.mark.asyncio + async def test_empty_user_ids(self, mock_user): + """An empty user_ids list returns an empty list.""" + mock_qs = MagicMock() + mock_user.objects.filter.return_value = mock_qs + mock_qs.annotate.return_value = qs = MagicMock() + qs.__aiter__.return_value = iter([]) + + result = await load_user_releases_count([]) + + assert result == [] + + @patch("apps.github.api.internal.dataloaders.user.User") + @pytest.mark.asyncio + async def test_user_with_no_releases_returns_zero(self, mock_user): + """A user with no releases yields 0.""" + users_data = [MagicMock(pk=1, items_count=None)] + mock_qs = MagicMock() + mock_user.objects.filter.return_value = mock_qs + mock_qs.annotate.return_value = qs = MagicMock() + qs.__aiter__.return_value = iter(users_data) + + result = await load_user_releases_count([1]) + + assert result == [0] + + +class TestGetUserLoaders: + """Tests for get_user_loaders.""" + + def test_returns_mapping_with_all_loaders(self): + """Factory returns a mapping with all loaders.""" + loaders = get_user_loaders() + assert USER_BADGES_BY_USER_ID_LOADER in loaders + assert USER_ISSUES_COUNT_LOADER in loaders + assert USER_RELEASES_COUNT_LOADER in loaders + assert isinstance(loaders[USER_BADGES_BY_USER_ID_LOADER], DataLoader) + assert isinstance(loaders[USER_ISSUES_COUNT_LOADER], DataLoader) + assert isinstance(loaders[USER_RELEASES_COUNT_LOADER], DataLoader) + + def test_returns_new_instances_on_each_call(self): + """Each call produces distinct DataLoader instances for per-request isolation.""" + loaders1 = get_user_loaders() + loaders2 = get_user_loaders() + assert loaders1 is not loaders2 + for key in ( + USER_BADGES_BY_USER_ID_LOADER, + USER_ISSUES_COUNT_LOADER, + USER_RELEASES_COUNT_LOADER, + ): + assert loaders1[key] is not loaders2[key] + + def test_load_fn_is_correct(self): + """Each loader is wired to its correct load function.""" + loaders = get_user_loaders() + assert loaders[USER_BADGES_BY_USER_ID_LOADER].load_fn is load_user_badges_by_user_id + assert loaders[USER_ISSUES_COUNT_LOADER].load_fn is load_user_issues_count + assert loaders[USER_RELEASES_COUNT_LOADER].load_fn is load_user_releases_count diff --git a/backend/tests/unit/apps/github/api/internal/nodes/user_test.py b/backend/tests/unit/apps/github/api/internal/nodes/user_test.py index 5d32d44e4c..0eede1014b 100644 --- a/backend/tests/unit/apps/github/api/internal/nodes/user_test.py +++ b/backend/tests/unit/apps/github/api/internal/nodes/user_test.py @@ -2,8 +2,15 @@ import math from datetime import UTC, datetime -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock +import pytest + +from apps.github.api.internal.dataloaders.user import ( + USER_BADGES_BY_USER_ID_LOADER, + USER_ISSUES_COUNT_LOADER, + USER_RELEASES_COUNT_LOADER, +) from apps.github.api.internal.nodes.user import UserNode from apps.nest.api.internal.nodes.badge import BadgeNode from tests.unit.apps.common.graphql_node_base_test import GraphQLNodeBaseTest @@ -58,22 +65,30 @@ def test_created_at_field(self): result = field.base_resolver.wrapped_func(None, mock_user) assert math.isclose(result, 1234567890.0) - def test_issues_count_field(self): + @pytest.mark.asyncio + async def test_issues_count_field(self): """Test issues_count field resolution.""" - mock_user = Mock() - mock_user.idx_issues_count = 42 + mock_user = Mock(pk=1) + mock_loader = Mock() + mock_loader.load = AsyncMock(return_value=42) + mock_info = Mock() + mock_info.context.github_dataloaders = {USER_ISSUES_COUNT_LOADER: mock_loader} field = self._get_field_by_name("issues_count", UserNode) - result = field.base_resolver.wrapped_func(None, mock_user) + result = await field.base_resolver.wrapped_func(None, mock_user, mock_info) assert result == 42 - def test_releases_count_field(self): + @pytest.mark.asyncio + async def test_releases_count_field(self): """Test releases_count field resolution.""" - mock_user = Mock() - mock_user.idx_releases_count = 15 + mock_user = Mock(pk=1) + mock_loader = Mock() + mock_loader.load = AsyncMock(return_value=15) + mock_info = Mock() + mock_info.context.github_dataloaders = {USER_RELEASES_COUNT_LOADER: mock_loader} field = self._get_field_by_name("releases_count", UserNode) - result = field.base_resolver.wrapped_func(None, mock_user) + result = await field.base_resolver.wrapped_func(None, mock_user, mock_info) assert result == 15 def test_updated_at_field(self): @@ -94,30 +109,38 @@ def test_url_field(self): result = field.base_resolver.wrapped_func(None, mock_user) assert result == "https://github.com/testuser" - def test_badges_field_empty(self): + @pytest.mark.asyncio + async def test_badges_field_empty(self): """Test badges field resolution with no badges.""" - mock_user = Mock() - mock_user.user_badges_list = [] + mock_user = Mock(pk=1) + mock_loader = Mock() + mock_loader.load = AsyncMock(return_value=[]) + mock_info = Mock() + mock_info.context.github_dataloaders = {USER_BADGES_BY_USER_ID_LOADER: mock_loader} field = self._get_field_by_name("badges", UserNode) - result = field.base_resolver.wrapped_func(None, mock_user) + result = await field.base_resolver.wrapped_func(None, mock_user, mock_info) assert result == [] - def test_badges_field_single_badge(self): + @pytest.mark.asyncio + async def test_badges_field_single_badge(self): """Test badges field resolution with single badge.""" - mock_user = Mock() + mock_user = Mock(pk=1) mock_badge = Mock(spec=BadgeNode) - mock_user_badge = Mock() - mock_user_badge.badge = mock_badge + mock_loader = Mock() + mock_loader.load = AsyncMock(return_value=[mock_badge]) + mock_info = Mock() + mock_info.context.github_dataloaders = {USER_BADGES_BY_USER_ID_LOADER: mock_loader} - mock_user.user_badges_list = [mock_user_badge] field = self._get_field_by_name("badges", UserNode) - result = field.base_resolver.wrapped_func(None, mock_user) + result = await field.base_resolver.wrapped_func(None, mock_user, mock_info) assert result == [mock_badge] - def test_badges_field_sorted_by_weight_and_name(self): + @pytest.mark.asyncio + async def test_badges_field_sorted_by_weight_and_name(self): """Test badges field resolution with multiple badges sorted by weight and name.""" - # Create mock badges with different weights and names + mock_user = Mock(pk=1) + mock_badge_high_weight = Mock(spec=BadgeNode) mock_badge_high_weight.weight = 100 mock_badge_high_weight.name = "High Weight Badge" @@ -134,39 +157,21 @@ def test_badges_field_sorted_by_weight_and_name(self): mock_badge_low_weight.weight = 10 mock_badge_low_weight.name = "Low Weight Badge" - # Create mock user badges - mock_user_badge_high = Mock() - mock_user_badge_high.badge = mock_badge_high_weight - - mock_user_badge_medium_a = Mock() - mock_user_badge_medium_a.badge = mock_badge_medium_weight_a - - mock_user_badge_medium_b = Mock() - mock_user_badge_medium_b.badge = mock_badge_medium_weight_b - - mock_user_badge_low = Mock() - mock_user_badge_low.badge = mock_badge_low_weight - - # Set up the mock queryset to return badges in the expected sorted order - # (lowest weight first, then by name for same weight) - mock_user = Mock() - mock_user.user_badges_list = [ - mock_user_badge_low, # weight 10 - mock_user_badge_medium_a, # weight 50, name "Medium Weight A" - mock_user_badge_medium_b, # weight 50, name "Medium Weight B" - mock_user_badge_high, # weight 100 - ] - - field = self._get_field_by_name("badges", UserNode) - result = field.base_resolver.wrapped_func(None, mock_user) - - # Verify the badges are returned in the correct order expected_badges = [ mock_badge_low_weight, mock_badge_medium_weight_a, mock_badge_medium_weight_b, mock_badge_high_weight, ] + + mock_loader = Mock() + mock_loader.load = AsyncMock(return_value=expected_badges) + mock_info = Mock() + mock_info.context.github_dataloaders = {USER_BADGES_BY_USER_ID_LOADER: mock_loader} + + field = self._get_field_by_name("badges", UserNode) + result = await field.base_resolver.wrapped_func(None, mock_user, mock_info) + assert result == expected_badges def test_first_owasp_contribution_at_with_profile(self):