diff --git a/backend/linkpulse/app.py b/backend/linkpulse/app.py index 69b9bf2..2d238d5 100644 --- a/backend/linkpulse/app.py +++ b/backend/linkpulse/app.py @@ -1,4 +1,5 @@ import logging +import os import random from collections import defaultdict from contextlib import asynccontextmanager @@ -12,7 +13,6 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi_cache import FastAPICache from fastapi_cache.backends.inmemory import InMemoryBackend from fastapi_cache.decorator import cache -from fastapi_utils.tasks import repeat_every import human_readable from linkpulse.utilities import get_ip, hide_ip, pluralize from peewee import PostgresqlDatabase @@ -24,6 +24,7 @@ load_dotenv(dotenv_path=".env") from linkpulse import models, responses # type: ignore +is_development = os.getenv("ENVIRONMENT") == "development" db: PostgresqlDatabase = models.BaseModel._meta.database @@ -69,10 +70,11 @@ async def lifespan(_: FastAPI) -> AsyncIterator[None]: backend=InMemoryBackend(), prefix="fastapi-cache", cache_status_header="X-Cache" ) - random.seed(42) - app.state.ip_pool = [ - ".".join(str(random.randint(0, 255)) for _ in range(4)) for _ in range(50) - ] + if is_development: + random.seed(42) + app.state.ip_pool = [ + ".".join(str(random.randint(0, 255)) for _ in range(4)) for _ in range(50) + ] app.state.buffered_updates = defaultdict(IPCounter) scheduler.start() @@ -92,18 +94,20 @@ class IPCounter: app = FastAPI(lifespan=lifespan) -origins = [ - "http://localhost", - "http://localhost:5173", -] -app.add_middleware( - CORSMiddleware, - allow_origins=origins, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) +if is_development: + origins = [ + "http://localhost", + "http://localhost:5173", + ] + + app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) @app.on_event("startup") @@ -143,8 +147,8 @@ async def get_ips(request: Request, response: Response): Returns a list of partially redacted IP addresses, as well as submitting the user's IP address to the database (buffered). """ now = datetime.now() - # user_ip = get_ip(request) - user_ip = random.choice(app.state.ip_pool) + + user_ip = get_ip(request) if not is_development else random.choice(app.state.ip_pool) if user_ip is None: print("No IP found!") response.status_code = status.HTTP_403_FORBIDDEN @@ -154,7 +158,7 @@ async def get_ips(request: Request, response: Response): app.state.buffered_updates[user_ip].count += 1 app.state.buffered_updates[user_ip].last_seen = now - # Return the IP addresses + # Return the IP addresses latest_ips = ( models.IPAddress.select(models.IPAddress.ip, models.IPAddress.last_seen, models.IPAddress.count) .order_by(models.IPAddress.last_seen.desc())