Remove deprecated startup/shutdown events into proper applicaiton Lifespan definition

This commit is contained in:
2024-11-01 02:40:39 -05:00
parent f8b76c757c
commit 91cc8e24b6
2 changed files with 22 additions and 21 deletions
+8 -2
View File
@@ -13,22 +13,28 @@ def main(*args):
asyncio.run(serve(app, config)) asyncio.run(serve(app, config))
elif args[0] == "migrate": elif args[0] == "migrate":
from linkpulse.migrate import main from linkpulse.migrate import main
main(*args[1:]) main(*args[1:])
elif args[0] == "repl": elif args[0] == "repl":
import linkpulse import linkpulse
lp = linkpulse
# import most useful objects, models, and functions
lp = linkpulse # alias
from linkpulse.app import app, db from linkpulse.app import app, db
from linkpulse.models import BaseModel, IPAddress from linkpulse.models import BaseModel, IPAddress
# start REPL
from bpython import embed from bpython import embed
embed(locals()) embed(locals())
else: else:
print("Invalid command: {}".format(args[0])) print("Invalid command: {}".format(args[0]))
if __name__ == "__main__": if __name__ == "__main__":
if len(sys.argv) == 1: if len(sys.argv) == 1:
main("serve") main("serve")
else: else:
# Check that args after aren't all whitespace # Check that args after aren't all whitespace
remaining_args = ' '.join(sys.argv[1:]).strip() remaining_args = " ".join(sys.argv[1:]).strip()
if len(remaining_args) > 0: if len(remaining_args) > 0:
main(*sys.argv[1:]) main(*sys.argv[1:])
+12 -17
View File
@@ -23,8 +23,9 @@ load_dotenv(dotenv_path=".env")
from linkpulse import models, responses # type: ignore from linkpulse import models, responses # type: ignore
# global variables
is_development = os.getenv("ENVIRONMENT") == "development" is_development = os.getenv("ENVIRONMENT") == "development"
db: PostgresqlDatabase = models.BaseModel._meta.database db: PostgresqlDatabase = models.BaseModel._meta.database # type: ignore
def flush_ips(): def flush_ips():
@@ -65,10 +66,6 @@ scheduler.add_job(flush_ips, IntervalTrigger(seconds=5))
@asynccontextmanager @asynccontextmanager
async def lifespan(_: FastAPI) -> AsyncIterator[None]: async def lifespan(_: FastAPI) -> AsyncIterator[None]:
FastAPICache.init(
backend=InMemoryBackend(), prefix="fastapi-cache", cache_status_header="X-Cache"
)
if is_development: if is_development:
# 42 is the answer to everything # 42 is the answer to everything
random.seed(42) random.seed(42)
@@ -77,6 +74,13 @@ async def lifespan(_: FastAPI) -> AsyncIterator[None]:
".".join(str(random.randint(0, 255)) for _ in range(4)) for _ in range(50) ".".join(str(random.randint(0, 255)) for _ in range(4)) for _ in range(50)
] ]
# Connect to database, ensure specific tables exist
db.connect()
db.create_tables([models.IPAddress])
FastAPICache.init(
backend=InMemoryBackend(), prefix="fastapi-cache", cache_status_header="X-Cache"
)
app.state.buffered_updates = defaultdict(IPCounter) app.state.buffered_updates = defaultdict(IPCounter)
scheduler.start() scheduler.start()
@@ -86,6 +90,9 @@ async def lifespan(_: FastAPI) -> AsyncIterator[None]:
scheduler.shutdown() scheduler.shutdown()
flush_ips() flush_ips()
if not db.is_closed():
db.close()
@dataclass @dataclass
class IPCounter: class IPCounter:
@@ -114,18 +121,6 @@ if is_development:
) )
@app.on_event("startup")
def startup():
db.connect()
db.create_tables([models.IPAddress])
@app.on_event("shutdown")
def shutdown():
if not db.is_closed():
db.close()
@app.get("/health") @app.get("/health")
async def health(): async def health():
return "OK" return "OK"