From b311700356f30d6def54ef28b1888d237e278845 Mon Sep 17 00:00:00 2001 From: Xevion Date: Fri, 19 Feb 2021 03:29:53 -0600 Subject: [PATCH] Expanded database testing - Tested Submission.count descriptor - Tested Submission.clear_other_votes - Expanded Submission.advance testing to include exceptions, Submission .active, Submission.completed and Submission.voting - Changed db model __repr__ implementations to use f-strings - Fixed Submission.update force kwarg not implemented. - Small formatting changes, new TODO, noinspection on protected access - Votes setter now removes duplicate keys (maintains order) - Removed 'aiosqlite' from Pipfile general package requirements --- Pipfile | 1 - bot/cogs/contest_commands.py | 2 + bot/models.py | 20 +++++---- tests/bot/test_db.py | 87 ++++++++++++++++++++++++++++++------ 4 files changed, 88 insertions(+), 22 deletions(-) diff --git a/Pipfile b/Pipfile index 123fbd5..f56377d 100644 --- a/Pipfile +++ b/Pipfile @@ -7,7 +7,6 @@ verify_ssl = true [packages] discord = "~=1.0.1" -aiosqlite = "~=0.16.1" pytest = "*" pytest-asyncio = "*" sqlalchemy = "*" diff --git a/bot/cogs/contest_commands.py b/bot/cogs/contest_commands.py index 8d9469e..c8c0b6e 100644 --- a/bot/cogs/contest_commands.py +++ b/bot/cogs/contest_commands.py @@ -13,6 +13,7 @@ logger.setLevel(constants.LOGGING_LEVEL) # TODO: Add command error handling to all commands +# TODO: Modify channel subject based on state class ContestCommandsCog(commands.Cog, name='Contest'): """ @@ -39,6 +40,7 @@ class ContestCommandsCog(commands.Cog, name='Contest'): # This prevents any cogs with an overwritten cog_command_error being handled here. cog = ctx.cog + # noinspection PyProtectedMember if cog and cog._get_overridden_method(cog.cog_command_error) is not None: return diff --git a/bot/models.py b/bot/models.py index c151d92..4a859ab 100644 --- a/bot/models.py +++ b/bot/models.py @@ -23,6 +23,7 @@ logger.setLevel(constants.LOGGING_LEVEL) Base = declarative_base() + # TODO: Contest names # TODO: Refactor Period into Contest (major) @@ -83,7 +84,8 @@ class Submission(Base): user = Column(Integer) # The ID of the user who submitted it. timestamp = Column(DateTime) # When the Submission was posted - _votes: List[int] = Column("votes", NestedMutableList.as_mutable(JSON)) # A list of IDs correlating to users who voted on this submission. + _votes: List[int] = Column("votes", + NestedMutableList.as_mutable(JSON)) # A list of IDs correlating to users who voted on this submission. count = Column(Integer, default=0, nullable=False) period_id = Column(Integer, ForeignKey("period.id")) # The id of the period this Submission relates to. @@ -97,6 +99,7 @@ class Submission(Base): @votes.setter def votes(self, votes: List[int]) -> None: """"Setter function for _votes descriptor. Modifies count column.""" + votes = list(dict.fromkeys(votes)) # Remove duplicate values while retaining order self._votes = votes self.count = len(votes) @@ -105,9 +108,6 @@ class Submission(Base): kwargs.setdefault("votes", []) super().__init__(**kwargs) - def __repr__(self) -> str: - return 'Submission(id={id}, user={user}, period={period_id}, {count} votes)'.format(**self.__dict__) - def increment(self, user: int) -> None: """Increase the number of votes by one.""" if user == self.user: @@ -122,7 +122,8 @@ class Submission(Base): raise exceptions.DatabaseNoVoteException() self.votes.remove(user) - def clear_other_votes(self, ignore: Union[int, Iterable[int]], users: Union[int, Iterable[int]], session: 'Session') -> ReactionMarker: + def clear_other_votes(self, ignore: Union[int, Iterable[int]], users: Union[int, Iterable[int]], + session: 'Session') -> List[ReactionMarker]: """ Removes votes from all submissions in the database for a specific user. Returns a list of combination Message and User IDs @@ -138,7 +139,7 @@ class Submission(Base): if len(ignore) == 0: logger.warning(f'Clearing ALL votes for user(s): {users}') if len(users) == 0: return [] - found: List[Tuple[int, int]] = [] + found = [] submissions = session.query(Submission).filter(Submission.id != self.id).all() for submission in submissions: # Ignore submissions in the ignore list @@ -205,7 +206,7 @@ class Submission(Base): ) # Update the current list of votes - if self.period.voting: + if self.period.voting or force: self.votes = list(current) if len(to_remove) > 0: @@ -217,6 +218,9 @@ class Submission(Base): if not saw_self and self.period.voting: await message.add_reaction(constants.Emoji.UPVOTE) + def __repr__(self) -> str: + return f'Submission(id={self.id}, user={self.user}, period={self.period_id}, {self.count} votes)' + class Period(Base): """Represents a particular period of submissions and voting for a given""" @@ -305,4 +309,4 @@ class Period(Base): return "Error." def __repr__(self) -> str: - return 'Period(id={id}, guild={guild_id}, {state.name}, active={active})'.format(**self.__dict__) + return f'Period(id={self.id}, guild={self.guild_id}, {self.state.name}, active={self.active})' diff --git a/tests/bot/test_db.py b/tests/bot/test_db.py index 50fc8c8..766eafc 100644 --- a/tests/bot/test_db.py +++ b/tests/bot/test_db.py @@ -53,21 +53,82 @@ def test_submission_decrement(session: Session) -> None: sub.decrement(1) -def test_advance_state(session: Session) -> None: - guild = Guild(id=1) - per = Period(id=1, guild=guild) - session.add(per) +def test_submission_count_descriptor(session: Session) -> None: + sub = Submission(id=1, user=1) + session.add(sub) session.commit() - assert per.state == PeriodStates.READY - per.advance_state() - assert per.state == PeriodStates.SUBMISSIONS - per.advance_state() - assert per.state == PeriodStates.PAUSED - per.advance_state() - assert per.state == PeriodStates.VOTING - per.advance_state() - assert per.state == PeriodStates.FINISHED + assert sub.votes == [] + assert sub.count == 0 + + sub.votes = [1, 2, 2, 3] + assert sub.votes == [1, 2, 3] + assert sub.count == 3 + + +def test_advance_state(session: Session) -> None: + guild = Guild(id=1) + per1 = Period(id=1, guild=guild) + session.add(per1) + session.commit() + + assert per1.active + assert not per1.completed + assert per1.state == PeriodStates.READY + per1.advance_state() + assert per1.active + assert not per1.completed + assert per1.state == PeriodStates.SUBMISSIONS + per1.advance_state() + assert per1.active + assert not per1.completed + assert per1.state == PeriodStates.PAUSED + per1.advance_state() + assert per1.active + assert not per1.completed + assert per1.voting + assert per1.state == PeriodStates.VOTING + per1.advance_state() + assert per1.state == PeriodStates.FINISHED + assert not per1.active + assert per1.completed + + with pytest.raises(exceptions.FinishedPeriodException): + per1.deactivate() + with pytest.raises(exceptions.FinishedPeriodException): + per1.deactivate() + + per2 = Period(id=2, guild=guild) + session.add(per2) + session.commit() + per2.advance_state() + per2.advance_state() + per2.deactivate() + assert per2.state == PeriodStates.PAUSED + assert not per2.voting + assert not per2.active and not per2.completed + + +def test_submission_clear_other_votes(session: Session) -> None: + guild = Guild(id=1) + per = Period(id=1, guild=guild) + sub1 = Submission(id=1, user=1, period=per) + sub2 = Submission(id=2, user=2, period=per) + sub3 = Submission(id=3, user=3, period=per) + session.add_all([guild, per, sub1, sub2, sub3]) + session.commit() + + sub1.votes = [1, 2] + sub2.votes = [1, 2, 3] + sub3.votes = [2, 3] + + sub2.clear_other_votes(ignore=sub2.id, users=[1, 2], session=session) + + assert sub1.votes == [] + assert sub2.votes == [1, 2, 3] + assert sub3.votes == [3] + + @pytest.fixture()