Files
delphi-database/app/middleware/websocket_middleware.py
HotSwapp bac8cc4bd5 changes
2025-08-18 20:20:04 -05:00

440 lines
16 KiB
Python

"""
WebSocket Middleware and Utilities
This module provides middleware and utilities for WebSocket connections,
including authentication, connection management, and integration with the
WebSocket pool system.
"""
import asyncio
from typing import Optional, Dict, Any, Set, Callable, Awaitable
from urllib.parse import parse_qs
from fastapi import WebSocket, WebSocketDisconnect, HTTPException, status
from sqlalchemy.orm import Session
from app.database.base import SessionLocal
from app.models.user import User
from app.auth.security import verify_token
from app.services.websocket_pool import (
get_websocket_pool,
websocket_connection,
WebSocketMessage,
MessageType
)
from app.utils.logging import StructuredLogger
class WebSocketAuthenticationError(Exception):
"""Raised when WebSocket authentication fails"""
pass
class WebSocketManager:
"""
High-level WebSocket manager that provides easy-to-use methods
for handling WebSocket connections with authentication and topic management
"""
def __init__(self):
self.logger = StructuredLogger("websocket_manager", "INFO")
self.pool = get_websocket_pool()
async def authenticate_websocket(self, websocket: WebSocket) -> Optional[User]:
"""
Authenticate a WebSocket connection using token from query parameters
Args:
websocket: WebSocket instance
Returns:
User object if authentication successful, None otherwise
"""
try:
# Get token from query parameters
query_params = parse_qs(str(websocket.url.query))
token = query_params.get('token', [None])[0]
if not token:
self.logger.warning("WebSocket authentication failed: no token provided")
return None
# Verify token
username = verify_token(token)
if not username:
self.logger.warning("WebSocket authentication failed: invalid token")
return None
# Get user from database
db: Session = SessionLocal()
try:
user = db.query(User).filter(User.username == username).first()
if not user or not user.is_active:
self.logger.warning("WebSocket authentication failed: user not found or inactive",
username=username)
return None
self.logger.info("WebSocket authentication successful",
user_id=user.id,
username=user.username)
return user
finally:
db.close()
except Exception as e:
self.logger.error("WebSocket authentication error", error=str(e))
return None
async def handle_connection(
self,
websocket: WebSocket,
topics: Optional[Set[str]] = None,
require_auth: bool = True,
metadata: Optional[Dict[str, Any]] = None,
message_handler: Optional[Callable[[str, WebSocketMessage], Awaitable[None]]] = None
) -> Optional[str]:
"""
Handle a WebSocket connection with authentication and message processing
Args:
websocket: WebSocket instance
topics: Initial topics to subscribe to
require_auth: Whether authentication is required
metadata: Additional metadata for the connection
message_handler: Optional function to handle incoming messages
Returns:
Connection ID if successful, None if failed
"""
user = None
if require_auth:
user = await self.authenticate_websocket(websocket)
if not user:
await websocket.close(code=4401, reason="Authentication failed")
return None
# Accept the connection
await websocket.accept()
# Add to pool
user_id = user.id if user else None
async with websocket_connection(
websocket=websocket,
user_id=user_id,
topics=topics,
metadata=metadata
) as (connection_id, pool):
# Set connection state to connected
connection_info = await pool.get_connection_info(connection_id)
if connection_info:
connection_info.state = connection_info.state.CONNECTED
# Send initial welcome message
welcome_message = WebSocketMessage(
type="welcome",
data={
"connection_id": connection_id,
"user_id": user_id,
"topics": list(topics) if topics else [],
"timestamp": connection_info.created_at.isoformat() if connection_info else None
}
)
await pool._send_to_connection(connection_id, welcome_message)
# Handle messages
await self._message_loop(
websocket=websocket,
connection_id=connection_id,
pool=pool,
message_handler=message_handler
)
return connection_id
async def _message_loop(
self,
websocket: WebSocket,
connection_id: str,
pool,
message_handler: Optional[Callable[[str, WebSocketMessage], Awaitable[None]]] = None
):
"""Handle incoming WebSocket messages"""
try:
while True:
try:
# Receive message
data = await websocket.receive_text()
# Update activity
connection_info = await pool.get_connection_info(connection_id)
if connection_info:
connection_info.update_activity()
# Parse message
try:
import json
message_dict = json.loads(data)
message = WebSocketMessage(**message_dict)
except (json.JSONDecodeError, ValueError) as e:
self.logger.warning("Invalid message format",
connection_id=connection_id,
error=str(e),
data=data[:100])
continue
# Handle standard message types
await self._handle_standard_message(connection_id, message, pool)
# Call custom message handler if provided
if message_handler:
try:
await message_handler(connection_id, message)
except Exception as e:
self.logger.error("Error in custom message handler",
connection_id=connection_id,
error=str(e))
except WebSocketDisconnect:
self.logger.info("WebSocket disconnected", connection_id=connection_id)
break
except Exception as e:
self.logger.error("Error in message loop",
connection_id=connection_id,
error=str(e))
break
except Exception as e:
self.logger.error("Fatal error in message loop",
connection_id=connection_id,
error=str(e))
async def _handle_standard_message(self, connection_id: str, message: WebSocketMessage, pool):
"""Handle standard WebSocket message types"""
if message.type == MessageType.PING.value:
# Respond with pong
pong_message = WebSocketMessage(
type=MessageType.PONG.value,
data={"timestamp": message.timestamp}
)
await pool._send_to_connection(connection_id, pong_message)
elif message.type == MessageType.PONG.value:
# Handle pong response
await pool.handle_pong(connection_id)
elif message.type == MessageType.SUBSCRIBE.value:
# Subscribe to topic
topic = message.topic
if topic:
success = await pool.subscribe_to_topic(connection_id, topic)
response = WebSocketMessage(
type="subscription_response",
topic=topic,
data={"success": success, "action": "subscribe"}
)
await pool._send_to_connection(connection_id, response)
elif message.type == MessageType.UNSUBSCRIBE.value:
# Unsubscribe from topic
topic = message.topic
if topic:
success = await pool.unsubscribe_from_topic(connection_id, topic)
response = WebSocketMessage(
type="subscription_response",
topic=topic,
data={"success": success, "action": "unsubscribe"}
)
await pool._send_to_connection(connection_id, response)
async def broadcast_to_topic(
self,
topic: str,
message_type: str,
data: Optional[Dict[str, Any]] = None,
exclude_connection_id: Optional[str] = None
) -> int:
"""Convenience method to broadcast a message to a topic"""
message = WebSocketMessage(
type=message_type,
topic=topic,
data=data
)
return await self.pool.broadcast_to_topic(topic, message, exclude_connection_id)
async def send_to_user(
self,
user_id: int,
message_type: str,
data: Optional[Dict[str, Any]] = None
) -> int:
"""Convenience method to send a message to all connections for a user"""
message = WebSocketMessage(
type=message_type,
data=data
)
return await self.pool.send_to_user(user_id, message)
async def get_stats(self) -> Dict[str, Any]:
"""Get WebSocket pool statistics"""
return await self.pool.get_stats()
# Global WebSocket manager instance
_websocket_manager: Optional[WebSocketManager] = None
def get_websocket_manager() -> WebSocketManager:
"""Get the global WebSocket manager instance"""
global _websocket_manager
if _websocket_manager is None:
_websocket_manager = WebSocketManager()
return _websocket_manager
# Utility decorators and functions
def websocket_endpoint(
topics: Optional[Set[str]] = None,
require_auth: bool = True,
metadata: Optional[Dict[str, Any]] = None
):
"""
Decorator for WebSocket endpoints that automatically handles
connection management, authentication, and cleanup
Usage:
@router.websocket("/my-endpoint")
@websocket_endpoint(topics={"my_topic"}, require_auth=True)
async def my_websocket_handler(websocket: WebSocket, connection_id: str, manager: WebSocketManager):
# Your custom logic here
pass
"""
def decorator(func):
async def wrapper(websocket: WebSocket, *args, **kwargs):
manager = get_websocket_manager()
async def message_handler(connection_id: str, message: WebSocketMessage):
# Call the original function with the message
await func(websocket, connection_id, manager, message, *args, **kwargs)
# Handle the connection
connection_id = await manager.handle_connection(
websocket=websocket,
topics=topics,
require_auth=require_auth,
metadata=metadata,
message_handler=message_handler
)
if not connection_id:
return
# Keep the connection alive
try:
while True:
await asyncio.sleep(1)
connection_info = await manager.pool.get_connection_info(connection_id)
if not connection_info or not connection_info.is_alive():
break
except Exception:
pass
return wrapper
return decorator
async def websocket_auth_dependency(websocket: WebSocket) -> User:
"""
FastAPI dependency for WebSocket authentication
Usage:
@router.websocket("/my-endpoint")
async def my_endpoint(websocket: WebSocket, user: User = Depends(websocket_auth_dependency)):
# user is guaranteed to be authenticated
pass
"""
manager = get_websocket_manager()
user = await manager.authenticate_websocket(websocket)
if not user:
await websocket.close(code=4401, reason="Authentication failed")
raise WebSocketAuthenticationError("Authentication failed")
return user
class WebSocketConnectionTracker:
"""
Utility class to track WebSocket connections and their health
"""
def __init__(self):
self.logger = StructuredLogger("websocket_tracker", "INFO")
async def track_connection_health(self, connection_id: str, interval: int = 60):
"""Track the health of a specific connection"""
pool = get_websocket_pool()
while True:
try:
await asyncio.sleep(interval)
connection_info = await pool.get_connection_info(connection_id)
if not connection_info:
break
# Check if connection is healthy
if connection_info.is_stale(timeout_seconds=300):
self.logger.warning("Connection is stale",
connection_id=connection_id,
last_activity=connection_info.last_activity.isoformat())
break
# Try to ping the connection
if connection_info.is_alive():
success = await pool.ping_connection(connection_id)
if not success:
self.logger.warning("Failed to ping connection",
connection_id=connection_id)
break
except asyncio.CancelledError:
break
except Exception as e:
self.logger.error("Error tracking connection health",
connection_id=connection_id,
error=str(e))
break
async def get_connection_metrics(self, connection_id: str) -> Optional[Dict[str, Any]]:
"""Get detailed metrics for a connection"""
pool = get_websocket_pool()
connection_info = await pool.get_connection_info(connection_id)
if not connection_info:
return None
now = connection_info.last_activity # Use last_activity for consistency
return {
"connection_id": connection_id,
"user_id": connection_info.user_id,
"state": connection_info.state.value,
"topics": list(connection_info.topics),
"created_at": connection_info.created_at.isoformat(),
"last_activity": connection_info.last_activity.isoformat(),
"age_seconds": (now - connection_info.created_at).total_seconds(),
"idle_seconds": (now - connection_info.last_activity).total_seconds(),
"error_count": connection_info.error_count,
"last_ping": connection_info.last_ping.isoformat() if connection_info.last_ping else None,
"last_pong": connection_info.last_pong.isoformat() if connection_info.last_pong else None,
"metadata": connection_info.metadata,
"is_alive": connection_info.is_alive(),
"is_stale": connection_info.is_stale()
}
def get_connection_tracker() -> WebSocketConnectionTracker:
"""Get a WebSocket connection tracker instance"""
return WebSocketConnectionTracker()