Reformat all python files, remove unused imports

This commit is contained in:
2024-11-01 16:13:01 -05:00
parent 4b85153065
commit 10b93d41d6
7 changed files with 77 additions and 59 deletions

View File

@@ -1,4 +1,3 @@
import logging
import os import os
import random import random
@@ -18,7 +17,7 @@ from fastapi_cache import FastAPICache
from fastapi_cache.backends.inmemory import InMemoryBackend from fastapi_cache.backends.inmemory import InMemoryBackend
from fastapi_cache.decorator import cache from fastapi_cache.decorator import cache
import structlog import structlog
from linkpulse.utilities import get_ip, hide_ip, pluralize from linkpulse.utilities import get_ip, hide_ip
from linkpulse.middleware import LoggingMiddleware from linkpulse.middleware import LoggingMiddleware
from peewee import PostgresqlDatabase from peewee import PostgresqlDatabase
from psycopg2.extras import execute_values from psycopg2.extras import execute_values
@@ -87,7 +86,9 @@ async def lifespan(_: FastAPI) -> AsyncIterator[None]:
# Delete all randomly generated IP addresses # Delete all randomly generated IP addresses
with db.atomic(): with db.atomic():
logger.info("Deleting Randomized IP Addresses", ip_pool_count=len(app.state.ip_pool)) logger.info(
"Deleting Randomized IP Addresses", ip_pool_count=len(app.state.ip_pool)
)
query = models.IPAddress.delete().where( query = models.IPAddress.delete().where(
models.IPAddress.ip << app.state.ip_pool models.IPAddress.ip << app.state.ip_pool
) )

View File

@@ -24,35 +24,42 @@ def drop_color_message_key(_, __, event_dict: EventDict) -> EventDict:
return event_dict return event_dict
def setup_logging(json_logs: Optional[bool] = None, log_level: Optional[str] = None) -> None: def setup_logging(
json_logs: Optional[bool] = None, log_level: Optional[str] = None
) -> None:
json_logs = json_logs or os.getenv("LOG_JSON_FORMAT", "true").lower() == "true" json_logs = json_logs or os.getenv("LOG_JSON_FORMAT", "true").lower() == "true"
log_level = log_level or os.getenv("LOG_LEVEL", "INFO") log_level = log_level or os.getenv("LOG_LEVEL", "INFO")
def flatten(n): def flatten(n):
match n: match n:
case []: return [] case []:
case [[*hd], *tl]: return [*flatten(hd), *flatten(tl)] return []
case [hd, *tl]: return [hd, *flatten(tl)] case [[*hd], *tl]:
return [*flatten(hd), *flatten(tl)]
case [hd, *tl]:
return [hd, *flatten(tl)]
shared_processors: List[Processor] = flatten([ shared_processors: List[Processor] = flatten(
structlog.contextvars.merge_contextvars, [
structlog.stdlib.add_logger_name, structlog.contextvars.merge_contextvars,
structlog.stdlib.add_log_level, structlog.stdlib.add_logger_name,
structlog.stdlib.PositionalArgumentsFormatter(), structlog.stdlib.add_log_level,
structlog.stdlib.ExtraAdder(), structlog.stdlib.PositionalArgumentsFormatter(),
drop_color_message_key, structlog.stdlib.ExtraAdder(),
structlog.processors.TimeStamper(fmt="iso"), drop_color_message_key,
structlog.processors.StackInfoRenderer(), structlog.processors.TimeStamper(fmt="iso"),
( structlog.processors.StackInfoRenderer(),
[ (
rename_event_key, [
# Format the exception only for JSON logs, as we want to pretty-print them when using the ConsoleRenderer rename_event_key,
structlog.processors.format_exc_info, # Format the exception only for JSON logs, as we want to pretty-print them when using the ConsoleRenderer
] structlog.processors.format_exc_info,
if json_logs ]
else [] if json_logs
), else []
]) ),
]
)
structlog.configure( structlog.configure(
processors=[ processors=[
@@ -88,18 +95,22 @@ def setup_logging(json_logs: Optional[bool] = None, log_level: Optional[str] = N
root_logger.addHandler(handler) root_logger.addHandler(handler)
root_logger.setLevel(log_level.upper()) root_logger.setLevel(log_level.upper())
def configure_logger(name: str, level: Optional[str] = None, clear: Optional[bool] = None, propagate: Optional[bool] = None) -> None: def configure_logger(
name: str,
level: Optional[str] = None,
clear: Optional[bool] = None,
propagate: Optional[bool] = None,
) -> None:
logger = logging.getLogger(name) logger = logging.getLogger(name)
if level is not None: if level is not None:
logger.setLevel(level.upper()) logger.setLevel(level.upper())
if clear is True: if clear is True:
logger.handlers.clear() logger.handlers.clear()
if propagate is not None: if propagate is not None:
logger.propagate = propagate logger.propagate = propagate
# Clear the log handlers for uvicorn loggers, and enable propagation # Clear the log handlers for uvicorn loggers, and enable propagation
# so the messages are caught by our root logger and formatted correctly # so the messages are caught by our root logger and formatted correctly
@@ -109,7 +120,6 @@ def setup_logging(json_logs: Optional[bool] = None, log_level: Optional[str] = N
configure_logger("apscheduler.executors.default", level="WARNING") configure_logger("apscheduler.executors.default", level="WARNING")
# Since we re-create the access logs ourselves, to add all information # Since we re-create the access logs ourselves, to add all information
# in the structured log (see the `logging_middleware` in main.py), we clear # in the structured log (see the `logging_middleware` in main.py), we clear
# the handlers and prevent the logs to propagate to a logger higher up in the # the handlers and prevent the logs to propagate to a logger higher up in the
@@ -130,4 +140,4 @@ def setup_logging(json_logs: Optional[bool] = None, log_level: Optional[str] = N
"Uncaught exception", exc_info=(exc_type, exc_value, exc_traceback) "Uncaught exception", exc_info=(exc_type, exc_value, exc_traceback)
) )
sys.excepthook = handle_exception sys.excepthook = handle_exception

View File

@@ -6,7 +6,6 @@ from fastapi import FastAPI, Request, Response
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
class LoggingMiddleware(BaseHTTPMiddleware): class LoggingMiddleware(BaseHTTPMiddleware):
def __init__(self, app: FastAPI): def __init__(self, app: FastAPI):
super().__init__(app) super().__init__(app)
@@ -30,7 +29,7 @@ class LoggingMiddleware(BaseHTTPMiddleware):
structlog.stdlib.get_logger("api.error").exception("Uncaught exception") structlog.stdlib.get_logger("api.error").exception("Uncaught exception")
raise raise
finally: finally:
process_time_ms = (time.perf_counter_ns() - start_time) / 10 ** 6 process_time_ms = (time.perf_counter_ns() - start_time) / 10**6
self.access_logger.debug( self.access_logger.debug(
"Request", "Request",
@@ -42,10 +41,14 @@ class LoggingMiddleware(BaseHTTPMiddleware):
"request_id": request_id, "request_id": request_id,
"version": request.scope["http_version"], "version": request.scope["http_version"],
}, },
client={"ip": request.client.host, "port": request.client.port} if request.client else None, client=(
{"ip": request.client.host, "port": request.client.port}
if request.client
else None
),
duration="{:.2f}ms".format(process_time_ms), duration="{:.2f}ms".format(process_time_ms),
) )
# response.headers["X-Process-Time"] = str(process_time / 10 ** 9) # response.headers["X-Process-Time"] = str(process_time / 10 ** 9)
return response return response

View File

@@ -1,4 +1,3 @@
import os
import pkgutil import pkgutil
import re import re
import sys import sys

View File

@@ -2,12 +2,13 @@ from peewee import Model, CharField, DateTimeField, IntegerField
from playhouse.db_url import connect from playhouse.db_url import connect
from os import environ from os import environ
class BaseModel(Model): class BaseModel(Model):
class Meta: class Meta:
database = connect(url=environ.get('DATABASE_URL')) database = connect(url=environ.get("DATABASE_URL"))
class IPAddress(BaseModel): class IPAddress(BaseModel):
ip = CharField(primary_key=True) ip = CharField(primary_key=True)
last_seen = DateTimeField() last_seen = DateTimeField()
count = IntegerField(default=0) count = IntegerField(default=0)

View File

@@ -1,8 +1,7 @@
from pydantic import BaseModel from pydantic import BaseModel
from datetime import datetime
class SeenIP(BaseModel): class SeenIP(BaseModel):
ip: str ip: str
last_seen: str last_seen: str
count: int count: int

View File

@@ -7,33 +7,35 @@ def pluralize(count: int, word: Optional[str] = None) -> str:
Pluralize a word based on count. Returns 's' if count is not 1, '' (empty string) otherwise. Pluralize a word based on count. Returns 's' if count is not 1, '' (empty string) otherwise.
""" """
if word: if word:
return word + 's' if count != 1 else word return word + "s" if count != 1 else word
return 's' if count != 1 else '' return "s" if count != 1 else ""
def get_ip(request: Request) -> Optional[str]: def get_ip(request: Request) -> Optional[str]:
""" """
This function attempts to retrieve the client's IP address from the request headers. This function attempts to retrieve the client's IP address from the request headers.
It first checks the 'X-Forwarded-For' header, which is commonly used in proxy setups. It first checks the 'X-Forwarded-For' header, which is commonly used in proxy setups.
If the header is present, it returns the first IP address in the list. If the header is present, it returns the first IP address in the list.
If the header is not present, it falls back to the client's direct connection IP address. If the header is not present, it falls back to the client's direct connection IP address.
If neither is available, it returns None. If neither is available, it returns None.
Args: Args:
request (Request): The request object containing headers and client information. request (Request): The request object containing headers and client information.
Returns: Returns:
Optional[str]: The client's IP address if available, otherwise None. Optional[str]: The client's IP address if available, otherwise None.
""" """
x_forwarded_for = request.headers.get('X-Forwarded-For') x_forwarded_for = request.headers.get("X-Forwarded-For")
if x_forwarded_for: if x_forwarded_for:
return x_forwarded_for.split(',')[0] return x_forwarded_for.split(",")[0]
if request.client: if request.client:
return request.client.host return request.client.host
return None return None
def hide_ip(ip: str, hidden_octets: Optional[int] = None) -> str: def hide_ip(ip: str, hidden_octets: Optional[int] = None) -> str:
""" """
Hide the last octet(s) of an IP address. Hide the last octet(s) of an IP address.
@@ -48,26 +50,29 @@ def hide_ip(ip: str, hidden_octets: Optional[int] = None) -> str:
Examples: Examples:
>>> hide_ip("192.168.1.1") >>> hide_ip("192.168.1.1")
'192.168.1.X' '192.168.1.X'
>>> hide_ip("192.168.1.1", 2) >>> hide_ip("192.168.1.1", 2)
'192.168.X.X' '192.168.X.X'
>>> hide_ip("2001:0db8:85a3:0000:0000:8a2e:0370:7334") >>> hide_ip("2001:0db8:85a3:0000:0000:8a2e:0370:7334")
'2001:0db8:85a3:0000:0000:XXXX:XXXX:XXXX' '2001:0db8:85a3:0000:0000:XXXX:XXXX:XXXX'
>>> hide_ip("2001:0db8:85a3:0000:0000:8a2e:0370:7334", 4) >>> hide_ip("2001:0db8:85a3:0000:0000:8a2e:0370:7334", 4)
'2001:0db8:85a3:0000:XXXX:XXXX:XXXX:XXXX' '2001:0db8:85a3:0000:XXXX:XXXX:XXXX:XXXX'
""" """
ipv6 = ':' in ip ipv6 = ":" in ip
# Make sure that IPv4 (dot) and IPv6 (colon) addresses are not mixed together somehow. Not a comprehensive check. # Make sure that IPv4 (dot) and IPv6 (colon) addresses are not mixed together somehow. Not a comprehensive check.
if ipv6 == ('.' in ip): if ipv6 == ("." in ip):
raise ValueError("Invalid IP address format. Must be either IPv4 or IPv6.") raise ValueError("Invalid IP address format. Must be either IPv4 or IPv6.")
total_octets = 8 if ipv6 else 4 total_octets = 8 if ipv6 else 4
separator = ':' if ipv6 else '.' separator = ":" if ipv6 else "."
replacement = 'XXXX' if ipv6 else 'X' replacement = "XXXX" if ipv6 else "X"
if hidden_octets is None: if hidden_octets is None:
hidden_octets = 3 if ipv6 else 1 hidden_octets = 3 if ipv6 else 1
return separator.join(ip.split(separator, total_octets - hidden_octets)[:-1]) + (separator + replacement) * hidden_octets return (
separator.join(ip.split(separator, total_octets - hidden_octets)[:-1])
+ (separator + replacement) * hidden_octets
)