refactor base ContestBot logic to use new sqlalchemy models, commit exceptions.py

This commit is contained in:
Xevion
2021-02-13 05:18:34 -06:00
parent 91594646c3
commit c8e0ae1bf2
2 changed files with 70 additions and 20 deletions

View File

@@ -1,44 +1,91 @@
import logging
from typing import Optional
from contextlib import contextmanager
from datetime import datetime
import discord
from discord.ext import commands
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
from bot import constants
from bot.db import ContestDatabase
from bot.models import Guild, Period
logger = logging.getLogger(__file__)
logger.setLevel(constants.LOGGING_LEVEL)
async def fetch_prefix(bot: 'ContestBot', message: discord.Message):
"""Fetches the prefix used by the relevant guild."""
user_id = bot.user.id
base = [f'<@!{user_id}> ', f'<@{user_id}> ']
if message.guild:
if bot.db is not None:
base.append(await bot.db.get_prefix(message.guild.id))
return base
class ContestBot(commands.Bot):
def __init__(self, **options):
self.db: Optional[ContestDatabase] = None
super().__init__(fetch_prefix, **options)
def __init__(self, engine: Engine, **options):
super().__init__(self.fetch_prefix, **options)
self.engine = engine
self.Session = sessionmaker(bind=engine)
@contextmanager
def autocommit(self):
"""Provides automatic commit and closing of Session with exception rollback."""
session = self.Session()
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()
@contextmanager
def autoclose(self):
"""Provides automatic closing of Session."""
session = self.Session()
try:
yield session
except Exception:
session.rollback()
raise
finally:
session.close()
async def fetch_prefix(self, bot: 'ContestBot', message: discord.Message):
"""Fetches the prefix used by the relevant guild."""
user_id = bot.user.id
base = [f'<@!{user_id}> ', f'<@{user_id}> ']
if message.guild:
with self.autocommit() as session:
guild = session.query(Guild).filter_by(id=message.guild.id).first()
base.append(guild.prefix)
return base
async def on_ready(self):
if self.db is None:
self.db = await ContestDatabase.create()
logger.info('Bot is now ready and connected to Discord.')
guild_count = len(self.guilds)
logger.info(
f'Connected as {self.user.name}#{self.user.discriminator} to {guild_count} guild{"s" if guild_count > 1 else ""}.')
async def on_guild_join(self, guild: discord.Guild) -> None:
"""Handles adding or reactivating a Guild in the database."""
logger.info(f'Added to new guild: {guild.name} ({guild.id})')
await self.db.setup_guild(guild.id)
with self.autocommit() as session:
_guild: Guild = session.query(Guild).filter_by(active=False, id=guild.id).first()
if _guild is None:
session.add(Guild(id=guild.id))
else:
# Guild has been seen before. Update last_joined and set as active again.
_guild.active = True
_guild.last_joined = datetime.utcnow()
async def on_guild_remove(self, guild: discord.Guild) -> None:
"""Handles disabling the guild in the database, as well."""
logger.info(f'Removed from guild: {guild.name} ({guild.id})')
await self.db.teardown_guild(guild.id)
with self.autocommit() as session:
# Get the associated Guild and mark it as disabled.
_guild = session.query(Guild).filter_by(active=True, id=guild.id).first()
_guild.active = False
# Shut down any current running Period objects if possible.
period: Period = _guild.current_period
if period is not None and period.active:
period.deactivate()

3
bot/exceptions.py Normal file
View File

@@ -0,0 +1,3 @@
class FinishedPeriod(Exception):
"""A inactive period, or a period in it's final state cannot be advanced or further modified."""
pass