diff --git a/app/api/import_data.py b/app/api/import_data.py index 156a6c1..e6b0e57 100644 --- a/app/api/import_data.py +++ b/app/api/import_data.py @@ -11,10 +11,10 @@ from difflib import SequenceMatcher from datetime import datetime, date, timezone from decimal import Decimal from typing import List, Dict, Any, Optional, Tuple -from fastapi import APIRouter, Depends, HTTPException, UploadFile, File as UploadFileForm, Form, Query +from fastapi import APIRouter, Depends, HTTPException, UploadFile, File as UploadFileForm, Form, Query, WebSocket from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session -from app.database.base import get_db +from app.database.base import get_db, SessionLocal from app.auth.security import get_current_user from app.models.user import User from app.models.rolodex import Rolodex, Phone @@ -28,9 +28,14 @@ from app.models.flexible import FlexibleImport from app.models.audit import ImportAudit, ImportAuditFile from app.config import settings from app.utils.logging import import_logger +from app.middleware.websocket_middleware import get_websocket_manager +from app.services.websocket_pool import WebSocketMessage router = APIRouter(tags=["import"]) +# WebSocket manager for import progress notifications +websocket_manager = get_websocket_manager() + # Common encodings to try for legacy CSV files (order matters) ENCODINGS = [ 'utf-8-sig', @@ -93,6 +98,82 @@ CSV_MODEL_MAPPING = { "RESULTS.csv": PensionResult } + +# ----------------------------- +# Progress aggregation helpers +# ----------------------------- +def _aggregate_batch_progress(db: Session, audit: ImportAudit) -> Dict[str, Any]: + """Compute progress summary for the given audit row.""" + processed_files = db.query(ImportAuditFile).filter(ImportAuditFile.audit_id == audit.id).count() + successful_files = db.query(ImportAuditFile).filter( + ImportAuditFile.audit_id == audit.id, + ImportAuditFile.status.in_(["success", "completed_with_errors", "skipped"]) + ).count() + failed_files = db.query(ImportAuditFile).filter( + ImportAuditFile.audit_id == audit.id, + ImportAuditFile.status == "failed" + ).count() + + total_files = audit.total_files or 0 + percent_complete: float = 0.0 + if total_files > 0: + try: + percent_complete = (processed_files / total_files) * 100.0 + except Exception: + percent_complete = 0.0 + + data = { + "audit_id": audit.id, + "status": audit.status, + "total_files": total_files, + "processed_files": processed_files, + "successful_files": successful_files, + "failed_files": failed_files, + "started_at": audit.started_at.isoformat() if audit.started_at else None, + "finished_at": audit.finished_at.isoformat() if audit.finished_at else None, + "percent": percent_complete, + "message": audit.message, + } + + try: + last_file = ( + db.query(ImportAuditFile) + .filter(ImportAuditFile.audit_id == audit.id) + .order_by(ImportAuditFile.id.desc()) + .first() + ) + if last_file: + data["last_file"] = { + "file_type": last_file.file_type, + "status": last_file.status, + "imported_count": last_file.imported_count, + "errors": last_file.errors, + "message": last_file.message, + "created_at": last_file.created_at.isoformat() if last_file.created_at else None, + } + except Exception: + pass + + return data + + +async def _broadcast_import_progress(db: Session, audit_id: int) -> None: + """Broadcast current progress for audit_id to its WebSocket topic.""" + audit = db.query(ImportAudit).filter(ImportAudit.id == audit_id).first() + if not audit: + return + payload = _aggregate_batch_progress(db, audit) + topic = f"import_batch_progress_{audit_id}" + try: + await websocket_manager.broadcast_to_topic( + topic=topic, + message_type="progress", + data=payload, + ) + except Exception: + # Non-fatal; continue API flow + pass + # Minimal CSV template definitions (headers + one sample row) used for template downloads CSV_IMPORT_TEMPLATES: Dict[str, Dict[str, List[str]]] = { "FILES.csv": { @@ -1920,6 +2001,11 @@ async def batch_import_csv_files( db.add(audit_row) db.commit() db.refresh(audit_row) + # Broadcast initial snapshot + try: + await _broadcast_import_progress(db, audit_row.id) + except Exception: + pass # Directory to persist uploaded files for this audit (for reruns) audit_dir = Path(settings.upload_dir).joinpath("import_audits", str(audit_row.id)) @@ -2017,6 +2103,10 @@ async def batch_import_csv_files( details={"saved_path": saved_path} if saved_path else {} )) db.commit() + try: + await _broadcast_import_progress(db, audit_row.id) + except Exception: + pass except Exception: db.rollback() continue @@ -2038,6 +2128,10 @@ async def batch_import_csv_files( details={} )) db.commit() + try: + await _broadcast_import_progress(db, audit_row.id) + except Exception: + pass except Exception: db.rollback() continue @@ -2087,6 +2181,10 @@ async def batch_import_csv_files( details={"saved_path": saved_path} if saved_path else {} )) db.commit() + try: + await _broadcast_import_progress(db, audit_row.id) + except Exception: + pass except Exception: db.rollback() continue @@ -2291,6 +2389,10 @@ async def batch_import_csv_files( } )) db.commit() + try: + await _broadcast_import_progress(db, audit_row.id) + except Exception: + pass except Exception: db.rollback() @@ -2312,6 +2414,10 @@ async def batch_import_csv_files( details={"saved_path": saved_path} if saved_path else {} )) db.commit() + try: + await _broadcast_import_progress(db, audit_row.id) + except Exception: + pass except Exception: db.rollback() @@ -2342,6 +2448,10 @@ async def batch_import_csv_files( } db.add(audit_row) db.commit() + try: + await _broadcast_import_progress(db, audit_row.id) + except Exception: + pass except Exception: db.rollback() @@ -2351,6 +2461,44 @@ async def batch_import_csv_files( } +@router.websocket("/batch-progress/ws/{audit_id}") +async def ws_import_batch_progress(websocket: WebSocket, audit_id: int): + """WebSocket: subscribe to real-time updates for an import audit using the pool.""" + topic = f"import_batch_progress_{audit_id}" + + async def handle_ws_message(connection_id: str, message: WebSocketMessage): + # No-op for now; reserved for future client messages + import_logger.debug( + "Import WS msg", + connection_id=connection_id, + audit_id=audit_id, + type=message.type, + ) + + connection_id = await websocket_manager.handle_connection( + websocket=websocket, + topics={topic}, + require_auth=True, + metadata={"audit_id": audit_id, "endpoint": "import_batch_progress"}, + message_handler=handle_ws_message, + ) + + if connection_id: + db = SessionLocal() + try: + audit = db.query(ImportAudit).filter(ImportAudit.id == audit_id).first() + payload = _aggregate_batch_progress(db, audit) if audit else None + if payload: + initial_message = WebSocketMessage( + type="progress", + topic=topic, + data=payload, + ) + await websocket_manager.pool._send_to_connection(connection_id, initial_message) + finally: + db.close() + + @router.get("/recent-batches") async def recent_batch_imports( limit: int = Query(5, ge=1, le=50), diff --git a/templates/import.html b/templates/import.html index 5444cdb..51f769f 100644 --- a/templates/import.html +++ b/templates/import.html @@ -882,7 +882,7 @@ function showProgress(show, message = '', percent = null) { // Batch progress monitoring // ----------------------------- const TERMINAL_BATCH_STATUSES = new Set(['success', 'completed_with_errors', 'failed']); -let batchProgress = { timer: null, auditId: null }; +let batchProgress = { timer: null, auditId: null, wsMgr: null }; async function fetchCurrentBatch() { try { @@ -899,6 +899,8 @@ function stopBatchProgressPolling() { batchProgress.timer = null; } batchProgress.auditId = null; + try { if (batchProgress.wsMgr) { batchProgress.wsMgr.close(); } } catch (_) {} + batchProgress.wsMgr = null; } async function pollBatchProgressOnce(auditId) { @@ -924,9 +926,48 @@ async function pollBatchProgressOnce(auditId) { function startBatchProgressPolling(auditId) { stopBatchProgressPolling(); batchProgress.auditId = auditId; - // immediate + interval polling - pollBatchProgressOnce(auditId); - batchProgress.timer = setInterval(() => pollBatchProgressOnce(auditId), 1500); + // Try WebSocket first; fallback to polling + try { + const mgr = new (window.notifications && window.notifications.NotificationManager ? window.notifications.NotificationManager : null)({ + getUrl: () => `/api/import/batch-progress/ws/${encodeURIComponent(auditId)}`, + onMessage: (msg) => { + if (!msg || !msg.type) return; + if (msg.type === 'progress') { + const p = msg.data || {}; + const percent = Number(p.percent || 0); + const total = Number(p.total_files || 0); + const processed = Number(p.processed_files || 0); + const status = String(p.status || 'running'); + const statusNice = status.replaceAll('_', ' '); + const msgText = total > 0 + ? `Processing ${processed}/${total} (${percent.toFixed(1)}%) · ${statusNice}` + : `Processing… ${statusNice}`; + showProgress(true, msgText, percent); + if (TERMINAL_BATCH_STATUSES.has(status)) { + stopBatchProgressPolling(); + } + } + }, + onStateChange: (state) => { + if (state === 'error' || state === 'closed' || state === 'offline') { + // fallback to polling + if (!batchProgress.timer) { + pollBatchProgressOnce(auditId); + batchProgress.timer = setInterval(() => pollBatchProgressOnce(auditId), 1500); + } + } + }, + autoConnect: true, + debug: false + }); + batchProgress.wsMgr = mgr; + // Safety: also do an immediate HTTP fetch as first snapshot + pollBatchProgressOnce(auditId); + } catch (_) { + // immediate + interval polling + pollBatchProgressOnce(auditId); + batchProgress.timer = setInterval(() => pollBatchProgressOnce(auditId), 1500); + } } async function ensureAuditIdWithRetry(maxAttempts = 10, delayMs = 500) {