diff --git a/src/api/dependencies.py b/src/api/dependencies.py index cd22980..7eba554 100644 --- a/src/api/dependencies.py +++ b/src/api/dependencies.py @@ -59,14 +59,14 @@ def get_account_from_bearer( if user is None: raise credentials_exception - if not user.is_active: - raise HTTPException(status_code=400, detail="Inactive account") + if not user.is_active: + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="account deactivated") return user def get_active_account(account: Account = Depends(get_account_from_bearer)) -> Account: if not account.is_active: - raise HTTPException(status_code=400, detail="Inactive account") + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="account deactivated") return account def get_account_even_if_inactive( diff --git a/src/api/roles/client/client.py b/src/api/roles/client/client.py index 7f2df20..1cdcabf 100644 --- a/src/api/roles/client/client.py +++ b/src/api/roles/client/client.py @@ -9,7 +9,7 @@ from sqlalchemy import func, desc, asc, delete from src import config -from src.api.dependencies import get_account_from_bearer, get_client_account, PaginationParams +from src.api.dependencies import get_account_from_bearer, get_client_account, get_active_account, PaginationParams #models from src.api.roles.client.domain import ( @@ -39,8 +39,18 @@ from src.database.coach_client_relationship.models import ClientCoachRequest, ClientCoachRelationship from src.database.account.models import Account, Availability, Notification from src.database.client.models import Client, ClientAvailability, FitnessGoals -from src.database.telemetry.models import HealthMetrics, ClientTelemetry -from src.database.telemetry.models import ClientTelemetry +from src.database.telemetry.models import ( + HealthMetrics, + ClientTelemetry, + StepCount, + DailyMoodSurvey, + DailyWorkoutSurvey, + DailyBodyMetricsSurvey, + DailyStepsSurvey, + DailyMealSurvey, + CompletedMealActivity, + CompletedWorkout, +) from src.database.reports.models import CoachReport, CoachReviews from src.database.payment.models import PaymentInformation, Invoice, BillingCycle, Subscription, PricingPlan @@ -189,11 +199,21 @@ def create_coach_request(coach_id: int, db = Depends(get_session), acc: Account client = db.get(Client, acc.client_id) coach = db.get(Coach, coach_id) - if coach is None: - raise HTTPException(404, detail="Coach not found") - - existing_request = db.query(ClientCoachRequest).filter_by( - client_id=client.id, coach_id=coach.id, is_accepted=None + if coach is None: + raise HTTPException(404, detail="Coach not found") + + coach_account = db.exec( + select(Account).where( + Account.coach_id == coach.id, + Account.is_active == True, + ) + ).first() + + if coach_account is None or not coach.verified: + raise HTTPException(404, detail="Coach not available") + + existing_request = db.query(ClientCoachRequest).filter_by( + client_id=client.id, coach_id=coach.id, is_accepted=None ).first() if existing_request: @@ -207,10 +227,9 @@ def create_coach_request(coach_id: int, db = Depends(get_session), acc: Account db.refresh(request) # notify the coach's account that a new request was created - coach_account = db.exec(select(Account).where(Account.coach_id == coach.id)).first() - if coach_account and coach_account.id is not None: - n = Notification( - account_id=coach_account.id, + if coach_account and coach_account.id is not None: + n = Notification( + account_id=coach_account.id, fav_category="relationship_request_creation", message=f"{acc.name} has requested to hire you.", details=f"Request {request.id} from client {client.id} to coach {coach.id}.", @@ -515,12 +534,24 @@ def get_review(coach_id: int, db = Depends(get_session), acc: Account = Depends( if acc.id is None: raise HTTPException(404, detail="Account not found") - if acc.client_id is None: - raise HTTPException(403, detail="You are not authorized to view this content") - - reviews = db.query(CoachReviews).filter(CoachReviews.coach_id == coach_id).all() - - return ReviewsResponse(reviews=reviews) + if acc.client_id is None: + raise HTTPException(403, detail="You are not authorized to view this content") + + coach_account = db.exec( + select(Account).where( + Account.coach_id == coach_id, + Account.is_active == True, + ) + ).first() + + if coach_account is None: + return ReviewsResponse(reviews=[]) + + reviews = db.exec( + select(CoachReviews).where(CoachReviews.coach_id == coach_id) + ).all() + + return ReviewsResponse(reviews=reviews) @router.get("/my_coach", response_model=MyCoachResponse) def get_my_coach(db = Depends(get_session), acc: Account = Depends(get_client_account)): @@ -533,8 +564,9 @@ def get_my_coach(db = Depends(get_session), acc: Account = Depends(get_client_ac coach_request = db.query(ClientCoachRequest).filter(ClientCoachRequest.client_id == acc.client_id).first() - if not coach_request.is_accepted: - raise HTTPException(403, detail="You are not authorized to see this coach until the request is accepted") + # If no request found or the request hasn't been accepted, surface as not found. + if coach_request is None or not getattr(coach_request, "is_accepted", False): + raise HTTPException(404, detail="No active coach relationship found") relationship = db.query(ClientCoachRelationship).filter(ClientCoachRelationship.request_id == coach_request.id).first() @@ -557,3 +589,150 @@ def get_my_coach_requests(db = Depends(get_session), acc: Account = Depends(get_ requests = db.get(ClientCoachRequest).filter(ClientCoachRequest.client_id == acc.client_id).all() return MyCoachRequestsResponse(requests = requests) + +@router.get("/coach_profile/{coach_id}") +def get_coach_profile(coach_id: int, db = Depends(get_session), acc: Account = Depends(get_client_account)): + """ + Allows a client to view a coach's profile given their ID. + Returns account basics, specialties, certifications, experiences, + pricing/payment plan, availability, and rating summary. + """ + + coach = db.get(Coach, coach_id) + + if coach is None: + raise HTTPException(404, detail="Coach not found") + + coach_account = db.exec( + select(Account).where(Account.coach_id == coach_id) + ).first() + + if coach_account is None: + raise HTTPException(404, detail="Coach account not found") + + if not coach_account.is_active or not coach.verified: + raise HTTPException(404, detail="Coach not available") + + certifications = db.exec( + select(Certifications) + .join(CoachCertifications, CoachCertifications.certification_id == Certifications.id) + .where(CoachCertifications.coach_id == coach_id) + ).all() + + experiences = db.exec( + select(Experience) + .join(CoachExperience, CoachExperience.experience_id == Experience.id) + .where(CoachExperience.coach_id == coach_id) + ).all() + + availability = db.exec( + select(Availability).where( + Availability.coach_availability_id == coach.coach_availability + ) + ).all() + + pricing_plan = db.exec( + select(PricingPlan).where(PricingPlan.coach_id == coach_id) + ).first() + + rating_summary = db.exec( + select( + func.count(CoachReviews.id).label("rating_count"), + func.avg(CoachReviews.rating).label("avg_rating"), + ).where(CoachReviews.coach_id == coach_id) + ).first() + + return { + "base_account": { + "id": coach_account.id, + "name": coach_account.name, + "email": coach_account.email, + "is_active": coach_account.is_active, + "gender": coach_account.gender, + "bio": coach_account.bio, + "age": coach_account.age, + "pfp_url": coach_account.pfp_url, + "client_id": coach_account.client_id, + "coach_id": coach_account.coach_id, + "admin_id": coach_account.admin_id, + "created_at": coach_account.created_at, + }, + "coach_account": coach, + "specialties": coach.specialties, + "certifications": certifications, + "experiences": experiences, + "pricing_plan": pricing_plan, + "availability": availability, + "rating_summary": { + "rating_count": int(rating_summary.rating_count or 0), + "avg_rating": float(rating_summary.avg_rating) if rating_summary.avg_rating is not None else None, + }, + } + + +@router.get("/progress_pictures") +def get_progress_pictures(db = Depends(get_session), acc: Account = Depends(get_client_account)): + """ + Queries progress picture URLs for the logged-in client. + Progress pictures are stored in HealthMetrics.progress_pic_url. + """ + + if acc.client_id is None: + raise HTTPException(403, detail="Client profile required") + + pictures = db.exec( + select( + ClientTelemetry.date, + HealthMetrics.progress_pic_url, + ) + .join(HealthMetrics, HealthMetrics.client_telemetry_id == ClientTelemetry.id) + .where( + ClientTelemetry.client_id == acc.client_id, + HealthMetrics.progress_pic_url.is_not(None), + ) + .order_by(ClientTelemetry.date.desc()) + ).all() + + return [ + { + "date": pic.date, + "progress_pic_url": pic.progress_pic_url, + } + for pic in pictures + ] + + +@router.get("/my_coach") +def get_my_coach(db = Depends(get_session), acc: Account = Depends(get_client_account)): + """ + Returns the active coach relationship for the logged-in client. + """ + + if acc.client_id is None: + raise HTTPException(403, detail="Client profile required") + + result = db.exec( + select(ClientCoachRequest, ClientCoachRelationship) + .join(ClientCoachRelationship, ClientCoachRelationship.request_id == ClientCoachRequest.id) + .where( + ClientCoachRequest.client_id == acc.client_id, + ClientCoachRequest.is_accepted == True, + ClientCoachRelationship.is_active == True, + ClientCoachRelationship.client_blocked == False, + ClientCoachRelationship.coach_blocked == False, + ) + ).first() + + if result is None: + raise HTTPException(404, detail="No active coach relationship found") + + request, relationship = result + + return { + "relationship_id": relationship.id, + "request_id": request.id, + "client_id": request.client_id, + "coach_id": request.coach_id, + "created_at": relationship.created_at, + "is_active": relationship.is_active, + } diff --git a/src/api/roles/coach/coach.py b/src/api/roles/coach/coach.py index 5bcaaa7..a9222f2 100644 --- a/src/api/roles/coach/coach.py +++ b/src/api/roles/coach/coach.py @@ -37,7 +37,18 @@ from src.database.coach_client_relationship.models import ClientCoachRequest, ClientCoachRelationship from src.database.session import get_session from src.database.account.models import Account, Availability, Notification -from src.database.telemetry.models import HealthMetrics, ClientTelemetry +from src.database.telemetry.models import ( + HealthMetrics, + ClientTelemetry, + StepCount, + DailyMoodSurvey, + DailyWorkoutSurvey, + DailyBodyMetricsSurvey, + DailyStepsSurvey, + DailyMealSurvey, + CompletedMealActivity, + CompletedWorkout, +) from src.database.coach.models import Coach, CoachCertifications, CoachExperience, CoachAvailability, Experience, Certifications from src.database.client.models import Client, FitnessGoals from src.database.role_management.models import CoachRequest @@ -534,3 +545,121 @@ def get_coach_earnings( total = float(result) if result is not None else 0.0 return CoachEarningsResponse(total_earnings=total, since=since) + +@router.get("/my_clients") +def get_my_clients(db = Depends(get_session), acc: Account = Depends(get_coach_account)): + """ + Returns all active clients for the logged-in coach. + Includes Client, Account without password, and telemetry objects. + """ + + if acc.coach_id is None: + raise HTTPException(403, detail="Coach profile required") + + relationships = db.exec( + select(ClientCoachRequest, ClientCoachRelationship) + .join(ClientCoachRelationship, ClientCoachRelationship.request_id == ClientCoachRequest.id) + .where( + ClientCoachRequest.coach_id == acc.coach_id, + ClientCoachRequest.is_accepted == True, + ClientCoachRelationship.is_active == True, + ClientCoachRelationship.client_blocked == False, + ClientCoachRelationship.coach_blocked == False, + ) + ).all() + + clients = [] + + for request, relationship in relationships: + client = db.get(Client, request.client_id) + + account = db.exec( + select(Account).where(Account.client_id == request.client_id) + ).first() + + telemetry_records = db.exec( + select(ClientTelemetry) + .where(ClientTelemetry.client_id == request.client_id) + .order_by(ClientTelemetry.date.desc()) + ).all() + + telemetry = [] + + for t in telemetry_records: + health_metrics = db.exec( + select(HealthMetrics).where(HealthMetrics.client_telemetry_id == t.id) + ).all() + + step_counts = db.exec( + select(StepCount).where(StepCount.client_telemetry_id == t.id) + ).all() + + mood_surveys = db.exec( + select(DailyMoodSurvey).where(DailyMoodSurvey.client_telemetry_id == t.id) + ).all() + + workout_surveys = db.exec( + select(DailyWorkoutSurvey).where(DailyWorkoutSurvey.client_telemetry_id == t.id) + ).all() + + body_metrics_surveys = db.exec( + select(DailyBodyMetricsSurvey).where(DailyBodyMetricsSurvey.client_telemetry_id == t.id) + ).all() + + steps_surveys = db.exec( + select(DailyStepsSurvey).where(DailyStepsSurvey.client_telemetry_id == t.id) + ).all() + + meal_surveys = db.exec( + select(DailyMealSurvey).where(DailyMealSurvey.client_telemetry_id == t.id) + ).all() + + completed_meals = db.exec( + select(CompletedMealActivity).where(CompletedMealActivity.client_telemetry_id == t.id) + ).all() + + completed_workouts = db.exec( + select(CompletedWorkout).where(CompletedWorkout.client_telemetry_id == t.id) + ).all() + + telemetry.append({ + "client_telemetry": t, + "health_metrics": health_metrics, + "step_counts": step_counts, + "mood_surveys": mood_surveys, + "workout_surveys": workout_surveys, + "body_metrics_surveys": body_metrics_surveys, + "steps_surveys": steps_surveys, + "meal_surveys": meal_surveys, + "completed_meals": completed_meals, + "completed_workouts": completed_workouts, + }) + + safe_account = None + + if account: + safe_account = { + "id": account.id, + "name": account.name, + "email": account.email, + "is_active": account.is_active, + "status": account.status, + "gender": account.gender, + "bio": account.bio, + "age": account.age, + "pfp_url": account.pfp_url, + "client_id": account.client_id, + "coach_id": account.coach_id, + "admin_id": account.admin_id, + "created_at": account.created_at, + } + + clients.append({ + "relationship_id": relationship.id, + "request_id": request.id, + "client": client, + "account": safe_account, + "telemetry": telemetry, + }) + + return clients \ No newline at end of file diff --git a/src/api/roles/shared/account.py b/src/api/roles/shared/account.py index 3b76661..12e29d7 100644 --- a/src/api/roles/shared/account.py +++ b/src/api/roles/shared/account.py @@ -3,9 +3,14 @@ from src import config from src.database.session import get_session -from src.database.account.models import Account +from src.database.account.models import Account, Availability, Notification +from src.database.client.models import Client, FitnessGoals +from src.database.coach.models import Coach, Experience, Certifications, CoachExperience, CoachCertifications +from src.database.payment.models import PricingPlan, PaymentInformation, Subscription, BillingCycle, Invoice +from src.database.telemetry.models import HealthMetrics, ClientTelemetry +from src.database.coach_client_relationship.models import ClientCoachRelationship, ClientCoachRequest from src.api.dependencies import get_account_from_bearer, get_active_account, get_account_even_if_inactive -from sqlmodel import Session +from sqlmodel import Session, select from pydantic import BaseModel, EmailStr from typing import Optional from datetime import datetime @@ -92,25 +97,159 @@ class ActivateAccountResponse(BaseModel): message: str +def get_affected_accounts(db: Session, account: Account) -> list[Account]: + affected_accounts_by_id: dict[int, Account] = {} + + def add_affected_account(affected_account: Account | None): + if affected_account is None: + return + if affected_account.id is None or affected_account.id == account.id: + return + affected_accounts_by_id[affected_account.id] = affected_account + + # If the deactivated account is a client, notify their active coach(es) + if account.client_id is not None: + relationships = db.exec( + select(ClientCoachRequest, ClientCoachRelationship) + .join( + ClientCoachRelationship, + ClientCoachRelationship.request_id == ClientCoachRequest.id, + ) + .where( + ClientCoachRequest.client_id == account.client_id, + ClientCoachRelationship.is_active == True, + ) + ).all() + + for request, relationship in relationships: + coach_account = db.exec( + select(Account).where(Account.coach_id == request.coach_id) + ).first() + + add_affected_account(coach_account) + + # If the deactivated account is a coach, notify their active client(s) + if account.coach_id is not None: + relationships = db.exec( + select(ClientCoachRequest, ClientCoachRelationship) + .join( + ClientCoachRelationship, + ClientCoachRelationship.request_id == ClientCoachRequest.id, + ) + .where( + ClientCoachRequest.coach_id == account.coach_id, + ClientCoachRelationship.is_active == True, + ) + ).all() + + for request, relationship in relationships: + client_account = db.exec( + select(Account).where(Account.client_id == request.client_id) + ).first() + + add_affected_account(client_account) + + return list(affected_accounts_by_id.values()) + + +def notify_affected_accounts( + db: Session, + deactivated_account: Account, + affected_accounts: list[Account], +): + for affected_account in affected_accounts: + if affected_account.id is None: + continue + + role = "account" + if deactivated_account.client_id is not None: + role = "client" + elif deactivated_account.coach_id is not None: + role = "coach" + + db.add( + Notification( + account_id=affected_account.id, + fav_category="account_deactivated", + message=f"{deactivated_account.name} has deactivated their account.", + details=f"{role.capitalize()} account {deactivated_account.id} was deactivated.", + is_read=False, + ) + ) + + +def delete_client_coach_mappings(db: Session, account: Account): + if account.client_id is not None: + requests = db.exec( + select(ClientCoachRequest) + .where(ClientCoachRequest.client_id == account.client_id) + ).all() + + for request in requests: + relationships = db.exec( + select(ClientCoachRelationship) + .where(ClientCoachRelationship.request_id == request.id) + ).all() + + for relationship in relationships: + db.delete(relationship) + + db.delete(request) + + if account.coach_id is not None: + requests = db.exec( + select(ClientCoachRequest) + .where(ClientCoachRequest.coach_id == account.coach_id) + ).all() + + for request in requests: + relationships = db.exec( + select(ClientCoachRelationship) + .where(ClientCoachRelationship.request_id == request.id) + ).all() + + for relationship in relationships: + db.delete(relationship) + + db.delete(request) + @router.post("/deactivate", response_model=DeactivateAccountResponse) def deactivate_account( db: Session = Depends(get_session), acc: Account = Depends(get_active_account), ): """ - Deactivate the current user's account. This sets is_active to False and prevents login/access. + Deactivate the current user's account. + This sets is_active to False and prevents access to protected routes. + It also notifies affected coaches/clients. """ account = db.get(Account, acc.id) + if account is None: - raise HTTPException(404, detail="Account not found") + raise HTTPException(status_code=404, detail="Account not found") + if not account.is_active: - return DeactivateAccountResponse(success=False, message="Account is already deactivated.") + return DeactivateAccountResponse( + success=False, + message="Account is already deactivated.", + ) + + affected_accounts = get_affected_accounts(db, account) + account.is_active = False db.add(account) - db.commit() + + notify_affected_accounts(db, account, affected_accounts) + + delete_client_coach_mappings(db, account) + + db.commit() db.refresh(account) - return DeactivateAccountResponse(success=True, message="Account deactivated successfully.") + return DeactivateAccountResponse( + success=True, + message="Account deactivated successfully.", + ) @router.post("/activate", response_model=ActivateAccountResponse) def activate_account( diff --git a/src/database/account/models.py b/src/database/account/models.py index d72dada..31958ac 100644 --- a/src/database/account/models.py +++ b/src/database/account/models.py @@ -15,6 +15,7 @@ class Account(SQLModelLU, table=True): name: str email: EmailStr = Field(index=True) is_active: bool = Field(default=True) + # status: str = Field(default="active") # auth, ONE of these needs to be here hashed_password: Optional[str] = Field(default=None) diff --git a/tests/test_client_routes.py b/tests/test_client_routes.py new file mode 100644 index 0000000..6cfd8ba --- /dev/null +++ b/tests/test_client_routes.py @@ -0,0 +1,71 @@ +def make_client_profile(test_client, auth_header): + payload = { + "fitness_goals": { + "goal_enum": "weight loss" + }, + "payment_information": { + "ccnum": "4111111111111111", + "cv": "123", + "exp_date": "2026-12-31" + }, + "availabilities": [ + { + "weekday": "monday", + "start_time": "08:00:00", + "end_time": "10:00:00" + } + ], + "initial_health_metric": { + "weight": 180 + } + } + + response = test_client.post( + "/roles/client/initial_survey", + json=payload, + headers=auth_header + ) + + assert response.status_code in (200, 409) + + +def test_get_my_coach(test_client, auth_header): + make_client_profile(test_client, auth_header) + + response = test_client.get( + "/roles/client/my_coach", + headers=auth_header + ) + + assert response.status_code in (200, 404) + + +def test_get_coach_profile(test_client, auth_header): + make_client_profile(test_client, auth_header) + + response = test_client.get( + "/roles/client/coach_profile/1", + headers=auth_header + ) + + assert response.status_code in (200, 404) + + +def test_get_progress_pictures(test_client, auth_header): + make_client_profile(test_client, auth_header) + + response = test_client.get( + "/roles/client/progress_pictures", + headers=auth_header + ) + + assert response.status_code == 200 + + +def test_get_my_clients(test_client, coach_auth_header): + response = test_client.get( + "/roles/coach/my_clients", + headers=coach_auth_header + ) + + assert response.status_code == 200 \ No newline at end of file diff --git a/tests/test_hirable_coaches.py b/tests/test_hirable_coaches.py index d443fd9..917c675 100644 --- a/tests/test_hirable_coaches.py +++ b/tests/test_hirable_coaches.py @@ -1,9 +1,11 @@ -from tests.payload_tools.auth import build_signup_payload, build_login_payload -from tests.payload_tools.client import build_client_init_payload -from tests.payload_tools.coach import build_coach_request_payload - -from src.database.reports.models import CoachReviews -from src.database.coach.models import Coach +from tests.payload_tools.auth import build_signup_payload, build_login_payload +from tests.payload_tools.client import build_client_init_payload +from tests.payload_tools.coach import build_coach_request_payload +from sqlmodel import select + +from src.database.account.models import Account +from src.database.reports.models import CoachReviews +from src.database.coach.models import Coach def _create_and_verify_coach(test_client, db_session, admin_auth_header, name, email_prefix, age=30, gender="non-binary", specialties=None): @@ -101,7 +103,7 @@ def test_hirable_coaches_privacy_and_empty_reviews(test_client, db_session, admi assert dana["avg_rating"] is None -def test_hirable_coaches_pagination_and_unauthorized(test_client, client_auth_header): +def test_hirable_coaches_pagination_and_unauthorized(test_client, client_auth_header): # Ensure endpoint accepts explicit skip/limit pagination params resp = test_client.get( "/roles/client/query/hirable_coaches?sort_by=avg_rating&order=desc&skip=0&limit=100", @@ -110,9 +112,79 @@ def test_hirable_coaches_pagination_and_unauthorized(test_client, client_auth_he assert resp.status_code == 200 # Unauthorized access returns 401 and clear message - resp2 = test_client.get( - "/roles/client/query/hirable_coaches?skip=0&limit=10", - headers={"Authorization": "Bearer invalid.token"}, - ) - assert resp2.status_code == 401 - assert resp2.json().get("detail") is not None + resp2 = test_client.get( + "/roles/client/query/hirable_coaches?skip=0&limit=10", + headers={"Authorization": "Bearer invalid.token"}, + ) + assert resp2.status_code == 401 + assert resp2.json().get("detail") is not None + + +def test_deactivated_coach_is_hidden_and_cannot_be_requested( + test_client, + db_session, + admin_auth_header, + client_auth_header, + create_client, +): + coach_id, coach_header = _create_and_verify_coach( + test_client, + db_session, + admin_auth_header, + "Hidden Coach", + "coach_hidden", + specialties="mobility", + ) + reviewer_header, reviewer_client_id = create_client(email_prefix="hidden_review") + _add_review(db_session, coach_id, reviewer_client_id, 5.0) + + deactivate_resp = test_client.post( + "/roles/shared/account/deactivate", + headers=coach_header, + ) + assert deactivate_resp.status_code == 200, deactivate_resp.text + + search_resp = test_client.get( + "/roles/client/query/hirable_coaches?specialty=mobility", + headers=client_auth_header, + ) + assert search_resp.status_code == 200 + assert all(item["coach_id"] != coach_id for item in search_resp.json()) + + request_resp = test_client.post( + f"/roles/client/request_coach/{coach_id}", + headers=client_auth_header, + ) + assert request_resp.status_code == 404 + + profile_resp = test_client.get( + f"/roles/client/coach_profile/{coach_id}", + headers=client_auth_header, + ) + assert profile_resp.status_code == 404 + + reviews_resp = test_client.get( + f"/roles/client/review/{coach_id}", + headers=client_auth_header, + ) + assert reviews_resp.status_code == 200 + assert reviews_resp.json()["reviews"] == [] + + activate_resp = test_client.post( + "/roles/shared/account/activate", + headers=coach_header, + ) + assert activate_resp.status_code == 200, activate_resp.text + + reviews_after_reactivate = test_client.get( + f"/roles/client/review/{coach_id}", + headers=client_auth_header, + ) + assert reviews_after_reactivate.status_code == 200 + assert len(reviews_after_reactivate.json()["reviews"]) == 1 + + coach_account = db_session.exec( + select(Account).where(Account.coach_id == coach_id) + ).first() + assert coach_account is not None + assert coach_account.is_active is True diff --git a/tests/test_shared_account_activation.py b/tests/test_shared_account_activation.py index 607567f..cd4980f 100644 --- a/tests/test_shared_account_activation.py +++ b/tests/test_shared_account_activation.py @@ -8,8 +8,8 @@ def test_account_deactivate_and_activate(test_client, auth_header): # Try to access a protected endpoint (should fail) resp2 = test_client.patch("/roles/shared/account/update", json={}, headers=auth_header) - assert resp2.status_code in (400, 401) - assert "inactive account" in resp2.text.lower() + assert resp2.status_code == 403 + assert "account deactivated" in resp2.text.lower() # Activate resp3 = test_client.post("/roles/shared/account/activate", headers=auth_header) diff --git a/tests/test_shared_account_notifications.py b/tests/test_shared_account_notifications.py new file mode 100644 index 0000000..ed20853 --- /dev/null +++ b/tests/test_shared_account_notifications.py @@ -0,0 +1,165 @@ +from sqlmodel import select +from datetime import datetime +from src.api.dependencies import create_jwt_token +from src.database.account.models import Notification, Account +from src.database.coach.models import Coach +from src.database.coach_client_relationship.models import ( + ClientCoachRequest, + ClientCoachRelationship, +) + + +def create_client_coach_relationship(db_session): + client = db_session.exec( + select(Account).where( + Account.client_id.is_not(None), + Account.is_active == True, + ) + ).first() + + assert client is not None + + coach = db_session.exec( + select(Account).where( + Account.coach_id.is_not(None), + Account.is_active == True, + Account.id != client.id, + ) + ).first() + + if coach is None: + coach_profile = Coach(verified=True) + db_session.add(coach_profile) + db_session.commit() + db_session.refresh(coach_profile) + + coach = Account( + name="Notification Test Coach", + email=f"notification_coach_{coach_profile.id}@example.com", + hashed_password="test-hash", + coach_id=coach_profile.id, + is_active=True, + ) + db_session.add(coach) + db_session.commit() + db_session.refresh(coach) + + assert coach is not None + + request = ClientCoachRequest( + client_id=client.client_id, + coach_id=coach.coach_id, + ) + + db_session.add(request) + db_session.commit() + db_session.refresh(request) + + relationship = ClientCoachRelationship( + request_id=request.id, + created_at=datetime.utcnow(), + is_active=True, + coach_blocked=False, + client_blocked=False, + ) + + db_session.add(relationship) + db_session.commit() + + return client, coach, request, relationship + + +def test_account_deactivate_sends_notification( + test_client, + db_session, + client_auth_header, + coach_auth_header, +): + client, coach, request, relationship = create_client_coach_relationship(db_session) + + client_auth_header = { + "Authorization": f"Bearer {create_jwt_token(client)}" + } + + resp = test_client.post( + "/roles/shared/account/deactivate", + headers=client_auth_header, + ) + + assert resp.status_code == 200 + assert resp.json()["success"] is True + + db_session.expire_all() + + notifications = list( + db_session.exec( + select(Notification).where(Notification.account_id == coach.id) + ) + ) + + remaining_relationship = db_session.get(ClientCoachRelationship, relationship.id) + remaining_request = db_session.get(ClientCoachRequest, request.id) + + assert notifications, "No notifications found for coach" + assert any( + n.details and "deactivated" in n.details.lower() + for n in notifications + ) + assert remaining_relationship is None + assert remaining_request is None + + activate_resp = test_client.post( + "/roles/shared/account/activate", + headers=client_auth_header, + ) + assert activate_resp.status_code == 200, activate_resp.text + assert db_session.get(ClientCoachRelationship, relationship.id) is None + assert db_session.get(ClientCoachRequest, request.id) is None + + +def test_account_deactivate_coach_notifies_client( + test_client, + db_session, + client_auth_header, + coach_auth_header, +): + client, coach, request, relationship = create_client_coach_relationship(db_session) + + coach_auth_header = { + "Authorization": f"Bearer {create_jwt_token(coach)}" + } + + resp = test_client.post( + "/roles/shared/account/deactivate", + headers=coach_auth_header, + ) + + assert resp.status_code == 200, resp.text + assert resp.json()["success"] is True + + db_session.expire_all() + + notifications = list( + db_session.exec( + select(Notification).where(Notification.account_id == client.id) + ) + ) + + remaining_relationship = db_session.get(ClientCoachRelationship, relationship.id) + remaining_request = db_session.get(ClientCoachRequest, request.id) + + assert notifications, "No notifications found for client" + assert any( + n.details and "deactivated" in n.details.lower() + for n in notifications + ) + assert remaining_relationship is None + assert remaining_request is None + + activate_resp = test_client.post( + "/roles/shared/account/activate", + headers=coach_auth_header, + ) + assert activate_resp.status_code == 200, activate_resp.text + assert db_session.get(ClientCoachRelationship, relationship.id) is None + assert db_session.get(ClientCoachRequest, request.id) is None