feat(import): use WebSocket push for progress updates with polling fallback

This commit is contained in:
HotSwapp
2025-09-04 14:50:14 -05:00
parent 48ca876123
commit 032baf6e3e
2 changed files with 195 additions and 6 deletions

View File

@@ -11,10 +11,10 @@ from difflib import SequenceMatcher
from datetime import datetime, date, timezone from datetime import datetime, date, timezone
from decimal import Decimal from decimal import Decimal
from typing import List, Dict, Any, Optional, Tuple from typing import List, Dict, Any, Optional, Tuple
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File as UploadFileForm, Form, Query from fastapi import APIRouter, Depends, HTTPException, UploadFile, File as UploadFileForm, Form, Query, WebSocket
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.database.base import get_db from app.database.base import get_db, SessionLocal
from app.auth.security import get_current_user from app.auth.security import get_current_user
from app.models.user import User from app.models.user import User
from app.models.rolodex import Rolodex, Phone from app.models.rolodex import Rolodex, Phone
@@ -28,9 +28,14 @@ from app.models.flexible import FlexibleImport
from app.models.audit import ImportAudit, ImportAuditFile from app.models.audit import ImportAudit, ImportAuditFile
from app.config import settings from app.config import settings
from app.utils.logging import import_logger from app.utils.logging import import_logger
from app.middleware.websocket_middleware import get_websocket_manager
from app.services.websocket_pool import WebSocketMessage
router = APIRouter(tags=["import"]) router = APIRouter(tags=["import"])
# WebSocket manager for import progress notifications
websocket_manager = get_websocket_manager()
# Common encodings to try for legacy CSV files (order matters) # Common encodings to try for legacy CSV files (order matters)
ENCODINGS = [ ENCODINGS = [
'utf-8-sig', 'utf-8-sig',
@@ -93,6 +98,82 @@ CSV_MODEL_MAPPING = {
"RESULTS.csv": PensionResult "RESULTS.csv": PensionResult
} }
# -----------------------------
# Progress aggregation helpers
# -----------------------------
def _aggregate_batch_progress(db: Session, audit: ImportAudit) -> Dict[str, Any]:
"""Compute progress summary for the given audit row."""
processed_files = db.query(ImportAuditFile).filter(ImportAuditFile.audit_id == audit.id).count()
successful_files = db.query(ImportAuditFile).filter(
ImportAuditFile.audit_id == audit.id,
ImportAuditFile.status.in_(["success", "completed_with_errors", "skipped"])
).count()
failed_files = db.query(ImportAuditFile).filter(
ImportAuditFile.audit_id == audit.id,
ImportAuditFile.status == "failed"
).count()
total_files = audit.total_files or 0
percent_complete: float = 0.0
if total_files > 0:
try:
percent_complete = (processed_files / total_files) * 100.0
except Exception:
percent_complete = 0.0
data = {
"audit_id": audit.id,
"status": audit.status,
"total_files": total_files,
"processed_files": processed_files,
"successful_files": successful_files,
"failed_files": failed_files,
"started_at": audit.started_at.isoformat() if audit.started_at else None,
"finished_at": audit.finished_at.isoformat() if audit.finished_at else None,
"percent": percent_complete,
"message": audit.message,
}
try:
last_file = (
db.query(ImportAuditFile)
.filter(ImportAuditFile.audit_id == audit.id)
.order_by(ImportAuditFile.id.desc())
.first()
)
if last_file:
data["last_file"] = {
"file_type": last_file.file_type,
"status": last_file.status,
"imported_count": last_file.imported_count,
"errors": last_file.errors,
"message": last_file.message,
"created_at": last_file.created_at.isoformat() if last_file.created_at else None,
}
except Exception:
pass
return data
async def _broadcast_import_progress(db: Session, audit_id: int) -> None:
"""Broadcast current progress for audit_id to its WebSocket topic."""
audit = db.query(ImportAudit).filter(ImportAudit.id == audit_id).first()
if not audit:
return
payload = _aggregate_batch_progress(db, audit)
topic = f"import_batch_progress_{audit_id}"
try:
await websocket_manager.broadcast_to_topic(
topic=topic,
message_type="progress",
data=payload,
)
except Exception:
# Non-fatal; continue API flow
pass
# Minimal CSV template definitions (headers + one sample row) used for template downloads # Minimal CSV template definitions (headers + one sample row) used for template downloads
CSV_IMPORT_TEMPLATES: Dict[str, Dict[str, List[str]]] = { CSV_IMPORT_TEMPLATES: Dict[str, Dict[str, List[str]]] = {
"FILES.csv": { "FILES.csv": {
@@ -1920,6 +2001,11 @@ async def batch_import_csv_files(
db.add(audit_row) db.add(audit_row)
db.commit() db.commit()
db.refresh(audit_row) db.refresh(audit_row)
# Broadcast initial snapshot
try:
await _broadcast_import_progress(db, audit_row.id)
except Exception:
pass
# Directory to persist uploaded files for this audit (for reruns) # Directory to persist uploaded files for this audit (for reruns)
audit_dir = Path(settings.upload_dir).joinpath("import_audits", str(audit_row.id)) audit_dir = Path(settings.upload_dir).joinpath("import_audits", str(audit_row.id))
@@ -2017,6 +2103,10 @@ async def batch_import_csv_files(
details={"saved_path": saved_path} if saved_path else {} details={"saved_path": saved_path} if saved_path else {}
)) ))
db.commit() db.commit()
try:
await _broadcast_import_progress(db, audit_row.id)
except Exception:
pass
except Exception: except Exception:
db.rollback() db.rollback()
continue continue
@@ -2038,6 +2128,10 @@ async def batch_import_csv_files(
details={} details={}
)) ))
db.commit() db.commit()
try:
await _broadcast_import_progress(db, audit_row.id)
except Exception:
pass
except Exception: except Exception:
db.rollback() db.rollback()
continue continue
@@ -2087,6 +2181,10 @@ async def batch_import_csv_files(
details={"saved_path": saved_path} if saved_path else {} details={"saved_path": saved_path} if saved_path else {}
)) ))
db.commit() db.commit()
try:
await _broadcast_import_progress(db, audit_row.id)
except Exception:
pass
except Exception: except Exception:
db.rollback() db.rollback()
continue continue
@@ -2291,6 +2389,10 @@ async def batch_import_csv_files(
} }
)) ))
db.commit() db.commit()
try:
await _broadcast_import_progress(db, audit_row.id)
except Exception:
pass
except Exception: except Exception:
db.rollback() db.rollback()
@@ -2312,6 +2414,10 @@ async def batch_import_csv_files(
details={"saved_path": saved_path} if saved_path else {} details={"saved_path": saved_path} if saved_path else {}
)) ))
db.commit() db.commit()
try:
await _broadcast_import_progress(db, audit_row.id)
except Exception:
pass
except Exception: except Exception:
db.rollback() db.rollback()
@@ -2342,6 +2448,10 @@ async def batch_import_csv_files(
} }
db.add(audit_row) db.add(audit_row)
db.commit() db.commit()
try:
await _broadcast_import_progress(db, audit_row.id)
except Exception:
pass
except Exception: except Exception:
db.rollback() db.rollback()
@@ -2351,6 +2461,44 @@ async def batch_import_csv_files(
} }
@router.websocket("/batch-progress/ws/{audit_id}")
async def ws_import_batch_progress(websocket: WebSocket, audit_id: int):
"""WebSocket: subscribe to real-time updates for an import audit using the pool."""
topic = f"import_batch_progress_{audit_id}"
async def handle_ws_message(connection_id: str, message: WebSocketMessage):
# No-op for now; reserved for future client messages
import_logger.debug(
"Import WS msg",
connection_id=connection_id,
audit_id=audit_id,
type=message.type,
)
connection_id = await websocket_manager.handle_connection(
websocket=websocket,
topics={topic},
require_auth=True,
metadata={"audit_id": audit_id, "endpoint": "import_batch_progress"},
message_handler=handle_ws_message,
)
if connection_id:
db = SessionLocal()
try:
audit = db.query(ImportAudit).filter(ImportAudit.id == audit_id).first()
payload = _aggregate_batch_progress(db, audit) if audit else None
if payload:
initial_message = WebSocketMessage(
type="progress",
topic=topic,
data=payload,
)
await websocket_manager.pool._send_to_connection(connection_id, initial_message)
finally:
db.close()
@router.get("/recent-batches") @router.get("/recent-batches")
async def recent_batch_imports( async def recent_batch_imports(
limit: int = Query(5, ge=1, le=50), limit: int = Query(5, ge=1, le=50),

View File

@@ -882,7 +882,7 @@ function showProgress(show, message = '', percent = null) {
// Batch progress monitoring // Batch progress monitoring
// ----------------------------- // -----------------------------
const TERMINAL_BATCH_STATUSES = new Set(['success', 'completed_with_errors', 'failed']); const TERMINAL_BATCH_STATUSES = new Set(['success', 'completed_with_errors', 'failed']);
let batchProgress = { timer: null, auditId: null }; let batchProgress = { timer: null, auditId: null, wsMgr: null };
async function fetchCurrentBatch() { async function fetchCurrentBatch() {
try { try {
@@ -899,6 +899,8 @@ function stopBatchProgressPolling() {
batchProgress.timer = null; batchProgress.timer = null;
} }
batchProgress.auditId = null; batchProgress.auditId = null;
try { if (batchProgress.wsMgr) { batchProgress.wsMgr.close(); } } catch (_) {}
batchProgress.wsMgr = null;
} }
async function pollBatchProgressOnce(auditId) { async function pollBatchProgressOnce(auditId) {
@@ -924,10 +926,49 @@ async function pollBatchProgressOnce(auditId) {
function startBatchProgressPolling(auditId) { function startBatchProgressPolling(auditId) {
stopBatchProgressPolling(); stopBatchProgressPolling();
batchProgress.auditId = auditId; batchProgress.auditId = auditId;
// Try WebSocket first; fallback to polling
try {
const mgr = new (window.notifications && window.notifications.NotificationManager ? window.notifications.NotificationManager : null)({
getUrl: () => `/api/import/batch-progress/ws/${encodeURIComponent(auditId)}`,
onMessage: (msg) => {
if (!msg || !msg.type) return;
if (msg.type === 'progress') {
const p = msg.data || {};
const percent = Number(p.percent || 0);
const total = Number(p.total_files || 0);
const processed = Number(p.processed_files || 0);
const status = String(p.status || 'running');
const statusNice = status.replaceAll('_', ' ');
const msgText = total > 0
? `Processing ${processed}/${total} (${percent.toFixed(1)}%) · ${statusNice}`
: `Processing… ${statusNice}`;
showProgress(true, msgText, percent);
if (TERMINAL_BATCH_STATUSES.has(status)) {
stopBatchProgressPolling();
}
}
},
onStateChange: (state) => {
if (state === 'error' || state === 'closed' || state === 'offline') {
// fallback to polling
if (!batchProgress.timer) {
pollBatchProgressOnce(auditId);
batchProgress.timer = setInterval(() => pollBatchProgressOnce(auditId), 1500);
}
}
},
autoConnect: true,
debug: false
});
batchProgress.wsMgr = mgr;
// Safety: also do an immediate HTTP fetch as first snapshot
pollBatchProgressOnce(auditId);
} catch (_) {
// immediate + interval polling // immediate + interval polling
pollBatchProgressOnce(auditId); pollBatchProgressOnce(auditId);
batchProgress.timer = setInterval(() => pollBatchProgressOnce(auditId), 1500); batchProgress.timer = setInterval(() => pollBatchProgressOnce(auditId), 1500);
} }
}
async function ensureAuditIdWithRetry(maxAttempts = 10, delayMs = 500) { async function ensureAuditIdWithRetry(maxAttempts = 10, delayMs = 500) {
for (let attempt = 0; attempt < maxAttempts; attempt++) { for (let attempt = 0; attempt < maxAttempts; attempt++) {