""" Database Security Utilities Provides utilities for secure database operations and SQL injection prevention: - Parameterized query helpers - SQL injection detection and prevention - Safe query building utilities - Database security auditing """ import re from typing import Any, Dict, List, Optional, Union from sqlalchemy import text from sqlalchemy.orm import Session from sqlalchemy.sql import ClauseElement from app.utils.logging import app_logger logger = app_logger.bind(name="database_security") # Dangerous SQL patterns that could indicate injection attempts DANGEROUS_PATTERNS = [ # SQL injection keywords r'\bunion\s+select\b', r'\b(drop\s+table)\b', r'\bdelete\s+from\b', r'\b(insert\s+into)\b', r'\b(update\s+.*set)\b', r'\b(alter\s+table)\b', r'\b(create\s+table)\b', r'\b(truncate\s+table)\b', # Command execution r'\b(exec\s*\()\b', r'\b(execute\s*\()\b', r'\b(system\s*\()\b', r'\b(shell_exec\s*\()\b', r'\b(eval\s*\()\b', # Comment-based attacks r'(;\s*drop\s+table\b)', r'(--\s*$)', r'(/\*.*?\*/)', r'(\#.*$)', # Quote escaping attempts r"(';|';|\";|\")", r"(\\\\'|\\\\\"|\\\\x)", # Hex/unicode encoding r'(0x[0-9a-fA-F]+)', r'(\\u[0-9a-fA-F]{4})', # Boolean-based attacks r'\b(1=1|1=0|true|false)\b', r"\b(or\s+1\s*=\s*1|and\s+1\s*=\s*1)\b", # Time-based attacks r'\b(sleep\s*\(|delay\s*\(|waitfor\s+delay)\b', # File operations r'\b(load_file\s*\(|into\s+outfile|into\s+dumpfile)\b', # Information schema access r'\b(information_schema|sys\.|pg_)\b', # Subselect usage in WHERE clause that may indicate enumeration r'\b\(\s*select\b', ] # Compiled regex patterns for performance COMPILED_PATTERNS = [re.compile(pattern, re.IGNORECASE | re.MULTILINE) for pattern in DANGEROUS_PATTERNS] class SQLSecurityValidator: """Validates SQL queries and parameters for security issues""" @staticmethod def validate_query_string(query: str) -> List[str]: """Validate a SQL query string for potential injection attempts""" issues: List[str] = [] for i, pattern in enumerate(COMPILED_PATTERNS): try: matches = pattern.findall(query) except Exception: matches = [] if matches: issues.append( f"Potentially dangerous SQL pattern detected: {DANGEROUS_PATTERNS[i]} -> {str(matches)[:80]}" ) # Heuristic fallback to catch common cases without relying on regex quirks if not issues: ql = (query or "").lower() if "; drop table" in ql or ";drop table" in ql: issues.append("Heuristic: DROP TABLE detected") if " union select " in ql or " union\nselect " in ql: issues.append("Heuristic: UNION SELECT detected") if " or 1=1" in ql or " and 1=1" in ql or " or 1 = 1" in ql or " and 1 = 1" in ql: issues.append("Heuristic: tautology (1=1) detected") if "(select" in ql: issues.append("Heuristic: subselect in WHERE detected") return issues @staticmethod def validate_parameter_value(param_name: str, param_value: Any) -> List[str]: """Validate a parameter value for potential injection attempts""" issues = [] if param_value is None: return issues # Convert to string for pattern matching str_value = str(param_value) # Check for dangerous patterns in parameter values for i, pattern in enumerate(COMPILED_PATTERNS): matches = pattern.findall(str_value) if matches: issues.append(f"Parameter '{param_name}' contains dangerous pattern: {DANGEROUS_PATTERNS[i]}") # Additional parameter-specific checks if isinstance(param_value, str): # Check for excessive length (potential buffer overflow) if len(param_value) > 10000: issues.append(f"Parameter '{param_name}' is excessively long ({len(param_value)} characters)") # Check for null bytes if '\x00' in param_value: issues.append(f"Parameter '{param_name}' contains null bytes") # Check for control characters if any(ord(c) < 32 and c not in '\t\n\r' for c in param_value): issues.append(f"Parameter '{param_name}' contains suspicious control characters") return issues @staticmethod def validate_query_with_params(query: str, params: Dict[str, Any]) -> List[str]: """Validate a complete query with its parameters""" issues = [] # Validate the query string query_issues = SQLSecurityValidator.validate_query_string(query) issues.extend(query_issues) # Validate each parameter for param_name, param_value in params.items(): param_issues = SQLSecurityValidator.validate_parameter_value(param_name, param_value) issues.extend(param_issues) return issues class SecureQueryBuilder: """Builds secure parameterized queries to prevent SQL injection""" @staticmethod def safe_text_query(query: str, params: Optional[Dict[str, Any]] = None) -> text: """Create a safe text query with parameter validation""" params = params or {} # Validate the query and parameters issues = SQLSecurityValidator.validate_query_with_params(query, params) if issues: logger.warning("Potential security issues in query", query=query[:100], issues=issues) # In production, you might want to raise an exception or sanitize # For now, we'll log and proceed with caution return text(query) @staticmethod def build_like_clause(column: ClauseElement, search_term: str, case_sensitive: bool = False) -> ClauseElement: """Build a safe LIKE clause with proper escaping""" # Escape special characters in the search term escaped_term = search_term.replace('%', r'\%').replace('_', r'\_').replace('\\', r'\\') if case_sensitive: return column.like(f"%{escaped_term}%", escape='\\') else: return column.ilike(f"%{escaped_term}%", escape='\\') @staticmethod def build_in_clause(column: ClauseElement, values: List[Any]) -> ClauseElement: """Build a safe IN clause with parameter binding""" if not values: return column.in_([]) # Empty list # Validate each value for i, value in enumerate(values): issues = SQLSecurityValidator.validate_parameter_value(f"in_value_{i}", value) if issues: logger.warning(f"Security issues in IN clause value: {issues}") return column.in_(values) @staticmethod def build_fts_query(search_terms: List[str], exact_phrase: bool = False) -> str: """Build a safe FTS (Full Text Search) query""" if not search_terms: return "" safe_terms = [] for term in search_terms: # Remove or escape dangerous characters for FTS # SQLite FTS5 has its own escaping rules safe_term = re.sub(r'[^\w\s\-\.]', '', term.strip()) if safe_term: if exact_phrase: # Escape quotes for phrase search safe_term = safe_term.replace('"', '""') safe_terms.append(f'"{safe_term}"') else: safe_terms.append(safe_term) if exact_phrase: return ' '.join(safe_terms) else: return ' AND '.join(safe_terms) class DatabaseAuditor: """Audits database operations for security compliance""" @staticmethod def audit_query_execution(query: str, params: Dict[str, Any], execution_time: float) -> None: """Audit a query execution for security and performance""" # Security audit security_issues = SQLSecurityValidator.validate_query_with_params(query, params) if security_issues: logger.warning( "Query executed with security issues", query=query[:200], issues=security_issues, execution_time=execution_time ) # Performance audit if execution_time > 5.0: # Slow query threshold logger.warning( "Slow query detected", query=query[:200], execution_time=execution_time ) # Pattern audit - detect potentially problematic queries if re.search(r'\bSELECT\s+\*\s+FROM\b', query, re.IGNORECASE): logger.info("SELECT * query detected - consider specifying columns", query=query[:100]) if re.search(r'\bLIMIT\s+\d{4,}\b', query, re.IGNORECASE): logger.info("Large LIMIT detected", query=query[:100]) def execute_secure_query( db: Session, query: str, params: Optional[Dict[str, Any]] = None, audit: bool = True ) -> Any: """Execute a query with security validation and auditing""" import time params = params or {} start_time = time.time() try: # Validate query security if audit: issues = SQLSecurityValidator.validate_query_with_params(query, params) if issues: logger.warning("Executing query with potential security issues", issues=issues) # Create safe text query safe_query = SecureQueryBuilder.safe_text_query(query, params) # Execute query result = db.execute(safe_query, params) # Audit execution if audit: execution_time = time.time() - start_time DatabaseAuditor.audit_query_execution(query, params, execution_time) return result except Exception as e: execution_time = time.time() - start_time logger.error( "Query execution failed", query=query[:200], error=str(e), execution_time=execution_time ) raise def sanitize_fts_query(query: str) -> str: """Sanitize user input for FTS queries""" if not query: return "" # Remove potentially dangerous characters # Keep alphanumeric, spaces, and basic punctuation sanitized = re.sub(r'[^\w\s\-\.\,\!\?\"\']', ' ', query) # Remove excessive whitespace sanitized = re.sub(r'\s+', ' ', sanitized).strip() # Limit length if len(sanitized) > 500: sanitized = sanitized[:500] return sanitized def create_safe_search_conditions( search_terms: List[str], searchable_columns: List[ClauseElement], case_sensitive: bool = False, exact_phrase: bool = False ) -> Optional[ClauseElement]: """Create safe search conditions for multiple columns""" from sqlalchemy import or_, and_ if not search_terms or not searchable_columns: return None search_conditions = [] if exact_phrase: # Single phrase search across all columns phrase = ' '.join(search_terms) phrase_conditions = [] for column in searchable_columns: phrase_conditions.append( SecureQueryBuilder.build_like_clause(column, phrase, case_sensitive) ) search_conditions.append(or_(*phrase_conditions)) else: # Each term must match at least one column for term in search_terms: term_conditions = [] for column in searchable_columns: term_conditions.append( SecureQueryBuilder.build_like_clause(column, term, case_sensitive) ) search_conditions.append(or_(*term_conditions)) return and_(*search_conditions) if search_conditions else None # Whitelist of allowed column names for dynamic queries ALLOWED_SORT_COLUMNS = { 'rolodex': ['id', 'first', 'last', 'city', 'email', 'created_at', 'updated_at'], 'files': ['file_no', 'id', 'regarding', 'status', 'file_type', 'opened', 'closed', 'created_at', 'updated_at'], 'ledger': ['id', 'file_no', 't_code', 'amount', 'date', 'created_at', 'updated_at'], 'qdros': ['id', 'file_no', 'form_name', 'status', 'created_at', 'updated_at'], } def validate_sort_column(table: str, column: str) -> bool: """Validate that a sort column is allowed for a table""" allowed_columns = ALLOWED_SORT_COLUMNS.get(table, []) return column in allowed_columns def safe_order_by(table: str, sort_column: str, sort_direction: str = 'asc') -> Optional[str]: """Create a safe ORDER BY clause with whitelist validation""" # Validate sort column if not validate_sort_column(table, sort_column): logger.warning(f"Invalid sort column '{sort_column}' for table '{table}'") return None # Validate sort direction if sort_direction.lower() not in ['asc', 'desc']: logger.warning(f"Invalid sort direction '{sort_direction}'") return None return f"{sort_column} {sort_direction.upper()}"