diff --git a/app/import_export/generic_importer.py b/app/import_export/generic_importer.py index dddea89..6b80bca 100644 --- a/app/import_export/generic_importer.py +++ b/app/import_export/generic_importer.py @@ -15,6 +15,27 @@ from .base import BaseCSVImporter, ImportResult logger = logging.getLogger(__name__) +# SQL reserved keywords that need to be quoted when used as column names +SQL_RESERVED_KEYWORDS = { + 'ABORT', 'ACTION', 'ADD', 'AFTER', 'ALL', 'ALTER', 'ALWAYS', 'ANALYZE', 'AND', 'AS', 'ASC', + 'ATTACH', 'AUTOINCREMENT', 'BEFORE', 'BEGIN', 'BETWEEN', 'BY', 'CASCADE', 'CASE', 'CAST', + 'CHECK', 'COLLATE', 'COLUMN', 'COMMIT', 'CONFLICT', 'CONSTRAINT', 'CREATE', 'CROSS', + 'CURRENT', 'CURRENT_DATE', 'CURRENT_TIME', 'CURRENT_TIMESTAMP', 'DATABASE', 'DEFAULT', + 'DEFERRABLE', 'DEFERRED', 'DELETE', 'DESC', 'DETACH', 'DISTINCT', 'DO', 'DROP', 'EACH', + 'ELSE', 'END', 'ESCAPE', 'EXCEPT', 'EXCLUDE', 'EXISTS', 'EXPLAIN', 'FAIL', 'FILTER', + 'FIRST', 'FOLLOWING', 'FOR', 'FOREIGN', 'FROM', 'FULL', 'GENERATED', 'GLOB', 'GROUP', + 'GROUPS', 'HAVING', 'IF', 'IGNORE', 'IMMEDIATE', 'IN', 'INDEX', 'INDEXED', 'INITIALLY', + 'INNER', 'INSERT', 'INSTEAD', 'INTERSECT', 'INTO', 'IS', 'ISNULL', 'JOIN', 'KEY', + 'LAST', 'LEFT', 'LIKE', 'LIMIT', 'MATCH', 'NATURAL', 'NO', 'NOT', 'NOTHING', 'NOTNULL', + 'NULL', 'NULLS', 'OF', 'OFFSET', 'ON', 'OR', 'ORDER', 'OUTER', 'OVER', 'PARTITION', + 'PLAN', 'PRAGMA', 'PRECEDING', 'PRIMARY', 'QUERY', 'RAISE', 'RECURSIVE', 'REFERENCES', + 'REGEXP', 'REINDEX', 'RELEASE', 'RENAME', 'REPLACE', 'RESTRICT', 'RIGHT', 'ROLLBACK', + 'ROW', 'ROWS', 'SAVEPOINT', 'SELECT', 'SET', 'TABLE', 'TEMP', 'TEMPORARY', 'THEN', + 'TIES', 'TO', 'TRANSACTION', 'TRIGGER', 'UNBOUNDED', 'UNION', 'UNIQUE', 'UPDATE', + 'USING', 'VACUUM', 'VALUES', 'VIEW', 'VIRTUAL', 'WHEN', 'WHERE', 'WINDOW', 'WITH', + 'WITHOUT' +} + class GenericCSVImporter(BaseCSVImporter): """Generic importer that can handle any CSV structure by creating tables dynamically""" @@ -69,12 +90,16 @@ class GenericCSVImporter(BaseCSVImporter): existing_tables = inspector.get_table_names() if safe_table_name in existing_tables: - logger.info(f"Table '{safe_table_name}' already exists, using unique table name") - # Instead of trying to drop, create a new table with timestamp suffix - import time - timestamp = str(int(time.time())) - safe_table_name = f"{safe_table_name}_{timestamp}" - logger.info(f"Creating new table with unique name: '{safe_table_name}'") + logger.info(f"Table '{safe_table_name}' already exists, will use existing table structure") + # Reflect the existing table to get its structure + metadata.reflect(bind=self.db_session.bind, only=[safe_table_name]) + existing_table = metadata.tables[safe_table_name] + + # Store the actual table name for use in data insertion + self.actual_table_name = safe_table_name + self._table_name = safe_table_name + logger.info(f"Using existing table: '{safe_table_name}'") + return existing_table else: logger.info(f"Creating new table: '{safe_table_name}'") @@ -157,6 +182,12 @@ class GenericCSVImporter(BaseCSVImporter): elif safe_name and not (safe_name[0].isalpha() or safe_name[0] == '_'): safe_name = 'col_' + safe_name return safe_name.lower() + + def _quote_column_name(self, column_name: str) -> str: + """Quote column name if it's a SQL reserved keyword""" + if column_name.upper() in SQL_RESERVED_KEYWORDS: + return f'"{column_name}"' + return column_name def _parse_date_value(self, value: str) -> Optional[str]: """Try to parse a date value and return it in ISO format""" @@ -287,9 +318,8 @@ class GenericCSVImporter(BaseCSVImporter): self.result.add_warning("File contains headers only, no data rows found") return self.result - # Process all rows in a single transaction + # Process all rows (transaction managed by session) try: - self.db_session.begin() for row_num, row in enumerate(rows, start=2): total_count += 1 @@ -297,33 +327,74 @@ class GenericCSVImporter(BaseCSVImporter): try: # Prepare row data row_data = {} + + # Get existing table columns if using existing table + existing_columns = set() + if hasattr(self, 'dynamic_table') and self.dynamic_table is not None: + # Convert column keys to strings for comparison + existing_columns = set(str(col) for col in self.dynamic_table.columns.keys()) + for header in self.csv_headers: - safe_column_name = self._make_safe_name(header) - # Handle 'id' column renaming for conflict avoidance - if safe_column_name.lower() == 'id': - safe_column_name = 'csv_id' - value = row.get(header, '').strip() if row.get(header) else None - # Convert empty strings to None for better database handling - if value == '': - value = None - elif value and ('date' in header.lower() or 'time' in header.lower()): - # Try to parse date values for better format consistency - value = self._parse_date_value(value) - row_data[safe_column_name] = value + try: + safe_column_name = self._make_safe_name(header) + + # Handle 'id' column mapping for existing tables + if safe_column_name.lower() == 'id' and 'id' in existing_columns: + # For existing tables, try to map the CSV 'id' to the actual 'id' column + # Check if id column has autoincrement - but handle this safely + try: + id_col = self.dynamic_table.columns.id + # Check if autoincrement is True (SQLAlchemy may not define this attribute) + is_autoincrement = getattr(id_col, 'autoincrement', False) is True + if is_autoincrement: + safe_column_name = 'csv_id' # Avoid conflict with auto-increment + else: + safe_column_name = 'id' # Use the actual id column + except (AttributeError, TypeError): + # If we can't determine autoincrement, default to using 'id' + safe_column_name = 'id' + elif safe_column_name.lower() == 'id': + safe_column_name = 'csv_id' # Default fallback + + # Only include columns that exist in the target table (if using existing table) + if existing_columns and safe_column_name not in existing_columns: + logger.debug(f"Skipping column '{safe_column_name}' (from '{header}') - not found in target table") + continue + + value = row.get(header, '').strip() if row.get(header) else None + # Convert empty strings to None for better database handling + if value == '': + value = None + elif value and ('date' in header.lower() or 'time' in header.lower()): + # Try to parse date values for better format consistency + value = self._parse_date_value(value) + row_data[safe_column_name] = value + + except Exception as header_error: + logger.error(f"Error processing header '{header}': {header_error}") + # Continue to next header instead of failing the whole row + continue # Insert into database with conflict resolution # Use INSERT OR IGNORE to handle potential duplicates gracefully # Use the actual table name (which may have timestamp suffix) instead of dynamic_table.name table_name = getattr(self, 'actual_table_name', self.dynamic_table.name) logger.debug(f"Inserting into table: '{table_name}' (original: '{self._table_name}', dynamic: '{self.dynamic_table.name}')") + + if not row_data: + logger.warning(f"Row {row_num}: No valid columns found for insertion") + continue + columns = list(row_data.keys()) values = list(row_data.values()) placeholders = ', '.join([':param' + str(i) for i in range(len(values))]) - column_names = ', '.join(columns) - + # Quote column names that are reserved keywords + quoted_columns = [self._quote_column_name(col) for col in columns] + column_names = ', '.join(quoted_columns) + # Create parameter dictionary for SQLAlchemy params = {f'param{i}': value for i, value in enumerate(values)} - + ignore_sql = f"INSERT OR IGNORE INTO {table_name} ({column_names}) VALUES ({placeholders})" result = self.db_session.execute(text(ignore_sql), params) @@ -352,13 +423,12 @@ class GenericCSVImporter(BaseCSVImporter): logger.warning(f"Error importing row {row_num}: {e}") continue - # Commit all changes - self.db_session.commit() + # Changes are automatically committed by the session manager + pass except Exception as transaction_error: - self.db_session.rollback() - logger.error(f"Transaction failed, rolled back: {transaction_error}") - self.result.add_error(f"Transaction failed: {str(transaction_error)}") + logger.error(f"Import processing failed: {transaction_error}") + self.result.add_error(f"Import failed: {str(transaction_error)}") # Update result self.result.success = imported_count > 0