""" Advanced session management utilities for P2 security features """ import secrets import hashlib import json from datetime import datetime, timezone, timedelta from typing import Optional, List, Dict, Any, Tuple from sqlalchemy.orm import Session from sqlalchemy import and_, or_, func from fastapi import Request, Depends from user_agents import parse as parse_user_agent from app.models.sessions import ( UserSession, SessionActivity, SessionConfiguration, SessionSecurityEvent, SessionStatus ) from app.models.user import User from app.core.logging import get_logger from app.database.base import get_db logger = get_logger(__name__) class SessionManager: """ Advanced session management with security features """ # Default configuration DEFAULT_SESSION_TIMEOUT = timedelta(hours=8) DEFAULT_IDLE_TIMEOUT = timedelta(hours=1) DEFAULT_MAX_CONCURRENT_SESSIONS = 3 def __init__(self, db: Session): self.db = db def generate_secure_session_id(self) -> str: """Generate cryptographically secure session ID""" # Generate 64 bytes of random data and hash it random_bytes = secrets.token_bytes(64) timestamp = str(datetime.now(timezone.utc).timestamp()).encode() return hashlib.sha256(random_bytes + timestamp).hexdigest() def create_session( self, user: User, request: Request, login_method: str = "password" ) -> UserSession: """ Create new secure session with fixation protection """ # Generate new session ID (prevents session fixation) session_id = self.generate_secure_session_id() # Extract request metadata ip_address = self._get_client_ip(request) user_agent = request.headers.get("user-agent", "") device_fingerprint = self._generate_device_fingerprint(request) # Get geographic info (placeholder - would integrate with GeoIP service) country, city = self._get_geographic_info(ip_address) # Check for suspicious activity is_suspicious, risk_score = self._assess_login_risk(user, ip_address, user_agent) # Get session configuration config = self._get_session_config(user) session_timeout = timedelta(minutes=config.session_timeout_minutes) # Enforce concurrent session limits self._enforce_concurrent_session_limits(user, config.max_concurrent_sessions) # Create session record session = UserSession( session_id=session_id, user_id=user.id, ip_address=ip_address, user_agent=user_agent, device_fingerprint=device_fingerprint, country=country, city=city, is_suspicious=is_suspicious, risk_score=risk_score, status=SessionStatus.ACTIVE, login_method=login_method, expires_at=datetime.now(timezone.utc) + session_timeout ) self.db.add(session) self.db.flush() # Get session ID # Log session creation activity self._log_activity( session, user, request, activity_type="session_created", endpoint="/api/auth/login" ) # Generate security event if suspicious if is_suspicious: self._create_security_event( session, user, event_type="suspicious_login", severity="medium", description=f"Suspicious login detected: risk score {risk_score}", ip_address=ip_address, user_agent=user_agent, country=country ) self.db.commit() logger.info(f"Created session {session_id} for user {user.username} from {ip_address}") return session def validate_session(self, session_id: str, request: Request) -> Optional[UserSession]: """ Validate session and update activity tracking """ session = self.db.query(UserSession).filter( UserSession.session_id == session_id ).first() if not session: return None # Check if session is expired or revoked if not session.is_active(): return None # Check for IP address changes if configured current_ip = self._get_client_ip(request) config = self._get_session_config(session.user) if config.force_logout_on_ip_change and session.ip_address != current_ip: self._create_security_event( session, session.user, event_type="ip_address_change", severity="medium", description=f"IP changed from {session.ip_address} to {current_ip}", ip_address=current_ip, action_taken="session_revoked" ) session.revoke_session("ip_address_change") self.db.commit() return None # Check idle timeout idle_timeout = timedelta(minutes=config.idle_timeout_minutes) if datetime.now(timezone.utc) - session.last_activity > idle_timeout: session.status = SessionStatus.EXPIRED self.db.commit() return None # Update last activity session.last_activity = datetime.now(timezone.utc) # Log activity self._log_activity( session, session.user, request, activity_type="session_validation", endpoint=str(request.url.path) ) self.db.commit() return session def revoke_session(self, session_id: str, reason: str = "user_logout") -> bool: """Revoke a specific session""" session = self.db.query(UserSession).filter( UserSession.session_id == session_id ).first() if session: session.revoke_session(reason) self.db.commit() logger.info(f"Revoked session {session_id}: {reason}") return True return False def revoke_all_user_sessions(self, user_id: int, reason: str = "admin_action") -> int: """Revoke all sessions for a user""" count = self.db.query(UserSession).filter( and_( UserSession.user_id == user_id, UserSession.status == SessionStatus.ACTIVE ) ).update({ "status": SessionStatus.REVOKED, "revoked_at": datetime.now(timezone.utc), "revocation_reason": reason }) self.db.commit() logger.info(f"Revoked {count} sessions for user {user_id}: {reason}") return count def get_active_sessions(self, user_id: int) -> List[UserSession]: """Get all active sessions for a user""" return self.db.query(UserSession).filter( and_( UserSession.user_id == user_id, UserSession.status == SessionStatus.ACTIVE, UserSession.expires_at > datetime.now(timezone.utc) ) ).order_by(UserSession.last_activity.desc()).all() def cleanup_expired_sessions(self) -> int: """Clean up expired sessions""" count = self.db.query(UserSession).filter( and_( UserSession.status == SessionStatus.ACTIVE, UserSession.expires_at <= datetime.now(timezone.utc) ) ).update({ "status": SessionStatus.EXPIRED }) self.db.commit() if count > 0: logger.info(f"Cleaned up {count} expired sessions") return count def get_session_statistics(self, user_id: Optional[int] = None) -> Dict[str, Any]: """Get session statistics for monitoring""" query = self.db.query(UserSession) if user_id: query = query.filter(UserSession.user_id == user_id) # Basic counts total_sessions = query.count() active_sessions = query.filter(UserSession.status == SessionStatus.ACTIVE).count() suspicious_sessions = query.filter(UserSession.is_suspicious == True).count() # Recent activity last_24h = datetime.now(timezone.utc) - timedelta(days=1) recent_sessions = query.filter(UserSession.created_at >= last_24h).count() # Risk distribution high_risk = query.filter(UserSession.risk_score >= 70).count() medium_risk = query.filter( and_(UserSession.risk_score >= 30, UserSession.risk_score < 70) ).count() low_risk = query.filter(UserSession.risk_score < 30).count() return { "total_sessions": total_sessions, "active_sessions": active_sessions, "suspicious_sessions": suspicious_sessions, "recent_sessions_24h": recent_sessions, "risk_distribution": { "high": high_risk, "medium": medium_risk, "low": low_risk } } def _enforce_concurrent_session_limits(self, user: User, max_sessions: int) -> None: """Enforce concurrent session limits""" active_sessions = self.get_active_sessions(user.id) if len(active_sessions) >= max_sessions: # Revoke oldest sessions sessions_to_revoke = active_sessions[max_sessions-1:] for session in sessions_to_revoke: session.revoke_session("concurrent_session_limit") # Create security event self._create_security_event( session, user, event_type="concurrent_session_limit", severity="medium", description=f"Session revoked due to concurrent session limit ({max_sessions})", action_taken="session_revoked" ) logger.info(f"Revoked {len(sessions_to_revoke)} sessions for user {user.username} due to concurrent limit") def _get_session_config(self, user: User) -> SessionConfiguration: """Get session configuration for user""" # Try user-specific config first config = self.db.query(SessionConfiguration).filter( SessionConfiguration.user_id == user.id ).first() if not config: # Try global config config = self.db.query(SessionConfiguration).filter( SessionConfiguration.user_id.is_(None) ).first() if not config: # Create default global config config = SessionConfiguration() self.db.add(config) self.db.flush() return config def _assess_login_risk(self, user: User, ip_address: str, user_agent: str) -> Tuple[bool, int]: """Assess login risk based on historical data""" risk_score = 0 risk_factors = [] # Check for new IP address previous_ips = self.db.query(UserSession.ip_address).filter( and_( UserSession.user_id == user.id, UserSession.created_at >= datetime.now(timezone.utc) - timedelta(days=30) ) ).distinct().all() if ip_address not in [ip[0] for ip in previous_ips]: risk_score += 30 risk_factors.append("new_ip_address") # Check for unusual login time current_hour = datetime.now(timezone.utc).hour user_login_hours = self.db.query(func.extract('hour', UserSession.created_at)).filter( and_( UserSession.user_id == user.id, UserSession.created_at >= datetime.now(timezone.utc) - timedelta(days=30) ) ).all() if user_login_hours: common_hours = [hour[0] for hour in user_login_hours] if current_hour not in common_hours[-10:]: # Not in recent login hours risk_score += 20 risk_factors.append("unusual_time") # Check for new user agent recent_agents = self.db.query(UserSession.user_agent).filter( and_( UserSession.user_id == user.id, UserSession.created_at >= datetime.now(timezone.utc) - timedelta(days=7) ) ).distinct().all() if user_agent not in [agent[0] for agent in recent_agents if agent[0]]: risk_score += 15 risk_factors.append("new_user_agent") # Check for rapid login attempts recent_attempts = self.db.query(UserSession).filter( and_( UserSession.user_id == user.id, UserSession.created_at >= datetime.now(timezone.utc) - timedelta(minutes=10) ) ).count() if recent_attempts > 3: risk_score += 25 risk_factors.append("rapid_attempts") is_suspicious = risk_score >= 50 return is_suspicious, min(risk_score, 100) def _log_activity( self, session: UserSession, user: User, request: Request, activity_type: str, endpoint: str = None ) -> None: """Log session activity""" activity = SessionActivity( session_id=session.id, user_id=user.id, activity_type=activity_type, endpoint=endpoint or str(request.url.path), method=request.method, ip_address=self._get_client_ip(request), user_agent=request.headers.get("user-agent", "") ) self.db.add(activity) def _create_security_event( self, session: Optional[UserSession], user: User, event_type: str, severity: str, description: str, ip_address: str = None, user_agent: str = None, country: str = None, action_taken: str = None ) -> None: """Create security event record""" event = SessionSecurityEvent( session_id=session.id if session else None, user_id=user.id, event_type=event_type, severity=severity, description=description, ip_address=ip_address, user_agent=user_agent, country=country, action_taken=action_taken ) self.db.add(event) logger.warning(f"Security event [{severity}]: {event_type} for user {user.username} - {description}") def _get_client_ip(self, request: Request) -> str: """Extract client IP address from request""" # Check for forwarded headers first forwarded_for = request.headers.get("x-forwarded-for") if forwarded_for: return forwarded_for.split(",")[0].strip() real_ip = request.headers.get("x-real-ip") if real_ip: return real_ip # Fallback to direct connection return request.client.host if request.client else "unknown" def _generate_device_fingerprint(self, request: Request) -> str: """Generate device fingerprint for tracking""" user_agent = request.headers.get("user-agent", "") accept_language = request.headers.get("accept-language", "") accept_encoding = request.headers.get("accept-encoding", "") fingerprint_data = f"{user_agent}|{accept_language}|{accept_encoding}" return hashlib.md5(fingerprint_data.encode()).hexdigest() def _get_geographic_info(self, ip_address: str) -> Tuple[Optional[str], Optional[str]]: """Get geographic information for IP address""" # Placeholder - would integrate with GeoIP service like MaxMind # For now, return None values return None, None def get_session_manager(db: Session = Depends(get_db)) -> SessionManager: """Dependency injection for session manager""" return SessionManager(db)