changes
This commit is contained in:
379
app/utils/database_security.py
Normal file
379
app/utils/database_security.py
Normal file
@@ -0,0 +1,379 @@
|
||||
"""
|
||||
Database Security Utilities
|
||||
|
||||
Provides utilities for secure database operations and SQL injection prevention:
|
||||
- Parameterized query helpers
|
||||
- SQL injection detection and prevention
|
||||
- Safe query building utilities
|
||||
- Database security auditing
|
||||
"""
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql import ClauseElement
|
||||
from app.utils.logging import app_logger
|
||||
|
||||
logger = app_logger.bind(name="database_security")
|
||||
|
||||
# Dangerous SQL patterns that could indicate injection attempts
|
||||
DANGEROUS_PATTERNS = [
|
||||
# SQL injection keywords
|
||||
r'\bunion\s+select\b',
|
||||
r'\b(drop\s+table)\b',
|
||||
r'\bdelete\s+from\b',
|
||||
r'\b(insert\s+into)\b',
|
||||
r'\b(update\s+.*set)\b',
|
||||
r'\b(alter\s+table)\b',
|
||||
r'\b(create\s+table)\b',
|
||||
r'\b(truncate\s+table)\b',
|
||||
|
||||
# Command execution
|
||||
r'\b(exec\s*\()\b',
|
||||
r'\b(execute\s*\()\b',
|
||||
r'\b(system\s*\()\b',
|
||||
r'\b(shell_exec\s*\()\b',
|
||||
r'\b(eval\s*\()\b',
|
||||
|
||||
# Comment-based attacks
|
||||
r'(;\s*drop\s+table\b)',
|
||||
r'(--\s*$)',
|
||||
r'(/\*.*?\*/)',
|
||||
r'(\#.*$)',
|
||||
|
||||
# Quote escaping attempts
|
||||
r"(';|';|\";|\")",
|
||||
r"(\\\\'|\\\\\"|\\\\x)",
|
||||
|
||||
# Hex/unicode encoding
|
||||
r'(0x[0-9a-fA-F]+)',
|
||||
r'(\\u[0-9a-fA-F]{4})',
|
||||
|
||||
# Boolean-based attacks
|
||||
r'\b(1=1|1=0|true|false)\b',
|
||||
r"\b(or\s+1\s*=\s*1|and\s+1\s*=\s*1)\b",
|
||||
|
||||
# Time-based attacks
|
||||
r'\b(sleep\s*\(|delay\s*\(|waitfor\s+delay)\b',
|
||||
|
||||
# File operations
|
||||
r'\b(load_file\s*\(|into\s+outfile|into\s+dumpfile)\b',
|
||||
|
||||
# Information schema access
|
||||
r'\b(information_schema|sys\.|pg_)\b',
|
||||
# Subselect usage in WHERE clause that may indicate enumeration
|
||||
r'\b\(\s*select\b',
|
||||
]
|
||||
|
||||
# Compiled regex patterns for performance
|
||||
COMPILED_PATTERNS = [re.compile(pattern, re.IGNORECASE | re.MULTILINE) for pattern in DANGEROUS_PATTERNS]
|
||||
|
||||
|
||||
class SQLSecurityValidator:
|
||||
"""Validates SQL queries and parameters for security issues"""
|
||||
|
||||
@staticmethod
|
||||
def validate_query_string(query: str) -> List[str]:
|
||||
"""Validate a SQL query string for potential injection attempts"""
|
||||
issues: List[str] = []
|
||||
|
||||
for i, pattern in enumerate(COMPILED_PATTERNS):
|
||||
try:
|
||||
matches = pattern.findall(query)
|
||||
except Exception:
|
||||
matches = []
|
||||
if matches:
|
||||
issues.append(
|
||||
f"Potentially dangerous SQL pattern detected: {DANGEROUS_PATTERNS[i]} -> {str(matches)[:80]}"
|
||||
)
|
||||
|
||||
# Heuristic fallback to catch common cases without relying on regex quirks
|
||||
if not issues:
|
||||
ql = (query or "").lower()
|
||||
if "; drop table" in ql or ";drop table" in ql:
|
||||
issues.append("Heuristic: DROP TABLE detected")
|
||||
if " union select " in ql or " union\nselect " in ql:
|
||||
issues.append("Heuristic: UNION SELECT detected")
|
||||
if " or 1=1" in ql or " and 1=1" in ql or " or 1 = 1" in ql or " and 1 = 1" in ql:
|
||||
issues.append("Heuristic: tautology (1=1) detected")
|
||||
if "(select" in ql:
|
||||
issues.append("Heuristic: subselect in WHERE detected")
|
||||
|
||||
return issues
|
||||
|
||||
@staticmethod
|
||||
def validate_parameter_value(param_name: str, param_value: Any) -> List[str]:
|
||||
"""Validate a parameter value for potential injection attempts"""
|
||||
issues = []
|
||||
|
||||
if param_value is None:
|
||||
return issues
|
||||
|
||||
# Convert to string for pattern matching
|
||||
str_value = str(param_value)
|
||||
|
||||
# Check for dangerous patterns in parameter values
|
||||
for i, pattern in enumerate(COMPILED_PATTERNS):
|
||||
matches = pattern.findall(str_value)
|
||||
if matches:
|
||||
issues.append(f"Parameter '{param_name}' contains dangerous pattern: {DANGEROUS_PATTERNS[i]}")
|
||||
|
||||
# Additional parameter-specific checks
|
||||
if isinstance(param_value, str):
|
||||
# Check for excessive length (potential buffer overflow)
|
||||
if len(param_value) > 10000:
|
||||
issues.append(f"Parameter '{param_name}' is excessively long ({len(param_value)} characters)")
|
||||
|
||||
# Check for null bytes
|
||||
if '\x00' in param_value:
|
||||
issues.append(f"Parameter '{param_name}' contains null bytes")
|
||||
|
||||
# Check for control characters
|
||||
if any(ord(c) < 32 and c not in '\t\n\r' for c in param_value):
|
||||
issues.append(f"Parameter '{param_name}' contains suspicious control characters")
|
||||
|
||||
return issues
|
||||
|
||||
@staticmethod
|
||||
def validate_query_with_params(query: str, params: Dict[str, Any]) -> List[str]:
|
||||
"""Validate a complete query with its parameters"""
|
||||
issues = []
|
||||
|
||||
# Validate the query string
|
||||
query_issues = SQLSecurityValidator.validate_query_string(query)
|
||||
issues.extend(query_issues)
|
||||
|
||||
# Validate each parameter
|
||||
for param_name, param_value in params.items():
|
||||
param_issues = SQLSecurityValidator.validate_parameter_value(param_name, param_value)
|
||||
issues.extend(param_issues)
|
||||
|
||||
return issues
|
||||
|
||||
|
||||
class SecureQueryBuilder:
|
||||
"""Builds secure parameterized queries to prevent SQL injection"""
|
||||
|
||||
@staticmethod
|
||||
def safe_text_query(query: str, params: Optional[Dict[str, Any]] = None) -> text:
|
||||
"""Create a safe text query with parameter validation"""
|
||||
params = params or {}
|
||||
|
||||
# Validate the query and parameters
|
||||
issues = SQLSecurityValidator.validate_query_with_params(query, params)
|
||||
|
||||
if issues:
|
||||
logger.warning("Potential security issues in query", query=query[:100], issues=issues)
|
||||
# In production, you might want to raise an exception or sanitize
|
||||
# For now, we'll log and proceed with caution
|
||||
|
||||
return text(query)
|
||||
|
||||
@staticmethod
|
||||
def build_like_clause(column: ClauseElement, search_term: str, case_sensitive: bool = False) -> ClauseElement:
|
||||
"""Build a safe LIKE clause with proper escaping"""
|
||||
# Escape special characters in the search term
|
||||
escaped_term = search_term.replace('%', r'\%').replace('_', r'\_').replace('\\', r'\\')
|
||||
|
||||
if case_sensitive:
|
||||
return column.like(f"%{escaped_term}%", escape='\\')
|
||||
else:
|
||||
return column.ilike(f"%{escaped_term}%", escape='\\')
|
||||
|
||||
@staticmethod
|
||||
def build_in_clause(column: ClauseElement, values: List[Any]) -> ClauseElement:
|
||||
"""Build a safe IN clause with parameter binding"""
|
||||
if not values:
|
||||
return column.in_([]) # Empty list
|
||||
|
||||
# Validate each value
|
||||
for i, value in enumerate(values):
|
||||
issues = SQLSecurityValidator.validate_parameter_value(f"in_value_{i}", value)
|
||||
if issues:
|
||||
logger.warning(f"Security issues in IN clause value: {issues}")
|
||||
|
||||
return column.in_(values)
|
||||
|
||||
@staticmethod
|
||||
def build_fts_query(search_terms: List[str], exact_phrase: bool = False) -> str:
|
||||
"""Build a safe FTS (Full Text Search) query"""
|
||||
if not search_terms:
|
||||
return ""
|
||||
|
||||
safe_terms = []
|
||||
for term in search_terms:
|
||||
# Remove or escape dangerous characters for FTS
|
||||
# SQLite FTS5 has its own escaping rules
|
||||
safe_term = re.sub(r'[^\w\s\-\.]', '', term.strip())
|
||||
if safe_term:
|
||||
if exact_phrase:
|
||||
# Escape quotes for phrase search
|
||||
safe_term = safe_term.replace('"', '""')
|
||||
safe_terms.append(f'"{safe_term}"')
|
||||
else:
|
||||
safe_terms.append(safe_term)
|
||||
|
||||
if exact_phrase:
|
||||
return ' '.join(safe_terms)
|
||||
else:
|
||||
return ' AND '.join(safe_terms)
|
||||
|
||||
|
||||
class DatabaseAuditor:
|
||||
"""Audits database operations for security compliance"""
|
||||
|
||||
@staticmethod
|
||||
def audit_query_execution(query: str, params: Dict[str, Any], execution_time: float) -> None:
|
||||
"""Audit a query execution for security and performance"""
|
||||
# Security audit
|
||||
security_issues = SQLSecurityValidator.validate_query_with_params(query, params)
|
||||
|
||||
if security_issues:
|
||||
logger.warning(
|
||||
"Query executed with security issues",
|
||||
query=query[:200],
|
||||
issues=security_issues,
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
# Performance audit
|
||||
if execution_time > 5.0: # Slow query threshold
|
||||
logger.warning(
|
||||
"Slow query detected",
|
||||
query=query[:200],
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
# Pattern audit - detect potentially problematic queries
|
||||
if re.search(r'\bSELECT\s+\*\s+FROM\b', query, re.IGNORECASE):
|
||||
logger.info("SELECT * query detected - consider specifying columns", query=query[:100])
|
||||
|
||||
if re.search(r'\bLIMIT\s+\d{4,}\b', query, re.IGNORECASE):
|
||||
logger.info("Large LIMIT detected", query=query[:100])
|
||||
|
||||
|
||||
def execute_secure_query(
|
||||
db: Session,
|
||||
query: str,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
audit: bool = True
|
||||
) -> Any:
|
||||
"""Execute a query with security validation and auditing"""
|
||||
import time
|
||||
|
||||
params = params or {}
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Validate query security
|
||||
if audit:
|
||||
issues = SQLSecurityValidator.validate_query_with_params(query, params)
|
||||
if issues:
|
||||
logger.warning("Executing query with potential security issues", issues=issues)
|
||||
|
||||
# Create safe text query
|
||||
safe_query = SecureQueryBuilder.safe_text_query(query, params)
|
||||
|
||||
# Execute query
|
||||
result = db.execute(safe_query, params)
|
||||
|
||||
# Audit execution
|
||||
if audit:
|
||||
execution_time = time.time() - start_time
|
||||
DatabaseAuditor.audit_query_execution(query, params, execution_time)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
execution_time = time.time() - start_time
|
||||
logger.error(
|
||||
"Query execution failed",
|
||||
query=query[:200],
|
||||
error=str(e),
|
||||
execution_time=execution_time
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
def sanitize_fts_query(query: str) -> str:
|
||||
"""Sanitize user input for FTS queries"""
|
||||
if not query:
|
||||
return ""
|
||||
|
||||
# Remove potentially dangerous characters
|
||||
# Keep alphanumeric, spaces, and basic punctuation
|
||||
sanitized = re.sub(r'[^\w\s\-\.\,\!\?\"\']', ' ', query)
|
||||
|
||||
# Remove excessive whitespace
|
||||
sanitized = re.sub(r'\s+', ' ', sanitized).strip()
|
||||
|
||||
# Limit length
|
||||
if len(sanitized) > 500:
|
||||
sanitized = sanitized[:500]
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
def create_safe_search_conditions(
|
||||
search_terms: List[str],
|
||||
searchable_columns: List[ClauseElement],
|
||||
case_sensitive: bool = False,
|
||||
exact_phrase: bool = False
|
||||
) -> Optional[ClauseElement]:
|
||||
"""Create safe search conditions for multiple columns"""
|
||||
from sqlalchemy import or_, and_
|
||||
|
||||
if not search_terms or not searchable_columns:
|
||||
return None
|
||||
|
||||
search_conditions = []
|
||||
|
||||
if exact_phrase:
|
||||
# Single phrase search across all columns
|
||||
phrase = ' '.join(search_terms)
|
||||
phrase_conditions = []
|
||||
for column in searchable_columns:
|
||||
phrase_conditions.append(
|
||||
SecureQueryBuilder.build_like_clause(column, phrase, case_sensitive)
|
||||
)
|
||||
search_conditions.append(or_(*phrase_conditions))
|
||||
else:
|
||||
# Each term must match at least one column
|
||||
for term in search_terms:
|
||||
term_conditions = []
|
||||
for column in searchable_columns:
|
||||
term_conditions.append(
|
||||
SecureQueryBuilder.build_like_clause(column, term, case_sensitive)
|
||||
)
|
||||
search_conditions.append(or_(*term_conditions))
|
||||
|
||||
return and_(*search_conditions) if search_conditions else None
|
||||
|
||||
|
||||
# Whitelist of allowed column names for dynamic queries
|
||||
ALLOWED_SORT_COLUMNS = {
|
||||
'rolodex': ['id', 'first', 'last', 'city', 'email', 'created_at', 'updated_at'],
|
||||
'files': ['file_no', 'id', 'regarding', 'status', 'file_type', 'opened', 'closed', 'created_at', 'updated_at'],
|
||||
'ledger': ['id', 'file_no', 't_code', 'amount', 'date', 'created_at', 'updated_at'],
|
||||
'qdros': ['id', 'file_no', 'form_name', 'status', 'created_at', 'updated_at'],
|
||||
}
|
||||
|
||||
def validate_sort_column(table: str, column: str) -> bool:
|
||||
"""Validate that a sort column is allowed for a table"""
|
||||
allowed_columns = ALLOWED_SORT_COLUMNS.get(table, [])
|
||||
return column in allowed_columns
|
||||
|
||||
|
||||
def safe_order_by(table: str, sort_column: str, sort_direction: str = 'asc') -> Optional[str]:
|
||||
"""Create a safe ORDER BY clause with whitelist validation"""
|
||||
# Validate sort column
|
||||
if not validate_sort_column(table, sort_column):
|
||||
logger.warning(f"Invalid sort column '{sort_column}' for table '{table}'")
|
||||
return None
|
||||
|
||||
# Validate sort direction
|
||||
if sort_direction.lower() not in ['asc', 'desc']:
|
||||
logger.warning(f"Invalid sort direction '{sort_direction}'")
|
||||
return None
|
||||
|
||||
return f"{sort_column} {sort_direction.upper()}"
|
||||
Reference in New Issue
Block a user