diff --git a/contest/__init__.py b/bot/__init__.py similarity index 100% rename from contest/__init__.py rename to bot/__init__.py diff --git a/contest/bot.py b/bot/bot.py similarity index 95% rename from contest/bot.py rename to bot/bot.py index a5aa34e..3c35965 100644 --- a/contest/bot.py +++ b/bot/bot.py @@ -4,8 +4,8 @@ from typing import Optional import discord from discord.ext import commands -from contest import constants -from contest.db import ContestDatabase +from bot import constants +from bot.db import ContestDatabase logger = logging.getLogger(__file__) logger.setLevel(constants.LOGGING_LEVEL) diff --git a/contest/checks.py b/bot/checks.py similarity index 100% rename from contest/checks.py rename to bot/checks.py diff --git a/contest/cogs/contest.py b/bot/cogs/contest.py similarity index 65% rename from contest/cogs/contest.py rename to bot/cogs/contest.py index 82fbb90..b21a16b 100644 --- a/contest/cogs/contest.py +++ b/bot/cogs/contest.py @@ -1,11 +1,13 @@ import logging +from datetime import datetime import discord from discord.ext import commands from discord.ext.commands import Context -from contest import checks, constants -from contest.bot import ContestBot +from bot import checks, constants +from bot.bot import ContestBot +from bot.db import Period logger = logging.getLogger(__file__) logger.setLevel(constants.LOGGING_LEVEL) @@ -22,9 +24,10 @@ class ContestCog(commands.Cog): @checks.privileged() async def prefix(self, ctx, new_prefix: str): """Changes the bot's saved prefix.""" - cur_prefix = await self.bot.db.get_prefix(ctx.guild.id) + guild = await self.bot.db.get_guild(ctx.guild.id) + if 1 <= len(new_prefix) <= 2: - if cur_prefix == new_prefix: + if guild.prefix == new_prefix: return await ctx.send(f':no_entry_sign: The prefix is already `{new_prefix}`.') else: await self.bot.db.set_prefix(ctx.guild.id, new_prefix) @@ -35,23 +38,76 @@ class ContestCog(commands.Cog): @commands.command() @commands.guild_only() @checks.privileged() - async def submission(self, ctx: Context, new_submission: discord.TextChannel): + async def submission(self, ctx: Context, new_submission: discord.TextChannel) -> None: """Changes the bot's saved submission channel.""" - cur_submission = await self.bot.db.get_submission_channel(ctx.guild.id) - if cur_submission == new_submission.id: - return await ctx.send( + guild = await self.bot.db.get_guild(ctx.guild.id) + + if guild.submission is not None and guild.submission == new_submission.id: + await ctx.send( f':no_entry_sign: The submission channel is already set to {new_submission.mention}.') else: await self.bot.db.set_submission_channel(ctx.guild.id, new_submission.id) - return await ctx.send(f':white_check_mark: Submission channel changed to {new_submission.mention}.') + await ctx.send(f':white_check_mark: Submission channel changed to {new_submission.mention}.') + + @commands.command() + @commands.guild_only() + @checks.privileged() + async def submissions(self, ctx: Context, duration: float = None) -> None: + """Opens up the submissions channel.""" + assert duration == -1 or duration >= 0, "Duration must" + + cur = await self.bot.db.conn.cursor() + try: + period = await self.bot.db.get_current_period(ctx.guild.id) + + # Handle non-existent or final-state period + if period is None: + new_period = Period(guild=ctx.guild.id, current_state=0, started_at=datetime.now(), voting_at=None, finished_at=None) + await self.bot.db.new_period(new_period) + # Handle submissions state + elif period.current_state == 0: + await self.bot.db.update_period(period) + return + # Handle voting state + elif period.current_state == 1: + return + # Print period submissions + elif period.current_state == 2: + # TODO: Fetch all submissions related to this period + # TODO: Create new period for Guild at + return + finally: + await cur.close() + + + @commands.command() + @commands.guild_only() + @checks.privileged() + async def voting(self, ctx: Context, duration: float = None) -> None: + """Closes submissions and sets up the voting period.""" + if duration < 0: + await ctx.send('Invalid duration - must be non-negative.') + + @commands.command() + @commands.guild_only() + @checks.privileged() + async def close(self, ctx: Context) -> None: + """Closes the voting period.""" + pass + + @commands.command() + @commands.guild_only() + async def status(self, ctx: Context) -> None: + """Provides the bot's current state in relation to internal configuration and the server's contest, if active.""" + pass @commands.Cog.listener() async def on_message(self, message: discord.Message): if message.author == self.bot.user or message.author.bot or not message.guild: return - cur_submission = await self.bot.db.get_submission_channel(message.guild.id) + guild = await self.bot.db.get_guild(message.guild.id) channel: discord.TextChannel = message.channel - if channel.id == cur_submission: + if channel.id == guild.submission: attachments = message.attachments if len(attachments) == 0: await message.delete(delay=1) @@ -94,8 +150,8 @@ class ContestCog(commands.Cog): # If the message was cached, check that it's in the correct channel. if payload.cached_message is not None: - cur_submission = await self.bot.db.get_submission_channel(payload.guild_id) - if payload.cached_message.channel.id != cur_submission: + guild = await self.bot.db.get_guild(payload.guild_id) + if payload.cached_message.channel.id != guild.submission: return cur = await self.bot.db.conn.cursor() diff --git a/contest/constants.py b/bot/constants.py similarity index 100% rename from contest/constants.py rename to bot/constants.py diff --git a/contest/db.py b/bot/db.py similarity index 87% rename from contest/db.py rename to bot/db.py index a6ad818..176d958 100644 --- a/contest/db.py +++ b/bot/db.py @@ -3,11 +3,11 @@ import os import sqlite3 from collections import namedtuple from datetime import datetime -from typing import Optional, List +from typing import Optional, List, Union import aiosqlite -from contest import constants +from bot import constants logger = logging.getLogger(__file__) logger.setLevel(constants.LOGGING_LEVEL) @@ -154,7 +154,14 @@ class ContestDatabase(object): @staticmethod async def generate_insert_query(table: str, columns: List[str]) -> str: - return f'''INSERT INTO {table} ({", ".join(columns)}) VALUES ({", ".join("?" for _ in columns)})''' + """ + Generate a basic limited insert query based on a destination table and number of named arguments. + + Does NOT execute the query or insert any values - run this with Connection.execute()! + """ + query = f'''INSERT INTO {table} ({", ".join(columns)}) VALUES ({", ".join("?" for _ in columns)})''' + logger.debug(query) + return query async def new_period(self, period: Period) -> None: """Given a period, adds the period to the table and updates the associated guild.""" @@ -172,7 +179,7 @@ class ContestDatabase(object): items = filter(lambda item: item[1] is not None, period._asdict().items()) columns, values = zip(*items) query = await self.generate_insert_query('period', list(columns)) - logger.debug(f'Generated Insert Query: {query}') + await cur.execute(query, values) await self.conn.commit() @@ -182,4 +189,19 @@ class ContestDatabase(object): finally: await cur.close() - async def update_period(self, ): + async def update(self, obj: Union[Guild, Submission, Period]) -> None: + """Using the objects's ID, updates a row in the database.""" + assert obj is not None and isinstance(obj, tuple) and getattr(obj, '_fields', None) is not None, "" + assert obj.id is not None, "" + + cur = await self.conn.cursor() + try: + # Add the period to the table + items = filter(lambda item: item[1] is not None, obj._asdict().items()) + columns, values = zip(*items) + query = await self.generate_insert_query(type(object).__name__.lower(), list(columns)) + + await cur.execute(query, values) + await self.conn.commit() + finally: + await cur.close() diff --git a/main.py b/main.py index 6722686..37c44d9 100644 --- a/main.py +++ b/main.py @@ -1,7 +1,7 @@ import logging -from contest import constants -from contest.bot import ContestBot +from bot import constants +from bot.bot import ContestBot if __name__ == "__main__": logger = logging.getLogger(__file__)