From 1f46a98f37f394778c5318b65ade04a57a0ea59a Mon Sep 17 00:00:00 2001 From: Xevion Date: Sat, 13 Feb 2021 08:35:00 -0600 Subject: [PATCH] add guilds to database on connection if not already in database, remove old db.py, switch from query filter_by to simple get by primary key --- bot/bot.py | 15 +++- bot/db.py | 208 -------------------------------------------------- bot/models.py | 2 +- 3 files changed, 12 insertions(+), 213 deletions(-) delete mode 100644 bot/db.py diff --git a/bot/bot.py b/bot/bot.py index 118774b..4bb382e 100644 --- a/bot/bot.py +++ b/bot/bot.py @@ -42,22 +42,29 @@ class ContestBot(commands.Bot): if message.guild: with self.get_session() as session: - guild = session.query(Guild).filter_by(id=message.guild.id).first() + guild = session.query(Guild).get(message.guild.id) base.append(guild.prefix) return base async def on_ready(self): + """Communicate that the bot is online now.""" logger.info('Bot is now ready and connected to Discord.') guild_count = len(self.guilds) - logger.info( - f'Connected as {self.user.name}#{self.user.discriminator} to {guild_count} guild{"s" if guild_count > 1 else ""}.') + logger.info(f'Connected as {self.user.name}#{self.user.discriminator} to {guild_count} guild{"s" if guild_count > 1 else ""}.') + + with self.get_session() as session: + for guild in self.guilds: + _guild: Guild = session.query(Guild).get(guild.id) + if _guild is None: + logger.warning(f'Guild {guild.name} ({guild.id}) was not inside database on ready. Bot was disconnected or did not add it properly...') + session.add(Guild(id=guild.id)) async def on_guild_join(self, guild: discord.Guild) -> None: """Handles adding or reactivating a Guild in the database.""" logger.info(f'Added to new guild: {guild.name} ({guild.id})') with self.get_session() as session: - _guild: Guild = session.query(Guild).filter_by(active=False, id=guild.id).first() + _guild: Guild = session.query(Guild).get(guild.id) if _guild is None: session.add(Guild(id=guild.id)) else: diff --git a/bot/db.py b/bot/db.py deleted file mode 100644 index dee77aa..0000000 --- a/bot/db.py +++ /dev/null @@ -1,208 +0,0 @@ -import logging -import os -import sqlite3 -from collections import namedtuple -from datetime import datetime -from typing import Optional, List, Union - -import aiosqlite - -from bot import constants - -logger = logging.getLogger(__file__) -logger.setLevel(constants.LOGGING_LEVEL) - -Guild = namedtuple('Guild', ['id', 'prefix', 'submission', 'period']) -Submission = namedtuple('Submission', ['id', 'user', 'guild', 'timestamp']) -Period = namedtuple('Period', ['id', 'guild', 'current_state', 'started_at', 'voting_at', 'finished_at']) -tables = [Guild, Submission, Period] - - -class ContestDatabase(object): - """ - A handler class for a SQLite3 database used by the bot with Async support. - """ - - def __init__(self, conn: aiosqlite.Connection) -> None: - self.conn = conn - - @classmethod - async def create(cls, dest: str = constants.DATABASE) -> 'ContestDatabase': - """ - Constructs a ContestDatabase object connecting to the default database location with the proper connection settings. - :return: A fully realized ContestDatabase object. - """ - - conn = await aiosqlite.connect(dest, detect_types=sqlite3.PARSE_DECLTYPES) - if dest.startswith(':memory:'): - logger.info('Asynchronous SQLite3 connection started in memory.') - else: - logger.info(f'Asynchronous SQLite3 connection made to ./{os.path.relpath(constants.DATABASE)}') - - db = ContestDatabase(conn) - await db.setup() - await conn.commit() - - logger.info('ContestDatabase instance created, database setup.') - - return db - - async def setup(self) -> None: - """Sets up the tables for initial database creation""" - cur = await self.conn.cursor() - try: - await cur.execute('''SELECT name FROM sqlite_master WHERE type='table' AND name = ?;''', ['guild']) - if await cur.fetchone() is None: - await self.conn.execute('''CREATE TABLE IF NOT EXISTS guild - (id INTEGER PRIMARY KEY, - prefix TEXT DEFAULT '$', - submission INTEGER NULLABLE, - period INTEGER)''') - logger.info(f"'guild' table created.") - - await cur.execute('''SELECT name FROM sqlite_master WHERE type='table' AND name = ?;''', ['submission']) - if await cur.fetchone() is None: - await self.conn.execute('''CREATE TABLE IF NOT EXISTS submission - (id INTEGER PRIMARY KEY, - user INTEGER, - guild INTEGER, - timestamp DATETIME)''') - logger.info(f"'submission' table created.") - - await cur.execute('''SELECT name FROM sqlite_master WHERE type='table' AND name = ?;''', ['period']) - if await cur.fetchone() is None: - await self.conn.execute('''CREATE TABLE IF NOT EXISTS period - (id INTEGER PRIMARY KEY, - guild INTEGER, - current_state INTEGER, - started_at TIMESTAMP, - voting_at TIMESTAMP DEFAULT NULL, - finished_at TIMESTAMP DEFAULT NULL)''') - logger.info(f"'period' table created.") - - finally: - await cur.close() - - async def setup_guild(self, guild_id: int) -> None: - """Sets up a guild in the database.""" - await self.conn.execute('''INSERT INTO guild (id) VALUES (?)''', [guild_id]) - await self.conn.commit() - - async def set_prefix(self, guild_id: int, new_prefix: str) -> None: - """Updates the prefix for a specific guild in the database""" - await self.conn.execute('''UPDATE guild SET prefix = ? WHERE id = ?''', [new_prefix, guild_id]) - await self.conn.commit() - - async def get_submission(self, guild_id: int, user_id: int) -> Optional[Submission]: - """Retrieves a row from the submission table by the associated unique Guild ID and User ID""" - cur = await self.conn.cursor() - try: - await cur.execute('''SELECT * FROM submission WHERE guild = ? AND user = ?''', [guild_id, user_id]) - row = await cur.fetchone() - return None if row is None else Submission._make(row) - finally: - await cur.close() - - async def get_guild(self, guild_id: int) -> Optional[Guild]: - """Retrieves a row from the Guild table by the Guild ID""" - cur = await self.conn.cursor() - try: - await cur.execute('''SELECT * FROM guild WHERE id = ?''', [guild_id]) - row = await cur.fetchone() - return None if row is None else Guild._make(row) - finally: - await cur.close() - - async def get_period(self, period_id) -> Optional[Period]: - cur = await self.conn.cursor() - try: - await cur.execute('''SELECT * FROM period WHERE id = ?''', [period_id]) - row = await cur.fetchone() - return None if row is None else Period._make(row) - finally: - await cur.close() - - async def get_current_period(self, guild_id: int) -> Optional[Period]: - """Retrieves a row from the Guild table by the Guild ID""" - cur = await self.conn.cursor() - try: - guild = await self.get_guild(guild_id) - if guild is None: - logger.debug(f'Guild {guild_id} does not exist.') - return None - - if guild.period is not None: - return await self.get_period(guild.period) - finally: - await cur.close() - - async def set_submission_channel(self, guild_id: int, new_submission: int) -> None: - """Updates the submission channel for a specific guild in the database""" - await self.conn.execute('''UPDATE guild SET submission = ? WHERE id = ?''', [new_submission, guild_id]) - await self.conn.commit() - - async def teardown_guild(self, guild_id: int) -> None: - """Removes a guild from the database while completing appropriate teardown actions.""" - await self.conn.execute('''DELETE FROM guild WHERE id = ?''', [guild_id]) - await self.conn.commit() - - async def add_submission(self, submission_id: int, guild_id: int, user_id: int, timestamp: int = None) -> None: - await self.conn.execute( - '''INSERT INTO submission (id, user, guild, timestamp) VALUES (?, ?, ?, ?)''', - [submission_id, user_id, guild_id, timestamp or datetime.utcnow().timestamp()] - ) - await self.conn.commit() - - @staticmethod - async def generate_insert_query(table: str, columns: List[str]) -> str: - """ - Generate a basic limited insert query based on a destination table and number of named arguments. - - Does NOT execute the query or insert any values - run this with Connection.execute()! - """ - query = f'''INSERT INTO {table} ({", ".join(columns)}) VALUES ({", ".join("?" for _ in columns)})''' - logger.debug(query) - return query - - async def new_period(self, period: Period) -> None: - """Given a period, adds the period to the table and updates the associated guild.""" - cur = await self.conn.cursor() - try: - # Ensure the associated guild exists - if period.guild is None: - return logger.error(f'Period {period} did not include a guild to associate with.') - - guild = await self.get_guild(period.guild) - if guild is None: - return logger.error(f'Specified guild {period.guild} does not exist.') - - # Add the period to the table - items = filter(lambda item: item[1] is not None, period._asdict().items()) - columns, values = zip(*items) - query = await self.generate_insert_query('period', list(columns)) - - await cur.execute(query, values) - await self.conn.commit() - - # Update the associated guild's period - await cur.execute('''UPDATE guild SET period = ? WHERE id = ?''', [cur.lastrowid, period.guild]) - await self.conn.commit() - finally: - await cur.close() - - async def update(self, obj: Union[Guild, Submission, Period]) -> None: - """Using the objects's ID, updates a row in the database.""" - assert obj is not None and isinstance(obj, tuple) and getattr(obj, '_fields', None) is not None, "" - assert obj.id is not None, "" - - cur = await self.conn.cursor() - try: - # Add the period to the table - items = filter(lambda item: item[1] is not None, obj._asdict().items()) - columns, values = zip(*items) - query = await self.generate_insert_query(type(object).__name__.lower(), list(columns)) - - await cur.execute(query, values) - await self.conn.commit() - finally: - await cur.close() diff --git a/bot/models.py b/bot/models.py index d0cc2f2..a90d16e 100644 --- a/bot/models.py +++ b/bot/models.py @@ -42,7 +42,7 @@ class Guild(Base): active = Column(Boolean, default=True) # Whether or not the bot is active inside the given Guild. Used for better querying. joined = Column(DateTime, default=datetime.datetime.utcnow) # The initial join time for this bot to a particular Discord. - last_joined = Column(DateTime, nullable=True) # The last time the bot joined this server. + last_joined = Column(DateTime, default=datetime.datetime.utcnow) # The last time the bot joined this server. def check_not_finished(func):