mirror of
https://github.com/Xevion/contest-assistant.git
synced 2025-12-07 05:14:42 -06:00
refactor base ContestBot logic to use new sqlalchemy models, commit exceptions.py
This commit is contained in:
87
bot/bot.py
87
bot/bot.py
@@ -1,44 +1,91 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
from contextlib import contextmanager
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
|
from sqlalchemy.engine import Engine
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
from bot import constants
|
from bot import constants
|
||||||
from bot.db import ContestDatabase
|
from bot.models import Guild, Period
|
||||||
|
|
||||||
logger = logging.getLogger(__file__)
|
logger = logging.getLogger(__file__)
|
||||||
logger.setLevel(constants.LOGGING_LEVEL)
|
logger.setLevel(constants.LOGGING_LEVEL)
|
||||||
|
|
||||||
|
|
||||||
async def fetch_prefix(bot: 'ContestBot', message: discord.Message):
|
|
||||||
"""Fetches the prefix used by the relevant guild."""
|
|
||||||
user_id = bot.user.id
|
|
||||||
base = [f'<@!{user_id}> ', f'<@{user_id}> ']
|
|
||||||
|
|
||||||
if message.guild:
|
|
||||||
if bot.db is not None:
|
|
||||||
base.append(await bot.db.get_prefix(message.guild.id))
|
|
||||||
return base
|
|
||||||
|
|
||||||
|
|
||||||
class ContestBot(commands.Bot):
|
class ContestBot(commands.Bot):
|
||||||
def __init__(self, **options):
|
def __init__(self, engine: Engine, **options):
|
||||||
self.db: Optional[ContestDatabase] = None
|
super().__init__(self.fetch_prefix, **options)
|
||||||
super().__init__(fetch_prefix, **options)
|
|
||||||
|
self.engine = engine
|
||||||
|
self.Session = sessionmaker(bind=engine)
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def autocommit(self):
|
||||||
|
"""Provides automatic commit and closing of Session with exception rollback."""
|
||||||
|
session = self.Session()
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
session.commit()
|
||||||
|
except Exception:
|
||||||
|
session.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
session.close()
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def autoclose(self):
|
||||||
|
"""Provides automatic closing of Session."""
|
||||||
|
session = self.Session()
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
except Exception:
|
||||||
|
session.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
session.close()
|
||||||
|
|
||||||
|
async def fetch_prefix(self, bot: 'ContestBot', message: discord.Message):
|
||||||
|
"""Fetches the prefix used by the relevant guild."""
|
||||||
|
user_id = bot.user.id
|
||||||
|
base = [f'<@!{user_id}> ', f'<@{user_id}> ']
|
||||||
|
|
||||||
|
if message.guild:
|
||||||
|
with self.autocommit() as session:
|
||||||
|
guild = session.query(Guild).filter_by(id=message.guild.id).first()
|
||||||
|
base.append(guild.prefix)
|
||||||
|
return base
|
||||||
|
|
||||||
async def on_ready(self):
|
async def on_ready(self):
|
||||||
if self.db is None:
|
|
||||||
self.db = await ContestDatabase.create()
|
|
||||||
logger.info('Bot is now ready and connected to Discord.')
|
logger.info('Bot is now ready and connected to Discord.')
|
||||||
guild_count = len(self.guilds)
|
guild_count = len(self.guilds)
|
||||||
logger.info(
|
logger.info(
|
||||||
f'Connected as {self.user.name}#{self.user.discriminator} to {guild_count} guild{"s" if guild_count > 1 else ""}.')
|
f'Connected as {self.user.name}#{self.user.discriminator} to {guild_count} guild{"s" if guild_count > 1 else ""}.')
|
||||||
|
|
||||||
async def on_guild_join(self, guild: discord.Guild) -> None:
|
async def on_guild_join(self, guild: discord.Guild) -> None:
|
||||||
|
"""Handles adding or reactivating a Guild in the database."""
|
||||||
logger.info(f'Added to new guild: {guild.name} ({guild.id})')
|
logger.info(f'Added to new guild: {guild.name} ({guild.id})')
|
||||||
await self.db.setup_guild(guild.id)
|
|
||||||
|
with self.autocommit() as session:
|
||||||
|
_guild: Guild = session.query(Guild).filter_by(active=False, id=guild.id).first()
|
||||||
|
if _guild is None:
|
||||||
|
session.add(Guild(id=guild.id))
|
||||||
|
else:
|
||||||
|
# Guild has been seen before. Update last_joined and set as active again.
|
||||||
|
_guild.active = True
|
||||||
|
_guild.last_joined = datetime.utcnow()
|
||||||
|
|
||||||
async def on_guild_remove(self, guild: discord.Guild) -> None:
|
async def on_guild_remove(self, guild: discord.Guild) -> None:
|
||||||
|
"""Handles disabling the guild in the database, as well."""
|
||||||
logger.info(f'Removed from guild: {guild.name} ({guild.id})')
|
logger.info(f'Removed from guild: {guild.name} ({guild.id})')
|
||||||
await self.db.teardown_guild(guild.id)
|
|
||||||
|
with self.autocommit() as session:
|
||||||
|
# Get the associated Guild and mark it as disabled.
|
||||||
|
_guild = session.query(Guild).filter_by(active=True, id=guild.id).first()
|
||||||
|
_guild.active = False
|
||||||
|
|
||||||
|
# Shut down any current running Period objects if possible.
|
||||||
|
period: Period = _guild.current_period
|
||||||
|
if period is not None and period.active:
|
||||||
|
period.deactivate()
|
||||||
|
|||||||
3
bot/exceptions.py
Normal file
3
bot/exceptions.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
class FinishedPeriod(Exception):
|
||||||
|
"""A inactive period, or a period in it's final state cannot be advanced or further modified."""
|
||||||
|
pass
|
||||||
Reference in New Issue
Block a user