mirror of
https://github.com/Xevion/linkpulse.git
synced 2025-12-06 01:15:30 -06:00
finish login function, true hash of user fixture
This commit is contained in:
@@ -4,20 +4,13 @@ It also provides a base model with database connection details.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import secrets
|
||||
from os import getenv
|
||||
from typing import Optional
|
||||
|
||||
import structlog
|
||||
from linkpulse.utilities import utc_now
|
||||
from peewee import (
|
||||
AutoField,
|
||||
BitField,
|
||||
CharField,
|
||||
Check,
|
||||
DateTimeField,
|
||||
ForeignKeyField,
|
||||
Model,
|
||||
)
|
||||
from peewee import AutoField, BitField, CharField, Check, DateTimeField, ForeignKeyField, Model
|
||||
from playhouse.db_url import connect
|
||||
|
||||
logger = structlog.get_logger()
|
||||
@@ -81,13 +74,16 @@ class Session(BaseModel):
|
||||
),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def generate_token(cls) -> str:
|
||||
alphabet = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
return "".join(secrets.choice(alphabet) for _ in range(32))
|
||||
|
||||
@property
|
||||
def expiry_utc(self) -> datetime.datetime:
|
||||
return self.expiry.replace(tzinfo=datetime.timezone.utc) # type: ignore
|
||||
|
||||
def is_expired(
|
||||
self, revoke: bool = True, now: Optional[datetime.datetime] = None
|
||||
) -> bool:
|
||||
def is_expired(self, revoke: bool = True, now: Optional[datetime.datetime] = None) -> bool:
|
||||
"""
|
||||
Check if the session is expired. If `revoke` is True, the session will be automatically revoked if it is expired.
|
||||
"""
|
||||
|
||||
@@ -1,23 +1,32 @@
|
||||
from typing import Tuple, Optional
|
||||
from datetime import datetime, timedelta
|
||||
from math import e
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from fastapi import status
|
||||
from fastapi.responses import ORJSONResponse
|
||||
import structlog
|
||||
from fastapi import APIRouter, Depends, Response, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from linkpulse.dependencies import RateLimiter
|
||||
from linkpulse.models import Session, User
|
||||
from linkpulse.utilities import utc_now
|
||||
from pwdlib import PasswordHash
|
||||
from pwdlib.hashers.argon2 import Argon2Hasher
|
||||
from fastapi import APIRouter, Depends
|
||||
from pydantic import BaseModel, EmailStr, Field
|
||||
from linkpulse.dependencies import RateLimiter
|
||||
from linkpulse.models import User, Session
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
hasher = PasswordHash([Argon2Hasher()])
|
||||
dummy_hash = "$argon2id$v=19$m=65536,t=3,p=4$Ii3hm5/NqcJddQDFK24Wtw$I99xV/qkaLROo0VZcvaZrYMAD9RTcWzxY5/RbMoRLQ4"
|
||||
dummy_hash = (
|
||||
"$argon2id$v=19$m=65536,t=3,p=4$Ii3hm5/NqcJddQDFK24Wtw$I99xV/qkaLROo0VZcvaZrYMAD9RTcWzxY5/RbMoRLQ4"
|
||||
)
|
||||
|
||||
# Session expiry times
|
||||
default_session_expiry = timedelta(hours=12)
|
||||
remember_me_session_expiry = timedelta(days=14)
|
||||
|
||||
|
||||
def validate_session(
|
||||
token: str, user: bool = True
|
||||
) -> Tuple[bool, bool, Optional[User]]:
|
||||
def validate_session(token: str, user: bool = True) -> Tuple[bool, bool, Optional[User]]:
|
||||
"""Given a token, validate that the session exists and is not expired.
|
||||
|
||||
This function has side effects:
|
||||
@@ -49,8 +58,8 @@ def validate_session(
|
||||
|
||||
|
||||
class LoginBody(BaseModel):
|
||||
email: EmailStr
|
||||
password: str = Field(min_length=1)
|
||||
email: EmailStr # May be a heavy check; profiling could determine if this is necessary
|
||||
password: str = Field(min_length=1) # Basic check, registration will have more stringent requirements
|
||||
remember_me: bool = False
|
||||
|
||||
|
||||
@@ -58,27 +67,50 @@ class LoginError(BaseModel):
|
||||
error: str
|
||||
|
||||
|
||||
@router.post("/api/login", dependencies=[Depends(RateLimiter("6/minute"))])
|
||||
async def login(body: LoginBody):
|
||||
class LoginSuccess(BaseModel):
|
||||
email: EmailStr
|
||||
expiry: datetime
|
||||
|
||||
|
||||
@router.post(
|
||||
"/api/login",
|
||||
responses={200: {"model": LoginSuccess}, 401: {"model": LoginError}},
|
||||
dependencies=[Depends(RateLimiter("6/minute"))],
|
||||
)
|
||||
async def login(body: LoginBody, response: Response):
|
||||
# Acquire user by email
|
||||
user = User.get_or_none(User.email == body.email)
|
||||
|
||||
if user is None:
|
||||
# Hash regardless of user existence to prevent timing attacks
|
||||
hasher.verify(body.password, dummy_hash)
|
||||
return ORJSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content=LoginError(error="Invalid email or password"),
|
||||
)
|
||||
response.status_code = status.HTTP_401_UNAUTHORIZED
|
||||
return LoginError(error="Invalid email or password")
|
||||
|
||||
# valid, updated_hash = hasher.verify_and_update(body.password, existing_hash)
|
||||
logger.warning("Hash", hash=user.password_hash)
|
||||
valid, updated_hash = hasher.verify_and_update(body.password, user.password_hash)
|
||||
|
||||
# Check if user exists, return 401 if not
|
||||
# Check if password matches, return 401 if not
|
||||
if not valid:
|
||||
response.status_code = status.HTTP_401_UNAUTHORIZED
|
||||
return LoginError(error="Invalid email or password")
|
||||
|
||||
# Update password hash if necessary
|
||||
if updated_hash:
|
||||
user.password_hash = updated_hash
|
||||
user.save()
|
||||
|
||||
# Create session
|
||||
token = Session.generate_token()
|
||||
session = Session.create(
|
||||
token=token,
|
||||
user=user,
|
||||
expiry=utc_now() + (remember_me_session_expiry if body.remember_me else default_session_expiry),
|
||||
)
|
||||
|
||||
# Set Cookie of session token
|
||||
# Return 200 with mild user information
|
||||
pass
|
||||
response.set_cookie("session", token, samesite="strict")
|
||||
return {"email": user.email, "expiry": session.expiry}
|
||||
|
||||
|
||||
@router.post("/api/logout")
|
||||
|
||||
@@ -5,12 +5,11 @@ from linkpulse.tests.test_user import user
|
||||
|
||||
|
||||
def test_auth_login(user):
|
||||
args = {"email": "test@test.com", "password": "test"}
|
||||
args = {"email": user.email, "password": "password"}
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.post("/api/login", json=args)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
# assert response.json()["token"] is not None
|
||||
|
||||
response = client.post("/api/login", json={**args, "email": "invalid_email"})
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
@@ -2,12 +2,13 @@ import structlog
|
||||
from fastapi import status
|
||||
from fastapi.testclient import TestClient
|
||||
from linkpulse.app import app
|
||||
from linkpulse.tests.test_user import user
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
def test_rate_limit():
|
||||
args = {"email": "test@test.com", "password": "test"}
|
||||
def test_rate_limit(user):
|
||||
args = {"email": user.email, "password": "password"}
|
||||
|
||||
with TestClient(app) as client:
|
||||
for _ in range(6):
|
||||
|
||||
@@ -2,7 +2,7 @@ import pytest
|
||||
import structlog
|
||||
from linkpulse.models import User
|
||||
from linkpulse.routers.auth import hasher
|
||||
from linkpulse.tests.random import random_email, random_string
|
||||
from linkpulse.tests.random import random_email
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user