326 lines
12 KiB
Python
326 lines
12 KiB
Python
"""
|
|
Database transaction management utilities for consistent transaction handling.
|
|
"""
|
|
from typing import Callable, Any, Optional, TypeVar, Type
|
|
from functools import wraps
|
|
from contextlib import contextmanager
|
|
import logging
|
|
from sqlalchemy.orm import Session
|
|
from sqlalchemy.exc import SQLAlchemyError
|
|
from app.utils.exceptions import DatabaseError, handle_database_errors
|
|
from app.utils.logging import database_logger
|
|
|
|
|
|
T = TypeVar('T')
|
|
|
|
|
|
class TransactionManager:
|
|
"""Context manager for database transactions with automatic rollback on errors."""
|
|
|
|
def __init__(self, db_session: Session, auto_commit: bool = True, auto_rollback: bool = True):
|
|
self.db_session = db_session
|
|
self.auto_commit = auto_commit
|
|
self.auto_rollback = auto_rollback
|
|
self.committed = False
|
|
self.rolled_back = False
|
|
|
|
def __enter__(self) -> Session:
|
|
"""Enter transaction context."""
|
|
database_logger.log_transaction_event("started", {
|
|
"auto_commit": self.auto_commit,
|
|
"auto_rollback": self.auto_rollback
|
|
})
|
|
return self.db_session
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
"""Exit transaction context with appropriate commit/rollback."""
|
|
try:
|
|
if exc_type is not None:
|
|
# Exception occurred
|
|
if self.auto_rollback and not self.rolled_back:
|
|
self.rollback()
|
|
database_logger.log_transaction_event("auto_rollback", {
|
|
"exception_type": exc_type.__name__ if exc_type else None,
|
|
"exception_message": str(exc_val) if exc_val else None
|
|
})
|
|
return False # Re-raise the exception
|
|
else:
|
|
# No exception
|
|
if self.auto_commit and not self.committed:
|
|
self.commit()
|
|
database_logger.log_transaction_event("auto_commit")
|
|
except Exception as e:
|
|
# Error during commit/rollback
|
|
database_logger.error(f"Error during transaction cleanup: {str(e)}")
|
|
if not self.rolled_back:
|
|
try:
|
|
self.rollback()
|
|
except Exception:
|
|
pass # Best effort rollback
|
|
raise
|
|
|
|
def commit(self):
|
|
"""Manually commit the transaction."""
|
|
if not self.committed and not self.rolled_back:
|
|
try:
|
|
self.db_session.commit()
|
|
self.committed = True
|
|
database_logger.log_transaction_event("manual_commit")
|
|
except SQLAlchemyError as e:
|
|
database_logger.error(f"Transaction commit failed: {str(e)}")
|
|
self.rollback()
|
|
raise DatabaseError(f"Failed to commit transaction: {str(e)}")
|
|
|
|
def rollback(self):
|
|
"""Manually rollback the transaction."""
|
|
if not self.rolled_back:
|
|
try:
|
|
self.db_session.rollback()
|
|
self.rolled_back = True
|
|
database_logger.log_transaction_event("manual_rollback")
|
|
except SQLAlchemyError as e:
|
|
database_logger.error(f"Transaction rollback failed: {str(e)}")
|
|
raise DatabaseError(f"Failed to rollback transaction: {str(e)}")
|
|
|
|
|
|
@contextmanager
|
|
def db_transaction(db_session: Session, auto_commit: bool = True, auto_rollback: bool = True):
|
|
"""
|
|
Context manager for database transactions.
|
|
|
|
Args:
|
|
db_session: SQLAlchemy session
|
|
auto_commit: Whether to auto-commit on successful completion
|
|
auto_rollback: Whether to auto-rollback on exceptions
|
|
|
|
Yields:
|
|
Session: The database session
|
|
|
|
Example:
|
|
with db_transaction(db) as session:
|
|
session.add(new_record)
|
|
# Auto-commits on exit if no exceptions
|
|
"""
|
|
with TransactionManager(db_session, auto_commit, auto_rollback) as session:
|
|
yield session
|
|
|
|
|
|
def transactional(auto_commit: bool = True, auto_rollback: bool = True):
|
|
"""
|
|
Decorator to wrap functions in database transactions.
|
|
|
|
Args:
|
|
auto_commit: Whether to auto-commit on successful completion
|
|
auto_rollback: Whether to auto-rollback on exceptions
|
|
|
|
Note:
|
|
The decorated function must accept a 'db' parameter that is a SQLAlchemy session.
|
|
|
|
Example:
|
|
@transactional()
|
|
def create_user(user_data: dict, db: Session):
|
|
user = User(**user_data)
|
|
db.add(user)
|
|
return user
|
|
"""
|
|
def decorator(func: Callable[..., T]) -> Callable[..., T]:
|
|
@wraps(func)
|
|
def wrapper(*args, **kwargs) -> T:
|
|
# Find the db session in arguments
|
|
db_session = None
|
|
|
|
# Check kwargs first
|
|
if 'db' in kwargs:
|
|
db_session = kwargs['db']
|
|
else:
|
|
# Check args for Session instance
|
|
for arg in args:
|
|
if isinstance(arg, Session):
|
|
db_session = arg
|
|
break
|
|
|
|
if db_session is None:
|
|
raise ValueError("Function must have a 'db' parameter with SQLAlchemy Session")
|
|
|
|
with db_transaction(db_session, auto_commit, auto_rollback):
|
|
return func(*args, **kwargs)
|
|
|
|
return wrapper
|
|
return decorator
|
|
|
|
|
|
class BulkOperationManager:
|
|
"""Manager for bulk database operations with batching and progress tracking."""
|
|
|
|
def __init__(self, db_session: Session, batch_size: int = 1000):
|
|
self.db_session = db_session
|
|
self.batch_size = batch_size
|
|
self.processed_count = 0
|
|
self.error_count = 0
|
|
|
|
def bulk_insert(self, records: list, model_class: Type):
|
|
"""
|
|
Perform bulk insert with batching and error handling.
|
|
|
|
Args:
|
|
records: List of record dictionaries
|
|
model_class: SQLAlchemy model class
|
|
|
|
Returns:
|
|
dict: Summary of operation results
|
|
"""
|
|
total_records = len(records)
|
|
database_logger.info(f"Starting bulk insert of {total_records} {model_class.__name__} records")
|
|
|
|
try:
|
|
with db_transaction(self.db_session) as session:
|
|
for i in range(0, total_records, self.batch_size):
|
|
batch = records[i:i + self.batch_size]
|
|
|
|
try:
|
|
# Insert batch
|
|
session.bulk_insert_mappings(model_class, batch)
|
|
session.flush() # Flush but don't commit yet
|
|
|
|
self.processed_count += len(batch)
|
|
|
|
# Log progress
|
|
if self.processed_count % (self.batch_size * 10) == 0:
|
|
database_logger.info(f"Bulk insert progress: {self.processed_count}/{total_records}")
|
|
|
|
except SQLAlchemyError as e:
|
|
self.error_count += len(batch)
|
|
database_logger.error(f"Batch insert failed for records {i}-{i+len(batch)}: {str(e)}")
|
|
raise DatabaseError(f"Bulk insert failed: {str(e)}")
|
|
|
|
# Final commit happens automatically via context manager
|
|
|
|
except Exception as e:
|
|
database_logger.error(f"Bulk insert operation failed: {str(e)}")
|
|
raise
|
|
|
|
summary = {
|
|
"total_records": total_records,
|
|
"processed": self.processed_count,
|
|
"errors": self.error_count,
|
|
"success_rate": (self.processed_count / total_records) * 100 if total_records > 0 else 0
|
|
}
|
|
|
|
database_logger.info(f"Bulk insert completed", **summary)
|
|
return summary
|
|
|
|
def bulk_update(self, updates: list, model_class: Type, key_field: str = 'id'):
|
|
"""
|
|
Perform bulk update with batching and error handling.
|
|
|
|
Args:
|
|
updates: List of update dictionaries (must include key_field)
|
|
model_class: SQLAlchemy model class
|
|
key_field: Field name to use as update key
|
|
|
|
Returns:
|
|
dict: Summary of operation results
|
|
"""
|
|
total_updates = len(updates)
|
|
database_logger.info(f"Starting bulk update of {total_updates} {model_class.__name__} records")
|
|
|
|
try:
|
|
with db_transaction(self.db_session) as session:
|
|
for i in range(0, total_updates, self.batch_size):
|
|
batch = updates[i:i + self.batch_size]
|
|
|
|
try:
|
|
# Update batch
|
|
session.bulk_update_mappings(model_class, batch)
|
|
session.flush() # Flush but don't commit yet
|
|
|
|
self.processed_count += len(batch)
|
|
|
|
# Log progress
|
|
if self.processed_count % (self.batch_size * 10) == 0:
|
|
database_logger.info(f"Bulk update progress: {self.processed_count}/{total_updates}")
|
|
|
|
except SQLAlchemyError as e:
|
|
self.error_count += len(batch)
|
|
database_logger.error(f"Batch update failed for records {i}-{i+len(batch)}: {str(e)}")
|
|
raise DatabaseError(f"Bulk update failed: {str(e)}")
|
|
|
|
# Final commit happens automatically via context manager
|
|
|
|
except Exception as e:
|
|
database_logger.error(f"Bulk update operation failed: {str(e)}")
|
|
raise
|
|
|
|
summary = {
|
|
"total_updates": total_updates,
|
|
"processed": self.processed_count,
|
|
"errors": self.error_count,
|
|
"success_rate": (self.processed_count / total_updates) * 100 if total_updates > 0 else 0
|
|
}
|
|
|
|
database_logger.info(f"Bulk update completed", **summary)
|
|
return summary
|
|
|
|
|
|
def safe_db_operation(operation: Callable, db_session: Session, default_return: Any = None) -> Any:
|
|
"""
|
|
Safely execute a database operation with automatic rollback on errors.
|
|
|
|
Args:
|
|
operation: Function to execute (should accept db_session as parameter)
|
|
db_session: SQLAlchemy session
|
|
default_return: Value to return on failure
|
|
|
|
Returns:
|
|
Result of operation or default_return on failure
|
|
"""
|
|
try:
|
|
with db_transaction(db_session, auto_rollback=True) as session:
|
|
return operation(session)
|
|
except Exception as e:
|
|
database_logger.error(f"Database operation failed: {str(e)}")
|
|
return default_return
|
|
|
|
|
|
def execute_with_retry(
|
|
operation: Callable,
|
|
db_session: Session,
|
|
max_retries: int = 3,
|
|
retry_delay: float = 1.0
|
|
) -> Any:
|
|
"""
|
|
Execute database operation with retry logic for transient failures.
|
|
|
|
Args:
|
|
operation: Function to execute
|
|
db_session: SQLAlchemy session
|
|
max_retries: Maximum number of retry attempts
|
|
retry_delay: Delay between retries in seconds
|
|
|
|
Returns:
|
|
Result of successful operation
|
|
|
|
Raises:
|
|
DatabaseError: If all retry attempts fail
|
|
"""
|
|
import time
|
|
|
|
last_exception = None
|
|
|
|
for attempt in range(max_retries + 1):
|
|
try:
|
|
with db_transaction(db_session) as session:
|
|
return operation(session)
|
|
|
|
except SQLAlchemyError as e:
|
|
last_exception = e
|
|
|
|
if attempt < max_retries:
|
|
database_logger.warning(
|
|
f"Database operation failed (attempt {attempt + 1}/{max_retries + 1}): {str(e)}"
|
|
)
|
|
time.sleep(retry_delay)
|
|
else:
|
|
database_logger.error(f"Database operation failed after {max_retries + 1} attempts: {str(e)}")
|
|
|
|
raise DatabaseError(f"Operation failed after {max_retries + 1} attempts: {str(last_exception)}") |