mirror of
https://github.com/Xevion/tcp-chat.git
synced 2025-12-06 09:16:40 -06:00
Add server-side flag-based thread stopping with socket timeouts
This commit is contained in:
@@ -5,12 +5,12 @@ import socket
|
|||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from json import JSONDecodeError
|
from json import JSONDecodeError
|
||||||
from typing import Any, List, Optional
|
from typing import Any, Callable, List, Optional
|
||||||
|
|
||||||
from shared import constants
|
from shared import constants
|
||||||
from shared import helpers
|
from shared import helpers
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
from shared.exceptions import DataReceptionException
|
from shared.exceptions import DataReceptionException, StopException
|
||||||
from server import db
|
from server import db
|
||||||
from server.commands import CommandHandler
|
from server.commands import CommandHandler
|
||||||
|
|
||||||
@@ -21,11 +21,14 @@ logger.setLevel(logging.DEBUG)
|
|||||||
class BaseClient(object):
|
class BaseClient(object):
|
||||||
"""A simple base class for the client containing basic client communication methods."""
|
"""A simple base class for the client containing basic client communication methods."""
|
||||||
|
|
||||||
def __init__(self, conn: socket.socket, all_clients: List['Client'], address) -> None:
|
def __init__(self, conn: socket.socket, all_clients: List['Client'], address, stop_flag: Callable[[], bool]) -> None:
|
||||||
self.conn, self.all_clients, self.address = conn, all_clients, address
|
self.conn, self.all_clients, self.address, self.stop_flag = conn, all_clients, address, stop_flag
|
||||||
self.db: Optional[db.ServerDatabase] = None
|
self.db: Optional[db.ServerDatabase] = None
|
||||||
|
|
||||||
|
self.conn.settimeout(0.5)
|
||||||
|
|
||||||
def connect_database(self):
|
def connect_database(self):
|
||||||
|
""""""
|
||||||
if self.db is None:
|
if self.db is None:
|
||||||
logger.debug('Connecting client to database.')
|
logger.debug('Connecting client to database.')
|
||||||
self.db = db.ServerDatabase()
|
self.db = db.ServerDatabase()
|
||||||
@@ -38,7 +41,7 @@ class BaseClient(object):
|
|||||||
"""Sends a string message as the server to this client."""
|
"""Sends a string message as the server to this client."""
|
||||||
# db.add_message('Server', 'server', constants.Colors.BLACK.hex, message, int(time.time()))
|
# db.add_message('Server', 'server', constants.Colors.BLACK.hex, message, int(time.time()))
|
||||||
self.conn.send(helpers.prepare_message(
|
self.conn.send(helpers.prepare_message(
|
||||||
nickname='Server', message=message, color=constants.Colors.BLACK.hex, message_id=-1
|
nickname='Server', message=message, color=constants.Colors.BLACK.hex, message_id=-1
|
||||||
))
|
))
|
||||||
|
|
||||||
def broadcast_message(self, message: str) -> None:
|
def broadcast_message(self, message: str) -> None:
|
||||||
@@ -46,8 +49,8 @@ class BaseClient(object):
|
|||||||
timestamp = int(time.time())
|
timestamp = int(time.time())
|
||||||
message_id = self.db.add_message('Server', 'server', constants.Colors.BLACK.hex, message, timestamp)
|
message_id = self.db.add_message('Server', 'server', constants.Colors.BLACK.hex, message, timestamp)
|
||||||
prepared = helpers.prepare_message(
|
prepared = helpers.prepare_message(
|
||||||
nickname='Server', message=message, color=constants.Colors.BLACK.hex, message_id=message_id,
|
nickname='Server', message=message, color=constants.Colors.BLACK.hex, message_id=message_id,
|
||||||
timestamp=timestamp
|
timestamp=timestamp
|
||||||
)
|
)
|
||||||
for client in self.all_clients:
|
for client in self.all_clients:
|
||||||
client.send(prepared)
|
client.send(prepared)
|
||||||
@@ -68,8 +71,8 @@ class Client(BaseClient):
|
|||||||
Client.run() should be ran in a thread alongside the other clients.
|
Client.run() should be ran in a thread alongside the other clients.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, conn: socket.socket, address: Any, all_clients: List['Client']):
|
def __init__(self, conn: socket.socket, address: Any, all_clients: List['Client'], stop_flag: Callable[[], bool]):
|
||||||
super().__init__(conn, all_clients, address)
|
super().__init__(conn, all_clients, address, stop_flag)
|
||||||
|
|
||||||
self.id = str(uuid.uuid4())
|
self.id = str(uuid.uuid4())
|
||||||
self.short_id = self.id[:8]
|
self.short_id = self.id[:8]
|
||||||
@@ -99,10 +102,10 @@ class Client(BaseClient):
|
|||||||
def send_connections_list(self) -> None:
|
def send_connections_list(self) -> None:
|
||||||
"""Sends a list of connections to the server, identifying their nickname and color"""
|
"""Sends a list of connections to the server, identifying their nickname and color"""
|
||||||
self.conn.send(helpers.prepare_json(
|
self.conn.send(helpers.prepare_json(
|
||||||
{
|
{
|
||||||
'type': constants.Types.USER_LIST,
|
'type': constants.Types.USER_LIST,
|
||||||
'users': [{'nickname': other.nickname, 'color': other.color.hex} for other in self.all_clients]
|
'users': [{'nickname': other.nickname, 'color': other.color.hex} for other in self.all_clients]
|
||||||
}
|
}
|
||||||
))
|
))
|
||||||
|
|
||||||
def send_message_history(self, limit: int, time_limit: int) -> None:
|
def send_message_history(self, limit: int, time_limit: int) -> None:
|
||||||
@@ -126,12 +129,17 @@ class Client(BaseClient):
|
|||||||
cur.close()
|
cur.close()
|
||||||
|
|
||||||
def receive(self) -> Any:
|
def receive(self) -> Any:
|
||||||
try:
|
while True:
|
||||||
length = int(self.conn.recv(constants.HEADER_LENGTH).decode('utf-8'))
|
try:
|
||||||
except ValueError:
|
self.check_stop()
|
||||||
raise DataReceptionException('The socket did not receive the expected header.')
|
length = int(self.conn.recv(constants.HEADER_LENGTH).decode('utf-8'))
|
||||||
else:
|
except socket.timeout:
|
||||||
logger.debug(f'Header received - Length {length}')
|
continue
|
||||||
|
except ValueError:
|
||||||
|
raise DataReceptionException('The socket did not receive the expected header.')
|
||||||
|
else:
|
||||||
|
logger.debug(f'Header received - Length {length}')
|
||||||
|
break
|
||||||
|
|
||||||
try:
|
try:
|
||||||
data = self.conn.recv(length).decode('utf-8')
|
data = self.conn.recv(length).decode('utf-8')
|
||||||
@@ -155,8 +163,14 @@ class Client(BaseClient):
|
|||||||
for client in self.all_clients:
|
for client in self.all_clients:
|
||||||
client.send_connections_list()
|
client.send_connections_list()
|
||||||
|
|
||||||
|
def check_stop(self) -> None:
|
||||||
|
"""Raises a StopException if the stop flag is set to true by the commanding main thread."""
|
||||||
|
stop_flag: bool = self.stop_flag()
|
||||||
|
if stop_flag:
|
||||||
|
raise StopException()
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
logger.info(f'Client {self.id} closed. ({self.nickname})')
|
logger.info(f'Shutting down Client {self.id}. ({self.nickname})')
|
||||||
self.conn.close() # Close socket connection
|
self.conn.close() # Close socket connection
|
||||||
self.all_clients.remove(self) # Remove the user from the global client list
|
self.all_clients.remove(self) # Remove the user from the global client list
|
||||||
self.broadcast_message(f'{self.nickname} left!') # Now we can broadcast it's exit message
|
self.broadcast_message(f'{self.nickname} left!') # Now we can broadcast it's exit message
|
||||||
@@ -178,7 +192,7 @@ class Client(BaseClient):
|
|||||||
if data['type'] == constants.Types.REQUEST:
|
if data['type'] == constants.Types.REQUEST:
|
||||||
if data['request'] == constants.Requests.GET_MESSAGE_HISTORY:
|
if data['request'] == constants.Requests.GET_MESSAGE_HISTORY:
|
||||||
self.send_message_history(
|
self.send_message_history(
|
||||||
limit=data.get('limit', 50), time_limit=data.get('time_limit', 60 * 30)
|
limit=data.get('limit', 50), time_limit=data.get('time_limit', 60 * 30)
|
||||||
)
|
)
|
||||||
|
|
||||||
elif data['type'] == constants.Types.NICKNAME:
|
elif data['type'] == constants.Types.NICKNAME:
|
||||||
@@ -189,10 +203,10 @@ class Client(BaseClient):
|
|||||||
int(time.time()))
|
int(time.time()))
|
||||||
|
|
||||||
self.broadcast(helpers.prepare_message(
|
self.broadcast(helpers.prepare_message(
|
||||||
nickname=self.nickname,
|
nickname=self.nickname,
|
||||||
message=data['content'],
|
message=data['content'],
|
||||||
color=self.color.hex,
|
color=self.color.hex,
|
||||||
message_id=message_id
|
message_id=message_id
|
||||||
))
|
))
|
||||||
|
|
||||||
# Process commands
|
# Process commands
|
||||||
@@ -209,7 +223,11 @@ class Client(BaseClient):
|
|||||||
self.close()
|
self.close()
|
||||||
break
|
break
|
||||||
except ConnectionResetError:
|
except ConnectionResetError:
|
||||||
logger.critical('Lost connection to the client. Exiting.')
|
logger.critical('Lost connection to the client.')
|
||||||
|
self.close()
|
||||||
|
break
|
||||||
|
except StopException:
|
||||||
|
logger.info('Stop flag received from main thread.')
|
||||||
self.close()
|
self.close()
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -1,5 +1,11 @@
|
|||||||
class TCPChatException(BaseException):
|
class TCPChatException(BaseException):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class DataReceptionException(TCPChatException):
|
class DataReceptionException(TCPChatException):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class StopException(TCPChatException):
|
||||||
|
"""An exception that occurs when a thread finds a stop flag."""
|
||||||
|
pass
|
||||||
|
|||||||
Reference in New Issue
Block a user