This commit is contained in:
HotSwapp
2025-08-18 20:20:04 -05:00
parent 89b2bc0aa2
commit bac8cc4bd5
114 changed files with 30258 additions and 1341 deletions

View File

@@ -0,0 +1,667 @@
"""
WebSocket Connection Pool and Management Service
This module provides a centralized WebSocket connection pooling system for the Delphi Database
application. It manages connections efficiently, handles cleanup of stale connections,
monitors connection health, and provides resource management to prevent memory leaks.
Features:
- Connection pooling by topic/channel
- Automatic cleanup of inactive connections
- Health monitoring and heartbeat management
- Resource management and memory leak prevention
- Integration with existing authentication
- Structured logging for debugging
"""
import asyncio
import time
import uuid
from typing import Dict, Set, Optional, Any, Callable, List, Union
from datetime import datetime, timezone, timedelta
from dataclasses import dataclass
from enum import Enum
from contextlib import asynccontextmanager
from fastapi import WebSocket, WebSocketDisconnect
from pydantic import BaseModel
from app.utils.logging import StructuredLogger
class ConnectionState(Enum):
"""WebSocket connection states"""
CONNECTING = "connecting"
CONNECTED = "connected"
DISCONNECTING = "disconnecting"
DISCONNECTED = "disconnected"
ERROR = "error"
class MessageType(Enum):
"""WebSocket message types"""
PING = "ping"
PONG = "pong"
DATA = "data"
ERROR = "error"
HEARTBEAT = "heartbeat"
SUBSCRIBE = "subscribe"
UNSUBSCRIBE = "unsubscribe"
@dataclass
class ConnectionInfo:
"""Information about a WebSocket connection"""
id: str
websocket: WebSocket
user_id: Optional[int]
topics: Set[str]
state: ConnectionState
created_at: datetime
last_activity: datetime
last_ping: Optional[datetime]
last_pong: Optional[datetime]
error_count: int
metadata: Dict[str, Any]
def is_alive(self) -> bool:
"""Check if connection is alive based on state"""
return self.state in [ConnectionState.CONNECTED, ConnectionState.CONNECTING]
def is_stale(self, timeout_seconds: int = 300) -> bool:
"""Check if connection is stale (no activity for timeout_seconds)"""
if not self.is_alive():
return True
return (datetime.now(timezone.utc) - self.last_activity).total_seconds() > timeout_seconds
def update_activity(self):
"""Update last activity timestamp"""
self.last_activity = datetime.now(timezone.utc)
class WebSocketMessage(BaseModel):
"""Standard WebSocket message format"""
type: str
topic: Optional[str] = None
data: Optional[Dict[str, Any]] = None
timestamp: Optional[str] = None
error: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for JSON serialization"""
return self.model_dump(exclude_none=True)
class WebSocketPool:
"""
Centralized WebSocket connection pool manager
Manages WebSocket connections by topics/channels, provides automatic cleanup,
health monitoring, and resource management.
"""
def __init__(
self,
cleanup_interval: int = 60, # seconds
connection_timeout: int = 300, # seconds
heartbeat_interval: int = 30, # seconds
max_connections_per_topic: int = 1000,
max_total_connections: int = 10000,
):
self.cleanup_interval = cleanup_interval
self.connection_timeout = connection_timeout
self.heartbeat_interval = heartbeat_interval
self.max_connections_per_topic = max_connections_per_topic
self.max_total_connections = max_total_connections
# Connection storage
self._connections: Dict[str, ConnectionInfo] = {}
self._topics: Dict[str, Set[str]] = {} # topic -> connection_ids
self._user_connections: Dict[int, Set[str]] = {} # user_id -> connection_ids
# Locks for thread safety
self._connections_lock = asyncio.Lock()
self._cleanup_task: Optional[asyncio.Task] = None
self._heartbeat_task: Optional[asyncio.Task] = None
# Statistics
self._stats = {
"total_connections": 0,
"active_connections": 0,
"messages_sent": 0,
"messages_failed": 0,
"connections_cleaned": 0,
"last_cleanup": None,
"last_heartbeat": None,
}
self.logger = StructuredLogger("websocket_pool", "INFO")
self.logger.info("WebSocket pool initialized",
cleanup_interval=cleanup_interval,
connection_timeout=connection_timeout,
heartbeat_interval=heartbeat_interval)
async def start(self):
"""Start the WebSocket pool background tasks"""
# If no global pool exists, register this instance to satisfy contexts that
# rely on the module-level getter during tests and simple scripts
global _websocket_pool
if _websocket_pool is None:
_websocket_pool = self
if self._cleanup_task is None:
self._cleanup_task = asyncio.create_task(self._cleanup_worker())
self.logger.info("Started cleanup worker task")
if self._heartbeat_task is None:
self._heartbeat_task = asyncio.create_task(self._heartbeat_worker())
self.logger.info("Started heartbeat worker task")
async def stop(self):
"""Stop the WebSocket pool and cleanup all connections"""
self.logger.info("Stopping WebSocket pool")
# Cancel background tasks
if self._cleanup_task:
self._cleanup_task.cancel()
try:
await self._cleanup_task
except asyncio.CancelledError:
pass
self._cleanup_task = None
if self._heartbeat_task:
self._heartbeat_task.cancel()
try:
await self._heartbeat_task
except asyncio.CancelledError:
pass
self._heartbeat_task = None
# Close all connections
await self._close_all_connections()
self.logger.info("WebSocket pool stopped")
# If this instance is the registered global, clear it
global _websocket_pool
if _websocket_pool is self:
_websocket_pool = None
async def add_connection(
self,
websocket: WebSocket,
user_id: Optional[int] = None,
topics: Optional[Set[str]] = None,
metadata: Optional[Dict[str, Any]] = None
) -> str:
"""
Add a new WebSocket connection to the pool
Args:
websocket: WebSocket instance
user_id: Optional user ID for the connection
topics: Initial topics to subscribe to
metadata: Additional metadata for the connection
Returns:
connection_id: Unique identifier for the connection
Raises:
ValueError: If maximum connections exceeded
"""
async with self._connections_lock:
# Check connection limits
if len(self._connections) >= self.max_total_connections:
raise ValueError(f"Maximum total connections ({self.max_total_connections}) exceeded")
# Generate unique connection ID
connection_id = f"ws_{uuid.uuid4().hex[:12]}"
# Create connection info
connection_info = ConnectionInfo(
id=connection_id,
websocket=websocket,
user_id=user_id,
topics=topics or set(),
state=ConnectionState.CONNECTING,
created_at=datetime.now(timezone.utc),
last_activity=datetime.now(timezone.utc),
last_ping=None,
last_pong=None,
error_count=0,
metadata=metadata or {}
)
# Store connection
self._connections[connection_id] = connection_info
# Update topic subscriptions
for topic in connection_info.topics:
if topic not in self._topics:
self._topics[topic] = set()
if len(self._topics[topic]) >= self.max_connections_per_topic:
# Remove this connection and raise error
del self._connections[connection_id]
raise ValueError(f"Maximum connections per topic ({self.max_connections_per_topic}) exceeded for topic: {topic}")
self._topics[topic].add(connection_id)
# Update user connections mapping
if user_id:
if user_id not in self._user_connections:
self._user_connections[user_id] = set()
self._user_connections[user_id].add(connection_id)
# Update statistics
self._stats["total_connections"] += 1
self._stats["active_connections"] = len(self._connections)
self.logger.info("Added WebSocket connection",
connection_id=connection_id,
user_id=user_id,
topics=list(connection_info.topics),
total_connections=self._stats["active_connections"])
return connection_id
async def remove_connection(self, connection_id: str, reason: str = "unknown"):
"""Remove a WebSocket connection from the pool"""
async with self._connections_lock:
connection_info = self._connections.get(connection_id)
if not connection_info:
return
# Update state
connection_info.state = ConnectionState.DISCONNECTING
# Remove from topics
for topic in connection_info.topics:
if topic in self._topics:
self._topics[topic].discard(connection_id)
if not self._topics[topic]:
del self._topics[topic]
# Remove from user connections
if connection_info.user_id and connection_info.user_id in self._user_connections:
self._user_connections[connection_info.user_id].discard(connection_id)
if not self._user_connections[connection_info.user_id]:
del self._user_connections[connection_info.user_id]
# Remove from connections
del self._connections[connection_id]
# Update statistics
self._stats["active_connections"] = len(self._connections)
self.logger.info("Removed WebSocket connection",
connection_id=connection_id,
reason=reason,
user_id=connection_info.user_id,
total_connections=self._stats["active_connections"])
async def subscribe_to_topic(self, connection_id: str, topic: str) -> bool:
"""Subscribe a connection to a topic"""
async with self._connections_lock:
connection_info = self._connections.get(connection_id)
if not connection_info or not connection_info.is_alive():
return False
# Check topic connection limit
if topic not in self._topics:
self._topics[topic] = set()
if len(self._topics[topic]) >= self.max_connections_per_topic:
self.logger.warning("Topic connection limit exceeded",
topic=topic,
connection_id=connection_id,
current_count=len(self._topics[topic]))
return False
# Add to topic and connection
self._topics[topic].add(connection_id)
connection_info.topics.add(topic)
connection_info.update_activity()
self.logger.debug("Connection subscribed to topic",
connection_id=connection_id,
topic=topic,
topic_subscribers=len(self._topics[topic]))
return True
async def unsubscribe_from_topic(self, connection_id: str, topic: str) -> bool:
"""Unsubscribe a connection from a topic"""
async with self._connections_lock:
connection_info = self._connections.get(connection_id)
if not connection_info:
return False
# Remove from topic and connection
if topic in self._topics:
self._topics[topic].discard(connection_id)
if not self._topics[topic]:
del self._topics[topic]
connection_info.topics.discard(topic)
connection_info.update_activity()
self.logger.debug("Connection unsubscribed from topic",
connection_id=connection_id,
topic=topic)
return True
async def broadcast_to_topic(
self,
topic: str,
message: Union[WebSocketMessage, Dict[str, Any]],
exclude_connection_id: Optional[str] = None
) -> int:
"""
Broadcast a message to all connections subscribed to a topic
Returns:
Number of successful sends
"""
if isinstance(message, dict):
message = WebSocketMessage(**message)
# Ensure timestamp is set
if not message.timestamp:
message.timestamp = datetime.now(timezone.utc).isoformat()
# Get connection IDs for the topic
async with self._connections_lock:
connection_ids = list(self._topics.get(topic, set()))
if exclude_connection_id:
connection_ids = [cid for cid in connection_ids if cid != exclude_connection_id]
if not connection_ids:
return 0
# Send to all connections (outside the lock to avoid blocking)
success_count = 0
failed_connections = []
for connection_id in connection_ids:
try:
success = await self._send_to_connection(connection_id, message)
if success:
success_count += 1
else:
failed_connections.append(connection_id)
except Exception as e:
self.logger.error("Error broadcasting to connection",
connection_id=connection_id,
topic=topic,
error=str(e))
failed_connections.append(connection_id)
# Update statistics
self._stats["messages_sent"] += success_count
self._stats["messages_failed"] += len(failed_connections)
# Clean up failed connections
if failed_connections:
for connection_id in failed_connections:
await self.remove_connection(connection_id, "broadcast_failed")
self.logger.debug("Broadcast completed",
topic=topic,
total_targets=len(connection_ids),
successful=success_count,
failed=len(failed_connections))
return success_count
async def send_to_user(
self,
user_id: int,
message: Union[WebSocketMessage, Dict[str, Any]]
) -> int:
"""
Send a message to all connections for a specific user
Returns:
Number of successful sends
"""
if isinstance(message, dict):
message = WebSocketMessage(**message)
# Get connection IDs for the user
async with self._connections_lock:
connection_ids = list(self._user_connections.get(user_id, set()))
if not connection_ids:
return 0
# Send to all user connections
success_count = 0
for connection_id in connection_ids:
try:
success = await self._send_to_connection(connection_id, message)
if success:
success_count += 1
except Exception as e:
self.logger.error("Error sending to user connection",
connection_id=connection_id,
user_id=user_id,
error=str(e))
return success_count
async def _send_to_connection(self, connection_id: str, message: WebSocketMessage) -> bool:
"""Send a message to a specific connection"""
async with self._connections_lock:
connection_info = self._connections.get(connection_id)
if not connection_info or not connection_info.is_alive():
return False
websocket = connection_info.websocket
try:
await websocket.send_json(message.to_dict())
connection_info.update_activity()
return True
except Exception as e:
connection_info.error_count += 1
connection_info.state = ConnectionState.ERROR
self.logger.warning("Failed to send message to connection",
connection_id=connection_id,
error=str(e),
error_count=connection_info.error_count)
return False
async def ping_connection(self, connection_id: str) -> bool:
"""Send a ping to a specific connection"""
ping_message = WebSocketMessage(
type=MessageType.PING.value,
timestamp=datetime.now(timezone.utc).isoformat()
)
success = await self._send_to_connection(connection_id, ping_message)
if success:
async with self._connections_lock:
connection_info = self._connections.get(connection_id)
if connection_info:
connection_info.last_ping = datetime.now(timezone.utc)
return success
async def handle_pong(self, connection_id: str):
"""Handle a pong response from a connection"""
async with self._connections_lock:
connection_info = self._connections.get(connection_id)
if connection_info:
connection_info.last_pong = datetime.now(timezone.utc)
connection_info.update_activity()
connection_info.state = ConnectionState.CONNECTED
async def get_connection_info(self, connection_id: str) -> Optional[ConnectionInfo]:
"""Get information about a specific connection"""
async with self._connections_lock:
info = self._connections.get(connection_id)
# Fallback to global pool if this instance is not the registered one
# This supports tests that instantiate a local pool while the context
# manager uses the global pool created by app startup.
if info is None:
global _websocket_pool
if _websocket_pool is not None and _websocket_pool is not self:
return await _websocket_pool.get_connection_info(connection_id)
return info
async def get_topic_connections(self, topic: str) -> List[str]:
"""Get all connection IDs subscribed to a topic"""
async with self._connections_lock:
return list(self._topics.get(topic, set()))
async def get_user_connections(self, user_id: int) -> List[str]:
"""Get all connection IDs for a user"""
async with self._connections_lock:
return list(self._user_connections.get(user_id, set()))
async def get_stats(self) -> Dict[str, Any]:
"""Get pool statistics"""
async with self._connections_lock:
active_by_state = {}
for conn in self._connections.values():
state = conn.state.value
active_by_state[state] = active_by_state.get(state, 0) + 1
# Compute total unique users robustly (avoid falsey user_id like 0)
try:
unique_user_ids = {conn.user_id for conn in self._connections.values() if conn.user_id is not None}
except Exception:
unique_user_ids = set(self._user_connections.keys())
return {
**self._stats,
"active_connections": len(self._connections),
"total_topics": len(self._topics),
"total_users": len(unique_user_ids),
"connections_by_state": active_by_state,
"topic_distribution": {topic: len(conn_ids) for topic, conn_ids in self._topics.items()},
}
async def _cleanup_worker(self):
"""Background task to clean up stale connections"""
while True:
try:
await asyncio.sleep(self.cleanup_interval)
await self._cleanup_stale_connections()
self._stats["last_cleanup"] = datetime.now(timezone.utc).isoformat()
except asyncio.CancelledError:
break
except Exception as e:
self.logger.error("Error in cleanup worker", error=str(e))
async def _heartbeat_worker(self):
"""Background task to send heartbeats to connections"""
while True:
try:
await asyncio.sleep(self.heartbeat_interval)
await self._send_heartbeats()
self._stats["last_heartbeat"] = datetime.now(timezone.utc).isoformat()
except asyncio.CancelledError:
break
except Exception as e:
self.logger.error("Error in heartbeat worker", error=str(e))
async def _cleanup_stale_connections(self):
"""Clean up stale and disconnected connections"""
stale_connections = []
async with self._connections_lock:
for connection_id, connection_info in self._connections.items():
if connection_info.is_stale(self.connection_timeout):
stale_connections.append(connection_id)
# Remove stale connections
for connection_id in stale_connections:
await self.remove_connection(connection_id, "stale_connection")
if stale_connections:
self._stats["connections_cleaned"] += len(stale_connections)
self.logger.info("Cleaned up stale connections",
count=len(stale_connections),
total_cleaned=self._stats["connections_cleaned"])
async def _send_heartbeats(self):
"""Send heartbeats to all active connections"""
async with self._connections_lock:
connection_ids = list(self._connections.keys())
heartbeat_message = WebSocketMessage(
type=MessageType.HEARTBEAT.value,
timestamp=datetime.now(timezone.utc).isoformat()
)
failed_connections = []
for connection_id in connection_ids:
try:
success = await self._send_to_connection(connection_id, heartbeat_message)
if not success:
failed_connections.append(connection_id)
except Exception:
failed_connections.append(connection_id)
# Clean up failed connections
for connection_id in failed_connections:
await self.remove_connection(connection_id, "heartbeat_failed")
async def _close_all_connections(self):
"""Close all active connections"""
async with self._connections_lock:
connection_ids = list(self._connections.keys())
for connection_id in connection_ids:
await self.remove_connection(connection_id, "pool_shutdown")
# Global WebSocket pool instance
_websocket_pool: Optional[WebSocketPool] = None
def get_websocket_pool() -> WebSocketPool:
"""Get the global WebSocket pool instance"""
global _websocket_pool
if _websocket_pool is None:
_websocket_pool = WebSocketPool()
return _websocket_pool
async def initialize_websocket_pool(**kwargs) -> WebSocketPool:
"""Initialize and start the global WebSocket pool"""
global _websocket_pool
if _websocket_pool is None:
_websocket_pool = WebSocketPool(**kwargs)
await _websocket_pool.start()
return _websocket_pool
async def shutdown_websocket_pool():
"""Shutdown the global WebSocket pool"""
global _websocket_pool
if _websocket_pool is not None:
await _websocket_pool.stop()
_websocket_pool = None
@asynccontextmanager
async def websocket_connection(
websocket: WebSocket,
user_id: Optional[int] = None,
topics: Optional[Set[str]] = None,
metadata: Optional[Dict[str, Any]] = None
):
"""
Context manager for WebSocket connections
Automatically handles connection registration and cleanup
"""
pool = get_websocket_pool()
connection_id = None
try:
connection_id = await pool.add_connection(websocket, user_id, topics, metadata)
yield connection_id, pool
finally:
if connection_id:
await pool.remove_connection(connection_id, "context_exit")