remove old import
This commit is contained in:
326
app/utils/database.py
Normal file
326
app/utils/database.py
Normal file
@@ -0,0 +1,326 @@
|
||||
"""
|
||||
Database transaction management utilities for consistent transaction handling.
|
||||
"""
|
||||
from typing import Callable, Any, Optional, TypeVar, Type
|
||||
from functools import wraps
|
||||
from contextlib import contextmanager
|
||||
import logging
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from app.utils.exceptions import DatabaseError, handle_database_errors
|
||||
from app.utils.logging import database_logger
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
class TransactionManager:
|
||||
"""Context manager for database transactions with automatic rollback on errors."""
|
||||
|
||||
def __init__(self, db_session: Session, auto_commit: bool = True, auto_rollback: bool = True):
|
||||
self.db_session = db_session
|
||||
self.auto_commit = auto_commit
|
||||
self.auto_rollback = auto_rollback
|
||||
self.committed = False
|
||||
self.rolled_back = False
|
||||
|
||||
def __enter__(self) -> Session:
|
||||
"""Enter transaction context."""
|
||||
database_logger.log_transaction_event("started", {
|
||||
"auto_commit": self.auto_commit,
|
||||
"auto_rollback": self.auto_rollback
|
||||
})
|
||||
return self.db_session
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Exit transaction context with appropriate commit/rollback."""
|
||||
try:
|
||||
if exc_type is not None:
|
||||
# Exception occurred
|
||||
if self.auto_rollback and not self.rolled_back:
|
||||
self.rollback()
|
||||
database_logger.log_transaction_event("auto_rollback", {
|
||||
"exception_type": exc_type.__name__ if exc_type else None,
|
||||
"exception_message": str(exc_val) if exc_val else None
|
||||
})
|
||||
return False # Re-raise the exception
|
||||
else:
|
||||
# No exception
|
||||
if self.auto_commit and not self.committed:
|
||||
self.commit()
|
||||
database_logger.log_transaction_event("auto_commit")
|
||||
except Exception as e:
|
||||
# Error during commit/rollback
|
||||
database_logger.error(f"Error during transaction cleanup: {str(e)}")
|
||||
if not self.rolled_back:
|
||||
try:
|
||||
self.rollback()
|
||||
except Exception:
|
||||
pass # Best effort rollback
|
||||
raise
|
||||
|
||||
def commit(self):
|
||||
"""Manually commit the transaction."""
|
||||
if not self.committed and not self.rolled_back:
|
||||
try:
|
||||
self.db_session.commit()
|
||||
self.committed = True
|
||||
database_logger.log_transaction_event("manual_commit")
|
||||
except SQLAlchemyError as e:
|
||||
database_logger.error(f"Transaction commit failed: {str(e)}")
|
||||
self.rollback()
|
||||
raise DatabaseError(f"Failed to commit transaction: {str(e)}")
|
||||
|
||||
def rollback(self):
|
||||
"""Manually rollback the transaction."""
|
||||
if not self.rolled_back:
|
||||
try:
|
||||
self.db_session.rollback()
|
||||
self.rolled_back = True
|
||||
database_logger.log_transaction_event("manual_rollback")
|
||||
except SQLAlchemyError as e:
|
||||
database_logger.error(f"Transaction rollback failed: {str(e)}")
|
||||
raise DatabaseError(f"Failed to rollback transaction: {str(e)}")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def db_transaction(db_session: Session, auto_commit: bool = True, auto_rollback: bool = True):
|
||||
"""
|
||||
Context manager for database transactions.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
auto_commit: Whether to auto-commit on successful completion
|
||||
auto_rollback: Whether to auto-rollback on exceptions
|
||||
|
||||
Yields:
|
||||
Session: The database session
|
||||
|
||||
Example:
|
||||
with db_transaction(db) as session:
|
||||
session.add(new_record)
|
||||
# Auto-commits on exit if no exceptions
|
||||
"""
|
||||
with TransactionManager(db_session, auto_commit, auto_rollback) as session:
|
||||
yield session
|
||||
|
||||
|
||||
def transactional(auto_commit: bool = True, auto_rollback: bool = True):
|
||||
"""
|
||||
Decorator to wrap functions in database transactions.
|
||||
|
||||
Args:
|
||||
auto_commit: Whether to auto-commit on successful completion
|
||||
auto_rollback: Whether to auto-rollback on exceptions
|
||||
|
||||
Note:
|
||||
The decorated function must accept a 'db' parameter that is a SQLAlchemy session.
|
||||
|
||||
Example:
|
||||
@transactional()
|
||||
def create_user(user_data: dict, db: Session):
|
||||
user = User(**user_data)
|
||||
db.add(user)
|
||||
return user
|
||||
"""
|
||||
def decorator(func: Callable[..., T]) -> Callable[..., T]:
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs) -> T:
|
||||
# Find the db session in arguments
|
||||
db_session = None
|
||||
|
||||
# Check kwargs first
|
||||
if 'db' in kwargs:
|
||||
db_session = kwargs['db']
|
||||
else:
|
||||
# Check args for Session instance
|
||||
for arg in args:
|
||||
if isinstance(arg, Session):
|
||||
db_session = arg
|
||||
break
|
||||
|
||||
if db_session is None:
|
||||
raise ValueError("Function must have a 'db' parameter with SQLAlchemy Session")
|
||||
|
||||
with db_transaction(db_session, auto_commit, auto_rollback):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
class BulkOperationManager:
|
||||
"""Manager for bulk database operations with batching and progress tracking."""
|
||||
|
||||
def __init__(self, db_session: Session, batch_size: int = 1000):
|
||||
self.db_session = db_session
|
||||
self.batch_size = batch_size
|
||||
self.processed_count = 0
|
||||
self.error_count = 0
|
||||
|
||||
def bulk_insert(self, records: list, model_class: Type):
|
||||
"""
|
||||
Perform bulk insert with batching and error handling.
|
||||
|
||||
Args:
|
||||
records: List of record dictionaries
|
||||
model_class: SQLAlchemy model class
|
||||
|
||||
Returns:
|
||||
dict: Summary of operation results
|
||||
"""
|
||||
total_records = len(records)
|
||||
database_logger.info(f"Starting bulk insert of {total_records} {model_class.__name__} records")
|
||||
|
||||
try:
|
||||
with db_transaction(self.db_session) as session:
|
||||
for i in range(0, total_records, self.batch_size):
|
||||
batch = records[i:i + self.batch_size]
|
||||
|
||||
try:
|
||||
# Insert batch
|
||||
session.bulk_insert_mappings(model_class, batch)
|
||||
session.flush() # Flush but don't commit yet
|
||||
|
||||
self.processed_count += len(batch)
|
||||
|
||||
# Log progress
|
||||
if self.processed_count % (self.batch_size * 10) == 0:
|
||||
database_logger.info(f"Bulk insert progress: {self.processed_count}/{total_records}")
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
self.error_count += len(batch)
|
||||
database_logger.error(f"Batch insert failed for records {i}-{i+len(batch)}: {str(e)}")
|
||||
raise DatabaseError(f"Bulk insert failed: {str(e)}")
|
||||
|
||||
# Final commit happens automatically via context manager
|
||||
|
||||
except Exception as e:
|
||||
database_logger.error(f"Bulk insert operation failed: {str(e)}")
|
||||
raise
|
||||
|
||||
summary = {
|
||||
"total_records": total_records,
|
||||
"processed": self.processed_count,
|
||||
"errors": self.error_count,
|
||||
"success_rate": (self.processed_count / total_records) * 100 if total_records > 0 else 0
|
||||
}
|
||||
|
||||
database_logger.info(f"Bulk insert completed", **summary)
|
||||
return summary
|
||||
|
||||
def bulk_update(self, updates: list, model_class: Type, key_field: str = 'id'):
|
||||
"""
|
||||
Perform bulk update with batching and error handling.
|
||||
|
||||
Args:
|
||||
updates: List of update dictionaries (must include key_field)
|
||||
model_class: SQLAlchemy model class
|
||||
key_field: Field name to use as update key
|
||||
|
||||
Returns:
|
||||
dict: Summary of operation results
|
||||
"""
|
||||
total_updates = len(updates)
|
||||
database_logger.info(f"Starting bulk update of {total_updates} {model_class.__name__} records")
|
||||
|
||||
try:
|
||||
with db_transaction(self.db_session) as session:
|
||||
for i in range(0, total_updates, self.batch_size):
|
||||
batch = updates[i:i + self.batch_size]
|
||||
|
||||
try:
|
||||
# Update batch
|
||||
session.bulk_update_mappings(model_class, batch)
|
||||
session.flush() # Flush but don't commit yet
|
||||
|
||||
self.processed_count += len(batch)
|
||||
|
||||
# Log progress
|
||||
if self.processed_count % (self.batch_size * 10) == 0:
|
||||
database_logger.info(f"Bulk update progress: {self.processed_count}/{total_updates}")
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
self.error_count += len(batch)
|
||||
database_logger.error(f"Batch update failed for records {i}-{i+len(batch)}: {str(e)}")
|
||||
raise DatabaseError(f"Bulk update failed: {str(e)}")
|
||||
|
||||
# Final commit happens automatically via context manager
|
||||
|
||||
except Exception as e:
|
||||
database_logger.error(f"Bulk update operation failed: {str(e)}")
|
||||
raise
|
||||
|
||||
summary = {
|
||||
"total_updates": total_updates,
|
||||
"processed": self.processed_count,
|
||||
"errors": self.error_count,
|
||||
"success_rate": (self.processed_count / total_updates) * 100 if total_updates > 0 else 0
|
||||
}
|
||||
|
||||
database_logger.info(f"Bulk update completed", **summary)
|
||||
return summary
|
||||
|
||||
|
||||
def safe_db_operation(operation: Callable, db_session: Session, default_return: Any = None) -> Any:
|
||||
"""
|
||||
Safely execute a database operation with automatic rollback on errors.
|
||||
|
||||
Args:
|
||||
operation: Function to execute (should accept db_session as parameter)
|
||||
db_session: SQLAlchemy session
|
||||
default_return: Value to return on failure
|
||||
|
||||
Returns:
|
||||
Result of operation or default_return on failure
|
||||
"""
|
||||
try:
|
||||
with db_transaction(db_session, auto_rollback=True) as session:
|
||||
return operation(session)
|
||||
except Exception as e:
|
||||
database_logger.error(f"Database operation failed: {str(e)}")
|
||||
return default_return
|
||||
|
||||
|
||||
def execute_with_retry(
|
||||
operation: Callable,
|
||||
db_session: Session,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0
|
||||
) -> Any:
|
||||
"""
|
||||
Execute database operation with retry logic for transient failures.
|
||||
|
||||
Args:
|
||||
operation: Function to execute
|
||||
db_session: SQLAlchemy session
|
||||
max_retries: Maximum number of retry attempts
|
||||
retry_delay: Delay between retries in seconds
|
||||
|
||||
Returns:
|
||||
Result of successful operation
|
||||
|
||||
Raises:
|
||||
DatabaseError: If all retry attempts fail
|
||||
"""
|
||||
import time
|
||||
|
||||
last_exception = None
|
||||
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
with db_transaction(db_session) as session:
|
||||
return operation(session)
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
last_exception = e
|
||||
|
||||
if attempt < max_retries:
|
||||
database_logger.warning(
|
||||
f"Database operation failed (attempt {attempt + 1}/{max_retries + 1}): {str(e)}"
|
||||
)
|
||||
time.sleep(retry_delay)
|
||||
else:
|
||||
database_logger.error(f"Database operation failed after {max_retries + 1} attempts: {str(e)}")
|
||||
|
||||
raise DatabaseError(f"Operation failed after {max_retries + 1} attempts: {str(last_exception)}")
|
||||
Reference in New Issue
Block a user