mirror of
https://github.com/Xevion/tcp-chat.git
synced 2025-12-06 05:16:45 -06:00
switch back to threading for global information object access, implement thread-by-thread database connections with thread locking mechanisms
This commit is contained in:
103
server/db.py
103
server/db.py
@@ -1,52 +1,75 @@
|
|||||||
import logging
|
import logging
|
||||||
import sqlite3
|
import sqlite3
|
||||||
|
import threading
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import constants
|
import constants
|
||||||
|
|
||||||
logger = logging.getLogger('database')
|
logger = logging.getLogger('database')
|
||||||
|
lock = threading.Lock()
|
||||||
conn = sqlite3.connect(constants.DATABASE)
|
|
||||||
logger.debug(f"Connected to '{constants.DATABASE}'")
|
|
||||||
|
|
||||||
logger.debug("Constructing 'message' table.")
|
|
||||||
conn.execute('''CREATE TABLE IF NOT EXISTS message
|
|
||||||
(id INTEGER PRIMARY KEY,
|
|
||||||
nickname TEXT NOT NULL,
|
|
||||||
connection_hash TEXT NOT NULL,
|
|
||||||
color TEXT DEFAULT '#000000',
|
|
||||||
message TEXT DEFAULT '',
|
|
||||||
timestamp INTEGER NOT NULL)''')
|
|
||||||
conn.commit()
|
|
||||||
|
|
||||||
|
|
||||||
def add_message(nickname: str, user_hash: str, color: str, message: str, timestamp: int) -> int:
|
class Database(object):
|
||||||
"""
|
def __init__(self):
|
||||||
Insert a message into the database. Returns the message ID.
|
logger.debug(f"Connected to '{constants.DATABASE}'")
|
||||||
|
self.conn = sqlite3.connect(constants.DATABASE)
|
||||||
|
self.__isClosed = False
|
||||||
|
|
||||||
:param nickname: A non-unique identifier for the user.
|
@property
|
||||||
:param user_hash: A unique hash (usually) denoting the sender's identity.
|
def is_closed(self) -> bool:
|
||||||
:param color: The color of the user who sent the message.
|
return self.__isClosed
|
||||||
:param message: The string content of the message echoed to all clients.
|
|
||||||
:param timestamp: The epoch time of the sent message.
|
|
||||||
:return: The unique integer primary key chosen for the message, i.e. it's ID.
|
|
||||||
"""
|
|
||||||
cur = conn.cursor()
|
|
||||||
try:
|
|
||||||
cur.execute('''INSERT INTO message (nickname, connection_hash, color, message, timestamp)
|
|
||||||
VALUES (?, ?, ?, ?, ?)''', [nickname, user_hash, color, message, timestamp])
|
|
||||||
conn.commit()
|
|
||||||
logger.debug(f'Message {cur.lastrowid} recorded.')
|
|
||||||
return cur.lastrowid
|
|
||||||
finally:
|
|
||||||
cur.close()
|
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
if self.__isClosed:
|
||||||
|
logger.warning(f'Database connection is already closed.', exc_info=True)
|
||||||
|
else:
|
||||||
|
self.conn.close()
|
||||||
|
self.__isClosed = True
|
||||||
|
|
||||||
def get_messages(columns: List[str] = None):
|
def construct(self):
|
||||||
cur = conn.cursor()
|
with lock:
|
||||||
try:
|
cur = self.conn.cursor()
|
||||||
if columns is None:
|
try:
|
||||||
cur.execute('''SELECT * FROM message''')
|
cur.execute('''SELECT name FROM sqlite_master WHERE type='table' AND name='?';''', 'message')
|
||||||
return cur.fetchall()
|
if cur.fetchone() is None:
|
||||||
finally:
|
self.conn.execute('''CREATE TABLE message
|
||||||
cur.close()
|
(id INTEGER PRIMARY KEY,
|
||||||
|
nickname TEXT NOT NULL,
|
||||||
|
connection_hash TEXT NOT NULL,
|
||||||
|
color TEXT DEFAULT '#000000',
|
||||||
|
message TEXT DEFAULT '',
|
||||||
|
timestamp INTEGER NOT NULL)''')
|
||||||
|
logger.debug("'message' table created.")
|
||||||
|
finally:
|
||||||
|
cur.close()
|
||||||
|
|
||||||
|
def add_message(self, nickname: str, user_hash: str, color: str, message: str, timestamp: int) -> int:
|
||||||
|
"""
|
||||||
|
Insert a message into the database. Returns the message ID.
|
||||||
|
|
||||||
|
:param nickname: A non-unique identifier for the user.
|
||||||
|
:param user_hash: A unique hash (usually) denoting the sender's identity.
|
||||||
|
:param color: The color of the user who sent the message.
|
||||||
|
:param message: The string content of the message echoed to all clients.
|
||||||
|
:param timestamp: The epoch time of the sent message.
|
||||||
|
:return: The unique integer primary key chosen for the message, i.e. it's ID.
|
||||||
|
"""
|
||||||
|
with lock:
|
||||||
|
cur = self.conn.cursor()
|
||||||
|
try:
|
||||||
|
cur.execute('''INSERT INTO message (nickname, connection_hash, color, message, timestamp)
|
||||||
|
VALUES (?, ?, ?, ?, ?)''', [nickname, user_hash, color, message, timestamp])
|
||||||
|
logger.debug(f'Message {cur.lastrowid} recorded.')
|
||||||
|
return cur.lastrowid
|
||||||
|
finally:
|
||||||
|
cur.close()
|
||||||
|
|
||||||
|
def get_messages(self, columns: List[str] = None):
|
||||||
|
with lock:
|
||||||
|
cur = self.conn.cursor()
|
||||||
|
try:
|
||||||
|
if columns is None:
|
||||||
|
cur.execute('''SELECT * FROM message''')
|
||||||
|
return cur.fetchall()
|
||||||
|
finally:
|
||||||
|
cur.close()
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import helpers
|
|||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
from server import db
|
from server import db
|
||||||
from server.commands import CommandHandler
|
from server.commands import CommandHandler
|
||||||
|
from server.db import Database
|
||||||
|
|
||||||
logger = logging.getLogger('handler')
|
logger = logging.getLogger('handler')
|
||||||
|
|
||||||
@@ -20,6 +21,12 @@ class BaseClient(object):
|
|||||||
|
|
||||||
def __init__(self, conn: socket.socket, all_clients: List['Client'], address) -> None:
|
def __init__(self, conn: socket.socket, all_clients: List['Client'], address) -> None:
|
||||||
self.conn, self.all_clients, self.address = conn, all_clients, address
|
self.conn, self.all_clients, self.address = conn, all_clients, address
|
||||||
|
self.db: Optional[Database] = None
|
||||||
|
|
||||||
|
def connect_database(self):
|
||||||
|
if self.db is None:
|
||||||
|
logger.debug('Connecting client to database.')
|
||||||
|
self.db = Database()
|
||||||
|
|
||||||
def send(self, message: bytes) -> None:
|
def send(self, message: bytes) -> None:
|
||||||
"""Sends a pre-encoded message to this client."""
|
"""Sends a pre-encoded message to this client."""
|
||||||
@@ -35,12 +42,13 @@ class BaseClient(object):
|
|||||||
def broadcast_message(self, message: str) -> None:
|
def broadcast_message(self, message: str) -> None:
|
||||||
"""Sends a string message to all connected clients as the Server."""
|
"""Sends a string message to all connected clients as the Server."""
|
||||||
timestamp = int(time.time())
|
timestamp = int(time.time())
|
||||||
message_id = db.add_message('Server', 'server', constants.Colors.BLACK.hex, message, timestamp)
|
message_id = self.db.add_message('Server', 'server', constants.Colors.BLACK.hex, message, timestamp)
|
||||||
prepared = helpers.prepare_message(
|
prepared = helpers.prepare_message(
|
||||||
nickname='Server', message=message, color=constants.Colors.BLACK.hex, message_id=message_id,
|
nickname='Server', message=message, color=constants.Colors.BLACK.hex, message_id=message_id,
|
||||||
timestamp=timestamp
|
timestamp=timestamp
|
||||||
)
|
)
|
||||||
for client in self.all_clients:
|
for client in self.all_clients:
|
||||||
|
print(f'Sending a message to {client.nickname}')
|
||||||
client.send(prepared)
|
client.send(prepared)
|
||||||
|
|
||||||
def broadcast(self, message: bytes) -> None:
|
def broadcast(self, message: bytes) -> None:
|
||||||
@@ -71,6 +79,16 @@ class Client(BaseClient):
|
|||||||
self.last_nickname_change = None
|
self.last_nickname_change = None
|
||||||
self.last_message_sent = None
|
self.last_message_sent = None
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
if self.last_nickname_change is None:
|
||||||
|
return f'Client({self.id[:8]})'
|
||||||
|
return f'Client({self.nickname}, {self.id[:8]})'
|
||||||
|
|
||||||
|
def connect_database(self) -> None:
|
||||||
|
if self.db is None:
|
||||||
|
logger.debug(f'Connecting Client({self.id[:8]}) to the database.')
|
||||||
|
self.db = Database()
|
||||||
|
|
||||||
def request_nickname(self) -> None:
|
def request_nickname(self) -> None:
|
||||||
"""Send a request for the client's nickname information."""
|
"""Send a request for the client's nickname information."""
|
||||||
self.conn.send(helpers.prepare_request(constants.Requests.REQUEST_NICK))
|
self.conn.send(helpers.prepare_request(constants.Requests.REQUEST_NICK))
|
||||||
@@ -89,19 +107,20 @@ class Client(BaseClient):
|
|||||||
time_limit = min(60 * 30, max(0, time_limit))
|
time_limit = min(60 * 30, max(0, time_limit))
|
||||||
min_time = int(time.time()) - time_limit
|
min_time = int(time.time()) - time_limit
|
||||||
|
|
||||||
cur = db.conn.cursor()
|
with db.lock:
|
||||||
try:
|
cur = self.db.conn.cursor()
|
||||||
cur.execute('''SELECT id, nickname, color, message, timestamp
|
try:
|
||||||
FROM message
|
cur.execute('''SELECT id, nickname, color, message, timestamp
|
||||||
WHERE timestamp >= ?
|
FROM message
|
||||||
ORDER BY timestamp
|
WHERE timestamp >= ?
|
||||||
LIMIT ?''',
|
ORDER BY timestamp
|
||||||
[min_time, limit])
|
LIMIT ?''',
|
||||||
|
[min_time, limit])
|
||||||
|
|
||||||
messages = cur.fetchall()
|
messages = cur.fetchall()
|
||||||
self.send(helpers.prepare_message_history(messages))
|
self.send(helpers.prepare_message_history(messages))
|
||||||
finally:
|
finally:
|
||||||
cur.close()
|
cur.close()
|
||||||
|
|
||||||
def receive(self) -> Any:
|
def receive(self) -> Any:
|
||||||
length = int(self.conn.recv(constants.HEADER_LENGTH).decode('utf-8'))
|
length = int(self.conn.recv(constants.HEADER_LENGTH).decode('utf-8'))
|
||||||
@@ -119,7 +138,15 @@ class Client(BaseClient):
|
|||||||
logger.info(f'{self.nickname} changed their name to {nickname}')
|
logger.info(f'{self.nickname} changed their name to {nickname}')
|
||||||
self.nickname = nickname
|
self.nickname = nickname
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
logger.info(f'Client {self.id} closed. ({self.nickname})')
|
||||||
|
self.conn.close() # Close socket connection
|
||||||
|
self.all_clients.remove(self) # Remove the user from the global client list
|
||||||
|
self.broadcast_message(f'{self.nickname} left!') # Now we can broadcast it's exit message
|
||||||
|
self.db.conn.close() # Close database connection
|
||||||
|
|
||||||
def handle(self) -> None:
|
def handle(self) -> None:
|
||||||
|
self.connect_database()
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
data = self.receive()
|
data = self.receive()
|
||||||
@@ -136,8 +163,8 @@ class Client(BaseClient):
|
|||||||
self.handle_nickname(data['nickname'])
|
self.handle_nickname(data['nickname'])
|
||||||
elif data['type'] == constants.Types.MESSAGE:
|
elif data['type'] == constants.Types.MESSAGE:
|
||||||
# Record the message in the DB.
|
# Record the message in the DB.
|
||||||
message_id = db.add_message(self.nickname, self.id, self.color.hex, data['content'],
|
message_id = self.db.add_message(self.nickname, self.id, self.color.hex, data['content'],
|
||||||
int(time.time()))
|
int(time.time()))
|
||||||
|
|
||||||
self.broadcast(helpers.prepare_message(
|
self.broadcast(helpers.prepare_message(
|
||||||
nickname=self.nickname,
|
nickname=self.nickname,
|
||||||
@@ -154,11 +181,16 @@ class Client(BaseClient):
|
|||||||
msg = self.command.process(args)
|
msg = self.command.process(args)
|
||||||
if msg is not None:
|
if msg is not None:
|
||||||
self.broadcast_message(msg)
|
self.broadcast_message(msg)
|
||||||
|
except DataReceptionException as e:
|
||||||
|
logger.critical(e)
|
||||||
|
logger.warning('Aborting connection to the client.')
|
||||||
|
self.close()
|
||||||
|
break
|
||||||
|
except ConnectionResetError:
|
||||||
|
logger.critical('Lost connection to the client. Exiting.')
|
||||||
|
self.close()
|
||||||
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.critical(e, exc_info=True)
|
logger.critical(e, exc_info=True)
|
||||||
logger.info(f'Client {self.id} closed. ({self.nickname})')
|
self.close()
|
||||||
self.conn.close()
|
|
||||||
self.all_clients.remove(self)
|
|
||||||
self.broadcast_message(f'{self.nickname} left!')
|
|
||||||
break
|
break
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
import multiprocessing
|
|
||||||
import socket
|
import socket
|
||||||
|
import threading
|
||||||
|
|
||||||
from server import handler
|
from server import handler
|
||||||
|
|
||||||
@@ -29,7 +29,7 @@ def receive():
|
|||||||
client.request_nickname()
|
client.request_nickname()
|
||||||
|
|
||||||
# Start Handling Thread For Client
|
# Start Handling Thread For Client
|
||||||
thread = multiprocessing.Process(target=client.handle, name=client.id[:8])
|
thread = threading.Thread(target=client.handle, name=client.id[:8])
|
||||||
thread.start()
|
thread.start()
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
logger.info('Server closed by user.')
|
logger.info('Server closed by user.')
|
||||||
@@ -38,7 +38,4 @@ def receive():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
from server import db
|
|
||||||
|
|
||||||
receive()
|
receive()
|
||||||
db.conn.close()
|
|
||||||
|
|||||||
Reference in New Issue
Block a user