Reformat all python files, remove unused imports

This commit is contained in:
2024-11-01 16:13:01 -05:00
parent 4b85153065
commit 10b93d41d6
7 changed files with 77 additions and 59 deletions

View File

@@ -1,4 +1,3 @@
import logging
import os import os
import random import random
@@ -18,7 +17,7 @@ from fastapi_cache import FastAPICache
from fastapi_cache.backends.inmemory import InMemoryBackend from fastapi_cache.backends.inmemory import InMemoryBackend
from fastapi_cache.decorator import cache from fastapi_cache.decorator import cache
import structlog import structlog
from linkpulse.utilities import get_ip, hide_ip, pluralize from linkpulse.utilities import get_ip, hide_ip
from linkpulse.middleware import LoggingMiddleware from linkpulse.middleware import LoggingMiddleware
from peewee import PostgresqlDatabase from peewee import PostgresqlDatabase
from psycopg2.extras import execute_values from psycopg2.extras import execute_values
@@ -87,7 +86,9 @@ async def lifespan(_: FastAPI) -> AsyncIterator[None]:
# Delete all randomly generated IP addresses # Delete all randomly generated IP addresses
with db.atomic(): with db.atomic():
logger.info("Deleting Randomized IP Addresses", ip_pool_count=len(app.state.ip_pool)) logger.info(
"Deleting Randomized IP Addresses", ip_pool_count=len(app.state.ip_pool)
)
query = models.IPAddress.delete().where( query = models.IPAddress.delete().where(
models.IPAddress.ip << app.state.ip_pool models.IPAddress.ip << app.state.ip_pool
) )

View File

@@ -24,17 +24,23 @@ def drop_color_message_key(_, __, event_dict: EventDict) -> EventDict:
return event_dict return event_dict
def setup_logging(json_logs: Optional[bool] = None, log_level: Optional[str] = None) -> None: def setup_logging(
json_logs: Optional[bool] = None, log_level: Optional[str] = None
) -> None:
json_logs = json_logs or os.getenv("LOG_JSON_FORMAT", "true").lower() == "true" json_logs = json_logs or os.getenv("LOG_JSON_FORMAT", "true").lower() == "true"
log_level = log_level or os.getenv("LOG_LEVEL", "INFO") log_level = log_level or os.getenv("LOG_LEVEL", "INFO")
def flatten(n): def flatten(n):
match n: match n:
case []: return [] case []:
case [[*hd], *tl]: return [*flatten(hd), *flatten(tl)] return []
case [hd, *tl]: return [hd, *flatten(tl)] case [[*hd], *tl]:
return [*flatten(hd), *flatten(tl)]
case [hd, *tl]:
return [hd, *flatten(tl)]
shared_processors: List[Processor] = flatten([ shared_processors: List[Processor] = flatten(
[
structlog.contextvars.merge_contextvars, structlog.contextvars.merge_contextvars,
structlog.stdlib.add_logger_name, structlog.stdlib.add_logger_name,
structlog.stdlib.add_log_level, structlog.stdlib.add_log_level,
@@ -52,7 +58,8 @@ def setup_logging(json_logs: Optional[bool] = None, log_level: Optional[str] = N
if json_logs if json_logs
else [] else []
), ),
]) ]
)
structlog.configure( structlog.configure(
processors=[ processors=[
@@ -88,7 +95,12 @@ def setup_logging(json_logs: Optional[bool] = None, log_level: Optional[str] = N
root_logger.addHandler(handler) root_logger.addHandler(handler)
root_logger.setLevel(log_level.upper()) root_logger.setLevel(log_level.upper())
def configure_logger(name: str, level: Optional[str] = None, clear: Optional[bool] = None, propagate: Optional[bool] = None) -> None: def configure_logger(
name: str,
level: Optional[str] = None,
clear: Optional[bool] = None,
propagate: Optional[bool] = None,
) -> None:
logger = logging.getLogger(name) logger = logging.getLogger(name)
if level is not None: if level is not None:
@@ -100,7 +112,6 @@ def setup_logging(json_logs: Optional[bool] = None, log_level: Optional[str] = N
if propagate is not None: if propagate is not None:
logger.propagate = propagate logger.propagate = propagate
# Clear the log handlers for uvicorn loggers, and enable propagation # Clear the log handlers for uvicorn loggers, and enable propagation
# so the messages are caught by our root logger and formatted correctly # so the messages are caught by our root logger and formatted correctly
# by structlog # by structlog
@@ -109,7 +120,6 @@ def setup_logging(json_logs: Optional[bool] = None, log_level: Optional[str] = N
configure_logger("apscheduler.executors.default", level="WARNING") configure_logger("apscheduler.executors.default", level="WARNING")
# Since we re-create the access logs ourselves, to add all information # Since we re-create the access logs ourselves, to add all information
# in the structured log (see the `logging_middleware` in main.py), we clear # in the structured log (see the `logging_middleware` in main.py), we clear
# the handlers and prevent the logs to propagate to a logger higher up in the # the handlers and prevent the logs to propagate to a logger higher up in the

View File

@@ -6,7 +6,6 @@ from fastapi import FastAPI, Request, Response
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
class LoggingMiddleware(BaseHTTPMiddleware): class LoggingMiddleware(BaseHTTPMiddleware):
def __init__(self, app: FastAPI): def __init__(self, app: FastAPI):
super().__init__(app) super().__init__(app)
@@ -42,7 +41,11 @@ class LoggingMiddleware(BaseHTTPMiddleware):
"request_id": request_id, "request_id": request_id,
"version": request.scope["http_version"], "version": request.scope["http_version"],
}, },
client={"ip": request.client.host, "port": request.client.port} if request.client else None, client=(
{"ip": request.client.host, "port": request.client.port}
if request.client
else None
),
duration="{:.2f}ms".format(process_time_ms), duration="{:.2f}ms".format(process_time_ms),
) )

View File

@@ -1,4 +1,3 @@
import os
import pkgutil import pkgutil
import re import re
import sys import sys

View File

@@ -2,9 +2,10 @@ from peewee import Model, CharField, DateTimeField, IntegerField
from playhouse.db_url import connect from playhouse.db_url import connect
from os import environ from os import environ
class BaseModel(Model): class BaseModel(Model):
class Meta: class Meta:
database = connect(url=environ.get('DATABASE_URL')) database = connect(url=environ.get("DATABASE_URL"))
class IPAddress(BaseModel): class IPAddress(BaseModel):

View File

@@ -1,5 +1,4 @@
from pydantic import BaseModel from pydantic import BaseModel
from datetime import datetime
class SeenIP(BaseModel): class SeenIP(BaseModel):

View File

@@ -7,8 +7,9 @@ def pluralize(count: int, word: Optional[str] = None) -> str:
Pluralize a word based on count. Returns 's' if count is not 1, '' (empty string) otherwise. Pluralize a word based on count. Returns 's' if count is not 1, '' (empty string) otherwise.
""" """
if word: if word:
return word + 's' if count != 1 else word return word + "s" if count != 1 else word
return 's' if count != 1 else '' return "s" if count != 1 else ""
def get_ip(request: Request) -> Optional[str]: def get_ip(request: Request) -> Optional[str]:
""" """
@@ -25,15 +26,16 @@ def get_ip(request: Request) -> Optional[str]:
Returns: Returns:
Optional[str]: The client's IP address if available, otherwise None. Optional[str]: The client's IP address if available, otherwise None.
""" """
x_forwarded_for = request.headers.get('X-Forwarded-For') x_forwarded_for = request.headers.get("X-Forwarded-For")
if x_forwarded_for: if x_forwarded_for:
return x_forwarded_for.split(',')[0] return x_forwarded_for.split(",")[0]
if request.client: if request.client:
return request.client.host return request.client.host
return None return None
def hide_ip(ip: str, hidden_octets: Optional[int] = None) -> str: def hide_ip(ip: str, hidden_octets: Optional[int] = None) -> str:
""" """
Hide the last octet(s) of an IP address. Hide the last octet(s) of an IP address.
@@ -58,16 +60,19 @@ def hide_ip(ip: str, hidden_octets: Optional[int] = None) -> str:
>>> hide_ip("2001:0db8:85a3:0000:0000:8a2e:0370:7334", 4) >>> hide_ip("2001:0db8:85a3:0000:0000:8a2e:0370:7334", 4)
'2001:0db8:85a3:0000:XXXX:XXXX:XXXX:XXXX' '2001:0db8:85a3:0000:XXXX:XXXX:XXXX:XXXX'
""" """
ipv6 = ':' in ip ipv6 = ":" in ip
# Make sure that IPv4 (dot) and IPv6 (colon) addresses are not mixed together somehow. Not a comprehensive check. # Make sure that IPv4 (dot) and IPv6 (colon) addresses are not mixed together somehow. Not a comprehensive check.
if ipv6 == ('.' in ip): if ipv6 == ("." in ip):
raise ValueError("Invalid IP address format. Must be either IPv4 or IPv6.") raise ValueError("Invalid IP address format. Must be either IPv4 or IPv6.")
total_octets = 8 if ipv6 else 4 total_octets = 8 if ipv6 else 4
separator = ':' if ipv6 else '.' separator = ":" if ipv6 else "."
replacement = 'XXXX' if ipv6 else 'X' replacement = "XXXX" if ipv6 else "X"
if hidden_octets is None: if hidden_octets is None:
hidden_octets = 3 if ipv6 else 1 hidden_octets = 3 if ipv6 else 1
return separator.join(ip.split(separator, total_octets - hidden_octets)[:-1]) + (separator + replacement) * hidden_octets return (
separator.join(ip.split(separator, total_octets - hidden_octets)[:-1])
+ (separator + replacement) * hidden_octets
)