changes
This commit is contained in:
439
app/middleware/websocket_middleware.py
Normal file
439
app/middleware/websocket_middleware.py
Normal file
@@ -0,0 +1,439 @@
|
||||
"""
|
||||
WebSocket Middleware and Utilities
|
||||
|
||||
This module provides middleware and utilities for WebSocket connections,
|
||||
including authentication, connection management, and integration with the
|
||||
WebSocket pool system.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Optional, Dict, Any, Set, Callable, Awaitable
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
from fastapi import WebSocket, WebSocketDisconnect, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database.base import SessionLocal
|
||||
from app.models.user import User
|
||||
from app.auth.security import verify_token
|
||||
from app.services.websocket_pool import (
|
||||
get_websocket_pool,
|
||||
websocket_connection,
|
||||
WebSocketMessage,
|
||||
MessageType
|
||||
)
|
||||
from app.utils.logging import StructuredLogger
|
||||
|
||||
|
||||
class WebSocketAuthenticationError(Exception):
|
||||
"""Raised when WebSocket authentication fails"""
|
||||
pass
|
||||
|
||||
|
||||
class WebSocketManager:
|
||||
"""
|
||||
High-level WebSocket manager that provides easy-to-use methods
|
||||
for handling WebSocket connections with authentication and topic management
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.logger = StructuredLogger("websocket_manager", "INFO")
|
||||
self.pool = get_websocket_pool()
|
||||
|
||||
async def authenticate_websocket(self, websocket: WebSocket) -> Optional[User]:
|
||||
"""
|
||||
Authenticate a WebSocket connection using token from query parameters
|
||||
|
||||
Args:
|
||||
websocket: WebSocket instance
|
||||
|
||||
Returns:
|
||||
User object if authentication successful, None otherwise
|
||||
"""
|
||||
try:
|
||||
# Get token from query parameters
|
||||
query_params = parse_qs(str(websocket.url.query))
|
||||
token = query_params.get('token', [None])[0]
|
||||
|
||||
if not token:
|
||||
self.logger.warning("WebSocket authentication failed: no token provided")
|
||||
return None
|
||||
|
||||
# Verify token
|
||||
username = verify_token(token)
|
||||
if not username:
|
||||
self.logger.warning("WebSocket authentication failed: invalid token")
|
||||
return None
|
||||
|
||||
# Get user from database
|
||||
db: Session = SessionLocal()
|
||||
try:
|
||||
user = db.query(User).filter(User.username == username).first()
|
||||
if not user or not user.is_active:
|
||||
self.logger.warning("WebSocket authentication failed: user not found or inactive",
|
||||
username=username)
|
||||
return None
|
||||
|
||||
self.logger.info("WebSocket authentication successful",
|
||||
user_id=user.id,
|
||||
username=user.username)
|
||||
return user
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error("WebSocket authentication error", error=str(e))
|
||||
return None
|
||||
|
||||
async def handle_connection(
|
||||
self,
|
||||
websocket: WebSocket,
|
||||
topics: Optional[Set[str]] = None,
|
||||
require_auth: bool = True,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
message_handler: Optional[Callable[[str, WebSocketMessage], Awaitable[None]]] = None
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Handle a WebSocket connection with authentication and message processing
|
||||
|
||||
Args:
|
||||
websocket: WebSocket instance
|
||||
topics: Initial topics to subscribe to
|
||||
require_auth: Whether authentication is required
|
||||
metadata: Additional metadata for the connection
|
||||
message_handler: Optional function to handle incoming messages
|
||||
|
||||
Returns:
|
||||
Connection ID if successful, None if failed
|
||||
"""
|
||||
user = None
|
||||
if require_auth:
|
||||
user = await self.authenticate_websocket(websocket)
|
||||
if not user:
|
||||
await websocket.close(code=4401, reason="Authentication failed")
|
||||
return None
|
||||
|
||||
# Accept the connection
|
||||
await websocket.accept()
|
||||
|
||||
# Add to pool
|
||||
user_id = user.id if user else None
|
||||
async with websocket_connection(
|
||||
websocket=websocket,
|
||||
user_id=user_id,
|
||||
topics=topics,
|
||||
metadata=metadata
|
||||
) as (connection_id, pool):
|
||||
|
||||
# Set connection state to connected
|
||||
connection_info = await pool.get_connection_info(connection_id)
|
||||
if connection_info:
|
||||
connection_info.state = connection_info.state.CONNECTED
|
||||
|
||||
# Send initial welcome message
|
||||
welcome_message = WebSocketMessage(
|
||||
type="welcome",
|
||||
data={
|
||||
"connection_id": connection_id,
|
||||
"user_id": user_id,
|
||||
"topics": list(topics) if topics else [],
|
||||
"timestamp": connection_info.created_at.isoformat() if connection_info else None
|
||||
}
|
||||
)
|
||||
await pool._send_to_connection(connection_id, welcome_message)
|
||||
|
||||
# Handle messages
|
||||
await self._message_loop(
|
||||
websocket=websocket,
|
||||
connection_id=connection_id,
|
||||
pool=pool,
|
||||
message_handler=message_handler
|
||||
)
|
||||
|
||||
return connection_id
|
||||
|
||||
async def _message_loop(
|
||||
self,
|
||||
websocket: WebSocket,
|
||||
connection_id: str,
|
||||
pool,
|
||||
message_handler: Optional[Callable[[str, WebSocketMessage], Awaitable[None]]] = None
|
||||
):
|
||||
"""Handle incoming WebSocket messages"""
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
# Receive message
|
||||
data = await websocket.receive_text()
|
||||
|
||||
# Update activity
|
||||
connection_info = await pool.get_connection_info(connection_id)
|
||||
if connection_info:
|
||||
connection_info.update_activity()
|
||||
|
||||
# Parse message
|
||||
try:
|
||||
import json
|
||||
message_dict = json.loads(data)
|
||||
message = WebSocketMessage(**message_dict)
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
self.logger.warning("Invalid message format",
|
||||
connection_id=connection_id,
|
||||
error=str(e),
|
||||
data=data[:100])
|
||||
continue
|
||||
|
||||
# Handle standard message types
|
||||
await self._handle_standard_message(connection_id, message, pool)
|
||||
|
||||
# Call custom message handler if provided
|
||||
if message_handler:
|
||||
try:
|
||||
await message_handler(connection_id, message)
|
||||
except Exception as e:
|
||||
self.logger.error("Error in custom message handler",
|
||||
connection_id=connection_id,
|
||||
error=str(e))
|
||||
|
||||
except WebSocketDisconnect:
|
||||
self.logger.info("WebSocket disconnected", connection_id=connection_id)
|
||||
break
|
||||
except Exception as e:
|
||||
self.logger.error("Error in message loop",
|
||||
connection_id=connection_id,
|
||||
error=str(e))
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error("Fatal error in message loop",
|
||||
connection_id=connection_id,
|
||||
error=str(e))
|
||||
|
||||
async def _handle_standard_message(self, connection_id: str, message: WebSocketMessage, pool):
|
||||
"""Handle standard WebSocket message types"""
|
||||
|
||||
if message.type == MessageType.PING.value:
|
||||
# Respond with pong
|
||||
pong_message = WebSocketMessage(
|
||||
type=MessageType.PONG.value,
|
||||
data={"timestamp": message.timestamp}
|
||||
)
|
||||
await pool._send_to_connection(connection_id, pong_message)
|
||||
|
||||
elif message.type == MessageType.PONG.value:
|
||||
# Handle pong response
|
||||
await pool.handle_pong(connection_id)
|
||||
|
||||
elif message.type == MessageType.SUBSCRIBE.value:
|
||||
# Subscribe to topic
|
||||
topic = message.topic
|
||||
if topic:
|
||||
success = await pool.subscribe_to_topic(connection_id, topic)
|
||||
response = WebSocketMessage(
|
||||
type="subscription_response",
|
||||
topic=topic,
|
||||
data={"success": success, "action": "subscribe"}
|
||||
)
|
||||
await pool._send_to_connection(connection_id, response)
|
||||
|
||||
elif message.type == MessageType.UNSUBSCRIBE.value:
|
||||
# Unsubscribe from topic
|
||||
topic = message.topic
|
||||
if topic:
|
||||
success = await pool.unsubscribe_from_topic(connection_id, topic)
|
||||
response = WebSocketMessage(
|
||||
type="subscription_response",
|
||||
topic=topic,
|
||||
data={"success": success, "action": "unsubscribe"}
|
||||
)
|
||||
await pool._send_to_connection(connection_id, response)
|
||||
|
||||
async def broadcast_to_topic(
|
||||
self,
|
||||
topic: str,
|
||||
message_type: str,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
exclude_connection_id: Optional[str] = None
|
||||
) -> int:
|
||||
"""Convenience method to broadcast a message to a topic"""
|
||||
message = WebSocketMessage(
|
||||
type=message_type,
|
||||
topic=topic,
|
||||
data=data
|
||||
)
|
||||
return await self.pool.broadcast_to_topic(topic, message, exclude_connection_id)
|
||||
|
||||
async def send_to_user(
|
||||
self,
|
||||
user_id: int,
|
||||
message_type: str,
|
||||
data: Optional[Dict[str, Any]] = None
|
||||
) -> int:
|
||||
"""Convenience method to send a message to all connections for a user"""
|
||||
message = WebSocketMessage(
|
||||
type=message_type,
|
||||
data=data
|
||||
)
|
||||
return await self.pool.send_to_user(user_id, message)
|
||||
|
||||
async def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get WebSocket pool statistics"""
|
||||
return await self.pool.get_stats()
|
||||
|
||||
|
||||
# Global WebSocket manager instance
|
||||
_websocket_manager: Optional[WebSocketManager] = None
|
||||
|
||||
|
||||
def get_websocket_manager() -> WebSocketManager:
|
||||
"""Get the global WebSocket manager instance"""
|
||||
global _websocket_manager
|
||||
if _websocket_manager is None:
|
||||
_websocket_manager = WebSocketManager()
|
||||
return _websocket_manager
|
||||
|
||||
|
||||
# Utility decorators and functions
|
||||
|
||||
def websocket_endpoint(
|
||||
topics: Optional[Set[str]] = None,
|
||||
require_auth: bool = True,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
"""
|
||||
Decorator for WebSocket endpoints that automatically handles
|
||||
connection management, authentication, and cleanup
|
||||
|
||||
Usage:
|
||||
@router.websocket("/my-endpoint")
|
||||
@websocket_endpoint(topics={"my_topic"}, require_auth=True)
|
||||
async def my_websocket_handler(websocket: WebSocket, connection_id: str, manager: WebSocketManager):
|
||||
# Your custom logic here
|
||||
pass
|
||||
"""
|
||||
def decorator(func):
|
||||
async def wrapper(websocket: WebSocket, *args, **kwargs):
|
||||
manager = get_websocket_manager()
|
||||
|
||||
async def message_handler(connection_id: str, message: WebSocketMessage):
|
||||
# Call the original function with the message
|
||||
await func(websocket, connection_id, manager, message, *args, **kwargs)
|
||||
|
||||
# Handle the connection
|
||||
connection_id = await manager.handle_connection(
|
||||
websocket=websocket,
|
||||
topics=topics,
|
||||
require_auth=require_auth,
|
||||
metadata=metadata,
|
||||
message_handler=message_handler
|
||||
)
|
||||
|
||||
if not connection_id:
|
||||
return
|
||||
|
||||
# Keep the connection alive
|
||||
try:
|
||||
while True:
|
||||
await asyncio.sleep(1)
|
||||
connection_info = await manager.pool.get_connection_info(connection_id)
|
||||
if not connection_info or not connection_info.is_alive():
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
async def websocket_auth_dependency(websocket: WebSocket) -> User:
|
||||
"""
|
||||
FastAPI dependency for WebSocket authentication
|
||||
|
||||
Usage:
|
||||
@router.websocket("/my-endpoint")
|
||||
async def my_endpoint(websocket: WebSocket, user: User = Depends(websocket_auth_dependency)):
|
||||
# user is guaranteed to be authenticated
|
||||
pass
|
||||
"""
|
||||
manager = get_websocket_manager()
|
||||
user = await manager.authenticate_websocket(websocket)
|
||||
if not user:
|
||||
await websocket.close(code=4401, reason="Authentication failed")
|
||||
raise WebSocketAuthenticationError("Authentication failed")
|
||||
return user
|
||||
|
||||
|
||||
class WebSocketConnectionTracker:
|
||||
"""
|
||||
Utility class to track WebSocket connections and their health
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.logger = StructuredLogger("websocket_tracker", "INFO")
|
||||
|
||||
async def track_connection_health(self, connection_id: str, interval: int = 60):
|
||||
"""Track the health of a specific connection"""
|
||||
pool = get_websocket_pool()
|
||||
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
connection_info = await pool.get_connection_info(connection_id)
|
||||
if not connection_info:
|
||||
break
|
||||
|
||||
# Check if connection is healthy
|
||||
if connection_info.is_stale(timeout_seconds=300):
|
||||
self.logger.warning("Connection is stale",
|
||||
connection_id=connection_id,
|
||||
last_activity=connection_info.last_activity.isoformat())
|
||||
break
|
||||
|
||||
# Try to ping the connection
|
||||
if connection_info.is_alive():
|
||||
success = await pool.ping_connection(connection_id)
|
||||
if not success:
|
||||
self.logger.warning("Failed to ping connection",
|
||||
connection_id=connection_id)
|
||||
break
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
self.logger.error("Error tracking connection health",
|
||||
connection_id=connection_id,
|
||||
error=str(e))
|
||||
break
|
||||
|
||||
async def get_connection_metrics(self, connection_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get detailed metrics for a connection"""
|
||||
pool = get_websocket_pool()
|
||||
connection_info = await pool.get_connection_info(connection_id)
|
||||
|
||||
if not connection_info:
|
||||
return None
|
||||
|
||||
now = connection_info.last_activity # Use last_activity for consistency
|
||||
return {
|
||||
"connection_id": connection_id,
|
||||
"user_id": connection_info.user_id,
|
||||
"state": connection_info.state.value,
|
||||
"topics": list(connection_info.topics),
|
||||
"created_at": connection_info.created_at.isoformat(),
|
||||
"last_activity": connection_info.last_activity.isoformat(),
|
||||
"age_seconds": (now - connection_info.created_at).total_seconds(),
|
||||
"idle_seconds": (now - connection_info.last_activity).total_seconds(),
|
||||
"error_count": connection_info.error_count,
|
||||
"last_ping": connection_info.last_ping.isoformat() if connection_info.last_ping else None,
|
||||
"last_pong": connection_info.last_pong.isoformat() if connection_info.last_pong else None,
|
||||
"metadata": connection_info.metadata,
|
||||
"is_alive": connection_info.is_alive(),
|
||||
"is_stale": connection_info.is_stale()
|
||||
}
|
||||
|
||||
|
||||
def get_connection_tracker() -> WebSocketConnectionTracker:
|
||||
"""Get a WebSocket connection tracker instance"""
|
||||
return WebSocketConnectionTracker()
|
||||
Reference in New Issue
Block a user