add sqlalchemy to bot startup, fix PyCharm typehinting for Session context manager, simplify autocommit and autoclose separate context managers into singular with kwargs

This commit is contained in:
Xevion
2021-02-13 05:23:12 -06:00
parent c8e0ae1bf2
commit 5978068e9b
2 changed files with 15 additions and 21 deletions

View File

@@ -1,11 +1,12 @@
import logging
from contextlib import contextmanager
from datetime import datetime
from typing import ContextManager
import discord
from discord.ext import commands
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session, sessionmaker
from bot import constants
from bot.models import Guild, Period
@@ -22,29 +23,17 @@ class ContestBot(commands.Bot):
self.Session = sessionmaker(bind=engine)
@contextmanager
def autocommit(self):
def get_session(self, autocommit=True, autoclose=True, rollback=True) -> ContextManager[Session]:
"""Provides automatic commit and closing of Session with exception rollback."""
session = self.Session()
try:
yield session
session.commit()
if autocommit: session.commit()
except Exception:
session.rollback()
if rollback: session.rollback()
raise
finally:
session.close()
@contextmanager
def autoclose(self):
"""Provides automatic closing of Session."""
session = self.Session()
try:
yield session
except Exception:
session.rollback()
raise
finally:
session.close()
if autoclose: session.close()
async def fetch_prefix(self, bot: 'ContestBot', message: discord.Message):
"""Fetches the prefix used by the relevant guild."""
@@ -52,7 +41,7 @@ class ContestBot(commands.Bot):
base = [f'<@!{user_id}> ', f'<@{user_id}> ']
if message.guild:
with self.autocommit() as session:
with self.get_session() as session:
guild = session.query(Guild).filter_by(id=message.guild.id).first()
base.append(guild.prefix)
return base
@@ -67,7 +56,7 @@ class ContestBot(commands.Bot):
"""Handles adding or reactivating a Guild in the database."""
logger.info(f'Added to new guild: {guild.name} ({guild.id})')
with self.autocommit() as session:
with self.get_session() as session:
_guild: Guild = session.query(Guild).filter_by(active=False, id=guild.id).first()
if _guild is None:
session.add(Guild(id=guild.id))
@@ -80,7 +69,7 @@ class ContestBot(commands.Bot):
"""Handles disabling the guild in the database, as well."""
logger.info(f'Removed from guild: {guild.name} ({guild.id})')
with self.autocommit() as session:
with self.get_session() as session:
# Get the associated Guild and mark it as disabled.
_guild = session.query(Guild).filter_by(active=True, id=guild.id).first()
_guild.active = False