all working
This commit is contained in:
132
app/middleware/errors.py
Normal file
132
app/middleware/errors.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""
|
||||
Global exception handlers that return a consistent JSON error envelope
|
||||
and propagate the X-Correlation-ID header.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette import status as http_status
|
||||
from starlette.requests import Request as StarletteRequest
|
||||
from starlette.responses import JSONResponse as StarletteJSONResponse
|
||||
from fastapi import HTTPException
|
||||
|
||||
try:
|
||||
# Prefer project logger if available
|
||||
from app.core.logging import get_logger
|
||||
logger = get_logger("middleware.errors")
|
||||
except Exception: # pragma: no cover - fallback simple logger
|
||||
import logging
|
||||
logger = logging.getLogger("middleware.errors")
|
||||
|
||||
|
||||
ERROR_HEADER_NAME = "X-Correlation-ID"
|
||||
|
||||
|
||||
def _get_correlation_id(request: Request) -> str:
|
||||
"""Resolve correlation ID from request state, headers, or generate a new one."""
|
||||
# From middleware
|
||||
correlation_id: Optional[str] = getattr(getattr(request, "state", object()), "correlation_id", None)
|
||||
if correlation_id:
|
||||
return correlation_id
|
||||
|
||||
# From incoming headers
|
||||
correlation_id = (
|
||||
request.headers.get("x-correlation-id")
|
||||
or request.headers.get("x-request-id")
|
||||
)
|
||||
if correlation_id:
|
||||
return correlation_id
|
||||
|
||||
# Generate a new one if not present
|
||||
return str(uuid4())
|
||||
|
||||
|
||||
def _build_error_response(
|
||||
request: Request,
|
||||
*,
|
||||
status_code: int,
|
||||
message: str,
|
||||
code: str,
|
||||
details: Any | None = None,
|
||||
) -> JSONResponse:
|
||||
correlation_id = _get_correlation_id(request)
|
||||
|
||||
body = {
|
||||
"success": False,
|
||||
"error": {
|
||||
"status": status_code,
|
||||
"code": code,
|
||||
"message": message,
|
||||
},
|
||||
"correlation_id": correlation_id,
|
||||
}
|
||||
if details is not None:
|
||||
body["error"]["details"] = details
|
||||
|
||||
response = JSONResponse(content=body, status_code=status_code)
|
||||
response.headers[ERROR_HEADER_NAME] = correlation_id
|
||||
return response
|
||||
|
||||
|
||||
async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
|
||||
"""Handle FastAPI HTTPException with envelope and correlation id."""
|
||||
message = exc.detail if isinstance(exc.detail, str) else "HTTP error"
|
||||
logger.warning(
|
||||
"HTTPException raised",
|
||||
status_code=exc.status_code,
|
||||
detail=message,
|
||||
path=request.url.path,
|
||||
)
|
||||
return _build_error_response(
|
||||
request,
|
||||
status_code=exc.status_code,
|
||||
message=message,
|
||||
code="http_error",
|
||||
details=None,
|
||||
)
|
||||
|
||||
|
||||
async def validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse:
|
||||
"""Handle validation errors from FastAPI/Pydantic."""
|
||||
logger.info(
|
||||
"Validation error",
|
||||
path=request.url.path,
|
||||
errors=exc.errors(),
|
||||
)
|
||||
return _build_error_response(
|
||||
request,
|
||||
status_code=http_status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
message="Validation error",
|
||||
code="validation_error",
|
||||
details=exc.errors(),
|
||||
)
|
||||
|
||||
|
||||
async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONResponse:
|
||||
"""Catch-all handler for unexpected exceptions (500)."""
|
||||
# Log full exception for diagnostics without leaking internals to clients
|
||||
try:
|
||||
logger.exception("Unhandled exception", path=request.url.path)
|
||||
except Exception:
|
||||
pass
|
||||
return _build_error_response(
|
||||
request,
|
||||
status_code=http_status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
message="Internal Server Error",
|
||||
code="internal_error",
|
||||
details=None,
|
||||
)
|
||||
|
||||
|
||||
def register_exception_handlers(app: FastAPI) -> None:
|
||||
"""Register global exception handlers on the provided FastAPI app."""
|
||||
app.add_exception_handler(HTTPException, http_exception_handler)
|
||||
app.add_exception_handler(RequestValidationError, validation_exception_handler)
|
||||
app.add_exception_handler(Exception, unhandled_exception_handler)
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ Request/Response Logging Middleware
|
||||
import time
|
||||
import json
|
||||
from typing import Callable
|
||||
from uuid import uuid4
|
||||
from fastapi import Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import StreamingResponse
|
||||
@@ -21,10 +22,19 @@ class LoggingMiddleware(BaseHTTPMiddleware):
|
||||
self.log_responses = log_responses
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
# Skip logging for static files and health checks
|
||||
# Correlation ID: use incoming header or generate a new one
|
||||
correlation_id = request.headers.get("x-correlation-id") or request.headers.get("x-request-id") or str(uuid4())
|
||||
request.state.correlation_id = correlation_id
|
||||
|
||||
# Skip logging for static files and health checks (still attach correlation id)
|
||||
skip_paths = ["/static/", "/uploads/", "/health", "/favicon.ico"]
|
||||
if any(request.url.path.startswith(path) for path in skip_paths):
|
||||
return await call_next(request)
|
||||
response = await call_next(request)
|
||||
try:
|
||||
response.headers["X-Correlation-ID"] = correlation_id
|
||||
except Exception:
|
||||
pass
|
||||
return response
|
||||
|
||||
# Record start time
|
||||
start_time = time.time()
|
||||
@@ -41,7 +51,8 @@ class LoggingMiddleware(BaseHTTPMiddleware):
|
||||
path=request.url.path,
|
||||
query_params=str(request.query_params) if request.query_params else None,
|
||||
client_ip=client_ip,
|
||||
user_agent=user_agent
|
||||
user_agent=user_agent,
|
||||
correlation_id=correlation_id,
|
||||
)
|
||||
|
||||
# Process request
|
||||
@@ -56,7 +67,8 @@ class LoggingMiddleware(BaseHTTPMiddleware):
|
||||
path=request.url.path,
|
||||
duration_ms=duration_ms,
|
||||
error=str(e),
|
||||
client_ip=client_ip
|
||||
client_ip=client_ip,
|
||||
correlation_id=correlation_id,
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -74,7 +86,8 @@ class LoggingMiddleware(BaseHTTPMiddleware):
|
||||
path=request.url.path,
|
||||
status_code=response.status_code,
|
||||
duration_ms=duration_ms,
|
||||
user_id=user_id
|
||||
user_id=user_id,
|
||||
correlation_id=correlation_id,
|
||||
)
|
||||
|
||||
# Log response details if enabled
|
||||
@@ -84,7 +97,8 @@ class LoggingMiddleware(BaseHTTPMiddleware):
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers),
|
||||
size_bytes=response.headers.get("content-length"),
|
||||
content_type=response.headers.get("content-type")
|
||||
content_type=response.headers.get("content-type"),
|
||||
correlation_id=correlation_id,
|
||||
)
|
||||
|
||||
# Log slow requests as warnings
|
||||
@@ -94,7 +108,8 @@ class LoggingMiddleware(BaseHTTPMiddleware):
|
||||
method=request.method,
|
||||
path=request.url.path,
|
||||
duration_ms=duration_ms,
|
||||
status_code=response.status_code
|
||||
status_code=response.status_code,
|
||||
correlation_id=correlation_id,
|
||||
)
|
||||
|
||||
# Log authentication-related requests to auth log
|
||||
@@ -106,9 +121,16 @@ class LoggingMiddleware(BaseHTTPMiddleware):
|
||||
status_code=response.status_code,
|
||||
duration_ms=duration_ms,
|
||||
client_ip=client_ip,
|
||||
user_agent=user_agent
|
||||
user_agent=user_agent,
|
||||
correlation_id=correlation_id,
|
||||
)
|
||||
|
||||
|
||||
# Attach correlation id header to all responses
|
||||
try:
|
||||
response.headers["X-Correlation-ID"] = correlation_id
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return response
|
||||
|
||||
def get_client_ip(self, request: Request) -> str:
|
||||
|
||||
Reference in New Issue
Block a user