Reformat all python files, remove unused imports

This commit is contained in:
2024-11-01 16:13:01 -05:00
parent 4b85153065
commit 10b93d41d6
7 changed files with 77 additions and 59 deletions

View File

@@ -1,4 +1,3 @@
import logging
import os
import random
@@ -18,7 +17,7 @@ from fastapi_cache import FastAPICache
from fastapi_cache.backends.inmemory import InMemoryBackend
from fastapi_cache.decorator import cache
import structlog
from linkpulse.utilities import get_ip, hide_ip, pluralize
from linkpulse.utilities import get_ip, hide_ip
from linkpulse.middleware import LoggingMiddleware
from peewee import PostgresqlDatabase
from psycopg2.extras import execute_values
@@ -87,7 +86,9 @@ async def lifespan(_: FastAPI) -> AsyncIterator[None]:
# Delete all randomly generated IP addresses
with db.atomic():
logger.info("Deleting Randomized IP Addresses", ip_pool_count=len(app.state.ip_pool))
logger.info(
"Deleting Randomized IP Addresses", ip_pool_count=len(app.state.ip_pool)
)
query = models.IPAddress.delete().where(
models.IPAddress.ip << app.state.ip_pool
)

View File

@@ -24,17 +24,23 @@ def drop_color_message_key(_, __, event_dict: EventDict) -> EventDict:
return event_dict
def setup_logging(json_logs: Optional[bool] = None, log_level: Optional[str] = None) -> None:
def setup_logging(
json_logs: Optional[bool] = None, log_level: Optional[str] = None
) -> None:
json_logs = json_logs or os.getenv("LOG_JSON_FORMAT", "true").lower() == "true"
log_level = log_level or os.getenv("LOG_LEVEL", "INFO")
def flatten(n):
match n:
case []: return []
case [[*hd], *tl]: return [*flatten(hd), *flatten(tl)]
case [hd, *tl]: return [hd, *flatten(tl)]
case []:
return []
case [[*hd], *tl]:
return [*flatten(hd), *flatten(tl)]
case [hd, *tl]:
return [hd, *flatten(tl)]
shared_processors: List[Processor] = flatten([
shared_processors: List[Processor] = flatten(
[
structlog.contextvars.merge_contextvars,
structlog.stdlib.add_logger_name,
structlog.stdlib.add_log_level,
@@ -52,7 +58,8 @@ def setup_logging(json_logs: Optional[bool] = None, log_level: Optional[str] = N
if json_logs
else []
),
])
]
)
structlog.configure(
processors=[
@@ -88,7 +95,12 @@ def setup_logging(json_logs: Optional[bool] = None, log_level: Optional[str] = N
root_logger.addHandler(handler)
root_logger.setLevel(log_level.upper())
def configure_logger(name: str, level: Optional[str] = None, clear: Optional[bool] = None, propagate: Optional[bool] = None) -> None:
def configure_logger(
name: str,
level: Optional[str] = None,
clear: Optional[bool] = None,
propagate: Optional[bool] = None,
) -> None:
logger = logging.getLogger(name)
if level is not None:
@@ -100,7 +112,6 @@ def setup_logging(json_logs: Optional[bool] = None, log_level: Optional[str] = N
if propagate is not None:
logger.propagate = propagate
# Clear the log handlers for uvicorn loggers, and enable propagation
# so the messages are caught by our root logger and formatted correctly
# by structlog
@@ -109,7 +120,6 @@ def setup_logging(json_logs: Optional[bool] = None, log_level: Optional[str] = N
configure_logger("apscheduler.executors.default", level="WARNING")
# Since we re-create the access logs ourselves, to add all information
# in the structured log (see the `logging_middleware` in main.py), we clear
# the handlers and prevent the logs to propagate to a logger higher up in the

View File

@@ -6,7 +6,6 @@ from fastapi import FastAPI, Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
class LoggingMiddleware(BaseHTTPMiddleware):
def __init__(self, app: FastAPI):
super().__init__(app)
@@ -42,7 +41,11 @@ class LoggingMiddleware(BaseHTTPMiddleware):
"request_id": request_id,
"version": request.scope["http_version"],
},
client={"ip": request.client.host, "port": request.client.port} if request.client else None,
client=(
{"ip": request.client.host, "port": request.client.port}
if request.client
else None
),
duration="{:.2f}ms".format(process_time_ms),
)

View File

@@ -1,4 +1,3 @@
import os
import pkgutil
import re
import sys

View File

@@ -2,9 +2,10 @@ from peewee import Model, CharField, DateTimeField, IntegerField
from playhouse.db_url import connect
from os import environ
class BaseModel(Model):
class Meta:
database = connect(url=environ.get('DATABASE_URL'))
database = connect(url=environ.get("DATABASE_URL"))
class IPAddress(BaseModel):

View File

@@ -1,5 +1,4 @@
from pydantic import BaseModel
from datetime import datetime
class SeenIP(BaseModel):

View File

@@ -7,8 +7,9 @@ def pluralize(count: int, word: Optional[str] = None) -> str:
Pluralize a word based on count. Returns 's' if count is not 1, '' (empty string) otherwise.
"""
if word:
return word + 's' if count != 1 else word
return 's' if count != 1 else ''
return word + "s" if count != 1 else word
return "s" if count != 1 else ""
def get_ip(request: Request) -> Optional[str]:
"""
@@ -25,15 +26,16 @@ def get_ip(request: Request) -> Optional[str]:
Returns:
Optional[str]: The client's IP address if available, otherwise None.
"""
x_forwarded_for = request.headers.get('X-Forwarded-For')
x_forwarded_for = request.headers.get("X-Forwarded-For")
if x_forwarded_for:
return x_forwarded_for.split(',')[0]
return x_forwarded_for.split(",")[0]
if request.client:
return request.client.host
return None
def hide_ip(ip: str, hidden_octets: Optional[int] = None) -> str:
"""
Hide the last octet(s) of an IP address.
@@ -58,16 +60,19 @@ def hide_ip(ip: str, hidden_octets: Optional[int] = None) -> str:
>>> hide_ip("2001:0db8:85a3:0000:0000:8a2e:0370:7334", 4)
'2001:0db8:85a3:0000:XXXX:XXXX:XXXX:XXXX'
"""
ipv6 = ':' in ip
ipv6 = ":" in ip
# Make sure that IPv4 (dot) and IPv6 (colon) addresses are not mixed together somehow. Not a comprehensive check.
if ipv6 == ('.' in ip):
if ipv6 == ("." in ip):
raise ValueError("Invalid IP address format. Must be either IPv4 or IPv6.")
total_octets = 8 if ipv6 else 4
separator = ':' if ipv6 else '.'
replacement = 'XXXX' if ipv6 else 'X'
separator = ":" if ipv6 else "."
replacement = "XXXX" if ipv6 else "X"
if hidden_octets is None:
hidden_octets = 3 if ipv6 else 1
return separator.join(ip.split(separator, total_octets - hidden_octets)[:-1]) + (separator + replacement) * hidden_octets
return (
separator.join(ip.split(separator, total_octets - hidden_octets)[:-1])
+ (separator + replacement) * hidden_octets
)