""" Base classes for CSV import functionality """ from abc import ABC, abstractmethod from typing import Dict, List, Any, Optional, Tuple import csv import io from datetime import datetime, date import logging import uuid from sqlalchemy.orm import Session from sqlalchemy.exc import IntegrityError, SQLAlchemyError from .logging_config import create_import_logger, ImportMetrics logger = logging.getLogger(__name__) class ImportResult: """Container for import operation results""" def __init__(self): self.success = False self.total_rows = 0 self.imported_rows = 0 self.skipped_rows = 0 self.error_rows = 0 self.errors: List[str] = [] self.warnings: List[str] = [] self.import_id = None def add_error(self, error: str): """Add an error message""" self.errors.append(error) self.error_rows += 1 def add_warning(self, warning: str): """Add a warning message""" self.warnings.append(warning) def to_dict(self) -> Dict[str, Any]: """Convert result to dictionary for JSON response""" return { "success": self.success, "total_rows": self.total_rows, "imported_rows": self.imported_rows, "skipped_rows": self.skipped_rows, "error_rows": self.error_rows, "errors": self.errors, "warnings": self.warnings, "import_id": self.import_id } class BaseCSVImporter(ABC): """Abstract base class for all CSV importers""" def __init__(self, db_session: Session, import_id: Optional[str] = None): self.db_session = db_session self.result = ImportResult() self.import_id = import_id or str(uuid.uuid4()) self.result.import_id = self.import_id self.import_logger = create_import_logger(self.import_id, self.table_name) self.metrics = ImportMetrics() @property @abstractmethod def table_name(self) -> str: """Name of the database table being imported to""" pass @property @abstractmethod def required_fields(self) -> List[str]: """List of required field names""" pass @property @abstractmethod def field_mapping(self) -> Dict[str, str]: """Mapping from CSV headers to database field names""" pass @abstractmethod def create_model_instance(self, row_data: Dict[str, Any]) -> Any: """Create a model instance from processed row data""" pass def parse_date(self, date_str: str) -> Optional[date]: """Parse date string to date object""" if not date_str or date_str.strip() == "": return None date_str = date_str.strip() # Try common date formats formats = [ "%Y-%m-%d", # ISO format "%m/%d/%Y", # US format "%m/%d/%y", # US format 2-digit year "%d/%m/%Y", # European format "%Y%m%d", # Compact format ] for fmt in formats: try: return datetime.strptime(date_str, fmt).date() except ValueError: continue raise ValueError(f"Unable to parse date: {date_str}") def parse_float(self, value_str: str) -> float: """Parse string to float, handling empty values""" if not value_str or value_str.strip() == "": return 0.0 value_str = value_str.strip().replace(",", "") # Remove commas try: return float(value_str) except ValueError: raise ValueError(f"Unable to parse float: {value_str}") def parse_int(self, value_str: str) -> int: """Parse string to int, handling empty values""" if not value_str or value_str.strip() == "": return 0 value_str = value_str.strip().replace(",", "") # Remove commas try: return int(float(value_str)) # Handle "1.0" format except ValueError: raise ValueError(f"Unable to parse integer: {value_str}") def normalize_string(self, value: str, max_length: Optional[int] = None) -> str: """Normalize string value""" if not value: return "" value = str(value).strip() if max_length and len(value) > max_length: self.result.add_warning(f"String truncated from {len(value)} to {max_length} characters: {value[:50]}...") value = value[:max_length] return value def detect_delimiter(self, csv_content: str) -> str: """Auto-detect CSV delimiter""" sample = csv_content[:1024] # Check first 1KB sniffer = csv.Sniffer() try: dialect = sniffer.sniff(sample, delimiters=",;\t|") return dialect.delimiter except: return "," # Default to comma def validate_headers(self, headers: List[str]) -> bool: """Validate that required headers are present""" missing_required = [] # Create case-insensitive mapping of headers header_map = {h.lower().strip(): h for h in headers} for required_field in self.required_fields: # Check direct match first if required_field in headers: continue # Check if there's a mapping for this field mapped_name = self.field_mapping.get(required_field, required_field) if mapped_name.lower() in header_map: continue missing_required.append(required_field) if missing_required: self.result.add_error(f"Missing required columns: {', '.join(missing_required)}") return False return True def map_row_data(self, row: Dict[str, str], headers: List[str]) -> Dict[str, Any]: """Map CSV row data to database field names""" mapped_data = {} # Create case-insensitive lookup row_lookup = {k.lower().strip(): v for k, v in row.items() if k} for db_field, csv_field in self.field_mapping.items(): csv_field_lower = csv_field.lower().strip() # Try exact match first if csv_field in row: mapped_data[db_field] = row[csv_field] # Try case-insensitive match elif csv_field_lower in row_lookup: mapped_data[db_field] = row_lookup[csv_field_lower] else: mapped_data[db_field] = "" return mapped_data def process_csv_content(self, csv_content: str, encoding: str = "utf-8") -> ImportResult: """Process CSV content and import data""" self.import_logger.info(f"Starting CSV import for {self.table_name}") try: # Detect delimiter delimiter = self.detect_delimiter(csv_content) self.import_logger.debug(f"Detected CSV delimiter: '{delimiter}'") # Parse CSV csv_reader = csv.DictReader( io.StringIO(csv_content), delimiter=delimiter ) headers = csv_reader.fieldnames or [] if not headers: error_msg = "No headers found in CSV file" self.result.add_error(error_msg) self.import_logger.error(error_msg) return self.result self.import_logger.info(f"Found headers: {headers}") # Validate headers if not self.validate_headers(headers): self.import_logger.error("Header validation failed") return self.result self.import_logger.info("Header validation passed") # Process rows imported_count = 0 total_count = 0 for row_num, row in enumerate(csv_reader, 1): total_count += 1 self.metrics.total_rows = total_count try: # Map CSV data to database fields mapped_data = self.map_row_data(row, headers) # Create model instance model_instance = self.create_model_instance(mapped_data) # Add to session self.db_session.add(model_instance) imported_count += 1 self.import_logger.log_row_processed(row_num, success=True) self.metrics.record_row_processed(success=True) except ImportValidationError as e: error_msg = f"Row {row_num}: {str(e)}" self.result.add_error(error_msg) self.import_logger.log_row_processed(row_num, success=False) self.import_logger.log_validation_error(row_num, "validation", row, str(e)) self.metrics.record_validation_error(row_num, str(e)) except Exception as e: error_msg = f"Row {row_num}: Unexpected error - {str(e)}" self.result.add_error(error_msg) self.import_logger.log_row_processed(row_num, success=False) self.import_logger.error(error_msg, row_number=row_num, exception_type=type(e).__name__) self.metrics.record_validation_error(row_num, str(e)) # Commit transaction try: self.db_session.commit() self.result.success = True self.result.imported_rows = imported_count self.import_logger.info(f"Successfully committed {imported_count} rows to database") logger.info(f"Successfully imported {imported_count} rows to {self.table_name}") except (IntegrityError, SQLAlchemyError) as e: self.db_session.rollback() error_msg = f"Database error during commit: {str(e)}" self.result.add_error(error_msg) self.import_logger.error(error_msg) self.metrics.record_database_error(str(e)) logger.error(f"Database error importing to {self.table_name}: {str(e)}") self.result.total_rows = total_count self.metrics.finalize() # Log final summary self.import_logger.log_import_summary( total_count, imported_count, self.result.error_rows ) except Exception as e: self.db_session.rollback() error_msg = f"Failed to process CSV: {str(e)}" self.result.add_error(error_msg) self.import_logger.error(error_msg, exception_type=type(e).__name__) self.metrics.record_database_error(str(e)) logger.error(f"CSV processing error for {self.table_name}: {str(e)}") return self.result class ImportValidationError(Exception): """Exception raised for validation errors during import""" pass