mirror of
https://github.com/Xevion/contest-assistant.git
synced 2025-12-10 06:06:48 -06:00
refactor contest to bot folder, work on generic sqlite update query generator, work on new submissions/voting period logic commands, get away from get_submission_channel
This commit is contained in:
0
bot/__init__.py
Normal file
0
bot/__init__.py
Normal file
44
bot/bot.py
Normal file
44
bot/bot.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import discord
|
||||
from discord.ext import commands
|
||||
|
||||
from bot import constants
|
||||
from bot.db import ContestDatabase
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
logger.setLevel(constants.LOGGING_LEVEL)
|
||||
|
||||
|
||||
async def fetch_prefix(bot: 'ContestBot', message: discord.Message):
|
||||
"""Fetches the prefix used by the relevant guild."""
|
||||
user_id = bot.user.id
|
||||
base = [f'<@!{user_id}> ', f'<@{user_id}> ']
|
||||
|
||||
if message.guild:
|
||||
if bot.db is not None:
|
||||
base.append(await bot.db.get_prefix(message.guild.id))
|
||||
return base
|
||||
|
||||
|
||||
class ContestBot(commands.Bot):
|
||||
def __init__(self, **options):
|
||||
self.db: Optional[ContestDatabase] = None
|
||||
super().__init__(fetch_prefix, **options)
|
||||
|
||||
async def on_ready(self):
|
||||
if self.db is None:
|
||||
self.db = await ContestDatabase.create()
|
||||
logger.info('Bot is now ready and connected to Discord.')
|
||||
guild_count = len(self.guilds)
|
||||
logger.info(
|
||||
f'Connected as {self.user.name}#{self.user.discriminator} to {guild_count} guild{"s" if guild_count > 1 else ""}.')
|
||||
|
||||
async def on_guild_join(self, guild: discord.Guild) -> None:
|
||||
logger.info(f'Added to new guild: {guild.name} ({guild.id})')
|
||||
await self.db.setup_guild(guild.id)
|
||||
|
||||
async def on_guild_remove(self, guild: discord.Guild) -> None:
|
||||
logger.info(f'Removed from guild: {guild.name} ({guild.id})')
|
||||
await self.db.teardown_guild(guild.id)
|
||||
15
bot/checks.py
Normal file
15
bot/checks.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from discord.ext import commands
|
||||
|
||||
|
||||
def check_permissions(ctx, perms, *, check=all):
|
||||
resolved = ctx.channel.permissions_for(ctx.author)
|
||||
return check(getattr(resolved, name, None) == value for name, value in perms.items())
|
||||
|
||||
|
||||
def privileged():
|
||||
def predicate(ctx):
|
||||
return (ctx.guild is not None and ctx.guild.owner_id == ctx.author.id) \
|
||||
or check_permissions(ctx, {'manage_guild': True}) \
|
||||
or check_permissions(ctx, {'administrator': True})
|
||||
|
||||
return commands.check(predicate)
|
||||
190
bot/cogs/contest.py
Normal file
190
bot/cogs/contest.py
Normal file
@@ -0,0 +1,190 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
import discord
|
||||
from discord.ext import commands
|
||||
from discord.ext.commands import Context
|
||||
|
||||
from bot import checks, constants
|
||||
from bot.bot import ContestBot
|
||||
from bot.db import Period
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
logger.setLevel(constants.LOGGING_LEVEL)
|
||||
|
||||
expected_deletions = []
|
||||
|
||||
|
||||
class ContestCog(commands.Cog):
|
||||
def __init__(self, bot: ContestBot):
|
||||
self.bot = bot
|
||||
|
||||
@commands.command()
|
||||
@commands.guild_only()
|
||||
@checks.privileged()
|
||||
async def prefix(self, ctx, new_prefix: str):
|
||||
"""Changes the bot's saved prefix."""
|
||||
guild = await self.bot.db.get_guild(ctx.guild.id)
|
||||
|
||||
if 1 <= len(new_prefix) <= 2:
|
||||
if guild.prefix == new_prefix:
|
||||
return await ctx.send(f':no_entry_sign: The prefix is already `{new_prefix}`.')
|
||||
else:
|
||||
await self.bot.db.set_prefix(ctx.guild.id, new_prefix)
|
||||
return await ctx.send(f':white_check_mark: Prefix changed to `{new_prefix}`.')
|
||||
else:
|
||||
return await ctx.send(':no_entry_sign: Invalid argument. Prefix must be 1 or 2 characters long.')
|
||||
|
||||
@commands.command()
|
||||
@commands.guild_only()
|
||||
@checks.privileged()
|
||||
async def submission(self, ctx: Context, new_submission: discord.TextChannel) -> None:
|
||||
"""Changes the bot's saved submission channel."""
|
||||
guild = await self.bot.db.get_guild(ctx.guild.id)
|
||||
|
||||
if guild.submission is not None and guild.submission == new_submission.id:
|
||||
await ctx.send(
|
||||
f':no_entry_sign: The submission channel is already set to {new_submission.mention}.')
|
||||
else:
|
||||
await self.bot.db.set_submission_channel(ctx.guild.id, new_submission.id)
|
||||
await ctx.send(f':white_check_mark: Submission channel changed to {new_submission.mention}.')
|
||||
|
||||
@commands.command()
|
||||
@commands.guild_only()
|
||||
@checks.privileged()
|
||||
async def submissions(self, ctx: Context, duration: float = None) -> None:
|
||||
"""Opens up the submissions channel."""
|
||||
assert duration == -1 or duration >= 0, "Duration must"
|
||||
|
||||
cur = await self.bot.db.conn.cursor()
|
||||
try:
|
||||
period = await self.bot.db.get_current_period(ctx.guild.id)
|
||||
|
||||
# Handle non-existent or final-state period
|
||||
if period is None:
|
||||
new_period = Period(guild=ctx.guild.id, current_state=0, started_at=datetime.now(), voting_at=None, finished_at=None)
|
||||
await self.bot.db.new_period(new_period)
|
||||
# Handle submissions state
|
||||
elif period.current_state == 0:
|
||||
await self.bot.db.update_period(period)
|
||||
return
|
||||
# Handle voting state
|
||||
elif period.current_state == 1:
|
||||
return
|
||||
# Print period submissions
|
||||
elif period.current_state == 2:
|
||||
# TODO: Fetch all submissions related to this period
|
||||
# TODO: Create new period for Guild at
|
||||
return
|
||||
finally:
|
||||
await cur.close()
|
||||
|
||||
|
||||
@commands.command()
|
||||
@commands.guild_only()
|
||||
@checks.privileged()
|
||||
async def voting(self, ctx: Context, duration: float = None) -> None:
|
||||
"""Closes submissions and sets up the voting period."""
|
||||
if duration < 0:
|
||||
await ctx.send('Invalid duration - must be non-negative.')
|
||||
|
||||
@commands.command()
|
||||
@commands.guild_only()
|
||||
@checks.privileged()
|
||||
async def close(self, ctx: Context) -> None:
|
||||
"""Closes the voting period."""
|
||||
pass
|
||||
|
||||
@commands.command()
|
||||
@commands.guild_only()
|
||||
async def status(self, ctx: Context) -> None:
|
||||
"""Provides the bot's current state in relation to internal configuration and the server's contest, if active."""
|
||||
pass
|
||||
|
||||
@commands.Cog.listener()
|
||||
async def on_message(self, message: discord.Message):
|
||||
if message.author == self.bot.user or message.author.bot or not message.guild: return
|
||||
guild = await self.bot.db.get_guild(message.guild.id)
|
||||
|
||||
channel: discord.TextChannel = message.channel
|
||||
if channel.id == guild.submission:
|
||||
attachments = message.attachments
|
||||
if len(attachments) == 0:
|
||||
await message.delete(delay=1)
|
||||
warning = await channel.send(
|
||||
f':no_entry_sign: {message.author.mention} Each submission must contain exactly one image.')
|
||||
await warning.delete(delay=5)
|
||||
elif len(attachments) > 1:
|
||||
await message.delete(delay=1)
|
||||
warning = await channel.send(
|
||||
f':no_entry_sign: {message.author.mention} Each submission must contain exactly one image.')
|
||||
await warning.delete(delay=5)
|
||||
else:
|
||||
last_submission = await self.bot.db.get_submission(message.guild.id, message.author.id)
|
||||
if last_submission is not None:
|
||||
# delete last submission
|
||||
submission_msg = await channel.fetch_message(last_submission)
|
||||
if submission_msg is None:
|
||||
logger.error(f'Unexpected: submission message {last_submission} could not be found.')
|
||||
else:
|
||||
await submission_msg.delete()
|
||||
logger.info(f'Old submission deleted. {last_submission} (Old) -> {message.id} (New)')
|
||||
|
||||
# Delete the old submission row
|
||||
await self.bot.db.conn.execute('''DELETE FROM submission WHERE id = ?''', [last_submission])
|
||||
await self.bot.db.conn.commit()
|
||||
|
||||
# Add the new submission row
|
||||
await self.bot.db.add_submission(message.id, channel.guild.id, message.author.id, message.created_at)
|
||||
logger.info(f'New submission created ({message.id}).')
|
||||
|
||||
@commands.Cog.listener()
|
||||
async def on_raw_message_delete(self, payload: discord.RawMessageDeleteEvent) -> None:
|
||||
"""Handles submission deletions by the users, moderators or other bots for any reason."""
|
||||
await self.bot.wait_until_ready()
|
||||
|
||||
# Ignore messages we delete
|
||||
if payload.message_id in expected_deletions:
|
||||
expected_deletions.remove(payload.message_id)
|
||||
return
|
||||
|
||||
# If the message was cached, check that it's in the correct channel.
|
||||
if payload.cached_message is not None:
|
||||
guild = await self.bot.db.get_guild(payload.guild_id)
|
||||
if payload.cached_message.channel.id != guild.submission:
|
||||
return
|
||||
|
||||
cur = await self.bot.db.conn.cursor()
|
||||
try:
|
||||
await cur.execute('''DELETE FROM submission WHERE id = ? AND guild = ?''',
|
||||
[payload.message_id, payload.guild_id])
|
||||
if cur.rowcount > 0:
|
||||
author = payload.cached_message.author.display_name if payload.cached_message is not None else 'Unknown'
|
||||
logger.info(f'Submission {payload.message_id} by {author} deleted by outside source.')
|
||||
await self.bot.db.conn.commit()
|
||||
finally:
|
||||
await cur.close()
|
||||
|
||||
@commands.Cog.listener()
|
||||
async def on_raw_bulk_message_delete(self, payload: discord.RawBulkMessageDeleteEvent) -> None:
|
||||
pass
|
||||
|
||||
@commands.Cog.listener()
|
||||
async def on_raw_reaction_add(self, payload: discord.RawReactionActionEvent) -> None:
|
||||
pass
|
||||
|
||||
@commands.Cog.listener()
|
||||
async def on_raw_reaction_remove(self, payload: discord.RawReactionActionEvent) -> None:
|
||||
pass
|
||||
|
||||
@commands.Cog.listener()
|
||||
async def on_raw_reaction_clear(self, payload: discord.RawReactionActionEvent) -> None:
|
||||
pass
|
||||
|
||||
@commands.Cog.listener()
|
||||
async def on_raw_reaction_clear_emoji(self, payload: discord.RawReactionClearEmojiEvent) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def setup(bot) -> None:
|
||||
bot.add_cog(ContestCog(bot))
|
||||
10
bot/constants.py
Normal file
10
bot/constants.py
Normal file
@@ -0,0 +1,10 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
# Path Constants
|
||||
BASE_DIR = os.path.dirname(os.path.abspath(os.path.join(__file__, '..')))
|
||||
TOKEN = os.path.join(BASE_DIR, 'token.dat')
|
||||
DATABASE = os.path.join(BASE_DIR, 'database.db')
|
||||
|
||||
# Other constants
|
||||
LOGGING_LEVEL = logging.DEBUG
|
||||
207
bot/db.py
Normal file
207
bot/db.py
Normal file
@@ -0,0 +1,207 @@
|
||||
import logging
|
||||
import os
|
||||
import sqlite3
|
||||
from collections import namedtuple
|
||||
from datetime import datetime
|
||||
from typing import Optional, List, Union
|
||||
|
||||
import aiosqlite
|
||||
|
||||
from bot import constants
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
logger.setLevel(constants.LOGGING_LEVEL)
|
||||
|
||||
Guild = namedtuple('Guild', ['id', 'prefix', 'submission', 'period'])
|
||||
Submission = namedtuple('Submission', ['id', 'user', 'guild', 'timestamp'])
|
||||
Period = namedtuple('Period', ['id', 'guild', 'current_state', 'started_at', 'voting_at', 'finished_at', ''])
|
||||
|
||||
|
||||
class ContestDatabase(object):
|
||||
"""
|
||||
A handler class for a SQLite3 database used by the bot with Async support.
|
||||
"""
|
||||
|
||||
def __init__(self, conn: aiosqlite.Connection) -> None:
|
||||
self.conn = conn
|
||||
|
||||
@classmethod
|
||||
async def create(cls, dest: str = constants.DATABASE) -> 'ContestDatabase':
|
||||
"""
|
||||
Constructs a ContestDatabase object connecting to the default database location with the proper connection settings.
|
||||
:return: A fully realized ContestDatabase object.
|
||||
"""
|
||||
|
||||
conn = await aiosqlite.connect(dest, detect_types=sqlite3.PARSE_DECLTYPES)
|
||||
if dest.startswith(':memory:'):
|
||||
logger.info('Asynchronous SQLite3 connection started in memory.')
|
||||
else:
|
||||
logger.info(f'Asynchronous SQLite3 connection made to ./{os.path.relpath(constants.DATABASE)}')
|
||||
|
||||
db = ContestDatabase(conn)
|
||||
await db.setup()
|
||||
await conn.commit()
|
||||
|
||||
logger.info('ContestDatabase instance created, database setup.')
|
||||
|
||||
return db
|
||||
|
||||
async def setup(self) -> None:
|
||||
"""Sets up the tables for initial database creation"""
|
||||
cur = await self.conn.cursor()
|
||||
try:
|
||||
await cur.execute('''SELECT name FROM sqlite_master WHERE type='table' AND name = ?;''', ['guild'])
|
||||
if await cur.fetchone() is None:
|
||||
await self.conn.execute('''CREATE TABLE IF NOT EXISTS guild
|
||||
(id INTEGER PRIMARY KEY,
|
||||
prefix TEXT DEFAULT '$',
|
||||
submission INTEGER NULLABLE,
|
||||
period INTEGER)''')
|
||||
logger.info(f"'guild' table created.")
|
||||
|
||||
await cur.execute('''SELECT name FROM sqlite_master WHERE type='table' AND name = ?;''', ['submission'])
|
||||
if await cur.fetchone() is None:
|
||||
await self.conn.execute('''CREATE TABLE IF NOT EXISTS submission
|
||||
(id INTEGER PRIMARY KEY,
|
||||
user INTEGER,
|
||||
guild INTEGER,
|
||||
timestamp DATETIME)''')
|
||||
logger.info(f"'submission' table created.")
|
||||
|
||||
await cur.execute('''SELECT name FROM sqlite_master WHERE type='table' AND name = ?;''', ['period'])
|
||||
if await cur.fetchone() is None:
|
||||
await self.conn.execute('''CREATE TABLE IF NOT EXISTS period
|
||||
(id INTEGER PRIMARY KEY,
|
||||
guild INTEGER,
|
||||
current_state INTEGER,
|
||||
started_at TIMESTAMP,
|
||||
voting_at TIMESTAMP DEFAULT NULL,
|
||||
finished_at TIMESTAMP DEFAULT NULL)''')
|
||||
logger.info(f"'period' table created.")
|
||||
|
||||
finally:
|
||||
await cur.close()
|
||||
|
||||
async def setup_guild(self, guild_id: int) -> None:
|
||||
"""Sets up a guild in the database."""
|
||||
await self.conn.execute('''INSERT INTO guild (id) VALUES (?)''', [guild_id])
|
||||
await self.conn.commit()
|
||||
|
||||
async def set_prefix(self, guild_id: int, new_prefix: str) -> None:
|
||||
"""Updates the prefix for a specific guild in the database"""
|
||||
await self.conn.execute('''UPDATE guild SET prefix = ? WHERE id = ?''', [new_prefix, guild_id])
|
||||
await self.conn.commit()
|
||||
|
||||
async def get_submission(self, guild_id: int, user_id: int) -> Optional[Submission]:
|
||||
"""Retrieves a row from the submission table by the associated unique Guild ID and User ID"""
|
||||
cur = await self.conn.cursor()
|
||||
try:
|
||||
await cur.execute('''SELECT * FROM submission WHERE guild = ? AND user = ?''', [guild_id, user_id])
|
||||
row = await cur.fetchone()
|
||||
return None if row is None else Submission._make(row)
|
||||
finally:
|
||||
await cur.close()
|
||||
|
||||
async def get_guild(self, guild_id: int) -> Optional[Guild]:
|
||||
"""Retrieves a row from the Guild table by the Guild ID"""
|
||||
cur = await self.conn.cursor()
|
||||
try:
|
||||
await cur.execute('''SELECT * FROM guild WHERE id = ?''', [guild_id])
|
||||
row = await cur.fetchone()
|
||||
return None if row is None else Guild._make(row)
|
||||
finally:
|
||||
await cur.close()
|
||||
|
||||
async def get_period(self, period_id) -> Optional[Period]:
|
||||
cur = await self.conn.cursor()
|
||||
try:
|
||||
await cur.execute('''SELECT * FROM period WHERE id = ?''', [period_id])
|
||||
row = await cur.fetchone()
|
||||
return None if row is None else Period._make(row)
|
||||
finally:
|
||||
await cur.close()
|
||||
|
||||
async def get_current_period(self, guild_id: int) -> Optional[Period]:
|
||||
"""Retrieves a row from the Guild table by the Guild ID"""
|
||||
cur = await self.conn.cursor()
|
||||
try:
|
||||
guild = await self.get_guild(guild_id)
|
||||
if guild is None:
|
||||
logger.debug(f'Guild {guild_id} does not exist.')
|
||||
return None
|
||||
|
||||
if guild.period is not None:
|
||||
return await self.get_period(guild.period)
|
||||
finally:
|
||||
await cur.close()
|
||||
|
||||
async def set_submission_channel(self, guild_id: int, new_submission: int) -> None:
|
||||
"""Updates the submission channel for a specific guild in the database"""
|
||||
await self.conn.execute('''UPDATE guild SET submission = ? WHERE id = ?''', [new_submission, guild_id])
|
||||
await self.conn.commit()
|
||||
|
||||
async def teardown_guild(self, guild_id: int) -> None:
|
||||
"""Removes a guild from the database while completing appropriate teardown actions."""
|
||||
await self.conn.execute('''DELETE FROM guild WHERE id = ?''', [guild_id])
|
||||
await self.conn.commit()
|
||||
|
||||
async def add_submission(self, submission_id: int, guild_id: int, user_id: int, timestamp: int = None) -> None:
|
||||
await self.conn.execute(
|
||||
'''INSERT INTO submission (id, user, guild, timestamp) VALUES (?, ?, ?, ?)''',
|
||||
[submission_id, user_id, guild_id, timestamp or datetime.utcnow().timestamp()]
|
||||
)
|
||||
await self.conn.commit()
|
||||
|
||||
@staticmethod
|
||||
async def generate_insert_query(table: str, columns: List[str]) -> str:
|
||||
"""
|
||||
Generate a basic limited insert query based on a destination table and number of named arguments.
|
||||
|
||||
Does NOT execute the query or insert any values - run this with Connection.execute()!
|
||||
"""
|
||||
query = f'''INSERT INTO {table} ({", ".join(columns)}) VALUES ({", ".join("?" for _ in columns)})'''
|
||||
logger.debug(query)
|
||||
return query
|
||||
|
||||
async def new_period(self, period: Period) -> None:
|
||||
"""Given a period, adds the period to the table and updates the associated guild."""
|
||||
cur = await self.conn.cursor()
|
||||
try:
|
||||
# Ensure the associated guild exists
|
||||
if period.guild is None:
|
||||
return logger.error(f'Period {period} did not include a guild to associate with.')
|
||||
|
||||
guild = await self.get_guild(period.guild)
|
||||
if guild is None:
|
||||
return logger.error(f'Specified guild {period.guild} does not exist.')
|
||||
|
||||
# Add the period to the table
|
||||
items = filter(lambda item: item[1] is not None, period._asdict().items())
|
||||
columns, values = zip(*items)
|
||||
query = await self.generate_insert_query('period', list(columns))
|
||||
|
||||
await cur.execute(query, values)
|
||||
await self.conn.commit()
|
||||
|
||||
# Update the associated guild's period
|
||||
await cur.execute('''UPDATE guild SET period = ? WHERE id = ?''', [cur.lastrowid, period.guild])
|
||||
await self.conn.commit()
|
||||
finally:
|
||||
await cur.close()
|
||||
|
||||
async def update(self, obj: Union[Guild, Submission, Period]) -> None:
|
||||
"""Using the objects's ID, updates a row in the database."""
|
||||
assert obj is not None and isinstance(obj, tuple) and getattr(obj, '_fields', None) is not None, ""
|
||||
assert obj.id is not None, ""
|
||||
|
||||
cur = await self.conn.cursor()
|
||||
try:
|
||||
# Add the period to the table
|
||||
items = filter(lambda item: item[1] is not None, obj._asdict().items())
|
||||
columns, values = zip(*items)
|
||||
query = await self.generate_insert_query(type(object).__name__.lower(), list(columns))
|
||||
|
||||
await cur.execute(query, values)
|
||||
await self.conn.commit()
|
||||
finally:
|
||||
await cur.close()
|
||||
Reference in New Issue
Block a user