This commit is contained in:
HotSwapp
2025-08-18 20:20:04 -05:00
parent 89b2bc0aa2
commit bac8cc4bd5
114 changed files with 30258 additions and 1341 deletions

View File

@@ -0,0 +1,406 @@
"""
Security Headers Middleware
Implements comprehensive security headers to protect against common web vulnerabilities:
- HSTS (HTTP Strict Transport Security)
- CSP (Content Security Policy)
- X-Frame-Options (Clickjacking protection)
- X-Content-Type-Options (MIME sniffing protection)
- X-XSS-Protection (XSS protection)
- Referrer-Policy (Information disclosure protection)
- Permissions-Policy (Feature policy)
"""
from typing import Callable, Dict, Optional
from uuid import uuid4
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
from app.config import settings
from app.utils.logging import app_logger
logger = app_logger.bind(name="security_headers")
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
"""Middleware to add security headers to all responses"""
def __init__(self, app, config: Optional[Dict[str, str]] = None):
super().__init__(app)
self.config = config or {}
self.headers = self._build_security_headers()
def _build_security_headers(self) -> Dict[str, str]:
"""Build security headers based on configuration and environment"""
# Base security headers
headers = {
# Prevent MIME type sniffing
"X-Content-Type-Options": "nosniff",
# XSS Protection (legacy but still useful)
"X-XSS-Protection": "1; mode=block",
# Clickjacking protection
"X-Frame-Options": "DENY",
# Referrer policy
"Referrer-Policy": "strict-origin-when-cross-origin",
# Remove server information
"Server": "Delphi-DB",
# Prevent exposure of sensitive information
"X-Powered-By": "",
}
# HSTS (HTTP Strict Transport Security) - only for HTTPS
if self._is_https_environment():
headers["Strict-Transport-Security"] = self.config.get(
"hsts",
"max-age=31536000; includeSubDomains; preload"
)
# Content Security Policy
csp = self._build_csp_header()
if csp:
headers["Content-Security-Policy"] = csp
# Permissions Policy (Feature Policy)
permissions_policy = self._build_permissions_policy()
if permissions_policy:
headers["Permissions-Policy"] = permissions_policy
return headers
def _is_https_environment(self) -> bool:
"""Check if we're in an HTTPS environment"""
# Check common HTTPS indicators
if self.config.get("force_https", False):
return True
# In production, assume HTTPS
if not settings.debug:
return True
# Check for secure cookies setting
if settings.secure_cookies:
return True
return False
def _build_csp_header(self) -> str:
"""Build Content Security Policy header"""
# Get domain configuration
domain = self.config.get("domain", "'self'")
# CSP directives for the application
csp_directives = {
# Default source
"default-src": ["'self'"],
# Script sources - allow self and inline scripts for the app
"script-src": [
"'self'",
"'unsafe-inline'", # Required for inline event handlers
"https://cdn.tailwindcss.com", # Tailwind CSS CDN if used
],
# Style sources - allow self and inline styles
"style-src": [
"'self'",
"'unsafe-inline'", # Required for component styling
"https://fonts.googleapis.com",
"https://cdn.tailwindcss.com",
],
# Font sources
"font-src": [
"'self'",
"https://fonts.gstatic.com",
"data:",
],
# Image sources
"img-src": [
"'self'",
"data:",
"blob:",
"https:", # Allow HTTPS images
],
# Media sources
"media-src": ["'self'", "blob:"],
# Object sources (disable Flash, etc.)
"object-src": ["'none'"],
# Frame sources (for embedding)
"frame-src": ["'none'"],
# Connect sources (AJAX, WebSocket, etc.)
"connect-src": [
"'self'",
"wss:", # WebSocket support
"ws:", # WebSocket support
],
# Worker sources
"worker-src": ["'self'", "blob:"],
# Child sources
"child-src": ["'none'"],
# Form action restrictions
"form-action": ["'self'"],
# Frame ancestors (clickjacking protection)
"frame-ancestors": ["'none'"],
# Base URI restrictions
"base-uri": ["'self'"],
# Manifest sources
"manifest-src": ["'self'"],
}
# Build CSP string
csp_parts = []
for directive, sources in csp_directives.items():
csp_parts.append(f"{directive} {' '.join(sources)}")
# Add upgrade insecure requests in HTTPS environments
if self._is_https_environment():
csp_parts.append("upgrade-insecure-requests")
return "; ".join(csp_parts)
def _build_permissions_policy(self) -> str:
"""Build Permissions Policy header"""
# Restrictive permissions policy
policies = {
# Disable camera access
"camera": "(),",
# Disable microphone access
"microphone": "(),",
# Disable geolocation
"geolocation": "(),",
# Disable gyroscope
"gyroscope": "(),",
# Disable magnetometer
"magnetometer": "(),",
# Disable payment API
"payment": "(),",
# Disable USB access
"usb": "(),",
# Disable notifications (except for self)
"notifications": "(self),",
# Disable push messaging
"push": "(),",
# Disable speaker selection
"speaker-selection": "(),",
# Allow clipboard access for self
"clipboard-write": "(self),",
"clipboard-read": "(self),",
# Allow fullscreen for self
"fullscreen": "(self),",
}
return " ".join([f"{feature}={policy}" for feature, policy in policies.items()])
async def dispatch(self, request: Request, call_next: Callable) -> Response:
# Process the request
response = await call_next(request)
# Add security headers to all responses
for header, value in self.headers.items():
if value: # Only add non-empty headers
response.headers[header] = value
# Special handling for certain endpoints
self._apply_endpoint_specific_headers(request, response)
return response
def _apply_endpoint_specific_headers(self, request: Request, response: Response):
"""Apply endpoint-specific security headers"""
path = request.url.path
# Admin pages - extra security
if path.startswith(("/admin", "/api/admin/")):
# More restrictive CSP for admin pages
response.headers["X-Frame-Options"] = "DENY"
response.headers["Cache-Control"] = "no-store, no-cache, must-revalidate"
response.headers["Pragma"] = "no-cache"
# API endpoints - prevent caching of sensitive data
elif path.startswith("/api/"):
# Prevent caching of API responses
if request.method != "GET" or "auth" in path or "admin" in path:
response.headers["Cache-Control"] = "no-store, no-cache, must-revalidate"
response.headers["Pragma"] = "no-cache"
# File upload endpoints - additional validation headers
elif "upload" in path:
response.headers["X-Content-Type-Options"] = "nosniff"
response.headers["Cache-Control"] = "no-store"
# Static files - allow caching but with security
elif path.startswith(("/static/", "/uploads/")):
response.headers["X-Content-Type-Options"] = "nosniff"
# Allow caching for static resources
if "static" in path:
response.headers["Cache-Control"] = "public, max-age=31536000"
class RequestSizeLimitMiddleware(BaseHTTPMiddleware):
"""Middleware to limit request body size to prevent DoS attacks"""
def __init__(self, app, max_size: int = 50 * 1024 * 1024): # 50MB default
super().__init__(app)
self.max_size = max_size
async def dispatch(self, request: Request, call_next: Callable) -> Response:
# Check Content-Length header
content_length = request.headers.get("content-length")
if content_length:
try:
size = int(content_length)
if size > self.max_size:
logger.warning(
"Request size limit exceeded",
size=size,
limit=self.max_size,
path=request.url.path,
ip=self._get_client_ip(request)
)
# Build standardized error envelope with correlation id
from starlette.responses import JSONResponse
# Resolve correlation id from state, headers, or generate
correlation_id = (
getattr(getattr(request, "state", object()), "correlation_id", None)
or request.headers.get("x-correlation-id")
or request.headers.get("x-request-id")
or str(uuid4())
)
body = {
"success": False,
"error": {
"status": 413,
"code": "http_error",
"message": "Payload too large",
},
"correlation_id": correlation_id,
}
response = JSONResponse(status_code=413, content=body)
response.headers["X-Correlation-ID"] = correlation_id
return response
except ValueError:
pass # Invalid Content-Length header, let it pass
return await call_next(request)
def _get_client_ip(self, request: Request) -> str:
"""Extract client IP address from request headers"""
forwarded_for = request.headers.get("x-forwarded-for")
if forwarded_for:
return forwarded_for.split(",")[0].strip()
real_ip = request.headers.get("x-real-ip")
if real_ip:
return real_ip
if request.client:
return request.client.host
return "unknown"
class CSRFMiddleware(BaseHTTPMiddleware):
"""CSRF protection middleware for state-changing operations"""
def __init__(self, app, exempt_paths: Optional[list] = None):
super().__init__(app)
# Paths that don't require CSRF protection
self.exempt_paths = exempt_paths or [
"/api/auth/login",
"/api/auth/refresh",
"/health",
"/static/",
"/uploads/",
]
# HTTP methods that require CSRF protection
self.protected_methods = {"POST", "PUT", "DELETE", "PATCH"}
async def dispatch(self, request: Request, call_next: Callable) -> Response:
# Skip CSRF protection for exempt paths
if any(request.url.path.startswith(path) for path in self.exempt_paths):
return await call_next(request)
# Skip for safe HTTP methods
if request.method not in self.protected_methods:
return await call_next(request)
# Check for CSRF token in headers
csrf_token = request.headers.get("X-CSRF-Token") or request.headers.get("X-CSRFToken")
# For now, implement a simple CSRF check based on Referer/Origin
# In production, you'd want proper CSRF tokens
referer = request.headers.get("referer", "")
origin = request.headers.get("origin", "")
host = request.headers.get("host", "")
# Allow requests from same origin
valid_origins = [f"https://{host}", f"http://{host}"]
if origin and origin not in valid_origins:
if not referer or not any(referer.startswith(valid) for valid in valid_origins):
logger.warning(
"CSRF check failed",
path=request.url.path,
method=request.method,
origin=origin,
referer=referer,
ip=self._get_client_ip(request)
)
from fastapi import HTTPException, status
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="CSRF validation failed"
)
return await call_next(request)
def _get_client_ip(self, request: Request) -> str:
"""Extract client IP address from request headers"""
forwarded_for = request.headers.get("x-forwarded-for")
if forwarded_for:
return forwarded_for.split(",")[0].strip()
real_ip = request.headers.get("x-real-ip")
if real_ip:
return real_ip
if request.client:
return request.client.host
return "unknown"