From 12917edaa4c90bcbd181ed2abf9da4ff69648954 Mon Sep 17 00:00:00 2001 From: Xevion Date: Fri, 22 Jan 2021 09:38:46 -0600 Subject: [PATCH] switch back to threading for global information object access, implement thread-by-thread database connections with thread locking mechanisms --- server/db.py | 103 ++++++++++++++++++++++++++++------------------ server/handler.py | 72 +++++++++++++++++++++++--------- server/main.py | 7 +--- 3 files changed, 117 insertions(+), 65 deletions(-) diff --git a/server/db.py b/server/db.py index ef333b0..6bbf865 100644 --- a/server/db.py +++ b/server/db.py @@ -1,52 +1,75 @@ import logging import sqlite3 +import threading from typing import List import constants logger = logging.getLogger('database') - -conn = sqlite3.connect(constants.DATABASE) -logger.debug(f"Connected to '{constants.DATABASE}'") - -logger.debug("Constructing 'message' table.") -conn.execute('''CREATE TABLE IF NOT EXISTS message - (id INTEGER PRIMARY KEY, - nickname TEXT NOT NULL, - connection_hash TEXT NOT NULL, - color TEXT DEFAULT '#000000', - message TEXT DEFAULT '', - timestamp INTEGER NOT NULL)''') -conn.commit() +lock = threading.Lock() -def add_message(nickname: str, user_hash: str, color: str, message: str, timestamp: int) -> int: - """ - Insert a message into the database. Returns the message ID. +class Database(object): + def __init__(self): + logger.debug(f"Connected to '{constants.DATABASE}'") + self.conn = sqlite3.connect(constants.DATABASE) + self.__isClosed = False - :param nickname: A non-unique identifier for the user. - :param user_hash: A unique hash (usually) denoting the sender's identity. - :param color: The color of the user who sent the message. - :param message: The string content of the message echoed to all clients. - :param timestamp: The epoch time of the sent message. - :return: The unique integer primary key chosen for the message, i.e. it's ID. - """ - cur = conn.cursor() - try: - cur.execute('''INSERT INTO message (nickname, connection_hash, color, message, timestamp) - VALUES (?, ?, ?, ?, ?)''', [nickname, user_hash, color, message, timestamp]) - conn.commit() - logger.debug(f'Message {cur.lastrowid} recorded.') - return cur.lastrowid - finally: - cur.close() + @property + def is_closed(self) -> bool: + return self.__isClosed + def close(self) -> None: + if self.__isClosed: + logger.warning(f'Database connection is already closed.', exc_info=True) + else: + self.conn.close() + self.__isClosed = True -def get_messages(columns: List[str] = None): - cur = conn.cursor() - try: - if columns is None: - cur.execute('''SELECT * FROM message''') - return cur.fetchall() - finally: - cur.close() + def construct(self): + with lock: + cur = self.conn.cursor() + try: + cur.execute('''SELECT name FROM sqlite_master WHERE type='table' AND name='?';''', 'message') + if cur.fetchone() is None: + self.conn.execute('''CREATE TABLE message + (id INTEGER PRIMARY KEY, + nickname TEXT NOT NULL, + connection_hash TEXT NOT NULL, + color TEXT DEFAULT '#000000', + message TEXT DEFAULT '', + timestamp INTEGER NOT NULL)''') + logger.debug("'message' table created.") + finally: + cur.close() + + def add_message(self, nickname: str, user_hash: str, color: str, message: str, timestamp: int) -> int: + """ + Insert a message into the database. Returns the message ID. + + :param nickname: A non-unique identifier for the user. + :param user_hash: A unique hash (usually) denoting the sender's identity. + :param color: The color of the user who sent the message. + :param message: The string content of the message echoed to all clients. + :param timestamp: The epoch time of the sent message. + :return: The unique integer primary key chosen for the message, i.e. it's ID. + """ + with lock: + cur = self.conn.cursor() + try: + cur.execute('''INSERT INTO message (nickname, connection_hash, color, message, timestamp) + VALUES (?, ?, ?, ?, ?)''', [nickname, user_hash, color, message, timestamp]) + logger.debug(f'Message {cur.lastrowid} recorded.') + return cur.lastrowid + finally: + cur.close() + + def get_messages(self, columns: List[str] = None): + with lock: + cur = self.conn.cursor() + try: + if columns is None: + cur.execute('''SELECT * FROM message''') + return cur.fetchall() + finally: + cur.close() diff --git a/server/handler.py b/server/handler.py index 509f668..556c546 100644 --- a/server/handler.py +++ b/server/handler.py @@ -11,6 +11,7 @@ import helpers # noinspection PyUnresolvedReferences from server import db from server.commands import CommandHandler +from server.db import Database logger = logging.getLogger('handler') @@ -20,6 +21,12 @@ class BaseClient(object): def __init__(self, conn: socket.socket, all_clients: List['Client'], address) -> None: self.conn, self.all_clients, self.address = conn, all_clients, address + self.db: Optional[Database] = None + + def connect_database(self): + if self.db is None: + logger.debug('Connecting client to database.') + self.db = Database() def send(self, message: bytes) -> None: """Sends a pre-encoded message to this client.""" @@ -35,12 +42,13 @@ class BaseClient(object): def broadcast_message(self, message: str) -> None: """Sends a string message to all connected clients as the Server.""" timestamp = int(time.time()) - message_id = db.add_message('Server', 'server', constants.Colors.BLACK.hex, message, timestamp) + message_id = self.db.add_message('Server', 'server', constants.Colors.BLACK.hex, message, timestamp) prepared = helpers.prepare_message( nickname='Server', message=message, color=constants.Colors.BLACK.hex, message_id=message_id, timestamp=timestamp ) for client in self.all_clients: + print(f'Sending a message to {client.nickname}') client.send(prepared) def broadcast(self, message: bytes) -> None: @@ -71,6 +79,16 @@ class Client(BaseClient): self.last_nickname_change = None self.last_message_sent = None + def __repr__(self) -> str: + if self.last_nickname_change is None: + return f'Client({self.id[:8]})' + return f'Client({self.nickname}, {self.id[:8]})' + + def connect_database(self) -> None: + if self.db is None: + logger.debug(f'Connecting Client({self.id[:8]}) to the database.') + self.db = Database() + def request_nickname(self) -> None: """Send a request for the client's nickname information.""" self.conn.send(helpers.prepare_request(constants.Requests.REQUEST_NICK)) @@ -89,19 +107,20 @@ class Client(BaseClient): time_limit = min(60 * 30, max(0, time_limit)) min_time = int(time.time()) - time_limit - cur = db.conn.cursor() - try: - cur.execute('''SELECT id, nickname, color, message, timestamp - FROM message - WHERE timestamp >= ? - ORDER BY timestamp - LIMIT ?''', - [min_time, limit]) + with db.lock: + cur = self.db.conn.cursor() + try: + cur.execute('''SELECT id, nickname, color, message, timestamp + FROM message + WHERE timestamp >= ? + ORDER BY timestamp + LIMIT ?''', + [min_time, limit]) - messages = cur.fetchall() - self.send(helpers.prepare_message_history(messages)) - finally: - cur.close() + messages = cur.fetchall() + self.send(helpers.prepare_message_history(messages)) + finally: + cur.close() def receive(self) -> Any: length = int(self.conn.recv(constants.HEADER_LENGTH).decode('utf-8')) @@ -119,7 +138,15 @@ class Client(BaseClient): logger.info(f'{self.nickname} changed their name to {nickname}') self.nickname = nickname + def close(self) -> None: + logger.info(f'Client {self.id} closed. ({self.nickname})') + self.conn.close() # Close socket connection + self.all_clients.remove(self) # Remove the user from the global client list + self.broadcast_message(f'{self.nickname} left!') # Now we can broadcast it's exit message + self.db.conn.close() # Close database connection + def handle(self) -> None: + self.connect_database() while True: try: data = self.receive() @@ -136,8 +163,8 @@ class Client(BaseClient): self.handle_nickname(data['nickname']) elif data['type'] == constants.Types.MESSAGE: # Record the message in the DB. - message_id = db.add_message(self.nickname, self.id, self.color.hex, data['content'], - int(time.time())) + message_id = self.db.add_message(self.nickname, self.id, self.color.hex, data['content'], + int(time.time())) self.broadcast(helpers.prepare_message( nickname=self.nickname, @@ -154,11 +181,16 @@ class Client(BaseClient): msg = self.command.process(args) if msg is not None: self.broadcast_message(msg) - + except DataReceptionException as e: + logger.critical(e) + logger.warning('Aborting connection to the client.') + self.close() + break + except ConnectionResetError: + logger.critical('Lost connection to the client. Exiting.') + self.close() + break except Exception as e: logger.critical(e, exc_info=True) - logger.info(f'Client {self.id} closed. ({self.nickname})') - self.conn.close() - self.all_clients.remove(self) - self.broadcast_message(f'{self.nickname} left!') + self.close() break diff --git a/server/main.py b/server/main.py index f16b91a..f637270 100644 --- a/server/main.py +++ b/server/main.py @@ -1,6 +1,6 @@ import logging -import multiprocessing import socket +import threading from server import handler @@ -29,7 +29,7 @@ def receive(): client.request_nickname() # Start Handling Thread For Client - thread = multiprocessing.Process(target=client.handle, name=client.id[:8]) + thread = threading.Thread(target=client.handle, name=client.id[:8]) thread.start() except KeyboardInterrupt: logger.info('Server closed by user.') @@ -38,7 +38,4 @@ def receive(): if __name__ == '__main__': - from server import db - receive() - db.conn.close()