This commit is contained in:
HotSwapp
2025-08-18 20:20:04 -05:00
parent 89b2bc0aa2
commit bac8cc4bd5
114 changed files with 30258 additions and 1341 deletions

View File

@@ -0,0 +1,379 @@
"""
Database Security Utilities
Provides utilities for secure database operations and SQL injection prevention:
- Parameterized query helpers
- SQL injection detection and prevention
- Safe query building utilities
- Database security auditing
"""
import re
from typing import Any, Dict, List, Optional, Union
from sqlalchemy import text
from sqlalchemy.orm import Session
from sqlalchemy.sql import ClauseElement
from app.utils.logging import app_logger
logger = app_logger.bind(name="database_security")
# Dangerous SQL patterns that could indicate injection attempts
DANGEROUS_PATTERNS = [
# SQL injection keywords
r'\bunion\s+select\b',
r'\b(drop\s+table)\b',
r'\bdelete\s+from\b',
r'\b(insert\s+into)\b',
r'\b(update\s+.*set)\b',
r'\b(alter\s+table)\b',
r'\b(create\s+table)\b',
r'\b(truncate\s+table)\b',
# Command execution
r'\b(exec\s*\()\b',
r'\b(execute\s*\()\b',
r'\b(system\s*\()\b',
r'\b(shell_exec\s*\()\b',
r'\b(eval\s*\()\b',
# Comment-based attacks
r'(;\s*drop\s+table\b)',
r'(--\s*$)',
r'(/\*.*?\*/)',
r'(\#.*$)',
# Quote escaping attempts
r"(';|';|\";|\")",
r"(\\\\'|\\\\\"|\\\\x)",
# Hex/unicode encoding
r'(0x[0-9a-fA-F]+)',
r'(\\u[0-9a-fA-F]{4})',
# Boolean-based attacks
r'\b(1=1|1=0|true|false)\b',
r"\b(or\s+1\s*=\s*1|and\s+1\s*=\s*1)\b",
# Time-based attacks
r'\b(sleep\s*\(|delay\s*\(|waitfor\s+delay)\b',
# File operations
r'\b(load_file\s*\(|into\s+outfile|into\s+dumpfile)\b',
# Information schema access
r'\b(information_schema|sys\.|pg_)\b',
# Subselect usage in WHERE clause that may indicate enumeration
r'\b\(\s*select\b',
]
# Compiled regex patterns for performance
COMPILED_PATTERNS = [re.compile(pattern, re.IGNORECASE | re.MULTILINE) for pattern in DANGEROUS_PATTERNS]
class SQLSecurityValidator:
"""Validates SQL queries and parameters for security issues"""
@staticmethod
def validate_query_string(query: str) -> List[str]:
"""Validate a SQL query string for potential injection attempts"""
issues: List[str] = []
for i, pattern in enumerate(COMPILED_PATTERNS):
try:
matches = pattern.findall(query)
except Exception:
matches = []
if matches:
issues.append(
f"Potentially dangerous SQL pattern detected: {DANGEROUS_PATTERNS[i]} -> {str(matches)[:80]}"
)
# Heuristic fallback to catch common cases without relying on regex quirks
if not issues:
ql = (query or "").lower()
if "; drop table" in ql or ";drop table" in ql:
issues.append("Heuristic: DROP TABLE detected")
if " union select " in ql or " union\nselect " in ql:
issues.append("Heuristic: UNION SELECT detected")
if " or 1=1" in ql or " and 1=1" in ql or " or 1 = 1" in ql or " and 1 = 1" in ql:
issues.append("Heuristic: tautology (1=1) detected")
if "(select" in ql:
issues.append("Heuristic: subselect in WHERE detected")
return issues
@staticmethod
def validate_parameter_value(param_name: str, param_value: Any) -> List[str]:
"""Validate a parameter value for potential injection attempts"""
issues = []
if param_value is None:
return issues
# Convert to string for pattern matching
str_value = str(param_value)
# Check for dangerous patterns in parameter values
for i, pattern in enumerate(COMPILED_PATTERNS):
matches = pattern.findall(str_value)
if matches:
issues.append(f"Parameter '{param_name}' contains dangerous pattern: {DANGEROUS_PATTERNS[i]}")
# Additional parameter-specific checks
if isinstance(param_value, str):
# Check for excessive length (potential buffer overflow)
if len(param_value) > 10000:
issues.append(f"Parameter '{param_name}' is excessively long ({len(param_value)} characters)")
# Check for null bytes
if '\x00' in param_value:
issues.append(f"Parameter '{param_name}' contains null bytes")
# Check for control characters
if any(ord(c) < 32 and c not in '\t\n\r' for c in param_value):
issues.append(f"Parameter '{param_name}' contains suspicious control characters")
return issues
@staticmethod
def validate_query_with_params(query: str, params: Dict[str, Any]) -> List[str]:
"""Validate a complete query with its parameters"""
issues = []
# Validate the query string
query_issues = SQLSecurityValidator.validate_query_string(query)
issues.extend(query_issues)
# Validate each parameter
for param_name, param_value in params.items():
param_issues = SQLSecurityValidator.validate_parameter_value(param_name, param_value)
issues.extend(param_issues)
return issues
class SecureQueryBuilder:
"""Builds secure parameterized queries to prevent SQL injection"""
@staticmethod
def safe_text_query(query: str, params: Optional[Dict[str, Any]] = None) -> text:
"""Create a safe text query with parameter validation"""
params = params or {}
# Validate the query and parameters
issues = SQLSecurityValidator.validate_query_with_params(query, params)
if issues:
logger.warning("Potential security issues in query", query=query[:100], issues=issues)
# In production, you might want to raise an exception or sanitize
# For now, we'll log and proceed with caution
return text(query)
@staticmethod
def build_like_clause(column: ClauseElement, search_term: str, case_sensitive: bool = False) -> ClauseElement:
"""Build a safe LIKE clause with proper escaping"""
# Escape special characters in the search term
escaped_term = search_term.replace('%', r'\%').replace('_', r'\_').replace('\\', r'\\')
if case_sensitive:
return column.like(f"%{escaped_term}%", escape='\\')
else:
return column.ilike(f"%{escaped_term}%", escape='\\')
@staticmethod
def build_in_clause(column: ClauseElement, values: List[Any]) -> ClauseElement:
"""Build a safe IN clause with parameter binding"""
if not values:
return column.in_([]) # Empty list
# Validate each value
for i, value in enumerate(values):
issues = SQLSecurityValidator.validate_parameter_value(f"in_value_{i}", value)
if issues:
logger.warning(f"Security issues in IN clause value: {issues}")
return column.in_(values)
@staticmethod
def build_fts_query(search_terms: List[str], exact_phrase: bool = False) -> str:
"""Build a safe FTS (Full Text Search) query"""
if not search_terms:
return ""
safe_terms = []
for term in search_terms:
# Remove or escape dangerous characters for FTS
# SQLite FTS5 has its own escaping rules
safe_term = re.sub(r'[^\w\s\-\.]', '', term.strip())
if safe_term:
if exact_phrase:
# Escape quotes for phrase search
safe_term = safe_term.replace('"', '""')
safe_terms.append(f'"{safe_term}"')
else:
safe_terms.append(safe_term)
if exact_phrase:
return ' '.join(safe_terms)
else:
return ' AND '.join(safe_terms)
class DatabaseAuditor:
"""Audits database operations for security compliance"""
@staticmethod
def audit_query_execution(query: str, params: Dict[str, Any], execution_time: float) -> None:
"""Audit a query execution for security and performance"""
# Security audit
security_issues = SQLSecurityValidator.validate_query_with_params(query, params)
if security_issues:
logger.warning(
"Query executed with security issues",
query=query[:200],
issues=security_issues,
execution_time=execution_time
)
# Performance audit
if execution_time > 5.0: # Slow query threshold
logger.warning(
"Slow query detected",
query=query[:200],
execution_time=execution_time
)
# Pattern audit - detect potentially problematic queries
if re.search(r'\bSELECT\s+\*\s+FROM\b', query, re.IGNORECASE):
logger.info("SELECT * query detected - consider specifying columns", query=query[:100])
if re.search(r'\bLIMIT\s+\d{4,}\b', query, re.IGNORECASE):
logger.info("Large LIMIT detected", query=query[:100])
def execute_secure_query(
db: Session,
query: str,
params: Optional[Dict[str, Any]] = None,
audit: bool = True
) -> Any:
"""Execute a query with security validation and auditing"""
import time
params = params or {}
start_time = time.time()
try:
# Validate query security
if audit:
issues = SQLSecurityValidator.validate_query_with_params(query, params)
if issues:
logger.warning("Executing query with potential security issues", issues=issues)
# Create safe text query
safe_query = SecureQueryBuilder.safe_text_query(query, params)
# Execute query
result = db.execute(safe_query, params)
# Audit execution
if audit:
execution_time = time.time() - start_time
DatabaseAuditor.audit_query_execution(query, params, execution_time)
return result
except Exception as e:
execution_time = time.time() - start_time
logger.error(
"Query execution failed",
query=query[:200],
error=str(e),
execution_time=execution_time
)
raise
def sanitize_fts_query(query: str) -> str:
"""Sanitize user input for FTS queries"""
if not query:
return ""
# Remove potentially dangerous characters
# Keep alphanumeric, spaces, and basic punctuation
sanitized = re.sub(r'[^\w\s\-\.\,\!\?\"\']', ' ', query)
# Remove excessive whitespace
sanitized = re.sub(r'\s+', ' ', sanitized).strip()
# Limit length
if len(sanitized) > 500:
sanitized = sanitized[:500]
return sanitized
def create_safe_search_conditions(
search_terms: List[str],
searchable_columns: List[ClauseElement],
case_sensitive: bool = False,
exact_phrase: bool = False
) -> Optional[ClauseElement]:
"""Create safe search conditions for multiple columns"""
from sqlalchemy import or_, and_
if not search_terms or not searchable_columns:
return None
search_conditions = []
if exact_phrase:
# Single phrase search across all columns
phrase = ' '.join(search_terms)
phrase_conditions = []
for column in searchable_columns:
phrase_conditions.append(
SecureQueryBuilder.build_like_clause(column, phrase, case_sensitive)
)
search_conditions.append(or_(*phrase_conditions))
else:
# Each term must match at least one column
for term in search_terms:
term_conditions = []
for column in searchable_columns:
term_conditions.append(
SecureQueryBuilder.build_like_clause(column, term, case_sensitive)
)
search_conditions.append(or_(*term_conditions))
return and_(*search_conditions) if search_conditions else None
# Whitelist of allowed column names for dynamic queries
ALLOWED_SORT_COLUMNS = {
'rolodex': ['id', 'first', 'last', 'city', 'email', 'created_at', 'updated_at'],
'files': ['file_no', 'id', 'regarding', 'status', 'file_type', 'opened', 'closed', 'created_at', 'updated_at'],
'ledger': ['id', 'file_no', 't_code', 'amount', 'date', 'created_at', 'updated_at'],
'qdros': ['id', 'file_no', 'form_name', 'status', 'created_at', 'updated_at'],
}
def validate_sort_column(table: str, column: str) -> bool:
"""Validate that a sort column is allowed for a table"""
allowed_columns = ALLOWED_SORT_COLUMNS.get(table, [])
return column in allowed_columns
def safe_order_by(table: str, sort_column: str, sort_direction: str = 'asc') -> Optional[str]:
"""Create a safe ORDER BY clause with whitelist validation"""
# Validate sort column
if not validate_sort_column(table, sort_column):
logger.warning(f"Invalid sort column '{sort_column}' for table '{table}'")
return None
# Validate sort direction
if sort_direction.lower() not in ['asc', 'desc']:
logger.warning(f"Invalid sort direction '{sort_direction}'")
return None
return f"{sort_column} {sort_direction.upper()}"

668
app/utils/enhanced_audit.py Normal file
View File

@@ -0,0 +1,668 @@
"""
Enhanced audit logging utilities for P2 security features
"""
import uuid
import json
import hashlib
from datetime import datetime, timezone, timedelta
from typing import Optional, Dict, Any, List, Union
from contextlib import contextmanager
from sqlalchemy.orm import Session
from sqlalchemy import and_, or_, func
from fastapi import Request
from user_agents import parse as parse_user_agent
from app.models.audit_enhanced import (
EnhancedAuditLog, SecurityAlert, ComplianceReport,
AuditRetentionPolicy, SIEMIntegration,
SecurityEventType, SecurityEventSeverity, ComplianceStandard
)
from app.models.user import User
from app.utils.logging import get_logger
logger = get_logger(__name__)
class EnhancedAuditLogger:
"""
Enhanced audit logging system with security event tracking
"""
def __init__(self, db: Session):
self.db = db
def log_security_event(
self,
event_type: SecurityEventType,
title: str,
description: str,
user: Optional[User] = None,
session_id: Optional[str] = None,
request: Optional[Request] = None,
severity: SecurityEventSeverity = SecurityEventSeverity.INFO,
outcome: str = "success",
resource_type: Optional[str] = None,
resource_id: Optional[str] = None,
resource_name: Optional[str] = None,
data_before: Optional[Dict[str, Any]] = None,
data_after: Optional[Dict[str, Any]] = None,
risk_factors: Optional[List[str]] = None,
threat_indicators: Optional[List[str]] = None,
compliance_standards: Optional[List[ComplianceStandard]] = None,
tags: Optional[List[str]] = None,
custom_fields: Optional[Dict[str, Any]] = None,
correlation_id: Optional[str] = None
) -> EnhancedAuditLog:
"""
Log a comprehensive security event
"""
# Generate unique event ID
event_id = str(uuid.uuid4())
# Extract request metadata
source_ip = None
user_agent = None
endpoint = None
http_method = None
request_id = None
if request:
source_ip = self._get_client_ip(request)
user_agent = request.headers.get("user-agent", "")
endpoint = str(request.url.path)
http_method = request.method
request_id = getattr(request.state, 'request_id', None)
# Determine event category
event_category = self._categorize_event(event_type)
# Calculate risk score
risk_score = self._calculate_risk_score(
event_type, severity, risk_factors, threat_indicators
)
# Get geographic info (placeholder - would integrate with GeoIP)
country, region, city = self._get_geographic_info(source_ip)
# Create audit log entry
audit_log = EnhancedAuditLog(
event_id=event_id,
event_type=event_type.value,
event_category=event_category,
severity=severity.value,
title=title,
description=description,
outcome=outcome,
user_id=user.id if user else None,
session_id=session_id,
source_ip=source_ip,
user_agent=user_agent,
request_id=request_id,
country=country,
region=region,
city=city,
endpoint=endpoint,
http_method=http_method,
resource_type=resource_type,
resource_id=resource_id,
resource_name=resource_name,
risk_score=risk_score,
correlation_id=correlation_id or str(uuid.uuid4())
)
# Set JSON data
if data_before:
audit_log.set_data_before(data_before)
if data_after:
audit_log.set_data_after(data_after)
if risk_factors:
audit_log.set_risk_factors(risk_factors)
if threat_indicators:
audit_log.set_threat_indicators(threat_indicators)
if compliance_standards:
audit_log.set_compliance_standards([std.value for std in compliance_standards])
if tags:
audit_log.set_tags(tags)
if custom_fields:
audit_log.set_custom_fields(custom_fields)
# Save to database
self.db.add(audit_log)
self.db.flush() # Get ID for further processing
# Check for security alerts
self._check_security_alerts(audit_log)
# Send to SIEM systems
self._send_to_siem(audit_log)
self.db.commit()
logger.info(
f"Security event logged: {event_type.value}",
extra={
"event_id": event_id,
"user_id": user.id if user else None,
"severity": severity.value,
"risk_score": risk_score
}
)
return audit_log
def log_data_access(
self,
user: User,
resource_type: str,
resource_id: str,
action: str, # read, write, delete, export
request: Optional[Request] = None,
session_id: Optional[str] = None,
record_count: Optional[int] = None,
data_volume: Optional[int] = None,
compliance_standards: Optional[List[ComplianceStandard]] = None
) -> EnhancedAuditLog:
"""
Log data access events for compliance
"""
event_type_map = {
"read": SecurityEventType.DATA_READ,
"write": SecurityEventType.DATA_WRITE,
"delete": SecurityEventType.DATA_DELETE,
"export": SecurityEventType.DATA_EXPORT
}
event_type = event_type_map.get(action, SecurityEventType.DATA_READ)
return self.log_security_event(
event_type=event_type,
title=f"Data {action} operation",
description=f"User {user.username} performed {action} on {resource_type} {resource_id}",
user=user,
session_id=session_id,
request=request,
severity=SecurityEventSeverity.INFO,
resource_type=resource_type,
resource_id=resource_id,
compliance_standards=compliance_standards or [ComplianceStandard.SOX],
custom_fields={
"record_count": record_count,
"data_volume": data_volume
}
)
def log_authentication_event(
self,
event_type: SecurityEventType,
username: str,
request: Request,
user: Optional[User] = None,
session_id: Optional[str] = None,
outcome: str = "success",
details: Optional[str] = None,
risk_factors: Optional[List[str]] = None
) -> EnhancedAuditLog:
"""
Log authentication-related events
"""
severity = SecurityEventSeverity.INFO
if outcome == "failure" or risk_factors:
severity = SecurityEventSeverity.MEDIUM
if event_type == SecurityEventType.ACCOUNT_LOCKED:
severity = SecurityEventSeverity.HIGH
return self.log_security_event(
event_type=event_type,
title=f"Authentication event: {event_type.value}",
description=details or f"Authentication {outcome} for user {username}",
user=user,
session_id=session_id,
request=request,
severity=severity,
outcome=outcome,
risk_factors=risk_factors,
compliance_standards=[ComplianceStandard.SOX, ComplianceStandard.ISO27001]
)
def log_admin_action(
self,
admin_user: User,
action: str,
target_resource: str,
request: Request,
session_id: Optional[str] = None,
data_before: Optional[Dict[str, Any]] = None,
data_after: Optional[Dict[str, Any]] = None,
affected_user_id: Optional[int] = None
) -> EnhancedAuditLog:
"""
Log administrative actions for compliance
"""
return self.log_security_event(
event_type=SecurityEventType.CONFIGURATION_CHANGE,
title=f"Administrative action: {action}",
description=f"Admin {admin_user.username} performed {action} on {target_resource}",
user=admin_user,
session_id=session_id,
request=request,
severity=SecurityEventSeverity.MEDIUM,
resource_type="admin",
resource_id=target_resource,
data_before=data_before,
data_after=data_after,
compliance_standards=[ComplianceStandard.SOX, ComplianceStandard.SOC2],
tags=["admin_action", "configuration_change"],
custom_fields={
"affected_user_id": affected_user_id
}
)
def create_security_alert(
self,
rule_id: str,
rule_name: str,
title: str,
description: str,
severity: SecurityEventSeverity,
triggering_events: List[str],
confidence: int = 100,
time_window_minutes: Optional[int] = None,
affected_users: Optional[List[int]] = None,
affected_resources: Optional[List[str]] = None
) -> SecurityAlert:
"""
Create a security alert based on detected patterns
"""
alert_id = str(uuid.uuid4())
alert = SecurityAlert(
alert_id=alert_id,
rule_id=rule_id,
rule_name=rule_name,
title=title,
description=description,
severity=severity.value,
confidence=confidence,
event_count=len(triggering_events),
time_window_minutes=time_window_minutes,
first_seen=datetime.now(timezone.utc),
last_seen=datetime.now(timezone.utc)
)
# Set JSON fields
alert.triggering_events = json.dumps(triggering_events)
if affected_users:
alert.affected_users = json.dumps(affected_users)
if affected_resources:
alert.affected_resources = json.dumps(affected_resources)
self.db.add(alert)
self.db.commit()
logger.warning(
f"Security alert created: {title}",
extra={
"alert_id": alert_id,
"severity": severity.value,
"confidence": confidence,
"event_count": len(triggering_events)
}
)
return alert
def search_audit_logs(
self,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
event_types: Optional[List[SecurityEventType]] = None,
severities: Optional[List[SecurityEventSeverity]] = None,
user_ids: Optional[List[int]] = None,
source_ips: Optional[List[str]] = None,
resource_types: Optional[List[str]] = None,
outcomes: Optional[List[str]] = None,
min_risk_score: Optional[int] = None,
correlation_id: Optional[str] = None,
limit: int = 1000,
offset: int = 0
) -> List[EnhancedAuditLog]:
"""
Search audit logs with comprehensive filtering
"""
query = self.db.query(EnhancedAuditLog)
# Apply filters
if start_date:
query = query.filter(EnhancedAuditLog.timestamp >= start_date)
if end_date:
query = query.filter(EnhancedAuditLog.timestamp <= end_date)
if event_types:
query = query.filter(EnhancedAuditLog.event_type.in_([et.value for et in event_types]))
if severities:
query = query.filter(EnhancedAuditLog.severity.in_([s.value for s in severities]))
if user_ids:
query = query.filter(EnhancedAuditLog.user_id.in_(user_ids))
if source_ips:
query = query.filter(EnhancedAuditLog.source_ip.in_(source_ips))
if resource_types:
query = query.filter(EnhancedAuditLog.resource_type.in_(resource_types))
if outcomes:
query = query.filter(EnhancedAuditLog.outcome.in_(outcomes))
if min_risk_score is not None:
query = query.filter(EnhancedAuditLog.risk_score >= min_risk_score)
if correlation_id:
query = query.filter(EnhancedAuditLog.correlation_id == correlation_id)
return query.order_by(EnhancedAuditLog.timestamp.desc()).offset(offset).limit(limit).all()
def generate_compliance_report(
self,
standard: ComplianceStandard,
start_date: datetime,
end_date: datetime,
generated_by: User,
report_type: str = "periodic"
) -> ComplianceReport:
"""
Generate compliance report for specified standard and date range
"""
report_id = str(uuid.uuid4())
# Query relevant audit logs
logs = self.search_audit_logs(
start_date=start_date,
end_date=end_date
)
# Filter logs relevant to the compliance standard
relevant_logs = [
log for log in logs
if standard.value in (log.get_compliance_standards() or [])
]
# Calculate metrics
total_events = len(relevant_logs)
security_events = len([log for log in relevant_logs if log.event_category == "security"])
violations = len([log for log in relevant_logs if log.outcome in ["failure", "blocked"]])
high_risk_events = len([log for log in relevant_logs if log.risk_score >= 70])
# Generate report content
summary = {
"total_events": total_events,
"security_events": security_events,
"violations": violations,
"high_risk_events": high_risk_events,
"compliance_percentage": ((total_events - violations) / total_events * 100) if total_events > 0 else 100
}
report = ComplianceReport(
report_id=report_id,
standard=standard.value,
report_type=report_type,
title=f"{standard.value.upper()} Compliance Report",
description=f"Compliance report for {standard.value.upper()} from {start_date.date()} to {end_date.date()}",
start_date=start_date,
end_date=end_date,
summary=json.dumps(summary),
total_events=total_events,
security_events=security_events,
violations=violations,
high_risk_events=high_risk_events,
generated_by=generated_by.id,
status="ready"
)
self.db.add(report)
self.db.commit()
logger.info(
f"Compliance report generated: {standard.value}",
extra={
"report_id": report_id,
"total_events": total_events,
"violations": violations
}
)
return report
def cleanup_old_logs(self) -> int:
"""
Clean up old audit logs based on retention policies
"""
# Get active retention policies
policies = self.db.query(AuditRetentionPolicy).filter(
AuditRetentionPolicy.is_active == True
).order_by(AuditRetentionPolicy.priority.desc()).all()
cleaned_count = 0
for policy in policies:
cutoff_date = datetime.now(timezone.utc) - timedelta(days=policy.retention_days)
# Build query for logs to delete
query = self.db.query(EnhancedAuditLog).filter(
EnhancedAuditLog.timestamp < cutoff_date
)
# Apply event type filter if specified
if policy.event_types:
event_types = json.loads(policy.event_types)
query = query.filter(EnhancedAuditLog.event_type.in_(event_types))
# Apply compliance standards filter if specified
if policy.compliance_standards:
standards = json.loads(policy.compliance_standards)
# This is a simplified check - in practice, you'd want more sophisticated filtering
for standard in standards:
query = query.filter(EnhancedAuditLog.compliance_standards.contains(standard))
# Delete matching logs
count = query.count()
query.delete(synchronize_session=False)
cleaned_count += count
logger.info(f"Cleaned {count} logs using policy {policy.policy_name}")
self.db.commit()
return cleaned_count
def _categorize_event(self, event_type: SecurityEventType) -> str:
"""Categorize event type into broader categories"""
auth_events = {
SecurityEventType.LOGIN_SUCCESS, SecurityEventType.LOGIN_FAILURE,
SecurityEventType.LOGOUT, SecurityEventType.SESSION_EXPIRED,
SecurityEventType.PASSWORD_CHANGE, SecurityEventType.ACCOUNT_LOCKED
}
security_events = {
SecurityEventType.SUSPICIOUS_ACTIVITY, SecurityEventType.ATTACK_DETECTED,
SecurityEventType.SECURITY_VIOLATION, SecurityEventType.IP_BLOCKED,
SecurityEventType.ACCESS_DENIED, SecurityEventType.UNAUTHORIZED_ACCESS
}
data_events = {
SecurityEventType.DATA_READ, SecurityEventType.DATA_WRITE,
SecurityEventType.DATA_DELETE, SecurityEventType.DATA_EXPORT,
SecurityEventType.BULK_OPERATION
}
if event_type in auth_events:
return "authentication"
elif event_type in security_events:
return "security"
elif event_type in data_events:
return "data_access"
else:
return "system"
def _calculate_risk_score(
self,
event_type: SecurityEventType,
severity: SecurityEventSeverity,
risk_factors: Optional[List[str]],
threat_indicators: Optional[List[str]]
) -> int:
"""Calculate risk score for the event"""
base_scores = {
SecurityEventSeverity.CRITICAL: 80,
SecurityEventSeverity.HIGH: 60,
SecurityEventSeverity.MEDIUM: 40,
SecurityEventSeverity.LOW: 20,
SecurityEventSeverity.INFO: 10
}
score = base_scores.get(severity, 10)
# Add points for risk factors
if risk_factors:
score += len(risk_factors) * 5
# Add points for threat indicators
if threat_indicators:
score += len(threat_indicators) * 10
# Event type modifiers
high_risk_events = {
SecurityEventType.ATTACK_DETECTED,
SecurityEventType.PRIVILEGE_ESCALATION,
SecurityEventType.UNAUTHORIZED_ACCESS
}
if event_type in high_risk_events:
score += 20
return min(score, 100) # Cap at 100
def _check_security_alerts(self, audit_log: EnhancedAuditLog) -> None:
"""Check if audit log should trigger security alerts"""
# Example: Multiple failed logins from same IP
if audit_log.event_type == SecurityEventType.LOGIN_FAILURE.value:
recent_failures = self.db.query(EnhancedAuditLog).filter(
and_(
EnhancedAuditLog.event_type == SecurityEventType.LOGIN_FAILURE.value,
EnhancedAuditLog.source_ip == audit_log.source_ip,
EnhancedAuditLog.timestamp >= datetime.now(timezone.utc) - timedelta(minutes=15)
)
).count()
if recent_failures >= 5:
self.create_security_alert(
rule_id="failed_login_threshold",
rule_name="Multiple Failed Logins",
title=f"Multiple failed logins from {audit_log.source_ip}",
description=f"{recent_failures} failed login attempts in 15 minutes",
severity=SecurityEventSeverity.HIGH,
triggering_events=[audit_log.event_id],
time_window_minutes=15
)
# Example: High risk score threshold
if audit_log.risk_score >= 80:
self.create_security_alert(
rule_id="high_risk_event",
rule_name="High Risk Security Event",
title=f"High risk event detected: {audit_log.title}",
description=f"Event with risk score {audit_log.risk_score} detected",
severity=SecurityEventSeverity.HIGH,
triggering_events=[audit_log.event_id],
confidence=audit_log.risk_score
)
def _send_to_siem(self, audit_log: EnhancedAuditLog) -> None:
"""Send audit log to configured SIEM systems"""
# Get active SIEM integrations
integrations = self.db.query(SIEMIntegration).filter(
SIEMIntegration.is_active == True
).all()
for integration in integrations:
try:
# Check if event should be sent based on filters
if self._should_send_to_siem(audit_log, integration):
# In a real implementation, this would send to the actual SIEM
# For now, just log the intent
logger.debug(
f"Sending event to SIEM {integration.integration_name}",
extra={"event_id": audit_log.event_id}
)
# Update statistics
integration.events_sent += 1
integration.last_sync = datetime.now(timezone.utc)
except Exception as e:
logger.error(f"Failed to send to SIEM {integration.integration_name}: {str(e)}")
integration.errors_count += 1
integration.last_error = str(e)
integration.is_healthy = False
def _should_send_to_siem(self, audit_log: EnhancedAuditLog, integration: SIEMIntegration) -> bool:
"""Check if audit log should be sent to specific SIEM integration"""
# Check severity threshold
severity_order = ["info", "low", "medium", "high", "critical"]
if severity_order.index(audit_log.severity) < severity_order.index(integration.severity_threshold):
return False
# Check event type filter
if integration.event_types:
allowed_types = json.loads(integration.event_types)
if audit_log.event_type not in allowed_types:
return False
return True
def _get_client_ip(self, request: Request) -> str:
"""Extract client IP from request"""
forwarded_for = request.headers.get("x-forwarded-for")
if forwarded_for:
return forwarded_for.split(",")[0].strip()
real_ip = request.headers.get("x-real-ip")
if real_ip:
return real_ip
return request.client.host if request.client else "unknown"
def _get_geographic_info(self, ip_address: Optional[str]) -> tuple:
"""Get geographic information for IP address"""
# Placeholder - would integrate with GeoIP service
return None, None, None
@contextmanager
def audit_context(
db: Session,
user: Optional[User] = None,
session_id: Optional[str] = None,
request: Optional[Request] = None,
correlation_id: Optional[str] = None
):
"""Context manager for audit logging"""
auditor = EnhancedAuditLogger(db)
# Set correlation ID for this context
if not correlation_id:
correlation_id = str(uuid.uuid4())
try:
yield auditor
except Exception as e:
# Log the exception as a security event
auditor.log_security_event(
event_type=SecurityEventType.SECURITY_VIOLATION,
title="System error occurred",
description=f"Exception in audit context: {str(e)}",
user=user,
session_id=session_id,
request=request,
severity=SecurityEventSeverity.HIGH,
outcome="error",
correlation_id=correlation_id
)
raise
def get_enhanced_audit_logger(db: Session) -> EnhancedAuditLogger:
"""Dependency injection for enhanced audit logger"""
return EnhancedAuditLogger(db)

540
app/utils/enhanced_auth.py Normal file
View File

@@ -0,0 +1,540 @@
"""
Enhanced Authentication Utilities
Provides advanced authentication features including:
- Password complexity validation
- Account lockout protection
- Session management
- Login attempt tracking
- Suspicious activity detection
"""
import re
import time
from datetime import datetime, timedelta, timezone
from typing import Optional, Dict, List, Tuple
from sqlalchemy.orm import Session
from sqlalchemy import func, and_
from fastapi import HTTPException, status, Request
from passlib.context import CryptContext
from app.models.user import User
try:
# Optional: enhanced features may rely on this model
from app.models.auth import LoginAttempt # type: ignore
except Exception: # pragma: no cover - older schemas may not include this model
LoginAttempt = None # type: ignore
from app.config import settings
from app.utils.logging import app_logger
logger = app_logger.bind(name="enhanced_auth")
# Password complexity configuration
PASSWORD_CONFIG = {
"min_length": 8,
"max_length": 128,
"require_uppercase": True,
"require_lowercase": True,
"require_digits": True,
"require_special_chars": True,
"special_chars": "!@#$%^&*()_+-=[]{}|;:,.<>?",
"max_consecutive_chars": 3,
"prevent_common_passwords": True,
}
# Account lockout configuration
LOCKOUT_CONFIG = {
"max_attempts": 5,
"lockout_duration": 900, # 15 minutes
"window_duration": 900, # 15 minutes
"progressive_delay": True,
"notify_on_lockout": True,
}
# Common weak passwords to prevent
COMMON_PASSWORDS = {
"password", "123456", "password123", "admin", "qwerty", "letmein",
"welcome", "monkey", "1234567890", "password1", "123456789",
"welcome123", "admin123", "root", "toor", "pass", "test", "guest",
"user", "login", "default", "changeme", "secret", "administrator"
}
# Password validation context
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
class PasswordValidator:
"""Advanced password validation with security requirements"""
@staticmethod
def validate_password_strength(password: str) -> Tuple[bool, List[str]]:
"""Validate password strength and return detailed feedback"""
errors = []
# Length check
if len(password) < PASSWORD_CONFIG["min_length"]:
errors.append(f"Password must be at least {PASSWORD_CONFIG['min_length']} characters long")
if len(password) > PASSWORD_CONFIG["max_length"]:
errors.append(f"Password must not exceed {PASSWORD_CONFIG['max_length']} characters")
# Character requirements
if PASSWORD_CONFIG["require_uppercase"] and not re.search(r'[A-Z]', password):
errors.append("Password must contain at least one uppercase letter")
if PASSWORD_CONFIG["require_lowercase"] and not re.search(r'[a-z]', password):
errors.append("Password must contain at least one lowercase letter")
if PASSWORD_CONFIG["require_digits"] and not re.search(r'\d', password):
errors.append("Password must contain at least one digit")
if PASSWORD_CONFIG["require_special_chars"]:
special_chars = PASSWORD_CONFIG["special_chars"]
if not re.search(f'[{re.escape(special_chars)}]', password):
errors.append(f"Password must contain at least one special character ({special_chars[:10]}...)")
# Consecutive character check
max_consecutive = PASSWORD_CONFIG["max_consecutive_chars"]
for i in range(len(password) - max_consecutive):
substr = password[i:i + max_consecutive + 1]
if len(set(substr)) == 1: # All same character
errors.append(f"Password cannot contain more than {max_consecutive} consecutive identical characters")
break
# Common password check
if PASSWORD_CONFIG["prevent_common_passwords"]:
if password.lower() in COMMON_PASSWORDS:
errors.append("Password is too common and easily guessable")
# Sequential character check
if PasswordValidator._contains_sequence(password):
errors.append("Password cannot contain common keyboard sequences")
# Dictionary word check (basic)
if PasswordValidator._is_dictionary_word(password):
errors.append("Password should not be a common dictionary word")
return len(errors) == 0, errors
@staticmethod
def _contains_sequence(password: str) -> bool:
"""Check for common keyboard sequences"""
sequences = [
"123456789", "987654321", "abcdefgh", "zyxwvuts",
"qwertyui", "asdfghjk", "zxcvbnm", "uioplkjh",
"qazwsxed", "plmoknij"
]
password_lower = password.lower()
for seq in sequences:
if seq in password_lower or seq[::-1] in password_lower:
return True
return False
@staticmethod
def _is_dictionary_word(password: str) -> bool:
"""Basic check for common dictionary words"""
# Simple check for common English words
common_words = {
"password", "computer", "internet", "database", "security",
"welcome", "hello", "world", "admin", "user", "login",
"system", "server", "network", "access", "control"
}
return password.lower() in common_words
@staticmethod
def generate_password_strength_score(password: str) -> int:
"""Generate a password strength score from 0-100"""
score = 0
# Length score (up to 25 points)
score += min(25, len(password) * 2)
# Character diversity (up to 40 points)
if re.search(r'[a-z]', password):
score += 5
if re.search(r'[A-Z]', password):
score += 5
if re.search(r'\d', password):
score += 5
if re.search(r'[!@#$%^&*()_+\-=\[\]{}|;:,.<>?]', password):
score += 10
# Bonus for multiple character types
char_types = sum([
bool(re.search(r'[a-z]', password)),
bool(re.search(r'[A-Z]', password)),
bool(re.search(r'\d', password)),
bool(re.search(r'[!@#$%^&*()_+\-=\[\]{}|;:,.<>?]', password))
])
score += char_types * 3
# Length bonus
if len(password) >= 12:
score += 10
if len(password) >= 16:
score += 5
# Penalties
if password.lower() in COMMON_PASSWORDS:
score -= 25
# Check for patterns
if re.search(r'(.)\1{2,}', password): # Repeated characters
score -= 10
return max(0, min(100, score))
class AccountLockoutManager:
"""Manages account lockout and login attempt tracking"""
@staticmethod
def record_login_attempt(
db: Session,
username: str,
success: bool,
ip_address: str,
user_agent: str,
failure_reason: Optional[str] = None
) -> None:
"""Record a login attempt in the database"""
try:
if LoginAttempt is None:
# Schema not available; log-only fallback
logger.info(
"Login attempt (no model)",
username=username,
success=success,
ip=ip_address,
reason=failure_reason
)
return
attempt = LoginAttempt( # type: ignore[call-arg]
username=username,
ip_address=ip_address,
user_agent=user_agent,
success=1 if success else 0,
failure_reason=failure_reason,
timestamp=datetime.now(timezone.utc)
)
db.add(attempt)
db.commit()
logger.info(
"Login attempt recorded",
username=username,
success=success,
ip=ip_address,
reason=failure_reason
)
except Exception as e:
logger.error("Failed to record login attempt", error=str(e))
db.rollback()
@staticmethod
def is_account_locked(db: Session, username: str) -> Tuple[bool, Optional[datetime]]:
"""Check if an account is locked due to failed attempts"""
try:
if LoginAttempt is None:
return False, None
now = datetime.now(timezone.utc)
window_start = now - timedelta(seconds=LOCKOUT_CONFIG["window_duration"])
# Count failed attempts within the window
failed_attempts = db.query(func.count(LoginAttempt.id)).filter( # type: ignore[attr-defined]
and_(
LoginAttempt.username == username,
LoginAttempt.success == 0,
LoginAttempt.timestamp >= window_start
)
).scalar()
if failed_attempts >= LOCKOUT_CONFIG["max_attempts"]:
# Get the time of the last failed attempt
last_attempt = db.query(LoginAttempt.timestamp).filter( # type: ignore[attr-defined]
and_(
LoginAttempt.username == username,
LoginAttempt.success == 0
)
).order_by(LoginAttempt.timestamp.desc()).first()
if last_attempt:
unlock_time = last_attempt[0] + timedelta(seconds=LOCKOUT_CONFIG["lockout_duration"])
if now < unlock_time:
return True, unlock_time
return False, None
except Exception as e:
logger.error("Failed to check account lockout", error=str(e))
return False, None
@staticmethod
def get_lockout_info(db: Session, username: str) -> Dict[str, any]:
"""Get detailed lockout information for an account"""
try:
now = datetime.now(timezone.utc)
window_start = now - timedelta(seconds=LOCKOUT_CONFIG["window_duration"])
if LoginAttempt is None:
return {
"is_locked": False,
"failed_attempts": 0,
"max_attempts": LOCKOUT_CONFIG["max_attempts"],
"attempts_remaining": LOCKOUT_CONFIG["max_attempts"],
"unlock_time": None,
"window_start": window_start.isoformat(),
"lockout_duration": LOCKOUT_CONFIG["lockout_duration"],
}
# Get recent failed attempts
failed_attempts = db.query(LoginAttempt).filter( # type: ignore[arg-type]
and_(
LoginAttempt.username == username,
LoginAttempt.success == 0,
LoginAttempt.timestamp >= window_start
)
).order_by(LoginAttempt.timestamp.desc()).all()
failed_count = len(failed_attempts)
is_locked, unlock_time = AccountLockoutManager.is_account_locked(db, username)
return {
"is_locked": is_locked,
"failed_attempts": failed_count,
"max_attempts": LOCKOUT_CONFIG["max_attempts"],
"attempts_remaining": max(0, LOCKOUT_CONFIG["max_attempts"] - failed_count),
"unlock_time": unlock_time.isoformat() if unlock_time else None,
"window_start": window_start.isoformat(),
"lockout_duration": LOCKOUT_CONFIG["lockout_duration"],
}
except Exception as e:
logger.error("Failed to get lockout info", error=str(e))
return {
"is_locked": False,
"failed_attempts": 0,
"max_attempts": LOCKOUT_CONFIG["max_attempts"],
"attempts_remaining": LOCKOUT_CONFIG["max_attempts"],
"unlock_time": None,
"window_start": window_start.isoformat() if 'window_start' in locals() else None,
"lockout_duration": LOCKOUT_CONFIG["lockout_duration"],
}
@staticmethod
def reset_failed_attempts(db: Session, username: str) -> None:
"""Reset failed login attempts for successful login"""
try:
# We don't delete the records, just mark successful login
# The lockout check will naturally reset due to time window
logger.info("Failed attempts naturally reset for successful login", username=username)
except Exception as e:
logger.error("Failed to reset attempts", error=str(e))
class SuspiciousActivityDetector:
"""Detects and reports suspicious authentication activity"""
@staticmethod
def detect_suspicious_patterns(db: Session, timeframe_hours: int = 24) -> List[Dict[str, any]]:
"""Detect suspicious login patterns"""
alerts = []
try:
if LoginAttempt is None:
return []
cutoff_time = datetime.now(timezone.utc) - timedelta(hours=timeframe_hours)
# Get all login attempts in timeframe
attempts = db.query(LoginAttempt).filter( # type: ignore[arg-type]
LoginAttempt.timestamp >= cutoff_time
).all()
# Analyze patterns
ip_attempts = {}
username_attempts = {}
for attempt in attempts:
# Group by IP
if attempt.ip_address not in ip_attempts:
ip_attempts[attempt.ip_address] = []
ip_attempts[attempt.ip_address].append(attempt)
# Group by username
if attempt.username not in username_attempts:
username_attempts[attempt.username] = []
username_attempts[attempt.username].append(attempt)
# Check for suspicious IP activity
for ip, attempts_list in ip_attempts.items():
failed_attempts = [a for a in attempts_list if not a.success]
if len(failed_attempts) >= 10: # Many failed attempts from one IP
alerts.append({
"type": "suspicious_ip",
"severity": "high",
"ip_address": ip,
"failed_attempts": len(failed_attempts),
"usernames_targeted": list(set(a.username for a in failed_attempts)),
"timeframe": f"{timeframe_hours} hours"
})
# Check for account targeting
for username, attempts_list in username_attempts.items():
failed_attempts = [a for a in attempts_list if not a.success]
unique_ips = set(a.ip_address for a in failed_attempts)
if len(failed_attempts) >= 5 and len(unique_ips) > 2:
alerts.append({
"type": "account_targeted",
"severity": "medium",
"username": username,
"failed_attempts": len(failed_attempts),
"source_ips": list(unique_ips),
"timeframe": f"{timeframe_hours} hours"
})
return alerts
except Exception as e:
logger.error("Failed to detect suspicious patterns", error=str(e))
return []
@staticmethod
def is_login_suspicious(
db: Session,
username: str,
ip_address: str,
user_agent: str
) -> Tuple[bool, List[str]]:
"""Check if a login attempt is suspicious"""
warnings = []
try:
if LoginAttempt is None:
return False, []
# Check for unusual IP
recent_ips = db.query(LoginAttempt.ip_address).filter( # type: ignore[attr-defined]
and_(
LoginAttempt.username == username,
LoginAttempt.success == 1,
LoginAttempt.timestamp >= datetime.now(timezone.utc) - timedelta(days=30)
)
).distinct().all()
known_ips = {ip[0] for ip in recent_ips}
if ip_address not in known_ips and len(known_ips) > 0:
warnings.append("Login from new IP address")
# Check for unusual time
now = datetime.now(timezone.utc)
if now.hour < 6 or now.hour > 22: # Outside business hours
warnings.append("Login outside normal business hours")
# Check for rapid attempts from same IP
recent_attempts = db.query(func.count(LoginAttempt.id)).filter( # type: ignore[attr-defined]
and_(
LoginAttempt.ip_address == ip_address,
LoginAttempt.timestamp >= datetime.now(timezone.utc) - timedelta(minutes=5)
)
).scalar()
if recent_attempts > 3:
warnings.append("Multiple rapid login attempts from same IP")
return len(warnings) > 0, warnings
except Exception as e:
logger.error("Failed to check suspicious login", error=str(e))
return False, []
def validate_and_authenticate_user(
db: Session,
username: str,
password: str,
request: Request
) -> Tuple[Optional[User], List[str]]:
"""Enhanced user authentication with security checks"""
errors = []
try:
# Extract request information
ip_address = get_client_ip(request)
user_agent = request.headers.get("user-agent", "")
# Check account lockout
is_locked, unlock_time = AccountLockoutManager.is_account_locked(db, username)
if is_locked:
AccountLockoutManager.record_login_attempt(
db, username, False, ip_address, user_agent, "Account locked"
)
unlock_str = unlock_time.strftime("%Y-%m-%d %H:%M:%S UTC") if unlock_time else "unknown"
errors.append(f"Account is locked due to too many failed attempts. Try again after {unlock_str}")
return None, errors
# Find user
user = db.query(User).filter(User.username == username).first()
if not user:
AccountLockoutManager.record_login_attempt(
db, username, False, ip_address, user_agent, "User not found"
)
errors.append("Invalid username or password")
return None, errors
# Check if user is active
if not user.is_active:
AccountLockoutManager.record_login_attempt(
db, username, False, ip_address, user_agent, "User account disabled"
)
errors.append("User account is disabled")
return None, errors
# Verify password
from app.auth.security import verify_password
if not verify_password(password, user.hashed_password):
AccountLockoutManager.record_login_attempt(
db, username, False, ip_address, user_agent, "Invalid password"
)
errors.append("Invalid username or password")
return None, errors
# Check for suspicious activity
is_suspicious, warnings = SuspiciousActivityDetector.is_login_suspicious(
db, username, ip_address, user_agent
)
if is_suspicious:
logger.warning(
"Suspicious login detected",
username=username,
ip=ip_address,
warnings=warnings
)
# You could require additional verification here
# Successful login
AccountLockoutManager.record_login_attempt(
db, username, True, ip_address, user_agent, None
)
# Update last login time
user.last_login = datetime.now(timezone.utc)
db.commit()
return user, []
except Exception as e:
logger.error("Authentication error", error=str(e))
errors.append("Authentication service temporarily unavailable")
return None, errors
def get_client_ip(request: Request) -> str:
"""Extract client IP from request headers"""
forwarded_for = request.headers.get("x-forwarded-for")
if forwarded_for:
return forwarded_for.split(",")[0].strip()
real_ip = request.headers.get("x-real-ip")
if real_ip:
return real_ip
if request.client:
return request.client.host
return "unknown"

342
app/utils/file_security.py Normal file
View File

@@ -0,0 +1,342 @@
"""
File Security and Validation Utilities
Comprehensive security validation for file uploads to prevent:
- Path traversal attacks
- File type spoofing
- DoS attacks via large files
- Malicious file uploads
- Directory traversal
"""
import os
import re
import hashlib
from pathlib import Path
from typing import List, Optional, Tuple, Dict, Any
from fastapi import HTTPException, UploadFile
# Try to import python-magic, fall back to extension-based detection
try:
import magic
MAGIC_AVAILABLE = True
except ImportError:
MAGIC_AVAILABLE = False
# File size limits (bytes)
MAX_FILE_SIZES = {
'document': 10 * 1024 * 1024, # 10MB for documents
'csv': 50 * 1024 * 1024, # 50MB for CSV imports
'template': 5 * 1024 * 1024, # 5MB for templates
'image': 2 * 1024 * 1024, # 2MB for images
'default': 10 * 1024 * 1024, # 10MB default
}
# Allowed MIME types for security
ALLOWED_MIME_TYPES = {
'document': {
'application/pdf',
'application/msword',
'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
},
'csv': {
'text/csv',
'text/plain',
'application/csv',
},
'template': {
'application/pdf',
'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
},
'image': {
'image/jpeg',
'image/png',
'image/gif',
'image/webp',
}
}
# File extensions mapping to categories
FILE_EXTENSIONS = {
'document': {'.pdf', '.doc', '.docx'},
'csv': {'.csv', '.txt'},
'template': {'.pdf', '.docx'},
'image': {'.jpg', '.jpeg', '.png', '.gif', '.webp'},
}
# Dangerous file extensions that should never be uploaded
DANGEROUS_EXTENSIONS = {
'.exe', '.bat', '.cmd', '.com', '.scr', '.pif', '.vbs', '.js',
'.jar', '.app', '.deb', '.pkg', '.dmg', '.rpm', '.msi', '.dll',
'.so', '.dylib', '.sys', '.drv', '.ocx', '.cpl', '.scf', '.lnk',
'.ps1', '.ps2', '.psc1', '.psc2', '.msh', '.msh1', '.msh2', '.mshxml',
'.msh1xml', '.msh2xml', '.scf', '.inf', '.reg', '.vb', '.vbe', '.asp',
'.aspx', '.php', '.jsp', '.jspx', '.py', '.rb', '.pl', '.sh', '.bash'
}
class FileSecurityValidator:
"""Comprehensive file security validation"""
def __init__(self):
self.magic_mime = None
if MAGIC_AVAILABLE:
try:
self.magic_mime = magic.Magic(mime=True)
except Exception:
self.magic_mime = None
def sanitize_filename(self, filename: str) -> str:
"""Sanitize filename to prevent path traversal and other attacks"""
if not filename:
raise HTTPException(status_code=400, detail="Filename cannot be empty")
# Remove any path separators and dangerous characters
filename = os.path.basename(filename)
filename = re.sub(r'[<>:"/\\|?*\x00-\x1f]', '_', filename)
# Remove leading/trailing dots and spaces
filename = filename.strip('. ')
# Ensure filename is not empty after sanitization
if not filename:
raise HTTPException(status_code=400, detail="Invalid filename")
# Limit filename length
if len(filename) > 255:
name, ext = os.path.splitext(filename)
filename = name[:250] + ext
return filename
def validate_file_extension(self, filename: str, category: str) -> str:
"""Validate file extension against allowed types"""
if not filename:
raise HTTPException(status_code=400, detail="Filename required")
# Get file extension
_, ext = os.path.splitext(filename.lower())
# Check for dangerous extensions
if ext in DANGEROUS_EXTENSIONS:
raise HTTPException(
status_code=400,
detail=f"File type '{ext}' is not allowed for security reasons"
)
# Check against allowed extensions for category
allowed_extensions = FILE_EXTENSIONS.get(category, set())
if ext not in allowed_extensions:
# Standardized message expected by tests
raise HTTPException(status_code=400, detail="Invalid file type")
return ext
def _detect_mime_from_content(self, content: bytes, filename: str) -> str:
"""Detect MIME type from file content or extension"""
if self.magic_mime:
try:
return self.magic_mime.from_buffer(content)
except Exception:
pass
# Fallback to extension-based detection and basic content inspection
_, ext = os.path.splitext(filename.lower())
# Basic content-based detection for common file types
if content.startswith(b'%PDF'):
return 'application/pdf'
elif content.startswith(b'PK\x03\x04') and ext in ['.docx', '.xlsx', '.pptx']:
if ext == '.docx':
return 'application/vnd.openxmlformats-officedocument.wordprocessingml.document'
elif content.startswith(b'\xd0\xcf\x11\xe0') and ext in ['.doc', '.xls', '.ppt']:
if ext == '.doc':
return 'application/msword'
elif content.startswith(b'\xff\xd8\xff'):
return 'image/jpeg'
elif content.startswith(b'\x89PNG'):
return 'image/png'
elif content.startswith(b'GIF8'):
return 'image/gif'
elif content.startswith(b'RIFF') and b'WEBP' in content[:20]:
return 'image/webp'
# Extension-based fallback
extension_to_mime = {
'.pdf': 'application/pdf',
'.doc': 'application/msword',
'.docx': 'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
'.csv': 'text/csv',
'.txt': 'text/plain',
'.jpg': 'image/jpeg',
'.jpeg': 'image/jpeg',
'.png': 'image/png',
'.gif': 'image/gif',
'.webp': 'image/webp',
}
return extension_to_mime.get(ext, 'application/octet-stream')
def validate_mime_type(self, content: bytes, filename: str, category: str) -> str:
"""Validate MIME type using content inspection and file extension"""
if not content:
raise HTTPException(status_code=400, detail="File content is empty")
# Detect MIME type
detected_mime = self._detect_mime_from_content(content, filename)
# Check against allowed MIME types
allowed_mimes = ALLOWED_MIME_TYPES.get(category, set())
if detected_mime not in allowed_mimes:
# Standardized message expected by tests
raise HTTPException(status_code=400, detail="Invalid file type")
return detected_mime
def validate_file_size(self, content: bytes, category: str) -> int:
"""Validate file size against limits"""
size = len(content)
max_size = MAX_FILE_SIZES.get(category, MAX_FILE_SIZES['default'])
if size == 0:
# Standardized message expected by tests
raise HTTPException(status_code=400, detail="No file uploaded")
if size > max_size:
# Standardized message expected by tests
raise HTTPException(status_code=400, detail="File too large")
return size
def scan_for_malware_patterns(self, content: bytes, filename: str) -> None:
"""Basic malware pattern detection"""
# Check for common malware signatures
malware_patterns = [
b'<script',
b'javascript:',
b'vbscript:',
b'data:text/html',
b'<?php',
b'<% ',
b'eval(',
b'exec(',
b'system(',
b'shell_exec(',
b'passthru(',
b'cmd.exe',
b'powershell',
]
content_lower = content.lower()
for pattern in malware_patterns:
if pattern in content_lower:
raise HTTPException(
status_code=400,
detail=f"File contains potentially malicious content and cannot be uploaded"
)
def generate_secure_path(self, base_dir: str, filename: str, subdir: Optional[str] = None) -> str:
"""Generate secure file path preventing directory traversal"""
# Sanitize filename
safe_filename = self.sanitize_filename(filename)
# Build path components
path_parts = [base_dir]
if subdir:
# Sanitize subdirectory name
safe_subdir = re.sub(r'[^a-zA-Z0-9_-]', '_', subdir)
path_parts.append(safe_subdir)
path_parts.append(safe_filename)
# Use Path to safely join and resolve
full_path = Path(*path_parts).resolve()
base_path = Path(base_dir).resolve()
# Ensure the resolved path is within the base directory
if not str(full_path).startswith(str(base_path)):
raise HTTPException(
status_code=400,
detail="Invalid file path - directory traversal detected"
)
return str(full_path)
async def validate_upload_file(
self,
file: UploadFile,
category: str,
max_size_override: Optional[int] = None
) -> Tuple[bytes, str, str, str]:
"""
Comprehensive validation of uploaded file
Returns: (content, sanitized_filename, file_extension, mime_type)
"""
# Check if file was uploaded
if not file.filename:
raise HTTPException(status_code=400, detail="No file uploaded")
# Read file content
content = await file.read()
# Validate file size
if max_size_override:
max_size = max_size_override
if len(content) > max_size:
raise HTTPException(
status_code=400,
detail=f"File size exceeds limit ({max_size:,} bytes)"
)
else:
size = self.validate_file_size(content, category)
# Sanitize filename
safe_filename = self.sanitize_filename(file.filename)
# Validate file extension
file_ext = self.validate_file_extension(safe_filename, category)
# Validate MIME type using actual file content
mime_type = self.validate_mime_type(content, safe_filename, category)
# Scan for malware patterns
self.scan_for_malware_patterns(content, safe_filename)
return content, safe_filename, file_ext, mime_type
# Global instance for use across the application
file_validator = FileSecurityValidator()
def validate_csv_content(content: str) -> None:
"""Additional validation for CSV content"""
# Check for SQL injection patterns in CSV content
sql_patterns = [
r'(union\s+select)',
r'(drop\s+table)',
r'(delete\s+from)',
r'(insert\s+into)',
r'(update\s+.*set)',
r'(exec\s*\()',
r'(<script)',
r'(javascript:)',
]
content_lower = content.lower()
for pattern in sql_patterns:
if re.search(pattern, content_lower):
raise HTTPException(
status_code=400,
detail="CSV content contains potentially malicious data"
)
def create_upload_directory(path: str) -> None:
"""Safely create upload directory with proper permissions"""
try:
os.makedirs(path, mode=0o755, exist_ok=True)
except OSError as e:
raise HTTPException(
status_code=500,
detail=f"Could not create upload directory: {str(e)}"
)

View File

@@ -17,6 +17,8 @@ class StructuredLogger:
def __init__(self, name: str, level: str = "INFO"):
self.logger = logging.getLogger(name)
self.logger.setLevel(getattr(logging, level.upper()))
# Support bound context similar to loguru's bind
self._bound_context: Dict[str, Any] = {}
if not self.logger.handlers:
self._setup_handlers()
@@ -68,13 +70,24 @@ class StructuredLogger:
def _log(self, level: int, message: str, **kwargs):
"""Internal method to log with structured data."""
context: Dict[str, Any] = {}
if self._bound_context:
context.update(self._bound_context)
if kwargs:
structured_message = f"{message} | Context: {json.dumps(kwargs, default=str)}"
context.update(kwargs)
if context:
structured_message = f"{message} | Context: {json.dumps(context, default=str)}"
else:
structured_message = message
self.logger.log(level, structured_message)
def bind(self, **kwargs):
"""Bind default context fields (compatibility with loguru-style usage)."""
if kwargs:
self._bound_context.update(kwargs)
return self
class ImportLogger(StructuredLogger):
"""Specialized logger for import operations."""
@@ -261,8 +274,25 @@ def log_function_call(logger: StructuredLogger = None, level: str = "DEBUG"):
return decorator
# Local logger cache and factory to avoid circular imports with app.core.logging
_loggers: dict[str, StructuredLogger] = {}
def get_logger(name: str) -> StructuredLogger:
"""Return a cached StructuredLogger instance.
This implementation is self-contained to avoid importing app.core.logging,
which would create a circular import (core -> utils -> core).
"""
logger = _loggers.get(name)
if logger is None:
logger = StructuredLogger(name, getattr(settings, 'log_level', 'INFO'))
_loggers[name] = logger
return logger
# Pre-configured logger instances
app_logger = StructuredLogger("application")
app_logger = get_logger("application")
import_logger = ImportLogger()
security_logger = SecurityLogger()
database_logger = DatabaseLogger()
@@ -270,16 +300,16 @@ database_logger = DatabaseLogger()
# Convenience functions
def log_info(message: str, **kwargs):
"""Quick info logging."""
app_logger.info(message, **kwargs)
get_logger("application").info(message, **kwargs)
def log_warning(message: str, **kwargs):
"""Quick warning logging."""
app_logger.warning(message, **kwargs)
get_logger("application").warning(message, **kwargs)
def log_error(message: str, **kwargs):
"""Quick error logging."""
app_logger.error(message, **kwargs)
get_logger("application").error(message, **kwargs)
def log_debug(message: str, **kwargs):
"""Quick debug logging."""
app_logger.debug(message, **kwargs)
get_logger("application").debug(message, **kwargs)

View File

@@ -0,0 +1,445 @@
"""
Advanced session management utilities for P2 security features
"""
import secrets
import hashlib
import json
from datetime import datetime, timezone, timedelta
from typing import Optional, List, Dict, Any, Tuple
from sqlalchemy.orm import Session
from sqlalchemy import and_, or_, func
from fastapi import Request, Depends
from user_agents import parse as parse_user_agent
from app.models.sessions import (
UserSession, SessionActivity, SessionConfiguration,
SessionSecurityEvent, SessionStatus
)
from app.models.user import User
from app.core.logging import get_logger
from app.database.base import get_db
logger = get_logger(__name__)
class SessionManager:
"""
Advanced session management with security features
"""
# Default configuration
DEFAULT_SESSION_TIMEOUT = timedelta(hours=8)
DEFAULT_IDLE_TIMEOUT = timedelta(hours=1)
DEFAULT_MAX_CONCURRENT_SESSIONS = 3
def __init__(self, db: Session):
self.db = db
def generate_secure_session_id(self) -> str:
"""Generate cryptographically secure session ID"""
# Generate 64 bytes of random data and hash it
random_bytes = secrets.token_bytes(64)
timestamp = str(datetime.now(timezone.utc).timestamp()).encode()
return hashlib.sha256(random_bytes + timestamp).hexdigest()
def create_session(
self,
user: User,
request: Request,
login_method: str = "password"
) -> UserSession:
"""
Create new secure session with fixation protection
"""
# Generate new session ID (prevents session fixation)
session_id = self.generate_secure_session_id()
# Extract request metadata
ip_address = self._get_client_ip(request)
user_agent = request.headers.get("user-agent", "")
device_fingerprint = self._generate_device_fingerprint(request)
# Get geographic info (placeholder - would integrate with GeoIP service)
country, city = self._get_geographic_info(ip_address)
# Check for suspicious activity
is_suspicious, risk_score = self._assess_login_risk(user, ip_address, user_agent)
# Get session configuration
config = self._get_session_config(user)
session_timeout = timedelta(minutes=config.session_timeout_minutes)
# Enforce concurrent session limits
self._enforce_concurrent_session_limits(user, config.max_concurrent_sessions)
# Create session record
session = UserSession(
session_id=session_id,
user_id=user.id,
ip_address=ip_address,
user_agent=user_agent,
device_fingerprint=device_fingerprint,
country=country,
city=city,
is_suspicious=is_suspicious,
risk_score=risk_score,
status=SessionStatus.ACTIVE,
login_method=login_method,
expires_at=datetime.now(timezone.utc) + session_timeout
)
self.db.add(session)
self.db.flush() # Get session ID
# Log session creation activity
self._log_activity(
session, user, request,
activity_type="session_created",
endpoint="/api/auth/login"
)
# Generate security event if suspicious
if is_suspicious:
self._create_security_event(
session, user,
event_type="suspicious_login",
severity="medium",
description=f"Suspicious login detected: risk score {risk_score}",
ip_address=ip_address,
user_agent=user_agent,
country=country
)
self.db.commit()
logger.info(f"Created session {session_id} for user {user.username} from {ip_address}")
return session
def validate_session(self, session_id: str, request: Request) -> Optional[UserSession]:
"""
Validate session and update activity tracking
"""
session = self.db.query(UserSession).filter(
UserSession.session_id == session_id
).first()
if not session:
return None
# Check if session is expired or revoked
if not session.is_active():
return None
# Check for IP address changes if configured
current_ip = self._get_client_ip(request)
config = self._get_session_config(session.user)
if config.force_logout_on_ip_change and session.ip_address != current_ip:
self._create_security_event(
session, session.user,
event_type="ip_address_change",
severity="medium",
description=f"IP changed from {session.ip_address} to {current_ip}",
ip_address=current_ip,
action_taken="session_revoked"
)
session.revoke_session("ip_address_change")
self.db.commit()
return None
# Check idle timeout
idle_timeout = timedelta(minutes=config.idle_timeout_minutes)
if datetime.now(timezone.utc) - session.last_activity > idle_timeout:
session.status = SessionStatus.EXPIRED
self.db.commit()
return None
# Update last activity
session.last_activity = datetime.now(timezone.utc)
# Log activity
self._log_activity(
session, session.user, request,
activity_type="session_validation",
endpoint=str(request.url.path)
)
self.db.commit()
return session
def revoke_session(self, session_id: str, reason: str = "user_logout") -> bool:
"""Revoke a specific session"""
session = self.db.query(UserSession).filter(
UserSession.session_id == session_id
).first()
if session:
session.revoke_session(reason)
self.db.commit()
logger.info(f"Revoked session {session_id}: {reason}")
return True
return False
def revoke_all_user_sessions(self, user_id: int, reason: str = "admin_action") -> int:
"""Revoke all sessions for a user"""
count = self.db.query(UserSession).filter(
and_(
UserSession.user_id == user_id,
UserSession.status == SessionStatus.ACTIVE
)
).update({
"status": SessionStatus.REVOKED,
"revoked_at": datetime.now(timezone.utc),
"revocation_reason": reason
})
self.db.commit()
logger.info(f"Revoked {count} sessions for user {user_id}: {reason}")
return count
def get_active_sessions(self, user_id: int) -> List[UserSession]:
"""Get all active sessions for a user"""
return self.db.query(UserSession).filter(
and_(
UserSession.user_id == user_id,
UserSession.status == SessionStatus.ACTIVE,
UserSession.expires_at > datetime.now(timezone.utc)
)
).order_by(UserSession.last_activity.desc()).all()
def cleanup_expired_sessions(self) -> int:
"""Clean up expired sessions"""
count = self.db.query(UserSession).filter(
and_(
UserSession.status == SessionStatus.ACTIVE,
UserSession.expires_at <= datetime.now(timezone.utc)
)
).update({
"status": SessionStatus.EXPIRED
})
self.db.commit()
if count > 0:
logger.info(f"Cleaned up {count} expired sessions")
return count
def get_session_statistics(self, user_id: Optional[int] = None) -> Dict[str, Any]:
"""Get session statistics for monitoring"""
query = self.db.query(UserSession)
if user_id:
query = query.filter(UserSession.user_id == user_id)
# Basic counts
total_sessions = query.count()
active_sessions = query.filter(UserSession.status == SessionStatus.ACTIVE).count()
suspicious_sessions = query.filter(UserSession.is_suspicious == True).count()
# Recent activity
last_24h = datetime.now(timezone.utc) - timedelta(days=1)
recent_sessions = query.filter(UserSession.created_at >= last_24h).count()
# Risk distribution
high_risk = query.filter(UserSession.risk_score >= 70).count()
medium_risk = query.filter(
and_(UserSession.risk_score >= 30, UserSession.risk_score < 70)
).count()
low_risk = query.filter(UserSession.risk_score < 30).count()
return {
"total_sessions": total_sessions,
"active_sessions": active_sessions,
"suspicious_sessions": suspicious_sessions,
"recent_sessions_24h": recent_sessions,
"risk_distribution": {
"high": high_risk,
"medium": medium_risk,
"low": low_risk
}
}
def _enforce_concurrent_session_limits(self, user: User, max_sessions: int) -> None:
"""Enforce concurrent session limits"""
active_sessions = self.get_active_sessions(user.id)
if len(active_sessions) >= max_sessions:
# Revoke oldest sessions
sessions_to_revoke = active_sessions[max_sessions-1:]
for session in sessions_to_revoke:
session.revoke_session("concurrent_session_limit")
# Create security event
self._create_security_event(
session, user,
event_type="concurrent_session_limit",
severity="medium",
description=f"Session revoked due to concurrent session limit ({max_sessions})",
action_taken="session_revoked"
)
logger.info(f"Revoked {len(sessions_to_revoke)} sessions for user {user.username} due to concurrent limit")
def _get_session_config(self, user: User) -> SessionConfiguration:
"""Get session configuration for user"""
# Try user-specific config first
config = self.db.query(SessionConfiguration).filter(
SessionConfiguration.user_id == user.id
).first()
if not config:
# Try global config
config = self.db.query(SessionConfiguration).filter(
SessionConfiguration.user_id.is_(None)
).first()
if not config:
# Create default global config
config = SessionConfiguration()
self.db.add(config)
self.db.flush()
return config
def _assess_login_risk(self, user: User, ip_address: str, user_agent: str) -> Tuple[bool, int]:
"""Assess login risk based on historical data"""
risk_score = 0
risk_factors = []
# Check for new IP address
previous_ips = self.db.query(UserSession.ip_address).filter(
and_(
UserSession.user_id == user.id,
UserSession.created_at >= datetime.now(timezone.utc) - timedelta(days=30)
)
).distinct().all()
if ip_address not in [ip[0] for ip in previous_ips]:
risk_score += 30
risk_factors.append("new_ip_address")
# Check for unusual login time
current_hour = datetime.now(timezone.utc).hour
user_login_hours = self.db.query(func.extract('hour', UserSession.created_at)).filter(
and_(
UserSession.user_id == user.id,
UserSession.created_at >= datetime.now(timezone.utc) - timedelta(days=30)
)
).all()
if user_login_hours:
common_hours = [hour[0] for hour in user_login_hours]
if current_hour not in common_hours[-10:]: # Not in recent login hours
risk_score += 20
risk_factors.append("unusual_time")
# Check for new user agent
recent_agents = self.db.query(UserSession.user_agent).filter(
and_(
UserSession.user_id == user.id,
UserSession.created_at >= datetime.now(timezone.utc) - timedelta(days=7)
)
).distinct().all()
if user_agent not in [agent[0] for agent in recent_agents if agent[0]]:
risk_score += 15
risk_factors.append("new_user_agent")
# Check for rapid login attempts
recent_attempts = self.db.query(UserSession).filter(
and_(
UserSession.user_id == user.id,
UserSession.created_at >= datetime.now(timezone.utc) - timedelta(minutes=10)
)
).count()
if recent_attempts > 3:
risk_score += 25
risk_factors.append("rapid_attempts")
is_suspicious = risk_score >= 50
return is_suspicious, min(risk_score, 100)
def _log_activity(
self,
session: UserSession,
user: User,
request: Request,
activity_type: str,
endpoint: str = None
) -> None:
"""Log session activity"""
activity = SessionActivity(
session_id=session.id,
user_id=user.id,
activity_type=activity_type,
endpoint=endpoint or str(request.url.path),
method=request.method,
ip_address=self._get_client_ip(request),
user_agent=request.headers.get("user-agent", "")
)
self.db.add(activity)
def _create_security_event(
self,
session: Optional[UserSession],
user: User,
event_type: str,
severity: str,
description: str,
ip_address: str = None,
user_agent: str = None,
country: str = None,
action_taken: str = None
) -> None:
"""Create security event record"""
event = SessionSecurityEvent(
session_id=session.id if session else None,
user_id=user.id,
event_type=event_type,
severity=severity,
description=description,
ip_address=ip_address,
user_agent=user_agent,
country=country,
action_taken=action_taken
)
self.db.add(event)
logger.warning(f"Security event [{severity}]: {event_type} for user {user.username} - {description}")
def _get_client_ip(self, request: Request) -> str:
"""Extract client IP address from request"""
# Check for forwarded headers first
forwarded_for = request.headers.get("x-forwarded-for")
if forwarded_for:
return forwarded_for.split(",")[0].strip()
real_ip = request.headers.get("x-real-ip")
if real_ip:
return real_ip
# Fallback to direct connection
return request.client.host if request.client else "unknown"
def _generate_device_fingerprint(self, request: Request) -> str:
"""Generate device fingerprint for tracking"""
user_agent = request.headers.get("user-agent", "")
accept_language = request.headers.get("accept-language", "")
accept_encoding = request.headers.get("accept-encoding", "")
fingerprint_data = f"{user_agent}|{accept_language}|{accept_encoding}"
return hashlib.md5(fingerprint_data.encode()).hexdigest()
def _get_geographic_info(self, ip_address: str) -> Tuple[Optional[str], Optional[str]]:
"""Get geographic information for IP address"""
# Placeholder - would integrate with GeoIP service like MaxMind
# For now, return None values
return None, None
def get_session_manager(db: Session = Depends(get_db)) -> SessionManager:
"""Dependency injection for session manager"""
return SessionManager(db)