mirror of
https://github.com/Xevion/contest-assistant.git
synced 2025-12-06 07:14:36 -06:00
refactor contest to bot folder, work on generic sqlite update query generator, work on new submissions/voting period logic commands, get away from get_submission_channel
This commit is contained in:
@@ -4,8 +4,8 @@ from typing import Optional
|
|||||||
import discord
|
import discord
|
||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
|
|
||||||
from contest import constants
|
from bot import constants
|
||||||
from contest.db import ContestDatabase
|
from bot.db import ContestDatabase
|
||||||
|
|
||||||
logger = logging.getLogger(__file__)
|
logger = logging.getLogger(__file__)
|
||||||
logger.setLevel(constants.LOGGING_LEVEL)
|
logger.setLevel(constants.LOGGING_LEVEL)
|
||||||
@@ -1,11 +1,13 @@
|
|||||||
import logging
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
from discord.ext.commands import Context
|
from discord.ext.commands import Context
|
||||||
|
|
||||||
from contest import checks, constants
|
from bot import checks, constants
|
||||||
from contest.bot import ContestBot
|
from bot.bot import ContestBot
|
||||||
|
from bot.db import Period
|
||||||
|
|
||||||
logger = logging.getLogger(__file__)
|
logger = logging.getLogger(__file__)
|
||||||
logger.setLevel(constants.LOGGING_LEVEL)
|
logger.setLevel(constants.LOGGING_LEVEL)
|
||||||
@@ -22,9 +24,10 @@ class ContestCog(commands.Cog):
|
|||||||
@checks.privileged()
|
@checks.privileged()
|
||||||
async def prefix(self, ctx, new_prefix: str):
|
async def prefix(self, ctx, new_prefix: str):
|
||||||
"""Changes the bot's saved prefix."""
|
"""Changes the bot's saved prefix."""
|
||||||
cur_prefix = await self.bot.db.get_prefix(ctx.guild.id)
|
guild = await self.bot.db.get_guild(ctx.guild.id)
|
||||||
|
|
||||||
if 1 <= len(new_prefix) <= 2:
|
if 1 <= len(new_prefix) <= 2:
|
||||||
if cur_prefix == new_prefix:
|
if guild.prefix == new_prefix:
|
||||||
return await ctx.send(f':no_entry_sign: The prefix is already `{new_prefix}`.')
|
return await ctx.send(f':no_entry_sign: The prefix is already `{new_prefix}`.')
|
||||||
else:
|
else:
|
||||||
await self.bot.db.set_prefix(ctx.guild.id, new_prefix)
|
await self.bot.db.set_prefix(ctx.guild.id, new_prefix)
|
||||||
@@ -35,23 +38,76 @@ class ContestCog(commands.Cog):
|
|||||||
@commands.command()
|
@commands.command()
|
||||||
@commands.guild_only()
|
@commands.guild_only()
|
||||||
@checks.privileged()
|
@checks.privileged()
|
||||||
async def submission(self, ctx: Context, new_submission: discord.TextChannel):
|
async def submission(self, ctx: Context, new_submission: discord.TextChannel) -> None:
|
||||||
"""Changes the bot's saved submission channel."""
|
"""Changes the bot's saved submission channel."""
|
||||||
cur_submission = await self.bot.db.get_submission_channel(ctx.guild.id)
|
guild = await self.bot.db.get_guild(ctx.guild.id)
|
||||||
if cur_submission == new_submission.id:
|
|
||||||
return await ctx.send(
|
if guild.submission is not None and guild.submission == new_submission.id:
|
||||||
|
await ctx.send(
|
||||||
f':no_entry_sign: The submission channel is already set to {new_submission.mention}.')
|
f':no_entry_sign: The submission channel is already set to {new_submission.mention}.')
|
||||||
else:
|
else:
|
||||||
await self.bot.db.set_submission_channel(ctx.guild.id, new_submission.id)
|
await self.bot.db.set_submission_channel(ctx.guild.id, new_submission.id)
|
||||||
return await ctx.send(f':white_check_mark: Submission channel changed to {new_submission.mention}.')
|
await ctx.send(f':white_check_mark: Submission channel changed to {new_submission.mention}.')
|
||||||
|
|
||||||
|
@commands.command()
|
||||||
|
@commands.guild_only()
|
||||||
|
@checks.privileged()
|
||||||
|
async def submissions(self, ctx: Context, duration: float = None) -> None:
|
||||||
|
"""Opens up the submissions channel."""
|
||||||
|
assert duration == -1 or duration >= 0, "Duration must"
|
||||||
|
|
||||||
|
cur = await self.bot.db.conn.cursor()
|
||||||
|
try:
|
||||||
|
period = await self.bot.db.get_current_period(ctx.guild.id)
|
||||||
|
|
||||||
|
# Handle non-existent or final-state period
|
||||||
|
if period is None:
|
||||||
|
new_period = Period(guild=ctx.guild.id, current_state=0, started_at=datetime.now(), voting_at=None, finished_at=None)
|
||||||
|
await self.bot.db.new_period(new_period)
|
||||||
|
# Handle submissions state
|
||||||
|
elif period.current_state == 0:
|
||||||
|
await self.bot.db.update_period(period)
|
||||||
|
return
|
||||||
|
# Handle voting state
|
||||||
|
elif period.current_state == 1:
|
||||||
|
return
|
||||||
|
# Print period submissions
|
||||||
|
elif period.current_state == 2:
|
||||||
|
# TODO: Fetch all submissions related to this period
|
||||||
|
# TODO: Create new period for Guild at
|
||||||
|
return
|
||||||
|
finally:
|
||||||
|
await cur.close()
|
||||||
|
|
||||||
|
|
||||||
|
@commands.command()
|
||||||
|
@commands.guild_only()
|
||||||
|
@checks.privileged()
|
||||||
|
async def voting(self, ctx: Context, duration: float = None) -> None:
|
||||||
|
"""Closes submissions and sets up the voting period."""
|
||||||
|
if duration < 0:
|
||||||
|
await ctx.send('Invalid duration - must be non-negative.')
|
||||||
|
|
||||||
|
@commands.command()
|
||||||
|
@commands.guild_only()
|
||||||
|
@checks.privileged()
|
||||||
|
async def close(self, ctx: Context) -> None:
|
||||||
|
"""Closes the voting period."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@commands.command()
|
||||||
|
@commands.guild_only()
|
||||||
|
async def status(self, ctx: Context) -> None:
|
||||||
|
"""Provides the bot's current state in relation to internal configuration and the server's contest, if active."""
|
||||||
|
pass
|
||||||
|
|
||||||
@commands.Cog.listener()
|
@commands.Cog.listener()
|
||||||
async def on_message(self, message: discord.Message):
|
async def on_message(self, message: discord.Message):
|
||||||
if message.author == self.bot.user or message.author.bot or not message.guild: return
|
if message.author == self.bot.user or message.author.bot or not message.guild: return
|
||||||
cur_submission = await self.bot.db.get_submission_channel(message.guild.id)
|
guild = await self.bot.db.get_guild(message.guild.id)
|
||||||
|
|
||||||
channel: discord.TextChannel = message.channel
|
channel: discord.TextChannel = message.channel
|
||||||
if channel.id == cur_submission:
|
if channel.id == guild.submission:
|
||||||
attachments = message.attachments
|
attachments = message.attachments
|
||||||
if len(attachments) == 0:
|
if len(attachments) == 0:
|
||||||
await message.delete(delay=1)
|
await message.delete(delay=1)
|
||||||
@@ -94,8 +150,8 @@ class ContestCog(commands.Cog):
|
|||||||
|
|
||||||
# If the message was cached, check that it's in the correct channel.
|
# If the message was cached, check that it's in the correct channel.
|
||||||
if payload.cached_message is not None:
|
if payload.cached_message is not None:
|
||||||
cur_submission = await self.bot.db.get_submission_channel(payload.guild_id)
|
guild = await self.bot.db.get_guild(payload.guild_id)
|
||||||
if payload.cached_message.channel.id != cur_submission:
|
if payload.cached_message.channel.id != guild.submission:
|
||||||
return
|
return
|
||||||
|
|
||||||
cur = await self.bot.db.conn.cursor()
|
cur = await self.bot.db.conn.cursor()
|
||||||
@@ -3,11 +3,11 @@ import os
|
|||||||
import sqlite3
|
import sqlite3
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional, List
|
from typing import Optional, List, Union
|
||||||
|
|
||||||
import aiosqlite
|
import aiosqlite
|
||||||
|
|
||||||
from contest import constants
|
from bot import constants
|
||||||
|
|
||||||
logger = logging.getLogger(__file__)
|
logger = logging.getLogger(__file__)
|
||||||
logger.setLevel(constants.LOGGING_LEVEL)
|
logger.setLevel(constants.LOGGING_LEVEL)
|
||||||
@@ -154,7 +154,14 @@ class ContestDatabase(object):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def generate_insert_query(table: str, columns: List[str]) -> str:
|
async def generate_insert_query(table: str, columns: List[str]) -> str:
|
||||||
return f'''INSERT INTO {table} ({", ".join(columns)}) VALUES ({", ".join("?" for _ in columns)})'''
|
"""
|
||||||
|
Generate a basic limited insert query based on a destination table and number of named arguments.
|
||||||
|
|
||||||
|
Does NOT execute the query or insert any values - run this with Connection.execute()!
|
||||||
|
"""
|
||||||
|
query = f'''INSERT INTO {table} ({", ".join(columns)}) VALUES ({", ".join("?" for _ in columns)})'''
|
||||||
|
logger.debug(query)
|
||||||
|
return query
|
||||||
|
|
||||||
async def new_period(self, period: Period) -> None:
|
async def new_period(self, period: Period) -> None:
|
||||||
"""Given a period, adds the period to the table and updates the associated guild."""
|
"""Given a period, adds the period to the table and updates the associated guild."""
|
||||||
@@ -172,7 +179,7 @@ class ContestDatabase(object):
|
|||||||
items = filter(lambda item: item[1] is not None, period._asdict().items())
|
items = filter(lambda item: item[1] is not None, period._asdict().items())
|
||||||
columns, values = zip(*items)
|
columns, values = zip(*items)
|
||||||
query = await self.generate_insert_query('period', list(columns))
|
query = await self.generate_insert_query('period', list(columns))
|
||||||
logger.debug(f'Generated Insert Query: {query}')
|
|
||||||
await cur.execute(query, values)
|
await cur.execute(query, values)
|
||||||
await self.conn.commit()
|
await self.conn.commit()
|
||||||
|
|
||||||
@@ -182,4 +189,19 @@ class ContestDatabase(object):
|
|||||||
finally:
|
finally:
|
||||||
await cur.close()
|
await cur.close()
|
||||||
|
|
||||||
async def update_period(self, ):
|
async def update(self, obj: Union[Guild, Submission, Period]) -> None:
|
||||||
|
"""Using the objects's ID, updates a row in the database."""
|
||||||
|
assert obj is not None and isinstance(obj, tuple) and getattr(obj, '_fields', None) is not None, ""
|
||||||
|
assert obj.id is not None, ""
|
||||||
|
|
||||||
|
cur = await self.conn.cursor()
|
||||||
|
try:
|
||||||
|
# Add the period to the table
|
||||||
|
items = filter(lambda item: item[1] is not None, obj._asdict().items())
|
||||||
|
columns, values = zip(*items)
|
||||||
|
query = await self.generate_insert_query(type(object).__name__.lower(), list(columns))
|
||||||
|
|
||||||
|
await cur.execute(query, values)
|
||||||
|
await self.conn.commit()
|
||||||
|
finally:
|
||||||
|
await cur.close()
|
||||||
4
main.py
4
main.py
@@ -1,7 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from contest import constants
|
from bot import constants
|
||||||
from contest.bot import ContestBot
|
from bot.bot import ContestBot
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
logger = logging.getLogger(__file__)
|
logger = logging.getLogger(__file__)
|
||||||
|
|||||||
Reference in New Issue
Block a user