add sqlalchemy to bot startup, fix PyCharm typehinting for Session context manager, simplify autocommit and autoclose separate context managers into singular with kwargs

This commit is contained in:
Xevion
2021-02-13 05:23:12 -06:00
parent c8e0ae1bf2
commit 5978068e9b
2 changed files with 15 additions and 21 deletions

View File

@@ -1,11 +1,12 @@
import logging import logging
from contextlib import contextmanager from contextlib import contextmanager
from datetime import datetime from datetime import datetime
from typing import ContextManager
import discord import discord
from discord.ext import commands from discord.ext import commands
from sqlalchemy.engine import Engine from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import Session, sessionmaker
from bot import constants from bot import constants
from bot.models import Guild, Period from bot.models import Guild, Period
@@ -22,29 +23,17 @@ class ContestBot(commands.Bot):
self.Session = sessionmaker(bind=engine) self.Session = sessionmaker(bind=engine)
@contextmanager @contextmanager
def autocommit(self): def get_session(self, autocommit=True, autoclose=True, rollback=True) -> ContextManager[Session]:
"""Provides automatic commit and closing of Session with exception rollback.""" """Provides automatic commit and closing of Session with exception rollback."""
session = self.Session() session = self.Session()
try: try:
yield session yield session
session.commit() if autocommit: session.commit()
except Exception: except Exception:
session.rollback() if rollback: session.rollback()
raise raise
finally: finally:
session.close() if autoclose: session.close()
@contextmanager
def autoclose(self):
"""Provides automatic closing of Session."""
session = self.Session()
try:
yield session
except Exception:
session.rollback()
raise
finally:
session.close()
async def fetch_prefix(self, bot: 'ContestBot', message: discord.Message): async def fetch_prefix(self, bot: 'ContestBot', message: discord.Message):
"""Fetches the prefix used by the relevant guild.""" """Fetches the prefix used by the relevant guild."""
@@ -52,7 +41,7 @@ class ContestBot(commands.Bot):
base = [f'<@!{user_id}> ', f'<@{user_id}> '] base = [f'<@!{user_id}> ', f'<@{user_id}> ']
if message.guild: if message.guild:
with self.autocommit() as session: with self.get_session() as session:
guild = session.query(Guild).filter_by(id=message.guild.id).first() guild = session.query(Guild).filter_by(id=message.guild.id).first()
base.append(guild.prefix) base.append(guild.prefix)
return base return base
@@ -67,7 +56,7 @@ class ContestBot(commands.Bot):
"""Handles adding or reactivating a Guild in the database.""" """Handles adding or reactivating a Guild in the database."""
logger.info(f'Added to new guild: {guild.name} ({guild.id})') logger.info(f'Added to new guild: {guild.name} ({guild.id})')
with self.autocommit() as session: with self.get_session() as session:
_guild: Guild = session.query(Guild).filter_by(active=False, id=guild.id).first() _guild: Guild = session.query(Guild).filter_by(active=False, id=guild.id).first()
if _guild is None: if _guild is None:
session.add(Guild(id=guild.id)) session.add(Guild(id=guild.id))
@@ -80,7 +69,7 @@ class ContestBot(commands.Bot):
"""Handles disabling the guild in the database, as well.""" """Handles disabling the guild in the database, as well."""
logger.info(f'Removed from guild: {guild.name} ({guild.id})') logger.info(f'Removed from guild: {guild.name} ({guild.id})')
with self.autocommit() as session: with self.get_session() as session:
# Get the associated Guild and mark it as disabled. # Get the associated Guild and mark it as disabled.
_guild = session.query(Guild).filter_by(active=True, id=guild.id).first() _guild = session.query(Guild).filter_by(active=True, id=guild.id).first()
_guild.active = False _guild.active = False

View File

@@ -1,7 +1,10 @@
import logging import logging
from sqlalchemy import create_engine
from bot import constants from bot import constants
from bot.bot import ContestBot from bot.bot import ContestBot
from bot.models import Base
if __name__ == "__main__": if __name__ == "__main__":
logger = logging.getLogger(__file__) logger = logging.getLogger(__file__)
@@ -16,7 +19,9 @@ if __name__ == "__main__":
initial_extensions = ['contest.cogs.contest'] initial_extensions = ['contest.cogs.contest']
bot = ContestBot(description='A assistant for the Photography Lounge\'s monday contests') engine = create_engine(constants.DATABASE)
Base.metadata.create_all(engine)
bot = ContestBot(engine, description='A assistant for the Photography Lounge\'s monday contests')
for extension in initial_extensions: for extension in initial_extensions:
bot.load_extension(extension) bot.load_extension(extension)