""" WebSocket Middleware and Utilities This module provides middleware and utilities for WebSocket connections, including authentication, connection management, and integration with the WebSocket pool system. """ import asyncio from typing import Optional, Dict, Any, Set, Callable, Awaitable from urllib.parse import parse_qs from fastapi import WebSocket, WebSocketDisconnect, HTTPException, status from sqlalchemy.orm import Session from app.database.base import SessionLocal from app.models.user import User from app.auth.security import verify_token from app.services.websocket_pool import ( get_websocket_pool, websocket_connection, WebSocketMessage, MessageType ) from app.utils.logging import StructuredLogger class WebSocketAuthenticationError(Exception): """Raised when WebSocket authentication fails""" pass class WebSocketManager: """ High-level WebSocket manager that provides easy-to-use methods for handling WebSocket connections with authentication and topic management """ def __init__(self): self.logger = StructuredLogger("websocket_manager", "INFO") self.pool = get_websocket_pool() async def authenticate_websocket(self, websocket: WebSocket) -> Optional[User]: """ Authenticate a WebSocket connection using token from query parameters Args: websocket: WebSocket instance Returns: User object if authentication successful, None otherwise """ try: # Get token from query parameters query_params = parse_qs(str(websocket.url.query)) token = query_params.get('token', [None])[0] if not token: self.logger.warning("WebSocket authentication failed: no token provided") return None # Verify token username = verify_token(token) if not username: self.logger.warning("WebSocket authentication failed: invalid token") return None # Get user from database db: Session = SessionLocal() try: user = db.query(User).filter(User.username == username).first() if not user or not user.is_active: self.logger.warning("WebSocket authentication failed: user not found or inactive", username=username) return None self.logger.info("WebSocket authentication successful", user_id=user.id, username=user.username) return user finally: db.close() except Exception as e: self.logger.error("WebSocket authentication error", error=str(e)) return None async def handle_connection( self, websocket: WebSocket, topics: Optional[Set[str]] = None, require_auth: bool = True, metadata: Optional[Dict[str, Any]] = None, message_handler: Optional[Callable[[str, WebSocketMessage], Awaitable[None]]] = None ) -> Optional[str]: """ Handle a WebSocket connection with authentication and message processing Args: websocket: WebSocket instance topics: Initial topics to subscribe to require_auth: Whether authentication is required metadata: Additional metadata for the connection message_handler: Optional function to handle incoming messages Returns: Connection ID if successful, None if failed """ user = None if require_auth: user = await self.authenticate_websocket(websocket) if not user: await websocket.close(code=4401, reason="Authentication failed") return None # Accept the connection await websocket.accept() # Add to pool user_id = user.id if user else None async with websocket_connection( websocket=websocket, user_id=user_id, topics=topics, metadata=metadata ) as (connection_id, pool): # Set connection state to connected connection_info = await pool.get_connection_info(connection_id) if connection_info: connection_info.state = connection_info.state.CONNECTED # Send initial welcome message welcome_message = WebSocketMessage( type="welcome", data={ "connection_id": connection_id, "user_id": user_id, "topics": list(topics) if topics else [], "timestamp": connection_info.created_at.isoformat() if connection_info else None } ) await pool._send_to_connection(connection_id, welcome_message) # Handle messages await self._message_loop( websocket=websocket, connection_id=connection_id, pool=pool, message_handler=message_handler ) return connection_id async def _message_loop( self, websocket: WebSocket, connection_id: str, pool, message_handler: Optional[Callable[[str, WebSocketMessage], Awaitable[None]]] = None ): """Handle incoming WebSocket messages""" try: while True: try: # Receive message data = await websocket.receive_text() # Update activity connection_info = await pool.get_connection_info(connection_id) if connection_info: connection_info.update_activity() # Parse message try: import json message_dict = json.loads(data) message = WebSocketMessage(**message_dict) except (json.JSONDecodeError, ValueError) as e: self.logger.warning("Invalid message format", connection_id=connection_id, error=str(e), data=data[:100]) continue # Handle standard message types await self._handle_standard_message(connection_id, message, pool) # Call custom message handler if provided if message_handler: try: await message_handler(connection_id, message) except Exception as e: self.logger.error("Error in custom message handler", connection_id=connection_id, error=str(e)) except WebSocketDisconnect: self.logger.info("WebSocket disconnected", connection_id=connection_id) break except Exception as e: self.logger.error("Error in message loop", connection_id=connection_id, error=str(e)) break except Exception as e: self.logger.error("Fatal error in message loop", connection_id=connection_id, error=str(e)) async def _handle_standard_message(self, connection_id: str, message: WebSocketMessage, pool): """Handle standard WebSocket message types""" if message.type == MessageType.PING.value: # Respond with pong pong_message = WebSocketMessage( type=MessageType.PONG.value, data={"timestamp": message.timestamp} ) await pool._send_to_connection(connection_id, pong_message) elif message.type == MessageType.PONG.value: # Handle pong response await pool.handle_pong(connection_id) elif message.type == MessageType.SUBSCRIBE.value: # Subscribe to topic topic = message.topic if topic: success = await pool.subscribe_to_topic(connection_id, topic) response = WebSocketMessage( type="subscription_response", topic=topic, data={"success": success, "action": "subscribe"} ) await pool._send_to_connection(connection_id, response) elif message.type == MessageType.UNSUBSCRIBE.value: # Unsubscribe from topic topic = message.topic if topic: success = await pool.unsubscribe_from_topic(connection_id, topic) response = WebSocketMessage( type="subscription_response", topic=topic, data={"success": success, "action": "unsubscribe"} ) await pool._send_to_connection(connection_id, response) async def broadcast_to_topic( self, topic: str, message_type: str, data: Optional[Dict[str, Any]] = None, exclude_connection_id: Optional[str] = None ) -> int: """Convenience method to broadcast a message to a topic""" message = WebSocketMessage( type=message_type, topic=topic, data=data ) return await self.pool.broadcast_to_topic(topic, message, exclude_connection_id) async def send_to_user( self, user_id: int, message_type: str, data: Optional[Dict[str, Any]] = None ) -> int: """Convenience method to send a message to all connections for a user""" message = WebSocketMessage( type=message_type, data=data ) return await self.pool.send_to_user(user_id, message) async def get_stats(self) -> Dict[str, Any]: """Get WebSocket pool statistics""" return await self.pool.get_stats() # Global WebSocket manager instance _websocket_manager: Optional[WebSocketManager] = None def get_websocket_manager() -> WebSocketManager: """Get the global WebSocket manager instance""" global _websocket_manager if _websocket_manager is None: _websocket_manager = WebSocketManager() return _websocket_manager # Utility decorators and functions def websocket_endpoint( topics: Optional[Set[str]] = None, require_auth: bool = True, metadata: Optional[Dict[str, Any]] = None ): """ Decorator for WebSocket endpoints that automatically handles connection management, authentication, and cleanup Usage: @router.websocket("/my-endpoint") @websocket_endpoint(topics={"my_topic"}, require_auth=True) async def my_websocket_handler(websocket: WebSocket, connection_id: str, manager: WebSocketManager): # Your custom logic here pass """ def decorator(func): async def wrapper(websocket: WebSocket, *args, **kwargs): manager = get_websocket_manager() async def message_handler(connection_id: str, message: WebSocketMessage): # Call the original function with the message await func(websocket, connection_id, manager, message, *args, **kwargs) # Handle the connection connection_id = await manager.handle_connection( websocket=websocket, topics=topics, require_auth=require_auth, metadata=metadata, message_handler=message_handler ) if not connection_id: return # Keep the connection alive try: while True: await asyncio.sleep(1) connection_info = await manager.pool.get_connection_info(connection_id) if not connection_info or not connection_info.is_alive(): break except Exception: pass return wrapper return decorator async def websocket_auth_dependency(websocket: WebSocket) -> User: """ FastAPI dependency for WebSocket authentication Usage: @router.websocket("/my-endpoint") async def my_endpoint(websocket: WebSocket, user: User = Depends(websocket_auth_dependency)): # user is guaranteed to be authenticated pass """ manager = get_websocket_manager() user = await manager.authenticate_websocket(websocket) if not user: await websocket.close(code=4401, reason="Authentication failed") raise WebSocketAuthenticationError("Authentication failed") return user class WebSocketConnectionTracker: """ Utility class to track WebSocket connections and their health """ def __init__(self): self.logger = StructuredLogger("websocket_tracker", "INFO") async def track_connection_health(self, connection_id: str, interval: int = 60): """Track the health of a specific connection""" pool = get_websocket_pool() while True: try: await asyncio.sleep(interval) connection_info = await pool.get_connection_info(connection_id) if not connection_info: break # Check if connection is healthy if connection_info.is_stale(timeout_seconds=300): self.logger.warning("Connection is stale", connection_id=connection_id, last_activity=connection_info.last_activity.isoformat()) break # Try to ping the connection if connection_info.is_alive(): success = await pool.ping_connection(connection_id) if not success: self.logger.warning("Failed to ping connection", connection_id=connection_id) break except asyncio.CancelledError: break except Exception as e: self.logger.error("Error tracking connection health", connection_id=connection_id, error=str(e)) break async def get_connection_metrics(self, connection_id: str) -> Optional[Dict[str, Any]]: """Get detailed metrics for a connection""" pool = get_websocket_pool() connection_info = await pool.get_connection_info(connection_id) if not connection_info: return None now = connection_info.last_activity # Use last_activity for consistency return { "connection_id": connection_id, "user_id": connection_info.user_id, "state": connection_info.state.value, "topics": list(connection_info.topics), "created_at": connection_info.created_at.isoformat(), "last_activity": connection_info.last_activity.isoformat(), "age_seconds": (now - connection_info.created_at).total_seconds(), "idle_seconds": (now - connection_info.last_activity).total_seconds(), "error_count": connection_info.error_count, "last_ping": connection_info.last_ping.isoformat() if connection_info.last_ping else None, "last_pong": connection_info.last_pong.isoformat() if connection_info.last_pong else None, "metadata": connection_info.metadata, "is_alive": connection_info.is_alive(), "is_stale": connection_info.is_stale() } def get_connection_tracker() -> WebSocketConnectionTracker: """Get a WebSocket connection tracker instance""" return WebSocketConnectionTracker()