refactor contest to bot folder, work on generic sqlite update query generator, work on new submissions/voting period logic commands, get away from get_submission_channel

This commit is contained in:
Xevion
2021-02-12 17:46:46 -06:00
parent 5defafa36e
commit 89b16fdc04
7 changed files with 100 additions and 22 deletions

View File

View File

@@ -4,8 +4,8 @@ from typing import Optional
import discord import discord
from discord.ext import commands from discord.ext import commands
from contest import constants from bot import constants
from contest.db import ContestDatabase from bot.db import ContestDatabase
logger = logging.getLogger(__file__) logger = logging.getLogger(__file__)
logger.setLevel(constants.LOGGING_LEVEL) logger.setLevel(constants.LOGGING_LEVEL)

View File

View File

@@ -1,11 +1,13 @@
import logging import logging
from datetime import datetime
import discord import discord
from discord.ext import commands from discord.ext import commands
from discord.ext.commands import Context from discord.ext.commands import Context
from contest import checks, constants from bot import checks, constants
from contest.bot import ContestBot from bot.bot import ContestBot
from bot.db import Period
logger = logging.getLogger(__file__) logger = logging.getLogger(__file__)
logger.setLevel(constants.LOGGING_LEVEL) logger.setLevel(constants.LOGGING_LEVEL)
@@ -22,9 +24,10 @@ class ContestCog(commands.Cog):
@checks.privileged() @checks.privileged()
async def prefix(self, ctx, new_prefix: str): async def prefix(self, ctx, new_prefix: str):
"""Changes the bot's saved prefix.""" """Changes the bot's saved prefix."""
cur_prefix = await self.bot.db.get_prefix(ctx.guild.id) guild = await self.bot.db.get_guild(ctx.guild.id)
if 1 <= len(new_prefix) <= 2: if 1 <= len(new_prefix) <= 2:
if cur_prefix == new_prefix: if guild.prefix == new_prefix:
return await ctx.send(f':no_entry_sign: The prefix is already `{new_prefix}`.') return await ctx.send(f':no_entry_sign: The prefix is already `{new_prefix}`.')
else: else:
await self.bot.db.set_prefix(ctx.guild.id, new_prefix) await self.bot.db.set_prefix(ctx.guild.id, new_prefix)
@@ -35,23 +38,76 @@ class ContestCog(commands.Cog):
@commands.command() @commands.command()
@commands.guild_only() @commands.guild_only()
@checks.privileged() @checks.privileged()
async def submission(self, ctx: Context, new_submission: discord.TextChannel): async def submission(self, ctx: Context, new_submission: discord.TextChannel) -> None:
"""Changes the bot's saved submission channel.""" """Changes the bot's saved submission channel."""
cur_submission = await self.bot.db.get_submission_channel(ctx.guild.id) guild = await self.bot.db.get_guild(ctx.guild.id)
if cur_submission == new_submission.id:
return await ctx.send( if guild.submission is not None and guild.submission == new_submission.id:
await ctx.send(
f':no_entry_sign: The submission channel is already set to {new_submission.mention}.') f':no_entry_sign: The submission channel is already set to {new_submission.mention}.')
else: else:
await self.bot.db.set_submission_channel(ctx.guild.id, new_submission.id) await self.bot.db.set_submission_channel(ctx.guild.id, new_submission.id)
return await ctx.send(f':white_check_mark: Submission channel changed to {new_submission.mention}.') await ctx.send(f':white_check_mark: Submission channel changed to {new_submission.mention}.')
@commands.command()
@commands.guild_only()
@checks.privileged()
async def submissions(self, ctx: Context, duration: float = None) -> None:
"""Opens up the submissions channel."""
assert duration == -1 or duration >= 0, "Duration must"
cur = await self.bot.db.conn.cursor()
try:
period = await self.bot.db.get_current_period(ctx.guild.id)
# Handle non-existent or final-state period
if period is None:
new_period = Period(guild=ctx.guild.id, current_state=0, started_at=datetime.now(), voting_at=None, finished_at=None)
await self.bot.db.new_period(new_period)
# Handle submissions state
elif period.current_state == 0:
await self.bot.db.update_period(period)
return
# Handle voting state
elif period.current_state == 1:
return
# Print period submissions
elif period.current_state == 2:
# TODO: Fetch all submissions related to this period
# TODO: Create new period for Guild at
return
finally:
await cur.close()
@commands.command()
@commands.guild_only()
@checks.privileged()
async def voting(self, ctx: Context, duration: float = None) -> None:
"""Closes submissions and sets up the voting period."""
if duration < 0:
await ctx.send('Invalid duration - must be non-negative.')
@commands.command()
@commands.guild_only()
@checks.privileged()
async def close(self, ctx: Context) -> None:
"""Closes the voting period."""
pass
@commands.command()
@commands.guild_only()
async def status(self, ctx: Context) -> None:
"""Provides the bot's current state in relation to internal configuration and the server's contest, if active."""
pass
@commands.Cog.listener() @commands.Cog.listener()
async def on_message(self, message: discord.Message): async def on_message(self, message: discord.Message):
if message.author == self.bot.user or message.author.bot or not message.guild: return if message.author == self.bot.user or message.author.bot or not message.guild: return
cur_submission = await self.bot.db.get_submission_channel(message.guild.id) guild = await self.bot.db.get_guild(message.guild.id)
channel: discord.TextChannel = message.channel channel: discord.TextChannel = message.channel
if channel.id == cur_submission: if channel.id == guild.submission:
attachments = message.attachments attachments = message.attachments
if len(attachments) == 0: if len(attachments) == 0:
await message.delete(delay=1) await message.delete(delay=1)
@@ -94,8 +150,8 @@ class ContestCog(commands.Cog):
# If the message was cached, check that it's in the correct channel. # If the message was cached, check that it's in the correct channel.
if payload.cached_message is not None: if payload.cached_message is not None:
cur_submission = await self.bot.db.get_submission_channel(payload.guild_id) guild = await self.bot.db.get_guild(payload.guild_id)
if payload.cached_message.channel.id != cur_submission: if payload.cached_message.channel.id != guild.submission:
return return
cur = await self.bot.db.conn.cursor() cur = await self.bot.db.conn.cursor()

View File

View File

@@ -3,11 +3,11 @@ import os
import sqlite3 import sqlite3
from collections import namedtuple from collections import namedtuple
from datetime import datetime from datetime import datetime
from typing import Optional, List from typing import Optional, List, Union
import aiosqlite import aiosqlite
from contest import constants from bot import constants
logger = logging.getLogger(__file__) logger = logging.getLogger(__file__)
logger.setLevel(constants.LOGGING_LEVEL) logger.setLevel(constants.LOGGING_LEVEL)
@@ -154,7 +154,14 @@ class ContestDatabase(object):
@staticmethod @staticmethod
async def generate_insert_query(table: str, columns: List[str]) -> str: async def generate_insert_query(table: str, columns: List[str]) -> str:
return f'''INSERT INTO {table} ({", ".join(columns)}) VALUES ({", ".join("?" for _ in columns)})''' """
Generate a basic limited insert query based on a destination table and number of named arguments.
Does NOT execute the query or insert any values - run this with Connection.execute()!
"""
query = f'''INSERT INTO {table} ({", ".join(columns)}) VALUES ({", ".join("?" for _ in columns)})'''
logger.debug(query)
return query
async def new_period(self, period: Period) -> None: async def new_period(self, period: Period) -> None:
"""Given a period, adds the period to the table and updates the associated guild.""" """Given a period, adds the period to the table and updates the associated guild."""
@@ -172,7 +179,7 @@ class ContestDatabase(object):
items = filter(lambda item: item[1] is not None, period._asdict().items()) items = filter(lambda item: item[1] is not None, period._asdict().items())
columns, values = zip(*items) columns, values = zip(*items)
query = await self.generate_insert_query('period', list(columns)) query = await self.generate_insert_query('period', list(columns))
logger.debug(f'Generated Insert Query: {query}')
await cur.execute(query, values) await cur.execute(query, values)
await self.conn.commit() await self.conn.commit()
@@ -182,4 +189,19 @@ class ContestDatabase(object):
finally: finally:
await cur.close() await cur.close()
async def update_period(self, ): async def update(self, obj: Union[Guild, Submission, Period]) -> None:
"""Using the objects's ID, updates a row in the database."""
assert obj is not None and isinstance(obj, tuple) and getattr(obj, '_fields', None) is not None, ""
assert obj.id is not None, ""
cur = await self.conn.cursor()
try:
# Add the period to the table
items = filter(lambda item: item[1] is not None, obj._asdict().items())
columns, values = zip(*items)
query = await self.generate_insert_query(type(object).__name__.lower(), list(columns))
await cur.execute(query, values)
await self.conn.commit()
finally:
await cur.close()

View File

@@ -1,7 +1,7 @@
import logging import logging
from contest import constants from bot import constants
from contest.bot import ContestBot from bot.bot import ContestBot
if __name__ == "__main__": if __name__ == "__main__":
logger = logging.getLogger(__file__) logger = logging.getLogger(__file__)