""" Request/Response Logging Middleware """ import time import json from typing import Callable from uuid import uuid4 from fastapi import Request, Response from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import StreamingResponse from app.core.logging import get_logger, log_request logger = get_logger("middleware.logging") class LoggingMiddleware(BaseHTTPMiddleware): """Middleware to log HTTP requests and responses""" def __init__(self, app, log_requests: bool = True, log_responses: bool = False): super().__init__(app) self.log_requests = log_requests self.log_responses = log_responses async def dispatch(self, request: Request, call_next: Callable) -> Response: # Correlation ID: use incoming header or generate a new one correlation_id = request.headers.get("x-correlation-id") or request.headers.get("x-request-id") or str(uuid4()) request.state.correlation_id = correlation_id # Skip logging for static files and health checks (still attach correlation id) skip_paths = ["/static/", "/uploads/", "/health", "/favicon.ico"] if any(request.url.path.startswith(path) for path in skip_paths): response = await call_next(request) try: response.headers["X-Correlation-ID"] = correlation_id except Exception: pass return response # Record start time start_time = time.time() # Extract request details client_ip = self.get_client_ip(request) user_agent = request.headers.get("user-agent", "") # Log request if self.log_requests: logger.info( "Request started", method=request.method, path=request.url.path, query_params=str(request.query_params) if request.query_params else None, client_ip=client_ip, user_agent=user_agent, correlation_id=correlation_id, ) # Process request try: response = await call_next(request) except Exception as e: # Log exceptions duration_ms = (time.time() - start_time) * 1000 logger.error( "Request failed with exception", method=request.method, path=request.url.path, duration_ms=duration_ms, error=str(e), client_ip=client_ip, correlation_id=correlation_id, ) raise # Calculate duration duration_ms = (time.time() - start_time) * 1000 # Extract user ID from request if available (for authenticated requests) user_id = None if hasattr(request.state, "user") and request.state.user: user_id = getattr(request.state.user, "id", None) or getattr(request.state.user, "username", None) # Log response log_request( method=request.method, path=request.url.path, status_code=response.status_code, duration_ms=duration_ms, user_id=user_id, correlation_id=correlation_id, ) # Log response details if enabled if self.log_responses: logger.debug( "Response details", status_code=response.status_code, headers=dict(response.headers), size_bytes=response.headers.get("content-length"), content_type=response.headers.get("content-type"), correlation_id=correlation_id, ) # Log slow requests as warnings if duration_ms > 1000: # More than 1 second logger.warning( "Slow request detected", method=request.method, path=request.url.path, duration_ms=duration_ms, status_code=response.status_code, correlation_id=correlation_id, ) # Log authentication-related requests to auth log if any(path in request.url.path for path in ["/api/auth/", "/login", "/logout"]): logger.bind(name="auth").info( "Auth endpoint accessed", method=request.method, path=request.url.path, status_code=response.status_code, duration_ms=duration_ms, client_ip=client_ip, user_agent=user_agent, correlation_id=correlation_id, ) # Attach correlation id header to all responses try: response.headers["X-Correlation-ID"] = correlation_id except Exception: pass return response def get_client_ip(self, request: Request) -> str: """Extract client IP address from request headers""" # Check for IP in common proxy headers forwarded_for = request.headers.get("x-forwarded-for") if forwarded_for: # Take the first IP in the chain return forwarded_for.split(",")[0].strip() real_ip = request.headers.get("x-real-ip") if real_ip: return real_ip # Fallback to direct client IP if request.client: return request.client.host return "unknown"