mirror of
https://github.com/Xevion/contest-assistant.git
synced 2025-12-06 05:14:41 -06:00
add sqlalchemy to bot startup, fix PyCharm typehinting for Session context manager, simplify autocommit and autoclose separate context managers into singular with kwargs
This commit is contained in:
29
bot/bot.py
29
bot/bot.py
@@ -1,11 +1,12 @@
|
|||||||
import logging
|
import logging
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from typing import ContextManager
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
from sqlalchemy.engine import Engine
|
from sqlalchemy.engine import Engine
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
|
||||||
from bot import constants
|
from bot import constants
|
||||||
from bot.models import Guild, Period
|
from bot.models import Guild, Period
|
||||||
@@ -22,29 +23,17 @@ class ContestBot(commands.Bot):
|
|||||||
self.Session = sessionmaker(bind=engine)
|
self.Session = sessionmaker(bind=engine)
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def autocommit(self):
|
def get_session(self, autocommit=True, autoclose=True, rollback=True) -> ContextManager[Session]:
|
||||||
"""Provides automatic commit and closing of Session with exception rollback."""
|
"""Provides automatic commit and closing of Session with exception rollback."""
|
||||||
session = self.Session()
|
session = self.Session()
|
||||||
try:
|
try:
|
||||||
yield session
|
yield session
|
||||||
session.commit()
|
if autocommit: session.commit()
|
||||||
except Exception:
|
except Exception:
|
||||||
session.rollback()
|
if rollback: session.rollback()
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
session.close()
|
if autoclose: session.close()
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def autoclose(self):
|
|
||||||
"""Provides automatic closing of Session."""
|
|
||||||
session = self.Session()
|
|
||||||
try:
|
|
||||||
yield session
|
|
||||||
except Exception:
|
|
||||||
session.rollback()
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
session.close()
|
|
||||||
|
|
||||||
async def fetch_prefix(self, bot: 'ContestBot', message: discord.Message):
|
async def fetch_prefix(self, bot: 'ContestBot', message: discord.Message):
|
||||||
"""Fetches the prefix used by the relevant guild."""
|
"""Fetches the prefix used by the relevant guild."""
|
||||||
@@ -52,7 +41,7 @@ class ContestBot(commands.Bot):
|
|||||||
base = [f'<@!{user_id}> ', f'<@{user_id}> ']
|
base = [f'<@!{user_id}> ', f'<@{user_id}> ']
|
||||||
|
|
||||||
if message.guild:
|
if message.guild:
|
||||||
with self.autocommit() as session:
|
with self.get_session() as session:
|
||||||
guild = session.query(Guild).filter_by(id=message.guild.id).first()
|
guild = session.query(Guild).filter_by(id=message.guild.id).first()
|
||||||
base.append(guild.prefix)
|
base.append(guild.prefix)
|
||||||
return base
|
return base
|
||||||
@@ -67,7 +56,7 @@ class ContestBot(commands.Bot):
|
|||||||
"""Handles adding or reactivating a Guild in the database."""
|
"""Handles adding or reactivating a Guild in the database."""
|
||||||
logger.info(f'Added to new guild: {guild.name} ({guild.id})')
|
logger.info(f'Added to new guild: {guild.name} ({guild.id})')
|
||||||
|
|
||||||
with self.autocommit() as session:
|
with self.get_session() as session:
|
||||||
_guild: Guild = session.query(Guild).filter_by(active=False, id=guild.id).first()
|
_guild: Guild = session.query(Guild).filter_by(active=False, id=guild.id).first()
|
||||||
if _guild is None:
|
if _guild is None:
|
||||||
session.add(Guild(id=guild.id))
|
session.add(Guild(id=guild.id))
|
||||||
@@ -80,7 +69,7 @@ class ContestBot(commands.Bot):
|
|||||||
"""Handles disabling the guild in the database, as well."""
|
"""Handles disabling the guild in the database, as well."""
|
||||||
logger.info(f'Removed from guild: {guild.name} ({guild.id})')
|
logger.info(f'Removed from guild: {guild.name} ({guild.id})')
|
||||||
|
|
||||||
with self.autocommit() as session:
|
with self.get_session() as session:
|
||||||
# Get the associated Guild and mark it as disabled.
|
# Get the associated Guild and mark it as disabled.
|
||||||
_guild = session.query(Guild).filter_by(active=True, id=guild.id).first()
|
_guild = session.query(Guild).filter_by(active=True, id=guild.id).first()
|
||||||
_guild.active = False
|
_guild.active = False
|
||||||
|
|||||||
7
main.py
7
main.py
@@ -1,7 +1,10 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
|
||||||
from bot import constants
|
from bot import constants
|
||||||
from bot.bot import ContestBot
|
from bot.bot import ContestBot
|
||||||
|
from bot.models import Base
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
logger = logging.getLogger(__file__)
|
logger = logging.getLogger(__file__)
|
||||||
@@ -16,7 +19,9 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
initial_extensions = ['contest.cogs.contest']
|
initial_extensions = ['contest.cogs.contest']
|
||||||
|
|
||||||
bot = ContestBot(description='A assistant for the Photography Lounge\'s monday contests')
|
engine = create_engine(constants.DATABASE)
|
||||||
|
Base.metadata.create_all(engine)
|
||||||
|
bot = ContestBot(engine, description='A assistant for the Photography Lounge\'s monday contests')
|
||||||
|
|
||||||
for extension in initial_extensions:
|
for extension in initial_extensions:
|
||||||
bot.load_extension(extension)
|
bot.load_extension(extension)
|
||||||
|
|||||||
Reference in New Issue
Block a user