begin refactoring into namedtuple style, add period setup, transition to more modular and safer database state management

This commit is contained in:
Xevion
2021-02-09 14:32:32 -06:00
parent 294b34abe0
commit 5defafa36e

View File

@@ -1,8 +1,9 @@
import logging
import os
import sqlite3
from collections import namedtuple
from datetime import datetime
from typing import Optional
from typing import Optional, List
import aiosqlite
@@ -11,6 +12,10 @@ from contest import constants
logger = logging.getLogger(__file__)
logger.setLevel(constants.LOGGING_LEVEL)
Guild = namedtuple('Guild', ['id', 'prefix', 'submission', 'period'])
Submission = namedtuple('Submission', ['id', 'user', 'guild', 'timestamp'])
Period = namedtuple('Period', ['id', 'guild', 'current_state', 'started_at', 'voting_at', 'finished_at', ''])
class ContestDatabase(object):
"""
@@ -21,17 +26,24 @@ class ContestDatabase(object):
self.conn = conn
@classmethod
async def create(cls) -> 'ContestDatabase':
async def create(cls, dest: str = constants.DATABASE) -> 'ContestDatabase':
"""
Constructs a ContestDatabase object connecting to the default database location with the proper connection settings.
:return: A fully realized ContestDatabase object.
"""
conn = await aiosqlite.connect(constants.DATABASE, detect_types=sqlite3.PARSE_DECLTYPES)
logger.info(f'Asynchronous SQLite3 connection made to ./{os.path.relpath(constants.DATABASE)}')
conn = await aiosqlite.connect(dest, detect_types=sqlite3.PARSE_DECLTYPES)
if dest.startswith(':memory:'):
logger.info('Asynchronous SQLite3 connection started in memory.')
else:
logger.info(f'Asynchronous SQLite3 connection made to ./{os.path.relpath(constants.DATABASE)}')
db = ContestDatabase(conn)
await db.setup()
await conn.commit()
logger.info('ContestDatabase instance created.')
logger.info('ContestDatabase instance created, database setup.')
return db
async def setup(self) -> None:
@@ -43,7 +55,8 @@ class ContestDatabase(object):
await self.conn.execute('''CREATE TABLE IF NOT EXISTS guild
(id INTEGER PRIMARY KEY,
prefix TEXT DEFAULT '$',
submission INTEGER NULLABLE)''')
submission INTEGER NULLABLE,
period INTEGER)''')
logger.info(f"'guild' table created.")
await cur.execute('''SELECT name FROM sqlite_master WHERE type='table' AND name = ?;''', ['submission'])
@@ -54,6 +67,18 @@ class ContestDatabase(object):
guild INTEGER,
timestamp DATETIME)''')
logger.info(f"'submission' table created.")
await cur.execute('''SELECT name FROM sqlite_master WHERE type='table' AND name = ?;''', ['period'])
if await cur.fetchone() is None:
await self.conn.execute('''CREATE TABLE IF NOT EXISTS period
(id INTEGER PRIMARY KEY,
guild INTEGER,
current_state INTEGER,
started_at TIMESTAMP,
voting_at TIMESTAMP DEFAULT NULL,
finished_at TIMESTAMP DEFAULT NULL)''')
logger.info(f"'period' table created.")
finally:
await cur.close()
@@ -67,31 +92,46 @@ class ContestDatabase(object):
await self.conn.execute('''UPDATE guild SET prefix = ? WHERE id = ?''', [new_prefix, guild_id])
await self.conn.commit()
async def is_setup(self, guild_id: int) -> bool:
"""Checks whether the bot is setup to complete submission channel related commands."""
async def get_submission(self, guild_id: int, user_id: int) -> Optional[Submission]:
"""Retrieves a row from the submission table by the associated unique Guild ID and User ID"""
cur = await self.conn.cursor()
try:
await cur.execute('''SELECT submission FROM guild WHERE id = ?''', [guild_id])
t = await cur.fetchone()
print(t)
return t['submission']
await cur.execute('''SELECT * FROM submission WHERE guild = ? AND user = ?''', [guild_id, user_id])
row = await cur.fetchone()
return None if row is None else Submission._make(row)
finally:
await cur.close()
async def get_prefix(self, guild_id: int) -> str:
"""Gets the prefix from a specific guild in the database."""
async def get_guild(self, guild_id: int) -> Optional[Guild]:
"""Retrieves a row from the Guild table by the Guild ID"""
cur = await self.conn.cursor()
try:
await cur.execute('''SELECT prefix FROM guild WHERE id = ?''', [guild_id])
return (await cur.fetchone())[0]
await cur.execute('''SELECT * FROM guild WHERE id = ?''', [guild_id])
row = await cur.fetchone()
return None if row is None else Guild._make(row)
finally:
await cur.close()
async def get_submission_channel(self, guild_id: int) -> int:
async def get_period(self, period_id) -> Optional[Period]:
cur = await self.conn.cursor()
try:
await cur.execute('''SELECT submission FROM guild WHERE id = ?''', [guild_id])
return (await cur.fetchone())[0]
await cur.execute('''SELECT * FROM period WHERE id = ?''', [period_id])
row = await cur.fetchone()
return None if row is None else Period._make(row)
finally:
await cur.close()
async def get_current_period(self, guild_id: int) -> Optional[Period]:
"""Retrieves a row from the Guild table by the Guild ID"""
cur = await self.conn.cursor()
try:
guild = await self.get_guild(guild_id)
if guild is None:
logger.debug(f'Guild {guild_id} does not exist.')
return None
if guild.period is not None:
return await self.get_period(guild.period)
finally:
await cur.close()
@@ -105,20 +145,41 @@ class ContestDatabase(object):
await self.conn.execute('''DELETE FROM guild WHERE id = ?''', [guild_id])
await self.conn.commit()
async def get_submission(self, guild_id: int, user_id: int) -> Optional[int]:
cur = await self.conn.cursor()
try:
await cur.execute('''SELECT id FROM submission WHERE guild = ? AND user = ?''', [guild_id, user_id])
row = await cur.fetchone()
if row is None:
return None
return row[0]
finally:
await cur.close()
async def add_submission(self, submission_id: int, guild_id: int, user_id: int, timestamp: int = None) -> None:
await self.conn.execute(
'''INSERT INTO submission (id, user, guild, timestamp) VALUES (?, ?, ?, ?)''',
[submission_id, user_id, guild_id, timestamp or datetime.utcnow().timestamp()]
)
await self.conn.commit()
@staticmethod
async def generate_insert_query(table: str, columns: List[str]) -> str:
return f'''INSERT INTO {table} ({", ".join(columns)}) VALUES ({", ".join("?" for _ in columns)})'''
async def new_period(self, period: Period) -> None:
"""Given a period, adds the period to the table and updates the associated guild."""
cur = await self.conn.cursor()
try:
# Ensure the associated guild exists
if period.guild is None:
return logger.error(f'Period {period} did not include a guild to associate with.')
guild = await self.get_guild(period.guild)
if guild is None:
return logger.error(f'Specified guild {period.guild} does not exist.')
# Add the period to the table
items = filter(lambda item: item[1] is not None, period._asdict().items())
columns, values = zip(*items)
query = await self.generate_insert_query('period', list(columns))
logger.debug(f'Generated Insert Query: {query}')
await cur.execute(query, values)
await self.conn.commit()
# Update the associated guild's period
await cur.execute('''UPDATE guild SET period = ? WHERE id = ?''', [cur.lastrowid, period.guild])
await self.conn.commit()
finally:
await cur.close()
async def update_period(self, ):