130 lines
4.5 KiB
Python
130 lines
4.5 KiB
Python
"""
|
|
Request/Response Logging Middleware
|
|
"""
|
|
import time
|
|
import json
|
|
from typing import Callable
|
|
from fastapi import Request, Response
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
from starlette.responses import StreamingResponse
|
|
from app.core.logging import get_logger, log_request
|
|
|
|
logger = get_logger("middleware.logging")
|
|
|
|
|
|
class LoggingMiddleware(BaseHTTPMiddleware):
|
|
"""Middleware to log HTTP requests and responses"""
|
|
|
|
def __init__(self, app, log_requests: bool = True, log_responses: bool = False):
|
|
super().__init__(app)
|
|
self.log_requests = log_requests
|
|
self.log_responses = log_responses
|
|
|
|
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
|
# Skip logging for static files and health checks
|
|
skip_paths = ["/static/", "/uploads/", "/health", "/favicon.ico"]
|
|
if any(request.url.path.startswith(path) for path in skip_paths):
|
|
return await call_next(request)
|
|
|
|
# Record start time
|
|
start_time = time.time()
|
|
|
|
# Extract request details
|
|
client_ip = self.get_client_ip(request)
|
|
user_agent = request.headers.get("user-agent", "")
|
|
|
|
# Log request
|
|
if self.log_requests:
|
|
logger.info(
|
|
"Request started",
|
|
method=request.method,
|
|
path=request.url.path,
|
|
query_params=str(request.query_params) if request.query_params else None,
|
|
client_ip=client_ip,
|
|
user_agent=user_agent
|
|
)
|
|
|
|
# Process request
|
|
try:
|
|
response = await call_next(request)
|
|
except Exception as e:
|
|
# Log exceptions
|
|
duration_ms = (time.time() - start_time) * 1000
|
|
logger.error(
|
|
"Request failed with exception",
|
|
method=request.method,
|
|
path=request.url.path,
|
|
duration_ms=duration_ms,
|
|
error=str(e),
|
|
client_ip=client_ip
|
|
)
|
|
raise
|
|
|
|
# Calculate duration
|
|
duration_ms = (time.time() - start_time) * 1000
|
|
|
|
# Extract user ID from request if available (for authenticated requests)
|
|
user_id = None
|
|
if hasattr(request.state, "user") and request.state.user:
|
|
user_id = getattr(request.state.user, "id", None) or getattr(request.state.user, "username", None)
|
|
|
|
# Log response
|
|
log_request(
|
|
method=request.method,
|
|
path=request.url.path,
|
|
status_code=response.status_code,
|
|
duration_ms=duration_ms,
|
|
user_id=user_id
|
|
)
|
|
|
|
# Log response details if enabled
|
|
if self.log_responses:
|
|
logger.debug(
|
|
"Response details",
|
|
status_code=response.status_code,
|
|
headers=dict(response.headers),
|
|
size_bytes=response.headers.get("content-length"),
|
|
content_type=response.headers.get("content-type")
|
|
)
|
|
|
|
# Log slow requests as warnings
|
|
if duration_ms > 1000: # More than 1 second
|
|
logger.warning(
|
|
"Slow request detected",
|
|
method=request.method,
|
|
path=request.url.path,
|
|
duration_ms=duration_ms,
|
|
status_code=response.status_code
|
|
)
|
|
|
|
# Log authentication-related requests to auth log
|
|
if any(path in request.url.path for path in ["/api/auth/", "/login", "/logout"]):
|
|
logger.bind(name="auth").info(
|
|
"Auth endpoint accessed",
|
|
method=request.method,
|
|
path=request.url.path,
|
|
status_code=response.status_code,
|
|
duration_ms=duration_ms,
|
|
client_ip=client_ip,
|
|
user_agent=user_agent
|
|
)
|
|
|
|
return response
|
|
|
|
def get_client_ip(self, request: Request) -> str:
|
|
"""Extract client IP address from request headers"""
|
|
# Check for IP in common proxy headers
|
|
forwarded_for = request.headers.get("x-forwarded-for")
|
|
if forwarded_for:
|
|
# Take the first IP in the chain
|
|
return forwarded_for.split(",")[0].strip()
|
|
|
|
real_ip = request.headers.get("x-real-ip")
|
|
if real_ip:
|
|
return real_ip
|
|
|
|
# Fallback to direct client IP
|
|
if request.client:
|
|
return request.client.host
|
|
|
|
return "unknown" |