mirror of
https://github.com/Xevion/contest-assistant.git
synced 2025-12-06 03:14:41 -06:00
Added and implemented helper functions for getting and fetching messages given a payload's integer channel and message IDs. Implemented Submission.votes list modifications and Submission.update in code (tested with real bot) Committed constants.py containing ReactionMarker used in models.py
117 lines
4.9 KiB
Python
117 lines
4.9 KiB
Python
import logging
|
|
from contextlib import contextmanager
|
|
from datetime import datetime
|
|
from typing import ContextManager, List, Optional
|
|
|
|
import discord
|
|
from discord.ext import commands
|
|
# noinspection PyProtectedMember
|
|
from sqlalchemy.engine import Engine
|
|
from sqlalchemy.orm import Session, sessionmaker
|
|
|
|
from bot import constants
|
|
from bot.models import Guild, Period, Submission
|
|
|
|
logger = logging.getLogger(__file__)
|
|
logger.setLevel(constants.LOGGING_LEVEL)
|
|
|
|
|
|
class ContestBot(commands.Bot):
|
|
def __init__(self, engine: Engine, **options):
|
|
super().__init__(self.fetch_prefix, **options)
|
|
|
|
self.engine = engine
|
|
self.Session = sessionmaker(bind=engine)
|
|
|
|
@contextmanager
|
|
def get_session(self, autocommit=True, autoclose=True, rollback=True) -> ContextManager[Session]:
|
|
"""Provides automatic commit and closing of Session with exception rollback."""
|
|
session = self.Session()
|
|
try:
|
|
yield session
|
|
if autocommit: session.commit()
|
|
except Exception:
|
|
if rollback: session.rollback()
|
|
raise
|
|
finally:
|
|
if autoclose: session.close()
|
|
|
|
async def fetch_prefix(self, bot: 'ContestBot', message: discord.Message):
|
|
"""Fetches the prefix used by the relevant guild."""
|
|
user_id = bot.user.id
|
|
base = [f'<@!{user_id}> ', f'<@{user_id}> ']
|
|
|
|
if message.guild:
|
|
with self.get_session() as session:
|
|
guild: Guild = session.query(Guild).get(message.guild.id)
|
|
base.append(guild.prefix)
|
|
return base
|
|
|
|
async def on_ready(self):
|
|
"""Communicate that the bot is online now."""
|
|
logger.info('Bot is now ready and connected to Discord.')
|
|
guild_count = len(self.guilds)
|
|
logger.info(f'Connected as {self.user.name}#{self.user.discriminator} to {guild_count} guild{"s" if guild_count > 1 else ""}.')
|
|
|
|
with self.get_session() as session:
|
|
for guild in self.guilds:
|
|
_guild: Guild = session.query(Guild).get(guild.id)
|
|
if _guild is None:
|
|
logger.warning(
|
|
f'Guild {guild.name} ({guild.id}) was not inside database on ready. Bot was disconnected or did not add it properly...')
|
|
session.add(Guild(id=guild.id))
|
|
|
|
async def on_guild_join(self, guild: discord.Guild) -> None:
|
|
"""Handles adding or reactivating a Guild in the database."""
|
|
logger.info(f'Added to new guild: {guild.name} ({guild.id})')
|
|
|
|
with self.get_session() as session:
|
|
_guild: Guild = session.query(Guild).get(guild.id)
|
|
if _guild is None:
|
|
session.add(Guild(id=guild.id))
|
|
else:
|
|
# Guild has been seen before. Update last_joined and set as active again.
|
|
_guild.active = True
|
|
_guild.last_joined = datetime.utcnow()
|
|
|
|
async def on_guild_remove(self, guild: discord.Guild) -> None:
|
|
"""Handles disabling the guild in the database, as well."""
|
|
logger.info(f'Removed from guild: {guild.name} ({guild.id})')
|
|
|
|
with self.get_session() as session:
|
|
# Get the associated Guild and mark it as disabled.
|
|
_guild: Guild = session.query(Guild).filter_by(active=True, id=guild.id).first()
|
|
_guild.active = False
|
|
|
|
# Shut down any current running Period objects if possible.
|
|
period: Period = _guild.current_period
|
|
if period is not None and period.active:
|
|
period.deactivate()
|
|
|
|
async def add_voting_reactions(self, channel: discord.TextChannel, submissions: Optional[List[Submission]] = None) -> None:
|
|
"""Adds reactions to all valid submissions in the given channel."""
|
|
if submissions is None:
|
|
with self.get_session() as session:
|
|
period: Period = session.query(Guild).get(channel.guild.id).current_period
|
|
if period is None:
|
|
logger.error('No valid submissions - current period is not set for the Guild this channel belongs to.')
|
|
return
|
|
else:
|
|
submissions = period.submissions
|
|
|
|
if len(submissions) == 0:
|
|
logger.warning('Attempted to add voting reactions to submissions, but none were given or could be found.')
|
|
return
|
|
else:
|
|
for submission in submissions:
|
|
message: discord.PartialMessage = channel.get_partial_message(submission.id)
|
|
await message.add_reaction(self.get_emoji(constants.Emoji.UPVOTE))
|
|
|
|
def get_message(self, channel_id: int, message_id: int) -> discord.PartialMessage:
|
|
channel: discord.TextChannel = self.get_channel(channel_id)
|
|
return channel.get_partial_message(message_id)
|
|
|
|
async def fetch_message(self, channel_id: int, message_id: int) -> discord.Message:
|
|
channel: discord.TextChannel = self.get_channel(channel_id)
|
|
return await channel.fetch_message(message_id)
|