407 lines
14 KiB
Python
407 lines
14 KiB
Python
"""
|
|
Security Headers Middleware
|
|
|
|
Implements comprehensive security headers to protect against common web vulnerabilities:
|
|
- HSTS (HTTP Strict Transport Security)
|
|
- CSP (Content Security Policy)
|
|
- X-Frame-Options (Clickjacking protection)
|
|
- X-Content-Type-Options (MIME sniffing protection)
|
|
- X-XSS-Protection (XSS protection)
|
|
- Referrer-Policy (Information disclosure protection)
|
|
- Permissions-Policy (Feature policy)
|
|
"""
|
|
from typing import Callable, Dict, Optional
|
|
from uuid import uuid4
|
|
from fastapi import Request
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
from starlette.responses import Response
|
|
from app.config import settings
|
|
from app.utils.logging import app_logger
|
|
|
|
logger = app_logger.bind(name="security_headers")
|
|
|
|
|
|
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
|
"""Middleware to add security headers to all responses"""
|
|
|
|
def __init__(self, app, config: Optional[Dict[str, str]] = None):
|
|
super().__init__(app)
|
|
self.config = config or {}
|
|
self.headers = self._build_security_headers()
|
|
|
|
def _build_security_headers(self) -> Dict[str, str]:
|
|
"""Build security headers based on configuration and environment"""
|
|
|
|
# Base security headers
|
|
headers = {
|
|
# Prevent MIME type sniffing
|
|
"X-Content-Type-Options": "nosniff",
|
|
|
|
# XSS Protection (legacy but still useful)
|
|
"X-XSS-Protection": "1; mode=block",
|
|
|
|
# Clickjacking protection
|
|
"X-Frame-Options": "DENY",
|
|
|
|
# Referrer policy
|
|
"Referrer-Policy": "strict-origin-when-cross-origin",
|
|
|
|
# Remove server information
|
|
"Server": "Delphi-DB",
|
|
|
|
# Prevent exposure of sensitive information
|
|
"X-Powered-By": "",
|
|
}
|
|
|
|
# HSTS (HTTP Strict Transport Security) - only for HTTPS
|
|
if self._is_https_environment():
|
|
headers["Strict-Transport-Security"] = self.config.get(
|
|
"hsts",
|
|
"max-age=31536000; includeSubDomains; preload"
|
|
)
|
|
|
|
# Content Security Policy
|
|
csp = self._build_csp_header()
|
|
if csp:
|
|
headers["Content-Security-Policy"] = csp
|
|
|
|
# Permissions Policy (Feature Policy)
|
|
permissions_policy = self._build_permissions_policy()
|
|
if permissions_policy:
|
|
headers["Permissions-Policy"] = permissions_policy
|
|
|
|
return headers
|
|
|
|
def _is_https_environment(self) -> bool:
|
|
"""Check if we're in an HTTPS environment"""
|
|
# Check common HTTPS indicators
|
|
if self.config.get("force_https", False):
|
|
return True
|
|
|
|
# In production, assume HTTPS
|
|
if not settings.debug:
|
|
return True
|
|
|
|
# Check for secure cookies setting
|
|
if settings.secure_cookies:
|
|
return True
|
|
|
|
return False
|
|
|
|
def _build_csp_header(self) -> str:
|
|
"""Build Content Security Policy header"""
|
|
|
|
# Get domain configuration
|
|
domain = self.config.get("domain", "'self'")
|
|
|
|
# CSP directives for the application
|
|
csp_directives = {
|
|
# Default source
|
|
"default-src": ["'self'"],
|
|
|
|
# Script sources - allow self and inline scripts for the app
|
|
"script-src": [
|
|
"'self'",
|
|
"'unsafe-inline'", # Required for inline event handlers
|
|
"https://cdn.tailwindcss.com", # Tailwind CSS CDN if used
|
|
],
|
|
|
|
# Style sources - allow self and inline styles
|
|
"style-src": [
|
|
"'self'",
|
|
"'unsafe-inline'", # Required for component styling
|
|
"https://fonts.googleapis.com",
|
|
"https://cdn.tailwindcss.com",
|
|
],
|
|
|
|
# Font sources
|
|
"font-src": [
|
|
"'self'",
|
|
"https://fonts.gstatic.com",
|
|
"data:",
|
|
],
|
|
|
|
# Image sources
|
|
"img-src": [
|
|
"'self'",
|
|
"data:",
|
|
"blob:",
|
|
"https:", # Allow HTTPS images
|
|
],
|
|
|
|
# Media sources
|
|
"media-src": ["'self'", "blob:"],
|
|
|
|
# Object sources (disable Flash, etc.)
|
|
"object-src": ["'none'"],
|
|
|
|
# Frame sources (for embedding)
|
|
"frame-src": ["'none'"],
|
|
|
|
# Connect sources (AJAX, WebSocket, etc.)
|
|
"connect-src": [
|
|
"'self'",
|
|
"wss:", # WebSocket support
|
|
"ws:", # WebSocket support
|
|
],
|
|
|
|
# Worker sources
|
|
"worker-src": ["'self'", "blob:"],
|
|
|
|
# Child sources
|
|
"child-src": ["'none'"],
|
|
|
|
# Form action restrictions
|
|
"form-action": ["'self'"],
|
|
|
|
# Frame ancestors (clickjacking protection)
|
|
"frame-ancestors": ["'none'"],
|
|
|
|
# Base URI restrictions
|
|
"base-uri": ["'self'"],
|
|
|
|
# Manifest sources
|
|
"manifest-src": ["'self'"],
|
|
}
|
|
|
|
# Build CSP string
|
|
csp_parts = []
|
|
for directive, sources in csp_directives.items():
|
|
csp_parts.append(f"{directive} {' '.join(sources)}")
|
|
|
|
# Add upgrade insecure requests in HTTPS environments
|
|
if self._is_https_environment():
|
|
csp_parts.append("upgrade-insecure-requests")
|
|
|
|
return "; ".join(csp_parts)
|
|
|
|
def _build_permissions_policy(self) -> str:
|
|
"""Build Permissions Policy header"""
|
|
|
|
# Restrictive permissions policy
|
|
policies = {
|
|
# Disable camera access
|
|
"camera": "(),",
|
|
|
|
# Disable microphone access
|
|
"microphone": "(),",
|
|
|
|
# Disable geolocation
|
|
"geolocation": "(),",
|
|
|
|
# Disable gyroscope
|
|
"gyroscope": "(),",
|
|
|
|
# Disable magnetometer
|
|
"magnetometer": "(),",
|
|
|
|
# Disable payment API
|
|
"payment": "(),",
|
|
|
|
# Disable USB access
|
|
"usb": "(),",
|
|
|
|
# Disable notifications (except for self)
|
|
"notifications": "(self),",
|
|
|
|
# Disable push messaging
|
|
"push": "(),",
|
|
|
|
# Disable speaker selection
|
|
"speaker-selection": "(),",
|
|
|
|
# Allow clipboard access for self
|
|
"clipboard-write": "(self),",
|
|
"clipboard-read": "(self),",
|
|
|
|
# Allow fullscreen for self
|
|
"fullscreen": "(self),",
|
|
}
|
|
|
|
return " ".join([f"{feature}={policy}" for feature, policy in policies.items()])
|
|
|
|
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
|
# Process the request
|
|
response = await call_next(request)
|
|
|
|
# Add security headers to all responses
|
|
for header, value in self.headers.items():
|
|
if value: # Only add non-empty headers
|
|
response.headers[header] = value
|
|
|
|
# Special handling for certain endpoints
|
|
self._apply_endpoint_specific_headers(request, response)
|
|
|
|
return response
|
|
|
|
def _apply_endpoint_specific_headers(self, request: Request, response: Response):
|
|
"""Apply endpoint-specific security headers"""
|
|
|
|
path = request.url.path
|
|
|
|
# Admin pages - extra security
|
|
if path.startswith(("/admin", "/api/admin/")):
|
|
# More restrictive CSP for admin pages
|
|
response.headers["X-Frame-Options"] = "DENY"
|
|
response.headers["Cache-Control"] = "no-store, no-cache, must-revalidate"
|
|
response.headers["Pragma"] = "no-cache"
|
|
|
|
# API endpoints - prevent caching of sensitive data
|
|
elif path.startswith("/api/"):
|
|
# Prevent caching of API responses
|
|
if request.method != "GET" or "auth" in path or "admin" in path:
|
|
response.headers["Cache-Control"] = "no-store, no-cache, must-revalidate"
|
|
response.headers["Pragma"] = "no-cache"
|
|
|
|
# File upload endpoints - additional validation headers
|
|
elif "upload" in path:
|
|
response.headers["X-Content-Type-Options"] = "nosniff"
|
|
response.headers["Cache-Control"] = "no-store"
|
|
|
|
# Static files - allow caching but with security
|
|
elif path.startswith(("/static/", "/uploads/")):
|
|
response.headers["X-Content-Type-Options"] = "nosniff"
|
|
# Allow caching for static resources
|
|
if "static" in path:
|
|
response.headers["Cache-Control"] = "public, max-age=31536000"
|
|
|
|
|
|
class RequestSizeLimitMiddleware(BaseHTTPMiddleware):
|
|
"""Middleware to limit request body size to prevent DoS attacks"""
|
|
|
|
def __init__(self, app, max_size: int = 50 * 1024 * 1024): # 50MB default
|
|
super().__init__(app)
|
|
self.max_size = max_size
|
|
|
|
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
|
# Check Content-Length header
|
|
content_length = request.headers.get("content-length")
|
|
|
|
if content_length:
|
|
try:
|
|
size = int(content_length)
|
|
if size > self.max_size:
|
|
logger.warning(
|
|
"Request size limit exceeded",
|
|
size=size,
|
|
limit=self.max_size,
|
|
path=request.url.path,
|
|
ip=self._get_client_ip(request)
|
|
)
|
|
|
|
# Build standardized error envelope with correlation id
|
|
from starlette.responses import JSONResponse
|
|
|
|
# Resolve correlation id from state, headers, or generate
|
|
correlation_id = (
|
|
getattr(getattr(request, "state", object()), "correlation_id", None)
|
|
or request.headers.get("x-correlation-id")
|
|
or request.headers.get("x-request-id")
|
|
or str(uuid4())
|
|
)
|
|
|
|
body = {
|
|
"success": False,
|
|
"error": {
|
|
"status": 413,
|
|
"code": "http_error",
|
|
"message": "Payload too large",
|
|
},
|
|
"correlation_id": correlation_id,
|
|
}
|
|
|
|
response = JSONResponse(status_code=413, content=body)
|
|
response.headers["X-Correlation-ID"] = correlation_id
|
|
return response
|
|
except ValueError:
|
|
pass # Invalid Content-Length header, let it pass
|
|
|
|
return await call_next(request)
|
|
|
|
def _get_client_ip(self, request: Request) -> str:
|
|
"""Extract client IP address from request headers"""
|
|
forwarded_for = request.headers.get("x-forwarded-for")
|
|
if forwarded_for:
|
|
return forwarded_for.split(",")[0].strip()
|
|
|
|
real_ip = request.headers.get("x-real-ip")
|
|
if real_ip:
|
|
return real_ip
|
|
|
|
if request.client:
|
|
return request.client.host
|
|
|
|
return "unknown"
|
|
|
|
|
|
class CSRFMiddleware(BaseHTTPMiddleware):
|
|
"""CSRF protection middleware for state-changing operations"""
|
|
|
|
def __init__(self, app, exempt_paths: Optional[list] = None):
|
|
super().__init__(app)
|
|
# Paths that don't require CSRF protection
|
|
self.exempt_paths = exempt_paths or [
|
|
"/api/auth/login",
|
|
"/api/auth/refresh",
|
|
"/health",
|
|
"/static/",
|
|
"/uploads/",
|
|
]
|
|
# HTTP methods that require CSRF protection
|
|
self.protected_methods = {"POST", "PUT", "DELETE", "PATCH"}
|
|
|
|
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
|
# Skip CSRF protection for exempt paths
|
|
if any(request.url.path.startswith(path) for path in self.exempt_paths):
|
|
return await call_next(request)
|
|
|
|
# Skip for safe HTTP methods
|
|
if request.method not in self.protected_methods:
|
|
return await call_next(request)
|
|
|
|
# Check for CSRF token in headers
|
|
csrf_token = request.headers.get("X-CSRF-Token") or request.headers.get("X-CSRFToken")
|
|
|
|
# For now, implement a simple CSRF check based on Referer/Origin
|
|
# In production, you'd want proper CSRF tokens
|
|
referer = request.headers.get("referer", "")
|
|
origin = request.headers.get("origin", "")
|
|
host = request.headers.get("host", "")
|
|
|
|
# Allow requests from same origin
|
|
valid_origins = [f"https://{host}", f"http://{host}"]
|
|
|
|
if origin and origin not in valid_origins:
|
|
if not referer or not any(referer.startswith(valid) for valid in valid_origins):
|
|
logger.warning(
|
|
"CSRF check failed",
|
|
path=request.url.path,
|
|
method=request.method,
|
|
origin=origin,
|
|
referer=referer,
|
|
ip=self._get_client_ip(request)
|
|
)
|
|
|
|
from fastapi import HTTPException, status
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="CSRF validation failed"
|
|
)
|
|
|
|
return await call_next(request)
|
|
|
|
def _get_client_ip(self, request: Request) -> str:
|
|
"""Extract client IP address from request headers"""
|
|
forwarded_for = request.headers.get("x-forwarded-for")
|
|
if forwarded_for:
|
|
return forwarded_for.split(",")[0].strip()
|
|
|
|
real_ip = request.headers.get("x-real-ip")
|
|
if real_ip:
|
|
return real_ip
|
|
|
|
if request.client:
|
|
return request.client.host
|
|
|
|
return "unknown"
|