diff --git a/bot/exceptions.py b/bot/exceptions.py index 1ad9fae..621dee3 100644 --- a/bot/exceptions.py +++ b/bot/exceptions.py @@ -30,3 +30,10 @@ class DatabaseNoVoteException(ContestException): def __repr__(self) -> str: return 'You can\'t remove a vote that never or no longer exists' + + +class SelfVoteException(ContestException): + """A user tried to vote on his own submission.""" + + def __repr__(self) -> str: + return 'You can\'t vote on your own submission. Please choose another post.' diff --git a/bot/models.py b/bot/models.py index ef3fbbb..fb3b072 100644 --- a/bot/models.py +++ b/bot/models.py @@ -6,10 +6,10 @@ import logging from typing import Iterable, List, TYPE_CHECKING, Tuple, Union import discord -from sqlalchemy import Boolean, Column, DateTime, Enum, ForeignKey, Integer, Text +from sqlalchemy import Boolean, Column, DateTime, Enum, ForeignKey, Integer, JSON, Text from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import relationship -from sqlalchemy_json import MutableJson +from sqlalchemy_json import NestedMutableList from bot import constants, exceptions, helpers from bot.constants import ReactionMarker @@ -83,11 +83,19 @@ class Submission(Base): id = Column(Integer, primary_key=True) # Doubles as the ID this Guild has in Discord user = Column(Integer) # The ID of the user who submitted it. timestamp = Column(DateTime) # When the Submission was posted - votes: List[int] = Column(MutableJson) # A list of IDs correlating to users who voted on this submission. + votes: List[int] = Column(NestedMutableList.as_mutable(JSON)) # A list of IDs correlating to users who voted on this submission. period_id = Column(Integer, ForeignKey("period.id")) # The id of the period this Submission relates to. period = relationship("Period", back_populates="submissions") # The period this submission was made in. + def __init__(self, **kwds): + # Adds default column behavior for Mutable JSON votes column + kwds.setdefault("votes", []) + super().__init__(**kwds) + + def __repr__(self) -> str: + return 'Submission(id={id}, user={user}, period={period_id}, {votes})'.format(**self.__dict__) + @property def count(self) -> int: """The number of votes cast for this submission.""" @@ -95,7 +103,9 @@ class Submission(Base): def increment(self, user: int) -> None: """Increase the number of votes by one.""" - if user in self.votes: + if user == self.user: + raise exceptions.SelfVoteException() + elif user in self.votes: raise exceptions.DatabaseDoubleVoteException() self.votes.append(user) @@ -242,3 +252,6 @@ class Period(Base): """ self.finished_time = datetime.datetime.utcnow() self.active = False + + def __repr__(self) -> str: + return 'Period(id={id}, guild={guild_id}, {state.name}, active={active})'.format(**self.__dict__)