development mode checks

This commit is contained in:
2024-10-24 04:18:08 -05:00
parent c0d135d8a8
commit 4267d40611

View File

@@ -1,4 +1,5 @@
import logging import logging
import os
import random import random
from collections import defaultdict from collections import defaultdict
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
@@ -12,7 +13,6 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi_cache import FastAPICache from fastapi_cache import FastAPICache
from fastapi_cache.backends.inmemory import InMemoryBackend from fastapi_cache.backends.inmemory import InMemoryBackend
from fastapi_cache.decorator import cache from fastapi_cache.decorator import cache
from fastapi_utils.tasks import repeat_every
import human_readable import human_readable
from linkpulse.utilities import get_ip, hide_ip, pluralize from linkpulse.utilities import get_ip, hide_ip, pluralize
from peewee import PostgresqlDatabase from peewee import PostgresqlDatabase
@@ -24,6 +24,7 @@ load_dotenv(dotenv_path=".env")
from linkpulse import models, responses # type: ignore from linkpulse import models, responses # type: ignore
is_development = os.getenv("ENVIRONMENT") == "development"
db: PostgresqlDatabase = models.BaseModel._meta.database db: PostgresqlDatabase = models.BaseModel._meta.database
@@ -69,6 +70,7 @@ async def lifespan(_: FastAPI) -> AsyncIterator[None]:
backend=InMemoryBackend(), prefix="fastapi-cache", cache_status_header="X-Cache" backend=InMemoryBackend(), prefix="fastapi-cache", cache_status_header="X-Cache"
) )
if is_development:
random.seed(42) random.seed(42)
app.state.ip_pool = [ app.state.ip_pool = [
".".join(str(random.randint(0, 255)) for _ in range(4)) for _ in range(50) ".".join(str(random.randint(0, 255)) for _ in range(4)) for _ in range(50)
@@ -92,18 +94,20 @@ class IPCounter:
app = FastAPI(lifespan=lifespan) app = FastAPI(lifespan=lifespan)
origins = [
if is_development:
origins = [
"http://localhost", "http://localhost",
"http://localhost:5173", "http://localhost:5173",
] ]
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=origins, allow_origins=origins,
allow_credentials=True, allow_credentials=True,
allow_methods=["*"], allow_methods=["*"],
allow_headers=["*"], allow_headers=["*"],
) )
@app.on_event("startup") @app.on_event("startup")
@@ -143,8 +147,8 @@ async def get_ips(request: Request, response: Response):
Returns a list of partially redacted IP addresses, as well as submitting the user's IP address to the database (buffered). Returns a list of partially redacted IP addresses, as well as submitting the user's IP address to the database (buffered).
""" """
now = datetime.now() now = datetime.now()
# user_ip = get_ip(request)
user_ip = random.choice(app.state.ip_pool) user_ip = get_ip(request) if not is_development else random.choice(app.state.ip_pool)
if user_ip is None: if user_ip is None:
print("No IP found!") print("No IP found!")
response.status_code = status.HTTP_403_FORBIDDEN response.status_code = status.HTTP_403_FORBIDDEN