begin refactoring into namedtuple style, add period setup, transition to more modular and safer database state management

This commit is contained in:
Xevion
2021-02-09 14:32:32 -06:00
parent 294b34abe0
commit 5defafa36e

View File

@@ -1,8 +1,9 @@
import logging import logging
import os import os
import sqlite3 import sqlite3
from collections import namedtuple
from datetime import datetime from datetime import datetime
from typing import Optional from typing import Optional, List
import aiosqlite import aiosqlite
@@ -11,6 +12,10 @@ from contest import constants
logger = logging.getLogger(__file__) logger = logging.getLogger(__file__)
logger.setLevel(constants.LOGGING_LEVEL) logger.setLevel(constants.LOGGING_LEVEL)
Guild = namedtuple('Guild', ['id', 'prefix', 'submission', 'period'])
Submission = namedtuple('Submission', ['id', 'user', 'guild', 'timestamp'])
Period = namedtuple('Period', ['id', 'guild', 'current_state', 'started_at', 'voting_at', 'finished_at', ''])
class ContestDatabase(object): class ContestDatabase(object):
""" """
@@ -21,17 +26,24 @@ class ContestDatabase(object):
self.conn = conn self.conn = conn
@classmethod @classmethod
async def create(cls) -> 'ContestDatabase': async def create(cls, dest: str = constants.DATABASE) -> 'ContestDatabase':
""" """
Constructs a ContestDatabase object connecting to the default database location with the proper connection settings. Constructs a ContestDatabase object connecting to the default database location with the proper connection settings.
:return: A fully realized ContestDatabase object. :return: A fully realized ContestDatabase object.
""" """
conn = await aiosqlite.connect(constants.DATABASE, detect_types=sqlite3.PARSE_DECLTYPES)
logger.info(f'Asynchronous SQLite3 connection made to ./{os.path.relpath(constants.DATABASE)}') conn = await aiosqlite.connect(dest, detect_types=sqlite3.PARSE_DECLTYPES)
if dest.startswith(':memory:'):
logger.info('Asynchronous SQLite3 connection started in memory.')
else:
logger.info(f'Asynchronous SQLite3 connection made to ./{os.path.relpath(constants.DATABASE)}')
db = ContestDatabase(conn) db = ContestDatabase(conn)
await db.setup() await db.setup()
await conn.commit() await conn.commit()
logger.info('ContestDatabase instance created.')
logger.info('ContestDatabase instance created, database setup.')
return db return db
async def setup(self) -> None: async def setup(self) -> None:
@@ -43,7 +55,8 @@ class ContestDatabase(object):
await self.conn.execute('''CREATE TABLE IF NOT EXISTS guild await self.conn.execute('''CREATE TABLE IF NOT EXISTS guild
(id INTEGER PRIMARY KEY, (id INTEGER PRIMARY KEY,
prefix TEXT DEFAULT '$', prefix TEXT DEFAULT '$',
submission INTEGER NULLABLE)''') submission INTEGER NULLABLE,
period INTEGER)''')
logger.info(f"'guild' table created.") logger.info(f"'guild' table created.")
await cur.execute('''SELECT name FROM sqlite_master WHERE type='table' AND name = ?;''', ['submission']) await cur.execute('''SELECT name FROM sqlite_master WHERE type='table' AND name = ?;''', ['submission'])
@@ -54,6 +67,18 @@ class ContestDatabase(object):
guild INTEGER, guild INTEGER,
timestamp DATETIME)''') timestamp DATETIME)''')
logger.info(f"'submission' table created.") logger.info(f"'submission' table created.")
await cur.execute('''SELECT name FROM sqlite_master WHERE type='table' AND name = ?;''', ['period'])
if await cur.fetchone() is None:
await self.conn.execute('''CREATE TABLE IF NOT EXISTS period
(id INTEGER PRIMARY KEY,
guild INTEGER,
current_state INTEGER,
started_at TIMESTAMP,
voting_at TIMESTAMP DEFAULT NULL,
finished_at TIMESTAMP DEFAULT NULL)''')
logger.info(f"'period' table created.")
finally: finally:
await cur.close() await cur.close()
@@ -67,31 +92,46 @@ class ContestDatabase(object):
await self.conn.execute('''UPDATE guild SET prefix = ? WHERE id = ?''', [new_prefix, guild_id]) await self.conn.execute('''UPDATE guild SET prefix = ? WHERE id = ?''', [new_prefix, guild_id])
await self.conn.commit() await self.conn.commit()
async def is_setup(self, guild_id: int) -> bool: async def get_submission(self, guild_id: int, user_id: int) -> Optional[Submission]:
"""Checks whether the bot is setup to complete submission channel related commands.""" """Retrieves a row from the submission table by the associated unique Guild ID and User ID"""
cur = await self.conn.cursor() cur = await self.conn.cursor()
try: try:
await cur.execute('''SELECT submission FROM guild WHERE id = ?''', [guild_id]) await cur.execute('''SELECT * FROM submission WHERE guild = ? AND user = ?''', [guild_id, user_id])
t = await cur.fetchone() row = await cur.fetchone()
print(t) return None if row is None else Submission._make(row)
return t['submission']
finally: finally:
await cur.close() await cur.close()
async def get_prefix(self, guild_id: int) -> str: async def get_guild(self, guild_id: int) -> Optional[Guild]:
"""Gets the prefix from a specific guild in the database.""" """Retrieves a row from the Guild table by the Guild ID"""
cur = await self.conn.cursor() cur = await self.conn.cursor()
try: try:
await cur.execute('''SELECT prefix FROM guild WHERE id = ?''', [guild_id]) await cur.execute('''SELECT * FROM guild WHERE id = ?''', [guild_id])
return (await cur.fetchone())[0] row = await cur.fetchone()
return None if row is None else Guild._make(row)
finally: finally:
await cur.close() await cur.close()
async def get_submission_channel(self, guild_id: int) -> int: async def get_period(self, period_id) -> Optional[Period]:
cur = await self.conn.cursor() cur = await self.conn.cursor()
try: try:
await cur.execute('''SELECT submission FROM guild WHERE id = ?''', [guild_id]) await cur.execute('''SELECT * FROM period WHERE id = ?''', [period_id])
return (await cur.fetchone())[0] row = await cur.fetchone()
return None if row is None else Period._make(row)
finally:
await cur.close()
async def get_current_period(self, guild_id: int) -> Optional[Period]:
"""Retrieves a row from the Guild table by the Guild ID"""
cur = await self.conn.cursor()
try:
guild = await self.get_guild(guild_id)
if guild is None:
logger.debug(f'Guild {guild_id} does not exist.')
return None
if guild.period is not None:
return await self.get_period(guild.period)
finally: finally:
await cur.close() await cur.close()
@@ -105,20 +145,41 @@ class ContestDatabase(object):
await self.conn.execute('''DELETE FROM guild WHERE id = ?''', [guild_id]) await self.conn.execute('''DELETE FROM guild WHERE id = ?''', [guild_id])
await self.conn.commit() await self.conn.commit()
async def get_submission(self, guild_id: int, user_id: int) -> Optional[int]:
cur = await self.conn.cursor()
try:
await cur.execute('''SELECT id FROM submission WHERE guild = ? AND user = ?''', [guild_id, user_id])
row = await cur.fetchone()
if row is None:
return None
return row[0]
finally:
await cur.close()
async def add_submission(self, submission_id: int, guild_id: int, user_id: int, timestamp: int = None) -> None: async def add_submission(self, submission_id: int, guild_id: int, user_id: int, timestamp: int = None) -> None:
await self.conn.execute( await self.conn.execute(
'''INSERT INTO submission (id, user, guild, timestamp) VALUES (?, ?, ?, ?)''', '''INSERT INTO submission (id, user, guild, timestamp) VALUES (?, ?, ?, ?)''',
[submission_id, user_id, guild_id, timestamp or datetime.utcnow().timestamp()] [submission_id, user_id, guild_id, timestamp or datetime.utcnow().timestamp()]
) )
await self.conn.commit() await self.conn.commit()
@staticmethod
async def generate_insert_query(table: str, columns: List[str]) -> str:
return f'''INSERT INTO {table} ({", ".join(columns)}) VALUES ({", ".join("?" for _ in columns)})'''
async def new_period(self, period: Period) -> None:
"""Given a period, adds the period to the table and updates the associated guild."""
cur = await self.conn.cursor()
try:
# Ensure the associated guild exists
if period.guild is None:
return logger.error(f'Period {period} did not include a guild to associate with.')
guild = await self.get_guild(period.guild)
if guild is None:
return logger.error(f'Specified guild {period.guild} does not exist.')
# Add the period to the table
items = filter(lambda item: item[1] is not None, period._asdict().items())
columns, values = zip(*items)
query = await self.generate_insert_query('period', list(columns))
logger.debug(f'Generated Insert Query: {query}')
await cur.execute(query, values)
await self.conn.commit()
# Update the associated guild's period
await cur.execute('''UPDATE guild SET period = ? WHERE id = ?''', [cur.lastrowid, period.guild])
await self.conn.commit()
finally:
await cur.close()
async def update_period(self, ):