380 lines
13 KiB
Python
380 lines
13 KiB
Python
"""
|
|
Database Security Utilities
|
|
|
|
Provides utilities for secure database operations and SQL injection prevention:
|
|
- Parameterized query helpers
|
|
- SQL injection detection and prevention
|
|
- Safe query building utilities
|
|
- Database security auditing
|
|
"""
|
|
import re
|
|
from typing import Any, Dict, List, Optional, Union
|
|
from sqlalchemy import text
|
|
from sqlalchemy.orm import Session
|
|
from sqlalchemy.sql import ClauseElement
|
|
from app.utils.logging import app_logger
|
|
|
|
logger = app_logger.bind(name="database_security")
|
|
|
|
# Dangerous SQL patterns that could indicate injection attempts
|
|
DANGEROUS_PATTERNS = [
|
|
# SQL injection keywords
|
|
r'\bunion\s+select\b',
|
|
r'\b(drop\s+table)\b',
|
|
r'\bdelete\s+from\b',
|
|
r'\b(insert\s+into)\b',
|
|
r'\b(update\s+.*set)\b',
|
|
r'\b(alter\s+table)\b',
|
|
r'\b(create\s+table)\b',
|
|
r'\b(truncate\s+table)\b',
|
|
|
|
# Command execution
|
|
r'\b(exec\s*\()\b',
|
|
r'\b(execute\s*\()\b',
|
|
r'\b(system\s*\()\b',
|
|
r'\b(shell_exec\s*\()\b',
|
|
r'\b(eval\s*\()\b',
|
|
|
|
# Comment-based attacks
|
|
r'(;\s*drop\s+table\b)',
|
|
r'(--\s*$)',
|
|
r'(/\*.*?\*/)',
|
|
r'(\#.*$)',
|
|
|
|
# Quote escaping attempts
|
|
r"(';|';|\";|\")",
|
|
r"(\\\\'|\\\\\"|\\\\x)",
|
|
|
|
# Hex/unicode encoding
|
|
r'(0x[0-9a-fA-F]+)',
|
|
r'(\\u[0-9a-fA-F]{4})',
|
|
|
|
# Boolean-based attacks
|
|
r'\b(1=1|1=0|true|false)\b',
|
|
r"\b(or\s+1\s*=\s*1|and\s+1\s*=\s*1)\b",
|
|
|
|
# Time-based attacks
|
|
r'\b(sleep\s*\(|delay\s*\(|waitfor\s+delay)\b',
|
|
|
|
# File operations
|
|
r'\b(load_file\s*\(|into\s+outfile|into\s+dumpfile)\b',
|
|
|
|
# Information schema access
|
|
r'\b(information_schema|sys\.|pg_)\b',
|
|
# Subselect usage in WHERE clause that may indicate enumeration
|
|
r'\b\(\s*select\b',
|
|
]
|
|
|
|
# Compiled regex patterns for performance
|
|
COMPILED_PATTERNS = [re.compile(pattern, re.IGNORECASE | re.MULTILINE) for pattern in DANGEROUS_PATTERNS]
|
|
|
|
|
|
class SQLSecurityValidator:
|
|
"""Validates SQL queries and parameters for security issues"""
|
|
|
|
@staticmethod
|
|
def validate_query_string(query: str) -> List[str]:
|
|
"""Validate a SQL query string for potential injection attempts"""
|
|
issues: List[str] = []
|
|
|
|
for i, pattern in enumerate(COMPILED_PATTERNS):
|
|
try:
|
|
matches = pattern.findall(query)
|
|
except Exception:
|
|
matches = []
|
|
if matches:
|
|
issues.append(
|
|
f"Potentially dangerous SQL pattern detected: {DANGEROUS_PATTERNS[i]} -> {str(matches)[:80]}"
|
|
)
|
|
|
|
# Heuristic fallback to catch common cases without relying on regex quirks
|
|
if not issues:
|
|
ql = (query or "").lower()
|
|
if "; drop table" in ql or ";drop table" in ql:
|
|
issues.append("Heuristic: DROP TABLE detected")
|
|
if " union select " in ql or " union\nselect " in ql:
|
|
issues.append("Heuristic: UNION SELECT detected")
|
|
if " or 1=1" in ql or " and 1=1" in ql or " or 1 = 1" in ql or " and 1 = 1" in ql:
|
|
issues.append("Heuristic: tautology (1=1) detected")
|
|
if "(select" in ql:
|
|
issues.append("Heuristic: subselect in WHERE detected")
|
|
|
|
return issues
|
|
|
|
@staticmethod
|
|
def validate_parameter_value(param_name: str, param_value: Any) -> List[str]:
|
|
"""Validate a parameter value for potential injection attempts"""
|
|
issues = []
|
|
|
|
if param_value is None:
|
|
return issues
|
|
|
|
# Convert to string for pattern matching
|
|
str_value = str(param_value)
|
|
|
|
# Check for dangerous patterns in parameter values
|
|
for i, pattern in enumerate(COMPILED_PATTERNS):
|
|
matches = pattern.findall(str_value)
|
|
if matches:
|
|
issues.append(f"Parameter '{param_name}' contains dangerous pattern: {DANGEROUS_PATTERNS[i]}")
|
|
|
|
# Additional parameter-specific checks
|
|
if isinstance(param_value, str):
|
|
# Check for excessive length (potential buffer overflow)
|
|
if len(param_value) > 10000:
|
|
issues.append(f"Parameter '{param_name}' is excessively long ({len(param_value)} characters)")
|
|
|
|
# Check for null bytes
|
|
if '\x00' in param_value:
|
|
issues.append(f"Parameter '{param_name}' contains null bytes")
|
|
|
|
# Check for control characters
|
|
if any(ord(c) < 32 and c not in '\t\n\r' for c in param_value):
|
|
issues.append(f"Parameter '{param_name}' contains suspicious control characters")
|
|
|
|
return issues
|
|
|
|
@staticmethod
|
|
def validate_query_with_params(query: str, params: Dict[str, Any]) -> List[str]:
|
|
"""Validate a complete query with its parameters"""
|
|
issues = []
|
|
|
|
# Validate the query string
|
|
query_issues = SQLSecurityValidator.validate_query_string(query)
|
|
issues.extend(query_issues)
|
|
|
|
# Validate each parameter
|
|
for param_name, param_value in params.items():
|
|
param_issues = SQLSecurityValidator.validate_parameter_value(param_name, param_value)
|
|
issues.extend(param_issues)
|
|
|
|
return issues
|
|
|
|
|
|
class SecureQueryBuilder:
|
|
"""Builds secure parameterized queries to prevent SQL injection"""
|
|
|
|
@staticmethod
|
|
def safe_text_query(query: str, params: Optional[Dict[str, Any]] = None) -> text:
|
|
"""Create a safe text query with parameter validation"""
|
|
params = params or {}
|
|
|
|
# Validate the query and parameters
|
|
issues = SQLSecurityValidator.validate_query_with_params(query, params)
|
|
|
|
if issues:
|
|
logger.warning("Potential security issues in query", query=query[:100], issues=issues)
|
|
# In production, you might want to raise an exception or sanitize
|
|
# For now, we'll log and proceed with caution
|
|
|
|
return text(query)
|
|
|
|
@staticmethod
|
|
def build_like_clause(column: ClauseElement, search_term: str, case_sensitive: bool = False) -> ClauseElement:
|
|
"""Build a safe LIKE clause with proper escaping"""
|
|
# Escape special characters in the search term
|
|
escaped_term = search_term.replace('%', r'\%').replace('_', r'\_').replace('\\', r'\\')
|
|
|
|
if case_sensitive:
|
|
return column.like(f"%{escaped_term}%", escape='\\')
|
|
else:
|
|
return column.ilike(f"%{escaped_term}%", escape='\\')
|
|
|
|
@staticmethod
|
|
def build_in_clause(column: ClauseElement, values: List[Any]) -> ClauseElement:
|
|
"""Build a safe IN clause with parameter binding"""
|
|
if not values:
|
|
return column.in_([]) # Empty list
|
|
|
|
# Validate each value
|
|
for i, value in enumerate(values):
|
|
issues = SQLSecurityValidator.validate_parameter_value(f"in_value_{i}", value)
|
|
if issues:
|
|
logger.warning(f"Security issues in IN clause value: {issues}")
|
|
|
|
return column.in_(values)
|
|
|
|
@staticmethod
|
|
def build_fts_query(search_terms: List[str], exact_phrase: bool = False) -> str:
|
|
"""Build a safe FTS (Full Text Search) query"""
|
|
if not search_terms:
|
|
return ""
|
|
|
|
safe_terms = []
|
|
for term in search_terms:
|
|
# Remove or escape dangerous characters for FTS
|
|
# SQLite FTS5 has its own escaping rules
|
|
safe_term = re.sub(r'[^\w\s\-\.]', '', term.strip())
|
|
if safe_term:
|
|
if exact_phrase:
|
|
# Escape quotes for phrase search
|
|
safe_term = safe_term.replace('"', '""')
|
|
safe_terms.append(f'"{safe_term}"')
|
|
else:
|
|
safe_terms.append(safe_term)
|
|
|
|
if exact_phrase:
|
|
return ' '.join(safe_terms)
|
|
else:
|
|
return ' AND '.join(safe_terms)
|
|
|
|
|
|
class DatabaseAuditor:
|
|
"""Audits database operations for security compliance"""
|
|
|
|
@staticmethod
|
|
def audit_query_execution(query: str, params: Dict[str, Any], execution_time: float) -> None:
|
|
"""Audit a query execution for security and performance"""
|
|
# Security audit
|
|
security_issues = SQLSecurityValidator.validate_query_with_params(query, params)
|
|
|
|
if security_issues:
|
|
logger.warning(
|
|
"Query executed with security issues",
|
|
query=query[:200],
|
|
issues=security_issues,
|
|
execution_time=execution_time
|
|
)
|
|
|
|
# Performance audit
|
|
if execution_time > 5.0: # Slow query threshold
|
|
logger.warning(
|
|
"Slow query detected",
|
|
query=query[:200],
|
|
execution_time=execution_time
|
|
)
|
|
|
|
# Pattern audit - detect potentially problematic queries
|
|
if re.search(r'\bSELECT\s+\*\s+FROM\b', query, re.IGNORECASE):
|
|
logger.info("SELECT * query detected - consider specifying columns", query=query[:100])
|
|
|
|
if re.search(r'\bLIMIT\s+\d{4,}\b', query, re.IGNORECASE):
|
|
logger.info("Large LIMIT detected", query=query[:100])
|
|
|
|
|
|
def execute_secure_query(
|
|
db: Session,
|
|
query: str,
|
|
params: Optional[Dict[str, Any]] = None,
|
|
audit: bool = True
|
|
) -> Any:
|
|
"""Execute a query with security validation and auditing"""
|
|
import time
|
|
|
|
params = params or {}
|
|
start_time = time.time()
|
|
|
|
try:
|
|
# Validate query security
|
|
if audit:
|
|
issues = SQLSecurityValidator.validate_query_with_params(query, params)
|
|
if issues:
|
|
logger.warning("Executing query with potential security issues", issues=issues)
|
|
|
|
# Create safe text query
|
|
safe_query = SecureQueryBuilder.safe_text_query(query, params)
|
|
|
|
# Execute query
|
|
result = db.execute(safe_query, params)
|
|
|
|
# Audit execution
|
|
if audit:
|
|
execution_time = time.time() - start_time
|
|
DatabaseAuditor.audit_query_execution(query, params, execution_time)
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
execution_time = time.time() - start_time
|
|
logger.error(
|
|
"Query execution failed",
|
|
query=query[:200],
|
|
error=str(e),
|
|
execution_time=execution_time
|
|
)
|
|
raise
|
|
|
|
|
|
def sanitize_fts_query(query: str) -> str:
|
|
"""Sanitize user input for FTS queries"""
|
|
if not query:
|
|
return ""
|
|
|
|
# Remove potentially dangerous characters
|
|
# Keep alphanumeric, spaces, and basic punctuation
|
|
sanitized = re.sub(r'[^\w\s\-\.\,\!\?\"\']', ' ', query)
|
|
|
|
# Remove excessive whitespace
|
|
sanitized = re.sub(r'\s+', ' ', sanitized).strip()
|
|
|
|
# Limit length
|
|
if len(sanitized) > 500:
|
|
sanitized = sanitized[:500]
|
|
|
|
return sanitized
|
|
|
|
|
|
def create_safe_search_conditions(
|
|
search_terms: List[str],
|
|
searchable_columns: List[ClauseElement],
|
|
case_sensitive: bool = False,
|
|
exact_phrase: bool = False
|
|
) -> Optional[ClauseElement]:
|
|
"""Create safe search conditions for multiple columns"""
|
|
from sqlalchemy import or_, and_
|
|
|
|
if not search_terms or not searchable_columns:
|
|
return None
|
|
|
|
search_conditions = []
|
|
|
|
if exact_phrase:
|
|
# Single phrase search across all columns
|
|
phrase = ' '.join(search_terms)
|
|
phrase_conditions = []
|
|
for column in searchable_columns:
|
|
phrase_conditions.append(
|
|
SecureQueryBuilder.build_like_clause(column, phrase, case_sensitive)
|
|
)
|
|
search_conditions.append(or_(*phrase_conditions))
|
|
else:
|
|
# Each term must match at least one column
|
|
for term in search_terms:
|
|
term_conditions = []
|
|
for column in searchable_columns:
|
|
term_conditions.append(
|
|
SecureQueryBuilder.build_like_clause(column, term, case_sensitive)
|
|
)
|
|
search_conditions.append(or_(*term_conditions))
|
|
|
|
return and_(*search_conditions) if search_conditions else None
|
|
|
|
|
|
# Whitelist of allowed column names for dynamic queries
|
|
ALLOWED_SORT_COLUMNS = {
|
|
'rolodex': ['id', 'first', 'last', 'city', 'email', 'created_at', 'updated_at'],
|
|
'files': ['file_no', 'id', 'regarding', 'status', 'file_type', 'opened', 'closed', 'created_at', 'updated_at'],
|
|
'ledger': ['id', 'file_no', 't_code', 'amount', 'date', 'created_at', 'updated_at'],
|
|
'qdros': ['id', 'file_no', 'form_name', 'status', 'created_at', 'updated_at'],
|
|
}
|
|
|
|
def validate_sort_column(table: str, column: str) -> bool:
|
|
"""Validate that a sort column is allowed for a table"""
|
|
allowed_columns = ALLOWED_SORT_COLUMNS.get(table, [])
|
|
return column in allowed_columns
|
|
|
|
|
|
def safe_order_by(table: str, sort_column: str, sort_direction: str = 'asc') -> Optional[str]:
|
|
"""Create a safe ORDER BY clause with whitelist validation"""
|
|
# Validate sort column
|
|
if not validate_sort_column(table, sort_column):
|
|
logger.warning(f"Invalid sort column '{sort_column}' for table '{table}'")
|
|
return None
|
|
|
|
# Validate sort direction
|
|
if sort_direction.lower() not in ['asc', 'desc']:
|
|
logger.warning(f"Invalid sort direction '{sort_direction}'")
|
|
return None
|
|
|
|
return f"{sort_column} {sort_direction.upper()}"
|