diff --git a/backend/linkpulse/dependencies.py b/backend/linkpulse/dependencies.py index db24be2..969a3e9 100644 --- a/backend/linkpulse/dependencies.py +++ b/backend/linkpulse/dependencies.py @@ -4,6 +4,8 @@ from fastapi import HTTPException, Request, Response, status from limits.aio.strategies import MovingWindowRateLimiter from limits.aio.storage import MemoryStorage from limits import parse +from linkpulse.models import Session +from dataclasses import dataclass storage = MemoryStorage() strategy = MovingWindowRateLimiter(storage) @@ -39,3 +41,37 @@ class RateLimiter: headers={"Retry-After": self.retry_after}, ) return True + + +@dataclass +class SessionModel: + user_id: str + session_id: str + expires_at: int + + +class SessionDependency: + def __init__(self, required: bool = False): + self.required = required + + async def __call__(self, request: Request, response: Response): + session_token = request.cookies.get("session") + + # If not present, raise 401 if required + if session_token is None: + if self.required: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized") + return None + + # Get session from database + session = Session.get_or_none(Session.token == session_token) + + # This doesn't differentiate between expired or completely invalid sessions + if session is None or session.is_expired(revoke=True): + if self.required: + logger.debug("Session Cookie Revoked", token=session_token) + response.set_cookie("session", "", max_age=0) + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized") + return None + + return session diff --git a/backend/linkpulse/routers/auth.py b/backend/linkpulse/routers/auth.py index 8ab4e61..ccb247d 100644 --- a/backend/linkpulse/routers/auth.py +++ b/backend/linkpulse/routers/auth.py @@ -1,11 +1,9 @@ from datetime import datetime, timedelta -from math import e -from typing import Optional, Tuple +from typing import Annotated, Optional, Tuple import structlog from fastapi import APIRouter, Depends, Response, status -from fastapi.responses import JSONResponse -from linkpulse.dependencies import RateLimiter +from linkpulse.dependencies import SessionDependency, RateLimiter, SessionModel from linkpulse.models import Session, User from linkpulse.utilities import utc_now from pwdlib import PasswordHash @@ -114,13 +112,20 @@ async def login(body: LoginBody, response: Response): return {"email": user.email, "expiry": session.expiry} -@router.post("/api/logout") -async def logout(): - # TODO: Force logout parameter, logout ALL sessions for User - # Get session token from Cookie - # Delete session - # Return 200 - pass +@router.post("/api/logout", status_code=status.HTTP_200_OK) +async def logout( + response: Response, + session: Annotated[Session, Depends(SessionDependency(required=True))], + all: bool = False, +): + # We can assume the session is valid via the dependency + if not all: + session.delete_instance() + else: + count = Session.delete().where(Session.user == session.user).execute() + logger.debug("All sessions deleted", user=session.user.email, count=count) + + response.set_cookie("session", "", max_age=0) @router.post("/api/register") @@ -134,9 +139,16 @@ async def register(): pass +@router.get("/api/session") +async def session(session: Annotated[SessionModel, Depends(SessionDependency(required=True))]): + # Returns the session information for the current session + return {} + + @router.get("/api/sessions") -async def sessions(): - pass +async def sessions(session: Annotated[SessionModel, Depends(SessionDependency(required=True))]): + # Returns a list of all active sessions for this user + return {} # GET /api/user/{id}/sessions