finish login function, true hash of user fixture

This commit is contained in:
2024-11-10 00:16:37 -06:00
parent 10919d0333
commit cb8dd80f33
5 changed files with 67 additions and 39 deletions

View File

@@ -4,20 +4,13 @@ It also provides a base model with database connection details.
""" """
import datetime import datetime
import secrets
from os import getenv from os import getenv
from typing import Optional from typing import Optional
import structlog import structlog
from linkpulse.utilities import utc_now from linkpulse.utilities import utc_now
from peewee import ( from peewee import AutoField, BitField, CharField, Check, DateTimeField, ForeignKeyField, Model
AutoField,
BitField,
CharField,
Check,
DateTimeField,
ForeignKeyField,
Model,
)
from playhouse.db_url import connect from playhouse.db_url import connect
logger = structlog.get_logger() logger = structlog.get_logger()
@@ -81,13 +74,16 @@ class Session(BaseModel):
), ),
] ]
@classmethod
def generate_token(cls) -> str:
alphabet = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
return "".join(secrets.choice(alphabet) for _ in range(32))
@property @property
def expiry_utc(self) -> datetime.datetime: def expiry_utc(self) -> datetime.datetime:
return self.expiry.replace(tzinfo=datetime.timezone.utc) # type: ignore return self.expiry.replace(tzinfo=datetime.timezone.utc) # type: ignore
def is_expired( def is_expired(self, revoke: bool = True, now: Optional[datetime.datetime] = None) -> bool:
self, revoke: bool = True, now: Optional[datetime.datetime] = None
) -> bool:
""" """
Check if the session is expired. If `revoke` is True, the session will be automatically revoked if it is expired. Check if the session is expired. If `revoke` is True, the session will be automatically revoked if it is expired.
""" """

View File

@@ -1,23 +1,32 @@
from typing import Tuple, Optional from datetime import datetime, timedelta
from math import e
from typing import Optional, Tuple
from fastapi import status import structlog
from fastapi.responses import ORJSONResponse from fastapi import APIRouter, Depends, Response, status
from fastapi.responses import JSONResponse
from linkpulse.dependencies import RateLimiter
from linkpulse.models import Session, User
from linkpulse.utilities import utc_now
from pwdlib import PasswordHash from pwdlib import PasswordHash
from pwdlib.hashers.argon2 import Argon2Hasher from pwdlib.hashers.argon2 import Argon2Hasher
from fastapi import APIRouter, Depends
from pydantic import BaseModel, EmailStr, Field from pydantic import BaseModel, EmailStr, Field
from linkpulse.dependencies import RateLimiter
from linkpulse.models import User, Session logger = structlog.get_logger()
router = APIRouter() router = APIRouter()
hasher = PasswordHash([Argon2Hasher()]) hasher = PasswordHash([Argon2Hasher()])
dummy_hash = "$argon2id$v=19$m=65536,t=3,p=4$Ii3hm5/NqcJddQDFK24Wtw$I99xV/qkaLROo0VZcvaZrYMAD9RTcWzxY5/RbMoRLQ4" dummy_hash = (
"$argon2id$v=19$m=65536,t=3,p=4$Ii3hm5/NqcJddQDFK24Wtw$I99xV/qkaLROo0VZcvaZrYMAD9RTcWzxY5/RbMoRLQ4"
)
# Session expiry times
default_session_expiry = timedelta(hours=12)
remember_me_session_expiry = timedelta(days=14)
def validate_session( def validate_session(token: str, user: bool = True) -> Tuple[bool, bool, Optional[User]]:
token: str, user: bool = True
) -> Tuple[bool, bool, Optional[User]]:
"""Given a token, validate that the session exists and is not expired. """Given a token, validate that the session exists and is not expired.
This function has side effects: This function has side effects:
@@ -49,8 +58,8 @@ def validate_session(
class LoginBody(BaseModel): class LoginBody(BaseModel):
email: EmailStr email: EmailStr # May be a heavy check; profiling could determine if this is necessary
password: str = Field(min_length=1) password: str = Field(min_length=1) # Basic check, registration will have more stringent requirements
remember_me: bool = False remember_me: bool = False
@@ -58,27 +67,50 @@ class LoginError(BaseModel):
error: str error: str
@router.post("/api/login", dependencies=[Depends(RateLimiter("6/minute"))]) class LoginSuccess(BaseModel):
async def login(body: LoginBody): email: EmailStr
expiry: datetime
@router.post(
"/api/login",
responses={200: {"model": LoginSuccess}, 401: {"model": LoginError}},
dependencies=[Depends(RateLimiter("6/minute"))],
)
async def login(body: LoginBody, response: Response):
# Acquire user by email # Acquire user by email
user = User.get_or_none(User.email == body.email) user = User.get_or_none(User.email == body.email)
if user is None: if user is None:
# Hash regardless of user existence to prevent timing attacks # Hash regardless of user existence to prevent timing attacks
hasher.verify(body.password, dummy_hash) hasher.verify(body.password, dummy_hash)
return ORJSONResponse( response.status_code = status.HTTP_401_UNAUTHORIZED
status_code=status.HTTP_401_UNAUTHORIZED, return LoginError(error="Invalid email or password")
content=LoginError(error="Invalid email or password"),
)
# valid, updated_hash = hasher.verify_and_update(body.password, existing_hash) logger.warning("Hash", hash=user.password_hash)
valid, updated_hash = hasher.verify_and_update(body.password, user.password_hash)
# Check if user exists, return 401 if not
# Check if password matches, return 401 if not # Check if password matches, return 401 if not
if not valid:
response.status_code = status.HTTP_401_UNAUTHORIZED
return LoginError(error="Invalid email or password")
# Update password hash if necessary
if updated_hash:
user.password_hash = updated_hash
user.save()
# Create session # Create session
token = Session.generate_token()
session = Session.create(
token=token,
user=user,
expiry=utc_now() + (remember_me_session_expiry if body.remember_me else default_session_expiry),
)
# Set Cookie of session token # Set Cookie of session token
# Return 200 with mild user information response.set_cookie("session", token, samesite="strict")
pass return {"email": user.email, "expiry": session.expiry}
@router.post("/api/logout") @router.post("/api/logout")

View File

@@ -5,12 +5,11 @@ from linkpulse.tests.test_user import user
def test_auth_login(user): def test_auth_login(user):
args = {"email": "test@test.com", "password": "test"} args = {"email": user.email, "password": "password"}
with TestClient(app) as client: with TestClient(app) as client:
response = client.post("/api/login", json=args) response = client.post("/api/login", json=args)
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
# assert response.json()["token"] is not None
response = client.post("/api/login", json={**args, "email": "invalid_email"}) response = client.post("/api/login", json={**args, "email": "invalid_email"})
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY

View File

@@ -2,12 +2,13 @@ import structlog
from fastapi import status from fastapi import status
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from linkpulse.app import app from linkpulse.app import app
from linkpulse.tests.test_user import user
logger = structlog.get_logger() logger = structlog.get_logger()
def test_rate_limit(): def test_rate_limit(user):
args = {"email": "test@test.com", "password": "test"} args = {"email": user.email, "password": "password"}
with TestClient(app) as client: with TestClient(app) as client:
for _ in range(6): for _ in range(6):

View File

@@ -2,7 +2,7 @@ import pytest
import structlog import structlog
from linkpulse.models import User from linkpulse.models import User
from linkpulse.routers.auth import hasher from linkpulse.routers.auth import hasher
from linkpulse.tests.random import random_email, random_string from linkpulse.tests.random import random_email
logger = structlog.get_logger() logger = structlog.get_logger()