refactor contest to bot folder, work on generic sqlite update query generator, work on new submissions/voting period logic commands, get away from get_submission_channel

This commit is contained in:
Xevion
2021-02-12 17:46:46 -06:00
parent 5defafa36e
commit 89b16fdc04
7 changed files with 100 additions and 22 deletions

View File

View File

@@ -4,8 +4,8 @@ from typing import Optional
import discord
from discord.ext import commands
from contest import constants
from contest.db import ContestDatabase
from bot import constants
from bot.db import ContestDatabase
logger = logging.getLogger(__file__)
logger.setLevel(constants.LOGGING_LEVEL)

View File

View File

@@ -1,11 +1,13 @@
import logging
from datetime import datetime
import discord
from discord.ext import commands
from discord.ext.commands import Context
from contest import checks, constants
from contest.bot import ContestBot
from bot import checks, constants
from bot.bot import ContestBot
from bot.db import Period
logger = logging.getLogger(__file__)
logger.setLevel(constants.LOGGING_LEVEL)
@@ -22,9 +24,10 @@ class ContestCog(commands.Cog):
@checks.privileged()
async def prefix(self, ctx, new_prefix: str):
"""Changes the bot's saved prefix."""
cur_prefix = await self.bot.db.get_prefix(ctx.guild.id)
guild = await self.bot.db.get_guild(ctx.guild.id)
if 1 <= len(new_prefix) <= 2:
if cur_prefix == new_prefix:
if guild.prefix == new_prefix:
return await ctx.send(f':no_entry_sign: The prefix is already `{new_prefix}`.')
else:
await self.bot.db.set_prefix(ctx.guild.id, new_prefix)
@@ -35,23 +38,76 @@ class ContestCog(commands.Cog):
@commands.command()
@commands.guild_only()
@checks.privileged()
async def submission(self, ctx: Context, new_submission: discord.TextChannel):
async def submission(self, ctx: Context, new_submission: discord.TextChannel) -> None:
"""Changes the bot's saved submission channel."""
cur_submission = await self.bot.db.get_submission_channel(ctx.guild.id)
if cur_submission == new_submission.id:
return await ctx.send(
guild = await self.bot.db.get_guild(ctx.guild.id)
if guild.submission is not None and guild.submission == new_submission.id:
await ctx.send(
f':no_entry_sign: The submission channel is already set to {new_submission.mention}.')
else:
await self.bot.db.set_submission_channel(ctx.guild.id, new_submission.id)
return await ctx.send(f':white_check_mark: Submission channel changed to {new_submission.mention}.')
await ctx.send(f':white_check_mark: Submission channel changed to {new_submission.mention}.')
@commands.command()
@commands.guild_only()
@checks.privileged()
async def submissions(self, ctx: Context, duration: float = None) -> None:
"""Opens up the submissions channel."""
assert duration == -1 or duration >= 0, "Duration must"
cur = await self.bot.db.conn.cursor()
try:
period = await self.bot.db.get_current_period(ctx.guild.id)
# Handle non-existent or final-state period
if period is None:
new_period = Period(guild=ctx.guild.id, current_state=0, started_at=datetime.now(), voting_at=None, finished_at=None)
await self.bot.db.new_period(new_period)
# Handle submissions state
elif period.current_state == 0:
await self.bot.db.update_period(period)
return
# Handle voting state
elif period.current_state == 1:
return
# Print period submissions
elif period.current_state == 2:
# TODO: Fetch all submissions related to this period
# TODO: Create new period for Guild at
return
finally:
await cur.close()
@commands.command()
@commands.guild_only()
@checks.privileged()
async def voting(self, ctx: Context, duration: float = None) -> None:
"""Closes submissions and sets up the voting period."""
if duration < 0:
await ctx.send('Invalid duration - must be non-negative.')
@commands.command()
@commands.guild_only()
@checks.privileged()
async def close(self, ctx: Context) -> None:
"""Closes the voting period."""
pass
@commands.command()
@commands.guild_only()
async def status(self, ctx: Context) -> None:
"""Provides the bot's current state in relation to internal configuration and the server's contest, if active."""
pass
@commands.Cog.listener()
async def on_message(self, message: discord.Message):
if message.author == self.bot.user or message.author.bot or not message.guild: return
cur_submission = await self.bot.db.get_submission_channel(message.guild.id)
guild = await self.bot.db.get_guild(message.guild.id)
channel: discord.TextChannel = message.channel
if channel.id == cur_submission:
if channel.id == guild.submission:
attachments = message.attachments
if len(attachments) == 0:
await message.delete(delay=1)
@@ -94,8 +150,8 @@ class ContestCog(commands.Cog):
# If the message was cached, check that it's in the correct channel.
if payload.cached_message is not None:
cur_submission = await self.bot.db.get_submission_channel(payload.guild_id)
if payload.cached_message.channel.id != cur_submission:
guild = await self.bot.db.get_guild(payload.guild_id)
if payload.cached_message.channel.id != guild.submission:
return
cur = await self.bot.db.conn.cursor()

View File

View File

@@ -3,11 +3,11 @@ import os
import sqlite3
from collections import namedtuple
from datetime import datetime
from typing import Optional, List
from typing import Optional, List, Union
import aiosqlite
from contest import constants
from bot import constants
logger = logging.getLogger(__file__)
logger.setLevel(constants.LOGGING_LEVEL)
@@ -154,7 +154,14 @@ class ContestDatabase(object):
@staticmethod
async def generate_insert_query(table: str, columns: List[str]) -> str:
return f'''INSERT INTO {table} ({", ".join(columns)}) VALUES ({", ".join("?" for _ in columns)})'''
"""
Generate a basic limited insert query based on a destination table and number of named arguments.
Does NOT execute the query or insert any values - run this with Connection.execute()!
"""
query = f'''INSERT INTO {table} ({", ".join(columns)}) VALUES ({", ".join("?" for _ in columns)})'''
logger.debug(query)
return query
async def new_period(self, period: Period) -> None:
"""Given a period, adds the period to the table and updates the associated guild."""
@@ -172,7 +179,7 @@ class ContestDatabase(object):
items = filter(lambda item: item[1] is not None, period._asdict().items())
columns, values = zip(*items)
query = await self.generate_insert_query('period', list(columns))
logger.debug(f'Generated Insert Query: {query}')
await cur.execute(query, values)
await self.conn.commit()
@@ -182,4 +189,19 @@ class ContestDatabase(object):
finally:
await cur.close()
async def update_period(self, ):
async def update(self, obj: Union[Guild, Submission, Period]) -> None:
"""Using the objects's ID, updates a row in the database."""
assert obj is not None and isinstance(obj, tuple) and getattr(obj, '_fields', None) is not None, ""
assert obj.id is not None, ""
cur = await self.conn.cursor()
try:
# Add the period to the table
items = filter(lambda item: item[1] is not None, obj._asdict().items())
columns, values = zip(*items)
query = await self.generate_insert_query(type(object).__name__.lower(), list(columns))
await cur.execute(query, values)
await self.conn.commit()
finally:
await cur.close()

View File

@@ -1,7 +1,7 @@
import logging
from contest import constants
from contest.bot import ContestBot
from bot import constants
from bot.bot import ContestBot
if __name__ == "__main__":
logger = logging.getLogger(__file__)