mirror of
https://github.com/Xevion/contest-assistant.git
synced 2025-12-06 15:14:41 -06:00
add guilds to database on connection if not already in database, remove old db.py, switch from query filter_by to simple get by primary key
This commit is contained in:
15
bot/bot.py
15
bot/bot.py
@@ -42,22 +42,29 @@ class ContestBot(commands.Bot):
|
||||
|
||||
if message.guild:
|
||||
with self.get_session() as session:
|
||||
guild = session.query(Guild).filter_by(id=message.guild.id).first()
|
||||
guild = session.query(Guild).get(message.guild.id)
|
||||
base.append(guild.prefix)
|
||||
return base
|
||||
|
||||
async def on_ready(self):
|
||||
"""Communicate that the bot is online now."""
|
||||
logger.info('Bot is now ready and connected to Discord.')
|
||||
guild_count = len(self.guilds)
|
||||
logger.info(
|
||||
f'Connected as {self.user.name}#{self.user.discriminator} to {guild_count} guild{"s" if guild_count > 1 else ""}.')
|
||||
logger.info(f'Connected as {self.user.name}#{self.user.discriminator} to {guild_count} guild{"s" if guild_count > 1 else ""}.')
|
||||
|
||||
with self.get_session() as session:
|
||||
for guild in self.guilds:
|
||||
_guild: Guild = session.query(Guild).get(guild.id)
|
||||
if _guild is None:
|
||||
logger.warning(f'Guild {guild.name} ({guild.id}) was not inside database on ready. Bot was disconnected or did not add it properly...')
|
||||
session.add(Guild(id=guild.id))
|
||||
|
||||
async def on_guild_join(self, guild: discord.Guild) -> None:
|
||||
"""Handles adding or reactivating a Guild in the database."""
|
||||
logger.info(f'Added to new guild: {guild.name} ({guild.id})')
|
||||
|
||||
with self.get_session() as session:
|
||||
_guild: Guild = session.query(Guild).filter_by(active=False, id=guild.id).first()
|
||||
_guild: Guild = session.query(Guild).get(guild.id)
|
||||
if _guild is None:
|
||||
session.add(Guild(id=guild.id))
|
||||
else:
|
||||
|
||||
208
bot/db.py
208
bot/db.py
@@ -1,208 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
import sqlite3
|
||||
from collections import namedtuple
|
||||
from datetime import datetime
|
||||
from typing import Optional, List, Union
|
||||
|
||||
import aiosqlite
|
||||
|
||||
from bot import constants
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
logger.setLevel(constants.LOGGING_LEVEL)
|
||||
|
||||
Guild = namedtuple('Guild', ['id', 'prefix', 'submission', 'period'])
|
||||
Submission = namedtuple('Submission', ['id', 'user', 'guild', 'timestamp'])
|
||||
Period = namedtuple('Period', ['id', 'guild', 'current_state', 'started_at', 'voting_at', 'finished_at'])
|
||||
tables = [Guild, Submission, Period]
|
||||
|
||||
|
||||
class ContestDatabase(object):
|
||||
"""
|
||||
A handler class for a SQLite3 database used by the bot with Async support.
|
||||
"""
|
||||
|
||||
def __init__(self, conn: aiosqlite.Connection) -> None:
|
||||
self.conn = conn
|
||||
|
||||
@classmethod
|
||||
async def create(cls, dest: str = constants.DATABASE) -> 'ContestDatabase':
|
||||
"""
|
||||
Constructs a ContestDatabase object connecting to the default database location with the proper connection settings.
|
||||
:return: A fully realized ContestDatabase object.
|
||||
"""
|
||||
|
||||
conn = await aiosqlite.connect(dest, detect_types=sqlite3.PARSE_DECLTYPES)
|
||||
if dest.startswith(':memory:'):
|
||||
logger.info('Asynchronous SQLite3 connection started in memory.')
|
||||
else:
|
||||
logger.info(f'Asynchronous SQLite3 connection made to ./{os.path.relpath(constants.DATABASE)}')
|
||||
|
||||
db = ContestDatabase(conn)
|
||||
await db.setup()
|
||||
await conn.commit()
|
||||
|
||||
logger.info('ContestDatabase instance created, database setup.')
|
||||
|
||||
return db
|
||||
|
||||
async def setup(self) -> None:
|
||||
"""Sets up the tables for initial database creation"""
|
||||
cur = await self.conn.cursor()
|
||||
try:
|
||||
await cur.execute('''SELECT name FROM sqlite_master WHERE type='table' AND name = ?;''', ['guild'])
|
||||
if await cur.fetchone() is None:
|
||||
await self.conn.execute('''CREATE TABLE IF NOT EXISTS guild
|
||||
(id INTEGER PRIMARY KEY,
|
||||
prefix TEXT DEFAULT '$',
|
||||
submission INTEGER NULLABLE,
|
||||
period INTEGER)''')
|
||||
logger.info(f"'guild' table created.")
|
||||
|
||||
await cur.execute('''SELECT name FROM sqlite_master WHERE type='table' AND name = ?;''', ['submission'])
|
||||
if await cur.fetchone() is None:
|
||||
await self.conn.execute('''CREATE TABLE IF NOT EXISTS submission
|
||||
(id INTEGER PRIMARY KEY,
|
||||
user INTEGER,
|
||||
guild INTEGER,
|
||||
timestamp DATETIME)''')
|
||||
logger.info(f"'submission' table created.")
|
||||
|
||||
await cur.execute('''SELECT name FROM sqlite_master WHERE type='table' AND name = ?;''', ['period'])
|
||||
if await cur.fetchone() is None:
|
||||
await self.conn.execute('''CREATE TABLE IF NOT EXISTS period
|
||||
(id INTEGER PRIMARY KEY,
|
||||
guild INTEGER,
|
||||
current_state INTEGER,
|
||||
started_at TIMESTAMP,
|
||||
voting_at TIMESTAMP DEFAULT NULL,
|
||||
finished_at TIMESTAMP DEFAULT NULL)''')
|
||||
logger.info(f"'period' table created.")
|
||||
|
||||
finally:
|
||||
await cur.close()
|
||||
|
||||
async def setup_guild(self, guild_id: int) -> None:
|
||||
"""Sets up a guild in the database."""
|
||||
await self.conn.execute('''INSERT INTO guild (id) VALUES (?)''', [guild_id])
|
||||
await self.conn.commit()
|
||||
|
||||
async def set_prefix(self, guild_id: int, new_prefix: str) -> None:
|
||||
"""Updates the prefix for a specific guild in the database"""
|
||||
await self.conn.execute('''UPDATE guild SET prefix = ? WHERE id = ?''', [new_prefix, guild_id])
|
||||
await self.conn.commit()
|
||||
|
||||
async def get_submission(self, guild_id: int, user_id: int) -> Optional[Submission]:
|
||||
"""Retrieves a row from the submission table by the associated unique Guild ID and User ID"""
|
||||
cur = await self.conn.cursor()
|
||||
try:
|
||||
await cur.execute('''SELECT * FROM submission WHERE guild = ? AND user = ?''', [guild_id, user_id])
|
||||
row = await cur.fetchone()
|
||||
return None if row is None else Submission._make(row)
|
||||
finally:
|
||||
await cur.close()
|
||||
|
||||
async def get_guild(self, guild_id: int) -> Optional[Guild]:
|
||||
"""Retrieves a row from the Guild table by the Guild ID"""
|
||||
cur = await self.conn.cursor()
|
||||
try:
|
||||
await cur.execute('''SELECT * FROM guild WHERE id = ?''', [guild_id])
|
||||
row = await cur.fetchone()
|
||||
return None if row is None else Guild._make(row)
|
||||
finally:
|
||||
await cur.close()
|
||||
|
||||
async def get_period(self, period_id) -> Optional[Period]:
|
||||
cur = await self.conn.cursor()
|
||||
try:
|
||||
await cur.execute('''SELECT * FROM period WHERE id = ?''', [period_id])
|
||||
row = await cur.fetchone()
|
||||
return None if row is None else Period._make(row)
|
||||
finally:
|
||||
await cur.close()
|
||||
|
||||
async def get_current_period(self, guild_id: int) -> Optional[Period]:
|
||||
"""Retrieves a row from the Guild table by the Guild ID"""
|
||||
cur = await self.conn.cursor()
|
||||
try:
|
||||
guild = await self.get_guild(guild_id)
|
||||
if guild is None:
|
||||
logger.debug(f'Guild {guild_id} does not exist.')
|
||||
return None
|
||||
|
||||
if guild.period is not None:
|
||||
return await self.get_period(guild.period)
|
||||
finally:
|
||||
await cur.close()
|
||||
|
||||
async def set_submission_channel(self, guild_id: int, new_submission: int) -> None:
|
||||
"""Updates the submission channel for a specific guild in the database"""
|
||||
await self.conn.execute('''UPDATE guild SET submission = ? WHERE id = ?''', [new_submission, guild_id])
|
||||
await self.conn.commit()
|
||||
|
||||
async def teardown_guild(self, guild_id: int) -> None:
|
||||
"""Removes a guild from the database while completing appropriate teardown actions."""
|
||||
await self.conn.execute('''DELETE FROM guild WHERE id = ?''', [guild_id])
|
||||
await self.conn.commit()
|
||||
|
||||
async def add_submission(self, submission_id: int, guild_id: int, user_id: int, timestamp: int = None) -> None:
|
||||
await self.conn.execute(
|
||||
'''INSERT INTO submission (id, user, guild, timestamp) VALUES (?, ?, ?, ?)''',
|
||||
[submission_id, user_id, guild_id, timestamp or datetime.utcnow().timestamp()]
|
||||
)
|
||||
await self.conn.commit()
|
||||
|
||||
@staticmethod
|
||||
async def generate_insert_query(table: str, columns: List[str]) -> str:
|
||||
"""
|
||||
Generate a basic limited insert query based on a destination table and number of named arguments.
|
||||
|
||||
Does NOT execute the query or insert any values - run this with Connection.execute()!
|
||||
"""
|
||||
query = f'''INSERT INTO {table} ({", ".join(columns)}) VALUES ({", ".join("?" for _ in columns)})'''
|
||||
logger.debug(query)
|
||||
return query
|
||||
|
||||
async def new_period(self, period: Period) -> None:
|
||||
"""Given a period, adds the period to the table and updates the associated guild."""
|
||||
cur = await self.conn.cursor()
|
||||
try:
|
||||
# Ensure the associated guild exists
|
||||
if period.guild is None:
|
||||
return logger.error(f'Period {period} did not include a guild to associate with.')
|
||||
|
||||
guild = await self.get_guild(period.guild)
|
||||
if guild is None:
|
||||
return logger.error(f'Specified guild {period.guild} does not exist.')
|
||||
|
||||
# Add the period to the table
|
||||
items = filter(lambda item: item[1] is not None, period._asdict().items())
|
||||
columns, values = zip(*items)
|
||||
query = await self.generate_insert_query('period', list(columns))
|
||||
|
||||
await cur.execute(query, values)
|
||||
await self.conn.commit()
|
||||
|
||||
# Update the associated guild's period
|
||||
await cur.execute('''UPDATE guild SET period = ? WHERE id = ?''', [cur.lastrowid, period.guild])
|
||||
await self.conn.commit()
|
||||
finally:
|
||||
await cur.close()
|
||||
|
||||
async def update(self, obj: Union[Guild, Submission, Period]) -> None:
|
||||
"""Using the objects's ID, updates a row in the database."""
|
||||
assert obj is not None and isinstance(obj, tuple) and getattr(obj, '_fields', None) is not None, ""
|
||||
assert obj.id is not None, ""
|
||||
|
||||
cur = await self.conn.cursor()
|
||||
try:
|
||||
# Add the period to the table
|
||||
items = filter(lambda item: item[1] is not None, obj._asdict().items())
|
||||
columns, values = zip(*items)
|
||||
query = await self.generate_insert_query(type(object).__name__.lower(), list(columns))
|
||||
|
||||
await cur.execute(query, values)
|
||||
await self.conn.commit()
|
||||
finally:
|
||||
await cur.close()
|
||||
@@ -42,7 +42,7 @@ class Guild(Base):
|
||||
|
||||
active = Column(Boolean, default=True) # Whether or not the bot is active inside the given Guild. Used for better querying.
|
||||
joined = Column(DateTime, default=datetime.datetime.utcnow) # The initial join time for this bot to a particular Discord.
|
||||
last_joined = Column(DateTime, nullable=True) # The last time the bot joined this server.
|
||||
last_joined = Column(DateTime, default=datetime.datetime.utcnow) # The last time the bot joined this server.
|
||||
|
||||
|
||||
def check_not_finished(func):
|
||||
|
||||
Reference in New Issue
Block a user