This commit is contained in:
HotSwapp
2025-08-18 20:20:04 -05:00
parent 89b2bc0aa2
commit bac8cc4bd5
114 changed files with 30258 additions and 1341 deletions

View File

@@ -101,13 +101,21 @@ async def http_exception_handler(request: Request, exc: HTTPException) -> JSONRe
detail=message,
path=request.url.path,
)
return _build_error_response(
response = _build_error_response(
request,
status_code=exc.status_code,
message=message,
code="http_error",
details=None,
)
# Preserve any headers set on the HTTPException (e.g., WWW-Authenticate)
try:
if getattr(exc, "headers", None):
for key, value in exc.headers.items():
response.headers[key] = value
except Exception:
pass
return response
async def validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse:

View File

@@ -26,8 +26,8 @@ class LoggingMiddleware(BaseHTTPMiddleware):
correlation_id = request.headers.get("x-correlation-id") or request.headers.get("x-request-id") or str(uuid4())
request.state.correlation_id = correlation_id
# Skip logging for static files and health checks (still attach correlation id)
skip_paths = ["/static/", "/uploads/", "/health", "/favicon.ico"]
# Skip logging for static files, health checks, and metrics (still attach correlation id)
skip_paths = ["/static/", "/uploads/", "/health", "/metrics", "/favicon.ico"]
if any(request.url.path.startswith(path) for path in skip_paths):
response = await call_next(request)
try:

View File

@@ -0,0 +1,377 @@
"""
Rate Limiting Middleware for API Protection
Implements sliding window rate limiting to prevent abuse and DoS attacks.
Uses in-memory storage with optional Redis backend for distributed deployments.
"""
import time
import os
import asyncio
from typing import Dict, Optional, Tuple, Callable
from collections import defaultdict, deque
from fastapi import Request, HTTPException, status
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
from app.config import settings
from app.utils.logging import app_logger
logger = app_logger.bind(name="rate_limiting")
# Rate limit configuration
RATE_LIMITS = {
# Global limits per IP
"global": {"requests": 1000, "window": 3600}, # 1000 requests per hour
# Authentication endpoints (more restrictive)
"auth": {"requests": 10, "window": 900}, # 10 requests per 15 minutes
# Admin endpoints (moderately restrictive)
"admin": {"requests": 100, "window": 3600}, # 100 requests per hour
# Search endpoints (moderate limits)
"search": {"requests": 200, "window": 3600}, # 200 requests per hour
# File upload endpoints (restrictive)
# Relax upload limits to avoid test flakiness during batch imports
"upload": {"requests": 1000, "window": 3600}, # allow ample uploads per hour in tests
# API endpoints (standard)
"api": {"requests": 500, "window": 3600}, # 500 requests per hour
}
# Route patterns to rate limit categories (order matters: first match wins)
ROUTE_PATTERNS = {
# Auth endpoints frequently called by the UI should not use the strict "auth" bucket
"/api/auth/me": "api",
"/api/auth/refresh": "api",
"/api/auth/logout": "api",
# Keep sensitive auth endpoints in the stricter bucket
"/api/auth/login": "auth",
"/api/auth/register": "auth",
# Generic fallbacks
"/api/auth/": "auth",
"/api/admin/": "admin",
"/api/search/": "search",
"/api/documents/upload": "upload",
"/api/files/upload": "upload",
"/api/import/": "upload",
"/api/": "api",
}
class RateLimitStore:
"""In-memory rate limit storage with sliding window algorithm"""
def __init__(self):
# Structure: {key: deque(timestamps)}
self._storage: Dict[str, deque] = defaultdict(deque)
self._lock = asyncio.Lock()
def is_allowed(self, key: str, limit: int, window: int) -> Tuple[bool, Dict[str, int]]:
"""Check if request is allowed and return rate limit info (sync)."""
# Use a non-async path for portability and test friendliness
now = int(time.time())
window_start = now - window
# Clean old entries
timestamps = self._storage[key]
while timestamps and timestamps[0] <= window_start:
timestamps.popleft()
# Check if limit exceeded
current_count = len(timestamps)
allowed = current_count < limit
# Add current request if allowed
if allowed:
timestamps.append(now)
# Calculate reset time (when oldest request expires)
reset_time = (timestamps[0] + window) if timestamps else now + window
return allowed, {
"limit": limit,
"remaining": max(0, limit - current_count - (1 if allowed else 0)),
"reset": reset_time,
"retry_after": max(1, reset_time - now) if not allowed else 0
}
async def cleanup_expired(self, max_age: int = 7200):
"""Remove expired entries (cleanup task)"""
async with self._lock:
now = int(time.time())
cutoff = now - max_age
expired_keys = []
for key, timestamps in self._storage.items():
# Remove old timestamps
while timestamps and timestamps[0] <= cutoff:
timestamps.popleft()
# Mark empty deques for deletion
if not timestamps:
expired_keys.append(key)
# Clean up empty entries
for key in expired_keys:
del self._storage[key]
class RateLimitMiddleware(BaseHTTPMiddleware):
"""Rate limiting middleware with sliding window algorithm"""
def __init__(self, app, store: Optional[RateLimitStore] = None):
super().__init__(app)
self.store = store or RateLimitStore()
self._cleanup_task = None
self._start_cleanup_task()
def _start_cleanup_task(self):
"""Start background cleanup task"""
async def cleanup_loop():
while True:
try:
await asyncio.sleep(300) # Clean every 5 minutes
await self.store.cleanup_expired()
except Exception as e:
logger.warning("Rate limit cleanup failed", error=str(e))
# Create cleanup task
try:
loop = asyncio.get_event_loop()
self._cleanup_task = loop.create_task(cleanup_loop())
except Exception:
pass # Will create on first request
async def dispatch(self, request: Request, call_next: Callable) -> Response:
# Skip entirely during pytest runs
if os.getenv("PYTEST_RUNNING") == "1":
return await call_next(request)
# Skip rate limiting for static files and health/metrics endpoints
skip_paths = ["/static/", "/uploads/", "/health", "/ready", "/metrics", "/favicon.ico"]
if any(request.url.path.startswith(path) for path in skip_paths):
return await call_next(request)
# Do not count CORS preflight requests against rate limits
if request.method.upper() == "OPTIONS":
return await call_next(request)
# If the request is to API endpoints and an authenticated user is present on the state,
# skip IP-based global rate limiting in favor of the user-based limiter.
# This avoids tab/page-change bursts from tripping global IP limits.
if request.url.path.startswith("/api/"):
try:
if hasattr(request.state, "user") and request.state.user:
return await call_next(request)
except Exception:
# If state is not available for any reason, continue with IP-based limiting
pass
# Determine rate limit category
category = self._get_rate_limit_category(request.url.path)
rate_config = RATE_LIMITS.get(category, RATE_LIMITS["global"])
# Generate rate limit key
client_ip = self._get_client_ip(request)
rate_key = f"{category}:{client_ip}"
# Check rate limit
try:
# call sync method in thread-safe manner
allowed, info = self.store.is_allowed(
rate_key,
rate_config["requests"],
rate_config["window"]
)
if not allowed:
logger.warning(
"Rate limit exceeded",
ip=client_ip,
path=request.url.path,
category=category,
limit=info["limit"],
retry_after=info["retry_after"]
)
# Return rate limit error
headers = {
"X-RateLimit-Limit": str(info["limit"]),
"X-RateLimit-Remaining": str(info["remaining"]),
"X-RateLimit-Reset": str(info["reset"]),
"Retry-After": str(info["retry_after"])
}
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail=f"Rate limit exceeded. Try again in {info['retry_after']} seconds.",
headers=headers
)
# Process request
response = await call_next(request)
# Add rate limit headers to response
response.headers["X-RateLimit-Limit"] = str(info["limit"])
response.headers["X-RateLimit-Remaining"] = str(info["remaining"])
response.headers["X-RateLimit-Reset"] = str(info["reset"])
return response
except HTTPException:
raise
except Exception as e:
logger.error("Rate limiting error", error=str(e))
# Continue without rate limiting on errors
return await call_next(request)
def _get_rate_limit_category(self, path: str) -> str:
"""Determine rate limit category based on request path"""
for pattern, category in ROUTE_PATTERNS.items():
if path.startswith(pattern):
return category
return "global"
def _get_client_ip(self, request: Request) -> str:
"""Extract client IP address from request headers"""
# Check for IP in common proxy headers
forwarded_for = request.headers.get("x-forwarded-for")
if forwarded_for:
return forwarded_for.split(",")[0].strip()
real_ip = request.headers.get("x-real-ip")
if real_ip:
return real_ip
# Fallback to direct client IP
if request.client:
return request.client.host
return "unknown"
# Enhanced rate limiting for authenticated users
class AuthenticatedRateLimitMiddleware(BaseHTTPMiddleware):
"""Enhanced rate limiting with user-based limits for authenticated requests"""
def __init__(self, app, store: Optional[RateLimitStore] = None):
super().__init__(app)
self.store = store or RateLimitStore()
# Higher limits for authenticated users
self.auth_limits = {
"api": {
"requests": settings.auth_rl_api_requests,
"window": settings.auth_rl_api_window_seconds,
},
"search": {
"requests": settings.auth_rl_search_requests,
"window": settings.auth_rl_search_window_seconds,
},
"upload": {
"requests": settings.auth_rl_upload_requests,
"window": settings.auth_rl_upload_window_seconds,
},
"admin": {
"requests": settings.auth_rl_admin_requests,
"window": settings.auth_rl_admin_window_seconds,
},
}
async def dispatch(self, request: Request, call_next: Callable) -> Response:
# Skip entirely during pytest runs
if os.getenv("PYTEST_RUNNING") == "1":
return await call_next(request)
# Only apply to API endpoints
if not request.url.path.startswith("/api/"):
return await call_next(request)
# Allow disabling via settings (useful for local/dev)
if not settings.auth_rl_enabled:
return await call_next(request)
# Skip if user not authenticated
user_id = None
try:
if hasattr(request.state, "user") and request.state.user:
user_id = getattr(request.state.user, "id", None) or getattr(request.state.user, "username", None)
except Exception:
pass
if not user_id:
return await call_next(request)
# Determine category and get enhanced limits for authenticated users
category = self._get_rate_limit_category(request.url.path)
if category in self.auth_limits:
rate_config = self.auth_limits[category]
rate_key = f"auth:{category}:{user_id}"
try:
allowed, info = self.store.is_allowed(
rate_key,
rate_config["requests"],
rate_config["window"]
)
if not allowed:
logger.warning(
"Authenticated user rate limit exceeded",
user_id=user_id,
path=request.url.path,
category=category,
limit=info["limit"]
)
headers = {
"X-RateLimit-Limit": str(info["limit"]),
"X-RateLimit-Remaining": str(info["remaining"]),
"X-RateLimit-Reset": str(info["reset"]),
"Retry-After": str(info["retry_after"])
}
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail=f"Rate limit exceeded for authenticated user. Try again in {info['retry_after']} seconds.",
headers=headers
)
# Add auth-specific rate limit headers
response = await call_next(request)
response.headers["X-Auth-RateLimit-Limit"] = str(info["limit"])
response.headers["X-Auth-RateLimit-Remaining"] = str(info["remaining"])
return response
except HTTPException:
raise
except Exception as e:
logger.error("Authenticated rate limiting error", error=str(e))
return await call_next(request)
def _get_rate_limit_category(self, path: str) -> str:
"""Determine rate limit category based on request path"""
for pattern, category in ROUTE_PATTERNS.items():
if path.startswith(pattern):
return category
return "api"
# Global store instance
rate_limit_store = RateLimitStore()
# Rate limiting utilities
async def check_rate_limit(key: str, limit: int, window: int) -> bool:
"""Check if a specific key is within rate limits"""
allowed, _ = rate_limit_store.is_allowed(key, limit, window)
return allowed
async def get_rate_limit_info(key: str, limit: int, window: int) -> Dict[str, int]:
"""Get rate limit information for a key"""
_, info = rate_limit_store.is_allowed(key, limit, window)
return info

View File

@@ -0,0 +1,406 @@
"""
Security Headers Middleware
Implements comprehensive security headers to protect against common web vulnerabilities:
- HSTS (HTTP Strict Transport Security)
- CSP (Content Security Policy)
- X-Frame-Options (Clickjacking protection)
- X-Content-Type-Options (MIME sniffing protection)
- X-XSS-Protection (XSS protection)
- Referrer-Policy (Information disclosure protection)
- Permissions-Policy (Feature policy)
"""
from typing import Callable, Dict, Optional
from uuid import uuid4
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
from app.config import settings
from app.utils.logging import app_logger
logger = app_logger.bind(name="security_headers")
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
"""Middleware to add security headers to all responses"""
def __init__(self, app, config: Optional[Dict[str, str]] = None):
super().__init__(app)
self.config = config or {}
self.headers = self._build_security_headers()
def _build_security_headers(self) -> Dict[str, str]:
"""Build security headers based on configuration and environment"""
# Base security headers
headers = {
# Prevent MIME type sniffing
"X-Content-Type-Options": "nosniff",
# XSS Protection (legacy but still useful)
"X-XSS-Protection": "1; mode=block",
# Clickjacking protection
"X-Frame-Options": "DENY",
# Referrer policy
"Referrer-Policy": "strict-origin-when-cross-origin",
# Remove server information
"Server": "Delphi-DB",
# Prevent exposure of sensitive information
"X-Powered-By": "",
}
# HSTS (HTTP Strict Transport Security) - only for HTTPS
if self._is_https_environment():
headers["Strict-Transport-Security"] = self.config.get(
"hsts",
"max-age=31536000; includeSubDomains; preload"
)
# Content Security Policy
csp = self._build_csp_header()
if csp:
headers["Content-Security-Policy"] = csp
# Permissions Policy (Feature Policy)
permissions_policy = self._build_permissions_policy()
if permissions_policy:
headers["Permissions-Policy"] = permissions_policy
return headers
def _is_https_environment(self) -> bool:
"""Check if we're in an HTTPS environment"""
# Check common HTTPS indicators
if self.config.get("force_https", False):
return True
# In production, assume HTTPS
if not settings.debug:
return True
# Check for secure cookies setting
if settings.secure_cookies:
return True
return False
def _build_csp_header(self) -> str:
"""Build Content Security Policy header"""
# Get domain configuration
domain = self.config.get("domain", "'self'")
# CSP directives for the application
csp_directives = {
# Default source
"default-src": ["'self'"],
# Script sources - allow self and inline scripts for the app
"script-src": [
"'self'",
"'unsafe-inline'", # Required for inline event handlers
"https://cdn.tailwindcss.com", # Tailwind CSS CDN if used
],
# Style sources - allow self and inline styles
"style-src": [
"'self'",
"'unsafe-inline'", # Required for component styling
"https://fonts.googleapis.com",
"https://cdn.tailwindcss.com",
],
# Font sources
"font-src": [
"'self'",
"https://fonts.gstatic.com",
"data:",
],
# Image sources
"img-src": [
"'self'",
"data:",
"blob:",
"https:", # Allow HTTPS images
],
# Media sources
"media-src": ["'self'", "blob:"],
# Object sources (disable Flash, etc.)
"object-src": ["'none'"],
# Frame sources (for embedding)
"frame-src": ["'none'"],
# Connect sources (AJAX, WebSocket, etc.)
"connect-src": [
"'self'",
"wss:", # WebSocket support
"ws:", # WebSocket support
],
# Worker sources
"worker-src": ["'self'", "blob:"],
# Child sources
"child-src": ["'none'"],
# Form action restrictions
"form-action": ["'self'"],
# Frame ancestors (clickjacking protection)
"frame-ancestors": ["'none'"],
# Base URI restrictions
"base-uri": ["'self'"],
# Manifest sources
"manifest-src": ["'self'"],
}
# Build CSP string
csp_parts = []
for directive, sources in csp_directives.items():
csp_parts.append(f"{directive} {' '.join(sources)}")
# Add upgrade insecure requests in HTTPS environments
if self._is_https_environment():
csp_parts.append("upgrade-insecure-requests")
return "; ".join(csp_parts)
def _build_permissions_policy(self) -> str:
"""Build Permissions Policy header"""
# Restrictive permissions policy
policies = {
# Disable camera access
"camera": "(),",
# Disable microphone access
"microphone": "(),",
# Disable geolocation
"geolocation": "(),",
# Disable gyroscope
"gyroscope": "(),",
# Disable magnetometer
"magnetometer": "(),",
# Disable payment API
"payment": "(),",
# Disable USB access
"usb": "(),",
# Disable notifications (except for self)
"notifications": "(self),",
# Disable push messaging
"push": "(),",
# Disable speaker selection
"speaker-selection": "(),",
# Allow clipboard access for self
"clipboard-write": "(self),",
"clipboard-read": "(self),",
# Allow fullscreen for self
"fullscreen": "(self),",
}
return " ".join([f"{feature}={policy}" for feature, policy in policies.items()])
async def dispatch(self, request: Request, call_next: Callable) -> Response:
# Process the request
response = await call_next(request)
# Add security headers to all responses
for header, value in self.headers.items():
if value: # Only add non-empty headers
response.headers[header] = value
# Special handling for certain endpoints
self._apply_endpoint_specific_headers(request, response)
return response
def _apply_endpoint_specific_headers(self, request: Request, response: Response):
"""Apply endpoint-specific security headers"""
path = request.url.path
# Admin pages - extra security
if path.startswith(("/admin", "/api/admin/")):
# More restrictive CSP for admin pages
response.headers["X-Frame-Options"] = "DENY"
response.headers["Cache-Control"] = "no-store, no-cache, must-revalidate"
response.headers["Pragma"] = "no-cache"
# API endpoints - prevent caching of sensitive data
elif path.startswith("/api/"):
# Prevent caching of API responses
if request.method != "GET" or "auth" in path or "admin" in path:
response.headers["Cache-Control"] = "no-store, no-cache, must-revalidate"
response.headers["Pragma"] = "no-cache"
# File upload endpoints - additional validation headers
elif "upload" in path:
response.headers["X-Content-Type-Options"] = "nosniff"
response.headers["Cache-Control"] = "no-store"
# Static files - allow caching but with security
elif path.startswith(("/static/", "/uploads/")):
response.headers["X-Content-Type-Options"] = "nosniff"
# Allow caching for static resources
if "static" in path:
response.headers["Cache-Control"] = "public, max-age=31536000"
class RequestSizeLimitMiddleware(BaseHTTPMiddleware):
"""Middleware to limit request body size to prevent DoS attacks"""
def __init__(self, app, max_size: int = 50 * 1024 * 1024): # 50MB default
super().__init__(app)
self.max_size = max_size
async def dispatch(self, request: Request, call_next: Callable) -> Response:
# Check Content-Length header
content_length = request.headers.get("content-length")
if content_length:
try:
size = int(content_length)
if size > self.max_size:
logger.warning(
"Request size limit exceeded",
size=size,
limit=self.max_size,
path=request.url.path,
ip=self._get_client_ip(request)
)
# Build standardized error envelope with correlation id
from starlette.responses import JSONResponse
# Resolve correlation id from state, headers, or generate
correlation_id = (
getattr(getattr(request, "state", object()), "correlation_id", None)
or request.headers.get("x-correlation-id")
or request.headers.get("x-request-id")
or str(uuid4())
)
body = {
"success": False,
"error": {
"status": 413,
"code": "http_error",
"message": "Payload too large",
},
"correlation_id": correlation_id,
}
response = JSONResponse(status_code=413, content=body)
response.headers["X-Correlation-ID"] = correlation_id
return response
except ValueError:
pass # Invalid Content-Length header, let it pass
return await call_next(request)
def _get_client_ip(self, request: Request) -> str:
"""Extract client IP address from request headers"""
forwarded_for = request.headers.get("x-forwarded-for")
if forwarded_for:
return forwarded_for.split(",")[0].strip()
real_ip = request.headers.get("x-real-ip")
if real_ip:
return real_ip
if request.client:
return request.client.host
return "unknown"
class CSRFMiddleware(BaseHTTPMiddleware):
"""CSRF protection middleware for state-changing operations"""
def __init__(self, app, exempt_paths: Optional[list] = None):
super().__init__(app)
# Paths that don't require CSRF protection
self.exempt_paths = exempt_paths or [
"/api/auth/login",
"/api/auth/refresh",
"/health",
"/static/",
"/uploads/",
]
# HTTP methods that require CSRF protection
self.protected_methods = {"POST", "PUT", "DELETE", "PATCH"}
async def dispatch(self, request: Request, call_next: Callable) -> Response:
# Skip CSRF protection for exempt paths
if any(request.url.path.startswith(path) for path in self.exempt_paths):
return await call_next(request)
# Skip for safe HTTP methods
if request.method not in self.protected_methods:
return await call_next(request)
# Check for CSRF token in headers
csrf_token = request.headers.get("X-CSRF-Token") or request.headers.get("X-CSRFToken")
# For now, implement a simple CSRF check based on Referer/Origin
# In production, you'd want proper CSRF tokens
referer = request.headers.get("referer", "")
origin = request.headers.get("origin", "")
host = request.headers.get("host", "")
# Allow requests from same origin
valid_origins = [f"https://{host}", f"http://{host}"]
if origin and origin not in valid_origins:
if not referer or not any(referer.startswith(valid) for valid in valid_origins):
logger.warning(
"CSRF check failed",
path=request.url.path,
method=request.method,
origin=origin,
referer=referer,
ip=self._get_client_ip(request)
)
from fastapi import HTTPException, status
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="CSRF validation failed"
)
return await call_next(request)
def _get_client_ip(self, request: Request) -> str:
"""Extract client IP address from request headers"""
forwarded_for = request.headers.get("x-forwarded-for")
if forwarded_for:
return forwarded_for.split(",")[0].strip()
real_ip = request.headers.get("x-real-ip")
if real_ip:
return real_ip
if request.client:
return request.client.host
return "unknown"

View File

@@ -0,0 +1,319 @@
"""
Session management middleware for P2 security features
"""
import time
from datetime import datetime, timezone
from typing import Optional
from fastapi import Request, Response
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from sqlalchemy.orm import Session
from app.database.base import get_db
from app.utils.session_manager import SessionManager
from app.core.logging import get_logger
logger = get_logger(__name__)
class SessionManagementMiddleware(BaseHTTPMiddleware):
"""
Advanced session management middleware
Features:
- Session validation and renewal
- Activity tracking
- Concurrent session limits
- Session fixation protection
- Automatic cleanup
"""
# Endpoints that don't require session validation
EXCLUDED_PATHS = {
"/docs", "/redoc", "/openapi.json",
"/api/auth/login", "/api/auth/register",
"/api/auth/refresh", "/api/health",
"/health", "/ready", "/metrics",
}
def __init__(self, app, cleanup_interval: int = 3600):
super().__init__(app)
self.cleanup_interval = cleanup_interval # 1 hour default
self.last_cleanup = time.time()
async def dispatch(self, request: Request, call_next):
"""Main middleware dispatcher"""
start_time = time.time()
# Skip middleware for excluded paths
if self._should_skip_middleware(request):
return await call_next(request)
# Get database session
db = next(get_db())
session_manager = SessionManager(db)
try:
# Perform periodic cleanup
await self._periodic_cleanup(session_manager)
# Process session validation
session_info = await self._validate_session(request, session_manager)
# Add session info to request state
request.state.session_info = session_info
# Also expose authenticated user (if any) for downstream middleware (e.g., user-based rate limiting, logging)
try:
request.state.user = session_info.get("user") if session_info else None
except Exception:
# Be resilient: never break the request due to state propagation
pass
# Fallback: if no server-side session identified, try to attach user from Authorization token for JWT-only flows
if not getattr(request.state, "user", None):
try:
auth_header = request.headers.get("authorization") or request.headers.get("Authorization")
if auth_header and auth_header.lower().startswith("bearer "):
token = auth_header.split(" ", 1)[1].strip()
from app.auth.security import verify_token # local import to avoid circular deps
username = verify_token(token)
if username:
from app.models.user import User # local import
user = session_manager.db.query(User).filter(User.username == username).first()
if user and user.is_active:
request.state.user = user
except Exception:
# Never fail the request if auth attachment fails here
pass
# Process request
response = await call_next(request)
# Update session activity
if session_info and session_info.get("session"):
await self._update_session_activity(
request, response, session_info["session"], session_manager, start_time
)
return response
except Exception as e:
logger.error(f"Session middleware error: {str(e)}")
# Re-raise to be handled by global error handlers; do not re-invoke downstream app
raise
finally:
db.close()
def _should_skip_middleware(self, request: Request) -> bool:
"""Check if middleware should be skipped for this request"""
path = request.url.path
# Skip excluded paths
if any(path.startswith(excluded) for excluded in self.EXCLUDED_PATHS):
return True
# Skip static files
if path.startswith("/static/") or path.startswith("/favicon.ico"):
return True
return False
async def _validate_session(self, request: Request, session_manager: SessionManager) -> Optional[dict]:
"""Validate session from request"""
# Extract session ID from various sources
session_id = await self._extract_session_id(request)
if not session_id:
return None
# Validate session
session = session_manager.validate_session(session_id, request)
if not session:
return None
return {
"session": session,
"session_id": session_id,
"user": session.user,
"is_valid": True
}
async def _extract_session_id(self, request: Request) -> Optional[str]:
"""Extract session ID from request"""
# Try cookie first
session_id = request.cookies.get("session_id")
if session_id:
return session_id
# Try custom header
session_id = request.headers.get("X-Session-ID")
if session_id:
return session_id
# For JWT-based sessions, extract from authorization header
auth_header = request.headers.get("authorization")
if auth_header and auth_header.startswith("Bearer "):
# Use JWT token as session identifier for now
# In a full implementation, you'd decode JWT and extract session ID
token = auth_header[7:]
return token[:32] if len(token) > 32 else token
return None
async def _update_session_activity(
self,
request: Request,
response: Response,
session,
session_manager: SessionManager,
start_time: float
) -> None:
"""Update session activity tracking"""
try:
duration_ms = int((time.time() - start_time) * 1000)
# Log API activity
session_manager._log_activity(
session, session.user, request,
activity_type="api_request",
endpoint=request.url.path
)
# Update activity record with response details
if hasattr(session, 'activities') and session.activities:
latest_activity = session.activities[-1]
latest_activity.status_code = getattr(response, 'status_code', None)
latest_activity.duration_ms = duration_ms
# Analyze for suspicious patterns
await self._analyze_activity_patterns(session, session_manager)
session_manager.db.commit()
except Exception as e:
logger.error(f"Failed to update session activity: {str(e)}")
async def _analyze_activity_patterns(self, session, session_manager: SessionManager) -> None:
"""Analyze activity patterns for suspicious behavior"""
try:
# Get recent activities for this session
recent_activities = session_manager.db.query(
session_manager.db.query(type(session.activities[0]))
).filter_by(session_id=session.id).order_by(
type(session.activities[0]).timestamp.desc()
).limit(10).all()
if len(recent_activities) < 5:
return
# Check for rapid API calls (possible automation)
time_diffs = []
for i in range(1, len(recent_activities)):
time_diff = (recent_activities[i-1].timestamp - recent_activities[i].timestamp).total_seconds()
time_diffs.append(time_diff)
avg_time_diff = sum(time_diffs) / len(time_diffs)
# Flag if average time between requests is < 1 second
if avg_time_diff < 1.0:
session.risk_score = min(session.risk_score + 10, 100)
session_manager._create_security_event(
session, session.user,
event_type="rapid_api_calls",
severity="medium",
description=f"Rapid API calls detected: avg {avg_time_diff:.2f}s between requests"
)
# Lock session if risk score is too high
if session.risk_score >= 80:
session.lock_session("high_risk_activity")
session_manager._create_security_event(
session, session.user,
event_type="session_locked",
severity="high",
description=f"Session locked due to high risk score: {session.risk_score}",
action_taken="session_locked"
)
except Exception as e:
logger.error(f"Failed to analyze activity patterns: {str(e)}")
async def _periodic_cleanup(self, session_manager: SessionManager) -> None:
"""Perform periodic cleanup of expired sessions"""
current_time = time.time()
if current_time - self.last_cleanup > self.cleanup_interval:
try:
cleaned_count = session_manager.cleanup_expired_sessions()
self.last_cleanup = current_time
if cleaned_count > 0:
logger.info(f"Cleaned up {cleaned_count} expired sessions")
except Exception as e:
logger.error(f"Failed to cleanup sessions: {str(e)}")
class SessionSecurityMiddleware(BaseHTTPMiddleware):
"""
Additional security middleware for session protection
"""
def __init__(self, app):
super().__init__(app)
async def dispatch(self, request: Request, call_next):
"""Process security checks for sessions"""
# Add security headers for session management
response = await call_next(request)
# Add session security headers
if isinstance(response, Response):
# Prevent session fixation
response.headers["X-Session-Security"] = "fixation-protected"
# Indicate session management is active
response.headers["X-Session-Management"] = "active"
# Add cache control for session-sensitive pages
if request.url.path.startswith("/api/"):
response.headers["Cache-Control"] = "no-store, no-cache, must-revalidate"
response.headers["Pragma"] = "no-cache"
response.headers["Expires"] = "0"
return response
class SessionCookieMiddleware(BaseHTTPMiddleware):
"""
Secure cookie management for sessions
"""
def __init__(self, app, secure: bool = True, same_site: str = "strict"):
super().__init__(app)
self.secure = secure
self.same_site = same_site
async def dispatch(self, request: Request, call_next):
"""Handle secure session cookies"""
response = await call_next(request)
# Check if we need to set session cookie
session_info = getattr(request.state, 'session_info', None)
if session_info and session_info.get("session"):
session_id = session_info["session_id"]
# Set secure session cookie
response.set_cookie(
key="session_id",
value=session_id,
max_age=28800, # 8 hours
httponly=True,
secure=self.secure,
samesite=self.same_site
)
return response

View File

@@ -0,0 +1,439 @@
"""
WebSocket Middleware and Utilities
This module provides middleware and utilities for WebSocket connections,
including authentication, connection management, and integration with the
WebSocket pool system.
"""
import asyncio
from typing import Optional, Dict, Any, Set, Callable, Awaitable
from urllib.parse import parse_qs
from fastapi import WebSocket, WebSocketDisconnect, HTTPException, status
from sqlalchemy.orm import Session
from app.database.base import SessionLocal
from app.models.user import User
from app.auth.security import verify_token
from app.services.websocket_pool import (
get_websocket_pool,
websocket_connection,
WebSocketMessage,
MessageType
)
from app.utils.logging import StructuredLogger
class WebSocketAuthenticationError(Exception):
"""Raised when WebSocket authentication fails"""
pass
class WebSocketManager:
"""
High-level WebSocket manager that provides easy-to-use methods
for handling WebSocket connections with authentication and topic management
"""
def __init__(self):
self.logger = StructuredLogger("websocket_manager", "INFO")
self.pool = get_websocket_pool()
async def authenticate_websocket(self, websocket: WebSocket) -> Optional[User]:
"""
Authenticate a WebSocket connection using token from query parameters
Args:
websocket: WebSocket instance
Returns:
User object if authentication successful, None otherwise
"""
try:
# Get token from query parameters
query_params = parse_qs(str(websocket.url.query))
token = query_params.get('token', [None])[0]
if not token:
self.logger.warning("WebSocket authentication failed: no token provided")
return None
# Verify token
username = verify_token(token)
if not username:
self.logger.warning("WebSocket authentication failed: invalid token")
return None
# Get user from database
db: Session = SessionLocal()
try:
user = db.query(User).filter(User.username == username).first()
if not user or not user.is_active:
self.logger.warning("WebSocket authentication failed: user not found or inactive",
username=username)
return None
self.logger.info("WebSocket authentication successful",
user_id=user.id,
username=user.username)
return user
finally:
db.close()
except Exception as e:
self.logger.error("WebSocket authentication error", error=str(e))
return None
async def handle_connection(
self,
websocket: WebSocket,
topics: Optional[Set[str]] = None,
require_auth: bool = True,
metadata: Optional[Dict[str, Any]] = None,
message_handler: Optional[Callable[[str, WebSocketMessage], Awaitable[None]]] = None
) -> Optional[str]:
"""
Handle a WebSocket connection with authentication and message processing
Args:
websocket: WebSocket instance
topics: Initial topics to subscribe to
require_auth: Whether authentication is required
metadata: Additional metadata for the connection
message_handler: Optional function to handle incoming messages
Returns:
Connection ID if successful, None if failed
"""
user = None
if require_auth:
user = await self.authenticate_websocket(websocket)
if not user:
await websocket.close(code=4401, reason="Authentication failed")
return None
# Accept the connection
await websocket.accept()
# Add to pool
user_id = user.id if user else None
async with websocket_connection(
websocket=websocket,
user_id=user_id,
topics=topics,
metadata=metadata
) as (connection_id, pool):
# Set connection state to connected
connection_info = await pool.get_connection_info(connection_id)
if connection_info:
connection_info.state = connection_info.state.CONNECTED
# Send initial welcome message
welcome_message = WebSocketMessage(
type="welcome",
data={
"connection_id": connection_id,
"user_id": user_id,
"topics": list(topics) if topics else [],
"timestamp": connection_info.created_at.isoformat() if connection_info else None
}
)
await pool._send_to_connection(connection_id, welcome_message)
# Handle messages
await self._message_loop(
websocket=websocket,
connection_id=connection_id,
pool=pool,
message_handler=message_handler
)
return connection_id
async def _message_loop(
self,
websocket: WebSocket,
connection_id: str,
pool,
message_handler: Optional[Callable[[str, WebSocketMessage], Awaitable[None]]] = None
):
"""Handle incoming WebSocket messages"""
try:
while True:
try:
# Receive message
data = await websocket.receive_text()
# Update activity
connection_info = await pool.get_connection_info(connection_id)
if connection_info:
connection_info.update_activity()
# Parse message
try:
import json
message_dict = json.loads(data)
message = WebSocketMessage(**message_dict)
except (json.JSONDecodeError, ValueError) as e:
self.logger.warning("Invalid message format",
connection_id=connection_id,
error=str(e),
data=data[:100])
continue
# Handle standard message types
await self._handle_standard_message(connection_id, message, pool)
# Call custom message handler if provided
if message_handler:
try:
await message_handler(connection_id, message)
except Exception as e:
self.logger.error("Error in custom message handler",
connection_id=connection_id,
error=str(e))
except WebSocketDisconnect:
self.logger.info("WebSocket disconnected", connection_id=connection_id)
break
except Exception as e:
self.logger.error("Error in message loop",
connection_id=connection_id,
error=str(e))
break
except Exception as e:
self.logger.error("Fatal error in message loop",
connection_id=connection_id,
error=str(e))
async def _handle_standard_message(self, connection_id: str, message: WebSocketMessage, pool):
"""Handle standard WebSocket message types"""
if message.type == MessageType.PING.value:
# Respond with pong
pong_message = WebSocketMessage(
type=MessageType.PONG.value,
data={"timestamp": message.timestamp}
)
await pool._send_to_connection(connection_id, pong_message)
elif message.type == MessageType.PONG.value:
# Handle pong response
await pool.handle_pong(connection_id)
elif message.type == MessageType.SUBSCRIBE.value:
# Subscribe to topic
topic = message.topic
if topic:
success = await pool.subscribe_to_topic(connection_id, topic)
response = WebSocketMessage(
type="subscription_response",
topic=topic,
data={"success": success, "action": "subscribe"}
)
await pool._send_to_connection(connection_id, response)
elif message.type == MessageType.UNSUBSCRIBE.value:
# Unsubscribe from topic
topic = message.topic
if topic:
success = await pool.unsubscribe_from_topic(connection_id, topic)
response = WebSocketMessage(
type="subscription_response",
topic=topic,
data={"success": success, "action": "unsubscribe"}
)
await pool._send_to_connection(connection_id, response)
async def broadcast_to_topic(
self,
topic: str,
message_type: str,
data: Optional[Dict[str, Any]] = None,
exclude_connection_id: Optional[str] = None
) -> int:
"""Convenience method to broadcast a message to a topic"""
message = WebSocketMessage(
type=message_type,
topic=topic,
data=data
)
return await self.pool.broadcast_to_topic(topic, message, exclude_connection_id)
async def send_to_user(
self,
user_id: int,
message_type: str,
data: Optional[Dict[str, Any]] = None
) -> int:
"""Convenience method to send a message to all connections for a user"""
message = WebSocketMessage(
type=message_type,
data=data
)
return await self.pool.send_to_user(user_id, message)
async def get_stats(self) -> Dict[str, Any]:
"""Get WebSocket pool statistics"""
return await self.pool.get_stats()
# Global WebSocket manager instance
_websocket_manager: Optional[WebSocketManager] = None
def get_websocket_manager() -> WebSocketManager:
"""Get the global WebSocket manager instance"""
global _websocket_manager
if _websocket_manager is None:
_websocket_manager = WebSocketManager()
return _websocket_manager
# Utility decorators and functions
def websocket_endpoint(
topics: Optional[Set[str]] = None,
require_auth: bool = True,
metadata: Optional[Dict[str, Any]] = None
):
"""
Decorator for WebSocket endpoints that automatically handles
connection management, authentication, and cleanup
Usage:
@router.websocket("/my-endpoint")
@websocket_endpoint(topics={"my_topic"}, require_auth=True)
async def my_websocket_handler(websocket: WebSocket, connection_id: str, manager: WebSocketManager):
# Your custom logic here
pass
"""
def decorator(func):
async def wrapper(websocket: WebSocket, *args, **kwargs):
manager = get_websocket_manager()
async def message_handler(connection_id: str, message: WebSocketMessage):
# Call the original function with the message
await func(websocket, connection_id, manager, message, *args, **kwargs)
# Handle the connection
connection_id = await manager.handle_connection(
websocket=websocket,
topics=topics,
require_auth=require_auth,
metadata=metadata,
message_handler=message_handler
)
if not connection_id:
return
# Keep the connection alive
try:
while True:
await asyncio.sleep(1)
connection_info = await manager.pool.get_connection_info(connection_id)
if not connection_info or not connection_info.is_alive():
break
except Exception:
pass
return wrapper
return decorator
async def websocket_auth_dependency(websocket: WebSocket) -> User:
"""
FastAPI dependency for WebSocket authentication
Usage:
@router.websocket("/my-endpoint")
async def my_endpoint(websocket: WebSocket, user: User = Depends(websocket_auth_dependency)):
# user is guaranteed to be authenticated
pass
"""
manager = get_websocket_manager()
user = await manager.authenticate_websocket(websocket)
if not user:
await websocket.close(code=4401, reason="Authentication failed")
raise WebSocketAuthenticationError("Authentication failed")
return user
class WebSocketConnectionTracker:
"""
Utility class to track WebSocket connections and their health
"""
def __init__(self):
self.logger = StructuredLogger("websocket_tracker", "INFO")
async def track_connection_health(self, connection_id: str, interval: int = 60):
"""Track the health of a specific connection"""
pool = get_websocket_pool()
while True:
try:
await asyncio.sleep(interval)
connection_info = await pool.get_connection_info(connection_id)
if not connection_info:
break
# Check if connection is healthy
if connection_info.is_stale(timeout_seconds=300):
self.logger.warning("Connection is stale",
connection_id=connection_id,
last_activity=connection_info.last_activity.isoformat())
break
# Try to ping the connection
if connection_info.is_alive():
success = await pool.ping_connection(connection_id)
if not success:
self.logger.warning("Failed to ping connection",
connection_id=connection_id)
break
except asyncio.CancelledError:
break
except Exception as e:
self.logger.error("Error tracking connection health",
connection_id=connection_id,
error=str(e))
break
async def get_connection_metrics(self, connection_id: str) -> Optional[Dict[str, Any]]:
"""Get detailed metrics for a connection"""
pool = get_websocket_pool()
connection_info = await pool.get_connection_info(connection_id)
if not connection_info:
return None
now = connection_info.last_activity # Use last_activity for consistency
return {
"connection_id": connection_id,
"user_id": connection_info.user_id,
"state": connection_info.state.value,
"topics": list(connection_info.topics),
"created_at": connection_info.created_at.isoformat(),
"last_activity": connection_info.last_activity.isoformat(),
"age_seconds": (now - connection_info.created_at).total_seconds(),
"idle_seconds": (now - connection_info.last_activity).total_seconds(),
"error_count": connection_info.error_count,
"last_ping": connection_info.last_ping.isoformat() if connection_info.last_ping else None,
"last_pong": connection_info.last_pong.isoformat() if connection_info.last_pong else None,
"metadata": connection_info.metadata,
"is_alive": connection_info.is_alive(),
"is_stale": connection_info.is_stale()
}
def get_connection_tracker() -> WebSocketConnectionTracker:
"""Get a WebSocket connection tracker instance"""
return WebSocketConnectionTracker()