changes
This commit is contained in:
319
app/middleware/session_middleware.py
Normal file
319
app/middleware/session_middleware.py
Normal file
@@ -0,0 +1,319 @@
|
||||
"""
|
||||
Session management middleware for P2 security features
|
||||
"""
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
from fastapi import Request, Response
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database.base import get_db
|
||||
from app.utils.session_manager import SessionManager
|
||||
from app.core.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class SessionManagementMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Advanced session management middleware
|
||||
|
||||
Features:
|
||||
- Session validation and renewal
|
||||
- Activity tracking
|
||||
- Concurrent session limits
|
||||
- Session fixation protection
|
||||
- Automatic cleanup
|
||||
"""
|
||||
|
||||
# Endpoints that don't require session validation
|
||||
EXCLUDED_PATHS = {
|
||||
"/docs", "/redoc", "/openapi.json",
|
||||
"/api/auth/login", "/api/auth/register",
|
||||
"/api/auth/refresh", "/api/health",
|
||||
"/health", "/ready", "/metrics",
|
||||
}
|
||||
|
||||
def __init__(self, app, cleanup_interval: int = 3600):
|
||||
super().__init__(app)
|
||||
self.cleanup_interval = cleanup_interval # 1 hour default
|
||||
self.last_cleanup = time.time()
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
"""Main middleware dispatcher"""
|
||||
start_time = time.time()
|
||||
|
||||
# Skip middleware for excluded paths
|
||||
if self._should_skip_middleware(request):
|
||||
return await call_next(request)
|
||||
|
||||
# Get database session
|
||||
db = next(get_db())
|
||||
session_manager = SessionManager(db)
|
||||
|
||||
try:
|
||||
# Perform periodic cleanup
|
||||
await self._periodic_cleanup(session_manager)
|
||||
|
||||
# Process session validation
|
||||
session_info = await self._validate_session(request, session_manager)
|
||||
|
||||
# Add session info to request state
|
||||
request.state.session_info = session_info
|
||||
# Also expose authenticated user (if any) for downstream middleware (e.g., user-based rate limiting, logging)
|
||||
try:
|
||||
request.state.user = session_info.get("user") if session_info else None
|
||||
except Exception:
|
||||
# Be resilient: never break the request due to state propagation
|
||||
pass
|
||||
|
||||
# Fallback: if no server-side session identified, try to attach user from Authorization token for JWT-only flows
|
||||
if not getattr(request.state, "user", None):
|
||||
try:
|
||||
auth_header = request.headers.get("authorization") or request.headers.get("Authorization")
|
||||
if auth_header and auth_header.lower().startswith("bearer "):
|
||||
token = auth_header.split(" ", 1)[1].strip()
|
||||
from app.auth.security import verify_token # local import to avoid circular deps
|
||||
username = verify_token(token)
|
||||
if username:
|
||||
from app.models.user import User # local import
|
||||
user = session_manager.db.query(User).filter(User.username == username).first()
|
||||
if user and user.is_active:
|
||||
request.state.user = user
|
||||
except Exception:
|
||||
# Never fail the request if auth attachment fails here
|
||||
pass
|
||||
|
||||
# Process request
|
||||
response = await call_next(request)
|
||||
|
||||
# Update session activity
|
||||
if session_info and session_info.get("session"):
|
||||
await self._update_session_activity(
|
||||
request, response, session_info["session"], session_manager, start_time
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Session middleware error: {str(e)}")
|
||||
# Re-raise to be handled by global error handlers; do not re-invoke downstream app
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def _should_skip_middleware(self, request: Request) -> bool:
|
||||
"""Check if middleware should be skipped for this request"""
|
||||
path = request.url.path
|
||||
|
||||
# Skip excluded paths
|
||||
if any(path.startswith(excluded) for excluded in self.EXCLUDED_PATHS):
|
||||
return True
|
||||
|
||||
# Skip static files
|
||||
if path.startswith("/static/") or path.startswith("/favicon.ico"):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def _validate_session(self, request: Request, session_manager: SessionManager) -> Optional[dict]:
|
||||
"""Validate session from request"""
|
||||
# Extract session ID from various sources
|
||||
session_id = await self._extract_session_id(request)
|
||||
|
||||
if not session_id:
|
||||
return None
|
||||
|
||||
# Validate session
|
||||
session = session_manager.validate_session(session_id, request)
|
||||
|
||||
if not session:
|
||||
return None
|
||||
|
||||
return {
|
||||
"session": session,
|
||||
"session_id": session_id,
|
||||
"user": session.user,
|
||||
"is_valid": True
|
||||
}
|
||||
|
||||
async def _extract_session_id(self, request: Request) -> Optional[str]:
|
||||
"""Extract session ID from request"""
|
||||
# Try cookie first
|
||||
session_id = request.cookies.get("session_id")
|
||||
if session_id:
|
||||
return session_id
|
||||
|
||||
# Try custom header
|
||||
session_id = request.headers.get("X-Session-ID")
|
||||
if session_id:
|
||||
return session_id
|
||||
|
||||
# For JWT-based sessions, extract from authorization header
|
||||
auth_header = request.headers.get("authorization")
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
# Use JWT token as session identifier for now
|
||||
# In a full implementation, you'd decode JWT and extract session ID
|
||||
token = auth_header[7:]
|
||||
return token[:32] if len(token) > 32 else token
|
||||
|
||||
return None
|
||||
|
||||
async def _update_session_activity(
|
||||
self,
|
||||
request: Request,
|
||||
response: Response,
|
||||
session,
|
||||
session_manager: SessionManager,
|
||||
start_time: float
|
||||
) -> None:
|
||||
"""Update session activity tracking"""
|
||||
try:
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
# Log API activity
|
||||
session_manager._log_activity(
|
||||
session, session.user, request,
|
||||
activity_type="api_request",
|
||||
endpoint=request.url.path
|
||||
)
|
||||
|
||||
# Update activity record with response details
|
||||
if hasattr(session, 'activities') and session.activities:
|
||||
latest_activity = session.activities[-1]
|
||||
latest_activity.status_code = getattr(response, 'status_code', None)
|
||||
latest_activity.duration_ms = duration_ms
|
||||
|
||||
# Analyze for suspicious patterns
|
||||
await self._analyze_activity_patterns(session, session_manager)
|
||||
|
||||
session_manager.db.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update session activity: {str(e)}")
|
||||
|
||||
async def _analyze_activity_patterns(self, session, session_manager: SessionManager) -> None:
|
||||
"""Analyze activity patterns for suspicious behavior"""
|
||||
try:
|
||||
# Get recent activities for this session
|
||||
recent_activities = session_manager.db.query(
|
||||
session_manager.db.query(type(session.activities[0]))
|
||||
).filter_by(session_id=session.id).order_by(
|
||||
type(session.activities[0]).timestamp.desc()
|
||||
).limit(10).all()
|
||||
|
||||
if len(recent_activities) < 5:
|
||||
return
|
||||
|
||||
# Check for rapid API calls (possible automation)
|
||||
time_diffs = []
|
||||
for i in range(1, len(recent_activities)):
|
||||
time_diff = (recent_activities[i-1].timestamp - recent_activities[i].timestamp).total_seconds()
|
||||
time_diffs.append(time_diff)
|
||||
|
||||
avg_time_diff = sum(time_diffs) / len(time_diffs)
|
||||
|
||||
# Flag if average time between requests is < 1 second
|
||||
if avg_time_diff < 1.0:
|
||||
session.risk_score = min(session.risk_score + 10, 100)
|
||||
session_manager._create_security_event(
|
||||
session, session.user,
|
||||
event_type="rapid_api_calls",
|
||||
severity="medium",
|
||||
description=f"Rapid API calls detected: avg {avg_time_diff:.2f}s between requests"
|
||||
)
|
||||
|
||||
# Lock session if risk score is too high
|
||||
if session.risk_score >= 80:
|
||||
session.lock_session("high_risk_activity")
|
||||
session_manager._create_security_event(
|
||||
session, session.user,
|
||||
event_type="session_locked",
|
||||
severity="high",
|
||||
description=f"Session locked due to high risk score: {session.risk_score}",
|
||||
action_taken="session_locked"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to analyze activity patterns: {str(e)}")
|
||||
|
||||
async def _periodic_cleanup(self, session_manager: SessionManager) -> None:
|
||||
"""Perform periodic cleanup of expired sessions"""
|
||||
current_time = time.time()
|
||||
|
||||
if current_time - self.last_cleanup > self.cleanup_interval:
|
||||
try:
|
||||
cleaned_count = session_manager.cleanup_expired_sessions()
|
||||
self.last_cleanup = current_time
|
||||
|
||||
if cleaned_count > 0:
|
||||
logger.info(f"Cleaned up {cleaned_count} expired sessions")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cleanup sessions: {str(e)}")
|
||||
|
||||
|
||||
class SessionSecurityMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Additional security middleware for session protection
|
||||
"""
|
||||
|
||||
def __init__(self, app):
|
||||
super().__init__(app)
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
"""Process security checks for sessions"""
|
||||
|
||||
# Add security headers for session management
|
||||
response = await call_next(request)
|
||||
|
||||
# Add session security headers
|
||||
if isinstance(response, Response):
|
||||
# Prevent session fixation
|
||||
response.headers["X-Session-Security"] = "fixation-protected"
|
||||
|
||||
# Indicate session management is active
|
||||
response.headers["X-Session-Management"] = "active"
|
||||
|
||||
# Add cache control for session-sensitive pages
|
||||
if request.url.path.startswith("/api/"):
|
||||
response.headers["Cache-Control"] = "no-store, no-cache, must-revalidate"
|
||||
response.headers["Pragma"] = "no-cache"
|
||||
response.headers["Expires"] = "0"
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class SessionCookieMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Secure cookie management for sessions
|
||||
"""
|
||||
|
||||
def __init__(self, app, secure: bool = True, same_site: str = "strict"):
|
||||
super().__init__(app)
|
||||
self.secure = secure
|
||||
self.same_site = same_site
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
"""Handle secure session cookies"""
|
||||
response = await call_next(request)
|
||||
|
||||
# Check if we need to set session cookie
|
||||
session_info = getattr(request.state, 'session_info', None)
|
||||
|
||||
if session_info and session_info.get("session"):
|
||||
session_id = session_info["session_id"]
|
||||
|
||||
# Set secure session cookie
|
||||
response.set_cookie(
|
||||
key="session_id",
|
||||
value=session_id,
|
||||
max_age=28800, # 8 hours
|
||||
httponly=True,
|
||||
secure=self.secure,
|
||||
samesite=self.same_site
|
||||
)
|
||||
|
||||
return response
|
||||
Reference in New Issue
Block a user