668 lines
25 KiB
Python
668 lines
25 KiB
Python
"""
|
|
WebSocket Connection Pool and Management Service
|
|
|
|
This module provides a centralized WebSocket connection pooling system for the Delphi Database
|
|
application. It manages connections efficiently, handles cleanup of stale connections,
|
|
monitors connection health, and provides resource management to prevent memory leaks.
|
|
|
|
Features:
|
|
- Connection pooling by topic/channel
|
|
- Automatic cleanup of inactive connections
|
|
- Health monitoring and heartbeat management
|
|
- Resource management and memory leak prevention
|
|
- Integration with existing authentication
|
|
- Structured logging for debugging
|
|
"""
|
|
|
|
import asyncio
|
|
import time
|
|
import uuid
|
|
from typing import Dict, Set, Optional, Any, Callable, List, Union
|
|
from datetime import datetime, timezone, timedelta
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
from contextlib import asynccontextmanager
|
|
|
|
from fastapi import WebSocket, WebSocketDisconnect
|
|
from pydantic import BaseModel
|
|
|
|
from app.utils.logging import StructuredLogger
|
|
|
|
|
|
class ConnectionState(Enum):
|
|
"""WebSocket connection states"""
|
|
CONNECTING = "connecting"
|
|
CONNECTED = "connected"
|
|
DISCONNECTING = "disconnecting"
|
|
DISCONNECTED = "disconnected"
|
|
ERROR = "error"
|
|
|
|
|
|
class MessageType(Enum):
|
|
"""WebSocket message types"""
|
|
PING = "ping"
|
|
PONG = "pong"
|
|
DATA = "data"
|
|
ERROR = "error"
|
|
HEARTBEAT = "heartbeat"
|
|
SUBSCRIBE = "subscribe"
|
|
UNSUBSCRIBE = "unsubscribe"
|
|
|
|
|
|
@dataclass
|
|
class ConnectionInfo:
|
|
"""Information about a WebSocket connection"""
|
|
id: str
|
|
websocket: WebSocket
|
|
user_id: Optional[int]
|
|
topics: Set[str]
|
|
state: ConnectionState
|
|
created_at: datetime
|
|
last_activity: datetime
|
|
last_ping: Optional[datetime]
|
|
last_pong: Optional[datetime]
|
|
error_count: int
|
|
metadata: Dict[str, Any]
|
|
|
|
def is_alive(self) -> bool:
|
|
"""Check if connection is alive based on state"""
|
|
return self.state in [ConnectionState.CONNECTED, ConnectionState.CONNECTING]
|
|
|
|
def is_stale(self, timeout_seconds: int = 300) -> bool:
|
|
"""Check if connection is stale (no activity for timeout_seconds)"""
|
|
if not self.is_alive():
|
|
return True
|
|
return (datetime.now(timezone.utc) - self.last_activity).total_seconds() > timeout_seconds
|
|
|
|
def update_activity(self):
|
|
"""Update last activity timestamp"""
|
|
self.last_activity = datetime.now(timezone.utc)
|
|
|
|
|
|
class WebSocketMessage(BaseModel):
|
|
"""Standard WebSocket message format"""
|
|
type: str
|
|
topic: Optional[str] = None
|
|
data: Optional[Dict[str, Any]] = None
|
|
timestamp: Optional[str] = None
|
|
error: Optional[str] = None
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""Convert to dictionary for JSON serialization"""
|
|
return self.model_dump(exclude_none=True)
|
|
|
|
|
|
class WebSocketPool:
|
|
"""
|
|
Centralized WebSocket connection pool manager
|
|
|
|
Manages WebSocket connections by topics/channels, provides automatic cleanup,
|
|
health monitoring, and resource management.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
cleanup_interval: int = 60, # seconds
|
|
connection_timeout: int = 300, # seconds
|
|
heartbeat_interval: int = 30, # seconds
|
|
max_connections_per_topic: int = 1000,
|
|
max_total_connections: int = 10000,
|
|
):
|
|
self.cleanup_interval = cleanup_interval
|
|
self.connection_timeout = connection_timeout
|
|
self.heartbeat_interval = heartbeat_interval
|
|
self.max_connections_per_topic = max_connections_per_topic
|
|
self.max_total_connections = max_total_connections
|
|
|
|
# Connection storage
|
|
self._connections: Dict[str, ConnectionInfo] = {}
|
|
self._topics: Dict[str, Set[str]] = {} # topic -> connection_ids
|
|
self._user_connections: Dict[int, Set[str]] = {} # user_id -> connection_ids
|
|
|
|
# Locks for thread safety
|
|
self._connections_lock = asyncio.Lock()
|
|
self._cleanup_task: Optional[asyncio.Task] = None
|
|
self._heartbeat_task: Optional[asyncio.Task] = None
|
|
|
|
# Statistics
|
|
self._stats = {
|
|
"total_connections": 0,
|
|
"active_connections": 0,
|
|
"messages_sent": 0,
|
|
"messages_failed": 0,
|
|
"connections_cleaned": 0,
|
|
"last_cleanup": None,
|
|
"last_heartbeat": None,
|
|
}
|
|
|
|
self.logger = StructuredLogger("websocket_pool", "INFO")
|
|
self.logger.info("WebSocket pool initialized",
|
|
cleanup_interval=cleanup_interval,
|
|
connection_timeout=connection_timeout,
|
|
heartbeat_interval=heartbeat_interval)
|
|
|
|
async def start(self):
|
|
"""Start the WebSocket pool background tasks"""
|
|
# If no global pool exists, register this instance to satisfy contexts that
|
|
# rely on the module-level getter during tests and simple scripts
|
|
global _websocket_pool
|
|
if _websocket_pool is None:
|
|
_websocket_pool = self
|
|
if self._cleanup_task is None:
|
|
self._cleanup_task = asyncio.create_task(self._cleanup_worker())
|
|
self.logger.info("Started cleanup worker task")
|
|
|
|
if self._heartbeat_task is None:
|
|
self._heartbeat_task = asyncio.create_task(self._heartbeat_worker())
|
|
self.logger.info("Started heartbeat worker task")
|
|
|
|
async def stop(self):
|
|
"""Stop the WebSocket pool and cleanup all connections"""
|
|
self.logger.info("Stopping WebSocket pool")
|
|
|
|
# Cancel background tasks
|
|
if self._cleanup_task:
|
|
self._cleanup_task.cancel()
|
|
try:
|
|
await self._cleanup_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
self._cleanup_task = None
|
|
|
|
if self._heartbeat_task:
|
|
self._heartbeat_task.cancel()
|
|
try:
|
|
await self._heartbeat_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
self._heartbeat_task = None
|
|
|
|
# Close all connections
|
|
await self._close_all_connections()
|
|
|
|
self.logger.info("WebSocket pool stopped")
|
|
# If this instance is the registered global, clear it
|
|
global _websocket_pool
|
|
if _websocket_pool is self:
|
|
_websocket_pool = None
|
|
|
|
async def add_connection(
|
|
self,
|
|
websocket: WebSocket,
|
|
user_id: Optional[int] = None,
|
|
topics: Optional[Set[str]] = None,
|
|
metadata: Optional[Dict[str, Any]] = None
|
|
) -> str:
|
|
"""
|
|
Add a new WebSocket connection to the pool
|
|
|
|
Args:
|
|
websocket: WebSocket instance
|
|
user_id: Optional user ID for the connection
|
|
topics: Initial topics to subscribe to
|
|
metadata: Additional metadata for the connection
|
|
|
|
Returns:
|
|
connection_id: Unique identifier for the connection
|
|
|
|
Raises:
|
|
ValueError: If maximum connections exceeded
|
|
"""
|
|
async with self._connections_lock:
|
|
# Check connection limits
|
|
if len(self._connections) >= self.max_total_connections:
|
|
raise ValueError(f"Maximum total connections ({self.max_total_connections}) exceeded")
|
|
|
|
# Generate unique connection ID
|
|
connection_id = f"ws_{uuid.uuid4().hex[:12]}"
|
|
|
|
# Create connection info
|
|
connection_info = ConnectionInfo(
|
|
id=connection_id,
|
|
websocket=websocket,
|
|
user_id=user_id,
|
|
topics=topics or set(),
|
|
state=ConnectionState.CONNECTING,
|
|
created_at=datetime.now(timezone.utc),
|
|
last_activity=datetime.now(timezone.utc),
|
|
last_ping=None,
|
|
last_pong=None,
|
|
error_count=0,
|
|
metadata=metadata or {}
|
|
)
|
|
|
|
# Store connection
|
|
self._connections[connection_id] = connection_info
|
|
|
|
# Update topic subscriptions
|
|
for topic in connection_info.topics:
|
|
if topic not in self._topics:
|
|
self._topics[topic] = set()
|
|
if len(self._topics[topic]) >= self.max_connections_per_topic:
|
|
# Remove this connection and raise error
|
|
del self._connections[connection_id]
|
|
raise ValueError(f"Maximum connections per topic ({self.max_connections_per_topic}) exceeded for topic: {topic}")
|
|
self._topics[topic].add(connection_id)
|
|
|
|
# Update user connections mapping
|
|
if user_id:
|
|
if user_id not in self._user_connections:
|
|
self._user_connections[user_id] = set()
|
|
self._user_connections[user_id].add(connection_id)
|
|
|
|
# Update statistics
|
|
self._stats["total_connections"] += 1
|
|
self._stats["active_connections"] = len(self._connections)
|
|
|
|
self.logger.info("Added WebSocket connection",
|
|
connection_id=connection_id,
|
|
user_id=user_id,
|
|
topics=list(connection_info.topics),
|
|
total_connections=self._stats["active_connections"])
|
|
|
|
return connection_id
|
|
|
|
async def remove_connection(self, connection_id: str, reason: str = "unknown"):
|
|
"""Remove a WebSocket connection from the pool"""
|
|
async with self._connections_lock:
|
|
connection_info = self._connections.get(connection_id)
|
|
if not connection_info:
|
|
return
|
|
|
|
# Update state
|
|
connection_info.state = ConnectionState.DISCONNECTING
|
|
|
|
# Remove from topics
|
|
for topic in connection_info.topics:
|
|
if topic in self._topics:
|
|
self._topics[topic].discard(connection_id)
|
|
if not self._topics[topic]:
|
|
del self._topics[topic]
|
|
|
|
# Remove from user connections
|
|
if connection_info.user_id and connection_info.user_id in self._user_connections:
|
|
self._user_connections[connection_info.user_id].discard(connection_id)
|
|
if not self._user_connections[connection_info.user_id]:
|
|
del self._user_connections[connection_info.user_id]
|
|
|
|
# Remove from connections
|
|
del self._connections[connection_id]
|
|
|
|
# Update statistics
|
|
self._stats["active_connections"] = len(self._connections)
|
|
|
|
self.logger.info("Removed WebSocket connection",
|
|
connection_id=connection_id,
|
|
reason=reason,
|
|
user_id=connection_info.user_id,
|
|
total_connections=self._stats["active_connections"])
|
|
|
|
async def subscribe_to_topic(self, connection_id: str, topic: str) -> bool:
|
|
"""Subscribe a connection to a topic"""
|
|
async with self._connections_lock:
|
|
connection_info = self._connections.get(connection_id)
|
|
if not connection_info or not connection_info.is_alive():
|
|
return False
|
|
|
|
# Check topic connection limit
|
|
if topic not in self._topics:
|
|
self._topics[topic] = set()
|
|
if len(self._topics[topic]) >= self.max_connections_per_topic:
|
|
self.logger.warning("Topic connection limit exceeded",
|
|
topic=topic,
|
|
connection_id=connection_id,
|
|
current_count=len(self._topics[topic]))
|
|
return False
|
|
|
|
# Add to topic and connection
|
|
self._topics[topic].add(connection_id)
|
|
connection_info.topics.add(topic)
|
|
connection_info.update_activity()
|
|
|
|
self.logger.debug("Connection subscribed to topic",
|
|
connection_id=connection_id,
|
|
topic=topic,
|
|
topic_subscribers=len(self._topics[topic]))
|
|
|
|
return True
|
|
|
|
async def unsubscribe_from_topic(self, connection_id: str, topic: str) -> bool:
|
|
"""Unsubscribe a connection from a topic"""
|
|
async with self._connections_lock:
|
|
connection_info = self._connections.get(connection_id)
|
|
if not connection_info:
|
|
return False
|
|
|
|
# Remove from topic and connection
|
|
if topic in self._topics:
|
|
self._topics[topic].discard(connection_id)
|
|
if not self._topics[topic]:
|
|
del self._topics[topic]
|
|
|
|
connection_info.topics.discard(topic)
|
|
connection_info.update_activity()
|
|
|
|
self.logger.debug("Connection unsubscribed from topic",
|
|
connection_id=connection_id,
|
|
topic=topic)
|
|
|
|
return True
|
|
|
|
async def broadcast_to_topic(
|
|
self,
|
|
topic: str,
|
|
message: Union[WebSocketMessage, Dict[str, Any]],
|
|
exclude_connection_id: Optional[str] = None
|
|
) -> int:
|
|
"""
|
|
Broadcast a message to all connections subscribed to a topic
|
|
|
|
Returns:
|
|
Number of successful sends
|
|
"""
|
|
if isinstance(message, dict):
|
|
message = WebSocketMessage(**message)
|
|
|
|
# Ensure timestamp is set
|
|
if not message.timestamp:
|
|
message.timestamp = datetime.now(timezone.utc).isoformat()
|
|
|
|
# Get connection IDs for the topic
|
|
async with self._connections_lock:
|
|
connection_ids = list(self._topics.get(topic, set()))
|
|
if exclude_connection_id:
|
|
connection_ids = [cid for cid in connection_ids if cid != exclude_connection_id]
|
|
|
|
if not connection_ids:
|
|
return 0
|
|
|
|
# Send to all connections (outside the lock to avoid blocking)
|
|
success_count = 0
|
|
failed_connections = []
|
|
|
|
for connection_id in connection_ids:
|
|
try:
|
|
success = await self._send_to_connection(connection_id, message)
|
|
if success:
|
|
success_count += 1
|
|
else:
|
|
failed_connections.append(connection_id)
|
|
except Exception as e:
|
|
self.logger.error("Error broadcasting to connection",
|
|
connection_id=connection_id,
|
|
topic=topic,
|
|
error=str(e))
|
|
failed_connections.append(connection_id)
|
|
|
|
# Update statistics
|
|
self._stats["messages_sent"] += success_count
|
|
self._stats["messages_failed"] += len(failed_connections)
|
|
|
|
# Clean up failed connections
|
|
if failed_connections:
|
|
for connection_id in failed_connections:
|
|
await self.remove_connection(connection_id, "broadcast_failed")
|
|
|
|
self.logger.debug("Broadcast completed",
|
|
topic=topic,
|
|
total_targets=len(connection_ids),
|
|
successful=success_count,
|
|
failed=len(failed_connections))
|
|
|
|
return success_count
|
|
|
|
async def send_to_user(
|
|
self,
|
|
user_id: int,
|
|
message: Union[WebSocketMessage, Dict[str, Any]]
|
|
) -> int:
|
|
"""
|
|
Send a message to all connections for a specific user
|
|
|
|
Returns:
|
|
Number of successful sends
|
|
"""
|
|
if isinstance(message, dict):
|
|
message = WebSocketMessage(**message)
|
|
|
|
# Get connection IDs for the user
|
|
async with self._connections_lock:
|
|
connection_ids = list(self._user_connections.get(user_id, set()))
|
|
|
|
if not connection_ids:
|
|
return 0
|
|
|
|
# Send to all user connections
|
|
success_count = 0
|
|
for connection_id in connection_ids:
|
|
try:
|
|
success = await self._send_to_connection(connection_id, message)
|
|
if success:
|
|
success_count += 1
|
|
except Exception as e:
|
|
self.logger.error("Error sending to user connection",
|
|
connection_id=connection_id,
|
|
user_id=user_id,
|
|
error=str(e))
|
|
|
|
return success_count
|
|
|
|
async def _send_to_connection(self, connection_id: str, message: WebSocketMessage) -> bool:
|
|
"""Send a message to a specific connection"""
|
|
async with self._connections_lock:
|
|
connection_info = self._connections.get(connection_id)
|
|
if not connection_info or not connection_info.is_alive():
|
|
return False
|
|
|
|
websocket = connection_info.websocket
|
|
|
|
try:
|
|
await websocket.send_json(message.to_dict())
|
|
connection_info.update_activity()
|
|
return True
|
|
except Exception as e:
|
|
connection_info.error_count += 1
|
|
connection_info.state = ConnectionState.ERROR
|
|
self.logger.warning("Failed to send message to connection",
|
|
connection_id=connection_id,
|
|
error=str(e),
|
|
error_count=connection_info.error_count)
|
|
return False
|
|
|
|
async def ping_connection(self, connection_id: str) -> bool:
|
|
"""Send a ping to a specific connection"""
|
|
ping_message = WebSocketMessage(
|
|
type=MessageType.PING.value,
|
|
timestamp=datetime.now(timezone.utc).isoformat()
|
|
)
|
|
|
|
success = await self._send_to_connection(connection_id, ping_message)
|
|
if success:
|
|
async with self._connections_lock:
|
|
connection_info = self._connections.get(connection_id)
|
|
if connection_info:
|
|
connection_info.last_ping = datetime.now(timezone.utc)
|
|
|
|
return success
|
|
|
|
async def handle_pong(self, connection_id: str):
|
|
"""Handle a pong response from a connection"""
|
|
async with self._connections_lock:
|
|
connection_info = self._connections.get(connection_id)
|
|
if connection_info:
|
|
connection_info.last_pong = datetime.now(timezone.utc)
|
|
connection_info.update_activity()
|
|
connection_info.state = ConnectionState.CONNECTED
|
|
|
|
async def get_connection_info(self, connection_id: str) -> Optional[ConnectionInfo]:
|
|
"""Get information about a specific connection"""
|
|
async with self._connections_lock:
|
|
info = self._connections.get(connection_id)
|
|
# Fallback to global pool if this instance is not the registered one
|
|
# This supports tests that instantiate a local pool while the context
|
|
# manager uses the global pool created by app startup.
|
|
if info is None:
|
|
global _websocket_pool
|
|
if _websocket_pool is not None and _websocket_pool is not self:
|
|
return await _websocket_pool.get_connection_info(connection_id)
|
|
return info
|
|
|
|
async def get_topic_connections(self, topic: str) -> List[str]:
|
|
"""Get all connection IDs subscribed to a topic"""
|
|
async with self._connections_lock:
|
|
return list(self._topics.get(topic, set()))
|
|
|
|
async def get_user_connections(self, user_id: int) -> List[str]:
|
|
"""Get all connection IDs for a user"""
|
|
async with self._connections_lock:
|
|
return list(self._user_connections.get(user_id, set()))
|
|
|
|
async def get_stats(self) -> Dict[str, Any]:
|
|
"""Get pool statistics"""
|
|
async with self._connections_lock:
|
|
active_by_state = {}
|
|
for conn in self._connections.values():
|
|
state = conn.state.value
|
|
active_by_state[state] = active_by_state.get(state, 0) + 1
|
|
|
|
# Compute total unique users robustly (avoid falsey user_id like 0)
|
|
try:
|
|
unique_user_ids = {conn.user_id for conn in self._connections.values() if conn.user_id is not None}
|
|
except Exception:
|
|
unique_user_ids = set(self._user_connections.keys())
|
|
|
|
return {
|
|
**self._stats,
|
|
"active_connections": len(self._connections),
|
|
"total_topics": len(self._topics),
|
|
"total_users": len(unique_user_ids),
|
|
"connections_by_state": active_by_state,
|
|
"topic_distribution": {topic: len(conn_ids) for topic, conn_ids in self._topics.items()},
|
|
}
|
|
|
|
async def _cleanup_worker(self):
|
|
"""Background task to clean up stale connections"""
|
|
while True:
|
|
try:
|
|
await asyncio.sleep(self.cleanup_interval)
|
|
await self._cleanup_stale_connections()
|
|
self._stats["last_cleanup"] = datetime.now(timezone.utc).isoformat()
|
|
except asyncio.CancelledError:
|
|
break
|
|
except Exception as e:
|
|
self.logger.error("Error in cleanup worker", error=str(e))
|
|
|
|
async def _heartbeat_worker(self):
|
|
"""Background task to send heartbeats to connections"""
|
|
while True:
|
|
try:
|
|
await asyncio.sleep(self.heartbeat_interval)
|
|
await self._send_heartbeats()
|
|
self._stats["last_heartbeat"] = datetime.now(timezone.utc).isoformat()
|
|
except asyncio.CancelledError:
|
|
break
|
|
except Exception as e:
|
|
self.logger.error("Error in heartbeat worker", error=str(e))
|
|
|
|
async def _cleanup_stale_connections(self):
|
|
"""Clean up stale and disconnected connections"""
|
|
stale_connections = []
|
|
|
|
async with self._connections_lock:
|
|
for connection_id, connection_info in self._connections.items():
|
|
if connection_info.is_stale(self.connection_timeout):
|
|
stale_connections.append(connection_id)
|
|
|
|
# Remove stale connections
|
|
for connection_id in stale_connections:
|
|
await self.remove_connection(connection_id, "stale_connection")
|
|
|
|
if stale_connections:
|
|
self._stats["connections_cleaned"] += len(stale_connections)
|
|
self.logger.info("Cleaned up stale connections",
|
|
count=len(stale_connections),
|
|
total_cleaned=self._stats["connections_cleaned"])
|
|
|
|
async def _send_heartbeats(self):
|
|
"""Send heartbeats to all active connections"""
|
|
async with self._connections_lock:
|
|
connection_ids = list(self._connections.keys())
|
|
|
|
heartbeat_message = WebSocketMessage(
|
|
type=MessageType.HEARTBEAT.value,
|
|
timestamp=datetime.now(timezone.utc).isoformat()
|
|
)
|
|
|
|
failed_connections = []
|
|
for connection_id in connection_ids:
|
|
try:
|
|
success = await self._send_to_connection(connection_id, heartbeat_message)
|
|
if not success:
|
|
failed_connections.append(connection_id)
|
|
except Exception:
|
|
failed_connections.append(connection_id)
|
|
|
|
# Clean up failed connections
|
|
for connection_id in failed_connections:
|
|
await self.remove_connection(connection_id, "heartbeat_failed")
|
|
|
|
async def _close_all_connections(self):
|
|
"""Close all active connections"""
|
|
async with self._connections_lock:
|
|
connection_ids = list(self._connections.keys())
|
|
|
|
for connection_id in connection_ids:
|
|
await self.remove_connection(connection_id, "pool_shutdown")
|
|
|
|
|
|
# Global WebSocket pool instance
|
|
_websocket_pool: Optional[WebSocketPool] = None
|
|
|
|
|
|
def get_websocket_pool() -> WebSocketPool:
|
|
"""Get the global WebSocket pool instance"""
|
|
global _websocket_pool
|
|
if _websocket_pool is None:
|
|
_websocket_pool = WebSocketPool()
|
|
return _websocket_pool
|
|
|
|
|
|
async def initialize_websocket_pool(**kwargs) -> WebSocketPool:
|
|
"""Initialize and start the global WebSocket pool"""
|
|
global _websocket_pool
|
|
if _websocket_pool is None:
|
|
_websocket_pool = WebSocketPool(**kwargs)
|
|
await _websocket_pool.start()
|
|
return _websocket_pool
|
|
|
|
|
|
async def shutdown_websocket_pool():
|
|
"""Shutdown the global WebSocket pool"""
|
|
global _websocket_pool
|
|
if _websocket_pool is not None:
|
|
await _websocket_pool.stop()
|
|
_websocket_pool = None
|
|
|
|
|
|
@asynccontextmanager
|
|
async def websocket_connection(
|
|
websocket: WebSocket,
|
|
user_id: Optional[int] = None,
|
|
topics: Optional[Set[str]] = None,
|
|
metadata: Optional[Dict[str, Any]] = None
|
|
):
|
|
"""
|
|
Context manager for WebSocket connections
|
|
|
|
Automatically handles connection registration and cleanup
|
|
"""
|
|
pool = get_websocket_pool()
|
|
connection_id = None
|
|
|
|
try:
|
|
connection_id = await pool.add_connection(websocket, user_id, topics, metadata)
|
|
yield connection_id, pool
|
|
finally:
|
|
if connection_id:
|
|
await pool.remove_connection(connection_id, "context_exit")
|