diff --git a/server/handler.py b/server/handler.py index e421b3e..8f54135 100644 --- a/server/handler.py +++ b/server/handler.py @@ -5,12 +5,12 @@ import socket import time import uuid from json import JSONDecodeError -from typing import Any, List, Optional +from typing import Any, Callable, List, Optional from shared import constants from shared import helpers # noinspection PyUnresolvedReferences -from shared.exceptions import DataReceptionException +from shared.exceptions import DataReceptionException, StopException from server import db from server.commands import CommandHandler @@ -21,11 +21,14 @@ logger.setLevel(logging.DEBUG) class BaseClient(object): """A simple base class for the client containing basic client communication methods.""" - def __init__(self, conn: socket.socket, all_clients: List['Client'], address) -> None: - self.conn, self.all_clients, self.address = conn, all_clients, address + def __init__(self, conn: socket.socket, all_clients: List['Client'], address, stop_flag: Callable[[], bool]) -> None: + self.conn, self.all_clients, self.address, self.stop_flag = conn, all_clients, address, stop_flag self.db: Optional[db.ServerDatabase] = None + self.conn.settimeout(0.5) + def connect_database(self): + """""" if self.db is None: logger.debug('Connecting client to database.') self.db = db.ServerDatabase() @@ -38,7 +41,7 @@ class BaseClient(object): """Sends a string message as the server to this client.""" # db.add_message('Server', 'server', constants.Colors.BLACK.hex, message, int(time.time())) self.conn.send(helpers.prepare_message( - nickname='Server', message=message, color=constants.Colors.BLACK.hex, message_id=-1 + nickname='Server', message=message, color=constants.Colors.BLACK.hex, message_id=-1 )) def broadcast_message(self, message: str) -> None: @@ -46,8 +49,8 @@ class BaseClient(object): timestamp = int(time.time()) message_id = self.db.add_message('Server', 'server', constants.Colors.BLACK.hex, message, timestamp) prepared = helpers.prepare_message( - nickname='Server', message=message, color=constants.Colors.BLACK.hex, message_id=message_id, - timestamp=timestamp + nickname='Server', message=message, color=constants.Colors.BLACK.hex, message_id=message_id, + timestamp=timestamp ) for client in self.all_clients: client.send(prepared) @@ -68,8 +71,8 @@ class Client(BaseClient): Client.run() should be ran in a thread alongside the other clients. """ - def __init__(self, conn: socket.socket, address: Any, all_clients: List['Client']): - super().__init__(conn, all_clients, address) + def __init__(self, conn: socket.socket, address: Any, all_clients: List['Client'], stop_flag: Callable[[], bool]): + super().__init__(conn, all_clients, address, stop_flag) self.id = str(uuid.uuid4()) self.short_id = self.id[:8] @@ -99,10 +102,10 @@ class Client(BaseClient): def send_connections_list(self) -> None: """Sends a list of connections to the server, identifying their nickname and color""" self.conn.send(helpers.prepare_json( - { - 'type': constants.Types.USER_LIST, - 'users': [{'nickname': other.nickname, 'color': other.color.hex} for other in self.all_clients] - } + { + 'type': constants.Types.USER_LIST, + 'users': [{'nickname': other.nickname, 'color': other.color.hex} for other in self.all_clients] + } )) def send_message_history(self, limit: int, time_limit: int) -> None: @@ -126,12 +129,17 @@ class Client(BaseClient): cur.close() def receive(self) -> Any: - try: - length = int(self.conn.recv(constants.HEADER_LENGTH).decode('utf-8')) - except ValueError: - raise DataReceptionException('The socket did not receive the expected header.') - else: - logger.debug(f'Header received - Length {length}') + while True: + try: + self.check_stop() + length = int(self.conn.recv(constants.HEADER_LENGTH).decode('utf-8')) + except socket.timeout: + continue + except ValueError: + raise DataReceptionException('The socket did not receive the expected header.') + else: + logger.debug(f'Header received - Length {length}') + break try: data = self.conn.recv(length).decode('utf-8') @@ -155,8 +163,14 @@ class Client(BaseClient): for client in self.all_clients: client.send_connections_list() + def check_stop(self) -> None: + """Raises a StopException if the stop flag is set to true by the commanding main thread.""" + stop_flag: bool = self.stop_flag() + if stop_flag: + raise StopException() + def close(self) -> None: - logger.info(f'Client {self.id} closed. ({self.nickname})') + logger.info(f'Shutting down Client {self.id}. ({self.nickname})') self.conn.close() # Close socket connection self.all_clients.remove(self) # Remove the user from the global client list self.broadcast_message(f'{self.nickname} left!') # Now we can broadcast it's exit message @@ -178,7 +192,7 @@ class Client(BaseClient): if data['type'] == constants.Types.REQUEST: if data['request'] == constants.Requests.GET_MESSAGE_HISTORY: self.send_message_history( - limit=data.get('limit', 50), time_limit=data.get('time_limit', 60 * 30) + limit=data.get('limit', 50), time_limit=data.get('time_limit', 60 * 30) ) elif data['type'] == constants.Types.NICKNAME: @@ -189,10 +203,10 @@ class Client(BaseClient): int(time.time())) self.broadcast(helpers.prepare_message( - nickname=self.nickname, - message=data['content'], - color=self.color.hex, - message_id=message_id + nickname=self.nickname, + message=data['content'], + color=self.color.hex, + message_id=message_id )) # Process commands @@ -209,7 +223,11 @@ class Client(BaseClient): self.close() break except ConnectionResetError: - logger.critical('Lost connection to the client. Exiting.') + logger.critical('Lost connection to the client.') + self.close() + break + except StopException: + logger.info('Stop flag received from main thread.') self.close() break except Exception as e: diff --git a/shared/exceptions.py b/shared/exceptions.py index 576469f..62a67d9 100644 --- a/shared/exceptions.py +++ b/shared/exceptions.py @@ -1,5 +1,11 @@ class TCPChatException(BaseException): pass + class DataReceptionException(TCPChatException): pass + + +class StopException(TCPChatException): + """An exception that occurs when a thread finds a stop flag.""" + pass