changes
This commit is contained in:
667
app/services/websocket_pool.py
Normal file
667
app/services/websocket_pool.py
Normal file
@@ -0,0 +1,667 @@
|
||||
"""
|
||||
WebSocket Connection Pool and Management Service
|
||||
|
||||
This module provides a centralized WebSocket connection pooling system for the Delphi Database
|
||||
application. It manages connections efficiently, handles cleanup of stale connections,
|
||||
monitors connection health, and provides resource management to prevent memory leaks.
|
||||
|
||||
Features:
|
||||
- Connection pooling by topic/channel
|
||||
- Automatic cleanup of inactive connections
|
||||
- Health monitoring and heartbeat management
|
||||
- Resource management and memory leak prevention
|
||||
- Integration with existing authentication
|
||||
- Structured logging for debugging
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import uuid
|
||||
from typing import Dict, Set, Optional, Any, Callable, List, Union
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import WebSocket, WebSocketDisconnect
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.utils.logging import StructuredLogger
|
||||
|
||||
|
||||
class ConnectionState(Enum):
|
||||
"""WebSocket connection states"""
|
||||
CONNECTING = "connecting"
|
||||
CONNECTED = "connected"
|
||||
DISCONNECTING = "disconnecting"
|
||||
DISCONNECTED = "disconnected"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class MessageType(Enum):
|
||||
"""WebSocket message types"""
|
||||
PING = "ping"
|
||||
PONG = "pong"
|
||||
DATA = "data"
|
||||
ERROR = "error"
|
||||
HEARTBEAT = "heartbeat"
|
||||
SUBSCRIBE = "subscribe"
|
||||
UNSUBSCRIBE = "unsubscribe"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConnectionInfo:
|
||||
"""Information about a WebSocket connection"""
|
||||
id: str
|
||||
websocket: WebSocket
|
||||
user_id: Optional[int]
|
||||
topics: Set[str]
|
||||
state: ConnectionState
|
||||
created_at: datetime
|
||||
last_activity: datetime
|
||||
last_ping: Optional[datetime]
|
||||
last_pong: Optional[datetime]
|
||||
error_count: int
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
def is_alive(self) -> bool:
|
||||
"""Check if connection is alive based on state"""
|
||||
return self.state in [ConnectionState.CONNECTED, ConnectionState.CONNECTING]
|
||||
|
||||
def is_stale(self, timeout_seconds: int = 300) -> bool:
|
||||
"""Check if connection is stale (no activity for timeout_seconds)"""
|
||||
if not self.is_alive():
|
||||
return True
|
||||
return (datetime.now(timezone.utc) - self.last_activity).total_seconds() > timeout_seconds
|
||||
|
||||
def update_activity(self):
|
||||
"""Update last activity timestamp"""
|
||||
self.last_activity = datetime.now(timezone.utc)
|
||||
|
||||
|
||||
class WebSocketMessage(BaseModel):
|
||||
"""Standard WebSocket message format"""
|
||||
type: str
|
||||
topic: Optional[str] = None
|
||||
data: Optional[Dict[str, Any]] = None
|
||||
timestamp: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for JSON serialization"""
|
||||
return self.model_dump(exclude_none=True)
|
||||
|
||||
|
||||
class WebSocketPool:
|
||||
"""
|
||||
Centralized WebSocket connection pool manager
|
||||
|
||||
Manages WebSocket connections by topics/channels, provides automatic cleanup,
|
||||
health monitoring, and resource management.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cleanup_interval: int = 60, # seconds
|
||||
connection_timeout: int = 300, # seconds
|
||||
heartbeat_interval: int = 30, # seconds
|
||||
max_connections_per_topic: int = 1000,
|
||||
max_total_connections: int = 10000,
|
||||
):
|
||||
self.cleanup_interval = cleanup_interval
|
||||
self.connection_timeout = connection_timeout
|
||||
self.heartbeat_interval = heartbeat_interval
|
||||
self.max_connections_per_topic = max_connections_per_topic
|
||||
self.max_total_connections = max_total_connections
|
||||
|
||||
# Connection storage
|
||||
self._connections: Dict[str, ConnectionInfo] = {}
|
||||
self._topics: Dict[str, Set[str]] = {} # topic -> connection_ids
|
||||
self._user_connections: Dict[int, Set[str]] = {} # user_id -> connection_ids
|
||||
|
||||
# Locks for thread safety
|
||||
self._connections_lock = asyncio.Lock()
|
||||
self._cleanup_task: Optional[asyncio.Task] = None
|
||||
self._heartbeat_task: Optional[asyncio.Task] = None
|
||||
|
||||
# Statistics
|
||||
self._stats = {
|
||||
"total_connections": 0,
|
||||
"active_connections": 0,
|
||||
"messages_sent": 0,
|
||||
"messages_failed": 0,
|
||||
"connections_cleaned": 0,
|
||||
"last_cleanup": None,
|
||||
"last_heartbeat": None,
|
||||
}
|
||||
|
||||
self.logger = StructuredLogger("websocket_pool", "INFO")
|
||||
self.logger.info("WebSocket pool initialized",
|
||||
cleanup_interval=cleanup_interval,
|
||||
connection_timeout=connection_timeout,
|
||||
heartbeat_interval=heartbeat_interval)
|
||||
|
||||
async def start(self):
|
||||
"""Start the WebSocket pool background tasks"""
|
||||
# If no global pool exists, register this instance to satisfy contexts that
|
||||
# rely on the module-level getter during tests and simple scripts
|
||||
global _websocket_pool
|
||||
if _websocket_pool is None:
|
||||
_websocket_pool = self
|
||||
if self._cleanup_task is None:
|
||||
self._cleanup_task = asyncio.create_task(self._cleanup_worker())
|
||||
self.logger.info("Started cleanup worker task")
|
||||
|
||||
if self._heartbeat_task is None:
|
||||
self._heartbeat_task = asyncio.create_task(self._heartbeat_worker())
|
||||
self.logger.info("Started heartbeat worker task")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the WebSocket pool and cleanup all connections"""
|
||||
self.logger.info("Stopping WebSocket pool")
|
||||
|
||||
# Cancel background tasks
|
||||
if self._cleanup_task:
|
||||
self._cleanup_task.cancel()
|
||||
try:
|
||||
await self._cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._cleanup_task = None
|
||||
|
||||
if self._heartbeat_task:
|
||||
self._heartbeat_task.cancel()
|
||||
try:
|
||||
await self._heartbeat_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._heartbeat_task = None
|
||||
|
||||
# Close all connections
|
||||
await self._close_all_connections()
|
||||
|
||||
self.logger.info("WebSocket pool stopped")
|
||||
# If this instance is the registered global, clear it
|
||||
global _websocket_pool
|
||||
if _websocket_pool is self:
|
||||
_websocket_pool = None
|
||||
|
||||
async def add_connection(
|
||||
self,
|
||||
websocket: WebSocket,
|
||||
user_id: Optional[int] = None,
|
||||
topics: Optional[Set[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> str:
|
||||
"""
|
||||
Add a new WebSocket connection to the pool
|
||||
|
||||
Args:
|
||||
websocket: WebSocket instance
|
||||
user_id: Optional user ID for the connection
|
||||
topics: Initial topics to subscribe to
|
||||
metadata: Additional metadata for the connection
|
||||
|
||||
Returns:
|
||||
connection_id: Unique identifier for the connection
|
||||
|
||||
Raises:
|
||||
ValueError: If maximum connections exceeded
|
||||
"""
|
||||
async with self._connections_lock:
|
||||
# Check connection limits
|
||||
if len(self._connections) >= self.max_total_connections:
|
||||
raise ValueError(f"Maximum total connections ({self.max_total_connections}) exceeded")
|
||||
|
||||
# Generate unique connection ID
|
||||
connection_id = f"ws_{uuid.uuid4().hex[:12]}"
|
||||
|
||||
# Create connection info
|
||||
connection_info = ConnectionInfo(
|
||||
id=connection_id,
|
||||
websocket=websocket,
|
||||
user_id=user_id,
|
||||
topics=topics or set(),
|
||||
state=ConnectionState.CONNECTING,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
last_activity=datetime.now(timezone.utc),
|
||||
last_ping=None,
|
||||
last_pong=None,
|
||||
error_count=0,
|
||||
metadata=metadata or {}
|
||||
)
|
||||
|
||||
# Store connection
|
||||
self._connections[connection_id] = connection_info
|
||||
|
||||
# Update topic subscriptions
|
||||
for topic in connection_info.topics:
|
||||
if topic not in self._topics:
|
||||
self._topics[topic] = set()
|
||||
if len(self._topics[topic]) >= self.max_connections_per_topic:
|
||||
# Remove this connection and raise error
|
||||
del self._connections[connection_id]
|
||||
raise ValueError(f"Maximum connections per topic ({self.max_connections_per_topic}) exceeded for topic: {topic}")
|
||||
self._topics[topic].add(connection_id)
|
||||
|
||||
# Update user connections mapping
|
||||
if user_id:
|
||||
if user_id not in self._user_connections:
|
||||
self._user_connections[user_id] = set()
|
||||
self._user_connections[user_id].add(connection_id)
|
||||
|
||||
# Update statistics
|
||||
self._stats["total_connections"] += 1
|
||||
self._stats["active_connections"] = len(self._connections)
|
||||
|
||||
self.logger.info("Added WebSocket connection",
|
||||
connection_id=connection_id,
|
||||
user_id=user_id,
|
||||
topics=list(connection_info.topics),
|
||||
total_connections=self._stats["active_connections"])
|
||||
|
||||
return connection_id
|
||||
|
||||
async def remove_connection(self, connection_id: str, reason: str = "unknown"):
|
||||
"""Remove a WebSocket connection from the pool"""
|
||||
async with self._connections_lock:
|
||||
connection_info = self._connections.get(connection_id)
|
||||
if not connection_info:
|
||||
return
|
||||
|
||||
# Update state
|
||||
connection_info.state = ConnectionState.DISCONNECTING
|
||||
|
||||
# Remove from topics
|
||||
for topic in connection_info.topics:
|
||||
if topic in self._topics:
|
||||
self._topics[topic].discard(connection_id)
|
||||
if not self._topics[topic]:
|
||||
del self._topics[topic]
|
||||
|
||||
# Remove from user connections
|
||||
if connection_info.user_id and connection_info.user_id in self._user_connections:
|
||||
self._user_connections[connection_info.user_id].discard(connection_id)
|
||||
if not self._user_connections[connection_info.user_id]:
|
||||
del self._user_connections[connection_info.user_id]
|
||||
|
||||
# Remove from connections
|
||||
del self._connections[connection_id]
|
||||
|
||||
# Update statistics
|
||||
self._stats["active_connections"] = len(self._connections)
|
||||
|
||||
self.logger.info("Removed WebSocket connection",
|
||||
connection_id=connection_id,
|
||||
reason=reason,
|
||||
user_id=connection_info.user_id,
|
||||
total_connections=self._stats["active_connections"])
|
||||
|
||||
async def subscribe_to_topic(self, connection_id: str, topic: str) -> bool:
|
||||
"""Subscribe a connection to a topic"""
|
||||
async with self._connections_lock:
|
||||
connection_info = self._connections.get(connection_id)
|
||||
if not connection_info or not connection_info.is_alive():
|
||||
return False
|
||||
|
||||
# Check topic connection limit
|
||||
if topic not in self._topics:
|
||||
self._topics[topic] = set()
|
||||
if len(self._topics[topic]) >= self.max_connections_per_topic:
|
||||
self.logger.warning("Topic connection limit exceeded",
|
||||
topic=topic,
|
||||
connection_id=connection_id,
|
||||
current_count=len(self._topics[topic]))
|
||||
return False
|
||||
|
||||
# Add to topic and connection
|
||||
self._topics[topic].add(connection_id)
|
||||
connection_info.topics.add(topic)
|
||||
connection_info.update_activity()
|
||||
|
||||
self.logger.debug("Connection subscribed to topic",
|
||||
connection_id=connection_id,
|
||||
topic=topic,
|
||||
topic_subscribers=len(self._topics[topic]))
|
||||
|
||||
return True
|
||||
|
||||
async def unsubscribe_from_topic(self, connection_id: str, topic: str) -> bool:
|
||||
"""Unsubscribe a connection from a topic"""
|
||||
async with self._connections_lock:
|
||||
connection_info = self._connections.get(connection_id)
|
||||
if not connection_info:
|
||||
return False
|
||||
|
||||
# Remove from topic and connection
|
||||
if topic in self._topics:
|
||||
self._topics[topic].discard(connection_id)
|
||||
if not self._topics[topic]:
|
||||
del self._topics[topic]
|
||||
|
||||
connection_info.topics.discard(topic)
|
||||
connection_info.update_activity()
|
||||
|
||||
self.logger.debug("Connection unsubscribed from topic",
|
||||
connection_id=connection_id,
|
||||
topic=topic)
|
||||
|
||||
return True
|
||||
|
||||
async def broadcast_to_topic(
|
||||
self,
|
||||
topic: str,
|
||||
message: Union[WebSocketMessage, Dict[str, Any]],
|
||||
exclude_connection_id: Optional[str] = None
|
||||
) -> int:
|
||||
"""
|
||||
Broadcast a message to all connections subscribed to a topic
|
||||
|
||||
Returns:
|
||||
Number of successful sends
|
||||
"""
|
||||
if isinstance(message, dict):
|
||||
message = WebSocketMessage(**message)
|
||||
|
||||
# Ensure timestamp is set
|
||||
if not message.timestamp:
|
||||
message.timestamp = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
# Get connection IDs for the topic
|
||||
async with self._connections_lock:
|
||||
connection_ids = list(self._topics.get(topic, set()))
|
||||
if exclude_connection_id:
|
||||
connection_ids = [cid for cid in connection_ids if cid != exclude_connection_id]
|
||||
|
||||
if not connection_ids:
|
||||
return 0
|
||||
|
||||
# Send to all connections (outside the lock to avoid blocking)
|
||||
success_count = 0
|
||||
failed_connections = []
|
||||
|
||||
for connection_id in connection_ids:
|
||||
try:
|
||||
success = await self._send_to_connection(connection_id, message)
|
||||
if success:
|
||||
success_count += 1
|
||||
else:
|
||||
failed_connections.append(connection_id)
|
||||
except Exception as e:
|
||||
self.logger.error("Error broadcasting to connection",
|
||||
connection_id=connection_id,
|
||||
topic=topic,
|
||||
error=str(e))
|
||||
failed_connections.append(connection_id)
|
||||
|
||||
# Update statistics
|
||||
self._stats["messages_sent"] += success_count
|
||||
self._stats["messages_failed"] += len(failed_connections)
|
||||
|
||||
# Clean up failed connections
|
||||
if failed_connections:
|
||||
for connection_id in failed_connections:
|
||||
await self.remove_connection(connection_id, "broadcast_failed")
|
||||
|
||||
self.logger.debug("Broadcast completed",
|
||||
topic=topic,
|
||||
total_targets=len(connection_ids),
|
||||
successful=success_count,
|
||||
failed=len(failed_connections))
|
||||
|
||||
return success_count
|
||||
|
||||
async def send_to_user(
|
||||
self,
|
||||
user_id: int,
|
||||
message: Union[WebSocketMessage, Dict[str, Any]]
|
||||
) -> int:
|
||||
"""
|
||||
Send a message to all connections for a specific user
|
||||
|
||||
Returns:
|
||||
Number of successful sends
|
||||
"""
|
||||
if isinstance(message, dict):
|
||||
message = WebSocketMessage(**message)
|
||||
|
||||
# Get connection IDs for the user
|
||||
async with self._connections_lock:
|
||||
connection_ids = list(self._user_connections.get(user_id, set()))
|
||||
|
||||
if not connection_ids:
|
||||
return 0
|
||||
|
||||
# Send to all user connections
|
||||
success_count = 0
|
||||
for connection_id in connection_ids:
|
||||
try:
|
||||
success = await self._send_to_connection(connection_id, message)
|
||||
if success:
|
||||
success_count += 1
|
||||
except Exception as e:
|
||||
self.logger.error("Error sending to user connection",
|
||||
connection_id=connection_id,
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
|
||||
return success_count
|
||||
|
||||
async def _send_to_connection(self, connection_id: str, message: WebSocketMessage) -> bool:
|
||||
"""Send a message to a specific connection"""
|
||||
async with self._connections_lock:
|
||||
connection_info = self._connections.get(connection_id)
|
||||
if not connection_info or not connection_info.is_alive():
|
||||
return False
|
||||
|
||||
websocket = connection_info.websocket
|
||||
|
||||
try:
|
||||
await websocket.send_json(message.to_dict())
|
||||
connection_info.update_activity()
|
||||
return True
|
||||
except Exception as e:
|
||||
connection_info.error_count += 1
|
||||
connection_info.state = ConnectionState.ERROR
|
||||
self.logger.warning("Failed to send message to connection",
|
||||
connection_id=connection_id,
|
||||
error=str(e),
|
||||
error_count=connection_info.error_count)
|
||||
return False
|
||||
|
||||
async def ping_connection(self, connection_id: str) -> bool:
|
||||
"""Send a ping to a specific connection"""
|
||||
ping_message = WebSocketMessage(
|
||||
type=MessageType.PING.value,
|
||||
timestamp=datetime.now(timezone.utc).isoformat()
|
||||
)
|
||||
|
||||
success = await self._send_to_connection(connection_id, ping_message)
|
||||
if success:
|
||||
async with self._connections_lock:
|
||||
connection_info = self._connections.get(connection_id)
|
||||
if connection_info:
|
||||
connection_info.last_ping = datetime.now(timezone.utc)
|
||||
|
||||
return success
|
||||
|
||||
async def handle_pong(self, connection_id: str):
|
||||
"""Handle a pong response from a connection"""
|
||||
async with self._connections_lock:
|
||||
connection_info = self._connections.get(connection_id)
|
||||
if connection_info:
|
||||
connection_info.last_pong = datetime.now(timezone.utc)
|
||||
connection_info.update_activity()
|
||||
connection_info.state = ConnectionState.CONNECTED
|
||||
|
||||
async def get_connection_info(self, connection_id: str) -> Optional[ConnectionInfo]:
|
||||
"""Get information about a specific connection"""
|
||||
async with self._connections_lock:
|
||||
info = self._connections.get(connection_id)
|
||||
# Fallback to global pool if this instance is not the registered one
|
||||
# This supports tests that instantiate a local pool while the context
|
||||
# manager uses the global pool created by app startup.
|
||||
if info is None:
|
||||
global _websocket_pool
|
||||
if _websocket_pool is not None and _websocket_pool is not self:
|
||||
return await _websocket_pool.get_connection_info(connection_id)
|
||||
return info
|
||||
|
||||
async def get_topic_connections(self, topic: str) -> List[str]:
|
||||
"""Get all connection IDs subscribed to a topic"""
|
||||
async with self._connections_lock:
|
||||
return list(self._topics.get(topic, set()))
|
||||
|
||||
async def get_user_connections(self, user_id: int) -> List[str]:
|
||||
"""Get all connection IDs for a user"""
|
||||
async with self._connections_lock:
|
||||
return list(self._user_connections.get(user_id, set()))
|
||||
|
||||
async def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get pool statistics"""
|
||||
async with self._connections_lock:
|
||||
active_by_state = {}
|
||||
for conn in self._connections.values():
|
||||
state = conn.state.value
|
||||
active_by_state[state] = active_by_state.get(state, 0) + 1
|
||||
|
||||
# Compute total unique users robustly (avoid falsey user_id like 0)
|
||||
try:
|
||||
unique_user_ids = {conn.user_id for conn in self._connections.values() if conn.user_id is not None}
|
||||
except Exception:
|
||||
unique_user_ids = set(self._user_connections.keys())
|
||||
|
||||
return {
|
||||
**self._stats,
|
||||
"active_connections": len(self._connections),
|
||||
"total_topics": len(self._topics),
|
||||
"total_users": len(unique_user_ids),
|
||||
"connections_by_state": active_by_state,
|
||||
"topic_distribution": {topic: len(conn_ids) for topic, conn_ids in self._topics.items()},
|
||||
}
|
||||
|
||||
async def _cleanup_worker(self):
|
||||
"""Background task to clean up stale connections"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(self.cleanup_interval)
|
||||
await self._cleanup_stale_connections()
|
||||
self._stats["last_cleanup"] = datetime.now(timezone.utc).isoformat()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
self.logger.error("Error in cleanup worker", error=str(e))
|
||||
|
||||
async def _heartbeat_worker(self):
|
||||
"""Background task to send heartbeats to connections"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(self.heartbeat_interval)
|
||||
await self._send_heartbeats()
|
||||
self._stats["last_heartbeat"] = datetime.now(timezone.utc).isoformat()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
self.logger.error("Error in heartbeat worker", error=str(e))
|
||||
|
||||
async def _cleanup_stale_connections(self):
|
||||
"""Clean up stale and disconnected connections"""
|
||||
stale_connections = []
|
||||
|
||||
async with self._connections_lock:
|
||||
for connection_id, connection_info in self._connections.items():
|
||||
if connection_info.is_stale(self.connection_timeout):
|
||||
stale_connections.append(connection_id)
|
||||
|
||||
# Remove stale connections
|
||||
for connection_id in stale_connections:
|
||||
await self.remove_connection(connection_id, "stale_connection")
|
||||
|
||||
if stale_connections:
|
||||
self._stats["connections_cleaned"] += len(stale_connections)
|
||||
self.logger.info("Cleaned up stale connections",
|
||||
count=len(stale_connections),
|
||||
total_cleaned=self._stats["connections_cleaned"])
|
||||
|
||||
async def _send_heartbeats(self):
|
||||
"""Send heartbeats to all active connections"""
|
||||
async with self._connections_lock:
|
||||
connection_ids = list(self._connections.keys())
|
||||
|
||||
heartbeat_message = WebSocketMessage(
|
||||
type=MessageType.HEARTBEAT.value,
|
||||
timestamp=datetime.now(timezone.utc).isoformat()
|
||||
)
|
||||
|
||||
failed_connections = []
|
||||
for connection_id in connection_ids:
|
||||
try:
|
||||
success = await self._send_to_connection(connection_id, heartbeat_message)
|
||||
if not success:
|
||||
failed_connections.append(connection_id)
|
||||
except Exception:
|
||||
failed_connections.append(connection_id)
|
||||
|
||||
# Clean up failed connections
|
||||
for connection_id in failed_connections:
|
||||
await self.remove_connection(connection_id, "heartbeat_failed")
|
||||
|
||||
async def _close_all_connections(self):
|
||||
"""Close all active connections"""
|
||||
async with self._connections_lock:
|
||||
connection_ids = list(self._connections.keys())
|
||||
|
||||
for connection_id in connection_ids:
|
||||
await self.remove_connection(connection_id, "pool_shutdown")
|
||||
|
||||
|
||||
# Global WebSocket pool instance
|
||||
_websocket_pool: Optional[WebSocketPool] = None
|
||||
|
||||
|
||||
def get_websocket_pool() -> WebSocketPool:
|
||||
"""Get the global WebSocket pool instance"""
|
||||
global _websocket_pool
|
||||
if _websocket_pool is None:
|
||||
_websocket_pool = WebSocketPool()
|
||||
return _websocket_pool
|
||||
|
||||
|
||||
async def initialize_websocket_pool(**kwargs) -> WebSocketPool:
|
||||
"""Initialize and start the global WebSocket pool"""
|
||||
global _websocket_pool
|
||||
if _websocket_pool is None:
|
||||
_websocket_pool = WebSocketPool(**kwargs)
|
||||
await _websocket_pool.start()
|
||||
return _websocket_pool
|
||||
|
||||
|
||||
async def shutdown_websocket_pool():
|
||||
"""Shutdown the global WebSocket pool"""
|
||||
global _websocket_pool
|
||||
if _websocket_pool is not None:
|
||||
await _websocket_pool.stop()
|
||||
_websocket_pool = None
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def websocket_connection(
|
||||
websocket: WebSocket,
|
||||
user_id: Optional[int] = None,
|
||||
topics: Optional[Set[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
"""
|
||||
Context manager for WebSocket connections
|
||||
|
||||
Automatically handles connection registration and cleanup
|
||||
"""
|
||||
pool = get_websocket_pool()
|
||||
connection_id = None
|
||||
|
||||
try:
|
||||
connection_id = await pool.add_connection(websocket, user_id, topics, metadata)
|
||||
yield connection_id, pool
|
||||
finally:
|
||||
if connection_id:
|
||||
await pool.remove_connection(connection_id, "context_exit")
|
||||
Reference in New Issue
Block a user