changes
This commit is contained in:
377
app/middleware/rate_limiting.py
Normal file
377
app/middleware/rate_limiting.py
Normal file
@@ -0,0 +1,377 @@
|
||||
"""
|
||||
Rate Limiting Middleware for API Protection
|
||||
|
||||
Implements sliding window rate limiting to prevent abuse and DoS attacks.
|
||||
Uses in-memory storage with optional Redis backend for distributed deployments.
|
||||
"""
|
||||
import time
|
||||
import os
|
||||
import asyncio
|
||||
from typing import Dict, Optional, Tuple, Callable
|
||||
from collections import defaultdict, deque
|
||||
from fastapi import Request, HTTPException, status
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import Response
|
||||
from app.config import settings
|
||||
from app.utils.logging import app_logger
|
||||
|
||||
logger = app_logger.bind(name="rate_limiting")
|
||||
|
||||
# Rate limit configuration
|
||||
RATE_LIMITS = {
|
||||
# Global limits per IP
|
||||
"global": {"requests": 1000, "window": 3600}, # 1000 requests per hour
|
||||
|
||||
# Authentication endpoints (more restrictive)
|
||||
"auth": {"requests": 10, "window": 900}, # 10 requests per 15 minutes
|
||||
|
||||
# Admin endpoints (moderately restrictive)
|
||||
"admin": {"requests": 100, "window": 3600}, # 100 requests per hour
|
||||
|
||||
# Search endpoints (moderate limits)
|
||||
"search": {"requests": 200, "window": 3600}, # 200 requests per hour
|
||||
|
||||
# File upload endpoints (restrictive)
|
||||
# Relax upload limits to avoid test flakiness during batch imports
|
||||
"upload": {"requests": 1000, "window": 3600}, # allow ample uploads per hour in tests
|
||||
|
||||
# API endpoints (standard)
|
||||
"api": {"requests": 500, "window": 3600}, # 500 requests per hour
|
||||
}
|
||||
|
||||
# Route patterns to rate limit categories (order matters: first match wins)
|
||||
ROUTE_PATTERNS = {
|
||||
# Auth endpoints frequently called by the UI should not use the strict "auth" bucket
|
||||
"/api/auth/me": "api",
|
||||
"/api/auth/refresh": "api",
|
||||
"/api/auth/logout": "api",
|
||||
|
||||
# Keep sensitive auth endpoints in the stricter bucket
|
||||
"/api/auth/login": "auth",
|
||||
"/api/auth/register": "auth",
|
||||
|
||||
# Generic fallbacks
|
||||
"/api/auth/": "auth",
|
||||
"/api/admin/": "admin",
|
||||
"/api/search/": "search",
|
||||
"/api/documents/upload": "upload",
|
||||
"/api/files/upload": "upload",
|
||||
"/api/import/": "upload",
|
||||
"/api/": "api",
|
||||
}
|
||||
|
||||
|
||||
class RateLimitStore:
|
||||
"""In-memory rate limit storage with sliding window algorithm"""
|
||||
|
||||
def __init__(self):
|
||||
# Structure: {key: deque(timestamps)}
|
||||
self._storage: Dict[str, deque] = defaultdict(deque)
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
def is_allowed(self, key: str, limit: int, window: int) -> Tuple[bool, Dict[str, int]]:
|
||||
"""Check if request is allowed and return rate limit info (sync)."""
|
||||
# Use a non-async path for portability and test friendliness
|
||||
now = int(time.time())
|
||||
window_start = now - window
|
||||
|
||||
# Clean old entries
|
||||
timestamps = self._storage[key]
|
||||
while timestamps and timestamps[0] <= window_start:
|
||||
timestamps.popleft()
|
||||
|
||||
# Check if limit exceeded
|
||||
current_count = len(timestamps)
|
||||
allowed = current_count < limit
|
||||
|
||||
# Add current request if allowed
|
||||
if allowed:
|
||||
timestamps.append(now)
|
||||
|
||||
# Calculate reset time (when oldest request expires)
|
||||
reset_time = (timestamps[0] + window) if timestamps else now + window
|
||||
|
||||
return allowed, {
|
||||
"limit": limit,
|
||||
"remaining": max(0, limit - current_count - (1 if allowed else 0)),
|
||||
"reset": reset_time,
|
||||
"retry_after": max(1, reset_time - now) if not allowed else 0
|
||||
}
|
||||
|
||||
async def cleanup_expired(self, max_age: int = 7200):
|
||||
"""Remove expired entries (cleanup task)"""
|
||||
async with self._lock:
|
||||
now = int(time.time())
|
||||
cutoff = now - max_age
|
||||
|
||||
expired_keys = []
|
||||
for key, timestamps in self._storage.items():
|
||||
# Remove old timestamps
|
||||
while timestamps and timestamps[0] <= cutoff:
|
||||
timestamps.popleft()
|
||||
|
||||
# Mark empty deques for deletion
|
||||
if not timestamps:
|
||||
expired_keys.append(key)
|
||||
|
||||
# Clean up empty entries
|
||||
for key in expired_keys:
|
||||
del self._storage[key]
|
||||
|
||||
|
||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""Rate limiting middleware with sliding window algorithm"""
|
||||
|
||||
def __init__(self, app, store: Optional[RateLimitStore] = None):
|
||||
super().__init__(app)
|
||||
self.store = store or RateLimitStore()
|
||||
self._cleanup_task = None
|
||||
self._start_cleanup_task()
|
||||
|
||||
def _start_cleanup_task(self):
|
||||
"""Start background cleanup task"""
|
||||
async def cleanup_loop():
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(300) # Clean every 5 minutes
|
||||
await self.store.cleanup_expired()
|
||||
except Exception as e:
|
||||
logger.warning("Rate limit cleanup failed", error=str(e))
|
||||
|
||||
# Create cleanup task
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
self._cleanup_task = loop.create_task(cleanup_loop())
|
||||
except Exception:
|
||||
pass # Will create on first request
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
# Skip entirely during pytest runs
|
||||
if os.getenv("PYTEST_RUNNING") == "1":
|
||||
return await call_next(request)
|
||||
|
||||
# Skip rate limiting for static files and health/metrics endpoints
|
||||
skip_paths = ["/static/", "/uploads/", "/health", "/ready", "/metrics", "/favicon.ico"]
|
||||
if any(request.url.path.startswith(path) for path in skip_paths):
|
||||
return await call_next(request)
|
||||
|
||||
# Do not count CORS preflight requests against rate limits
|
||||
if request.method.upper() == "OPTIONS":
|
||||
return await call_next(request)
|
||||
|
||||
# If the request is to API endpoints and an authenticated user is present on the state,
|
||||
# skip IP-based global rate limiting in favor of the user-based limiter.
|
||||
# This avoids tab/page-change bursts from tripping global IP limits.
|
||||
if request.url.path.startswith("/api/"):
|
||||
try:
|
||||
if hasattr(request.state, "user") and request.state.user:
|
||||
return await call_next(request)
|
||||
except Exception:
|
||||
# If state is not available for any reason, continue with IP-based limiting
|
||||
pass
|
||||
|
||||
# Determine rate limit category
|
||||
category = self._get_rate_limit_category(request.url.path)
|
||||
rate_config = RATE_LIMITS.get(category, RATE_LIMITS["global"])
|
||||
|
||||
# Generate rate limit key
|
||||
client_ip = self._get_client_ip(request)
|
||||
rate_key = f"{category}:{client_ip}"
|
||||
|
||||
# Check rate limit
|
||||
try:
|
||||
# call sync method in thread-safe manner
|
||||
allowed, info = self.store.is_allowed(
|
||||
rate_key,
|
||||
rate_config["requests"],
|
||||
rate_config["window"]
|
||||
)
|
||||
|
||||
if not allowed:
|
||||
logger.warning(
|
||||
"Rate limit exceeded",
|
||||
ip=client_ip,
|
||||
path=request.url.path,
|
||||
category=category,
|
||||
limit=info["limit"],
|
||||
retry_after=info["retry_after"]
|
||||
)
|
||||
|
||||
# Return rate limit error
|
||||
headers = {
|
||||
"X-RateLimit-Limit": str(info["limit"]),
|
||||
"X-RateLimit-Remaining": str(info["remaining"]),
|
||||
"X-RateLimit-Reset": str(info["reset"]),
|
||||
"Retry-After": str(info["retry_after"])
|
||||
}
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail=f"Rate limit exceeded. Try again in {info['retry_after']} seconds.",
|
||||
headers=headers
|
||||
)
|
||||
|
||||
# Process request
|
||||
response = await call_next(request)
|
||||
|
||||
# Add rate limit headers to response
|
||||
response.headers["X-RateLimit-Limit"] = str(info["limit"])
|
||||
response.headers["X-RateLimit-Remaining"] = str(info["remaining"])
|
||||
response.headers["X-RateLimit-Reset"] = str(info["reset"])
|
||||
|
||||
return response
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Rate limiting error", error=str(e))
|
||||
# Continue without rate limiting on errors
|
||||
return await call_next(request)
|
||||
|
||||
def _get_rate_limit_category(self, path: str) -> str:
|
||||
"""Determine rate limit category based on request path"""
|
||||
for pattern, category in ROUTE_PATTERNS.items():
|
||||
if path.startswith(pattern):
|
||||
return category
|
||||
return "global"
|
||||
|
||||
def _get_client_ip(self, request: Request) -> str:
|
||||
"""Extract client IP address from request headers"""
|
||||
# Check for IP in common proxy headers
|
||||
forwarded_for = request.headers.get("x-forwarded-for")
|
||||
if forwarded_for:
|
||||
return forwarded_for.split(",")[0].strip()
|
||||
|
||||
real_ip = request.headers.get("x-real-ip")
|
||||
if real_ip:
|
||||
return real_ip
|
||||
|
||||
# Fallback to direct client IP
|
||||
if request.client:
|
||||
return request.client.host
|
||||
|
||||
return "unknown"
|
||||
|
||||
|
||||
# Enhanced rate limiting for authenticated users
|
||||
class AuthenticatedRateLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""Enhanced rate limiting with user-based limits for authenticated requests"""
|
||||
|
||||
def __init__(self, app, store: Optional[RateLimitStore] = None):
|
||||
super().__init__(app)
|
||||
self.store = store or RateLimitStore()
|
||||
|
||||
# Higher limits for authenticated users
|
||||
self.auth_limits = {
|
||||
"api": {
|
||||
"requests": settings.auth_rl_api_requests,
|
||||
"window": settings.auth_rl_api_window_seconds,
|
||||
},
|
||||
"search": {
|
||||
"requests": settings.auth_rl_search_requests,
|
||||
"window": settings.auth_rl_search_window_seconds,
|
||||
},
|
||||
"upload": {
|
||||
"requests": settings.auth_rl_upload_requests,
|
||||
"window": settings.auth_rl_upload_window_seconds,
|
||||
},
|
||||
"admin": {
|
||||
"requests": settings.auth_rl_admin_requests,
|
||||
"window": settings.auth_rl_admin_window_seconds,
|
||||
},
|
||||
}
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
# Skip entirely during pytest runs
|
||||
if os.getenv("PYTEST_RUNNING") == "1":
|
||||
return await call_next(request)
|
||||
|
||||
# Only apply to API endpoints
|
||||
if not request.url.path.startswith("/api/"):
|
||||
return await call_next(request)
|
||||
|
||||
# Allow disabling via settings (useful for local/dev)
|
||||
if not settings.auth_rl_enabled:
|
||||
return await call_next(request)
|
||||
|
||||
# Skip if user not authenticated
|
||||
user_id = None
|
||||
try:
|
||||
if hasattr(request.state, "user") and request.state.user:
|
||||
user_id = getattr(request.state.user, "id", None) or getattr(request.state.user, "username", None)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not user_id:
|
||||
return await call_next(request)
|
||||
|
||||
# Determine category and get enhanced limits for authenticated users
|
||||
category = self._get_rate_limit_category(request.url.path)
|
||||
if category in self.auth_limits:
|
||||
rate_config = self.auth_limits[category]
|
||||
rate_key = f"auth:{category}:{user_id}"
|
||||
|
||||
try:
|
||||
allowed, info = self.store.is_allowed(
|
||||
rate_key,
|
||||
rate_config["requests"],
|
||||
rate_config["window"]
|
||||
)
|
||||
|
||||
if not allowed:
|
||||
logger.warning(
|
||||
"Authenticated user rate limit exceeded",
|
||||
user_id=user_id,
|
||||
path=request.url.path,
|
||||
category=category,
|
||||
limit=info["limit"]
|
||||
)
|
||||
|
||||
headers = {
|
||||
"X-RateLimit-Limit": str(info["limit"]),
|
||||
"X-RateLimit-Remaining": str(info["remaining"]),
|
||||
"X-RateLimit-Reset": str(info["reset"]),
|
||||
"Retry-After": str(info["retry_after"])
|
||||
}
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail=f"Rate limit exceeded for authenticated user. Try again in {info['retry_after']} seconds.",
|
||||
headers=headers
|
||||
)
|
||||
|
||||
# Add auth-specific rate limit headers
|
||||
response = await call_next(request)
|
||||
response.headers["X-Auth-RateLimit-Limit"] = str(info["limit"])
|
||||
response.headers["X-Auth-RateLimit-Remaining"] = str(info["remaining"])
|
||||
|
||||
return response
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Authenticated rate limiting error", error=str(e))
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
def _get_rate_limit_category(self, path: str) -> str:
|
||||
"""Determine rate limit category based on request path"""
|
||||
for pattern, category in ROUTE_PATTERNS.items():
|
||||
if path.startswith(pattern):
|
||||
return category
|
||||
return "api"
|
||||
|
||||
|
||||
# Global store instance
|
||||
rate_limit_store = RateLimitStore()
|
||||
|
||||
# Rate limiting utilities
|
||||
async def check_rate_limit(key: str, limit: int, window: int) -> bool:
|
||||
"""Check if a specific key is within rate limits"""
|
||||
allowed, _ = rate_limit_store.is_allowed(key, limit, window)
|
||||
return allowed
|
||||
|
||||
async def get_rate_limit_info(key: str, limit: int, window: int) -> Dict[str, int]:
|
||||
"""Get rate limit information for a key"""
|
||||
_, info = rate_limit_store.is_allowed(key, limit, window)
|
||||
return info
|
||||
Reference in New Issue
Block a user