160 lines
4.8 KiB
Python
160 lines
4.8 KiB
Python
"""
|
|
Global exception handlers that return a consistent JSON error envelope
|
|
and propagate the X-Correlation-ID header.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
from typing import Any, Optional
|
|
from uuid import uuid4
|
|
|
|
from fastapi import FastAPI, Request
|
|
from fastapi.exceptions import RequestValidationError
|
|
from fastapi.responses import JSONResponse
|
|
from starlette import status as http_status
|
|
from starlette.requests import Request as StarletteRequest
|
|
from starlette.responses import JSONResponse as StarletteJSONResponse
|
|
from fastapi import HTTPException
|
|
|
|
try:
|
|
# Prefer project logger if available
|
|
from app.core.logging import get_logger
|
|
logger = get_logger("middleware.errors")
|
|
except Exception: # pragma: no cover - fallback simple logger
|
|
import logging
|
|
logger = logging.getLogger("middleware.errors")
|
|
|
|
|
|
ERROR_HEADER_NAME = "X-Correlation-ID"
|
|
|
|
|
|
def _get_correlation_id(request: Request) -> str:
|
|
"""Resolve correlation ID from request state, headers, or generate a new one."""
|
|
# From middleware
|
|
correlation_id: Optional[str] = getattr(getattr(request, "state", object()), "correlation_id", None)
|
|
if correlation_id:
|
|
return correlation_id
|
|
|
|
# From incoming headers
|
|
correlation_id = (
|
|
request.headers.get("x-correlation-id")
|
|
or request.headers.get("x-request-id")
|
|
)
|
|
if correlation_id:
|
|
return correlation_id
|
|
|
|
# Generate a new one if not present
|
|
return str(uuid4())
|
|
|
|
|
|
def _json_safe(value: Any) -> Any:
|
|
"""Recursively convert non-JSON-serializable objects (like Exceptions) into strings.
|
|
|
|
Keeps overall structure intact so tests inspecting error details (e.g. 'loc', 'msg') still work.
|
|
"""
|
|
# Exception -> string message
|
|
if isinstance(value, BaseException):
|
|
return str(value)
|
|
# Mapping types
|
|
if isinstance(value, dict):
|
|
return {k: _json_safe(v) for k, v in value.items()}
|
|
# Sequence types
|
|
if isinstance(value, (list, tuple)):
|
|
return [
|
|
_json_safe(v) for v in value
|
|
]
|
|
return value
|
|
|
|
|
|
def _build_error_response(
|
|
request: Request,
|
|
*,
|
|
status_code: int,
|
|
message: str,
|
|
code: str,
|
|
details: Any | None = None,
|
|
) -> JSONResponse:
|
|
correlation_id = _get_correlation_id(request)
|
|
|
|
body = {
|
|
"success": False,
|
|
"error": {
|
|
"status": status_code,
|
|
"code": code,
|
|
"message": message,
|
|
},
|
|
"correlation_id": correlation_id,
|
|
}
|
|
if details is not None:
|
|
body["error"]["details"] = _json_safe(details)
|
|
|
|
response = JSONResponse(content=body, status_code=status_code)
|
|
response.headers[ERROR_HEADER_NAME] = correlation_id
|
|
return response
|
|
|
|
|
|
async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
|
|
"""Handle FastAPI HTTPException with envelope and correlation id."""
|
|
message = exc.detail if isinstance(exc.detail, str) else "HTTP error"
|
|
logger.warning(
|
|
"HTTPException raised",
|
|
status_code=exc.status_code,
|
|
detail=message,
|
|
path=request.url.path,
|
|
)
|
|
response = _build_error_response(
|
|
request,
|
|
status_code=exc.status_code,
|
|
message=message,
|
|
code="http_error",
|
|
details=None,
|
|
)
|
|
# Preserve any headers set on the HTTPException (e.g., WWW-Authenticate)
|
|
try:
|
|
if getattr(exc, "headers", None):
|
|
for key, value in exc.headers.items():
|
|
response.headers[key] = value
|
|
except Exception:
|
|
pass
|
|
return response
|
|
|
|
|
|
async def validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse:
|
|
"""Handle validation errors from FastAPI/Pydantic."""
|
|
logger.info(
|
|
"Validation error",
|
|
path=request.url.path,
|
|
errors=exc.errors(),
|
|
)
|
|
return _build_error_response(
|
|
request,
|
|
status_code=http_status.HTTP_422_UNPROCESSABLE_ENTITY,
|
|
message="Validation error",
|
|
code="validation_error",
|
|
details=exc.errors(),
|
|
)
|
|
|
|
|
|
async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONResponse:
|
|
"""Catch-all handler for unexpected exceptions (500)."""
|
|
# Log full exception for diagnostics without leaking internals to clients
|
|
try:
|
|
logger.exception("Unhandled exception", path=request.url.path)
|
|
except Exception:
|
|
pass
|
|
return _build_error_response(
|
|
request,
|
|
status_code=http_status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
message="Internal Server Error",
|
|
code="internal_error",
|
|
details=None,
|
|
)
|
|
|
|
|
|
def register_exception_handlers(app: FastAPI) -> None:
|
|
"""Register global exception handlers on the provided FastAPI app."""
|
|
app.add_exception_handler(HTTPException, http_exception_handler)
|
|
app.add_exception_handler(RequestValidationError, validation_exception_handler)
|
|
app.add_exception_handler(Exception, unhandled_exception_handler)
|
|
|
|
|