Files
linkpulse/backend/linkpulse/app.py
2024-11-01 15:47:28 -05:00

201 lines
5.6 KiB
Python

import logging
import os
import random
from collections import defaultdict
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from datetime import datetime
from typing import AsyncIterator
import structlog
import human_readable
from apscheduler.schedulers.background import BackgroundScheduler # type: ignore
from apscheduler.triggers.interval import IntervalTrigger # type: ignore
from dotenv import load_dotenv
from fastapi import FastAPI, Request, Response, status
from fastapi_cache import FastAPICache
from fastapi_cache.backends.inmemory import InMemoryBackend
from fastapi_cache.decorator import cache
from linkpulse.utilities import get_ip, hide_ip, pluralize
from peewee import PostgresqlDatabase
from psycopg2.extras import execute_values
if not structlog.is_configured():
import linkpulse.logging
load_dotenv(dotenv_path=".env")
from linkpulse import models, responses # type: ignore
# global variables
is_development = os.getenv("ENVIRONMENT") == "development"
db: PostgresqlDatabase = models.BaseModel._meta.database # type: ignore
logger = structlog.get_logger(__name__)
def flush_ips():
if len(app.state.buffered_updates) == 0:
return
try:
with db.atomic():
sql = """
WITH updates (ip, last_seen, increment) AS (VALUES %s)
INSERT INTO ipaddress (ip, last_seen, count)
SELECT ip, last_seen, increment
FROM updates
ON CONFLICT (ip)
DO UPDATE
SET count = ipaddress.count + EXCLUDED.count, last_seen = EXCLUDED.last_seen;
"""
rows = [
(ip, data.last_seen, data.count)
for ip, data in app.state.buffered_updates.items()
]
cur = db.cursor()
execute_values(cur, sql, rows)
except:
logging.error("Failed to flush IPs to the database.")
i = len(app.state.buffered_updates)
logging.debug("Flushed {} IP{} to the database.".format(i, pluralize(i)))
# Finish up
app.state.buffered_updates.clear()
scheduler = BackgroundScheduler()
scheduler.add_job(flush_ips, IntervalTrigger(seconds=5))
@asynccontextmanager
async def lifespan(_: FastAPI) -> AsyncIterator[None]:
if is_development:
# 42 is the answer to everything
random.seed(42)
# Generate a pool of random IP addresses
app.state.ip_pool = [
".".join(str(random.randint(0, 255)) for _ in range(4)) for _ in range(50)
]
# Connect to database, ensure specific tables exist
db.connect()
db.create_tables([models.IPAddress])
# Delete all randomly generated IP addresses
with db.atomic():
logging.info(
"Deleting Randomized IP Addresses",
{"ip_pool_count": len(app.state.ip_pool)},
)
query = models.IPAddress.delete().where(
models.IPAddress.ip << app.state.ip_pool
)
rowcount = query.execute()
logger.info("Randomized IP Addresses deleted", {"rowcount": rowcount})
FastAPICache.init(
backend=InMemoryBackend(), prefix="fastapi-cache", cache_status_header="X-Cache"
)
app.state.buffered_updates = defaultdict(IPCounter)
scheduler.start()
yield
scheduler.shutdown()
flush_ips()
if not db.is_closed():
db.close()
@dataclass
class IPCounter:
# Note: This is not the true 'seen' count, but the count of how many times the IP has been seen since the last flush.
count: int = 0
last_seen: datetime = field(default_factory=datetime.now)
app = FastAPI(lifespan=lifespan)
if is_development:
from fastapi.middleware.cors import CORSMiddleware
origins = [
"http://localhost",
"http://localhost:5173",
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/health")
async def health():
return "OK"
@app.get("/api/migration")
@cache(expire=60)
async def get_migration():
"""
Returns the details of the most recent migration.
"""
# Kind of insecure, but this is just a demo thing to show that migratehistory is available.
cursor = db.execute_sql(
"SELECT name, migrated_at FROM migratehistory ORDER BY migrated_at DESC LIMIT 1"
)
name, migrated_at = cursor.fetchone()
return {"name": name, "migrated_at": migrated_at}
@app.get("/api/ips")
async def get_ips(request: Request, response: Response):
"""
Returns a list of partially redacted IP addresses, as well as submitting the user's IP address to the database (buffered).
"""
now = datetime.now()
# Get the user's IP address
user_ip = get_ip(request)
# If the IP address is not found, return an error
if user_ip is None:
logging.info("No IP found!")
response.status_code = status.HTTP_403_FORBIDDEN
return {"error": "Unable to handle request."}
# Update the buffered updates
app.state.buffered_updates[user_ip].count += 1
app.state.buffered_updates[user_ip].last_seen = now
# Query the latest IPs
latest_ips = (
models.IPAddress.select(
models.IPAddress.ip, models.IPAddress.last_seen, models.IPAddress.count
)
.order_by(models.IPAddress.last_seen.desc())
.limit(10)
)
return {
"ips": [
responses.SeenIP(
ip=hide_ip(ip.ip) if ip.ip != user_ip else ip.ip,
last_seen=human_readable.date_time(ip.last_seen),
count=ip.count,
)
for ip in latest_ips
]
}