changes
This commit is contained in:
406
app/middleware/security_headers.py
Normal file
406
app/middleware/security_headers.py
Normal file
@@ -0,0 +1,406 @@
|
||||
"""
|
||||
Security Headers Middleware
|
||||
|
||||
Implements comprehensive security headers to protect against common web vulnerabilities:
|
||||
- HSTS (HTTP Strict Transport Security)
|
||||
- CSP (Content Security Policy)
|
||||
- X-Frame-Options (Clickjacking protection)
|
||||
- X-Content-Type-Options (MIME sniffing protection)
|
||||
- X-XSS-Protection (XSS protection)
|
||||
- Referrer-Policy (Information disclosure protection)
|
||||
- Permissions-Policy (Feature policy)
|
||||
"""
|
||||
from typing import Callable, Dict, Optional
|
||||
from uuid import uuid4
|
||||
from fastapi import Request
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import Response
|
||||
from app.config import settings
|
||||
from app.utils.logging import app_logger
|
||||
|
||||
logger = app_logger.bind(name="security_headers")
|
||||
|
||||
|
||||
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware to add security headers to all responses"""
|
||||
|
||||
def __init__(self, app, config: Optional[Dict[str, str]] = None):
|
||||
super().__init__(app)
|
||||
self.config = config or {}
|
||||
self.headers = self._build_security_headers()
|
||||
|
||||
def _build_security_headers(self) -> Dict[str, str]:
|
||||
"""Build security headers based on configuration and environment"""
|
||||
|
||||
# Base security headers
|
||||
headers = {
|
||||
# Prevent MIME type sniffing
|
||||
"X-Content-Type-Options": "nosniff",
|
||||
|
||||
# XSS Protection (legacy but still useful)
|
||||
"X-XSS-Protection": "1; mode=block",
|
||||
|
||||
# Clickjacking protection
|
||||
"X-Frame-Options": "DENY",
|
||||
|
||||
# Referrer policy
|
||||
"Referrer-Policy": "strict-origin-when-cross-origin",
|
||||
|
||||
# Remove server information
|
||||
"Server": "Delphi-DB",
|
||||
|
||||
# Prevent exposure of sensitive information
|
||||
"X-Powered-By": "",
|
||||
}
|
||||
|
||||
# HSTS (HTTP Strict Transport Security) - only for HTTPS
|
||||
if self._is_https_environment():
|
||||
headers["Strict-Transport-Security"] = self.config.get(
|
||||
"hsts",
|
||||
"max-age=31536000; includeSubDomains; preload"
|
||||
)
|
||||
|
||||
# Content Security Policy
|
||||
csp = self._build_csp_header()
|
||||
if csp:
|
||||
headers["Content-Security-Policy"] = csp
|
||||
|
||||
# Permissions Policy (Feature Policy)
|
||||
permissions_policy = self._build_permissions_policy()
|
||||
if permissions_policy:
|
||||
headers["Permissions-Policy"] = permissions_policy
|
||||
|
||||
return headers
|
||||
|
||||
def _is_https_environment(self) -> bool:
|
||||
"""Check if we're in an HTTPS environment"""
|
||||
# Check common HTTPS indicators
|
||||
if self.config.get("force_https", False):
|
||||
return True
|
||||
|
||||
# In production, assume HTTPS
|
||||
if not settings.debug:
|
||||
return True
|
||||
|
||||
# Check for secure cookies setting
|
||||
if settings.secure_cookies:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _build_csp_header(self) -> str:
|
||||
"""Build Content Security Policy header"""
|
||||
|
||||
# Get domain configuration
|
||||
domain = self.config.get("domain", "'self'")
|
||||
|
||||
# CSP directives for the application
|
||||
csp_directives = {
|
||||
# Default source
|
||||
"default-src": ["'self'"],
|
||||
|
||||
# Script sources - allow self and inline scripts for the app
|
||||
"script-src": [
|
||||
"'self'",
|
||||
"'unsafe-inline'", # Required for inline event handlers
|
||||
"https://cdn.tailwindcss.com", # Tailwind CSS CDN if used
|
||||
],
|
||||
|
||||
# Style sources - allow self and inline styles
|
||||
"style-src": [
|
||||
"'self'",
|
||||
"'unsafe-inline'", # Required for component styling
|
||||
"https://fonts.googleapis.com",
|
||||
"https://cdn.tailwindcss.com",
|
||||
],
|
||||
|
||||
# Font sources
|
||||
"font-src": [
|
||||
"'self'",
|
||||
"https://fonts.gstatic.com",
|
||||
"data:",
|
||||
],
|
||||
|
||||
# Image sources
|
||||
"img-src": [
|
||||
"'self'",
|
||||
"data:",
|
||||
"blob:",
|
||||
"https:", # Allow HTTPS images
|
||||
],
|
||||
|
||||
# Media sources
|
||||
"media-src": ["'self'", "blob:"],
|
||||
|
||||
# Object sources (disable Flash, etc.)
|
||||
"object-src": ["'none'"],
|
||||
|
||||
# Frame sources (for embedding)
|
||||
"frame-src": ["'none'"],
|
||||
|
||||
# Connect sources (AJAX, WebSocket, etc.)
|
||||
"connect-src": [
|
||||
"'self'",
|
||||
"wss:", # WebSocket support
|
||||
"ws:", # WebSocket support
|
||||
],
|
||||
|
||||
# Worker sources
|
||||
"worker-src": ["'self'", "blob:"],
|
||||
|
||||
# Child sources
|
||||
"child-src": ["'none'"],
|
||||
|
||||
# Form action restrictions
|
||||
"form-action": ["'self'"],
|
||||
|
||||
# Frame ancestors (clickjacking protection)
|
||||
"frame-ancestors": ["'none'"],
|
||||
|
||||
# Base URI restrictions
|
||||
"base-uri": ["'self'"],
|
||||
|
||||
# Manifest sources
|
||||
"manifest-src": ["'self'"],
|
||||
}
|
||||
|
||||
# Build CSP string
|
||||
csp_parts = []
|
||||
for directive, sources in csp_directives.items():
|
||||
csp_parts.append(f"{directive} {' '.join(sources)}")
|
||||
|
||||
# Add upgrade insecure requests in HTTPS environments
|
||||
if self._is_https_environment():
|
||||
csp_parts.append("upgrade-insecure-requests")
|
||||
|
||||
return "; ".join(csp_parts)
|
||||
|
||||
def _build_permissions_policy(self) -> str:
|
||||
"""Build Permissions Policy header"""
|
||||
|
||||
# Restrictive permissions policy
|
||||
policies = {
|
||||
# Disable camera access
|
||||
"camera": "(),",
|
||||
|
||||
# Disable microphone access
|
||||
"microphone": "(),",
|
||||
|
||||
# Disable geolocation
|
||||
"geolocation": "(),",
|
||||
|
||||
# Disable gyroscope
|
||||
"gyroscope": "(),",
|
||||
|
||||
# Disable magnetometer
|
||||
"magnetometer": "(),",
|
||||
|
||||
# Disable payment API
|
||||
"payment": "(),",
|
||||
|
||||
# Disable USB access
|
||||
"usb": "(),",
|
||||
|
||||
# Disable notifications (except for self)
|
||||
"notifications": "(self),",
|
||||
|
||||
# Disable push messaging
|
||||
"push": "(),",
|
||||
|
||||
# Disable speaker selection
|
||||
"speaker-selection": "(),",
|
||||
|
||||
# Allow clipboard access for self
|
||||
"clipboard-write": "(self),",
|
||||
"clipboard-read": "(self),",
|
||||
|
||||
# Allow fullscreen for self
|
||||
"fullscreen": "(self),",
|
||||
}
|
||||
|
||||
return " ".join([f"{feature}={policy}" for feature, policy in policies.items()])
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
# Process the request
|
||||
response = await call_next(request)
|
||||
|
||||
# Add security headers to all responses
|
||||
for header, value in self.headers.items():
|
||||
if value: # Only add non-empty headers
|
||||
response.headers[header] = value
|
||||
|
||||
# Special handling for certain endpoints
|
||||
self._apply_endpoint_specific_headers(request, response)
|
||||
|
||||
return response
|
||||
|
||||
def _apply_endpoint_specific_headers(self, request: Request, response: Response):
|
||||
"""Apply endpoint-specific security headers"""
|
||||
|
||||
path = request.url.path
|
||||
|
||||
# Admin pages - extra security
|
||||
if path.startswith(("/admin", "/api/admin/")):
|
||||
# More restrictive CSP for admin pages
|
||||
response.headers["X-Frame-Options"] = "DENY"
|
||||
response.headers["Cache-Control"] = "no-store, no-cache, must-revalidate"
|
||||
response.headers["Pragma"] = "no-cache"
|
||||
|
||||
# API endpoints - prevent caching of sensitive data
|
||||
elif path.startswith("/api/"):
|
||||
# Prevent caching of API responses
|
||||
if request.method != "GET" or "auth" in path or "admin" in path:
|
||||
response.headers["Cache-Control"] = "no-store, no-cache, must-revalidate"
|
||||
response.headers["Pragma"] = "no-cache"
|
||||
|
||||
# File upload endpoints - additional validation headers
|
||||
elif "upload" in path:
|
||||
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||
response.headers["Cache-Control"] = "no-store"
|
||||
|
||||
# Static files - allow caching but with security
|
||||
elif path.startswith(("/static/", "/uploads/")):
|
||||
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||
# Allow caching for static resources
|
||||
if "static" in path:
|
||||
response.headers["Cache-Control"] = "public, max-age=31536000"
|
||||
|
||||
|
||||
class RequestSizeLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware to limit request body size to prevent DoS attacks"""
|
||||
|
||||
def __init__(self, app, max_size: int = 50 * 1024 * 1024): # 50MB default
|
||||
super().__init__(app)
|
||||
self.max_size = max_size
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
# Check Content-Length header
|
||||
content_length = request.headers.get("content-length")
|
||||
|
||||
if content_length:
|
||||
try:
|
||||
size = int(content_length)
|
||||
if size > self.max_size:
|
||||
logger.warning(
|
||||
"Request size limit exceeded",
|
||||
size=size,
|
||||
limit=self.max_size,
|
||||
path=request.url.path,
|
||||
ip=self._get_client_ip(request)
|
||||
)
|
||||
|
||||
# Build standardized error envelope with correlation id
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
# Resolve correlation id from state, headers, or generate
|
||||
correlation_id = (
|
||||
getattr(getattr(request, "state", object()), "correlation_id", None)
|
||||
or request.headers.get("x-correlation-id")
|
||||
or request.headers.get("x-request-id")
|
||||
or str(uuid4())
|
||||
)
|
||||
|
||||
body = {
|
||||
"success": False,
|
||||
"error": {
|
||||
"status": 413,
|
||||
"code": "http_error",
|
||||
"message": "Payload too large",
|
||||
},
|
||||
"correlation_id": correlation_id,
|
||||
}
|
||||
|
||||
response = JSONResponse(status_code=413, content=body)
|
||||
response.headers["X-Correlation-ID"] = correlation_id
|
||||
return response
|
||||
except ValueError:
|
||||
pass # Invalid Content-Length header, let it pass
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
def _get_client_ip(self, request: Request) -> str:
|
||||
"""Extract client IP address from request headers"""
|
||||
forwarded_for = request.headers.get("x-forwarded-for")
|
||||
if forwarded_for:
|
||||
return forwarded_for.split(",")[0].strip()
|
||||
|
||||
real_ip = request.headers.get("x-real-ip")
|
||||
if real_ip:
|
||||
return real_ip
|
||||
|
||||
if request.client:
|
||||
return request.client.host
|
||||
|
||||
return "unknown"
|
||||
|
||||
|
||||
class CSRFMiddleware(BaseHTTPMiddleware):
|
||||
"""CSRF protection middleware for state-changing operations"""
|
||||
|
||||
def __init__(self, app, exempt_paths: Optional[list] = None):
|
||||
super().__init__(app)
|
||||
# Paths that don't require CSRF protection
|
||||
self.exempt_paths = exempt_paths or [
|
||||
"/api/auth/login",
|
||||
"/api/auth/refresh",
|
||||
"/health",
|
||||
"/static/",
|
||||
"/uploads/",
|
||||
]
|
||||
# HTTP methods that require CSRF protection
|
||||
self.protected_methods = {"POST", "PUT", "DELETE", "PATCH"}
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
# Skip CSRF protection for exempt paths
|
||||
if any(request.url.path.startswith(path) for path in self.exempt_paths):
|
||||
return await call_next(request)
|
||||
|
||||
# Skip for safe HTTP methods
|
||||
if request.method not in self.protected_methods:
|
||||
return await call_next(request)
|
||||
|
||||
# Check for CSRF token in headers
|
||||
csrf_token = request.headers.get("X-CSRF-Token") or request.headers.get("X-CSRFToken")
|
||||
|
||||
# For now, implement a simple CSRF check based on Referer/Origin
|
||||
# In production, you'd want proper CSRF tokens
|
||||
referer = request.headers.get("referer", "")
|
||||
origin = request.headers.get("origin", "")
|
||||
host = request.headers.get("host", "")
|
||||
|
||||
# Allow requests from same origin
|
||||
valid_origins = [f"https://{host}", f"http://{host}"]
|
||||
|
||||
if origin and origin not in valid_origins:
|
||||
if not referer or not any(referer.startswith(valid) for valid in valid_origins):
|
||||
logger.warning(
|
||||
"CSRF check failed",
|
||||
path=request.url.path,
|
||||
method=request.method,
|
||||
origin=origin,
|
||||
referer=referer,
|
||||
ip=self._get_client_ip(request)
|
||||
)
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="CSRF validation failed"
|
||||
)
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
def _get_client_ip(self, request: Request) -> str:
|
||||
"""Extract client IP address from request headers"""
|
||||
forwarded_for = request.headers.get("x-forwarded-for")
|
||||
if forwarded_for:
|
||||
return forwarded_for.split(",")[0].strip()
|
||||
|
||||
real_ip = request.headers.get("x-real-ip")
|
||||
if real_ip:
|
||||
return real_ip
|
||||
|
||||
if request.client:
|
||||
return request.client.host
|
||||
|
||||
return "unknown"
|
||||
Reference in New Issue
Block a user