Files
contest-assistant/bot/db.py
2021-02-12 23:19:35 -06:00

209 lines
9.0 KiB
Python

import logging
import os
import sqlite3
from collections import namedtuple
from datetime import datetime
from typing import Optional, List, Union
import aiosqlite
from bot import constants
logger = logging.getLogger(__file__)
logger.setLevel(constants.LOGGING_LEVEL)
Guild = namedtuple('Guild', ['id', 'prefix', 'submission', 'period'])
Submission = namedtuple('Submission', ['id', 'user', 'guild', 'timestamp'])
Period = namedtuple('Period', ['id', 'guild', 'current_state', 'started_at', 'voting_at', 'finished_at'])
tables = [Guild, Submission, Period]
class ContestDatabase(object):
"""
A handler class for a SQLite3 database used by the bot with Async support.
"""
def __init__(self, conn: aiosqlite.Connection) -> None:
self.conn = conn
@classmethod
async def create(cls, dest: str = constants.DATABASE) -> 'ContestDatabase':
"""
Constructs a ContestDatabase object connecting to the default database location with the proper connection settings.
:return: A fully realized ContestDatabase object.
"""
conn = await aiosqlite.connect(dest, detect_types=sqlite3.PARSE_DECLTYPES)
if dest.startswith(':memory:'):
logger.info('Asynchronous SQLite3 connection started in memory.')
else:
logger.info(f'Asynchronous SQLite3 connection made to ./{os.path.relpath(constants.DATABASE)}')
db = ContestDatabase(conn)
await db.setup()
await conn.commit()
logger.info('ContestDatabase instance created, database setup.')
return db
async def setup(self) -> None:
"""Sets up the tables for initial database creation"""
cur = await self.conn.cursor()
try:
await cur.execute('''SELECT name FROM sqlite_master WHERE type='table' AND name = ?;''', ['guild'])
if await cur.fetchone() is None:
await self.conn.execute('''CREATE TABLE IF NOT EXISTS guild
(id INTEGER PRIMARY KEY,
prefix TEXT DEFAULT '$',
submission INTEGER NULLABLE,
period INTEGER)''')
logger.info(f"'guild' table created.")
await cur.execute('''SELECT name FROM sqlite_master WHERE type='table' AND name = ?;''', ['submission'])
if await cur.fetchone() is None:
await self.conn.execute('''CREATE TABLE IF NOT EXISTS submission
(id INTEGER PRIMARY KEY,
user INTEGER,
guild INTEGER,
timestamp DATETIME)''')
logger.info(f"'submission' table created.")
await cur.execute('''SELECT name FROM sqlite_master WHERE type='table' AND name = ?;''', ['period'])
if await cur.fetchone() is None:
await self.conn.execute('''CREATE TABLE IF NOT EXISTS period
(id INTEGER PRIMARY KEY,
guild INTEGER,
current_state INTEGER,
started_at TIMESTAMP,
voting_at TIMESTAMP DEFAULT NULL,
finished_at TIMESTAMP DEFAULT NULL)''')
logger.info(f"'period' table created.")
finally:
await cur.close()
async def setup_guild(self, guild_id: int) -> None:
"""Sets up a guild in the database."""
await self.conn.execute('''INSERT INTO guild (id) VALUES (?)''', [guild_id])
await self.conn.commit()
async def set_prefix(self, guild_id: int, new_prefix: str) -> None:
"""Updates the prefix for a specific guild in the database"""
await self.conn.execute('''UPDATE guild SET prefix = ? WHERE id = ?''', [new_prefix, guild_id])
await self.conn.commit()
async def get_submission(self, guild_id: int, user_id: int) -> Optional[Submission]:
"""Retrieves a row from the submission table by the associated unique Guild ID and User ID"""
cur = await self.conn.cursor()
try:
await cur.execute('''SELECT * FROM submission WHERE guild = ? AND user = ?''', [guild_id, user_id])
row = await cur.fetchone()
return None if row is None else Submission._make(row)
finally:
await cur.close()
async def get_guild(self, guild_id: int) -> Optional[Guild]:
"""Retrieves a row from the Guild table by the Guild ID"""
cur = await self.conn.cursor()
try:
await cur.execute('''SELECT * FROM guild WHERE id = ?''', [guild_id])
row = await cur.fetchone()
return None if row is None else Guild._make(row)
finally:
await cur.close()
async def get_period(self, period_id) -> Optional[Period]:
cur = await self.conn.cursor()
try:
await cur.execute('''SELECT * FROM period WHERE id = ?''', [period_id])
row = await cur.fetchone()
return None if row is None else Period._make(row)
finally:
await cur.close()
async def get_current_period(self, guild_id: int) -> Optional[Period]:
"""Retrieves a row from the Guild table by the Guild ID"""
cur = await self.conn.cursor()
try:
guild = await self.get_guild(guild_id)
if guild is None:
logger.debug(f'Guild {guild_id} does not exist.')
return None
if guild.period is not None:
return await self.get_period(guild.period)
finally:
await cur.close()
async def set_submission_channel(self, guild_id: int, new_submission: int) -> None:
"""Updates the submission channel for a specific guild in the database"""
await self.conn.execute('''UPDATE guild SET submission = ? WHERE id = ?''', [new_submission, guild_id])
await self.conn.commit()
async def teardown_guild(self, guild_id: int) -> None:
"""Removes a guild from the database while completing appropriate teardown actions."""
await self.conn.execute('''DELETE FROM guild WHERE id = ?''', [guild_id])
await self.conn.commit()
async def add_submission(self, submission_id: int, guild_id: int, user_id: int, timestamp: int = None) -> None:
await self.conn.execute(
'''INSERT INTO submission (id, user, guild, timestamp) VALUES (?, ?, ?, ?)''',
[submission_id, user_id, guild_id, timestamp or datetime.utcnow().timestamp()]
)
await self.conn.commit()
@staticmethod
async def generate_insert_query(table: str, columns: List[str]) -> str:
"""
Generate a basic limited insert query based on a destination table and number of named arguments.
Does NOT execute the query or insert any values - run this with Connection.execute()!
"""
query = f'''INSERT INTO {table} ({", ".join(columns)}) VALUES ({", ".join("?" for _ in columns)})'''
logger.debug(query)
return query
async def new_period(self, period: Period) -> None:
"""Given a period, adds the period to the table and updates the associated guild."""
cur = await self.conn.cursor()
try:
# Ensure the associated guild exists
if period.guild is None:
return logger.error(f'Period {period} did not include a guild to associate with.')
guild = await self.get_guild(period.guild)
if guild is None:
return logger.error(f'Specified guild {period.guild} does not exist.')
# Add the period to the table
items = filter(lambda item: item[1] is not None, period._asdict().items())
columns, values = zip(*items)
query = await self.generate_insert_query('period', list(columns))
await cur.execute(query, values)
await self.conn.commit()
# Update the associated guild's period
await cur.execute('''UPDATE guild SET period = ? WHERE id = ?''', [cur.lastrowid, period.guild])
await self.conn.commit()
finally:
await cur.close()
async def update(self, obj: Union[Guild, Submission, Period]) -> None:
"""Using the objects's ID, updates a row in the database."""
assert obj is not None and isinstance(obj, tuple) and getattr(obj, '_fields', None) is not None, ""
assert obj.id is not None, ""
cur = await self.conn.cursor()
try:
# Add the period to the table
items = filter(lambda item: item[1] is not None, obj._asdict().items())
columns, values = zip(*items)
query = await self.generate_insert_query(type(object).__name__.lower(), list(columns))
await cur.execute(query, values)
await self.conn.commit()
finally:
await cur.close()