""" Database transaction management utilities for consistent transaction handling. """ from typing import Callable, Any, Optional, TypeVar, Type from functools import wraps from contextlib import contextmanager import logging from sqlalchemy.orm import Session from sqlalchemy.exc import SQLAlchemyError from app.utils.exceptions import DatabaseError, handle_database_errors from app.utils.logging import database_logger T = TypeVar('T') class TransactionManager: """Context manager for database transactions with automatic rollback on errors.""" def __init__(self, db_session: Session, auto_commit: bool = True, auto_rollback: bool = True): self.db_session = db_session self.auto_commit = auto_commit self.auto_rollback = auto_rollback self.committed = False self.rolled_back = False def __enter__(self) -> Session: """Enter transaction context.""" database_logger.log_transaction_event("started", { "auto_commit": self.auto_commit, "auto_rollback": self.auto_rollback }) return self.db_session def __exit__(self, exc_type, exc_val, exc_tb): """Exit transaction context with appropriate commit/rollback.""" try: if exc_type is not None: # Exception occurred if self.auto_rollback and not self.rolled_back: self.rollback() database_logger.log_transaction_event("auto_rollback", { "exception_type": exc_type.__name__ if exc_type else None, "exception_message": str(exc_val) if exc_val else None }) return False # Re-raise the exception else: # No exception if self.auto_commit and not self.committed: self.commit() database_logger.log_transaction_event("auto_commit") except Exception as e: # Error during commit/rollback database_logger.error(f"Error during transaction cleanup: {str(e)}") if not self.rolled_back: try: self.rollback() except Exception: pass # Best effort rollback raise def commit(self): """Manually commit the transaction.""" if not self.committed and not self.rolled_back: try: self.db_session.commit() self.committed = True database_logger.log_transaction_event("manual_commit") except SQLAlchemyError as e: database_logger.error(f"Transaction commit failed: {str(e)}") self.rollback() raise DatabaseError(f"Failed to commit transaction: {str(e)}") def rollback(self): """Manually rollback the transaction.""" if not self.rolled_back: try: self.db_session.rollback() self.rolled_back = True database_logger.log_transaction_event("manual_rollback") except SQLAlchemyError as e: database_logger.error(f"Transaction rollback failed: {str(e)}") raise DatabaseError(f"Failed to rollback transaction: {str(e)}") @contextmanager def db_transaction(db_session: Session, auto_commit: bool = True, auto_rollback: bool = True): """ Context manager for database transactions. Args: db_session: SQLAlchemy session auto_commit: Whether to auto-commit on successful completion auto_rollback: Whether to auto-rollback on exceptions Yields: Session: The database session Example: with db_transaction(db) as session: session.add(new_record) # Auto-commits on exit if no exceptions """ with TransactionManager(db_session, auto_commit, auto_rollback) as session: yield session def transactional(auto_commit: bool = True, auto_rollback: bool = True): """ Decorator to wrap functions in database transactions. Args: auto_commit: Whether to auto-commit on successful completion auto_rollback: Whether to auto-rollback on exceptions Note: The decorated function must accept a 'db' parameter that is a SQLAlchemy session. Example: @transactional() def create_user(user_data: dict, db: Session): user = User(**user_data) db.add(user) return user """ def decorator(func: Callable[..., T]) -> Callable[..., T]: @wraps(func) def wrapper(*args, **kwargs) -> T: # Find the db session in arguments db_session = None # Check kwargs first if 'db' in kwargs: db_session = kwargs['db'] else: # Check args for Session instance for arg in args: if isinstance(arg, Session): db_session = arg break if db_session is None: raise ValueError("Function must have a 'db' parameter with SQLAlchemy Session") with db_transaction(db_session, auto_commit, auto_rollback): return func(*args, **kwargs) return wrapper return decorator class BulkOperationManager: """Manager for bulk database operations with batching and progress tracking.""" def __init__(self, db_session: Session, batch_size: int = 1000): self.db_session = db_session self.batch_size = batch_size self.processed_count = 0 self.error_count = 0 def bulk_insert(self, records: list, model_class: Type): """ Perform bulk insert with batching and error handling. Args: records: List of record dictionaries model_class: SQLAlchemy model class Returns: dict: Summary of operation results """ total_records = len(records) database_logger.info(f"Starting bulk insert of {total_records} {model_class.__name__} records") try: with db_transaction(self.db_session) as session: for i in range(0, total_records, self.batch_size): batch = records[i:i + self.batch_size] try: # Insert batch session.bulk_insert_mappings(model_class, batch) session.flush() # Flush but don't commit yet self.processed_count += len(batch) # Log progress if self.processed_count % (self.batch_size * 10) == 0: database_logger.info(f"Bulk insert progress: {self.processed_count}/{total_records}") except SQLAlchemyError as e: self.error_count += len(batch) database_logger.error(f"Batch insert failed for records {i}-{i+len(batch)}: {str(e)}") raise DatabaseError(f"Bulk insert failed: {str(e)}") # Final commit happens automatically via context manager except Exception as e: database_logger.error(f"Bulk insert operation failed: {str(e)}") raise summary = { "total_records": total_records, "processed": self.processed_count, "errors": self.error_count, "success_rate": (self.processed_count / total_records) * 100 if total_records > 0 else 0 } database_logger.info(f"Bulk insert completed", **summary) return summary def bulk_update(self, updates: list, model_class: Type, key_field: str = 'id'): """ Perform bulk update with batching and error handling. Args: updates: List of update dictionaries (must include key_field) model_class: SQLAlchemy model class key_field: Field name to use as update key Returns: dict: Summary of operation results """ total_updates = len(updates) database_logger.info(f"Starting bulk update of {total_updates} {model_class.__name__} records") try: with db_transaction(self.db_session) as session: for i in range(0, total_updates, self.batch_size): batch = updates[i:i + self.batch_size] try: # Update batch session.bulk_update_mappings(model_class, batch) session.flush() # Flush but don't commit yet self.processed_count += len(batch) # Log progress if self.processed_count % (self.batch_size * 10) == 0: database_logger.info(f"Bulk update progress: {self.processed_count}/{total_updates}") except SQLAlchemyError as e: self.error_count += len(batch) database_logger.error(f"Batch update failed for records {i}-{i+len(batch)}: {str(e)}") raise DatabaseError(f"Bulk update failed: {str(e)}") # Final commit happens automatically via context manager except Exception as e: database_logger.error(f"Bulk update operation failed: {str(e)}") raise summary = { "total_updates": total_updates, "processed": self.processed_count, "errors": self.error_count, "success_rate": (self.processed_count / total_updates) * 100 if total_updates > 0 else 0 } database_logger.info(f"Bulk update completed", **summary) return summary def safe_db_operation(operation: Callable, db_session: Session, default_return: Any = None) -> Any: """ Safely execute a database operation with automatic rollback on errors. Args: operation: Function to execute (should accept db_session as parameter) db_session: SQLAlchemy session default_return: Value to return on failure Returns: Result of operation or default_return on failure """ try: with db_transaction(db_session, auto_rollback=True) as session: return operation(session) except Exception as e: database_logger.error(f"Database operation failed: {str(e)}") return default_return def execute_with_retry( operation: Callable, db_session: Session, max_retries: int = 3, retry_delay: float = 1.0 ) -> Any: """ Execute database operation with retry logic for transient failures. Args: operation: Function to execute db_session: SQLAlchemy session max_retries: Maximum number of retry attempts retry_delay: Delay between retries in seconds Returns: Result of successful operation Raises: DatabaseError: If all retry attempts fail """ import time last_exception = None for attempt in range(max_retries + 1): try: with db_transaction(db_session) as session: return operation(session) except SQLAlchemyError as e: last_exception = e if attempt < max_retries: database_logger.warning( f"Database operation failed (attempt {attempt + 1}/{max_retries + 1}): {str(e)}" ) time.sleep(retry_delay) else: database_logger.error(f"Database operation failed after {max_retries + 1} attempts: {str(e)}") raise DatabaseError(f"Operation failed after {max_retries + 1} attempts: {str(last_exception)}")