""" Request/Response Logging Middleware """ import time import json from typing import Callable from fastapi import Request, Response from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import StreamingResponse from app.core.logging import get_logger, log_request logger = get_logger("middleware.logging") class LoggingMiddleware(BaseHTTPMiddleware): """Middleware to log HTTP requests and responses""" def __init__(self, app, log_requests: bool = True, log_responses: bool = False): super().__init__(app) self.log_requests = log_requests self.log_responses = log_responses async def dispatch(self, request: Request, call_next: Callable) -> Response: # Skip logging for static files and health checks skip_paths = ["/static/", "/uploads/", "/health", "/favicon.ico"] if any(request.url.path.startswith(path) for path in skip_paths): return await call_next(request) # Record start time start_time = time.time() # Extract request details client_ip = self.get_client_ip(request) user_agent = request.headers.get("user-agent", "") # Log request if self.log_requests: logger.info( "Request started", method=request.method, path=request.url.path, query_params=str(request.query_params) if request.query_params else None, client_ip=client_ip, user_agent=user_agent ) # Process request try: response = await call_next(request) except Exception as e: # Log exceptions duration_ms = (time.time() - start_time) * 1000 logger.error( "Request failed with exception", method=request.method, path=request.url.path, duration_ms=duration_ms, error=str(e), client_ip=client_ip ) raise # Calculate duration duration_ms = (time.time() - start_time) * 1000 # Extract user ID from request if available (for authenticated requests) user_id = None if hasattr(request.state, "user") and request.state.user: user_id = getattr(request.state.user, "id", None) or getattr(request.state.user, "username", None) # Log response log_request( method=request.method, path=request.url.path, status_code=response.status_code, duration_ms=duration_ms, user_id=user_id ) # Log response details if enabled if self.log_responses: logger.debug( "Response details", status_code=response.status_code, headers=dict(response.headers), size_bytes=response.headers.get("content-length"), content_type=response.headers.get("content-type") ) # Log slow requests as warnings if duration_ms > 1000: # More than 1 second logger.warning( "Slow request detected", method=request.method, path=request.url.path, duration_ms=duration_ms, status_code=response.status_code ) # Log authentication-related requests to auth log if any(path in request.url.path for path in ["/api/auth/", "/login", "/logout"]): logger.bind(name="auth").info( "Auth endpoint accessed", method=request.method, path=request.url.path, status_code=response.status_code, duration_ms=duration_ms, client_ip=client_ip, user_agent=user_agent ) return response def get_client_ip(self, request: Request) -> str: """Extract client IP address from request headers""" # Check for IP in common proxy headers forwarded_for = request.headers.get("x-forwarded-for") if forwarded_for: # Take the first IP in the chain return forwarded_for.split(",")[0].strip() real_ip = request.headers.get("x-real-ip") if real_ip: return real_ip # Fallback to direct client IP if request.client: return request.client.host return "unknown"