diff --git a/bot/bot.py b/bot/bot.py index 9fdcc38..118774b 100644 --- a/bot/bot.py +++ b/bot/bot.py @@ -1,11 +1,12 @@ import logging from contextlib import contextmanager from datetime import datetime +from typing import ContextManager import discord from discord.ext import commands from sqlalchemy.engine import Engine -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import Session, sessionmaker from bot import constants from bot.models import Guild, Period @@ -22,29 +23,17 @@ class ContestBot(commands.Bot): self.Session = sessionmaker(bind=engine) @contextmanager - def autocommit(self): + def get_session(self, autocommit=True, autoclose=True, rollback=True) -> ContextManager[Session]: """Provides automatic commit and closing of Session with exception rollback.""" session = self.Session() try: yield session - session.commit() + if autocommit: session.commit() except Exception: - session.rollback() + if rollback: session.rollback() raise finally: - session.close() - - @contextmanager - def autoclose(self): - """Provides automatic closing of Session.""" - session = self.Session() - try: - yield session - except Exception: - session.rollback() - raise - finally: - session.close() + if autoclose: session.close() async def fetch_prefix(self, bot: 'ContestBot', message: discord.Message): """Fetches the prefix used by the relevant guild.""" @@ -52,7 +41,7 @@ class ContestBot(commands.Bot): base = [f'<@!{user_id}> ', f'<@{user_id}> '] if message.guild: - with self.autocommit() as session: + with self.get_session() as session: guild = session.query(Guild).filter_by(id=message.guild.id).first() base.append(guild.prefix) return base @@ -67,7 +56,7 @@ class ContestBot(commands.Bot): """Handles adding or reactivating a Guild in the database.""" logger.info(f'Added to new guild: {guild.name} ({guild.id})') - with self.autocommit() as session: + with self.get_session() as session: _guild: Guild = session.query(Guild).filter_by(active=False, id=guild.id).first() if _guild is None: session.add(Guild(id=guild.id)) @@ -80,7 +69,7 @@ class ContestBot(commands.Bot): """Handles disabling the guild in the database, as well.""" logger.info(f'Removed from guild: {guild.name} ({guild.id})') - with self.autocommit() as session: + with self.get_session() as session: # Get the associated Guild and mark it as disabled. _guild = session.query(Guild).filter_by(active=True, id=guild.id).first() _guild.active = False diff --git a/main.py b/main.py index 37c44d9..823dcf4 100644 --- a/main.py +++ b/main.py @@ -1,7 +1,10 @@ import logging +from sqlalchemy import create_engine + from bot import constants from bot.bot import ContestBot +from bot.models import Base if __name__ == "__main__": logger = logging.getLogger(__file__) @@ -16,7 +19,9 @@ if __name__ == "__main__": initial_extensions = ['contest.cogs.contest'] - bot = ContestBot(description='A assistant for the Photography Lounge\'s monday contests') + engine = create_engine(constants.DATABASE) + Base.metadata.create_all(engine) + bot = ContestBot(engine, description='A assistant for the Photography Lounge\'s monday contests') for extension in initial_extensions: bot.load_extension(extension)