Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backend/apps/github/api/internal/dataloaders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

from apps.github.api.internal.dataloaders.release import get_release_loaders
from apps.github.api.internal.dataloaders.repository import get_repository_loaders
from apps.github.api.internal.dataloaders.user import get_user_loaders


def get_github_dataloaders() -> dict[str, object]:
"""Return a dict of dataloader instances for GitHub API resolvers."""
loaders: dict[str, object] = {}
loaders.update(get_repository_loaders())
loaders.update(get_release_loaders())
loaders.update(get_user_loaders())
return loaders
65 changes: 65 additions & 0 deletions backend/apps/github/api/internal/dataloaders/user.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""DataLoaders for users."""

from django.db.models import Count
from strawberry.dataloader import DataLoader

from apps.common.api.internal.dataloaders.utils import get_result_by_keys, get_results_by_keys
from apps.github.models.user import User
from apps.nest.models.badge import Badge
from apps.nest.models.user_badge import UserBadge

USER_BADGES_BY_USER_ID_LOADER = "user_badges_by_user_id"
USER_ISSUES_COUNT_LOADER = "user_issues_count"
USER_RELEASES_COUNT_LOADER = "user_releases_count"


async def load_user_badges_by_user_id(user_ids: list[int]) -> list[list[Badge]]:
"""Batch-load badges for the given user IDs in a single query."""
user_badges = (
UserBadge.objects.select_related("badge")
.filter(user_id__in=user_ids, is_active=True)
.order_by(
"badge__weight",
"badge__name",
)
)
return await get_results_by_keys(
user_badges, user_ids, key_field="user_id", value_field="badge"
)


async def load_user_issues_count(user_ids: list[int]) -> list[int]:
"""Batch-load issues count for the given user IDs in a single query."""
users = User.objects.filter(pk__in=user_ids).annotate(items_count=Count("created_issues"))
return [
result or 0
for result in await get_result_by_keys(
users, user_ids, key_field="pk", value_field="items_count"
)
]


async def load_user_releases_count(user_ids: list[int]) -> list[int]:
"""Batch-load releases count for the given user IDs in a single query."""
users = User.objects.filter(pk__in=user_ids).annotate(items_count=Count("created_releases"))
return [
result or 0
for result in await get_result_by_keys(
users, user_ids, key_field="pk", value_field="items_count"
)
]


def get_user_loaders() -> dict[str, DataLoader[int, int] | DataLoader[int, list[Badge]]]:
"""Return a mapping of per-request DataLoader instances."""
return {
USER_BADGES_BY_USER_ID_LOADER: DataLoader[int, list[Badge]](
load_fn=load_user_badges_by_user_id,
),
USER_ISSUES_COUNT_LOADER: DataLoader[int, int](
load_fn=load_user_issues_count,
),
USER_RELEASES_COUNT_LOADER: DataLoader[int, int](
load_fn=load_user_releases_count,
),
}
4 changes: 2 additions & 2 deletions backend/apps/github/api/internal/nodes/issue.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class IssueNode(strawberry.relay.Node):
"""GitHub issue node."""

assignees: list[UserNode] = strawberry_django.field()
author: UserNode | None = strawberry_django.field()
author: UserNode | None = strawberry_django.field(select_related=["author"])

@strawberry_django.field(prefetch_related=["pull_requests"])
def pull_requests(self, limit: int = 4, offset: int = 0) -> list[PullRequestNode]:
Expand All @@ -57,7 +57,7 @@ def pull_requests(self, limit: int = 4, offset: int = 0) -> list[PullRequestNode
self.pull_requests.all().order_by("-created_at")[offset : offset + normalized_limit]
)

@strawberry_django.field(select_related=["repository__organization", "repository"])
@strawberry_django.field(select_related=["repository__organization"])
def organization_name(self, root: Issue) -> str | None:
"""Resolve organization name."""
return (
Expand Down
20 changes: 13 additions & 7 deletions backend/apps/github/api/internal/nodes/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,13 @@

import strawberry_django
from django.db.models.query import Prefetch
from strawberry.types.info import Info

from apps.github.api.internal.dataloaders.user import (
USER_BADGES_BY_USER_ID_LOADER,
USER_ISSUES_COUNT_LOADER,
USER_RELEASES_COUNT_LOADER,
)
from apps.github.models.user import User
from apps.nest.api.internal.nodes.badge import BadgeNode
from apps.nest.models.user_badge import UserBadge
Expand Down Expand Up @@ -41,10 +47,10 @@
class UserNode:
"""GitHub user node."""

@strawberry_django.field(prefetch_related=[USER_BADGES_PREFETCH])
def badges(self, root: User) -> list[BadgeNode]:
@strawberry_django.field
async def badges(self, root: User, info: Info) -> list[BadgeNode]:
"""Return user badges."""
return [user_badge.badge for user_badge in getattr(root, "user_badges_list", [])]
return await info.context.github_dataloaders[USER_BADGES_BY_USER_ID_LOADER].load(root.pk)

@strawberry_django.field
def created_at(self, root: User) -> str:
Expand Down Expand Up @@ -80,9 +86,9 @@ def is_gsoc_mentor(self, root: User) -> bool:
return root.owasp_profile.is_gsoc_mentor if hasattr(root, "owasp_profile") else False

@strawberry_django.field
def issues_count(self, root: User) -> int:
async def issues_count(self, root: User, info: Info) -> int:
"""Resolve issues count."""
return root.idx_issues_count
return await info.context.github_dataloaders[USER_ISSUES_COUNT_LOADER].load(root.pk)

@strawberry_django.field(select_related=["owasp_profile"])
def linkedin_page_id(self, root: User) -> str:
Expand All @@ -94,9 +100,9 @@ def linkedin_page_id(self, root: User) -> str:
)

@strawberry_django.field
def releases_count(self, root: User) -> int:
async def releases_count(self, root: User, info: Info) -> int:
"""Resolve releases count."""
return root.idx_releases_count
return await info.context.github_dataloaders[USER_RELEASES_COUNT_LOADER].load(root.pk)

@strawberry_django.field
def updated_at(self, root: User) -> str:
Expand Down
235 changes: 235 additions & 0 deletions backend/tests/unit/apps/github/api/internal/dataloaders/user_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
"""Tests for user dataloaders."""

from unittest.mock import MagicMock, patch

import pytest
from strawberry.dataloader import DataLoader

from apps.github.api.internal.dataloaders.user import (
USER_BADGES_BY_USER_ID_LOADER,
USER_ISSUES_COUNT_LOADER,
USER_RELEASES_COUNT_LOADER,
get_user_loaders,
load_user_badges_by_user_id,
load_user_issues_count,
load_user_releases_count,
)


class TestLoadUserBadgesByUserId:
"""Tests for load_user_badges_by_user_id."""

@patch("apps.github.api.internal.dataloaders.user.UserBadge")
@pytest.mark.asyncio
async def test_builds_queryset_with_correct_chain(self, mock_user_badge):
"""Queryset uses filter, select_related, and order_by."""
mock_qs = MagicMock()
mock_user_badge.objects.select_related.return_value = mock_qs
mock_qs.filter.return_value = mock_qs
mock_qs.order_by.return_value = qs = MagicMock()
qs.__aiter__.return_value = iter([])

await load_user_badges_by_user_id([1, 2])

mock_user_badge.objects.select_related.assert_called_once_with("badge")
mock_qs.filter.assert_called_once_with(user_id__in=[1, 2], is_active=True)
mock_qs.order_by.assert_called_once_with("badge__weight", "badge__name")

@patch("apps.github.api.internal.dataloaders.user.UserBadge")
@pytest.mark.asyncio
async def test_returns_badges_grouped_by_user_id(self, mock_user_badge):
"""Badges are grouped by user ID in the correct order."""
badge_1 = MagicMock()
badge_2 = MagicMock()
badge_3 = MagicMock()

user_badges = [
MagicMock(user_id=1, badge=badge_1),
MagicMock(user_id=1, badge=badge_2),
MagicMock(user_id=2, badge=badge_3),
]

mock_qs = mock_user_badge.objects.select_related.return_value
mock_qs.filter.return_value = mock_qs
mock_qs.order_by.return_value = qs = MagicMock()
qs.__aiter__.return_value = iter(user_badges)

result = await load_user_badges_by_user_id([1, 2])

assert result == [[badge_1, badge_2], [badge_3]]

@patch("apps.github.api.internal.dataloaders.user.UserBadge")
@pytest.mark.asyncio
async def test_empty_user_ids(self, mock_user_badge):
"""An empty user_ids list returns an empty list."""
mock_qs = mock_user_badge.objects.select_related.return_value
mock_qs.filter.return_value = mock_qs
mock_qs.order_by.return_value = qs = MagicMock()
qs.__aiter__.return_value = iter([])

result = await load_user_badges_by_user_id([])

assert result == []

@patch("apps.github.api.internal.dataloaders.user.UserBadge")
@pytest.mark.asyncio
async def test_user_with_no_badges_returns_empty_list(self, mock_user_badge):
"""A user with no badges yields an empty list."""
mock_qs = mock_user_badge.objects.select_related.return_value
mock_qs.filter.return_value = mock_qs
mock_qs.order_by.return_value = qs = MagicMock()
qs.__aiter__.return_value = iter([])

result = await load_user_badges_by_user_id([1])

assert result == [[]]

@patch("apps.github.api.internal.dataloaders.user.UserBadge")
@pytest.mark.asyncio
async def test_order_matches_keys_not_queryset(self, mock_user_badge):
"""The output order follows user_ids, not the queryset iteration order."""
badge_a = MagicMock()
badge_b = MagicMock()

user_badges = [
MagicMock(user_id=2, badge=badge_a),
MagicMock(user_id=1, badge=badge_b),
]

mock_qs = mock_user_badge.objects.select_related.return_value
mock_qs.filter.return_value = mock_qs
mock_qs.order_by.return_value = qs = MagicMock()
qs.__aiter__.return_value = iter(user_badges)

result = await load_user_badges_by_user_id([1, 2])

assert result == [[badge_b], [badge_a]]


class TestLoadUserIssuesCount:
"""Tests for load_user_issues_count."""

@patch("apps.github.api.internal.dataloaders.user.User")
@pytest.mark.asyncio
async def test_returns_counts_grouped_by_user_id(self, mock_user):
"""Issues counts are grouped by user ID in the correct order."""
users_data = [
MagicMock(pk=1, items_count=5),
MagicMock(pk=2, items_count=3),
]
mock_qs = MagicMock()
mock_user.objects.filter.return_value = mock_qs
mock_qs.annotate.return_value = qs = MagicMock()
qs.__aiter__.return_value = iter(users_data)

result = await load_user_issues_count([1, 2])

assert result == [5, 3]

@patch("apps.github.api.internal.dataloaders.user.User")
@pytest.mark.asyncio
async def test_empty_user_ids(self, mock_user):
"""An empty user_ids list returns an empty list."""
mock_qs = MagicMock()
mock_user.objects.filter.return_value = mock_qs
mock_qs.annotate.return_value = qs = MagicMock()
qs.__aiter__.return_value = iter([])

result = await load_user_issues_count([])

assert result == []

@patch("apps.github.api.internal.dataloaders.user.User")
@pytest.mark.asyncio
async def test_user_with_no_issues_returns_zero(self, mock_user):
"""A user with no issues yields 0."""
users_data = [MagicMock(pk=1, items_count=None)]
mock_qs = MagicMock()
mock_user.objects.filter.return_value = mock_qs
mock_qs.annotate.return_value = qs = MagicMock()
qs.__aiter__.return_value = iter(users_data)

result = await load_user_issues_count([1])

assert result == [0]


class TestLoadUserReleasesCount:
"""Tests for load_user_releases_count."""

@patch("apps.github.api.internal.dataloaders.user.User")
@pytest.mark.asyncio
async def test_returns_counts_grouped_by_user_id(self, mock_user):
"""Releases counts are grouped by user ID in the correct order."""
users_data = [
MagicMock(pk=1, items_count=2),
MagicMock(pk=2, items_count=7),
]
mock_qs = MagicMock()
mock_user.objects.filter.return_value = mock_qs
mock_qs.annotate.return_value = qs = MagicMock()
qs.__aiter__.return_value = iter(users_data)

result = await load_user_releases_count([1, 2])

assert result == [2, 7]

@patch("apps.github.api.internal.dataloaders.user.User")
@pytest.mark.asyncio
async def test_empty_user_ids(self, mock_user):
"""An empty user_ids list returns an empty list."""
mock_qs = MagicMock()
mock_user.objects.filter.return_value = mock_qs
mock_qs.annotate.return_value = qs = MagicMock()
qs.__aiter__.return_value = iter([])

result = await load_user_releases_count([])

assert result == []

@patch("apps.github.api.internal.dataloaders.user.User")
@pytest.mark.asyncio
async def test_user_with_no_releases_returns_zero(self, mock_user):
"""A user with no releases yields 0."""
users_data = [MagicMock(pk=1, items_count=None)]
mock_qs = MagicMock()
mock_user.objects.filter.return_value = mock_qs
mock_qs.annotate.return_value = qs = MagicMock()
qs.__aiter__.return_value = iter(users_data)

result = await load_user_releases_count([1])

assert result == [0]


class TestGetUserLoaders:
"""Tests for get_user_loaders."""

def test_returns_mapping_with_all_loaders(self):
"""Factory returns a mapping with all loaders."""
loaders = get_user_loaders()
assert USER_BADGES_BY_USER_ID_LOADER in loaders
assert USER_ISSUES_COUNT_LOADER in loaders
assert USER_RELEASES_COUNT_LOADER in loaders
assert isinstance(loaders[USER_BADGES_BY_USER_ID_LOADER], DataLoader)
assert isinstance(loaders[USER_ISSUES_COUNT_LOADER], DataLoader)
assert isinstance(loaders[USER_RELEASES_COUNT_LOADER], DataLoader)

def test_returns_new_instances_on_each_call(self):
"""Each call produces distinct DataLoader instances for per-request isolation."""
loaders1 = get_user_loaders()
loaders2 = get_user_loaders()
assert loaders1 is not loaders2
for key in (
USER_BADGES_BY_USER_ID_LOADER,
USER_ISSUES_COUNT_LOADER,
USER_RELEASES_COUNT_LOADER,
):
assert loaders1[key] is not loaders2[key]

def test_load_fn_is_correct(self):
"""Each loader is wired to its correct load function."""
loaders = get_user_loaders()
assert loaders[USER_BADGES_BY_USER_ID_LOADER].load_fn is load_user_badges_by_user_id
assert loaders[USER_ISSUES_COUNT_LOADER].load_fn is load_user_issues_count
assert loaders[USER_RELEASES_COUNT_LOADER].load_fn is load_user_releases_count
Loading
Loading