446 lines
16 KiB
Python
446 lines
16 KiB
Python
"""
|
|
Advanced session management utilities for P2 security features
|
|
"""
|
|
import secrets
|
|
import hashlib
|
|
import json
|
|
from datetime import datetime, timezone, timedelta
|
|
from typing import Optional, List, Dict, Any, Tuple
|
|
from sqlalchemy.orm import Session
|
|
from sqlalchemy import and_, or_, func
|
|
from fastapi import Request, Depends
|
|
from user_agents import parse as parse_user_agent
|
|
|
|
from app.models.sessions import (
|
|
UserSession, SessionActivity, SessionConfiguration,
|
|
SessionSecurityEvent, SessionStatus
|
|
)
|
|
from app.models.user import User
|
|
from app.core.logging import get_logger
|
|
from app.database.base import get_db
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class SessionManager:
|
|
"""
|
|
Advanced session management with security features
|
|
"""
|
|
|
|
# Default configuration
|
|
DEFAULT_SESSION_TIMEOUT = timedelta(hours=8)
|
|
DEFAULT_IDLE_TIMEOUT = timedelta(hours=1)
|
|
DEFAULT_MAX_CONCURRENT_SESSIONS = 3
|
|
|
|
def __init__(self, db: Session):
|
|
self.db = db
|
|
|
|
def generate_secure_session_id(self) -> str:
|
|
"""Generate cryptographically secure session ID"""
|
|
# Generate 64 bytes of random data and hash it
|
|
random_bytes = secrets.token_bytes(64)
|
|
timestamp = str(datetime.now(timezone.utc).timestamp()).encode()
|
|
return hashlib.sha256(random_bytes + timestamp).hexdigest()
|
|
|
|
def create_session(
|
|
self,
|
|
user: User,
|
|
request: Request,
|
|
login_method: str = "password"
|
|
) -> UserSession:
|
|
"""
|
|
Create new secure session with fixation protection
|
|
"""
|
|
# Generate new session ID (prevents session fixation)
|
|
session_id = self.generate_secure_session_id()
|
|
|
|
# Extract request metadata
|
|
ip_address = self._get_client_ip(request)
|
|
user_agent = request.headers.get("user-agent", "")
|
|
device_fingerprint = self._generate_device_fingerprint(request)
|
|
|
|
# Get geographic info (placeholder - would integrate with GeoIP service)
|
|
country, city = self._get_geographic_info(ip_address)
|
|
|
|
# Check for suspicious activity
|
|
is_suspicious, risk_score = self._assess_login_risk(user, ip_address, user_agent)
|
|
|
|
# Get session configuration
|
|
config = self._get_session_config(user)
|
|
session_timeout = timedelta(minutes=config.session_timeout_minutes)
|
|
|
|
# Enforce concurrent session limits
|
|
self._enforce_concurrent_session_limits(user, config.max_concurrent_sessions)
|
|
|
|
# Create session record
|
|
session = UserSession(
|
|
session_id=session_id,
|
|
user_id=user.id,
|
|
ip_address=ip_address,
|
|
user_agent=user_agent,
|
|
device_fingerprint=device_fingerprint,
|
|
country=country,
|
|
city=city,
|
|
is_suspicious=is_suspicious,
|
|
risk_score=risk_score,
|
|
status=SessionStatus.ACTIVE,
|
|
login_method=login_method,
|
|
expires_at=datetime.now(timezone.utc) + session_timeout
|
|
)
|
|
|
|
self.db.add(session)
|
|
self.db.flush() # Get session ID
|
|
|
|
# Log session creation activity
|
|
self._log_activity(
|
|
session, user, request,
|
|
activity_type="session_created",
|
|
endpoint="/api/auth/login"
|
|
)
|
|
|
|
# Generate security event if suspicious
|
|
if is_suspicious:
|
|
self._create_security_event(
|
|
session, user,
|
|
event_type="suspicious_login",
|
|
severity="medium",
|
|
description=f"Suspicious login detected: risk score {risk_score}",
|
|
ip_address=ip_address,
|
|
user_agent=user_agent,
|
|
country=country
|
|
)
|
|
|
|
self.db.commit()
|
|
logger.info(f"Created session {session_id} for user {user.username} from {ip_address}")
|
|
|
|
return session
|
|
|
|
def validate_session(self, session_id: str, request: Request) -> Optional[UserSession]:
|
|
"""
|
|
Validate session and update activity tracking
|
|
"""
|
|
session = self.db.query(UserSession).filter(
|
|
UserSession.session_id == session_id
|
|
).first()
|
|
|
|
if not session:
|
|
return None
|
|
|
|
# Check if session is expired or revoked
|
|
if not session.is_active():
|
|
return None
|
|
|
|
# Check for IP address changes if configured
|
|
current_ip = self._get_client_ip(request)
|
|
config = self._get_session_config(session.user)
|
|
|
|
if config.force_logout_on_ip_change and session.ip_address != current_ip:
|
|
self._create_security_event(
|
|
session, session.user,
|
|
event_type="ip_address_change",
|
|
severity="medium",
|
|
description=f"IP changed from {session.ip_address} to {current_ip}",
|
|
ip_address=current_ip,
|
|
action_taken="session_revoked"
|
|
)
|
|
session.revoke_session("ip_address_change")
|
|
self.db.commit()
|
|
return None
|
|
|
|
# Check idle timeout
|
|
idle_timeout = timedelta(minutes=config.idle_timeout_minutes)
|
|
if datetime.now(timezone.utc) - session.last_activity > idle_timeout:
|
|
session.status = SessionStatus.EXPIRED
|
|
self.db.commit()
|
|
return None
|
|
|
|
# Update last activity
|
|
session.last_activity = datetime.now(timezone.utc)
|
|
|
|
# Log activity
|
|
self._log_activity(
|
|
session, session.user, request,
|
|
activity_type="session_validation",
|
|
endpoint=str(request.url.path)
|
|
)
|
|
|
|
self.db.commit()
|
|
return session
|
|
|
|
def revoke_session(self, session_id: str, reason: str = "user_logout") -> bool:
|
|
"""Revoke a specific session"""
|
|
session = self.db.query(UserSession).filter(
|
|
UserSession.session_id == session_id
|
|
).first()
|
|
|
|
if session:
|
|
session.revoke_session(reason)
|
|
self.db.commit()
|
|
logger.info(f"Revoked session {session_id}: {reason}")
|
|
return True
|
|
|
|
return False
|
|
|
|
def revoke_all_user_sessions(self, user_id: int, reason: str = "admin_action") -> int:
|
|
"""Revoke all sessions for a user"""
|
|
count = self.db.query(UserSession).filter(
|
|
and_(
|
|
UserSession.user_id == user_id,
|
|
UserSession.status == SessionStatus.ACTIVE
|
|
)
|
|
).update({
|
|
"status": SessionStatus.REVOKED,
|
|
"revoked_at": datetime.now(timezone.utc),
|
|
"revocation_reason": reason
|
|
})
|
|
|
|
self.db.commit()
|
|
logger.info(f"Revoked {count} sessions for user {user_id}: {reason}")
|
|
return count
|
|
|
|
def get_active_sessions(self, user_id: int) -> List[UserSession]:
|
|
"""Get all active sessions for a user"""
|
|
return self.db.query(UserSession).filter(
|
|
and_(
|
|
UserSession.user_id == user_id,
|
|
UserSession.status == SessionStatus.ACTIVE,
|
|
UserSession.expires_at > datetime.now(timezone.utc)
|
|
)
|
|
).order_by(UserSession.last_activity.desc()).all()
|
|
|
|
def cleanup_expired_sessions(self) -> int:
|
|
"""Clean up expired sessions"""
|
|
count = self.db.query(UserSession).filter(
|
|
and_(
|
|
UserSession.status == SessionStatus.ACTIVE,
|
|
UserSession.expires_at <= datetime.now(timezone.utc)
|
|
)
|
|
).update({
|
|
"status": SessionStatus.EXPIRED
|
|
})
|
|
|
|
self.db.commit()
|
|
|
|
if count > 0:
|
|
logger.info(f"Cleaned up {count} expired sessions")
|
|
|
|
return count
|
|
|
|
def get_session_statistics(self, user_id: Optional[int] = None) -> Dict[str, Any]:
|
|
"""Get session statistics for monitoring"""
|
|
query = self.db.query(UserSession)
|
|
if user_id:
|
|
query = query.filter(UserSession.user_id == user_id)
|
|
|
|
# Basic counts
|
|
total_sessions = query.count()
|
|
active_sessions = query.filter(UserSession.status == SessionStatus.ACTIVE).count()
|
|
suspicious_sessions = query.filter(UserSession.is_suspicious == True).count()
|
|
|
|
# Recent activity
|
|
last_24h = datetime.now(timezone.utc) - timedelta(days=1)
|
|
recent_sessions = query.filter(UserSession.created_at >= last_24h).count()
|
|
|
|
# Risk distribution
|
|
high_risk = query.filter(UserSession.risk_score >= 70).count()
|
|
medium_risk = query.filter(
|
|
and_(UserSession.risk_score >= 30, UserSession.risk_score < 70)
|
|
).count()
|
|
low_risk = query.filter(UserSession.risk_score < 30).count()
|
|
|
|
return {
|
|
"total_sessions": total_sessions,
|
|
"active_sessions": active_sessions,
|
|
"suspicious_sessions": suspicious_sessions,
|
|
"recent_sessions_24h": recent_sessions,
|
|
"risk_distribution": {
|
|
"high": high_risk,
|
|
"medium": medium_risk,
|
|
"low": low_risk
|
|
}
|
|
}
|
|
|
|
def _enforce_concurrent_session_limits(self, user: User, max_sessions: int) -> None:
|
|
"""Enforce concurrent session limits"""
|
|
active_sessions = self.get_active_sessions(user.id)
|
|
|
|
if len(active_sessions) >= max_sessions:
|
|
# Revoke oldest sessions
|
|
sessions_to_revoke = active_sessions[max_sessions-1:]
|
|
for session in sessions_to_revoke:
|
|
session.revoke_session("concurrent_session_limit")
|
|
|
|
# Create security event
|
|
self._create_security_event(
|
|
session, user,
|
|
event_type="concurrent_session_limit",
|
|
severity="medium",
|
|
description=f"Session revoked due to concurrent session limit ({max_sessions})",
|
|
action_taken="session_revoked"
|
|
)
|
|
|
|
logger.info(f"Revoked {len(sessions_to_revoke)} sessions for user {user.username} due to concurrent limit")
|
|
|
|
def _get_session_config(self, user: User) -> SessionConfiguration:
|
|
"""Get session configuration for user"""
|
|
# Try user-specific config first
|
|
config = self.db.query(SessionConfiguration).filter(
|
|
SessionConfiguration.user_id == user.id
|
|
).first()
|
|
|
|
if not config:
|
|
# Try global config
|
|
config = self.db.query(SessionConfiguration).filter(
|
|
SessionConfiguration.user_id.is_(None)
|
|
).first()
|
|
|
|
if not config:
|
|
# Create default global config
|
|
config = SessionConfiguration()
|
|
self.db.add(config)
|
|
self.db.flush()
|
|
|
|
return config
|
|
|
|
def _assess_login_risk(self, user: User, ip_address: str, user_agent: str) -> Tuple[bool, int]:
|
|
"""Assess login risk based on historical data"""
|
|
risk_score = 0
|
|
risk_factors = []
|
|
|
|
# Check for new IP address
|
|
previous_ips = self.db.query(UserSession.ip_address).filter(
|
|
and_(
|
|
UserSession.user_id == user.id,
|
|
UserSession.created_at >= datetime.now(timezone.utc) - timedelta(days=30)
|
|
)
|
|
).distinct().all()
|
|
|
|
if ip_address not in [ip[0] for ip in previous_ips]:
|
|
risk_score += 30
|
|
risk_factors.append("new_ip_address")
|
|
|
|
# Check for unusual login time
|
|
current_hour = datetime.now(timezone.utc).hour
|
|
user_login_hours = self.db.query(func.extract('hour', UserSession.created_at)).filter(
|
|
and_(
|
|
UserSession.user_id == user.id,
|
|
UserSession.created_at >= datetime.now(timezone.utc) - timedelta(days=30)
|
|
)
|
|
).all()
|
|
|
|
if user_login_hours:
|
|
common_hours = [hour[0] for hour in user_login_hours]
|
|
if current_hour not in common_hours[-10:]: # Not in recent login hours
|
|
risk_score += 20
|
|
risk_factors.append("unusual_time")
|
|
|
|
# Check for new user agent
|
|
recent_agents = self.db.query(UserSession.user_agent).filter(
|
|
and_(
|
|
UserSession.user_id == user.id,
|
|
UserSession.created_at >= datetime.now(timezone.utc) - timedelta(days=7)
|
|
)
|
|
).distinct().all()
|
|
|
|
if user_agent not in [agent[0] for agent in recent_agents if agent[0]]:
|
|
risk_score += 15
|
|
risk_factors.append("new_user_agent")
|
|
|
|
# Check for rapid login attempts
|
|
recent_attempts = self.db.query(UserSession).filter(
|
|
and_(
|
|
UserSession.user_id == user.id,
|
|
UserSession.created_at >= datetime.now(timezone.utc) - timedelta(minutes=10)
|
|
)
|
|
).count()
|
|
|
|
if recent_attempts > 3:
|
|
risk_score += 25
|
|
risk_factors.append("rapid_attempts")
|
|
|
|
is_suspicious = risk_score >= 50
|
|
return is_suspicious, min(risk_score, 100)
|
|
|
|
def _log_activity(
|
|
self,
|
|
session: UserSession,
|
|
user: User,
|
|
request: Request,
|
|
activity_type: str,
|
|
endpoint: str = None
|
|
) -> None:
|
|
"""Log session activity"""
|
|
activity = SessionActivity(
|
|
session_id=session.id,
|
|
user_id=user.id,
|
|
activity_type=activity_type,
|
|
endpoint=endpoint or str(request.url.path),
|
|
method=request.method,
|
|
ip_address=self._get_client_ip(request),
|
|
user_agent=request.headers.get("user-agent", "")
|
|
)
|
|
|
|
self.db.add(activity)
|
|
|
|
def _create_security_event(
|
|
self,
|
|
session: Optional[UserSession],
|
|
user: User,
|
|
event_type: str,
|
|
severity: str,
|
|
description: str,
|
|
ip_address: str = None,
|
|
user_agent: str = None,
|
|
country: str = None,
|
|
action_taken: str = None
|
|
) -> None:
|
|
"""Create security event record"""
|
|
event = SessionSecurityEvent(
|
|
session_id=session.id if session else None,
|
|
user_id=user.id,
|
|
event_type=event_type,
|
|
severity=severity,
|
|
description=description,
|
|
ip_address=ip_address,
|
|
user_agent=user_agent,
|
|
country=country,
|
|
action_taken=action_taken
|
|
)
|
|
|
|
self.db.add(event)
|
|
logger.warning(f"Security event [{severity}]: {event_type} for user {user.username} - {description}")
|
|
|
|
def _get_client_ip(self, request: Request) -> str:
|
|
"""Extract client IP address from request"""
|
|
# Check for forwarded headers first
|
|
forwarded_for = request.headers.get("x-forwarded-for")
|
|
if forwarded_for:
|
|
return forwarded_for.split(",")[0].strip()
|
|
|
|
real_ip = request.headers.get("x-real-ip")
|
|
if real_ip:
|
|
return real_ip
|
|
|
|
# Fallback to direct connection
|
|
return request.client.host if request.client else "unknown"
|
|
|
|
def _generate_device_fingerprint(self, request: Request) -> str:
|
|
"""Generate device fingerprint for tracking"""
|
|
user_agent = request.headers.get("user-agent", "")
|
|
accept_language = request.headers.get("accept-language", "")
|
|
accept_encoding = request.headers.get("accept-encoding", "")
|
|
|
|
fingerprint_data = f"{user_agent}|{accept_language}|{accept_encoding}"
|
|
return hashlib.md5(fingerprint_data.encode()).hexdigest()
|
|
|
|
def _get_geographic_info(self, ip_address: str) -> Tuple[Optional[str], Optional[str]]:
|
|
"""Get geographic information for IP address"""
|
|
# Placeholder - would integrate with GeoIP service like MaxMind
|
|
# For now, return None values
|
|
return None, None
|
|
|
|
|
|
def get_session_manager(db: Session = Depends(get_db)) -> SessionManager:
|
|
"""Dependency injection for session manager"""
|
|
return SessionManager(db)
|