development mode checks

This commit is contained in:
2024-10-24 04:18:08 -05:00
parent c0d135d8a8
commit 4267d40611

View File

@@ -1,4 +1,5 @@
import logging import logging
import os
import random import random
from collections import defaultdict from collections import defaultdict
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
@@ -12,7 +13,6 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi_cache import FastAPICache from fastapi_cache import FastAPICache
from fastapi_cache.backends.inmemory import InMemoryBackend from fastapi_cache.backends.inmemory import InMemoryBackend
from fastapi_cache.decorator import cache from fastapi_cache.decorator import cache
from fastapi_utils.tasks import repeat_every
import human_readable import human_readable
from linkpulse.utilities import get_ip, hide_ip, pluralize from linkpulse.utilities import get_ip, hide_ip, pluralize
from peewee import PostgresqlDatabase from peewee import PostgresqlDatabase
@@ -24,6 +24,7 @@ load_dotenv(dotenv_path=".env")
from linkpulse import models, responses # type: ignore from linkpulse import models, responses # type: ignore
is_development = os.getenv("ENVIRONMENT") == "development"
db: PostgresqlDatabase = models.BaseModel._meta.database db: PostgresqlDatabase = models.BaseModel._meta.database
@@ -69,10 +70,11 @@ async def lifespan(_: FastAPI) -> AsyncIterator[None]:
backend=InMemoryBackend(), prefix="fastapi-cache", cache_status_header="X-Cache" backend=InMemoryBackend(), prefix="fastapi-cache", cache_status_header="X-Cache"
) )
random.seed(42) if is_development:
app.state.ip_pool = [ random.seed(42)
".".join(str(random.randint(0, 255)) for _ in range(4)) for _ in range(50) app.state.ip_pool = [
] ".".join(str(random.randint(0, 255)) for _ in range(4)) for _ in range(50)
]
app.state.buffered_updates = defaultdict(IPCounter) app.state.buffered_updates = defaultdict(IPCounter)
scheduler.start() scheduler.start()
@@ -92,18 +94,20 @@ class IPCounter:
app = FastAPI(lifespan=lifespan) app = FastAPI(lifespan=lifespan)
origins = [
"http://localhost",
"http://localhost:5173",
]
app.add_middleware( if is_development:
CORSMiddleware, origins = [
allow_origins=origins, "http://localhost",
allow_credentials=True, "http://localhost:5173",
allow_methods=["*"], ]
allow_headers=["*"],
) app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.on_event("startup") @app.on_event("startup")
@@ -143,8 +147,8 @@ async def get_ips(request: Request, response: Response):
Returns a list of partially redacted IP addresses, as well as submitting the user's IP address to the database (buffered). Returns a list of partially redacted IP addresses, as well as submitting the user's IP address to the database (buffered).
""" """
now = datetime.now() now = datetime.now()
# user_ip = get_ip(request)
user_ip = random.choice(app.state.ip_pool) user_ip = get_ip(request) if not is_development else random.choice(app.state.ip_pool)
if user_ip is None: if user_ip is None:
print("No IP found!") print("No IP found!")
response.status_code = status.HTTP_403_FORBIDDEN response.status_code = status.HTTP_403_FORBIDDEN
@@ -154,7 +158,7 @@ async def get_ips(request: Request, response: Response):
app.state.buffered_updates[user_ip].count += 1 app.state.buffered_updates[user_ip].count += 1
app.state.buffered_updates[user_ip].last_seen = now app.state.buffered_updates[user_ip].last_seen = now
# Return the IP addresses # Return the IP addresses
latest_ips = ( latest_ips = (
models.IPAddress.select(models.IPAddress.ip, models.IPAddress.last_seen, models.IPAddress.count) models.IPAddress.select(models.IPAddress.ip, models.IPAddress.last_seen, models.IPAddress.count)
.order_by(models.IPAddress.last_seen.desc()) .order_by(models.IPAddress.last_seen.desc())