changes
This commit is contained in:
445
app/utils/session_manager.py
Normal file
445
app/utils/session_manager.py
Normal file
@@ -0,0 +1,445 @@
|
||||
"""
|
||||
Advanced session management utilities for P2 security features
|
||||
"""
|
||||
import secrets
|
||||
import hashlib
|
||||
import json
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import Optional, List, Dict, Any, Tuple
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import and_, or_, func
|
||||
from fastapi import Request, Depends
|
||||
from user_agents import parse as parse_user_agent
|
||||
|
||||
from app.models.sessions import (
|
||||
UserSession, SessionActivity, SessionConfiguration,
|
||||
SessionSecurityEvent, SessionStatus
|
||||
)
|
||||
from app.models.user import User
|
||||
from app.core.logging import get_logger
|
||||
from app.database.base import get_db
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""
|
||||
Advanced session management with security features
|
||||
"""
|
||||
|
||||
# Default configuration
|
||||
DEFAULT_SESSION_TIMEOUT = timedelta(hours=8)
|
||||
DEFAULT_IDLE_TIMEOUT = timedelta(hours=1)
|
||||
DEFAULT_MAX_CONCURRENT_SESSIONS = 3
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def generate_secure_session_id(self) -> str:
|
||||
"""Generate cryptographically secure session ID"""
|
||||
# Generate 64 bytes of random data and hash it
|
||||
random_bytes = secrets.token_bytes(64)
|
||||
timestamp = str(datetime.now(timezone.utc).timestamp()).encode()
|
||||
return hashlib.sha256(random_bytes + timestamp).hexdigest()
|
||||
|
||||
def create_session(
|
||||
self,
|
||||
user: User,
|
||||
request: Request,
|
||||
login_method: str = "password"
|
||||
) -> UserSession:
|
||||
"""
|
||||
Create new secure session with fixation protection
|
||||
"""
|
||||
# Generate new session ID (prevents session fixation)
|
||||
session_id = self.generate_secure_session_id()
|
||||
|
||||
# Extract request metadata
|
||||
ip_address = self._get_client_ip(request)
|
||||
user_agent = request.headers.get("user-agent", "")
|
||||
device_fingerprint = self._generate_device_fingerprint(request)
|
||||
|
||||
# Get geographic info (placeholder - would integrate with GeoIP service)
|
||||
country, city = self._get_geographic_info(ip_address)
|
||||
|
||||
# Check for suspicious activity
|
||||
is_suspicious, risk_score = self._assess_login_risk(user, ip_address, user_agent)
|
||||
|
||||
# Get session configuration
|
||||
config = self._get_session_config(user)
|
||||
session_timeout = timedelta(minutes=config.session_timeout_minutes)
|
||||
|
||||
# Enforce concurrent session limits
|
||||
self._enforce_concurrent_session_limits(user, config.max_concurrent_sessions)
|
||||
|
||||
# Create session record
|
||||
session = UserSession(
|
||||
session_id=session_id,
|
||||
user_id=user.id,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
device_fingerprint=device_fingerprint,
|
||||
country=country,
|
||||
city=city,
|
||||
is_suspicious=is_suspicious,
|
||||
risk_score=risk_score,
|
||||
status=SessionStatus.ACTIVE,
|
||||
login_method=login_method,
|
||||
expires_at=datetime.now(timezone.utc) + session_timeout
|
||||
)
|
||||
|
||||
self.db.add(session)
|
||||
self.db.flush() # Get session ID
|
||||
|
||||
# Log session creation activity
|
||||
self._log_activity(
|
||||
session, user, request,
|
||||
activity_type="session_created",
|
||||
endpoint="/api/auth/login"
|
||||
)
|
||||
|
||||
# Generate security event if suspicious
|
||||
if is_suspicious:
|
||||
self._create_security_event(
|
||||
session, user,
|
||||
event_type="suspicious_login",
|
||||
severity="medium",
|
||||
description=f"Suspicious login detected: risk score {risk_score}",
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
country=country
|
||||
)
|
||||
|
||||
self.db.commit()
|
||||
logger.info(f"Created session {session_id} for user {user.username} from {ip_address}")
|
||||
|
||||
return session
|
||||
|
||||
def validate_session(self, session_id: str, request: Request) -> Optional[UserSession]:
|
||||
"""
|
||||
Validate session and update activity tracking
|
||||
"""
|
||||
session = self.db.query(UserSession).filter(
|
||||
UserSession.session_id == session_id
|
||||
).first()
|
||||
|
||||
if not session:
|
||||
return None
|
||||
|
||||
# Check if session is expired or revoked
|
||||
if not session.is_active():
|
||||
return None
|
||||
|
||||
# Check for IP address changes if configured
|
||||
current_ip = self._get_client_ip(request)
|
||||
config = self._get_session_config(session.user)
|
||||
|
||||
if config.force_logout_on_ip_change and session.ip_address != current_ip:
|
||||
self._create_security_event(
|
||||
session, session.user,
|
||||
event_type="ip_address_change",
|
||||
severity="medium",
|
||||
description=f"IP changed from {session.ip_address} to {current_ip}",
|
||||
ip_address=current_ip,
|
||||
action_taken="session_revoked"
|
||||
)
|
||||
session.revoke_session("ip_address_change")
|
||||
self.db.commit()
|
||||
return None
|
||||
|
||||
# Check idle timeout
|
||||
idle_timeout = timedelta(minutes=config.idle_timeout_minutes)
|
||||
if datetime.now(timezone.utc) - session.last_activity > idle_timeout:
|
||||
session.status = SessionStatus.EXPIRED
|
||||
self.db.commit()
|
||||
return None
|
||||
|
||||
# Update last activity
|
||||
session.last_activity = datetime.now(timezone.utc)
|
||||
|
||||
# Log activity
|
||||
self._log_activity(
|
||||
session, session.user, request,
|
||||
activity_type="session_validation",
|
||||
endpoint=str(request.url.path)
|
||||
)
|
||||
|
||||
self.db.commit()
|
||||
return session
|
||||
|
||||
def revoke_session(self, session_id: str, reason: str = "user_logout") -> bool:
|
||||
"""Revoke a specific session"""
|
||||
session = self.db.query(UserSession).filter(
|
||||
UserSession.session_id == session_id
|
||||
).first()
|
||||
|
||||
if session:
|
||||
session.revoke_session(reason)
|
||||
self.db.commit()
|
||||
logger.info(f"Revoked session {session_id}: {reason}")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def revoke_all_user_sessions(self, user_id: int, reason: str = "admin_action") -> int:
|
||||
"""Revoke all sessions for a user"""
|
||||
count = self.db.query(UserSession).filter(
|
||||
and_(
|
||||
UserSession.user_id == user_id,
|
||||
UserSession.status == SessionStatus.ACTIVE
|
||||
)
|
||||
).update({
|
||||
"status": SessionStatus.REVOKED,
|
||||
"revoked_at": datetime.now(timezone.utc),
|
||||
"revocation_reason": reason
|
||||
})
|
||||
|
||||
self.db.commit()
|
||||
logger.info(f"Revoked {count} sessions for user {user_id}: {reason}")
|
||||
return count
|
||||
|
||||
def get_active_sessions(self, user_id: int) -> List[UserSession]:
|
||||
"""Get all active sessions for a user"""
|
||||
return self.db.query(UserSession).filter(
|
||||
and_(
|
||||
UserSession.user_id == user_id,
|
||||
UserSession.status == SessionStatus.ACTIVE,
|
||||
UserSession.expires_at > datetime.now(timezone.utc)
|
||||
)
|
||||
).order_by(UserSession.last_activity.desc()).all()
|
||||
|
||||
def cleanup_expired_sessions(self) -> int:
|
||||
"""Clean up expired sessions"""
|
||||
count = self.db.query(UserSession).filter(
|
||||
and_(
|
||||
UserSession.status == SessionStatus.ACTIVE,
|
||||
UserSession.expires_at <= datetime.now(timezone.utc)
|
||||
)
|
||||
).update({
|
||||
"status": SessionStatus.EXPIRED
|
||||
})
|
||||
|
||||
self.db.commit()
|
||||
|
||||
if count > 0:
|
||||
logger.info(f"Cleaned up {count} expired sessions")
|
||||
|
||||
return count
|
||||
|
||||
def get_session_statistics(self, user_id: Optional[int] = None) -> Dict[str, Any]:
|
||||
"""Get session statistics for monitoring"""
|
||||
query = self.db.query(UserSession)
|
||||
if user_id:
|
||||
query = query.filter(UserSession.user_id == user_id)
|
||||
|
||||
# Basic counts
|
||||
total_sessions = query.count()
|
||||
active_sessions = query.filter(UserSession.status == SessionStatus.ACTIVE).count()
|
||||
suspicious_sessions = query.filter(UserSession.is_suspicious == True).count()
|
||||
|
||||
# Recent activity
|
||||
last_24h = datetime.now(timezone.utc) - timedelta(days=1)
|
||||
recent_sessions = query.filter(UserSession.created_at >= last_24h).count()
|
||||
|
||||
# Risk distribution
|
||||
high_risk = query.filter(UserSession.risk_score >= 70).count()
|
||||
medium_risk = query.filter(
|
||||
and_(UserSession.risk_score >= 30, UserSession.risk_score < 70)
|
||||
).count()
|
||||
low_risk = query.filter(UserSession.risk_score < 30).count()
|
||||
|
||||
return {
|
||||
"total_sessions": total_sessions,
|
||||
"active_sessions": active_sessions,
|
||||
"suspicious_sessions": suspicious_sessions,
|
||||
"recent_sessions_24h": recent_sessions,
|
||||
"risk_distribution": {
|
||||
"high": high_risk,
|
||||
"medium": medium_risk,
|
||||
"low": low_risk
|
||||
}
|
||||
}
|
||||
|
||||
def _enforce_concurrent_session_limits(self, user: User, max_sessions: int) -> None:
|
||||
"""Enforce concurrent session limits"""
|
||||
active_sessions = self.get_active_sessions(user.id)
|
||||
|
||||
if len(active_sessions) >= max_sessions:
|
||||
# Revoke oldest sessions
|
||||
sessions_to_revoke = active_sessions[max_sessions-1:]
|
||||
for session in sessions_to_revoke:
|
||||
session.revoke_session("concurrent_session_limit")
|
||||
|
||||
# Create security event
|
||||
self._create_security_event(
|
||||
session, user,
|
||||
event_type="concurrent_session_limit",
|
||||
severity="medium",
|
||||
description=f"Session revoked due to concurrent session limit ({max_sessions})",
|
||||
action_taken="session_revoked"
|
||||
)
|
||||
|
||||
logger.info(f"Revoked {len(sessions_to_revoke)} sessions for user {user.username} due to concurrent limit")
|
||||
|
||||
def _get_session_config(self, user: User) -> SessionConfiguration:
|
||||
"""Get session configuration for user"""
|
||||
# Try user-specific config first
|
||||
config = self.db.query(SessionConfiguration).filter(
|
||||
SessionConfiguration.user_id == user.id
|
||||
).first()
|
||||
|
||||
if not config:
|
||||
# Try global config
|
||||
config = self.db.query(SessionConfiguration).filter(
|
||||
SessionConfiguration.user_id.is_(None)
|
||||
).first()
|
||||
|
||||
if not config:
|
||||
# Create default global config
|
||||
config = SessionConfiguration()
|
||||
self.db.add(config)
|
||||
self.db.flush()
|
||||
|
||||
return config
|
||||
|
||||
def _assess_login_risk(self, user: User, ip_address: str, user_agent: str) -> Tuple[bool, int]:
|
||||
"""Assess login risk based on historical data"""
|
||||
risk_score = 0
|
||||
risk_factors = []
|
||||
|
||||
# Check for new IP address
|
||||
previous_ips = self.db.query(UserSession.ip_address).filter(
|
||||
and_(
|
||||
UserSession.user_id == user.id,
|
||||
UserSession.created_at >= datetime.now(timezone.utc) - timedelta(days=30)
|
||||
)
|
||||
).distinct().all()
|
||||
|
||||
if ip_address not in [ip[0] for ip in previous_ips]:
|
||||
risk_score += 30
|
||||
risk_factors.append("new_ip_address")
|
||||
|
||||
# Check for unusual login time
|
||||
current_hour = datetime.now(timezone.utc).hour
|
||||
user_login_hours = self.db.query(func.extract('hour', UserSession.created_at)).filter(
|
||||
and_(
|
||||
UserSession.user_id == user.id,
|
||||
UserSession.created_at >= datetime.now(timezone.utc) - timedelta(days=30)
|
||||
)
|
||||
).all()
|
||||
|
||||
if user_login_hours:
|
||||
common_hours = [hour[0] for hour in user_login_hours]
|
||||
if current_hour not in common_hours[-10:]: # Not in recent login hours
|
||||
risk_score += 20
|
||||
risk_factors.append("unusual_time")
|
||||
|
||||
# Check for new user agent
|
||||
recent_agents = self.db.query(UserSession.user_agent).filter(
|
||||
and_(
|
||||
UserSession.user_id == user.id,
|
||||
UserSession.created_at >= datetime.now(timezone.utc) - timedelta(days=7)
|
||||
)
|
||||
).distinct().all()
|
||||
|
||||
if user_agent not in [agent[0] for agent in recent_agents if agent[0]]:
|
||||
risk_score += 15
|
||||
risk_factors.append("new_user_agent")
|
||||
|
||||
# Check for rapid login attempts
|
||||
recent_attempts = self.db.query(UserSession).filter(
|
||||
and_(
|
||||
UserSession.user_id == user.id,
|
||||
UserSession.created_at >= datetime.now(timezone.utc) - timedelta(minutes=10)
|
||||
)
|
||||
).count()
|
||||
|
||||
if recent_attempts > 3:
|
||||
risk_score += 25
|
||||
risk_factors.append("rapid_attempts")
|
||||
|
||||
is_suspicious = risk_score >= 50
|
||||
return is_suspicious, min(risk_score, 100)
|
||||
|
||||
def _log_activity(
|
||||
self,
|
||||
session: UserSession,
|
||||
user: User,
|
||||
request: Request,
|
||||
activity_type: str,
|
||||
endpoint: str = None
|
||||
) -> None:
|
||||
"""Log session activity"""
|
||||
activity = SessionActivity(
|
||||
session_id=session.id,
|
||||
user_id=user.id,
|
||||
activity_type=activity_type,
|
||||
endpoint=endpoint or str(request.url.path),
|
||||
method=request.method,
|
||||
ip_address=self._get_client_ip(request),
|
||||
user_agent=request.headers.get("user-agent", "")
|
||||
)
|
||||
|
||||
self.db.add(activity)
|
||||
|
||||
def _create_security_event(
|
||||
self,
|
||||
session: Optional[UserSession],
|
||||
user: User,
|
||||
event_type: str,
|
||||
severity: str,
|
||||
description: str,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None,
|
||||
country: str = None,
|
||||
action_taken: str = None
|
||||
) -> None:
|
||||
"""Create security event record"""
|
||||
event = SessionSecurityEvent(
|
||||
session_id=session.id if session else None,
|
||||
user_id=user.id,
|
||||
event_type=event_type,
|
||||
severity=severity,
|
||||
description=description,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
country=country,
|
||||
action_taken=action_taken
|
||||
)
|
||||
|
||||
self.db.add(event)
|
||||
logger.warning(f"Security event [{severity}]: {event_type} for user {user.username} - {description}")
|
||||
|
||||
def _get_client_ip(self, request: Request) -> str:
|
||||
"""Extract client IP address from request"""
|
||||
# Check for forwarded headers first
|
||||
forwarded_for = request.headers.get("x-forwarded-for")
|
||||
if forwarded_for:
|
||||
return forwarded_for.split(",")[0].strip()
|
||||
|
||||
real_ip = request.headers.get("x-real-ip")
|
||||
if real_ip:
|
||||
return real_ip
|
||||
|
||||
# Fallback to direct connection
|
||||
return request.client.host if request.client else "unknown"
|
||||
|
||||
def _generate_device_fingerprint(self, request: Request) -> str:
|
||||
"""Generate device fingerprint for tracking"""
|
||||
user_agent = request.headers.get("user-agent", "")
|
||||
accept_language = request.headers.get("accept-language", "")
|
||||
accept_encoding = request.headers.get("accept-encoding", "")
|
||||
|
||||
fingerprint_data = f"{user_agent}|{accept_language}|{accept_encoding}"
|
||||
return hashlib.md5(fingerprint_data.encode()).hexdigest()
|
||||
|
||||
def _get_geographic_info(self, ip_address: str) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""Get geographic information for IP address"""
|
||||
# Placeholder - would integrate with GeoIP service like MaxMind
|
||||
# For now, return None values
|
||||
return None, None
|
||||
|
||||
|
||||
def get_session_manager(db: Session = Depends(get_db)) -> SessionManager:
|
||||
"""Dependency injection for session manager"""
|
||||
return SessionManager(db)
|
||||
Reference in New Issue
Block a user