Reformat all python files, remove unused imports

This commit is contained in:
2024-11-01 16:13:01 -05:00
parent 4b85153065
commit 10b93d41d6
7 changed files with 77 additions and 59 deletions

View File

@@ -1,4 +1,3 @@
import logging
import os
import random
@@ -18,7 +17,7 @@ from fastapi_cache import FastAPICache
from fastapi_cache.backends.inmemory import InMemoryBackend
from fastapi_cache.decorator import cache
import structlog
from linkpulse.utilities import get_ip, hide_ip, pluralize
from linkpulse.utilities import get_ip, hide_ip
from linkpulse.middleware import LoggingMiddleware
from peewee import PostgresqlDatabase
from psycopg2.extras import execute_values
@@ -87,7 +86,9 @@ async def lifespan(_: FastAPI) -> AsyncIterator[None]:
# Delete all randomly generated IP addresses
with db.atomic():
logger.info("Deleting Randomized IP Addresses", ip_pool_count=len(app.state.ip_pool))
logger.info(
"Deleting Randomized IP Addresses", ip_pool_count=len(app.state.ip_pool)
)
query = models.IPAddress.delete().where(
models.IPAddress.ip << app.state.ip_pool
)

View File

@@ -24,35 +24,42 @@ def drop_color_message_key(_, __, event_dict: EventDict) -> EventDict:
return event_dict
def setup_logging(json_logs: Optional[bool] = None, log_level: Optional[str] = None) -> None:
def setup_logging(
json_logs: Optional[bool] = None, log_level: Optional[str] = None
) -> None:
json_logs = json_logs or os.getenv("LOG_JSON_FORMAT", "true").lower() == "true"
log_level = log_level or os.getenv("LOG_LEVEL", "INFO")
def flatten(n):
match n:
case []: return []
case [[*hd], *tl]: return [*flatten(hd), *flatten(tl)]
case [hd, *tl]: return [hd, *flatten(tl)]
case []:
return []
case [[*hd], *tl]:
return [*flatten(hd), *flatten(tl)]
case [hd, *tl]:
return [hd, *flatten(tl)]
shared_processors: List[Processor] = flatten([
structlog.contextvars.merge_contextvars,
structlog.stdlib.add_logger_name,
structlog.stdlib.add_log_level,
structlog.stdlib.PositionalArgumentsFormatter(),
structlog.stdlib.ExtraAdder(),
drop_color_message_key,
structlog.processors.TimeStamper(fmt="iso"),
structlog.processors.StackInfoRenderer(),
(
[
rename_event_key,
# Format the exception only for JSON logs, as we want to pretty-print them when using the ConsoleRenderer
structlog.processors.format_exc_info,
]
if json_logs
else []
),
])
shared_processors: List[Processor] = flatten(
[
structlog.contextvars.merge_contextvars,
structlog.stdlib.add_logger_name,
structlog.stdlib.add_log_level,
structlog.stdlib.PositionalArgumentsFormatter(),
structlog.stdlib.ExtraAdder(),
drop_color_message_key,
structlog.processors.TimeStamper(fmt="iso"),
structlog.processors.StackInfoRenderer(),
(
[
rename_event_key,
# Format the exception only for JSON logs, as we want to pretty-print them when using the ConsoleRenderer
structlog.processors.format_exc_info,
]
if json_logs
else []
),
]
)
structlog.configure(
processors=[
@@ -88,18 +95,22 @@ def setup_logging(json_logs: Optional[bool] = None, log_level: Optional[str] = N
root_logger.addHandler(handler)
root_logger.setLevel(log_level.upper())
def configure_logger(name: str, level: Optional[str] = None, clear: Optional[bool] = None, propagate: Optional[bool] = None) -> None:
def configure_logger(
name: str,
level: Optional[str] = None,
clear: Optional[bool] = None,
propagate: Optional[bool] = None,
) -> None:
logger = logging.getLogger(name)
if level is not None:
logger.setLevel(level.upper())
if clear is True:
logger.handlers.clear()
if propagate is not None:
logger.propagate = propagate
# Clear the log handlers for uvicorn loggers, and enable propagation
# so the messages are caught by our root logger and formatted correctly
@@ -109,7 +120,6 @@ def setup_logging(json_logs: Optional[bool] = None, log_level: Optional[str] = N
configure_logger("apscheduler.executors.default", level="WARNING")
# Since we re-create the access logs ourselves, to add all information
# in the structured log (see the `logging_middleware` in main.py), we clear
# the handlers and prevent the logs to propagate to a logger higher up in the
@@ -130,4 +140,4 @@ def setup_logging(json_logs: Optional[bool] = None, log_level: Optional[str] = N
"Uncaught exception", exc_info=(exc_type, exc_value, exc_traceback)
)
sys.excepthook = handle_exception
sys.excepthook = handle_exception

View File

@@ -6,7 +6,6 @@ from fastapi import FastAPI, Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
class LoggingMiddleware(BaseHTTPMiddleware):
def __init__(self, app: FastAPI):
super().__init__(app)
@@ -30,7 +29,7 @@ class LoggingMiddleware(BaseHTTPMiddleware):
structlog.stdlib.get_logger("api.error").exception("Uncaught exception")
raise
finally:
process_time_ms = (time.perf_counter_ns() - start_time) / 10 ** 6
process_time_ms = (time.perf_counter_ns() - start_time) / 10**6
self.access_logger.debug(
"Request",
@@ -42,10 +41,14 @@ class LoggingMiddleware(BaseHTTPMiddleware):
"request_id": request_id,
"version": request.scope["http_version"],
},
client={"ip": request.client.host, "port": request.client.port} if request.client else None,
client=(
{"ip": request.client.host, "port": request.client.port}
if request.client
else None
),
duration="{:.2f}ms".format(process_time_ms),
)
# response.headers["X-Process-Time"] = str(process_time / 10 ** 9)
return response
return response

View File

@@ -1,4 +1,3 @@
import os
import pkgutil
import re
import sys

View File

@@ -2,12 +2,13 @@ from peewee import Model, CharField, DateTimeField, IntegerField
from playhouse.db_url import connect
from os import environ
class BaseModel(Model):
class Meta:
database = connect(url=environ.get('DATABASE_URL'))
database = connect(url=environ.get("DATABASE_URL"))
class IPAddress(BaseModel):
ip = CharField(primary_key=True)
last_seen = DateTimeField()
count = IntegerField(default=0)
count = IntegerField(default=0)

View File

@@ -1,8 +1,7 @@
from pydantic import BaseModel
from datetime import datetime
class SeenIP(BaseModel):
ip: str
last_seen: str
count: int
count: int

View File

@@ -7,33 +7,35 @@ def pluralize(count: int, word: Optional[str] = None) -> str:
Pluralize a word based on count. Returns 's' if count is not 1, '' (empty string) otherwise.
"""
if word:
return word + 's' if count != 1 else word
return 's' if count != 1 else ''
return word + "s" if count != 1 else word
return "s" if count != 1 else ""
def get_ip(request: Request) -> Optional[str]:
"""
This function attempts to retrieve the client's IP address from the request headers.
It first checks the 'X-Forwarded-For' header, which is commonly used in proxy setups.
If the header is present, it returns the first IP address in the list.
If the header is not present, it falls back to the client's direct connection IP address.
If neither is available, it returns None.
Args:
request (Request): The request object containing headers and client information.
Returns:
Optional[str]: The client's IP address if available, otherwise None.
"""
x_forwarded_for = request.headers.get('X-Forwarded-For')
x_forwarded_for = request.headers.get("X-Forwarded-For")
if x_forwarded_for:
return x_forwarded_for.split(',')[0]
return x_forwarded_for.split(",")[0]
if request.client:
return request.client.host
return None
def hide_ip(ip: str, hidden_octets: Optional[int] = None) -> str:
"""
Hide the last octet(s) of an IP address.
@@ -48,26 +50,29 @@ def hide_ip(ip: str, hidden_octets: Optional[int] = None) -> str:
Examples:
>>> hide_ip("192.168.1.1")
'192.168.1.X'
>>> hide_ip("192.168.1.1", 2)
'192.168.X.X'
>>> hide_ip("2001:0db8:85a3:0000:0000:8a2e:0370:7334")
'2001:0db8:85a3:0000:0000:XXXX:XXXX:XXXX'
>>> hide_ip("2001:0db8:85a3:0000:0000:8a2e:0370:7334", 4)
'2001:0db8:85a3:0000:XXXX:XXXX:XXXX:XXXX'
"""
ipv6 = ':' in ip
ipv6 = ":" in ip
# Make sure that IPv4 (dot) and IPv6 (colon) addresses are not mixed together somehow. Not a comprehensive check.
if ipv6 == ('.' in ip):
if ipv6 == ("." in ip):
raise ValueError("Invalid IP address format. Must be either IPv4 or IPv6.")
total_octets = 8 if ipv6 else 4
separator = ':' if ipv6 else '.'
replacement = 'XXXX' if ipv6 else 'X'
separator = ":" if ipv6 else "."
replacement = "XXXX" if ipv6 else "X"
if hidden_octets is None:
hidden_octets = 3 if ipv6 else 1
return separator.join(ip.split(separator, total_octets - hidden_octets)[:-1]) + (separator + replacement) * hidden_octets
return (
separator.join(ip.split(separator, total_octets - hidden_octets)[:-1])
+ (separator + replacement) * hidden_octets
)