Files
contest-assistant/tests/bot/test_db.py
Xevion 8d8fe83b5f Proper basic tests for db.Submission methods
Testing for Submission.increment, Submission.decrement with exceptions
Randomized database generation fixture with query test
2021-02-17 03:38:04 -06:00

113 lines
2.8 KiB
Python

import random
from itertools import count
from typing import Generator
import pytest
from sqlalchemy.orm import Session, sessionmaker
from bot import exceptions
from bot.models import Guild, Period, PeriodStates, Submission
from main import load_db
numbers = count()
@pytest.fixture(scope='class')
def SessionClass():
engine = load_db('sqlite:///')
yield sessionmaker(bind=engine)
engine.dispose()
@pytest.fixture(scope='function')
def session(SessionClass: sessionmaker):
s: Session = SessionClass()
try:
yield s
finally:
s.close()
def test_submission_increment(session: Session):
sub = Submission(id=1, user=1)
session.bulk_save_objects([sub])
session.commit()
with pytest.raises(exceptions.SelfVoteException):
sub.increment(1)
sub.increment(2)
assert sub.votes == [2]
with pytest.raises(exceptions.DatabaseDoubleVoteException):
sub.increment(2)
def test_submission_decrement(session: Session) -> None:
sub = Submission(id=1, user=1)
session.add(sub)
session.commit()
sub.votes = [1]
sub.decrement(1)
assert sub.votes == []
with pytest.raises(exceptions.DatabaseNoVoteException):
sub.decrement(1)
def test_advance_state(session: Session) -> None:
guild = Guild(id=1)
per = Period(id=1, guild=guild)
session.add(per)
session.commit()
assert per.state == PeriodStates.READY
per.advance_state()
assert per.state == PeriodStates.SUBMISSIONS
per.advance_state()
assert per.state == PeriodStates.PAUSED
per.advance_state()
assert per.state == PeriodStates.VOTING
per.advance_state()
assert per.state == PeriodStates.FINISHED
@pytest.fixture()
def database(SessionClass: sessionmaker):
session: Session = SessionClass()
def users(user_ids) -> Generator[int, None, None]:
index = 0
random.shuffle(user_ids)
while True:
if index >= len(user_ids):
random.shuffle(user_ids)
index = 0
yield user_ids[index]
index += 1
user_ids = [next(numbers) for _ in range(50)]
guild = Guild(id=next(numbers), submission_channel=next(numbers))
users = users(user_ids)
for _ in range(3):
for state in PeriodStates:
period = Period(id=next(numbers), guild=guild)
while period.active and period.state != state:
period.advance_state()
for _ in range(50):
s = Submission(id=next(numbers), user=next(users), period=period)
session.add(s)
session.add(period)
session.add(guild)
session.commit()
yield session
session.close()
def test_database(database: Session) -> None:
database.query(Guild).all()
print(database.query(Period).all())
print(database.query(Submission).all())