""" Global exception handlers that return a consistent JSON error envelope and propagate the X-Correlation-ID header. """ from __future__ import annotations from typing import Any, Optional from uuid import uuid4 from fastapi import FastAPI, Request from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse from starlette import status as http_status from starlette.requests import Request as StarletteRequest from starlette.responses import JSONResponse as StarletteJSONResponse from fastapi import HTTPException try: # Prefer project logger if available from app.core.logging import get_logger logger = get_logger("middleware.errors") except Exception: # pragma: no cover - fallback simple logger import logging logger = logging.getLogger("middleware.errors") ERROR_HEADER_NAME = "X-Correlation-ID" def _get_correlation_id(request: Request) -> str: """Resolve correlation ID from request state, headers, or generate a new one.""" # From middleware correlation_id: Optional[str] = getattr(getattr(request, "state", object()), "correlation_id", None) if correlation_id: return correlation_id # From incoming headers correlation_id = ( request.headers.get("x-correlation-id") or request.headers.get("x-request-id") ) if correlation_id: return correlation_id # Generate a new one if not present return str(uuid4()) def _json_safe(value: Any) -> Any: """Recursively convert non-JSON-serializable objects (like Exceptions) into strings. Keeps overall structure intact so tests inspecting error details (e.g. 'loc', 'msg') still work. """ # Exception -> string message if isinstance(value, BaseException): return str(value) # Mapping types if isinstance(value, dict): return {k: _json_safe(v) for k, v in value.items()} # Sequence types if isinstance(value, (list, tuple)): return [ _json_safe(v) for v in value ] return value def _build_error_response( request: Request, *, status_code: int, message: str, code: str, details: Any | None = None, ) -> JSONResponse: correlation_id = _get_correlation_id(request) body = { "success": False, "error": { "status": status_code, "code": code, "message": message, }, "correlation_id": correlation_id, } if details is not None: body["error"]["details"] = _json_safe(details) response = JSONResponse(content=body, status_code=status_code) response.headers[ERROR_HEADER_NAME] = correlation_id return response async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse: """Handle FastAPI HTTPException with envelope and correlation id.""" message = exc.detail if isinstance(exc.detail, str) else "HTTP error" logger.warning( "HTTPException raised", status_code=exc.status_code, detail=message, path=request.url.path, ) response = _build_error_response( request, status_code=exc.status_code, message=message, code="http_error", details=None, ) # Preserve any headers set on the HTTPException (e.g., WWW-Authenticate) try: if getattr(exc, "headers", None): for key, value in exc.headers.items(): response.headers[key] = value except Exception: pass return response async def validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse: """Handle validation errors from FastAPI/Pydantic.""" logger.info( "Validation error", path=request.url.path, errors=exc.errors(), ) return _build_error_response( request, status_code=http_status.HTTP_422_UNPROCESSABLE_ENTITY, message="Validation error", code="validation_error", details=exc.errors(), ) async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONResponse: """Catch-all handler for unexpected exceptions (500).""" # Log full exception for diagnostics without leaking internals to clients try: logger.exception("Unhandled exception", path=request.url.path) except Exception: pass return _build_error_response( request, status_code=http_status.HTTP_500_INTERNAL_SERVER_ERROR, message="Internal Server Error", code="internal_error", details=None, ) def register_exception_handlers(app: FastAPI) -> None: """Register global exception handlers on the provided FastAPI app.""" app.add_exception_handler(HTTPException, http_exception_handler) app.add_exception_handler(RequestValidationError, validation_exception_handler) app.add_exception_handler(Exception, unhandled_exception_handler)