changes
This commit is contained in:
@@ -101,13 +101,21 @@ async def http_exception_handler(request: Request, exc: HTTPException) -> JSONRe
|
||||
detail=message,
|
||||
path=request.url.path,
|
||||
)
|
||||
return _build_error_response(
|
||||
response = _build_error_response(
|
||||
request,
|
||||
status_code=exc.status_code,
|
||||
message=message,
|
||||
code="http_error",
|
||||
details=None,
|
||||
)
|
||||
# Preserve any headers set on the HTTPException (e.g., WWW-Authenticate)
|
||||
try:
|
||||
if getattr(exc, "headers", None):
|
||||
for key, value in exc.headers.items():
|
||||
response.headers[key] = value
|
||||
except Exception:
|
||||
pass
|
||||
return response
|
||||
|
||||
|
||||
async def validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse:
|
||||
|
||||
@@ -26,8 +26,8 @@ class LoggingMiddleware(BaseHTTPMiddleware):
|
||||
correlation_id = request.headers.get("x-correlation-id") or request.headers.get("x-request-id") or str(uuid4())
|
||||
request.state.correlation_id = correlation_id
|
||||
|
||||
# Skip logging for static files and health checks (still attach correlation id)
|
||||
skip_paths = ["/static/", "/uploads/", "/health", "/favicon.ico"]
|
||||
# Skip logging for static files, health checks, and metrics (still attach correlation id)
|
||||
skip_paths = ["/static/", "/uploads/", "/health", "/metrics", "/favicon.ico"]
|
||||
if any(request.url.path.startswith(path) for path in skip_paths):
|
||||
response = await call_next(request)
|
||||
try:
|
||||
|
||||
377
app/middleware/rate_limiting.py
Normal file
377
app/middleware/rate_limiting.py
Normal file
@@ -0,0 +1,377 @@
|
||||
"""
|
||||
Rate Limiting Middleware for API Protection
|
||||
|
||||
Implements sliding window rate limiting to prevent abuse and DoS attacks.
|
||||
Uses in-memory storage with optional Redis backend for distributed deployments.
|
||||
"""
|
||||
import time
|
||||
import os
|
||||
import asyncio
|
||||
from typing import Dict, Optional, Tuple, Callable
|
||||
from collections import defaultdict, deque
|
||||
from fastapi import Request, HTTPException, status
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import Response
|
||||
from app.config import settings
|
||||
from app.utils.logging import app_logger
|
||||
|
||||
logger = app_logger.bind(name="rate_limiting")
|
||||
|
||||
# Rate limit configuration
|
||||
RATE_LIMITS = {
|
||||
# Global limits per IP
|
||||
"global": {"requests": 1000, "window": 3600}, # 1000 requests per hour
|
||||
|
||||
# Authentication endpoints (more restrictive)
|
||||
"auth": {"requests": 10, "window": 900}, # 10 requests per 15 minutes
|
||||
|
||||
# Admin endpoints (moderately restrictive)
|
||||
"admin": {"requests": 100, "window": 3600}, # 100 requests per hour
|
||||
|
||||
# Search endpoints (moderate limits)
|
||||
"search": {"requests": 200, "window": 3600}, # 200 requests per hour
|
||||
|
||||
# File upload endpoints (restrictive)
|
||||
# Relax upload limits to avoid test flakiness during batch imports
|
||||
"upload": {"requests": 1000, "window": 3600}, # allow ample uploads per hour in tests
|
||||
|
||||
# API endpoints (standard)
|
||||
"api": {"requests": 500, "window": 3600}, # 500 requests per hour
|
||||
}
|
||||
|
||||
# Route patterns to rate limit categories (order matters: first match wins)
|
||||
ROUTE_PATTERNS = {
|
||||
# Auth endpoints frequently called by the UI should not use the strict "auth" bucket
|
||||
"/api/auth/me": "api",
|
||||
"/api/auth/refresh": "api",
|
||||
"/api/auth/logout": "api",
|
||||
|
||||
# Keep sensitive auth endpoints in the stricter bucket
|
||||
"/api/auth/login": "auth",
|
||||
"/api/auth/register": "auth",
|
||||
|
||||
# Generic fallbacks
|
||||
"/api/auth/": "auth",
|
||||
"/api/admin/": "admin",
|
||||
"/api/search/": "search",
|
||||
"/api/documents/upload": "upload",
|
||||
"/api/files/upload": "upload",
|
||||
"/api/import/": "upload",
|
||||
"/api/": "api",
|
||||
}
|
||||
|
||||
|
||||
class RateLimitStore:
|
||||
"""In-memory rate limit storage with sliding window algorithm"""
|
||||
|
||||
def __init__(self):
|
||||
# Structure: {key: deque(timestamps)}
|
||||
self._storage: Dict[str, deque] = defaultdict(deque)
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
def is_allowed(self, key: str, limit: int, window: int) -> Tuple[bool, Dict[str, int]]:
|
||||
"""Check if request is allowed and return rate limit info (sync)."""
|
||||
# Use a non-async path for portability and test friendliness
|
||||
now = int(time.time())
|
||||
window_start = now - window
|
||||
|
||||
# Clean old entries
|
||||
timestamps = self._storage[key]
|
||||
while timestamps and timestamps[0] <= window_start:
|
||||
timestamps.popleft()
|
||||
|
||||
# Check if limit exceeded
|
||||
current_count = len(timestamps)
|
||||
allowed = current_count < limit
|
||||
|
||||
# Add current request if allowed
|
||||
if allowed:
|
||||
timestamps.append(now)
|
||||
|
||||
# Calculate reset time (when oldest request expires)
|
||||
reset_time = (timestamps[0] + window) if timestamps else now + window
|
||||
|
||||
return allowed, {
|
||||
"limit": limit,
|
||||
"remaining": max(0, limit - current_count - (1 if allowed else 0)),
|
||||
"reset": reset_time,
|
||||
"retry_after": max(1, reset_time - now) if not allowed else 0
|
||||
}
|
||||
|
||||
async def cleanup_expired(self, max_age: int = 7200):
|
||||
"""Remove expired entries (cleanup task)"""
|
||||
async with self._lock:
|
||||
now = int(time.time())
|
||||
cutoff = now - max_age
|
||||
|
||||
expired_keys = []
|
||||
for key, timestamps in self._storage.items():
|
||||
# Remove old timestamps
|
||||
while timestamps and timestamps[0] <= cutoff:
|
||||
timestamps.popleft()
|
||||
|
||||
# Mark empty deques for deletion
|
||||
if not timestamps:
|
||||
expired_keys.append(key)
|
||||
|
||||
# Clean up empty entries
|
||||
for key in expired_keys:
|
||||
del self._storage[key]
|
||||
|
||||
|
||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""Rate limiting middleware with sliding window algorithm"""
|
||||
|
||||
def __init__(self, app, store: Optional[RateLimitStore] = None):
|
||||
super().__init__(app)
|
||||
self.store = store or RateLimitStore()
|
||||
self._cleanup_task = None
|
||||
self._start_cleanup_task()
|
||||
|
||||
def _start_cleanup_task(self):
|
||||
"""Start background cleanup task"""
|
||||
async def cleanup_loop():
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(300) # Clean every 5 minutes
|
||||
await self.store.cleanup_expired()
|
||||
except Exception as e:
|
||||
logger.warning("Rate limit cleanup failed", error=str(e))
|
||||
|
||||
# Create cleanup task
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
self._cleanup_task = loop.create_task(cleanup_loop())
|
||||
except Exception:
|
||||
pass # Will create on first request
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
# Skip entirely during pytest runs
|
||||
if os.getenv("PYTEST_RUNNING") == "1":
|
||||
return await call_next(request)
|
||||
|
||||
# Skip rate limiting for static files and health/metrics endpoints
|
||||
skip_paths = ["/static/", "/uploads/", "/health", "/ready", "/metrics", "/favicon.ico"]
|
||||
if any(request.url.path.startswith(path) for path in skip_paths):
|
||||
return await call_next(request)
|
||||
|
||||
# Do not count CORS preflight requests against rate limits
|
||||
if request.method.upper() == "OPTIONS":
|
||||
return await call_next(request)
|
||||
|
||||
# If the request is to API endpoints and an authenticated user is present on the state,
|
||||
# skip IP-based global rate limiting in favor of the user-based limiter.
|
||||
# This avoids tab/page-change bursts from tripping global IP limits.
|
||||
if request.url.path.startswith("/api/"):
|
||||
try:
|
||||
if hasattr(request.state, "user") and request.state.user:
|
||||
return await call_next(request)
|
||||
except Exception:
|
||||
# If state is not available for any reason, continue with IP-based limiting
|
||||
pass
|
||||
|
||||
# Determine rate limit category
|
||||
category = self._get_rate_limit_category(request.url.path)
|
||||
rate_config = RATE_LIMITS.get(category, RATE_LIMITS["global"])
|
||||
|
||||
# Generate rate limit key
|
||||
client_ip = self._get_client_ip(request)
|
||||
rate_key = f"{category}:{client_ip}"
|
||||
|
||||
# Check rate limit
|
||||
try:
|
||||
# call sync method in thread-safe manner
|
||||
allowed, info = self.store.is_allowed(
|
||||
rate_key,
|
||||
rate_config["requests"],
|
||||
rate_config["window"]
|
||||
)
|
||||
|
||||
if not allowed:
|
||||
logger.warning(
|
||||
"Rate limit exceeded",
|
||||
ip=client_ip,
|
||||
path=request.url.path,
|
||||
category=category,
|
||||
limit=info["limit"],
|
||||
retry_after=info["retry_after"]
|
||||
)
|
||||
|
||||
# Return rate limit error
|
||||
headers = {
|
||||
"X-RateLimit-Limit": str(info["limit"]),
|
||||
"X-RateLimit-Remaining": str(info["remaining"]),
|
||||
"X-RateLimit-Reset": str(info["reset"]),
|
||||
"Retry-After": str(info["retry_after"])
|
||||
}
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail=f"Rate limit exceeded. Try again in {info['retry_after']} seconds.",
|
||||
headers=headers
|
||||
)
|
||||
|
||||
# Process request
|
||||
response = await call_next(request)
|
||||
|
||||
# Add rate limit headers to response
|
||||
response.headers["X-RateLimit-Limit"] = str(info["limit"])
|
||||
response.headers["X-RateLimit-Remaining"] = str(info["remaining"])
|
||||
response.headers["X-RateLimit-Reset"] = str(info["reset"])
|
||||
|
||||
return response
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Rate limiting error", error=str(e))
|
||||
# Continue without rate limiting on errors
|
||||
return await call_next(request)
|
||||
|
||||
def _get_rate_limit_category(self, path: str) -> str:
|
||||
"""Determine rate limit category based on request path"""
|
||||
for pattern, category in ROUTE_PATTERNS.items():
|
||||
if path.startswith(pattern):
|
||||
return category
|
||||
return "global"
|
||||
|
||||
def _get_client_ip(self, request: Request) -> str:
|
||||
"""Extract client IP address from request headers"""
|
||||
# Check for IP in common proxy headers
|
||||
forwarded_for = request.headers.get("x-forwarded-for")
|
||||
if forwarded_for:
|
||||
return forwarded_for.split(",")[0].strip()
|
||||
|
||||
real_ip = request.headers.get("x-real-ip")
|
||||
if real_ip:
|
||||
return real_ip
|
||||
|
||||
# Fallback to direct client IP
|
||||
if request.client:
|
||||
return request.client.host
|
||||
|
||||
return "unknown"
|
||||
|
||||
|
||||
# Enhanced rate limiting for authenticated users
|
||||
class AuthenticatedRateLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""Enhanced rate limiting with user-based limits for authenticated requests"""
|
||||
|
||||
def __init__(self, app, store: Optional[RateLimitStore] = None):
|
||||
super().__init__(app)
|
||||
self.store = store or RateLimitStore()
|
||||
|
||||
# Higher limits for authenticated users
|
||||
self.auth_limits = {
|
||||
"api": {
|
||||
"requests": settings.auth_rl_api_requests,
|
||||
"window": settings.auth_rl_api_window_seconds,
|
||||
},
|
||||
"search": {
|
||||
"requests": settings.auth_rl_search_requests,
|
||||
"window": settings.auth_rl_search_window_seconds,
|
||||
},
|
||||
"upload": {
|
||||
"requests": settings.auth_rl_upload_requests,
|
||||
"window": settings.auth_rl_upload_window_seconds,
|
||||
},
|
||||
"admin": {
|
||||
"requests": settings.auth_rl_admin_requests,
|
||||
"window": settings.auth_rl_admin_window_seconds,
|
||||
},
|
||||
}
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
# Skip entirely during pytest runs
|
||||
if os.getenv("PYTEST_RUNNING") == "1":
|
||||
return await call_next(request)
|
||||
|
||||
# Only apply to API endpoints
|
||||
if not request.url.path.startswith("/api/"):
|
||||
return await call_next(request)
|
||||
|
||||
# Allow disabling via settings (useful for local/dev)
|
||||
if not settings.auth_rl_enabled:
|
||||
return await call_next(request)
|
||||
|
||||
# Skip if user not authenticated
|
||||
user_id = None
|
||||
try:
|
||||
if hasattr(request.state, "user") and request.state.user:
|
||||
user_id = getattr(request.state.user, "id", None) or getattr(request.state.user, "username", None)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not user_id:
|
||||
return await call_next(request)
|
||||
|
||||
# Determine category and get enhanced limits for authenticated users
|
||||
category = self._get_rate_limit_category(request.url.path)
|
||||
if category in self.auth_limits:
|
||||
rate_config = self.auth_limits[category]
|
||||
rate_key = f"auth:{category}:{user_id}"
|
||||
|
||||
try:
|
||||
allowed, info = self.store.is_allowed(
|
||||
rate_key,
|
||||
rate_config["requests"],
|
||||
rate_config["window"]
|
||||
)
|
||||
|
||||
if not allowed:
|
||||
logger.warning(
|
||||
"Authenticated user rate limit exceeded",
|
||||
user_id=user_id,
|
||||
path=request.url.path,
|
||||
category=category,
|
||||
limit=info["limit"]
|
||||
)
|
||||
|
||||
headers = {
|
||||
"X-RateLimit-Limit": str(info["limit"]),
|
||||
"X-RateLimit-Remaining": str(info["remaining"]),
|
||||
"X-RateLimit-Reset": str(info["reset"]),
|
||||
"Retry-After": str(info["retry_after"])
|
||||
}
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail=f"Rate limit exceeded for authenticated user. Try again in {info['retry_after']} seconds.",
|
||||
headers=headers
|
||||
)
|
||||
|
||||
# Add auth-specific rate limit headers
|
||||
response = await call_next(request)
|
||||
response.headers["X-Auth-RateLimit-Limit"] = str(info["limit"])
|
||||
response.headers["X-Auth-RateLimit-Remaining"] = str(info["remaining"])
|
||||
|
||||
return response
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Authenticated rate limiting error", error=str(e))
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
def _get_rate_limit_category(self, path: str) -> str:
|
||||
"""Determine rate limit category based on request path"""
|
||||
for pattern, category in ROUTE_PATTERNS.items():
|
||||
if path.startswith(pattern):
|
||||
return category
|
||||
return "api"
|
||||
|
||||
|
||||
# Global store instance
|
||||
rate_limit_store = RateLimitStore()
|
||||
|
||||
# Rate limiting utilities
|
||||
async def check_rate_limit(key: str, limit: int, window: int) -> bool:
|
||||
"""Check if a specific key is within rate limits"""
|
||||
allowed, _ = rate_limit_store.is_allowed(key, limit, window)
|
||||
return allowed
|
||||
|
||||
async def get_rate_limit_info(key: str, limit: int, window: int) -> Dict[str, int]:
|
||||
"""Get rate limit information for a key"""
|
||||
_, info = rate_limit_store.is_allowed(key, limit, window)
|
||||
return info
|
||||
406
app/middleware/security_headers.py
Normal file
406
app/middleware/security_headers.py
Normal file
@@ -0,0 +1,406 @@
|
||||
"""
|
||||
Security Headers Middleware
|
||||
|
||||
Implements comprehensive security headers to protect against common web vulnerabilities:
|
||||
- HSTS (HTTP Strict Transport Security)
|
||||
- CSP (Content Security Policy)
|
||||
- X-Frame-Options (Clickjacking protection)
|
||||
- X-Content-Type-Options (MIME sniffing protection)
|
||||
- X-XSS-Protection (XSS protection)
|
||||
- Referrer-Policy (Information disclosure protection)
|
||||
- Permissions-Policy (Feature policy)
|
||||
"""
|
||||
from typing import Callable, Dict, Optional
|
||||
from uuid import uuid4
|
||||
from fastapi import Request
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import Response
|
||||
from app.config import settings
|
||||
from app.utils.logging import app_logger
|
||||
|
||||
logger = app_logger.bind(name="security_headers")
|
||||
|
||||
|
||||
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware to add security headers to all responses"""
|
||||
|
||||
def __init__(self, app, config: Optional[Dict[str, str]] = None):
|
||||
super().__init__(app)
|
||||
self.config = config or {}
|
||||
self.headers = self._build_security_headers()
|
||||
|
||||
def _build_security_headers(self) -> Dict[str, str]:
|
||||
"""Build security headers based on configuration and environment"""
|
||||
|
||||
# Base security headers
|
||||
headers = {
|
||||
# Prevent MIME type sniffing
|
||||
"X-Content-Type-Options": "nosniff",
|
||||
|
||||
# XSS Protection (legacy but still useful)
|
||||
"X-XSS-Protection": "1; mode=block",
|
||||
|
||||
# Clickjacking protection
|
||||
"X-Frame-Options": "DENY",
|
||||
|
||||
# Referrer policy
|
||||
"Referrer-Policy": "strict-origin-when-cross-origin",
|
||||
|
||||
# Remove server information
|
||||
"Server": "Delphi-DB",
|
||||
|
||||
# Prevent exposure of sensitive information
|
||||
"X-Powered-By": "",
|
||||
}
|
||||
|
||||
# HSTS (HTTP Strict Transport Security) - only for HTTPS
|
||||
if self._is_https_environment():
|
||||
headers["Strict-Transport-Security"] = self.config.get(
|
||||
"hsts",
|
||||
"max-age=31536000; includeSubDomains; preload"
|
||||
)
|
||||
|
||||
# Content Security Policy
|
||||
csp = self._build_csp_header()
|
||||
if csp:
|
||||
headers["Content-Security-Policy"] = csp
|
||||
|
||||
# Permissions Policy (Feature Policy)
|
||||
permissions_policy = self._build_permissions_policy()
|
||||
if permissions_policy:
|
||||
headers["Permissions-Policy"] = permissions_policy
|
||||
|
||||
return headers
|
||||
|
||||
def _is_https_environment(self) -> bool:
|
||||
"""Check if we're in an HTTPS environment"""
|
||||
# Check common HTTPS indicators
|
||||
if self.config.get("force_https", False):
|
||||
return True
|
||||
|
||||
# In production, assume HTTPS
|
||||
if not settings.debug:
|
||||
return True
|
||||
|
||||
# Check for secure cookies setting
|
||||
if settings.secure_cookies:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _build_csp_header(self) -> str:
|
||||
"""Build Content Security Policy header"""
|
||||
|
||||
# Get domain configuration
|
||||
domain = self.config.get("domain", "'self'")
|
||||
|
||||
# CSP directives for the application
|
||||
csp_directives = {
|
||||
# Default source
|
||||
"default-src": ["'self'"],
|
||||
|
||||
# Script sources - allow self and inline scripts for the app
|
||||
"script-src": [
|
||||
"'self'",
|
||||
"'unsafe-inline'", # Required for inline event handlers
|
||||
"https://cdn.tailwindcss.com", # Tailwind CSS CDN if used
|
||||
],
|
||||
|
||||
# Style sources - allow self and inline styles
|
||||
"style-src": [
|
||||
"'self'",
|
||||
"'unsafe-inline'", # Required for component styling
|
||||
"https://fonts.googleapis.com",
|
||||
"https://cdn.tailwindcss.com",
|
||||
],
|
||||
|
||||
# Font sources
|
||||
"font-src": [
|
||||
"'self'",
|
||||
"https://fonts.gstatic.com",
|
||||
"data:",
|
||||
],
|
||||
|
||||
# Image sources
|
||||
"img-src": [
|
||||
"'self'",
|
||||
"data:",
|
||||
"blob:",
|
||||
"https:", # Allow HTTPS images
|
||||
],
|
||||
|
||||
# Media sources
|
||||
"media-src": ["'self'", "blob:"],
|
||||
|
||||
# Object sources (disable Flash, etc.)
|
||||
"object-src": ["'none'"],
|
||||
|
||||
# Frame sources (for embedding)
|
||||
"frame-src": ["'none'"],
|
||||
|
||||
# Connect sources (AJAX, WebSocket, etc.)
|
||||
"connect-src": [
|
||||
"'self'",
|
||||
"wss:", # WebSocket support
|
||||
"ws:", # WebSocket support
|
||||
],
|
||||
|
||||
# Worker sources
|
||||
"worker-src": ["'self'", "blob:"],
|
||||
|
||||
# Child sources
|
||||
"child-src": ["'none'"],
|
||||
|
||||
# Form action restrictions
|
||||
"form-action": ["'self'"],
|
||||
|
||||
# Frame ancestors (clickjacking protection)
|
||||
"frame-ancestors": ["'none'"],
|
||||
|
||||
# Base URI restrictions
|
||||
"base-uri": ["'self'"],
|
||||
|
||||
# Manifest sources
|
||||
"manifest-src": ["'self'"],
|
||||
}
|
||||
|
||||
# Build CSP string
|
||||
csp_parts = []
|
||||
for directive, sources in csp_directives.items():
|
||||
csp_parts.append(f"{directive} {' '.join(sources)}")
|
||||
|
||||
# Add upgrade insecure requests in HTTPS environments
|
||||
if self._is_https_environment():
|
||||
csp_parts.append("upgrade-insecure-requests")
|
||||
|
||||
return "; ".join(csp_parts)
|
||||
|
||||
def _build_permissions_policy(self) -> str:
|
||||
"""Build Permissions Policy header"""
|
||||
|
||||
# Restrictive permissions policy
|
||||
policies = {
|
||||
# Disable camera access
|
||||
"camera": "(),",
|
||||
|
||||
# Disable microphone access
|
||||
"microphone": "(),",
|
||||
|
||||
# Disable geolocation
|
||||
"geolocation": "(),",
|
||||
|
||||
# Disable gyroscope
|
||||
"gyroscope": "(),",
|
||||
|
||||
# Disable magnetometer
|
||||
"magnetometer": "(),",
|
||||
|
||||
# Disable payment API
|
||||
"payment": "(),",
|
||||
|
||||
# Disable USB access
|
||||
"usb": "(),",
|
||||
|
||||
# Disable notifications (except for self)
|
||||
"notifications": "(self),",
|
||||
|
||||
# Disable push messaging
|
||||
"push": "(),",
|
||||
|
||||
# Disable speaker selection
|
||||
"speaker-selection": "(),",
|
||||
|
||||
# Allow clipboard access for self
|
||||
"clipboard-write": "(self),",
|
||||
"clipboard-read": "(self),",
|
||||
|
||||
# Allow fullscreen for self
|
||||
"fullscreen": "(self),",
|
||||
}
|
||||
|
||||
return " ".join([f"{feature}={policy}" for feature, policy in policies.items()])
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
# Process the request
|
||||
response = await call_next(request)
|
||||
|
||||
# Add security headers to all responses
|
||||
for header, value in self.headers.items():
|
||||
if value: # Only add non-empty headers
|
||||
response.headers[header] = value
|
||||
|
||||
# Special handling for certain endpoints
|
||||
self._apply_endpoint_specific_headers(request, response)
|
||||
|
||||
return response
|
||||
|
||||
def _apply_endpoint_specific_headers(self, request: Request, response: Response):
|
||||
"""Apply endpoint-specific security headers"""
|
||||
|
||||
path = request.url.path
|
||||
|
||||
# Admin pages - extra security
|
||||
if path.startswith(("/admin", "/api/admin/")):
|
||||
# More restrictive CSP for admin pages
|
||||
response.headers["X-Frame-Options"] = "DENY"
|
||||
response.headers["Cache-Control"] = "no-store, no-cache, must-revalidate"
|
||||
response.headers["Pragma"] = "no-cache"
|
||||
|
||||
# API endpoints - prevent caching of sensitive data
|
||||
elif path.startswith("/api/"):
|
||||
# Prevent caching of API responses
|
||||
if request.method != "GET" or "auth" in path or "admin" in path:
|
||||
response.headers["Cache-Control"] = "no-store, no-cache, must-revalidate"
|
||||
response.headers["Pragma"] = "no-cache"
|
||||
|
||||
# File upload endpoints - additional validation headers
|
||||
elif "upload" in path:
|
||||
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||
response.headers["Cache-Control"] = "no-store"
|
||||
|
||||
# Static files - allow caching but with security
|
||||
elif path.startswith(("/static/", "/uploads/")):
|
||||
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||
# Allow caching for static resources
|
||||
if "static" in path:
|
||||
response.headers["Cache-Control"] = "public, max-age=31536000"
|
||||
|
||||
|
||||
class RequestSizeLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware to limit request body size to prevent DoS attacks"""
|
||||
|
||||
def __init__(self, app, max_size: int = 50 * 1024 * 1024): # 50MB default
|
||||
super().__init__(app)
|
||||
self.max_size = max_size
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
# Check Content-Length header
|
||||
content_length = request.headers.get("content-length")
|
||||
|
||||
if content_length:
|
||||
try:
|
||||
size = int(content_length)
|
||||
if size > self.max_size:
|
||||
logger.warning(
|
||||
"Request size limit exceeded",
|
||||
size=size,
|
||||
limit=self.max_size,
|
||||
path=request.url.path,
|
||||
ip=self._get_client_ip(request)
|
||||
)
|
||||
|
||||
# Build standardized error envelope with correlation id
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
# Resolve correlation id from state, headers, or generate
|
||||
correlation_id = (
|
||||
getattr(getattr(request, "state", object()), "correlation_id", None)
|
||||
or request.headers.get("x-correlation-id")
|
||||
or request.headers.get("x-request-id")
|
||||
or str(uuid4())
|
||||
)
|
||||
|
||||
body = {
|
||||
"success": False,
|
||||
"error": {
|
||||
"status": 413,
|
||||
"code": "http_error",
|
||||
"message": "Payload too large",
|
||||
},
|
||||
"correlation_id": correlation_id,
|
||||
}
|
||||
|
||||
response = JSONResponse(status_code=413, content=body)
|
||||
response.headers["X-Correlation-ID"] = correlation_id
|
||||
return response
|
||||
except ValueError:
|
||||
pass # Invalid Content-Length header, let it pass
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
def _get_client_ip(self, request: Request) -> str:
|
||||
"""Extract client IP address from request headers"""
|
||||
forwarded_for = request.headers.get("x-forwarded-for")
|
||||
if forwarded_for:
|
||||
return forwarded_for.split(",")[0].strip()
|
||||
|
||||
real_ip = request.headers.get("x-real-ip")
|
||||
if real_ip:
|
||||
return real_ip
|
||||
|
||||
if request.client:
|
||||
return request.client.host
|
||||
|
||||
return "unknown"
|
||||
|
||||
|
||||
class CSRFMiddleware(BaseHTTPMiddleware):
|
||||
"""CSRF protection middleware for state-changing operations"""
|
||||
|
||||
def __init__(self, app, exempt_paths: Optional[list] = None):
|
||||
super().__init__(app)
|
||||
# Paths that don't require CSRF protection
|
||||
self.exempt_paths = exempt_paths or [
|
||||
"/api/auth/login",
|
||||
"/api/auth/refresh",
|
||||
"/health",
|
||||
"/static/",
|
||||
"/uploads/",
|
||||
]
|
||||
# HTTP methods that require CSRF protection
|
||||
self.protected_methods = {"POST", "PUT", "DELETE", "PATCH"}
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
# Skip CSRF protection for exempt paths
|
||||
if any(request.url.path.startswith(path) for path in self.exempt_paths):
|
||||
return await call_next(request)
|
||||
|
||||
# Skip for safe HTTP methods
|
||||
if request.method not in self.protected_methods:
|
||||
return await call_next(request)
|
||||
|
||||
# Check for CSRF token in headers
|
||||
csrf_token = request.headers.get("X-CSRF-Token") or request.headers.get("X-CSRFToken")
|
||||
|
||||
# For now, implement a simple CSRF check based on Referer/Origin
|
||||
# In production, you'd want proper CSRF tokens
|
||||
referer = request.headers.get("referer", "")
|
||||
origin = request.headers.get("origin", "")
|
||||
host = request.headers.get("host", "")
|
||||
|
||||
# Allow requests from same origin
|
||||
valid_origins = [f"https://{host}", f"http://{host}"]
|
||||
|
||||
if origin and origin not in valid_origins:
|
||||
if not referer or not any(referer.startswith(valid) for valid in valid_origins):
|
||||
logger.warning(
|
||||
"CSRF check failed",
|
||||
path=request.url.path,
|
||||
method=request.method,
|
||||
origin=origin,
|
||||
referer=referer,
|
||||
ip=self._get_client_ip(request)
|
||||
)
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="CSRF validation failed"
|
||||
)
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
def _get_client_ip(self, request: Request) -> str:
|
||||
"""Extract client IP address from request headers"""
|
||||
forwarded_for = request.headers.get("x-forwarded-for")
|
||||
if forwarded_for:
|
||||
return forwarded_for.split(",")[0].strip()
|
||||
|
||||
real_ip = request.headers.get("x-real-ip")
|
||||
if real_ip:
|
||||
return real_ip
|
||||
|
||||
if request.client:
|
||||
return request.client.host
|
||||
|
||||
return "unknown"
|
||||
319
app/middleware/session_middleware.py
Normal file
319
app/middleware/session_middleware.py
Normal file
@@ -0,0 +1,319 @@
|
||||
"""
|
||||
Session management middleware for P2 security features
|
||||
"""
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
from fastapi import Request, Response
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database.base import get_db
|
||||
from app.utils.session_manager import SessionManager
|
||||
from app.core.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class SessionManagementMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Advanced session management middleware
|
||||
|
||||
Features:
|
||||
- Session validation and renewal
|
||||
- Activity tracking
|
||||
- Concurrent session limits
|
||||
- Session fixation protection
|
||||
- Automatic cleanup
|
||||
"""
|
||||
|
||||
# Endpoints that don't require session validation
|
||||
EXCLUDED_PATHS = {
|
||||
"/docs", "/redoc", "/openapi.json",
|
||||
"/api/auth/login", "/api/auth/register",
|
||||
"/api/auth/refresh", "/api/health",
|
||||
"/health", "/ready", "/metrics",
|
||||
}
|
||||
|
||||
def __init__(self, app, cleanup_interval: int = 3600):
|
||||
super().__init__(app)
|
||||
self.cleanup_interval = cleanup_interval # 1 hour default
|
||||
self.last_cleanup = time.time()
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
"""Main middleware dispatcher"""
|
||||
start_time = time.time()
|
||||
|
||||
# Skip middleware for excluded paths
|
||||
if self._should_skip_middleware(request):
|
||||
return await call_next(request)
|
||||
|
||||
# Get database session
|
||||
db = next(get_db())
|
||||
session_manager = SessionManager(db)
|
||||
|
||||
try:
|
||||
# Perform periodic cleanup
|
||||
await self._periodic_cleanup(session_manager)
|
||||
|
||||
# Process session validation
|
||||
session_info = await self._validate_session(request, session_manager)
|
||||
|
||||
# Add session info to request state
|
||||
request.state.session_info = session_info
|
||||
# Also expose authenticated user (if any) for downstream middleware (e.g., user-based rate limiting, logging)
|
||||
try:
|
||||
request.state.user = session_info.get("user") if session_info else None
|
||||
except Exception:
|
||||
# Be resilient: never break the request due to state propagation
|
||||
pass
|
||||
|
||||
# Fallback: if no server-side session identified, try to attach user from Authorization token for JWT-only flows
|
||||
if not getattr(request.state, "user", None):
|
||||
try:
|
||||
auth_header = request.headers.get("authorization") or request.headers.get("Authorization")
|
||||
if auth_header and auth_header.lower().startswith("bearer "):
|
||||
token = auth_header.split(" ", 1)[1].strip()
|
||||
from app.auth.security import verify_token # local import to avoid circular deps
|
||||
username = verify_token(token)
|
||||
if username:
|
||||
from app.models.user import User # local import
|
||||
user = session_manager.db.query(User).filter(User.username == username).first()
|
||||
if user and user.is_active:
|
||||
request.state.user = user
|
||||
except Exception:
|
||||
# Never fail the request if auth attachment fails here
|
||||
pass
|
||||
|
||||
# Process request
|
||||
response = await call_next(request)
|
||||
|
||||
# Update session activity
|
||||
if session_info and session_info.get("session"):
|
||||
await self._update_session_activity(
|
||||
request, response, session_info["session"], session_manager, start_time
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Session middleware error: {str(e)}")
|
||||
# Re-raise to be handled by global error handlers; do not re-invoke downstream app
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def _should_skip_middleware(self, request: Request) -> bool:
|
||||
"""Check if middleware should be skipped for this request"""
|
||||
path = request.url.path
|
||||
|
||||
# Skip excluded paths
|
||||
if any(path.startswith(excluded) for excluded in self.EXCLUDED_PATHS):
|
||||
return True
|
||||
|
||||
# Skip static files
|
||||
if path.startswith("/static/") or path.startswith("/favicon.ico"):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def _validate_session(self, request: Request, session_manager: SessionManager) -> Optional[dict]:
|
||||
"""Validate session from request"""
|
||||
# Extract session ID from various sources
|
||||
session_id = await self._extract_session_id(request)
|
||||
|
||||
if not session_id:
|
||||
return None
|
||||
|
||||
# Validate session
|
||||
session = session_manager.validate_session(session_id, request)
|
||||
|
||||
if not session:
|
||||
return None
|
||||
|
||||
return {
|
||||
"session": session,
|
||||
"session_id": session_id,
|
||||
"user": session.user,
|
||||
"is_valid": True
|
||||
}
|
||||
|
||||
async def _extract_session_id(self, request: Request) -> Optional[str]:
|
||||
"""Extract session ID from request"""
|
||||
# Try cookie first
|
||||
session_id = request.cookies.get("session_id")
|
||||
if session_id:
|
||||
return session_id
|
||||
|
||||
# Try custom header
|
||||
session_id = request.headers.get("X-Session-ID")
|
||||
if session_id:
|
||||
return session_id
|
||||
|
||||
# For JWT-based sessions, extract from authorization header
|
||||
auth_header = request.headers.get("authorization")
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
# Use JWT token as session identifier for now
|
||||
# In a full implementation, you'd decode JWT and extract session ID
|
||||
token = auth_header[7:]
|
||||
return token[:32] if len(token) > 32 else token
|
||||
|
||||
return None
|
||||
|
||||
async def _update_session_activity(
|
||||
self,
|
||||
request: Request,
|
||||
response: Response,
|
||||
session,
|
||||
session_manager: SessionManager,
|
||||
start_time: float
|
||||
) -> None:
|
||||
"""Update session activity tracking"""
|
||||
try:
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
# Log API activity
|
||||
session_manager._log_activity(
|
||||
session, session.user, request,
|
||||
activity_type="api_request",
|
||||
endpoint=request.url.path
|
||||
)
|
||||
|
||||
# Update activity record with response details
|
||||
if hasattr(session, 'activities') and session.activities:
|
||||
latest_activity = session.activities[-1]
|
||||
latest_activity.status_code = getattr(response, 'status_code', None)
|
||||
latest_activity.duration_ms = duration_ms
|
||||
|
||||
# Analyze for suspicious patterns
|
||||
await self._analyze_activity_patterns(session, session_manager)
|
||||
|
||||
session_manager.db.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update session activity: {str(e)}")
|
||||
|
||||
async def _analyze_activity_patterns(self, session, session_manager: SessionManager) -> None:
|
||||
"""Analyze activity patterns for suspicious behavior"""
|
||||
try:
|
||||
# Get recent activities for this session
|
||||
recent_activities = session_manager.db.query(
|
||||
session_manager.db.query(type(session.activities[0]))
|
||||
).filter_by(session_id=session.id).order_by(
|
||||
type(session.activities[0]).timestamp.desc()
|
||||
).limit(10).all()
|
||||
|
||||
if len(recent_activities) < 5:
|
||||
return
|
||||
|
||||
# Check for rapid API calls (possible automation)
|
||||
time_diffs = []
|
||||
for i in range(1, len(recent_activities)):
|
||||
time_diff = (recent_activities[i-1].timestamp - recent_activities[i].timestamp).total_seconds()
|
||||
time_diffs.append(time_diff)
|
||||
|
||||
avg_time_diff = sum(time_diffs) / len(time_diffs)
|
||||
|
||||
# Flag if average time between requests is < 1 second
|
||||
if avg_time_diff < 1.0:
|
||||
session.risk_score = min(session.risk_score + 10, 100)
|
||||
session_manager._create_security_event(
|
||||
session, session.user,
|
||||
event_type="rapid_api_calls",
|
||||
severity="medium",
|
||||
description=f"Rapid API calls detected: avg {avg_time_diff:.2f}s between requests"
|
||||
)
|
||||
|
||||
# Lock session if risk score is too high
|
||||
if session.risk_score >= 80:
|
||||
session.lock_session("high_risk_activity")
|
||||
session_manager._create_security_event(
|
||||
session, session.user,
|
||||
event_type="session_locked",
|
||||
severity="high",
|
||||
description=f"Session locked due to high risk score: {session.risk_score}",
|
||||
action_taken="session_locked"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to analyze activity patterns: {str(e)}")
|
||||
|
||||
async def _periodic_cleanup(self, session_manager: SessionManager) -> None:
|
||||
"""Perform periodic cleanup of expired sessions"""
|
||||
current_time = time.time()
|
||||
|
||||
if current_time - self.last_cleanup > self.cleanup_interval:
|
||||
try:
|
||||
cleaned_count = session_manager.cleanup_expired_sessions()
|
||||
self.last_cleanup = current_time
|
||||
|
||||
if cleaned_count > 0:
|
||||
logger.info(f"Cleaned up {cleaned_count} expired sessions")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cleanup sessions: {str(e)}")
|
||||
|
||||
|
||||
class SessionSecurityMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Additional security middleware for session protection
|
||||
"""
|
||||
|
||||
def __init__(self, app):
|
||||
super().__init__(app)
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
"""Process security checks for sessions"""
|
||||
|
||||
# Add security headers for session management
|
||||
response = await call_next(request)
|
||||
|
||||
# Add session security headers
|
||||
if isinstance(response, Response):
|
||||
# Prevent session fixation
|
||||
response.headers["X-Session-Security"] = "fixation-protected"
|
||||
|
||||
# Indicate session management is active
|
||||
response.headers["X-Session-Management"] = "active"
|
||||
|
||||
# Add cache control for session-sensitive pages
|
||||
if request.url.path.startswith("/api/"):
|
||||
response.headers["Cache-Control"] = "no-store, no-cache, must-revalidate"
|
||||
response.headers["Pragma"] = "no-cache"
|
||||
response.headers["Expires"] = "0"
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class SessionCookieMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Secure cookie management for sessions
|
||||
"""
|
||||
|
||||
def __init__(self, app, secure: bool = True, same_site: str = "strict"):
|
||||
super().__init__(app)
|
||||
self.secure = secure
|
||||
self.same_site = same_site
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
"""Handle secure session cookies"""
|
||||
response = await call_next(request)
|
||||
|
||||
# Check if we need to set session cookie
|
||||
session_info = getattr(request.state, 'session_info', None)
|
||||
|
||||
if session_info and session_info.get("session"):
|
||||
session_id = session_info["session_id"]
|
||||
|
||||
# Set secure session cookie
|
||||
response.set_cookie(
|
||||
key="session_id",
|
||||
value=session_id,
|
||||
max_age=28800, # 8 hours
|
||||
httponly=True,
|
||||
secure=self.secure,
|
||||
samesite=self.same_site
|
||||
)
|
||||
|
||||
return response
|
||||
439
app/middleware/websocket_middleware.py
Normal file
439
app/middleware/websocket_middleware.py
Normal file
@@ -0,0 +1,439 @@
|
||||
"""
|
||||
WebSocket Middleware and Utilities
|
||||
|
||||
This module provides middleware and utilities for WebSocket connections,
|
||||
including authentication, connection management, and integration with the
|
||||
WebSocket pool system.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Optional, Dict, Any, Set, Callable, Awaitable
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
from fastapi import WebSocket, WebSocketDisconnect, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database.base import SessionLocal
|
||||
from app.models.user import User
|
||||
from app.auth.security import verify_token
|
||||
from app.services.websocket_pool import (
|
||||
get_websocket_pool,
|
||||
websocket_connection,
|
||||
WebSocketMessage,
|
||||
MessageType
|
||||
)
|
||||
from app.utils.logging import StructuredLogger
|
||||
|
||||
|
||||
class WebSocketAuthenticationError(Exception):
|
||||
"""Raised when WebSocket authentication fails"""
|
||||
pass
|
||||
|
||||
|
||||
class WebSocketManager:
|
||||
"""
|
||||
High-level WebSocket manager that provides easy-to-use methods
|
||||
for handling WebSocket connections with authentication and topic management
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.logger = StructuredLogger("websocket_manager", "INFO")
|
||||
self.pool = get_websocket_pool()
|
||||
|
||||
async def authenticate_websocket(self, websocket: WebSocket) -> Optional[User]:
|
||||
"""
|
||||
Authenticate a WebSocket connection using token from query parameters
|
||||
|
||||
Args:
|
||||
websocket: WebSocket instance
|
||||
|
||||
Returns:
|
||||
User object if authentication successful, None otherwise
|
||||
"""
|
||||
try:
|
||||
# Get token from query parameters
|
||||
query_params = parse_qs(str(websocket.url.query))
|
||||
token = query_params.get('token', [None])[0]
|
||||
|
||||
if not token:
|
||||
self.logger.warning("WebSocket authentication failed: no token provided")
|
||||
return None
|
||||
|
||||
# Verify token
|
||||
username = verify_token(token)
|
||||
if not username:
|
||||
self.logger.warning("WebSocket authentication failed: invalid token")
|
||||
return None
|
||||
|
||||
# Get user from database
|
||||
db: Session = SessionLocal()
|
||||
try:
|
||||
user = db.query(User).filter(User.username == username).first()
|
||||
if not user or not user.is_active:
|
||||
self.logger.warning("WebSocket authentication failed: user not found or inactive",
|
||||
username=username)
|
||||
return None
|
||||
|
||||
self.logger.info("WebSocket authentication successful",
|
||||
user_id=user.id,
|
||||
username=user.username)
|
||||
return user
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error("WebSocket authentication error", error=str(e))
|
||||
return None
|
||||
|
||||
async def handle_connection(
|
||||
self,
|
||||
websocket: WebSocket,
|
||||
topics: Optional[Set[str]] = None,
|
||||
require_auth: bool = True,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
message_handler: Optional[Callable[[str, WebSocketMessage], Awaitable[None]]] = None
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Handle a WebSocket connection with authentication and message processing
|
||||
|
||||
Args:
|
||||
websocket: WebSocket instance
|
||||
topics: Initial topics to subscribe to
|
||||
require_auth: Whether authentication is required
|
||||
metadata: Additional metadata for the connection
|
||||
message_handler: Optional function to handle incoming messages
|
||||
|
||||
Returns:
|
||||
Connection ID if successful, None if failed
|
||||
"""
|
||||
user = None
|
||||
if require_auth:
|
||||
user = await self.authenticate_websocket(websocket)
|
||||
if not user:
|
||||
await websocket.close(code=4401, reason="Authentication failed")
|
||||
return None
|
||||
|
||||
# Accept the connection
|
||||
await websocket.accept()
|
||||
|
||||
# Add to pool
|
||||
user_id = user.id if user else None
|
||||
async with websocket_connection(
|
||||
websocket=websocket,
|
||||
user_id=user_id,
|
||||
topics=topics,
|
||||
metadata=metadata
|
||||
) as (connection_id, pool):
|
||||
|
||||
# Set connection state to connected
|
||||
connection_info = await pool.get_connection_info(connection_id)
|
||||
if connection_info:
|
||||
connection_info.state = connection_info.state.CONNECTED
|
||||
|
||||
# Send initial welcome message
|
||||
welcome_message = WebSocketMessage(
|
||||
type="welcome",
|
||||
data={
|
||||
"connection_id": connection_id,
|
||||
"user_id": user_id,
|
||||
"topics": list(topics) if topics else [],
|
||||
"timestamp": connection_info.created_at.isoformat() if connection_info else None
|
||||
}
|
||||
)
|
||||
await pool._send_to_connection(connection_id, welcome_message)
|
||||
|
||||
# Handle messages
|
||||
await self._message_loop(
|
||||
websocket=websocket,
|
||||
connection_id=connection_id,
|
||||
pool=pool,
|
||||
message_handler=message_handler
|
||||
)
|
||||
|
||||
return connection_id
|
||||
|
||||
async def _message_loop(
|
||||
self,
|
||||
websocket: WebSocket,
|
||||
connection_id: str,
|
||||
pool,
|
||||
message_handler: Optional[Callable[[str, WebSocketMessage], Awaitable[None]]] = None
|
||||
):
|
||||
"""Handle incoming WebSocket messages"""
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
# Receive message
|
||||
data = await websocket.receive_text()
|
||||
|
||||
# Update activity
|
||||
connection_info = await pool.get_connection_info(connection_id)
|
||||
if connection_info:
|
||||
connection_info.update_activity()
|
||||
|
||||
# Parse message
|
||||
try:
|
||||
import json
|
||||
message_dict = json.loads(data)
|
||||
message = WebSocketMessage(**message_dict)
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
self.logger.warning("Invalid message format",
|
||||
connection_id=connection_id,
|
||||
error=str(e),
|
||||
data=data[:100])
|
||||
continue
|
||||
|
||||
# Handle standard message types
|
||||
await self._handle_standard_message(connection_id, message, pool)
|
||||
|
||||
# Call custom message handler if provided
|
||||
if message_handler:
|
||||
try:
|
||||
await message_handler(connection_id, message)
|
||||
except Exception as e:
|
||||
self.logger.error("Error in custom message handler",
|
||||
connection_id=connection_id,
|
||||
error=str(e))
|
||||
|
||||
except WebSocketDisconnect:
|
||||
self.logger.info("WebSocket disconnected", connection_id=connection_id)
|
||||
break
|
||||
except Exception as e:
|
||||
self.logger.error("Error in message loop",
|
||||
connection_id=connection_id,
|
||||
error=str(e))
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error("Fatal error in message loop",
|
||||
connection_id=connection_id,
|
||||
error=str(e))
|
||||
|
||||
async def _handle_standard_message(self, connection_id: str, message: WebSocketMessage, pool):
|
||||
"""Handle standard WebSocket message types"""
|
||||
|
||||
if message.type == MessageType.PING.value:
|
||||
# Respond with pong
|
||||
pong_message = WebSocketMessage(
|
||||
type=MessageType.PONG.value,
|
||||
data={"timestamp": message.timestamp}
|
||||
)
|
||||
await pool._send_to_connection(connection_id, pong_message)
|
||||
|
||||
elif message.type == MessageType.PONG.value:
|
||||
# Handle pong response
|
||||
await pool.handle_pong(connection_id)
|
||||
|
||||
elif message.type == MessageType.SUBSCRIBE.value:
|
||||
# Subscribe to topic
|
||||
topic = message.topic
|
||||
if topic:
|
||||
success = await pool.subscribe_to_topic(connection_id, topic)
|
||||
response = WebSocketMessage(
|
||||
type="subscription_response",
|
||||
topic=topic,
|
||||
data={"success": success, "action": "subscribe"}
|
||||
)
|
||||
await pool._send_to_connection(connection_id, response)
|
||||
|
||||
elif message.type == MessageType.UNSUBSCRIBE.value:
|
||||
# Unsubscribe from topic
|
||||
topic = message.topic
|
||||
if topic:
|
||||
success = await pool.unsubscribe_from_topic(connection_id, topic)
|
||||
response = WebSocketMessage(
|
||||
type="subscription_response",
|
||||
topic=topic,
|
||||
data={"success": success, "action": "unsubscribe"}
|
||||
)
|
||||
await pool._send_to_connection(connection_id, response)
|
||||
|
||||
async def broadcast_to_topic(
|
||||
self,
|
||||
topic: str,
|
||||
message_type: str,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
exclude_connection_id: Optional[str] = None
|
||||
) -> int:
|
||||
"""Convenience method to broadcast a message to a topic"""
|
||||
message = WebSocketMessage(
|
||||
type=message_type,
|
||||
topic=topic,
|
||||
data=data
|
||||
)
|
||||
return await self.pool.broadcast_to_topic(topic, message, exclude_connection_id)
|
||||
|
||||
async def send_to_user(
|
||||
self,
|
||||
user_id: int,
|
||||
message_type: str,
|
||||
data: Optional[Dict[str, Any]] = None
|
||||
) -> int:
|
||||
"""Convenience method to send a message to all connections for a user"""
|
||||
message = WebSocketMessage(
|
||||
type=message_type,
|
||||
data=data
|
||||
)
|
||||
return await self.pool.send_to_user(user_id, message)
|
||||
|
||||
async def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get WebSocket pool statistics"""
|
||||
return await self.pool.get_stats()
|
||||
|
||||
|
||||
# Global WebSocket manager instance
|
||||
_websocket_manager: Optional[WebSocketManager] = None
|
||||
|
||||
|
||||
def get_websocket_manager() -> WebSocketManager:
|
||||
"""Get the global WebSocket manager instance"""
|
||||
global _websocket_manager
|
||||
if _websocket_manager is None:
|
||||
_websocket_manager = WebSocketManager()
|
||||
return _websocket_manager
|
||||
|
||||
|
||||
# Utility decorators and functions
|
||||
|
||||
def websocket_endpoint(
|
||||
topics: Optional[Set[str]] = None,
|
||||
require_auth: bool = True,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
"""
|
||||
Decorator for WebSocket endpoints that automatically handles
|
||||
connection management, authentication, and cleanup
|
||||
|
||||
Usage:
|
||||
@router.websocket("/my-endpoint")
|
||||
@websocket_endpoint(topics={"my_topic"}, require_auth=True)
|
||||
async def my_websocket_handler(websocket: WebSocket, connection_id: str, manager: WebSocketManager):
|
||||
# Your custom logic here
|
||||
pass
|
||||
"""
|
||||
def decorator(func):
|
||||
async def wrapper(websocket: WebSocket, *args, **kwargs):
|
||||
manager = get_websocket_manager()
|
||||
|
||||
async def message_handler(connection_id: str, message: WebSocketMessage):
|
||||
# Call the original function with the message
|
||||
await func(websocket, connection_id, manager, message, *args, **kwargs)
|
||||
|
||||
# Handle the connection
|
||||
connection_id = await manager.handle_connection(
|
||||
websocket=websocket,
|
||||
topics=topics,
|
||||
require_auth=require_auth,
|
||||
metadata=metadata,
|
||||
message_handler=message_handler
|
||||
)
|
||||
|
||||
if not connection_id:
|
||||
return
|
||||
|
||||
# Keep the connection alive
|
||||
try:
|
||||
while True:
|
||||
await asyncio.sleep(1)
|
||||
connection_info = await manager.pool.get_connection_info(connection_id)
|
||||
if not connection_info or not connection_info.is_alive():
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
async def websocket_auth_dependency(websocket: WebSocket) -> User:
|
||||
"""
|
||||
FastAPI dependency for WebSocket authentication
|
||||
|
||||
Usage:
|
||||
@router.websocket("/my-endpoint")
|
||||
async def my_endpoint(websocket: WebSocket, user: User = Depends(websocket_auth_dependency)):
|
||||
# user is guaranteed to be authenticated
|
||||
pass
|
||||
"""
|
||||
manager = get_websocket_manager()
|
||||
user = await manager.authenticate_websocket(websocket)
|
||||
if not user:
|
||||
await websocket.close(code=4401, reason="Authentication failed")
|
||||
raise WebSocketAuthenticationError("Authentication failed")
|
||||
return user
|
||||
|
||||
|
||||
class WebSocketConnectionTracker:
|
||||
"""
|
||||
Utility class to track WebSocket connections and their health
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.logger = StructuredLogger("websocket_tracker", "INFO")
|
||||
|
||||
async def track_connection_health(self, connection_id: str, interval: int = 60):
|
||||
"""Track the health of a specific connection"""
|
||||
pool = get_websocket_pool()
|
||||
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
connection_info = await pool.get_connection_info(connection_id)
|
||||
if not connection_info:
|
||||
break
|
||||
|
||||
# Check if connection is healthy
|
||||
if connection_info.is_stale(timeout_seconds=300):
|
||||
self.logger.warning("Connection is stale",
|
||||
connection_id=connection_id,
|
||||
last_activity=connection_info.last_activity.isoformat())
|
||||
break
|
||||
|
||||
# Try to ping the connection
|
||||
if connection_info.is_alive():
|
||||
success = await pool.ping_connection(connection_id)
|
||||
if not success:
|
||||
self.logger.warning("Failed to ping connection",
|
||||
connection_id=connection_id)
|
||||
break
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
self.logger.error("Error tracking connection health",
|
||||
connection_id=connection_id,
|
||||
error=str(e))
|
||||
break
|
||||
|
||||
async def get_connection_metrics(self, connection_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get detailed metrics for a connection"""
|
||||
pool = get_websocket_pool()
|
||||
connection_info = await pool.get_connection_info(connection_id)
|
||||
|
||||
if not connection_info:
|
||||
return None
|
||||
|
||||
now = connection_info.last_activity # Use last_activity for consistency
|
||||
return {
|
||||
"connection_id": connection_id,
|
||||
"user_id": connection_info.user_id,
|
||||
"state": connection_info.state.value,
|
||||
"topics": list(connection_info.topics),
|
||||
"created_at": connection_info.created_at.isoformat(),
|
||||
"last_activity": connection_info.last_activity.isoformat(),
|
||||
"age_seconds": (now - connection_info.created_at).total_seconds(),
|
||||
"idle_seconds": (now - connection_info.last_activity).total_seconds(),
|
||||
"error_count": connection_info.error_count,
|
||||
"last_ping": connection_info.last_ping.isoformat() if connection_info.last_ping else None,
|
||||
"last_pong": connection_info.last_pong.isoformat() if connection_info.last_pong else None,
|
||||
"metadata": connection_info.metadata,
|
||||
"is_alive": connection_info.is_alive(),
|
||||
"is_stale": connection_info.is_stale()
|
||||
}
|
||||
|
||||
|
||||
def get_connection_tracker() -> WebSocketConnectionTracker:
|
||||
"""Get a WebSocket connection tracker instance"""
|
||||
return WebSocketConnectionTracker()
|
||||
Reference in New Issue
Block a user