Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ POSTHOG_HOST='https://eu.posthog.com'
POSTHOG_KEY=
SUPPORT_EMAIL=
TELEGRAM_TOKEN=
VAPID_PRIVATE_KEY=
VAPID_PUBLIC_KEY=
VAPID_SUBJECT=
PUSH_NOTIFICATIONS_TTL_SECONDS=

# Risk API (daily fire-weather index per camera)
RISK_API_URL=
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ jobs:
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
file: ./coverage-src.xml
files: ./coverage-src.xml
flags: backend
fail_ci_if_error: true

Expand Down Expand Up @@ -76,7 +76,7 @@ jobs:
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
file: ./client/coverage.xml
files: ./client/coverage.xml
flags: client
fail_ci_if_error: true

Expand Down
2,423 changes: 1,615 additions & 808 deletions poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ sqlmodel = "^0.0.24"
pydantic = ">=2.0.0,<3.0.0"
pydantic-settings = ">=2.0.0,<3.0.0"
requests = "^2.32.0"
pywebpush = "^2.3.0"
PyJWT = "^2.8.0"
passlib = { version = "^1.7.4", extras = ["bcrypt"] }
bcrypt = "3.1.7"
Expand Down
28 changes: 27 additions & 1 deletion src/app/api/api_v1/endpoints/detections.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,22 @@
get_jwt,
get_organization_crud,
get_pose_crud,
get_push_subscription_crud,
get_sequence_crud,
get_webhook_crud,
)
from app.core.config import settings
from app.core.time import utcnow
from app.crud import AlertCRUD, CameraCRUD, DetectionCRUD, OrganizationCRUD, PoseCRUD, SequenceCRUD, WebhookCRUD
from app.crud import (
AlertCRUD,
CameraCRUD,
DetectionCRUD,
OrganizationCRUD,
PoseCRUD,
SequenceCRUD,
WebhookCRUD,
)
from app.crud.crud_push_subscription import PushSubscriptionCRUD
from app.models import Alert, AlertSequence, Camera, Detection, Organization, Pose, Role, Sequence, UserRole
from app.schemas.alerts import AlertCreate, AlertUpdate
from app.schemas.detections import (
Expand All @@ -57,6 +67,7 @@
from app.schemas.sequences import SequenceUpdate
from app.services.cones import resolve_cone
from app.services.overlap import compute_overlap, haversine_km
from app.services.push_notifications import push_notification_client
from app.services.risk import risk_service
from app.services.sequence_confidence import max_conf_from_bboxes
from app.services.slack import slack_client
Expand Down Expand Up @@ -362,6 +373,7 @@ async def create_detection(
detections: DetectionCRUD = Depends(get_detection_crud),
webhooks: WebhookCRUD = Depends(get_webhook_crud),
organizations: OrganizationCRUD = Depends(get_organization_crud),
push_subscriptions_crud: PushSubscriptionCRUD = Depends(get_push_subscription_crud),
sequences: SequenceCRUD = Depends(get_sequence_crud),
alerts: AlertCRUD = Depends(get_alert_crud),
cameras: CameraCRUD = Depends(get_camera_crud),
Expand Down Expand Up @@ -516,6 +528,20 @@ async def create_detection(
min_conf,
)

if push_notification_client.is_enabled and alert_id is not None:
subscriptions = await push_subscriptions_crud.fetch_all(
filters=[("organization_id", token_payload.organization_id)]
)
if any(subscriptions):
background_tasks.add_task(
push_notification_client.notify_many,
subscriptions,
alert_id=alert_id,
camera_name=camera.name,
created_at=det.created_at,
sequence_azimuth=sequence_.sequence_azimuth,
)

created.append(det)

first_det = cast(Detection, await detections.get(created[0].id, strict=True))
Expand Down
85 changes: 85 additions & 0 deletions src/app/api/api_v1/endpoints/push_subscriptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright (C) 2024-2026, Pyronear.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.

from typing import List, cast

from fastapi import APIRouter, Depends, HTTPException, Path, Security, status

from app.api.dependencies import get_jwt, get_push_subscription_crud
from app.crud.crud_push_subscription import PushSubscriptionCRUD
from app.models import PushSubscription, UserRole
from app.schemas.login import TokenPayload
from app.schemas.push_subscriptions import (
PushSubscriptionCreate,
PushSubscriptionRead,
PushSubscriptionUpsert,
PushSubscriptionVapidPublicKey,
)
from app.services.push_notifications import push_notification_client
from app.services.telemetry import telemetry_client

router = APIRouter()


@router.get("/public-key", status_code=status.HTTP_200_OK, summary="Fetch the VAPID public key")
async def get_public_key(
token_payload: TokenPayload = Security(get_jwt, scopes=[UserRole.ADMIN, UserRole.AGENT, UserRole.USER]),
) -> PushSubscriptionVapidPublicKey:
telemetry_client.capture(token_payload.sub, event="push-subscriptions-public-key")
if not push_notification_client.is_enabled:
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Push notifications are disabled")
return PushSubscriptionVapidPublicKey(public_key=push_notification_client.get_public_key())


@router.post("/", status_code=status.HTTP_200_OK, summary="Register or update a push subscription")
async def register_push_subscription(
payload: PushSubscriptionUpsert,
subscriptions: PushSubscriptionCRUD = Depends(get_push_subscription_crud),
token_payload: TokenPayload = Security(get_jwt, scopes=[UserRole.ADMIN, UserRole.AGENT, UserRole.USER]),
) -> PushSubscriptionRead:
telemetry_client.capture(token_payload.sub, event="push-subscriptions-upsert")
subscription = await subscriptions.upsert_for_user(
token_payload.sub,
token_payload.organization_id,
PushSubscriptionCreate(
user_id=token_payload.sub,
organization_id=token_payload.organization_id,
endpoint=payload.endpoint,
auth=payload.keys.auth,
p256dh=payload.keys.p256dh,
expiration_time=payload.expiration_time,
user_agent=payload.user_agent,
),
)
return PushSubscriptionRead(**subscription.model_dump())


@router.get("/", status_code=status.HTTP_200_OK, summary="Fetch current user's push subscriptions")
async def fetch_push_subscriptions(
subscriptions: PushSubscriptionCRUD = Depends(get_push_subscription_crud),
token_payload: TokenPayload = Security(get_jwt, scopes=[UserRole.ADMIN, UserRole.AGENT, UserRole.USER]),
) -> List[PushSubscriptionRead]:
telemetry_client.capture(token_payload.sub, event="push-subscriptions-fetch")
return [
PushSubscriptionRead(**elt.model_dump())
for elt in await subscriptions.fetch_all(filters=[("user_id", token_payload.sub)], order_by="created_at")
]


@router.delete("/{subscription_id}", status_code=status.HTTP_200_OK, summary="Delete a push subscription")
async def delete_push_subscription(
subscription_id: int = Path(..., gt=0),
subscriptions: PushSubscriptionCRUD = Depends(get_push_subscription_crud),
token_payload: TokenPayload = Security(get_jwt, scopes=[UserRole.ADMIN, UserRole.AGENT, UserRole.USER]),
) -> None:
telemetry_client.capture(
token_payload.sub,
event="push-subscriptions-delete",
properties={"subscription_id": subscription_id},
)
subscription = cast(PushSubscription, await subscriptions.get(subscription_id, strict=True))
if subscription.user_id != token_payload.sub:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Access forbidden.")
await subscriptions.delete(subscription_id)
2 changes: 2 additions & 0 deletions src/app/api/api_v1/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
occlusion_masks,
organizations,
poses,
push_subscriptions,
sequences,
users,
webhooks,
Expand All @@ -31,3 +32,4 @@
api_router.include_router(alerts.router, prefix="/alerts", tags=["alerts"])
api_router.include_router(organizations.router, prefix="/organizations", tags=["organizations"])
api_router.include_router(webhooks.router, prefix="/webhooks", tags=["webhooks"])
api_router.include_router(push_subscriptions.router, prefix="/push-subscriptions", tags=["push-subscriptions"])
15 changes: 14 additions & 1 deletion src/app/api/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,18 @@
from sqlmodel.ext.asyncio.session import AsyncSession

from app.core.config import settings
from app.crud import AlertCRUD, CameraCRUD, DetectionCRUD, OrganizationCRUD, SequenceCRUD, UserCRUD, WebhookCRUD
from app.crud import (
AlertCRUD,
CameraCRUD,
DetectionCRUD,
OrganizationCRUD,
SequenceCRUD,
UserCRUD,
WebhookCRUD,
)
from app.crud.crud_occlusion_mask import OcclusionMaskCRUD
from app.crud.crud_pose import PoseCRUD
from app.crud.crud_push_subscription import PushSubscriptionCRUD
from app.db import get_session
from app.models import User, UserRole
from app.schemas.login import TokenPayload
Expand Down Expand Up @@ -74,6 +83,10 @@ def get_alert_crud(session: AsyncSession = Depends(get_session)) -> AlertCRUD:
return AlertCRUD(session=session)


def get_push_subscription_crud(session: AsyncSession = Depends(get_session)) -> PushSubscriptionCRUD:
return PushSubscriptionCRUD(session=session)


def decode_token(token: str, authenticate_value: Union[str, None] = None) -> Dict[str, str]:
try:
payload = jwt_decode(token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM])
Expand Down
4 changes: 4 additions & 0 deletions src/app/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ def sqlachmey_uri(cls, v: str) -> str:
# Notifications
TELEGRAM_TOKEN: Union[str, None] = os.environ.get("TELEGRAM_TOKEN")
PLATFORM_URL: str = os.environ.get("PLATFORM_URL", "")
VAPID_PRIVATE_KEY: Union[str, None] = os.environ.get("VAPID_PRIVATE_KEY")
VAPID_PUBLIC_KEY: Union[str, None] = os.environ.get("VAPID_PUBLIC_KEY")
VAPID_SUBJECT: Union[str, None] = os.environ.get("VAPID_SUBJECT")
PUSH_NOTIFICATIONS_TTL_SECONDS: int = int(os.environ.get("PUSH_NOTIFICATIONS_TTL_SECONDS") or 300)

# Risk API (daily fire-weather index per camera)
RISK_API_URL: Union[str, None] = os.environ.get("RISK_API_URL")
Expand Down
53 changes: 53 additions & 0 deletions src/app/crud/crud_push_subscription.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (C) 2024-2026, Pyronear.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.

from sqlmodel.ext.asyncio.session import AsyncSession

from app.core.time import utcnow
from app.crud.base import BaseCRUD
from app.models import PushSubscription
from app.schemas.push_subscriptions import PushSubscriptionCreate, PushSubscriptionUpdate

__all__ = ["PushSubscriptionCRUD"]


class PushSubscriptionCRUD(BaseCRUD[PushSubscription, PushSubscriptionCreate, PushSubscriptionUpdate]):
def __init__(self, session: AsyncSession) -> None:
"""Initialize push subscription CRUD."""
super().__init__(session, PushSubscription)

async def upsert_for_user(
self,
user_id: int,
organization_id: int,
payload: PushSubscriptionCreate,
) -> PushSubscription:
existing = await self.get_by("endpoint", payload.endpoint, strict=False)
if existing is None:
return await self.create(
PushSubscriptionCreate(
user_id=user_id,
organization_id=organization_id,
endpoint=payload.endpoint,
auth=payload.auth,
p256dh=payload.p256dh,
expiration_time=payload.expiration_time,
user_agent=payload.user_agent,
)
)

return await self.update(
existing.id,
PushSubscriptionUpdate(
user_id=user_id,
organization_id=organization_id,
endpoint=payload.endpoint,
auth=payload.auth,
p256dh=payload.p256dh,
expiration_time=payload.expiration_time,
user_agent=payload.user_agent,
updated_at=utcnow(),
),
)
28 changes: 26 additions & 2 deletions src/app/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,24 @@

from datetime import datetime
from enum import Enum
from typing import Union
from typing import Optional, Union

from sqlmodel import Field, SQLModel

from app.core.config import settings
from app.core.time import utcnow

__all__ = ["Alert", "AlertSequence", "Camera", "Detection", "Organization", "Pose", "Sequence", "User"]
__all__ = [
"Alert",
"AlertSequence",
"Camera",
"Detection",
"Organization",
"Pose",
"PushSubscription",
"Sequence",
"User",
]


class UserRole(str, Enum):
Expand Down Expand Up @@ -142,3 +152,17 @@ class Webhook(SQLModel, table=True):
__tablename__ = "webhooks"
id: int = Field(None, primary_key=True)
url: str = Field(..., nullable=False, unique=True)


class PushSubscription(SQLModel, table=True):
__tablename__ = "push_subscriptions"
id: int = Field(None, primary_key=True)
user_id: int = Field(..., foreign_key="users.id", nullable=False)
organization_id: int = Field(..., foreign_key="organizations.id", nullable=False)
endpoint: str = Field(..., nullable=False, unique=True)
auth: str = Field(..., min_length=1, max_length=255, nullable=False)
p256dh: str = Field(..., min_length=1, max_length=255, nullable=False)
expiration_time: Optional[datetime] = Field(default=None, nullable=True)
user_agent: Union[str, None] = Field(default=None, max_length=512, nullable=True)
created_at: datetime = Field(default_factory=utcnow, nullable=False)
updated_at: datetime = Field(default_factory=utcnow, nullable=False)
Loading
Loading