""" Security Headers Middleware Implements comprehensive security headers to protect against common web vulnerabilities: - HSTS (HTTP Strict Transport Security) - CSP (Content Security Policy) - X-Frame-Options (Clickjacking protection) - X-Content-Type-Options (MIME sniffing protection) - X-XSS-Protection (XSS protection) - Referrer-Policy (Information disclosure protection) - Permissions-Policy (Feature policy) """ from typing import Callable, Dict, Optional from uuid import uuid4 from fastapi import Request from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import Response from app.config import settings from app.utils.logging import app_logger logger = app_logger.bind(name="security_headers") class SecurityHeadersMiddleware(BaseHTTPMiddleware): """Middleware to add security headers to all responses""" def __init__(self, app, config: Optional[Dict[str, str]] = None): super().__init__(app) self.config = config or {} self.headers = self._build_security_headers() def _build_security_headers(self) -> Dict[str, str]: """Build security headers based on configuration and environment""" # Base security headers headers = { # Prevent MIME type sniffing "X-Content-Type-Options": "nosniff", # XSS Protection (legacy but still useful) "X-XSS-Protection": "1; mode=block", # Clickjacking protection "X-Frame-Options": "DENY", # Referrer policy "Referrer-Policy": "strict-origin-when-cross-origin", # Remove server information "Server": "Delphi-DB", # Prevent exposure of sensitive information "X-Powered-By": "", } # HSTS (HTTP Strict Transport Security) - only for HTTPS if self._is_https_environment(): headers["Strict-Transport-Security"] = self.config.get( "hsts", "max-age=31536000; includeSubDomains; preload" ) # Content Security Policy csp = self._build_csp_header() if csp: headers["Content-Security-Policy"] = csp # Permissions Policy (Feature Policy) permissions_policy = self._build_permissions_policy() if permissions_policy: headers["Permissions-Policy"] = permissions_policy return headers def _is_https_environment(self) -> bool: """Check if we're in an HTTPS environment""" # Check common HTTPS indicators if self.config.get("force_https", False): return True # In production, assume HTTPS if not settings.debug: return True # Check for secure cookies setting if settings.secure_cookies: return True return False def _build_csp_header(self) -> str: """Build Content Security Policy header""" # Get domain configuration domain = self.config.get("domain", "'self'") # CSP directives for the application csp_directives = { # Default source "default-src": ["'self'"], # Script sources - allow self and inline scripts for the app "script-src": [ "'self'", "'unsafe-inline'", # Required for inline event handlers "https://cdn.tailwindcss.com", # Tailwind CSS CDN if used ], # Style sources - allow self and inline styles "style-src": [ "'self'", "'unsafe-inline'", # Required for component styling "https://fonts.googleapis.com", "https://cdn.tailwindcss.com", ], # Font sources "font-src": [ "'self'", "https://fonts.gstatic.com", "data:", ], # Image sources "img-src": [ "'self'", "data:", "blob:", "https:", # Allow HTTPS images ], # Media sources "media-src": ["'self'", "blob:"], # Object sources (disable Flash, etc.) "object-src": ["'none'"], # Frame sources (for embedding) "frame-src": ["'none'"], # Connect sources (AJAX, WebSocket, etc.) "connect-src": [ "'self'", "wss:", # WebSocket support "ws:", # WebSocket support ], # Worker sources "worker-src": ["'self'", "blob:"], # Child sources "child-src": ["'none'"], # Form action restrictions "form-action": ["'self'"], # Frame ancestors (clickjacking protection) "frame-ancestors": ["'none'"], # Base URI restrictions "base-uri": ["'self'"], # Manifest sources "manifest-src": ["'self'"], } # Build CSP string csp_parts = [] for directive, sources in csp_directives.items(): csp_parts.append(f"{directive} {' '.join(sources)}") # Add upgrade insecure requests in HTTPS environments if self._is_https_environment(): csp_parts.append("upgrade-insecure-requests") return "; ".join(csp_parts) def _build_permissions_policy(self) -> str: """Build Permissions Policy header""" # Restrictive permissions policy policies = { # Disable camera access "camera": "(),", # Disable microphone access "microphone": "(),", # Disable geolocation "geolocation": "(),", # Disable gyroscope "gyroscope": "(),", # Disable magnetometer "magnetometer": "(),", # Disable payment API "payment": "(),", # Disable USB access "usb": "(),", # Disable notifications (except for self) "notifications": "(self),", # Disable push messaging "push": "(),", # Disable speaker selection "speaker-selection": "(),", # Allow clipboard access for self "clipboard-write": "(self),", "clipboard-read": "(self),", # Allow fullscreen for self "fullscreen": "(self),", } return " ".join([f"{feature}={policy}" for feature, policy in policies.items()]) async def dispatch(self, request: Request, call_next: Callable) -> Response: # Process the request response = await call_next(request) # Add security headers to all responses for header, value in self.headers.items(): if value: # Only add non-empty headers response.headers[header] = value # Special handling for certain endpoints self._apply_endpoint_specific_headers(request, response) return response def _apply_endpoint_specific_headers(self, request: Request, response: Response): """Apply endpoint-specific security headers""" path = request.url.path # Admin pages - extra security if path.startswith(("/admin", "/api/admin/")): # More restrictive CSP for admin pages response.headers["X-Frame-Options"] = "DENY" response.headers["Cache-Control"] = "no-store, no-cache, must-revalidate" response.headers["Pragma"] = "no-cache" # API endpoints - prevent caching of sensitive data elif path.startswith("/api/"): # Prevent caching of API responses if request.method != "GET" or "auth" in path or "admin" in path: response.headers["Cache-Control"] = "no-store, no-cache, must-revalidate" response.headers["Pragma"] = "no-cache" # File upload endpoints - additional validation headers elif "upload" in path: response.headers["X-Content-Type-Options"] = "nosniff" response.headers["Cache-Control"] = "no-store" # Static files - allow caching but with security elif path.startswith(("/static/", "/uploads/")): response.headers["X-Content-Type-Options"] = "nosniff" # Allow caching for static resources if "static" in path: response.headers["Cache-Control"] = "public, max-age=31536000" class RequestSizeLimitMiddleware(BaseHTTPMiddleware): """Middleware to limit request body size to prevent DoS attacks""" def __init__(self, app, max_size: int = 50 * 1024 * 1024): # 50MB default super().__init__(app) self.max_size = max_size async def dispatch(self, request: Request, call_next: Callable) -> Response: # Check Content-Length header content_length = request.headers.get("content-length") if content_length: try: size = int(content_length) if size > self.max_size: logger.warning( "Request size limit exceeded", size=size, limit=self.max_size, path=request.url.path, ip=self._get_client_ip(request) ) # Build standardized error envelope with correlation id from starlette.responses import JSONResponse # Resolve correlation id from state, headers, or generate correlation_id = ( getattr(getattr(request, "state", object()), "correlation_id", None) or request.headers.get("x-correlation-id") or request.headers.get("x-request-id") or str(uuid4()) ) body = { "success": False, "error": { "status": 413, "code": "http_error", "message": "Payload too large", }, "correlation_id": correlation_id, } response = JSONResponse(status_code=413, content=body) response.headers["X-Correlation-ID"] = correlation_id return response except ValueError: pass # Invalid Content-Length header, let it pass return await call_next(request) def _get_client_ip(self, request: Request) -> str: """Extract client IP address from request headers""" forwarded_for = request.headers.get("x-forwarded-for") if forwarded_for: return forwarded_for.split(",")[0].strip() real_ip = request.headers.get("x-real-ip") if real_ip: return real_ip if request.client: return request.client.host return "unknown" class CSRFMiddleware(BaseHTTPMiddleware): """CSRF protection middleware for state-changing operations""" def __init__(self, app, exempt_paths: Optional[list] = None): super().__init__(app) # Paths that don't require CSRF protection self.exempt_paths = exempt_paths or [ "/api/auth/login", "/api/auth/refresh", "/health", "/static/", "/uploads/", ] # HTTP methods that require CSRF protection self.protected_methods = {"POST", "PUT", "DELETE", "PATCH"} async def dispatch(self, request: Request, call_next: Callable) -> Response: # Skip CSRF protection for exempt paths if any(request.url.path.startswith(path) for path in self.exempt_paths): return await call_next(request) # Skip for safe HTTP methods if request.method not in self.protected_methods: return await call_next(request) # Check for CSRF token in headers csrf_token = request.headers.get("X-CSRF-Token") or request.headers.get("X-CSRFToken") # For now, implement a simple CSRF check based on Referer/Origin # In production, you'd want proper CSRF tokens referer = request.headers.get("referer", "") origin = request.headers.get("origin", "") host = request.headers.get("host", "") # Allow requests from same origin valid_origins = [f"https://{host}", f"http://{host}"] if origin and origin not in valid_origins: if not referer or not any(referer.startswith(valid) for valid in valid_origins): logger.warning( "CSRF check failed", path=request.url.path, method=request.method, origin=origin, referer=referer, ip=self._get_client_ip(request) ) from fastapi import HTTPException, status raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="CSRF validation failed" ) return await call_next(request) def _get_client_ip(self, request: Request) -> str: """Extract client IP address from request headers""" forwarded_for = request.headers.get("x-forwarded-for") if forwarded_for: return forwarded_for.split(",")[0].strip() real_ip = request.headers.get("x-real-ip") if real_ip: return real_ip if request.client: return request.client.host return "unknown"