changes
This commit is contained in:
379
app/utils/database_security.py
Normal file
379
app/utils/database_security.py
Normal file
@@ -0,0 +1,379 @@
|
||||
"""
|
||||
Database Security Utilities
|
||||
|
||||
Provides utilities for secure database operations and SQL injection prevention:
|
||||
- Parameterized query helpers
|
||||
- SQL injection detection and prevention
|
||||
- Safe query building utilities
|
||||
- Database security auditing
|
||||
"""
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql import ClauseElement
|
||||
from app.utils.logging import app_logger
|
||||
|
||||
logger = app_logger.bind(name="database_security")
|
||||
|
||||
# Dangerous SQL patterns that could indicate injection attempts
|
||||
DANGEROUS_PATTERNS = [
|
||||
# SQL injection keywords
|
||||
r'\bunion\s+select\b',
|
||||
r'\b(drop\s+table)\b',
|
||||
r'\bdelete\s+from\b',
|
||||
r'\b(insert\s+into)\b',
|
||||
r'\b(update\s+.*set)\b',
|
||||
r'\b(alter\s+table)\b',
|
||||
r'\b(create\s+table)\b',
|
||||
r'\b(truncate\s+table)\b',
|
||||
|
||||
# Command execution
|
||||
r'\b(exec\s*\()\b',
|
||||
r'\b(execute\s*\()\b',
|
||||
r'\b(system\s*\()\b',
|
||||
r'\b(shell_exec\s*\()\b',
|
||||
r'\b(eval\s*\()\b',
|
||||
|
||||
# Comment-based attacks
|
||||
r'(;\s*drop\s+table\b)',
|
||||
r'(--\s*$)',
|
||||
r'(/\*.*?\*/)',
|
||||
r'(\#.*$)',
|
||||
|
||||
# Quote escaping attempts
|
||||
r"(';|';|\";|\")",
|
||||
r"(\\\\'|\\\\\"|\\\\x)",
|
||||
|
||||
# Hex/unicode encoding
|
||||
r'(0x[0-9a-fA-F]+)',
|
||||
r'(\\u[0-9a-fA-F]{4})',
|
||||
|
||||
# Boolean-based attacks
|
||||
r'\b(1=1|1=0|true|false)\b',
|
||||
r"\b(or\s+1\s*=\s*1|and\s+1\s*=\s*1)\b",
|
||||
|
||||
# Time-based attacks
|
||||
r'\b(sleep\s*\(|delay\s*\(|waitfor\s+delay)\b',
|
||||
|
||||
# File operations
|
||||
r'\b(load_file\s*\(|into\s+outfile|into\s+dumpfile)\b',
|
||||
|
||||
# Information schema access
|
||||
r'\b(information_schema|sys\.|pg_)\b',
|
||||
# Subselect usage in WHERE clause that may indicate enumeration
|
||||
r'\b\(\s*select\b',
|
||||
]
|
||||
|
||||
# Compiled regex patterns for performance
|
||||
COMPILED_PATTERNS = [re.compile(pattern, re.IGNORECASE | re.MULTILINE) for pattern in DANGEROUS_PATTERNS]
|
||||
|
||||
|
||||
class SQLSecurityValidator:
|
||||
"""Validates SQL queries and parameters for security issues"""
|
||||
|
||||
@staticmethod
|
||||
def validate_query_string(query: str) -> List[str]:
|
||||
"""Validate a SQL query string for potential injection attempts"""
|
||||
issues: List[str] = []
|
||||
|
||||
for i, pattern in enumerate(COMPILED_PATTERNS):
|
||||
try:
|
||||
matches = pattern.findall(query)
|
||||
except Exception:
|
||||
matches = []
|
||||
if matches:
|
||||
issues.append(
|
||||
f"Potentially dangerous SQL pattern detected: {DANGEROUS_PATTERNS[i]} -> {str(matches)[:80]}"
|
||||
)
|
||||
|
||||
# Heuristic fallback to catch common cases without relying on regex quirks
|
||||
if not issues:
|
||||
ql = (query or "").lower()
|
||||
if "; drop table" in ql or ";drop table" in ql:
|
||||
issues.append("Heuristic: DROP TABLE detected")
|
||||
if " union select " in ql or " union\nselect " in ql:
|
||||
issues.append("Heuristic: UNION SELECT detected")
|
||||
if " or 1=1" in ql or " and 1=1" in ql or " or 1 = 1" in ql or " and 1 = 1" in ql:
|
||||
issues.append("Heuristic: tautology (1=1) detected")
|
||||
if "(select" in ql:
|
||||
issues.append("Heuristic: subselect in WHERE detected")
|
||||
|
||||
return issues
|
||||
|
||||
@staticmethod
|
||||
def validate_parameter_value(param_name: str, param_value: Any) -> List[str]:
|
||||
"""Validate a parameter value for potential injection attempts"""
|
||||
issues = []
|
||||
|
||||
if param_value is None:
|
||||
return issues
|
||||
|
||||
# Convert to string for pattern matching
|
||||
str_value = str(param_value)
|
||||
|
||||
# Check for dangerous patterns in parameter values
|
||||
for i, pattern in enumerate(COMPILED_PATTERNS):
|
||||
matches = pattern.findall(str_value)
|
||||
if matches:
|
||||
issues.append(f"Parameter '{param_name}' contains dangerous pattern: {DANGEROUS_PATTERNS[i]}")
|
||||
|
||||
# Additional parameter-specific checks
|
||||
if isinstance(param_value, str):
|
||||
# Check for excessive length (potential buffer overflow)
|
||||
if len(param_value) > 10000:
|
||||
issues.append(f"Parameter '{param_name}' is excessively long ({len(param_value)} characters)")
|
||||
|
||||
# Check for null bytes
|
||||
if '\x00' in param_value:
|
||||
issues.append(f"Parameter '{param_name}' contains null bytes")
|
||||
|
||||
# Check for control characters
|
||||
if any(ord(c) < 32 and c not in '\t\n\r' for c in param_value):
|
||||
issues.append(f"Parameter '{param_name}' contains suspicious control characters")
|
||||
|
||||
return issues
|
||||
|
||||
@staticmethod
|
||||
def validate_query_with_params(query: str, params: Dict[str, Any]) -> List[str]:
|
||||
"""Validate a complete query with its parameters"""
|
||||
issues = []
|
||||
|
||||
# Validate the query string
|
||||
query_issues = SQLSecurityValidator.validate_query_string(query)
|
||||
issues.extend(query_issues)
|
||||
|
||||
# Validate each parameter
|
||||
for param_name, param_value in params.items():
|
||||
param_issues = SQLSecurityValidator.validate_parameter_value(param_name, param_value)
|
||||
issues.extend(param_issues)
|
||||
|
||||
return issues
|
||||
|
||||
|
||||
class SecureQueryBuilder:
|
||||
"""Builds secure parameterized queries to prevent SQL injection"""
|
||||
|
||||
@staticmethod
|
||||
def safe_text_query(query: str, params: Optional[Dict[str, Any]] = None) -> text:
|
||||
"""Create a safe text query with parameter validation"""
|
||||
params = params or {}
|
||||
|
||||
# Validate the query and parameters
|
||||
issues = SQLSecurityValidator.validate_query_with_params(query, params)
|
||||
|
||||
if issues:
|
||||
logger.warning("Potential security issues in query", query=query[:100], issues=issues)
|
||||
# In production, you might want to raise an exception or sanitize
|
||||
# For now, we'll log and proceed with caution
|
||||
|
||||
return text(query)
|
||||
|
||||
@staticmethod
|
||||
def build_like_clause(column: ClauseElement, search_term: str, case_sensitive: bool = False) -> ClauseElement:
|
||||
"""Build a safe LIKE clause with proper escaping"""
|
||||
# Escape special characters in the search term
|
||||
escaped_term = search_term.replace('%', r'\%').replace('_', r'\_').replace('\\', r'\\')
|
||||
|
||||
if case_sensitive:
|
||||
return column.like(f"%{escaped_term}%", escape='\\')
|
||||
else:
|
||||
return column.ilike(f"%{escaped_term}%", escape='\\')
|
||||
|
||||
@staticmethod
|
||||
def build_in_clause(column: ClauseElement, values: List[Any]) -> ClauseElement:
|
||||
"""Build a safe IN clause with parameter binding"""
|
||||
if not values:
|
||||
return column.in_([]) # Empty list
|
||||
|
||||
# Validate each value
|
||||
for i, value in enumerate(values):
|
||||
issues = SQLSecurityValidator.validate_parameter_value(f"in_value_{i}", value)
|
||||
if issues:
|
||||
logger.warning(f"Security issues in IN clause value: {issues}")
|
||||
|
||||
return column.in_(values)
|
||||
|
||||
@staticmethod
|
||||
def build_fts_query(search_terms: List[str], exact_phrase: bool = False) -> str:
|
||||
"""Build a safe FTS (Full Text Search) query"""
|
||||
if not search_terms:
|
||||
return ""
|
||||
|
||||
safe_terms = []
|
||||
for term in search_terms:
|
||||
# Remove or escape dangerous characters for FTS
|
||||
# SQLite FTS5 has its own escaping rules
|
||||
safe_term = re.sub(r'[^\w\s\-\.]', '', term.strip())
|
||||
if safe_term:
|
||||
if exact_phrase:
|
||||
# Escape quotes for phrase search
|
||||
safe_term = safe_term.replace('"', '""')
|
||||
safe_terms.append(f'"{safe_term}"')
|
||||
else:
|
||||
safe_terms.append(safe_term)
|
||||
|
||||
if exact_phrase:
|
||||
return ' '.join(safe_terms)
|
||||
else:
|
||||
return ' AND '.join(safe_terms)
|
||||
|
||||
|
||||
class DatabaseAuditor:
|
||||
"""Audits database operations for security compliance"""
|
||||
|
||||
@staticmethod
|
||||
def audit_query_execution(query: str, params: Dict[str, Any], execution_time: float) -> None:
|
||||
"""Audit a query execution for security and performance"""
|
||||
# Security audit
|
||||
security_issues = SQLSecurityValidator.validate_query_with_params(query, params)
|
||||
|
||||
if security_issues:
|
||||
logger.warning(
|
||||
"Query executed with security issues",
|
||||
query=query[:200],
|
||||
issues=security_issues,
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
# Performance audit
|
||||
if execution_time > 5.0: # Slow query threshold
|
||||
logger.warning(
|
||||
"Slow query detected",
|
||||
query=query[:200],
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
# Pattern audit - detect potentially problematic queries
|
||||
if re.search(r'\bSELECT\s+\*\s+FROM\b', query, re.IGNORECASE):
|
||||
logger.info("SELECT * query detected - consider specifying columns", query=query[:100])
|
||||
|
||||
if re.search(r'\bLIMIT\s+\d{4,}\b', query, re.IGNORECASE):
|
||||
logger.info("Large LIMIT detected", query=query[:100])
|
||||
|
||||
|
||||
def execute_secure_query(
|
||||
db: Session,
|
||||
query: str,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
audit: bool = True
|
||||
) -> Any:
|
||||
"""Execute a query with security validation and auditing"""
|
||||
import time
|
||||
|
||||
params = params or {}
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Validate query security
|
||||
if audit:
|
||||
issues = SQLSecurityValidator.validate_query_with_params(query, params)
|
||||
if issues:
|
||||
logger.warning("Executing query with potential security issues", issues=issues)
|
||||
|
||||
# Create safe text query
|
||||
safe_query = SecureQueryBuilder.safe_text_query(query, params)
|
||||
|
||||
# Execute query
|
||||
result = db.execute(safe_query, params)
|
||||
|
||||
# Audit execution
|
||||
if audit:
|
||||
execution_time = time.time() - start_time
|
||||
DatabaseAuditor.audit_query_execution(query, params, execution_time)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
execution_time = time.time() - start_time
|
||||
logger.error(
|
||||
"Query execution failed",
|
||||
query=query[:200],
|
||||
error=str(e),
|
||||
execution_time=execution_time
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
def sanitize_fts_query(query: str) -> str:
|
||||
"""Sanitize user input for FTS queries"""
|
||||
if not query:
|
||||
return ""
|
||||
|
||||
# Remove potentially dangerous characters
|
||||
# Keep alphanumeric, spaces, and basic punctuation
|
||||
sanitized = re.sub(r'[^\w\s\-\.\,\!\?\"\']', ' ', query)
|
||||
|
||||
# Remove excessive whitespace
|
||||
sanitized = re.sub(r'\s+', ' ', sanitized).strip()
|
||||
|
||||
# Limit length
|
||||
if len(sanitized) > 500:
|
||||
sanitized = sanitized[:500]
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
def create_safe_search_conditions(
|
||||
search_terms: List[str],
|
||||
searchable_columns: List[ClauseElement],
|
||||
case_sensitive: bool = False,
|
||||
exact_phrase: bool = False
|
||||
) -> Optional[ClauseElement]:
|
||||
"""Create safe search conditions for multiple columns"""
|
||||
from sqlalchemy import or_, and_
|
||||
|
||||
if not search_terms or not searchable_columns:
|
||||
return None
|
||||
|
||||
search_conditions = []
|
||||
|
||||
if exact_phrase:
|
||||
# Single phrase search across all columns
|
||||
phrase = ' '.join(search_terms)
|
||||
phrase_conditions = []
|
||||
for column in searchable_columns:
|
||||
phrase_conditions.append(
|
||||
SecureQueryBuilder.build_like_clause(column, phrase, case_sensitive)
|
||||
)
|
||||
search_conditions.append(or_(*phrase_conditions))
|
||||
else:
|
||||
# Each term must match at least one column
|
||||
for term in search_terms:
|
||||
term_conditions = []
|
||||
for column in searchable_columns:
|
||||
term_conditions.append(
|
||||
SecureQueryBuilder.build_like_clause(column, term, case_sensitive)
|
||||
)
|
||||
search_conditions.append(or_(*term_conditions))
|
||||
|
||||
return and_(*search_conditions) if search_conditions else None
|
||||
|
||||
|
||||
# Whitelist of allowed column names for dynamic queries
|
||||
ALLOWED_SORT_COLUMNS = {
|
||||
'rolodex': ['id', 'first', 'last', 'city', 'email', 'created_at', 'updated_at'],
|
||||
'files': ['file_no', 'id', 'regarding', 'status', 'file_type', 'opened', 'closed', 'created_at', 'updated_at'],
|
||||
'ledger': ['id', 'file_no', 't_code', 'amount', 'date', 'created_at', 'updated_at'],
|
||||
'qdros': ['id', 'file_no', 'form_name', 'status', 'created_at', 'updated_at'],
|
||||
}
|
||||
|
||||
def validate_sort_column(table: str, column: str) -> bool:
|
||||
"""Validate that a sort column is allowed for a table"""
|
||||
allowed_columns = ALLOWED_SORT_COLUMNS.get(table, [])
|
||||
return column in allowed_columns
|
||||
|
||||
|
||||
def safe_order_by(table: str, sort_column: str, sort_direction: str = 'asc') -> Optional[str]:
|
||||
"""Create a safe ORDER BY clause with whitelist validation"""
|
||||
# Validate sort column
|
||||
if not validate_sort_column(table, sort_column):
|
||||
logger.warning(f"Invalid sort column '{sort_column}' for table '{table}'")
|
||||
return None
|
||||
|
||||
# Validate sort direction
|
||||
if sort_direction.lower() not in ['asc', 'desc']:
|
||||
logger.warning(f"Invalid sort direction '{sort_direction}'")
|
||||
return None
|
||||
|
||||
return f"{sort_column} {sort_direction.upper()}"
|
||||
668
app/utils/enhanced_audit.py
Normal file
668
app/utils/enhanced_audit.py
Normal file
@@ -0,0 +1,668 @@
|
||||
"""
|
||||
Enhanced audit logging utilities for P2 security features
|
||||
"""
|
||||
import uuid
|
||||
import json
|
||||
import hashlib
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import Optional, Dict, Any, List, Union
|
||||
from contextlib import contextmanager
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import and_, or_, func
|
||||
from fastapi import Request
|
||||
from user_agents import parse as parse_user_agent
|
||||
|
||||
from app.models.audit_enhanced import (
|
||||
EnhancedAuditLog, SecurityAlert, ComplianceReport,
|
||||
AuditRetentionPolicy, SIEMIntegration,
|
||||
SecurityEventType, SecurityEventSeverity, ComplianceStandard
|
||||
)
|
||||
from app.models.user import User
|
||||
from app.utils.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class EnhancedAuditLogger:
|
||||
"""
|
||||
Enhanced audit logging system with security event tracking
|
||||
"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def log_security_event(
|
||||
self,
|
||||
event_type: SecurityEventType,
|
||||
title: str,
|
||||
description: str,
|
||||
user: Optional[User] = None,
|
||||
session_id: Optional[str] = None,
|
||||
request: Optional[Request] = None,
|
||||
severity: SecurityEventSeverity = SecurityEventSeverity.INFO,
|
||||
outcome: str = "success",
|
||||
resource_type: Optional[str] = None,
|
||||
resource_id: Optional[str] = None,
|
||||
resource_name: Optional[str] = None,
|
||||
data_before: Optional[Dict[str, Any]] = None,
|
||||
data_after: Optional[Dict[str, Any]] = None,
|
||||
risk_factors: Optional[List[str]] = None,
|
||||
threat_indicators: Optional[List[str]] = None,
|
||||
compliance_standards: Optional[List[ComplianceStandard]] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
custom_fields: Optional[Dict[str, Any]] = None,
|
||||
correlation_id: Optional[str] = None
|
||||
) -> EnhancedAuditLog:
|
||||
"""
|
||||
Log a comprehensive security event
|
||||
"""
|
||||
# Generate unique event ID
|
||||
event_id = str(uuid.uuid4())
|
||||
|
||||
# Extract request metadata
|
||||
source_ip = None
|
||||
user_agent = None
|
||||
endpoint = None
|
||||
http_method = None
|
||||
request_id = None
|
||||
|
||||
if request:
|
||||
source_ip = self._get_client_ip(request)
|
||||
user_agent = request.headers.get("user-agent", "")
|
||||
endpoint = str(request.url.path)
|
||||
http_method = request.method
|
||||
request_id = getattr(request.state, 'request_id', None)
|
||||
|
||||
# Determine event category
|
||||
event_category = self._categorize_event(event_type)
|
||||
|
||||
# Calculate risk score
|
||||
risk_score = self._calculate_risk_score(
|
||||
event_type, severity, risk_factors, threat_indicators
|
||||
)
|
||||
|
||||
# Get geographic info (placeholder - would integrate with GeoIP)
|
||||
country, region, city = self._get_geographic_info(source_ip)
|
||||
|
||||
# Create audit log entry
|
||||
audit_log = EnhancedAuditLog(
|
||||
event_id=event_id,
|
||||
event_type=event_type.value,
|
||||
event_category=event_category,
|
||||
severity=severity.value,
|
||||
title=title,
|
||||
description=description,
|
||||
outcome=outcome,
|
||||
user_id=user.id if user else None,
|
||||
session_id=session_id,
|
||||
source_ip=source_ip,
|
||||
user_agent=user_agent,
|
||||
request_id=request_id,
|
||||
country=country,
|
||||
region=region,
|
||||
city=city,
|
||||
endpoint=endpoint,
|
||||
http_method=http_method,
|
||||
resource_type=resource_type,
|
||||
resource_id=resource_id,
|
||||
resource_name=resource_name,
|
||||
risk_score=risk_score,
|
||||
correlation_id=correlation_id or str(uuid.uuid4())
|
||||
)
|
||||
|
||||
# Set JSON data
|
||||
if data_before:
|
||||
audit_log.set_data_before(data_before)
|
||||
if data_after:
|
||||
audit_log.set_data_after(data_after)
|
||||
if risk_factors:
|
||||
audit_log.set_risk_factors(risk_factors)
|
||||
if threat_indicators:
|
||||
audit_log.set_threat_indicators(threat_indicators)
|
||||
if compliance_standards:
|
||||
audit_log.set_compliance_standards([std.value for std in compliance_standards])
|
||||
if tags:
|
||||
audit_log.set_tags(tags)
|
||||
if custom_fields:
|
||||
audit_log.set_custom_fields(custom_fields)
|
||||
|
||||
# Save to database
|
||||
self.db.add(audit_log)
|
||||
self.db.flush() # Get ID for further processing
|
||||
|
||||
# Check for security alerts
|
||||
self._check_security_alerts(audit_log)
|
||||
|
||||
# Send to SIEM systems
|
||||
self._send_to_siem(audit_log)
|
||||
|
||||
self.db.commit()
|
||||
|
||||
logger.info(
|
||||
f"Security event logged: {event_type.value}",
|
||||
extra={
|
||||
"event_id": event_id,
|
||||
"user_id": user.id if user else None,
|
||||
"severity": severity.value,
|
||||
"risk_score": risk_score
|
||||
}
|
||||
)
|
||||
|
||||
return audit_log
|
||||
|
||||
def log_data_access(
|
||||
self,
|
||||
user: User,
|
||||
resource_type: str,
|
||||
resource_id: str,
|
||||
action: str, # read, write, delete, export
|
||||
request: Optional[Request] = None,
|
||||
session_id: Optional[str] = None,
|
||||
record_count: Optional[int] = None,
|
||||
data_volume: Optional[int] = None,
|
||||
compliance_standards: Optional[List[ComplianceStandard]] = None
|
||||
) -> EnhancedAuditLog:
|
||||
"""
|
||||
Log data access events for compliance
|
||||
"""
|
||||
event_type_map = {
|
||||
"read": SecurityEventType.DATA_READ,
|
||||
"write": SecurityEventType.DATA_WRITE,
|
||||
"delete": SecurityEventType.DATA_DELETE,
|
||||
"export": SecurityEventType.DATA_EXPORT
|
||||
}
|
||||
|
||||
event_type = event_type_map.get(action, SecurityEventType.DATA_READ)
|
||||
|
||||
return self.log_security_event(
|
||||
event_type=event_type,
|
||||
title=f"Data {action} operation",
|
||||
description=f"User {user.username} performed {action} on {resource_type} {resource_id}",
|
||||
user=user,
|
||||
session_id=session_id,
|
||||
request=request,
|
||||
severity=SecurityEventSeverity.INFO,
|
||||
resource_type=resource_type,
|
||||
resource_id=resource_id,
|
||||
compliance_standards=compliance_standards or [ComplianceStandard.SOX],
|
||||
custom_fields={
|
||||
"record_count": record_count,
|
||||
"data_volume": data_volume
|
||||
}
|
||||
)
|
||||
|
||||
def log_authentication_event(
|
||||
self,
|
||||
event_type: SecurityEventType,
|
||||
username: str,
|
||||
request: Request,
|
||||
user: Optional[User] = None,
|
||||
session_id: Optional[str] = None,
|
||||
outcome: str = "success",
|
||||
details: Optional[str] = None,
|
||||
risk_factors: Optional[List[str]] = None
|
||||
) -> EnhancedAuditLog:
|
||||
"""
|
||||
Log authentication-related events
|
||||
"""
|
||||
severity = SecurityEventSeverity.INFO
|
||||
if outcome == "failure" or risk_factors:
|
||||
severity = SecurityEventSeverity.MEDIUM
|
||||
if event_type == SecurityEventType.ACCOUNT_LOCKED:
|
||||
severity = SecurityEventSeverity.HIGH
|
||||
|
||||
return self.log_security_event(
|
||||
event_type=event_type,
|
||||
title=f"Authentication event: {event_type.value}",
|
||||
description=details or f"Authentication {outcome} for user {username}",
|
||||
user=user,
|
||||
session_id=session_id,
|
||||
request=request,
|
||||
severity=severity,
|
||||
outcome=outcome,
|
||||
risk_factors=risk_factors,
|
||||
compliance_standards=[ComplianceStandard.SOX, ComplianceStandard.ISO27001]
|
||||
)
|
||||
|
||||
def log_admin_action(
|
||||
self,
|
||||
admin_user: User,
|
||||
action: str,
|
||||
target_resource: str,
|
||||
request: Request,
|
||||
session_id: Optional[str] = None,
|
||||
data_before: Optional[Dict[str, Any]] = None,
|
||||
data_after: Optional[Dict[str, Any]] = None,
|
||||
affected_user_id: Optional[int] = None
|
||||
) -> EnhancedAuditLog:
|
||||
"""
|
||||
Log administrative actions for compliance
|
||||
"""
|
||||
return self.log_security_event(
|
||||
event_type=SecurityEventType.CONFIGURATION_CHANGE,
|
||||
title=f"Administrative action: {action}",
|
||||
description=f"Admin {admin_user.username} performed {action} on {target_resource}",
|
||||
user=admin_user,
|
||||
session_id=session_id,
|
||||
request=request,
|
||||
severity=SecurityEventSeverity.MEDIUM,
|
||||
resource_type="admin",
|
||||
resource_id=target_resource,
|
||||
data_before=data_before,
|
||||
data_after=data_after,
|
||||
compliance_standards=[ComplianceStandard.SOX, ComplianceStandard.SOC2],
|
||||
tags=["admin_action", "configuration_change"],
|
||||
custom_fields={
|
||||
"affected_user_id": affected_user_id
|
||||
}
|
||||
)
|
||||
|
||||
def create_security_alert(
|
||||
self,
|
||||
rule_id: str,
|
||||
rule_name: str,
|
||||
title: str,
|
||||
description: str,
|
||||
severity: SecurityEventSeverity,
|
||||
triggering_events: List[str],
|
||||
confidence: int = 100,
|
||||
time_window_minutes: Optional[int] = None,
|
||||
affected_users: Optional[List[int]] = None,
|
||||
affected_resources: Optional[List[str]] = None
|
||||
) -> SecurityAlert:
|
||||
"""
|
||||
Create a security alert based on detected patterns
|
||||
"""
|
||||
alert_id = str(uuid.uuid4())
|
||||
|
||||
alert = SecurityAlert(
|
||||
alert_id=alert_id,
|
||||
rule_id=rule_id,
|
||||
rule_name=rule_name,
|
||||
title=title,
|
||||
description=description,
|
||||
severity=severity.value,
|
||||
confidence=confidence,
|
||||
event_count=len(triggering_events),
|
||||
time_window_minutes=time_window_minutes,
|
||||
first_seen=datetime.now(timezone.utc),
|
||||
last_seen=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
# Set JSON fields
|
||||
alert.triggering_events = json.dumps(triggering_events)
|
||||
if affected_users:
|
||||
alert.affected_users = json.dumps(affected_users)
|
||||
if affected_resources:
|
||||
alert.affected_resources = json.dumps(affected_resources)
|
||||
|
||||
self.db.add(alert)
|
||||
self.db.commit()
|
||||
|
||||
logger.warning(
|
||||
f"Security alert created: {title}",
|
||||
extra={
|
||||
"alert_id": alert_id,
|
||||
"severity": severity.value,
|
||||
"confidence": confidence,
|
||||
"event_count": len(triggering_events)
|
||||
}
|
||||
)
|
||||
|
||||
return alert
|
||||
|
||||
def search_audit_logs(
|
||||
self,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
event_types: Optional[List[SecurityEventType]] = None,
|
||||
severities: Optional[List[SecurityEventSeverity]] = None,
|
||||
user_ids: Optional[List[int]] = None,
|
||||
source_ips: Optional[List[str]] = None,
|
||||
resource_types: Optional[List[str]] = None,
|
||||
outcomes: Optional[List[str]] = None,
|
||||
min_risk_score: Optional[int] = None,
|
||||
correlation_id: Optional[str] = None,
|
||||
limit: int = 1000,
|
||||
offset: int = 0
|
||||
) -> List[EnhancedAuditLog]:
|
||||
"""
|
||||
Search audit logs with comprehensive filtering
|
||||
"""
|
||||
query = self.db.query(EnhancedAuditLog)
|
||||
|
||||
# Apply filters
|
||||
if start_date:
|
||||
query = query.filter(EnhancedAuditLog.timestamp >= start_date)
|
||||
if end_date:
|
||||
query = query.filter(EnhancedAuditLog.timestamp <= end_date)
|
||||
if event_types:
|
||||
query = query.filter(EnhancedAuditLog.event_type.in_([et.value for et in event_types]))
|
||||
if severities:
|
||||
query = query.filter(EnhancedAuditLog.severity.in_([s.value for s in severities]))
|
||||
if user_ids:
|
||||
query = query.filter(EnhancedAuditLog.user_id.in_(user_ids))
|
||||
if source_ips:
|
||||
query = query.filter(EnhancedAuditLog.source_ip.in_(source_ips))
|
||||
if resource_types:
|
||||
query = query.filter(EnhancedAuditLog.resource_type.in_(resource_types))
|
||||
if outcomes:
|
||||
query = query.filter(EnhancedAuditLog.outcome.in_(outcomes))
|
||||
if min_risk_score is not None:
|
||||
query = query.filter(EnhancedAuditLog.risk_score >= min_risk_score)
|
||||
if correlation_id:
|
||||
query = query.filter(EnhancedAuditLog.correlation_id == correlation_id)
|
||||
|
||||
return query.order_by(EnhancedAuditLog.timestamp.desc()).offset(offset).limit(limit).all()
|
||||
|
||||
def generate_compliance_report(
|
||||
self,
|
||||
standard: ComplianceStandard,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
generated_by: User,
|
||||
report_type: str = "periodic"
|
||||
) -> ComplianceReport:
|
||||
"""
|
||||
Generate compliance report for specified standard and date range
|
||||
"""
|
||||
report_id = str(uuid.uuid4())
|
||||
|
||||
# Query relevant audit logs
|
||||
logs = self.search_audit_logs(
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
|
||||
# Filter logs relevant to the compliance standard
|
||||
relevant_logs = [
|
||||
log for log in logs
|
||||
if standard.value in (log.get_compliance_standards() or [])
|
||||
]
|
||||
|
||||
# Calculate metrics
|
||||
total_events = len(relevant_logs)
|
||||
security_events = len([log for log in relevant_logs if log.event_category == "security"])
|
||||
violations = len([log for log in relevant_logs if log.outcome in ["failure", "blocked"]])
|
||||
high_risk_events = len([log for log in relevant_logs if log.risk_score >= 70])
|
||||
|
||||
# Generate report content
|
||||
summary = {
|
||||
"total_events": total_events,
|
||||
"security_events": security_events,
|
||||
"violations": violations,
|
||||
"high_risk_events": high_risk_events,
|
||||
"compliance_percentage": ((total_events - violations) / total_events * 100) if total_events > 0 else 100
|
||||
}
|
||||
|
||||
report = ComplianceReport(
|
||||
report_id=report_id,
|
||||
standard=standard.value,
|
||||
report_type=report_type,
|
||||
title=f"{standard.value.upper()} Compliance Report",
|
||||
description=f"Compliance report for {standard.value.upper()} from {start_date.date()} to {end_date.date()}",
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
summary=json.dumps(summary),
|
||||
total_events=total_events,
|
||||
security_events=security_events,
|
||||
violations=violations,
|
||||
high_risk_events=high_risk_events,
|
||||
generated_by=generated_by.id,
|
||||
status="ready"
|
||||
)
|
||||
|
||||
self.db.add(report)
|
||||
self.db.commit()
|
||||
|
||||
logger.info(
|
||||
f"Compliance report generated: {standard.value}",
|
||||
extra={
|
||||
"report_id": report_id,
|
||||
"total_events": total_events,
|
||||
"violations": violations
|
||||
}
|
||||
)
|
||||
|
||||
return report
|
||||
|
||||
def cleanup_old_logs(self) -> int:
|
||||
"""
|
||||
Clean up old audit logs based on retention policies
|
||||
"""
|
||||
# Get active retention policies
|
||||
policies = self.db.query(AuditRetentionPolicy).filter(
|
||||
AuditRetentionPolicy.is_active == True
|
||||
).order_by(AuditRetentionPolicy.priority.desc()).all()
|
||||
|
||||
cleaned_count = 0
|
||||
|
||||
for policy in policies:
|
||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=policy.retention_days)
|
||||
|
||||
# Build query for logs to delete
|
||||
query = self.db.query(EnhancedAuditLog).filter(
|
||||
EnhancedAuditLog.timestamp < cutoff_date
|
||||
)
|
||||
|
||||
# Apply event type filter if specified
|
||||
if policy.event_types:
|
||||
event_types = json.loads(policy.event_types)
|
||||
query = query.filter(EnhancedAuditLog.event_type.in_(event_types))
|
||||
|
||||
# Apply compliance standards filter if specified
|
||||
if policy.compliance_standards:
|
||||
standards = json.loads(policy.compliance_standards)
|
||||
# This is a simplified check - in practice, you'd want more sophisticated filtering
|
||||
for standard in standards:
|
||||
query = query.filter(EnhancedAuditLog.compliance_standards.contains(standard))
|
||||
|
||||
# Delete matching logs
|
||||
count = query.count()
|
||||
query.delete(synchronize_session=False)
|
||||
cleaned_count += count
|
||||
|
||||
logger.info(f"Cleaned {count} logs using policy {policy.policy_name}")
|
||||
|
||||
self.db.commit()
|
||||
return cleaned_count
|
||||
|
||||
def _categorize_event(self, event_type: SecurityEventType) -> str:
|
||||
"""Categorize event type into broader categories"""
|
||||
auth_events = {
|
||||
SecurityEventType.LOGIN_SUCCESS, SecurityEventType.LOGIN_FAILURE,
|
||||
SecurityEventType.LOGOUT, SecurityEventType.SESSION_EXPIRED,
|
||||
SecurityEventType.PASSWORD_CHANGE, SecurityEventType.ACCOUNT_LOCKED
|
||||
}
|
||||
|
||||
security_events = {
|
||||
SecurityEventType.SUSPICIOUS_ACTIVITY, SecurityEventType.ATTACK_DETECTED,
|
||||
SecurityEventType.SECURITY_VIOLATION, SecurityEventType.IP_BLOCKED,
|
||||
SecurityEventType.ACCESS_DENIED, SecurityEventType.UNAUTHORIZED_ACCESS
|
||||
}
|
||||
|
||||
data_events = {
|
||||
SecurityEventType.DATA_READ, SecurityEventType.DATA_WRITE,
|
||||
SecurityEventType.DATA_DELETE, SecurityEventType.DATA_EXPORT,
|
||||
SecurityEventType.BULK_OPERATION
|
||||
}
|
||||
|
||||
if event_type in auth_events:
|
||||
return "authentication"
|
||||
elif event_type in security_events:
|
||||
return "security"
|
||||
elif event_type in data_events:
|
||||
return "data_access"
|
||||
else:
|
||||
return "system"
|
||||
|
||||
def _calculate_risk_score(
|
||||
self,
|
||||
event_type: SecurityEventType,
|
||||
severity: SecurityEventSeverity,
|
||||
risk_factors: Optional[List[str]],
|
||||
threat_indicators: Optional[List[str]]
|
||||
) -> int:
|
||||
"""Calculate risk score for the event"""
|
||||
base_scores = {
|
||||
SecurityEventSeverity.CRITICAL: 80,
|
||||
SecurityEventSeverity.HIGH: 60,
|
||||
SecurityEventSeverity.MEDIUM: 40,
|
||||
SecurityEventSeverity.LOW: 20,
|
||||
SecurityEventSeverity.INFO: 10
|
||||
}
|
||||
|
||||
score = base_scores.get(severity, 10)
|
||||
|
||||
# Add points for risk factors
|
||||
if risk_factors:
|
||||
score += len(risk_factors) * 5
|
||||
|
||||
# Add points for threat indicators
|
||||
if threat_indicators:
|
||||
score += len(threat_indicators) * 10
|
||||
|
||||
# Event type modifiers
|
||||
high_risk_events = {
|
||||
SecurityEventType.ATTACK_DETECTED,
|
||||
SecurityEventType.PRIVILEGE_ESCALATION,
|
||||
SecurityEventType.UNAUTHORIZED_ACCESS
|
||||
}
|
||||
|
||||
if event_type in high_risk_events:
|
||||
score += 20
|
||||
|
||||
return min(score, 100) # Cap at 100
|
||||
|
||||
def _check_security_alerts(self, audit_log: EnhancedAuditLog) -> None:
|
||||
"""Check if audit log should trigger security alerts"""
|
||||
# Example: Multiple failed logins from same IP
|
||||
if audit_log.event_type == SecurityEventType.LOGIN_FAILURE.value:
|
||||
recent_failures = self.db.query(EnhancedAuditLog).filter(
|
||||
and_(
|
||||
EnhancedAuditLog.event_type == SecurityEventType.LOGIN_FAILURE.value,
|
||||
EnhancedAuditLog.source_ip == audit_log.source_ip,
|
||||
EnhancedAuditLog.timestamp >= datetime.now(timezone.utc) - timedelta(minutes=15)
|
||||
)
|
||||
).count()
|
||||
|
||||
if recent_failures >= 5:
|
||||
self.create_security_alert(
|
||||
rule_id="failed_login_threshold",
|
||||
rule_name="Multiple Failed Logins",
|
||||
title=f"Multiple failed logins from {audit_log.source_ip}",
|
||||
description=f"{recent_failures} failed login attempts in 15 minutes",
|
||||
severity=SecurityEventSeverity.HIGH,
|
||||
triggering_events=[audit_log.event_id],
|
||||
time_window_minutes=15
|
||||
)
|
||||
|
||||
# Example: High risk score threshold
|
||||
if audit_log.risk_score >= 80:
|
||||
self.create_security_alert(
|
||||
rule_id="high_risk_event",
|
||||
rule_name="High Risk Security Event",
|
||||
title=f"High risk event detected: {audit_log.title}",
|
||||
description=f"Event with risk score {audit_log.risk_score} detected",
|
||||
severity=SecurityEventSeverity.HIGH,
|
||||
triggering_events=[audit_log.event_id],
|
||||
confidence=audit_log.risk_score
|
||||
)
|
||||
|
||||
def _send_to_siem(self, audit_log: EnhancedAuditLog) -> None:
|
||||
"""Send audit log to configured SIEM systems"""
|
||||
# Get active SIEM integrations
|
||||
integrations = self.db.query(SIEMIntegration).filter(
|
||||
SIEMIntegration.is_active == True
|
||||
).all()
|
||||
|
||||
for integration in integrations:
|
||||
try:
|
||||
# Check if event should be sent based on filters
|
||||
if self._should_send_to_siem(audit_log, integration):
|
||||
# In a real implementation, this would send to the actual SIEM
|
||||
# For now, just log the intent
|
||||
logger.debug(
|
||||
f"Sending event to SIEM {integration.integration_name}",
|
||||
extra={"event_id": audit_log.event_id}
|
||||
)
|
||||
|
||||
# Update statistics
|
||||
integration.events_sent += 1
|
||||
integration.last_sync = datetime.now(timezone.utc)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send to SIEM {integration.integration_name}: {str(e)}")
|
||||
integration.errors_count += 1
|
||||
integration.last_error = str(e)
|
||||
integration.is_healthy = False
|
||||
|
||||
def _should_send_to_siem(self, audit_log: EnhancedAuditLog, integration: SIEMIntegration) -> bool:
|
||||
"""Check if audit log should be sent to specific SIEM integration"""
|
||||
# Check severity threshold
|
||||
severity_order = ["info", "low", "medium", "high", "critical"]
|
||||
if severity_order.index(audit_log.severity) < severity_order.index(integration.severity_threshold):
|
||||
return False
|
||||
|
||||
# Check event type filter
|
||||
if integration.event_types:
|
||||
allowed_types = json.loads(integration.event_types)
|
||||
if audit_log.event_type not in allowed_types:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _get_client_ip(self, request: Request) -> str:
|
||||
"""Extract client IP from request"""
|
||||
forwarded_for = request.headers.get("x-forwarded-for")
|
||||
if forwarded_for:
|
||||
return forwarded_for.split(",")[0].strip()
|
||||
|
||||
real_ip = request.headers.get("x-real-ip")
|
||||
if real_ip:
|
||||
return real_ip
|
||||
|
||||
return request.client.host if request.client else "unknown"
|
||||
|
||||
def _get_geographic_info(self, ip_address: Optional[str]) -> tuple:
|
||||
"""Get geographic information for IP address"""
|
||||
# Placeholder - would integrate with GeoIP service
|
||||
return None, None, None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def audit_context(
|
||||
db: Session,
|
||||
user: Optional[User] = None,
|
||||
session_id: Optional[str] = None,
|
||||
request: Optional[Request] = None,
|
||||
correlation_id: Optional[str] = None
|
||||
):
|
||||
"""Context manager for audit logging"""
|
||||
auditor = EnhancedAuditLogger(db)
|
||||
|
||||
# Set correlation ID for this context
|
||||
if not correlation_id:
|
||||
correlation_id = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
yield auditor
|
||||
except Exception as e:
|
||||
# Log the exception as a security event
|
||||
auditor.log_security_event(
|
||||
event_type=SecurityEventType.SECURITY_VIOLATION,
|
||||
title="System error occurred",
|
||||
description=f"Exception in audit context: {str(e)}",
|
||||
user=user,
|
||||
session_id=session_id,
|
||||
request=request,
|
||||
severity=SecurityEventSeverity.HIGH,
|
||||
outcome="error",
|
||||
correlation_id=correlation_id
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
def get_enhanced_audit_logger(db: Session) -> EnhancedAuditLogger:
|
||||
"""Dependency injection for enhanced audit logger"""
|
||||
return EnhancedAuditLogger(db)
|
||||
540
app/utils/enhanced_auth.py
Normal file
540
app/utils/enhanced_auth.py
Normal file
@@ -0,0 +1,540 @@
|
||||
"""
|
||||
Enhanced Authentication Utilities
|
||||
|
||||
Provides advanced authentication features including:
|
||||
- Password complexity validation
|
||||
- Account lockout protection
|
||||
- Session management
|
||||
- Login attempt tracking
|
||||
- Suspicious activity detection
|
||||
"""
|
||||
import re
|
||||
import time
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional, Dict, List, Tuple
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import func, and_
|
||||
from fastapi import HTTPException, status, Request
|
||||
from passlib.context import CryptContext
|
||||
|
||||
from app.models.user import User
|
||||
try:
|
||||
# Optional: enhanced features may rely on this model
|
||||
from app.models.auth import LoginAttempt # type: ignore
|
||||
except Exception: # pragma: no cover - older schemas may not include this model
|
||||
LoginAttempt = None # type: ignore
|
||||
from app.config import settings
|
||||
from app.utils.logging import app_logger
|
||||
|
||||
logger = app_logger.bind(name="enhanced_auth")
|
||||
|
||||
# Password complexity configuration
|
||||
PASSWORD_CONFIG = {
|
||||
"min_length": 8,
|
||||
"max_length": 128,
|
||||
"require_uppercase": True,
|
||||
"require_lowercase": True,
|
||||
"require_digits": True,
|
||||
"require_special_chars": True,
|
||||
"special_chars": "!@#$%^&*()_+-=[]{}|;:,.<>?",
|
||||
"max_consecutive_chars": 3,
|
||||
"prevent_common_passwords": True,
|
||||
}
|
||||
|
||||
# Account lockout configuration
|
||||
LOCKOUT_CONFIG = {
|
||||
"max_attempts": 5,
|
||||
"lockout_duration": 900, # 15 minutes
|
||||
"window_duration": 900, # 15 minutes
|
||||
"progressive_delay": True,
|
||||
"notify_on_lockout": True,
|
||||
}
|
||||
|
||||
# Common weak passwords to prevent
|
||||
COMMON_PASSWORDS = {
|
||||
"password", "123456", "password123", "admin", "qwerty", "letmein",
|
||||
"welcome", "monkey", "1234567890", "password1", "123456789",
|
||||
"welcome123", "admin123", "root", "toor", "pass", "test", "guest",
|
||||
"user", "login", "default", "changeme", "secret", "administrator"
|
||||
}
|
||||
|
||||
# Password validation context
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
|
||||
class PasswordValidator:
|
||||
"""Advanced password validation with security requirements"""
|
||||
|
||||
@staticmethod
|
||||
def validate_password_strength(password: str) -> Tuple[bool, List[str]]:
|
||||
"""Validate password strength and return detailed feedback"""
|
||||
errors = []
|
||||
|
||||
# Length check
|
||||
if len(password) < PASSWORD_CONFIG["min_length"]:
|
||||
errors.append(f"Password must be at least {PASSWORD_CONFIG['min_length']} characters long")
|
||||
|
||||
if len(password) > PASSWORD_CONFIG["max_length"]:
|
||||
errors.append(f"Password must not exceed {PASSWORD_CONFIG['max_length']} characters")
|
||||
|
||||
# Character requirements
|
||||
if PASSWORD_CONFIG["require_uppercase"] and not re.search(r'[A-Z]', password):
|
||||
errors.append("Password must contain at least one uppercase letter")
|
||||
|
||||
if PASSWORD_CONFIG["require_lowercase"] and not re.search(r'[a-z]', password):
|
||||
errors.append("Password must contain at least one lowercase letter")
|
||||
|
||||
if PASSWORD_CONFIG["require_digits"] and not re.search(r'\d', password):
|
||||
errors.append("Password must contain at least one digit")
|
||||
|
||||
if PASSWORD_CONFIG["require_special_chars"]:
|
||||
special_chars = PASSWORD_CONFIG["special_chars"]
|
||||
if not re.search(f'[{re.escape(special_chars)}]', password):
|
||||
errors.append(f"Password must contain at least one special character ({special_chars[:10]}...)")
|
||||
|
||||
# Consecutive character check
|
||||
max_consecutive = PASSWORD_CONFIG["max_consecutive_chars"]
|
||||
for i in range(len(password) - max_consecutive):
|
||||
substr = password[i:i + max_consecutive + 1]
|
||||
if len(set(substr)) == 1: # All same character
|
||||
errors.append(f"Password cannot contain more than {max_consecutive} consecutive identical characters")
|
||||
break
|
||||
|
||||
# Common password check
|
||||
if PASSWORD_CONFIG["prevent_common_passwords"]:
|
||||
if password.lower() in COMMON_PASSWORDS:
|
||||
errors.append("Password is too common and easily guessable")
|
||||
|
||||
# Sequential character check
|
||||
if PasswordValidator._contains_sequence(password):
|
||||
errors.append("Password cannot contain common keyboard sequences")
|
||||
|
||||
# Dictionary word check (basic)
|
||||
if PasswordValidator._is_dictionary_word(password):
|
||||
errors.append("Password should not be a common dictionary word")
|
||||
|
||||
return len(errors) == 0, errors
|
||||
|
||||
@staticmethod
|
||||
def _contains_sequence(password: str) -> bool:
|
||||
"""Check for common keyboard sequences"""
|
||||
sequences = [
|
||||
"123456789", "987654321", "abcdefgh", "zyxwvuts",
|
||||
"qwertyui", "asdfghjk", "zxcvbnm", "uioplkjh",
|
||||
"qazwsxed", "plmoknij"
|
||||
]
|
||||
|
||||
password_lower = password.lower()
|
||||
for seq in sequences:
|
||||
if seq in password_lower or seq[::-1] in password_lower:
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _is_dictionary_word(password: str) -> bool:
|
||||
"""Basic check for common dictionary words"""
|
||||
# Simple check for common English words
|
||||
common_words = {
|
||||
"password", "computer", "internet", "database", "security",
|
||||
"welcome", "hello", "world", "admin", "user", "login",
|
||||
"system", "server", "network", "access", "control"
|
||||
}
|
||||
|
||||
return password.lower() in common_words
|
||||
|
||||
@staticmethod
|
||||
def generate_password_strength_score(password: str) -> int:
|
||||
"""Generate a password strength score from 0-100"""
|
||||
score = 0
|
||||
|
||||
# Length score (up to 25 points)
|
||||
score += min(25, len(password) * 2)
|
||||
|
||||
# Character diversity (up to 40 points)
|
||||
if re.search(r'[a-z]', password):
|
||||
score += 5
|
||||
if re.search(r'[A-Z]', password):
|
||||
score += 5
|
||||
if re.search(r'\d', password):
|
||||
score += 5
|
||||
if re.search(r'[!@#$%^&*()_+\-=\[\]{}|;:,.<>?]', password):
|
||||
score += 10
|
||||
|
||||
# Bonus for multiple character types
|
||||
char_types = sum([
|
||||
bool(re.search(r'[a-z]', password)),
|
||||
bool(re.search(r'[A-Z]', password)),
|
||||
bool(re.search(r'\d', password)),
|
||||
bool(re.search(r'[!@#$%^&*()_+\-=\[\]{}|;:,.<>?]', password))
|
||||
])
|
||||
score += char_types * 3
|
||||
|
||||
# Length bonus
|
||||
if len(password) >= 12:
|
||||
score += 10
|
||||
if len(password) >= 16:
|
||||
score += 5
|
||||
|
||||
# Penalties
|
||||
if password.lower() in COMMON_PASSWORDS:
|
||||
score -= 25
|
||||
|
||||
# Check for patterns
|
||||
if re.search(r'(.)\1{2,}', password): # Repeated characters
|
||||
score -= 10
|
||||
|
||||
return max(0, min(100, score))
|
||||
|
||||
|
||||
class AccountLockoutManager:
|
||||
"""Manages account lockout and login attempt tracking"""
|
||||
|
||||
@staticmethod
|
||||
def record_login_attempt(
|
||||
db: Session,
|
||||
username: str,
|
||||
success: bool,
|
||||
ip_address: str,
|
||||
user_agent: str,
|
||||
failure_reason: Optional[str] = None
|
||||
) -> None:
|
||||
"""Record a login attempt in the database"""
|
||||
try:
|
||||
if LoginAttempt is None:
|
||||
# Schema not available; log-only fallback
|
||||
logger.info(
|
||||
"Login attempt (no model)",
|
||||
username=username,
|
||||
success=success,
|
||||
ip=ip_address,
|
||||
reason=failure_reason
|
||||
)
|
||||
return
|
||||
|
||||
attempt = LoginAttempt( # type: ignore[call-arg]
|
||||
username=username,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
success=1 if success else 0,
|
||||
failure_reason=failure_reason,
|
||||
timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
db.add(attempt)
|
||||
db.commit()
|
||||
|
||||
logger.info(
|
||||
"Login attempt recorded",
|
||||
username=username,
|
||||
success=success,
|
||||
ip=ip_address,
|
||||
reason=failure_reason
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to record login attempt", error=str(e))
|
||||
db.rollback()
|
||||
|
||||
@staticmethod
|
||||
def is_account_locked(db: Session, username: str) -> Tuple[bool, Optional[datetime]]:
|
||||
"""Check if an account is locked due to failed attempts"""
|
||||
try:
|
||||
if LoginAttempt is None:
|
||||
return False, None
|
||||
now = datetime.now(timezone.utc)
|
||||
window_start = now - timedelta(seconds=LOCKOUT_CONFIG["window_duration"])
|
||||
|
||||
# Count failed attempts within the window
|
||||
failed_attempts = db.query(func.count(LoginAttempt.id)).filter( # type: ignore[attr-defined]
|
||||
and_(
|
||||
LoginAttempt.username == username,
|
||||
LoginAttempt.success == 0,
|
||||
LoginAttempt.timestamp >= window_start
|
||||
)
|
||||
).scalar()
|
||||
|
||||
if failed_attempts >= LOCKOUT_CONFIG["max_attempts"]:
|
||||
# Get the time of the last failed attempt
|
||||
last_attempt = db.query(LoginAttempt.timestamp).filter( # type: ignore[attr-defined]
|
||||
and_(
|
||||
LoginAttempt.username == username,
|
||||
LoginAttempt.success == 0
|
||||
)
|
||||
).order_by(LoginAttempt.timestamp.desc()).first()
|
||||
|
||||
if last_attempt:
|
||||
unlock_time = last_attempt[0] + timedelta(seconds=LOCKOUT_CONFIG["lockout_duration"])
|
||||
if now < unlock_time:
|
||||
return True, unlock_time
|
||||
|
||||
return False, None
|
||||
except Exception as e:
|
||||
logger.error("Failed to check account lockout", error=str(e))
|
||||
return False, None
|
||||
|
||||
@staticmethod
|
||||
def get_lockout_info(db: Session, username: str) -> Dict[str, any]:
|
||||
"""Get detailed lockout information for an account"""
|
||||
try:
|
||||
now = datetime.now(timezone.utc)
|
||||
window_start = now - timedelta(seconds=LOCKOUT_CONFIG["window_duration"])
|
||||
if LoginAttempt is None:
|
||||
return {
|
||||
"is_locked": False,
|
||||
"failed_attempts": 0,
|
||||
"max_attempts": LOCKOUT_CONFIG["max_attempts"],
|
||||
"attempts_remaining": LOCKOUT_CONFIG["max_attempts"],
|
||||
"unlock_time": None,
|
||||
"window_start": window_start.isoformat(),
|
||||
"lockout_duration": LOCKOUT_CONFIG["lockout_duration"],
|
||||
}
|
||||
|
||||
# Get recent failed attempts
|
||||
failed_attempts = db.query(LoginAttempt).filter( # type: ignore[arg-type]
|
||||
and_(
|
||||
LoginAttempt.username == username,
|
||||
LoginAttempt.success == 0,
|
||||
LoginAttempt.timestamp >= window_start
|
||||
)
|
||||
).order_by(LoginAttempt.timestamp.desc()).all()
|
||||
|
||||
failed_count = len(failed_attempts)
|
||||
is_locked, unlock_time = AccountLockoutManager.is_account_locked(db, username)
|
||||
|
||||
return {
|
||||
"is_locked": is_locked,
|
||||
"failed_attempts": failed_count,
|
||||
"max_attempts": LOCKOUT_CONFIG["max_attempts"],
|
||||
"attempts_remaining": max(0, LOCKOUT_CONFIG["max_attempts"] - failed_count),
|
||||
"unlock_time": unlock_time.isoformat() if unlock_time else None,
|
||||
"window_start": window_start.isoformat(),
|
||||
"lockout_duration": LOCKOUT_CONFIG["lockout_duration"],
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error("Failed to get lockout info", error=str(e))
|
||||
return {
|
||||
"is_locked": False,
|
||||
"failed_attempts": 0,
|
||||
"max_attempts": LOCKOUT_CONFIG["max_attempts"],
|
||||
"attempts_remaining": LOCKOUT_CONFIG["max_attempts"],
|
||||
"unlock_time": None,
|
||||
"window_start": window_start.isoformat() if 'window_start' in locals() else None,
|
||||
"lockout_duration": LOCKOUT_CONFIG["lockout_duration"],
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def reset_failed_attempts(db: Session, username: str) -> None:
|
||||
"""Reset failed login attempts for successful login"""
|
||||
try:
|
||||
# We don't delete the records, just mark successful login
|
||||
# The lockout check will naturally reset due to time window
|
||||
logger.info("Failed attempts naturally reset for successful login", username=username)
|
||||
except Exception as e:
|
||||
logger.error("Failed to reset attempts", error=str(e))
|
||||
|
||||
|
||||
class SuspiciousActivityDetector:
|
||||
"""Detects and reports suspicious authentication activity"""
|
||||
|
||||
@staticmethod
|
||||
def detect_suspicious_patterns(db: Session, timeframe_hours: int = 24) -> List[Dict[str, any]]:
|
||||
"""Detect suspicious login patterns"""
|
||||
alerts = []
|
||||
try:
|
||||
if LoginAttempt is None:
|
||||
return []
|
||||
cutoff_time = datetime.now(timezone.utc) - timedelta(hours=timeframe_hours)
|
||||
|
||||
# Get all login attempts in timeframe
|
||||
attempts = db.query(LoginAttempt).filter( # type: ignore[arg-type]
|
||||
LoginAttempt.timestamp >= cutoff_time
|
||||
).all()
|
||||
|
||||
# Analyze patterns
|
||||
ip_attempts = {}
|
||||
username_attempts = {}
|
||||
|
||||
for attempt in attempts:
|
||||
# Group by IP
|
||||
if attempt.ip_address not in ip_attempts:
|
||||
ip_attempts[attempt.ip_address] = []
|
||||
ip_attempts[attempt.ip_address].append(attempt)
|
||||
|
||||
# Group by username
|
||||
if attempt.username not in username_attempts:
|
||||
username_attempts[attempt.username] = []
|
||||
username_attempts[attempt.username].append(attempt)
|
||||
|
||||
# Check for suspicious IP activity
|
||||
for ip, attempts_list in ip_attempts.items():
|
||||
failed_attempts = [a for a in attempts_list if not a.success]
|
||||
if len(failed_attempts) >= 10: # Many failed attempts from one IP
|
||||
alerts.append({
|
||||
"type": "suspicious_ip",
|
||||
"severity": "high",
|
||||
"ip_address": ip,
|
||||
"failed_attempts": len(failed_attempts),
|
||||
"usernames_targeted": list(set(a.username for a in failed_attempts)),
|
||||
"timeframe": f"{timeframe_hours} hours"
|
||||
})
|
||||
|
||||
# Check for account targeting
|
||||
for username, attempts_list in username_attempts.items():
|
||||
failed_attempts = [a for a in attempts_list if not a.success]
|
||||
unique_ips = set(a.ip_address for a in failed_attempts)
|
||||
|
||||
if len(failed_attempts) >= 5 and len(unique_ips) > 2:
|
||||
alerts.append({
|
||||
"type": "account_targeted",
|
||||
"severity": "medium",
|
||||
"username": username,
|
||||
"failed_attempts": len(failed_attempts),
|
||||
"source_ips": list(unique_ips),
|
||||
"timeframe": f"{timeframe_hours} hours"
|
||||
})
|
||||
|
||||
return alerts
|
||||
except Exception as e:
|
||||
logger.error("Failed to detect suspicious patterns", error=str(e))
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def is_login_suspicious(
|
||||
db: Session,
|
||||
username: str,
|
||||
ip_address: str,
|
||||
user_agent: str
|
||||
) -> Tuple[bool, List[str]]:
|
||||
"""Check if a login attempt is suspicious"""
|
||||
warnings = []
|
||||
try:
|
||||
if LoginAttempt is None:
|
||||
return False, []
|
||||
# Check for unusual IP
|
||||
recent_ips = db.query(LoginAttempt.ip_address).filter( # type: ignore[attr-defined]
|
||||
and_(
|
||||
LoginAttempt.username == username,
|
||||
LoginAttempt.success == 1,
|
||||
LoginAttempt.timestamp >= datetime.now(timezone.utc) - timedelta(days=30)
|
||||
)
|
||||
).distinct().all()
|
||||
|
||||
known_ips = {ip[0] for ip in recent_ips}
|
||||
if ip_address not in known_ips and len(known_ips) > 0:
|
||||
warnings.append("Login from new IP address")
|
||||
|
||||
# Check for unusual time
|
||||
now = datetime.now(timezone.utc)
|
||||
if now.hour < 6 or now.hour > 22: # Outside business hours
|
||||
warnings.append("Login outside normal business hours")
|
||||
|
||||
# Check for rapid attempts from same IP
|
||||
recent_attempts = db.query(func.count(LoginAttempt.id)).filter( # type: ignore[attr-defined]
|
||||
and_(
|
||||
LoginAttempt.ip_address == ip_address,
|
||||
LoginAttempt.timestamp >= datetime.now(timezone.utc) - timedelta(minutes=5)
|
||||
)
|
||||
).scalar()
|
||||
|
||||
if recent_attempts > 3:
|
||||
warnings.append("Multiple rapid login attempts from same IP")
|
||||
|
||||
return len(warnings) > 0, warnings
|
||||
except Exception as e:
|
||||
logger.error("Failed to check suspicious login", error=str(e))
|
||||
return False, []
|
||||
|
||||
|
||||
def validate_and_authenticate_user(
|
||||
db: Session,
|
||||
username: str,
|
||||
password: str,
|
||||
request: Request
|
||||
) -> Tuple[Optional[User], List[str]]:
|
||||
"""Enhanced user authentication with security checks"""
|
||||
errors = []
|
||||
|
||||
try:
|
||||
# Extract request information
|
||||
ip_address = get_client_ip(request)
|
||||
user_agent = request.headers.get("user-agent", "")
|
||||
|
||||
# Check account lockout
|
||||
is_locked, unlock_time = AccountLockoutManager.is_account_locked(db, username)
|
||||
if is_locked:
|
||||
AccountLockoutManager.record_login_attempt(
|
||||
db, username, False, ip_address, user_agent, "Account locked"
|
||||
)
|
||||
unlock_str = unlock_time.strftime("%Y-%m-%d %H:%M:%S UTC") if unlock_time else "unknown"
|
||||
errors.append(f"Account is locked due to too many failed attempts. Try again after {unlock_str}")
|
||||
return None, errors
|
||||
|
||||
# Find user
|
||||
user = db.query(User).filter(User.username == username).first()
|
||||
if not user:
|
||||
AccountLockoutManager.record_login_attempt(
|
||||
db, username, False, ip_address, user_agent, "User not found"
|
||||
)
|
||||
errors.append("Invalid username or password")
|
||||
return None, errors
|
||||
|
||||
# Check if user is active
|
||||
if not user.is_active:
|
||||
AccountLockoutManager.record_login_attempt(
|
||||
db, username, False, ip_address, user_agent, "User account disabled"
|
||||
)
|
||||
errors.append("User account is disabled")
|
||||
return None, errors
|
||||
|
||||
# Verify password
|
||||
from app.auth.security import verify_password
|
||||
if not verify_password(password, user.hashed_password):
|
||||
AccountLockoutManager.record_login_attempt(
|
||||
db, username, False, ip_address, user_agent, "Invalid password"
|
||||
)
|
||||
errors.append("Invalid username or password")
|
||||
return None, errors
|
||||
|
||||
# Check for suspicious activity
|
||||
is_suspicious, warnings = SuspiciousActivityDetector.is_login_suspicious(
|
||||
db, username, ip_address, user_agent
|
||||
)
|
||||
|
||||
if is_suspicious:
|
||||
logger.warning(
|
||||
"Suspicious login detected",
|
||||
username=username,
|
||||
ip=ip_address,
|
||||
warnings=warnings
|
||||
)
|
||||
# You could require additional verification here
|
||||
|
||||
# Successful login
|
||||
AccountLockoutManager.record_login_attempt(
|
||||
db, username, True, ip_address, user_agent, None
|
||||
)
|
||||
|
||||
# Update last login time
|
||||
user.last_login = datetime.now(timezone.utc)
|
||||
db.commit()
|
||||
|
||||
return user, []
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Authentication error", error=str(e))
|
||||
errors.append("Authentication service temporarily unavailable")
|
||||
return None, errors
|
||||
|
||||
|
||||
def get_client_ip(request: Request) -> str:
|
||||
"""Extract client IP from request headers"""
|
||||
forwarded_for = request.headers.get("x-forwarded-for")
|
||||
if forwarded_for:
|
||||
return forwarded_for.split(",")[0].strip()
|
||||
|
||||
real_ip = request.headers.get("x-real-ip")
|
||||
if real_ip:
|
||||
return real_ip
|
||||
|
||||
if request.client:
|
||||
return request.client.host
|
||||
|
||||
return "unknown"
|
||||
342
app/utils/file_security.py
Normal file
342
app/utils/file_security.py
Normal file
@@ -0,0 +1,342 @@
|
||||
"""
|
||||
File Security and Validation Utilities
|
||||
|
||||
Comprehensive security validation for file uploads to prevent:
|
||||
- Path traversal attacks
|
||||
- File type spoofing
|
||||
- DoS attacks via large files
|
||||
- Malicious file uploads
|
||||
- Directory traversal
|
||||
"""
|
||||
import os
|
||||
import re
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple, Dict, Any
|
||||
from fastapi import HTTPException, UploadFile
|
||||
|
||||
# Try to import python-magic, fall back to extension-based detection
|
||||
try:
|
||||
import magic
|
||||
MAGIC_AVAILABLE = True
|
||||
except ImportError:
|
||||
MAGIC_AVAILABLE = False
|
||||
|
||||
# File size limits (bytes)
|
||||
MAX_FILE_SIZES = {
|
||||
'document': 10 * 1024 * 1024, # 10MB for documents
|
||||
'csv': 50 * 1024 * 1024, # 50MB for CSV imports
|
||||
'template': 5 * 1024 * 1024, # 5MB for templates
|
||||
'image': 2 * 1024 * 1024, # 2MB for images
|
||||
'default': 10 * 1024 * 1024, # 10MB default
|
||||
}
|
||||
|
||||
# Allowed MIME types for security
|
||||
ALLOWED_MIME_TYPES = {
|
||||
'document': {
|
||||
'application/pdf',
|
||||
'application/msword',
|
||||
'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
|
||||
},
|
||||
'csv': {
|
||||
'text/csv',
|
||||
'text/plain',
|
||||
'application/csv',
|
||||
},
|
||||
'template': {
|
||||
'application/pdf',
|
||||
'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
|
||||
},
|
||||
'image': {
|
||||
'image/jpeg',
|
||||
'image/png',
|
||||
'image/gif',
|
||||
'image/webp',
|
||||
}
|
||||
}
|
||||
|
||||
# File extensions mapping to categories
|
||||
FILE_EXTENSIONS = {
|
||||
'document': {'.pdf', '.doc', '.docx'},
|
||||
'csv': {'.csv', '.txt'},
|
||||
'template': {'.pdf', '.docx'},
|
||||
'image': {'.jpg', '.jpeg', '.png', '.gif', '.webp'},
|
||||
}
|
||||
|
||||
# Dangerous file extensions that should never be uploaded
|
||||
DANGEROUS_EXTENSIONS = {
|
||||
'.exe', '.bat', '.cmd', '.com', '.scr', '.pif', '.vbs', '.js',
|
||||
'.jar', '.app', '.deb', '.pkg', '.dmg', '.rpm', '.msi', '.dll',
|
||||
'.so', '.dylib', '.sys', '.drv', '.ocx', '.cpl', '.scf', '.lnk',
|
||||
'.ps1', '.ps2', '.psc1', '.psc2', '.msh', '.msh1', '.msh2', '.mshxml',
|
||||
'.msh1xml', '.msh2xml', '.scf', '.inf', '.reg', '.vb', '.vbe', '.asp',
|
||||
'.aspx', '.php', '.jsp', '.jspx', '.py', '.rb', '.pl', '.sh', '.bash'
|
||||
}
|
||||
|
||||
|
||||
class FileSecurityValidator:
|
||||
"""Comprehensive file security validation"""
|
||||
|
||||
def __init__(self):
|
||||
self.magic_mime = None
|
||||
if MAGIC_AVAILABLE:
|
||||
try:
|
||||
self.magic_mime = magic.Magic(mime=True)
|
||||
except Exception:
|
||||
self.magic_mime = None
|
||||
|
||||
def sanitize_filename(self, filename: str) -> str:
|
||||
"""Sanitize filename to prevent path traversal and other attacks"""
|
||||
if not filename:
|
||||
raise HTTPException(status_code=400, detail="Filename cannot be empty")
|
||||
|
||||
# Remove any path separators and dangerous characters
|
||||
filename = os.path.basename(filename)
|
||||
filename = re.sub(r'[<>:"/\\|?*\x00-\x1f]', '_', filename)
|
||||
|
||||
# Remove leading/trailing dots and spaces
|
||||
filename = filename.strip('. ')
|
||||
|
||||
# Ensure filename is not empty after sanitization
|
||||
if not filename:
|
||||
raise HTTPException(status_code=400, detail="Invalid filename")
|
||||
|
||||
# Limit filename length
|
||||
if len(filename) > 255:
|
||||
name, ext = os.path.splitext(filename)
|
||||
filename = name[:250] + ext
|
||||
|
||||
return filename
|
||||
|
||||
def validate_file_extension(self, filename: str, category: str) -> str:
|
||||
"""Validate file extension against allowed types"""
|
||||
if not filename:
|
||||
raise HTTPException(status_code=400, detail="Filename required")
|
||||
|
||||
# Get file extension
|
||||
_, ext = os.path.splitext(filename.lower())
|
||||
|
||||
# Check for dangerous extensions
|
||||
if ext in DANGEROUS_EXTENSIONS:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"File type '{ext}' is not allowed for security reasons"
|
||||
)
|
||||
|
||||
# Check against allowed extensions for category
|
||||
allowed_extensions = FILE_EXTENSIONS.get(category, set())
|
||||
if ext not in allowed_extensions:
|
||||
# Standardized message expected by tests
|
||||
raise HTTPException(status_code=400, detail="Invalid file type")
|
||||
|
||||
return ext
|
||||
|
||||
def _detect_mime_from_content(self, content: bytes, filename: str) -> str:
|
||||
"""Detect MIME type from file content or extension"""
|
||||
if self.magic_mime:
|
||||
try:
|
||||
return self.magic_mime.from_buffer(content)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Fallback to extension-based detection and basic content inspection
|
||||
_, ext = os.path.splitext(filename.lower())
|
||||
|
||||
# Basic content-based detection for common file types
|
||||
if content.startswith(b'%PDF'):
|
||||
return 'application/pdf'
|
||||
elif content.startswith(b'PK\x03\x04') and ext in ['.docx', '.xlsx', '.pptx']:
|
||||
if ext == '.docx':
|
||||
return 'application/vnd.openxmlformats-officedocument.wordprocessingml.document'
|
||||
elif content.startswith(b'\xd0\xcf\x11\xe0') and ext in ['.doc', '.xls', '.ppt']:
|
||||
if ext == '.doc':
|
||||
return 'application/msword'
|
||||
elif content.startswith(b'\xff\xd8\xff'):
|
||||
return 'image/jpeg'
|
||||
elif content.startswith(b'\x89PNG'):
|
||||
return 'image/png'
|
||||
elif content.startswith(b'GIF8'):
|
||||
return 'image/gif'
|
||||
elif content.startswith(b'RIFF') and b'WEBP' in content[:20]:
|
||||
return 'image/webp'
|
||||
|
||||
# Extension-based fallback
|
||||
extension_to_mime = {
|
||||
'.pdf': 'application/pdf',
|
||||
'.doc': 'application/msword',
|
||||
'.docx': 'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
|
||||
'.csv': 'text/csv',
|
||||
'.txt': 'text/plain',
|
||||
'.jpg': 'image/jpeg',
|
||||
'.jpeg': 'image/jpeg',
|
||||
'.png': 'image/png',
|
||||
'.gif': 'image/gif',
|
||||
'.webp': 'image/webp',
|
||||
}
|
||||
|
||||
return extension_to_mime.get(ext, 'application/octet-stream')
|
||||
|
||||
def validate_mime_type(self, content: bytes, filename: str, category: str) -> str:
|
||||
"""Validate MIME type using content inspection and file extension"""
|
||||
if not content:
|
||||
raise HTTPException(status_code=400, detail="File content is empty")
|
||||
|
||||
# Detect MIME type
|
||||
detected_mime = self._detect_mime_from_content(content, filename)
|
||||
|
||||
# Check against allowed MIME types
|
||||
allowed_mimes = ALLOWED_MIME_TYPES.get(category, set())
|
||||
if detected_mime not in allowed_mimes:
|
||||
# Standardized message expected by tests
|
||||
raise HTTPException(status_code=400, detail="Invalid file type")
|
||||
|
||||
return detected_mime
|
||||
|
||||
def validate_file_size(self, content: bytes, category: str) -> int:
|
||||
"""Validate file size against limits"""
|
||||
size = len(content)
|
||||
max_size = MAX_FILE_SIZES.get(category, MAX_FILE_SIZES['default'])
|
||||
|
||||
if size == 0:
|
||||
# Standardized message expected by tests
|
||||
raise HTTPException(status_code=400, detail="No file uploaded")
|
||||
|
||||
if size > max_size:
|
||||
# Standardized message expected by tests
|
||||
raise HTTPException(status_code=400, detail="File too large")
|
||||
|
||||
return size
|
||||
|
||||
def scan_for_malware_patterns(self, content: bytes, filename: str) -> None:
|
||||
"""Basic malware pattern detection"""
|
||||
# Check for common malware signatures
|
||||
malware_patterns = [
|
||||
b'<script',
|
||||
b'javascript:',
|
||||
b'vbscript:',
|
||||
b'data:text/html',
|
||||
b'<?php',
|
||||
b'<% ',
|
||||
b'eval(',
|
||||
b'exec(',
|
||||
b'system(',
|
||||
b'shell_exec(',
|
||||
b'passthru(',
|
||||
b'cmd.exe',
|
||||
b'powershell',
|
||||
]
|
||||
|
||||
content_lower = content.lower()
|
||||
for pattern in malware_patterns:
|
||||
if pattern in content_lower:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"File contains potentially malicious content and cannot be uploaded"
|
||||
)
|
||||
|
||||
def generate_secure_path(self, base_dir: str, filename: str, subdir: Optional[str] = None) -> str:
|
||||
"""Generate secure file path preventing directory traversal"""
|
||||
# Sanitize filename
|
||||
safe_filename = self.sanitize_filename(filename)
|
||||
|
||||
# Build path components
|
||||
path_parts = [base_dir]
|
||||
if subdir:
|
||||
# Sanitize subdirectory name
|
||||
safe_subdir = re.sub(r'[^a-zA-Z0-9_-]', '_', subdir)
|
||||
path_parts.append(safe_subdir)
|
||||
path_parts.append(safe_filename)
|
||||
|
||||
# Use Path to safely join and resolve
|
||||
full_path = Path(*path_parts).resolve()
|
||||
base_path = Path(base_dir).resolve()
|
||||
|
||||
# Ensure the resolved path is within the base directory
|
||||
if not str(full_path).startswith(str(base_path)):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Invalid file path - directory traversal detected"
|
||||
)
|
||||
|
||||
return str(full_path)
|
||||
|
||||
async def validate_upload_file(
|
||||
self,
|
||||
file: UploadFile,
|
||||
category: str,
|
||||
max_size_override: Optional[int] = None
|
||||
) -> Tuple[bytes, str, str, str]:
|
||||
"""
|
||||
Comprehensive validation of uploaded file
|
||||
|
||||
Returns: (content, sanitized_filename, file_extension, mime_type)
|
||||
"""
|
||||
# Check if file was uploaded
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="No file uploaded")
|
||||
|
||||
# Read file content
|
||||
content = await file.read()
|
||||
|
||||
# Validate file size
|
||||
if max_size_override:
|
||||
max_size = max_size_override
|
||||
if len(content) > max_size:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"File size exceeds limit ({max_size:,} bytes)"
|
||||
)
|
||||
else:
|
||||
size = self.validate_file_size(content, category)
|
||||
|
||||
# Sanitize filename
|
||||
safe_filename = self.sanitize_filename(file.filename)
|
||||
|
||||
# Validate file extension
|
||||
file_ext = self.validate_file_extension(safe_filename, category)
|
||||
|
||||
# Validate MIME type using actual file content
|
||||
mime_type = self.validate_mime_type(content, safe_filename, category)
|
||||
|
||||
# Scan for malware patterns
|
||||
self.scan_for_malware_patterns(content, safe_filename)
|
||||
|
||||
return content, safe_filename, file_ext, mime_type
|
||||
|
||||
|
||||
# Global instance for use across the application
|
||||
file_validator = FileSecurityValidator()
|
||||
|
||||
|
||||
def validate_csv_content(content: str) -> None:
|
||||
"""Additional validation for CSV content"""
|
||||
# Check for SQL injection patterns in CSV content
|
||||
sql_patterns = [
|
||||
r'(union\s+select)',
|
||||
r'(drop\s+table)',
|
||||
r'(delete\s+from)',
|
||||
r'(insert\s+into)',
|
||||
r'(update\s+.*set)',
|
||||
r'(exec\s*\()',
|
||||
r'(<script)',
|
||||
r'(javascript:)',
|
||||
]
|
||||
|
||||
content_lower = content.lower()
|
||||
for pattern in sql_patterns:
|
||||
if re.search(pattern, content_lower):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="CSV content contains potentially malicious data"
|
||||
)
|
||||
|
||||
|
||||
def create_upload_directory(path: str) -> None:
|
||||
"""Safely create upload directory with proper permissions"""
|
||||
try:
|
||||
os.makedirs(path, mode=0o755, exist_ok=True)
|
||||
except OSError as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Could not create upload directory: {str(e)}"
|
||||
)
|
||||
@@ -17,6 +17,8 @@ class StructuredLogger:
|
||||
def __init__(self, name: str, level: str = "INFO"):
|
||||
self.logger = logging.getLogger(name)
|
||||
self.logger.setLevel(getattr(logging, level.upper()))
|
||||
# Support bound context similar to loguru's bind
|
||||
self._bound_context: Dict[str, Any] = {}
|
||||
|
||||
if not self.logger.handlers:
|
||||
self._setup_handlers()
|
||||
@@ -68,13 +70,24 @@ class StructuredLogger:
|
||||
|
||||
def _log(self, level: int, message: str, **kwargs):
|
||||
"""Internal method to log with structured data."""
|
||||
context: Dict[str, Any] = {}
|
||||
if self._bound_context:
|
||||
context.update(self._bound_context)
|
||||
if kwargs:
|
||||
structured_message = f"{message} | Context: {json.dumps(kwargs, default=str)}"
|
||||
context.update(kwargs)
|
||||
if context:
|
||||
structured_message = f"{message} | Context: {json.dumps(context, default=str)}"
|
||||
else:
|
||||
structured_message = message
|
||||
|
||||
self.logger.log(level, structured_message)
|
||||
|
||||
def bind(self, **kwargs):
|
||||
"""Bind default context fields (compatibility with loguru-style usage)."""
|
||||
if kwargs:
|
||||
self._bound_context.update(kwargs)
|
||||
return self
|
||||
|
||||
|
||||
class ImportLogger(StructuredLogger):
|
||||
"""Specialized logger for import operations."""
|
||||
@@ -261,8 +274,25 @@ def log_function_call(logger: StructuredLogger = None, level: str = "DEBUG"):
|
||||
return decorator
|
||||
|
||||
|
||||
# Local logger cache and factory to avoid circular imports with app.core.logging
|
||||
_loggers: dict[str, StructuredLogger] = {}
|
||||
|
||||
|
||||
def get_logger(name: str) -> StructuredLogger:
|
||||
"""Return a cached StructuredLogger instance.
|
||||
|
||||
This implementation is self-contained to avoid importing app.core.logging,
|
||||
which would create a circular import (core -> utils -> core).
|
||||
"""
|
||||
logger = _loggers.get(name)
|
||||
if logger is None:
|
||||
logger = StructuredLogger(name, getattr(settings, 'log_level', 'INFO'))
|
||||
_loggers[name] = logger
|
||||
return logger
|
||||
|
||||
|
||||
# Pre-configured logger instances
|
||||
app_logger = StructuredLogger("application")
|
||||
app_logger = get_logger("application")
|
||||
import_logger = ImportLogger()
|
||||
security_logger = SecurityLogger()
|
||||
database_logger = DatabaseLogger()
|
||||
@@ -270,16 +300,16 @@ database_logger = DatabaseLogger()
|
||||
# Convenience functions
|
||||
def log_info(message: str, **kwargs):
|
||||
"""Quick info logging."""
|
||||
app_logger.info(message, **kwargs)
|
||||
get_logger("application").info(message, **kwargs)
|
||||
|
||||
def log_warning(message: str, **kwargs):
|
||||
"""Quick warning logging."""
|
||||
app_logger.warning(message, **kwargs)
|
||||
get_logger("application").warning(message, **kwargs)
|
||||
|
||||
def log_error(message: str, **kwargs):
|
||||
"""Quick error logging."""
|
||||
app_logger.error(message, **kwargs)
|
||||
get_logger("application").error(message, **kwargs)
|
||||
|
||||
def log_debug(message: str, **kwargs):
|
||||
"""Quick debug logging."""
|
||||
app_logger.debug(message, **kwargs)
|
||||
get_logger("application").debug(message, **kwargs)
|
||||
445
app/utils/session_manager.py
Normal file
445
app/utils/session_manager.py
Normal file
@@ -0,0 +1,445 @@
|
||||
"""
|
||||
Advanced session management utilities for P2 security features
|
||||
"""
|
||||
import secrets
|
||||
import hashlib
|
||||
import json
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import Optional, List, Dict, Any, Tuple
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import and_, or_, func
|
||||
from fastapi import Request, Depends
|
||||
from user_agents import parse as parse_user_agent
|
||||
|
||||
from app.models.sessions import (
|
||||
UserSession, SessionActivity, SessionConfiguration,
|
||||
SessionSecurityEvent, SessionStatus
|
||||
)
|
||||
from app.models.user import User
|
||||
from app.core.logging import get_logger
|
||||
from app.database.base import get_db
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""
|
||||
Advanced session management with security features
|
||||
"""
|
||||
|
||||
# Default configuration
|
||||
DEFAULT_SESSION_TIMEOUT = timedelta(hours=8)
|
||||
DEFAULT_IDLE_TIMEOUT = timedelta(hours=1)
|
||||
DEFAULT_MAX_CONCURRENT_SESSIONS = 3
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def generate_secure_session_id(self) -> str:
|
||||
"""Generate cryptographically secure session ID"""
|
||||
# Generate 64 bytes of random data and hash it
|
||||
random_bytes = secrets.token_bytes(64)
|
||||
timestamp = str(datetime.now(timezone.utc).timestamp()).encode()
|
||||
return hashlib.sha256(random_bytes + timestamp).hexdigest()
|
||||
|
||||
def create_session(
|
||||
self,
|
||||
user: User,
|
||||
request: Request,
|
||||
login_method: str = "password"
|
||||
) -> UserSession:
|
||||
"""
|
||||
Create new secure session with fixation protection
|
||||
"""
|
||||
# Generate new session ID (prevents session fixation)
|
||||
session_id = self.generate_secure_session_id()
|
||||
|
||||
# Extract request metadata
|
||||
ip_address = self._get_client_ip(request)
|
||||
user_agent = request.headers.get("user-agent", "")
|
||||
device_fingerprint = self._generate_device_fingerprint(request)
|
||||
|
||||
# Get geographic info (placeholder - would integrate with GeoIP service)
|
||||
country, city = self._get_geographic_info(ip_address)
|
||||
|
||||
# Check for suspicious activity
|
||||
is_suspicious, risk_score = self._assess_login_risk(user, ip_address, user_agent)
|
||||
|
||||
# Get session configuration
|
||||
config = self._get_session_config(user)
|
||||
session_timeout = timedelta(minutes=config.session_timeout_minutes)
|
||||
|
||||
# Enforce concurrent session limits
|
||||
self._enforce_concurrent_session_limits(user, config.max_concurrent_sessions)
|
||||
|
||||
# Create session record
|
||||
session = UserSession(
|
||||
session_id=session_id,
|
||||
user_id=user.id,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
device_fingerprint=device_fingerprint,
|
||||
country=country,
|
||||
city=city,
|
||||
is_suspicious=is_suspicious,
|
||||
risk_score=risk_score,
|
||||
status=SessionStatus.ACTIVE,
|
||||
login_method=login_method,
|
||||
expires_at=datetime.now(timezone.utc) + session_timeout
|
||||
)
|
||||
|
||||
self.db.add(session)
|
||||
self.db.flush() # Get session ID
|
||||
|
||||
# Log session creation activity
|
||||
self._log_activity(
|
||||
session, user, request,
|
||||
activity_type="session_created",
|
||||
endpoint="/api/auth/login"
|
||||
)
|
||||
|
||||
# Generate security event if suspicious
|
||||
if is_suspicious:
|
||||
self._create_security_event(
|
||||
session, user,
|
||||
event_type="suspicious_login",
|
||||
severity="medium",
|
||||
description=f"Suspicious login detected: risk score {risk_score}",
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
country=country
|
||||
)
|
||||
|
||||
self.db.commit()
|
||||
logger.info(f"Created session {session_id} for user {user.username} from {ip_address}")
|
||||
|
||||
return session
|
||||
|
||||
def validate_session(self, session_id: str, request: Request) -> Optional[UserSession]:
|
||||
"""
|
||||
Validate session and update activity tracking
|
||||
"""
|
||||
session = self.db.query(UserSession).filter(
|
||||
UserSession.session_id == session_id
|
||||
).first()
|
||||
|
||||
if not session:
|
||||
return None
|
||||
|
||||
# Check if session is expired or revoked
|
||||
if not session.is_active():
|
||||
return None
|
||||
|
||||
# Check for IP address changes if configured
|
||||
current_ip = self._get_client_ip(request)
|
||||
config = self._get_session_config(session.user)
|
||||
|
||||
if config.force_logout_on_ip_change and session.ip_address != current_ip:
|
||||
self._create_security_event(
|
||||
session, session.user,
|
||||
event_type="ip_address_change",
|
||||
severity="medium",
|
||||
description=f"IP changed from {session.ip_address} to {current_ip}",
|
||||
ip_address=current_ip,
|
||||
action_taken="session_revoked"
|
||||
)
|
||||
session.revoke_session("ip_address_change")
|
||||
self.db.commit()
|
||||
return None
|
||||
|
||||
# Check idle timeout
|
||||
idle_timeout = timedelta(minutes=config.idle_timeout_minutes)
|
||||
if datetime.now(timezone.utc) - session.last_activity > idle_timeout:
|
||||
session.status = SessionStatus.EXPIRED
|
||||
self.db.commit()
|
||||
return None
|
||||
|
||||
# Update last activity
|
||||
session.last_activity = datetime.now(timezone.utc)
|
||||
|
||||
# Log activity
|
||||
self._log_activity(
|
||||
session, session.user, request,
|
||||
activity_type="session_validation",
|
||||
endpoint=str(request.url.path)
|
||||
)
|
||||
|
||||
self.db.commit()
|
||||
return session
|
||||
|
||||
def revoke_session(self, session_id: str, reason: str = "user_logout") -> bool:
|
||||
"""Revoke a specific session"""
|
||||
session = self.db.query(UserSession).filter(
|
||||
UserSession.session_id == session_id
|
||||
).first()
|
||||
|
||||
if session:
|
||||
session.revoke_session(reason)
|
||||
self.db.commit()
|
||||
logger.info(f"Revoked session {session_id}: {reason}")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def revoke_all_user_sessions(self, user_id: int, reason: str = "admin_action") -> int:
|
||||
"""Revoke all sessions for a user"""
|
||||
count = self.db.query(UserSession).filter(
|
||||
and_(
|
||||
UserSession.user_id == user_id,
|
||||
UserSession.status == SessionStatus.ACTIVE
|
||||
)
|
||||
).update({
|
||||
"status": SessionStatus.REVOKED,
|
||||
"revoked_at": datetime.now(timezone.utc),
|
||||
"revocation_reason": reason
|
||||
})
|
||||
|
||||
self.db.commit()
|
||||
logger.info(f"Revoked {count} sessions for user {user_id}: {reason}")
|
||||
return count
|
||||
|
||||
def get_active_sessions(self, user_id: int) -> List[UserSession]:
|
||||
"""Get all active sessions for a user"""
|
||||
return self.db.query(UserSession).filter(
|
||||
and_(
|
||||
UserSession.user_id == user_id,
|
||||
UserSession.status == SessionStatus.ACTIVE,
|
||||
UserSession.expires_at > datetime.now(timezone.utc)
|
||||
)
|
||||
).order_by(UserSession.last_activity.desc()).all()
|
||||
|
||||
def cleanup_expired_sessions(self) -> int:
|
||||
"""Clean up expired sessions"""
|
||||
count = self.db.query(UserSession).filter(
|
||||
and_(
|
||||
UserSession.status == SessionStatus.ACTIVE,
|
||||
UserSession.expires_at <= datetime.now(timezone.utc)
|
||||
)
|
||||
).update({
|
||||
"status": SessionStatus.EXPIRED
|
||||
})
|
||||
|
||||
self.db.commit()
|
||||
|
||||
if count > 0:
|
||||
logger.info(f"Cleaned up {count} expired sessions")
|
||||
|
||||
return count
|
||||
|
||||
def get_session_statistics(self, user_id: Optional[int] = None) -> Dict[str, Any]:
|
||||
"""Get session statistics for monitoring"""
|
||||
query = self.db.query(UserSession)
|
||||
if user_id:
|
||||
query = query.filter(UserSession.user_id == user_id)
|
||||
|
||||
# Basic counts
|
||||
total_sessions = query.count()
|
||||
active_sessions = query.filter(UserSession.status == SessionStatus.ACTIVE).count()
|
||||
suspicious_sessions = query.filter(UserSession.is_suspicious == True).count()
|
||||
|
||||
# Recent activity
|
||||
last_24h = datetime.now(timezone.utc) - timedelta(days=1)
|
||||
recent_sessions = query.filter(UserSession.created_at >= last_24h).count()
|
||||
|
||||
# Risk distribution
|
||||
high_risk = query.filter(UserSession.risk_score >= 70).count()
|
||||
medium_risk = query.filter(
|
||||
and_(UserSession.risk_score >= 30, UserSession.risk_score < 70)
|
||||
).count()
|
||||
low_risk = query.filter(UserSession.risk_score < 30).count()
|
||||
|
||||
return {
|
||||
"total_sessions": total_sessions,
|
||||
"active_sessions": active_sessions,
|
||||
"suspicious_sessions": suspicious_sessions,
|
||||
"recent_sessions_24h": recent_sessions,
|
||||
"risk_distribution": {
|
||||
"high": high_risk,
|
||||
"medium": medium_risk,
|
||||
"low": low_risk
|
||||
}
|
||||
}
|
||||
|
||||
def _enforce_concurrent_session_limits(self, user: User, max_sessions: int) -> None:
|
||||
"""Enforce concurrent session limits"""
|
||||
active_sessions = self.get_active_sessions(user.id)
|
||||
|
||||
if len(active_sessions) >= max_sessions:
|
||||
# Revoke oldest sessions
|
||||
sessions_to_revoke = active_sessions[max_sessions-1:]
|
||||
for session in sessions_to_revoke:
|
||||
session.revoke_session("concurrent_session_limit")
|
||||
|
||||
# Create security event
|
||||
self._create_security_event(
|
||||
session, user,
|
||||
event_type="concurrent_session_limit",
|
||||
severity="medium",
|
||||
description=f"Session revoked due to concurrent session limit ({max_sessions})",
|
||||
action_taken="session_revoked"
|
||||
)
|
||||
|
||||
logger.info(f"Revoked {len(sessions_to_revoke)} sessions for user {user.username} due to concurrent limit")
|
||||
|
||||
def _get_session_config(self, user: User) -> SessionConfiguration:
|
||||
"""Get session configuration for user"""
|
||||
# Try user-specific config first
|
||||
config = self.db.query(SessionConfiguration).filter(
|
||||
SessionConfiguration.user_id == user.id
|
||||
).first()
|
||||
|
||||
if not config:
|
||||
# Try global config
|
||||
config = self.db.query(SessionConfiguration).filter(
|
||||
SessionConfiguration.user_id.is_(None)
|
||||
).first()
|
||||
|
||||
if not config:
|
||||
# Create default global config
|
||||
config = SessionConfiguration()
|
||||
self.db.add(config)
|
||||
self.db.flush()
|
||||
|
||||
return config
|
||||
|
||||
def _assess_login_risk(self, user: User, ip_address: str, user_agent: str) -> Tuple[bool, int]:
|
||||
"""Assess login risk based on historical data"""
|
||||
risk_score = 0
|
||||
risk_factors = []
|
||||
|
||||
# Check for new IP address
|
||||
previous_ips = self.db.query(UserSession.ip_address).filter(
|
||||
and_(
|
||||
UserSession.user_id == user.id,
|
||||
UserSession.created_at >= datetime.now(timezone.utc) - timedelta(days=30)
|
||||
)
|
||||
).distinct().all()
|
||||
|
||||
if ip_address not in [ip[0] for ip in previous_ips]:
|
||||
risk_score += 30
|
||||
risk_factors.append("new_ip_address")
|
||||
|
||||
# Check for unusual login time
|
||||
current_hour = datetime.now(timezone.utc).hour
|
||||
user_login_hours = self.db.query(func.extract('hour', UserSession.created_at)).filter(
|
||||
and_(
|
||||
UserSession.user_id == user.id,
|
||||
UserSession.created_at >= datetime.now(timezone.utc) - timedelta(days=30)
|
||||
)
|
||||
).all()
|
||||
|
||||
if user_login_hours:
|
||||
common_hours = [hour[0] for hour in user_login_hours]
|
||||
if current_hour not in common_hours[-10:]: # Not in recent login hours
|
||||
risk_score += 20
|
||||
risk_factors.append("unusual_time")
|
||||
|
||||
# Check for new user agent
|
||||
recent_agents = self.db.query(UserSession.user_agent).filter(
|
||||
and_(
|
||||
UserSession.user_id == user.id,
|
||||
UserSession.created_at >= datetime.now(timezone.utc) - timedelta(days=7)
|
||||
)
|
||||
).distinct().all()
|
||||
|
||||
if user_agent not in [agent[0] for agent in recent_agents if agent[0]]:
|
||||
risk_score += 15
|
||||
risk_factors.append("new_user_agent")
|
||||
|
||||
# Check for rapid login attempts
|
||||
recent_attempts = self.db.query(UserSession).filter(
|
||||
and_(
|
||||
UserSession.user_id == user.id,
|
||||
UserSession.created_at >= datetime.now(timezone.utc) - timedelta(minutes=10)
|
||||
)
|
||||
).count()
|
||||
|
||||
if recent_attempts > 3:
|
||||
risk_score += 25
|
||||
risk_factors.append("rapid_attempts")
|
||||
|
||||
is_suspicious = risk_score >= 50
|
||||
return is_suspicious, min(risk_score, 100)
|
||||
|
||||
def _log_activity(
|
||||
self,
|
||||
session: UserSession,
|
||||
user: User,
|
||||
request: Request,
|
||||
activity_type: str,
|
||||
endpoint: str = None
|
||||
) -> None:
|
||||
"""Log session activity"""
|
||||
activity = SessionActivity(
|
||||
session_id=session.id,
|
||||
user_id=user.id,
|
||||
activity_type=activity_type,
|
||||
endpoint=endpoint or str(request.url.path),
|
||||
method=request.method,
|
||||
ip_address=self._get_client_ip(request),
|
||||
user_agent=request.headers.get("user-agent", "")
|
||||
)
|
||||
|
||||
self.db.add(activity)
|
||||
|
||||
def _create_security_event(
|
||||
self,
|
||||
session: Optional[UserSession],
|
||||
user: User,
|
||||
event_type: str,
|
||||
severity: str,
|
||||
description: str,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None,
|
||||
country: str = None,
|
||||
action_taken: str = None
|
||||
) -> None:
|
||||
"""Create security event record"""
|
||||
event = SessionSecurityEvent(
|
||||
session_id=session.id if session else None,
|
||||
user_id=user.id,
|
||||
event_type=event_type,
|
||||
severity=severity,
|
||||
description=description,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
country=country,
|
||||
action_taken=action_taken
|
||||
)
|
||||
|
||||
self.db.add(event)
|
||||
logger.warning(f"Security event [{severity}]: {event_type} for user {user.username} - {description}")
|
||||
|
||||
def _get_client_ip(self, request: Request) -> str:
|
||||
"""Extract client IP address from request"""
|
||||
# Check for forwarded headers first
|
||||
forwarded_for = request.headers.get("x-forwarded-for")
|
||||
if forwarded_for:
|
||||
return forwarded_for.split(",")[0].strip()
|
||||
|
||||
real_ip = request.headers.get("x-real-ip")
|
||||
if real_ip:
|
||||
return real_ip
|
||||
|
||||
# Fallback to direct connection
|
||||
return request.client.host if request.client else "unknown"
|
||||
|
||||
def _generate_device_fingerprint(self, request: Request) -> str:
|
||||
"""Generate device fingerprint for tracking"""
|
||||
user_agent = request.headers.get("user-agent", "")
|
||||
accept_language = request.headers.get("accept-language", "")
|
||||
accept_encoding = request.headers.get("accept-encoding", "")
|
||||
|
||||
fingerprint_data = f"{user_agent}|{accept_language}|{accept_encoding}"
|
||||
return hashlib.md5(fingerprint_data.encode()).hexdigest()
|
||||
|
||||
def _get_geographic_info(self, ip_address: str) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""Get geographic information for IP address"""
|
||||
# Placeholder - would integrate with GeoIP service like MaxMind
|
||||
# For now, return None values
|
||||
return None, None
|
||||
|
||||
|
||||
def get_session_manager(db: Session = Depends(get_db)) -> SessionManager:
|
||||
"""Dependency injection for session manager"""
|
||||
return SessionManager(db)
|
||||
Reference in New Issue
Block a user