""" Session management middleware for P2 security features """ import time from datetime import datetime, timezone from typing import Optional from fastapi import Request, Response from fastapi.responses import JSONResponse from starlette.middleware.base import BaseHTTPMiddleware from sqlalchemy.orm import Session from app.database.base import get_db from app.utils.session_manager import SessionManager from app.core.logging import get_logger logger = get_logger(__name__) class SessionManagementMiddleware(BaseHTTPMiddleware): """ Advanced session management middleware Features: - Session validation and renewal - Activity tracking - Concurrent session limits - Session fixation protection - Automatic cleanup """ # Endpoints that don't require session validation EXCLUDED_PATHS = { "/docs", "/redoc", "/openapi.json", "/api/auth/login", "/api/auth/register", "/api/auth/refresh", "/api/health", "/health", "/ready", "/metrics", } def __init__(self, app, cleanup_interval: int = 3600): super().__init__(app) self.cleanup_interval = cleanup_interval # 1 hour default self.last_cleanup = time.time() async def dispatch(self, request: Request, call_next): """Main middleware dispatcher""" start_time = time.time() # Skip middleware for excluded paths if self._should_skip_middleware(request): return await call_next(request) # Get database session db = next(get_db()) session_manager = SessionManager(db) try: # Perform periodic cleanup await self._periodic_cleanup(session_manager) # Process session validation session_info = await self._validate_session(request, session_manager) # Add session info to request state request.state.session_info = session_info # Also expose authenticated user (if any) for downstream middleware (e.g., user-based rate limiting, logging) try: request.state.user = session_info.get("user") if session_info else None except Exception: # Be resilient: never break the request due to state propagation pass # Fallback: if no server-side session identified, try to attach user from Authorization token for JWT-only flows if not getattr(request.state, "user", None): try: auth_header = request.headers.get("authorization") or request.headers.get("Authorization") if auth_header and auth_header.lower().startswith("bearer "): token = auth_header.split(" ", 1)[1].strip() from app.auth.security import verify_token # local import to avoid circular deps username = verify_token(token) if username: from app.models.user import User # local import user = session_manager.db.query(User).filter(User.username == username).first() if user and user.is_active: request.state.user = user except Exception: # Never fail the request if auth attachment fails here pass # Process request response = await call_next(request) # Update session activity if session_info and session_info.get("session"): await self._update_session_activity( request, response, session_info["session"], session_manager, start_time ) return response except Exception as e: logger.error(f"Session middleware error: {str(e)}") # Re-raise to be handled by global error handlers; do not re-invoke downstream app raise finally: db.close() def _should_skip_middleware(self, request: Request) -> bool: """Check if middleware should be skipped for this request""" path = request.url.path # Skip excluded paths if any(path.startswith(excluded) for excluded in self.EXCLUDED_PATHS): return True # Skip static files if path.startswith("/static/") or path.startswith("/favicon.ico"): return True return False async def _validate_session(self, request: Request, session_manager: SessionManager) -> Optional[dict]: """Validate session from request""" # Extract session ID from various sources session_id = await self._extract_session_id(request) if not session_id: return None # Validate session session = session_manager.validate_session(session_id, request) if not session: return None return { "session": session, "session_id": session_id, "user": session.user, "is_valid": True } async def _extract_session_id(self, request: Request) -> Optional[str]: """Extract session ID from request""" # Try cookie first session_id = request.cookies.get("session_id") if session_id: return session_id # Try custom header session_id = request.headers.get("X-Session-ID") if session_id: return session_id # For JWT-based sessions, extract from authorization header auth_header = request.headers.get("authorization") if auth_header and auth_header.startswith("Bearer "): # Use JWT token as session identifier for now # In a full implementation, you'd decode JWT and extract session ID token = auth_header[7:] return token[:32] if len(token) > 32 else token return None async def _update_session_activity( self, request: Request, response: Response, session, session_manager: SessionManager, start_time: float ) -> None: """Update session activity tracking""" try: duration_ms = int((time.time() - start_time) * 1000) # Log API activity session_manager._log_activity( session, session.user, request, activity_type="api_request", endpoint=request.url.path ) # Update activity record with response details if hasattr(session, 'activities') and session.activities: latest_activity = session.activities[-1] latest_activity.status_code = getattr(response, 'status_code', None) latest_activity.duration_ms = duration_ms # Analyze for suspicious patterns await self._analyze_activity_patterns(session, session_manager) session_manager.db.commit() except Exception as e: logger.error(f"Failed to update session activity: {str(e)}") async def _analyze_activity_patterns(self, session, session_manager: SessionManager) -> None: """Analyze activity patterns for suspicious behavior""" try: # Get recent activities for this session recent_activities = session_manager.db.query( session_manager.db.query(type(session.activities[0])) ).filter_by(session_id=session.id).order_by( type(session.activities[0]).timestamp.desc() ).limit(10).all() if len(recent_activities) < 5: return # Check for rapid API calls (possible automation) time_diffs = [] for i in range(1, len(recent_activities)): time_diff = (recent_activities[i-1].timestamp - recent_activities[i].timestamp).total_seconds() time_diffs.append(time_diff) avg_time_diff = sum(time_diffs) / len(time_diffs) # Flag if average time between requests is < 1 second if avg_time_diff < 1.0: session.risk_score = min(session.risk_score + 10, 100) session_manager._create_security_event( session, session.user, event_type="rapid_api_calls", severity="medium", description=f"Rapid API calls detected: avg {avg_time_diff:.2f}s between requests" ) # Lock session if risk score is too high if session.risk_score >= 80: session.lock_session("high_risk_activity") session_manager._create_security_event( session, session.user, event_type="session_locked", severity="high", description=f"Session locked due to high risk score: {session.risk_score}", action_taken="session_locked" ) except Exception as e: logger.error(f"Failed to analyze activity patterns: {str(e)}") async def _periodic_cleanup(self, session_manager: SessionManager) -> None: """Perform periodic cleanup of expired sessions""" current_time = time.time() if current_time - self.last_cleanup > self.cleanup_interval: try: cleaned_count = session_manager.cleanup_expired_sessions() self.last_cleanup = current_time if cleaned_count > 0: logger.info(f"Cleaned up {cleaned_count} expired sessions") except Exception as e: logger.error(f"Failed to cleanup sessions: {str(e)}") class SessionSecurityMiddleware(BaseHTTPMiddleware): """ Additional security middleware for session protection """ def __init__(self, app): super().__init__(app) async def dispatch(self, request: Request, call_next): """Process security checks for sessions""" # Add security headers for session management response = await call_next(request) # Add session security headers if isinstance(response, Response): # Prevent session fixation response.headers["X-Session-Security"] = "fixation-protected" # Indicate session management is active response.headers["X-Session-Management"] = "active" # Add cache control for session-sensitive pages if request.url.path.startswith("/api/"): response.headers["Cache-Control"] = "no-store, no-cache, must-revalidate" response.headers["Pragma"] = "no-cache" response.headers["Expires"] = "0" return response class SessionCookieMiddleware(BaseHTTPMiddleware): """ Secure cookie management for sessions """ def __init__(self, app, secure: bool = True, same_site: str = "strict"): super().__init__(app) self.secure = secure self.same_site = same_site async def dispatch(self, request: Request, call_next): """Handle secure session cookies""" response = await call_next(request) # Check if we need to set session cookie session_info = getattr(request.state, 'session_info', None) if session_info and session_info.get("session"): session_id = session_info["session_id"] # Set secure session cookie response.set_cookie( key="session_id", value=session_id, max_age=28800, # 8 hours httponly=True, secure=self.secure, samesite=self.same_site ) return response