diff --git a/tests/bot/test_db.py b/tests/bot/test_db.py index c1785b8..613b258 100644 --- a/tests/bot/test_db.py +++ b/tests/bot/test_db.py @@ -1,74 +1,71 @@ +import datetime +import random + import pytest +from sqlalchemy.orm import Session, sessionmaker +from itertools import count +from bot.models import Guild, Submission, Period +from main import load_db -from bot.db import ContestDatabase, tables +numbers = count() + +@pytest.fixture(scope='class') +async def SessionClass(): + engine = load_db('sqlite:///') + yield sessionmaker(bind=engine) + engine.dispose() + +class TestSubmissions: + @pytest.fixture() + async def session(self, SessionClass) -> Session: + session = SessionClass() + yield session + session.commit() + session.close() + + @pytest.fixture() + def guild(self, session) -> Guild: + guild = Guild(id=next(numbers), submission_channel=next(numbers)) + session.add(guild) + session.commit() + yield guild + session.delete(guild) + session.close() + + @pytest.mark.asyncio + async def test_submission_base(self, session) -> None: + period = Period(id=next(numbers)) + session.add(period) + submission = Submission(id=next(numbers), user=next(numbers), timestamp=datetime.datetime.utcnow(), period=period) + session.add(submission) -@pytest.fixture() -async def db() -> ContestDatabase: - db = await ContestDatabase.create(':memory:') - yield db - await db.conn.close() +class TestGuilds: + @pytest.fixture() + async def session(self, SessionClass) -> Session: + session = SessionClass() + yield session + session.commit() + session.close() + @pytest.mark.asyncio + async def test_guild_base(session) -> None: + guild = Guild(id=0) + session.commit() + for guild in session.query(Guild).all(): + print(guild) + pass -@pytest.mark.asyncio -async def test_table_setup(db) -> None: - """Test that all tables were setup by the database.""" - cur = await db.conn.cursor() - try: - for table_namedtuple in tables: - await cur.execute('''SELECT name FROM sqlite_master WHERE type='table' AND name = ?;''', - [table_namedtuple.__name__.lower()]) - rows = list(await cur.fetchall()) - assert len(rows) == 1 - finally: - await cur.close() + class TestPeriods: + @pytest.fixture() + async def session(self, SessionClass) -> Session: + session = SessionClass() + yield session + session.commit() + session.close() - -@pytest.mark.asyncio -async def test_guild_setup(db) -> None: - await db.setup_guild(0) - guild = await db.get_guild(0) - assert guild is not None - assert guild.submission is None - - assert await db.get_guild(1) is None - - -@pytest.mark.asyncio -async def test_update(db) -> None: - pass - - -@pytest.mark.asyncio -async def test_generate_update_query(db) -> None: - pass - - -@pytest.mark.asyncio -async def test_insert(db) -> None: - """Test automatic namedtuple query""" - pass - - -@pytest.mark.asyncio -async def test_generate_insert_query(db) -> None: - """Test INSERT query generation.""" - pass - - -@pytest.mark.asyncio -async def test_submissions(db) -> None: - """Test all submission related helper functions.""" - pass - - -@pytest.mark.asyncio -async def test_guilds(db) -> None: - """Test all guild related helper functions""" - pass - - -@pytest.mark.asyncio -async def test_periods(db) -> None: - """Test all period related helper functions.""" - pass + @pytest.mark.asyncio + async def test_period_base(session) -> None: + period = Period(id=1, guild_id=1) + session.commit() + pass