feat(import): use WebSocket push for progress updates with polling fallback
This commit is contained in:
@@ -11,10 +11,10 @@ from difflib import SequenceMatcher
|
||||
from datetime import datetime, date, timezone
|
||||
from decimal import Decimal
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File as UploadFileForm, Form, Query
|
||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File as UploadFileForm, Form, Query, WebSocket
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
from app.database.base import get_db
|
||||
from app.database.base import get_db, SessionLocal
|
||||
from app.auth.security import get_current_user
|
||||
from app.models.user import User
|
||||
from app.models.rolodex import Rolodex, Phone
|
||||
@@ -28,9 +28,14 @@ from app.models.flexible import FlexibleImport
|
||||
from app.models.audit import ImportAudit, ImportAuditFile
|
||||
from app.config import settings
|
||||
from app.utils.logging import import_logger
|
||||
from app.middleware.websocket_middleware import get_websocket_manager
|
||||
from app.services.websocket_pool import WebSocketMessage
|
||||
|
||||
router = APIRouter(tags=["import"])
|
||||
|
||||
# WebSocket manager for import progress notifications
|
||||
websocket_manager = get_websocket_manager()
|
||||
|
||||
# Common encodings to try for legacy CSV files (order matters)
|
||||
ENCODINGS = [
|
||||
'utf-8-sig',
|
||||
@@ -93,6 +98,82 @@ CSV_MODEL_MAPPING = {
|
||||
"RESULTS.csv": PensionResult
|
||||
}
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# Progress aggregation helpers
|
||||
# -----------------------------
|
||||
def _aggregate_batch_progress(db: Session, audit: ImportAudit) -> Dict[str, Any]:
|
||||
"""Compute progress summary for the given audit row."""
|
||||
processed_files = db.query(ImportAuditFile).filter(ImportAuditFile.audit_id == audit.id).count()
|
||||
successful_files = db.query(ImportAuditFile).filter(
|
||||
ImportAuditFile.audit_id == audit.id,
|
||||
ImportAuditFile.status.in_(["success", "completed_with_errors", "skipped"])
|
||||
).count()
|
||||
failed_files = db.query(ImportAuditFile).filter(
|
||||
ImportAuditFile.audit_id == audit.id,
|
||||
ImportAuditFile.status == "failed"
|
||||
).count()
|
||||
|
||||
total_files = audit.total_files or 0
|
||||
percent_complete: float = 0.0
|
||||
if total_files > 0:
|
||||
try:
|
||||
percent_complete = (processed_files / total_files) * 100.0
|
||||
except Exception:
|
||||
percent_complete = 0.0
|
||||
|
||||
data = {
|
||||
"audit_id": audit.id,
|
||||
"status": audit.status,
|
||||
"total_files": total_files,
|
||||
"processed_files": processed_files,
|
||||
"successful_files": successful_files,
|
||||
"failed_files": failed_files,
|
||||
"started_at": audit.started_at.isoformat() if audit.started_at else None,
|
||||
"finished_at": audit.finished_at.isoformat() if audit.finished_at else None,
|
||||
"percent": percent_complete,
|
||||
"message": audit.message,
|
||||
}
|
||||
|
||||
try:
|
||||
last_file = (
|
||||
db.query(ImportAuditFile)
|
||||
.filter(ImportAuditFile.audit_id == audit.id)
|
||||
.order_by(ImportAuditFile.id.desc())
|
||||
.first()
|
||||
)
|
||||
if last_file:
|
||||
data["last_file"] = {
|
||||
"file_type": last_file.file_type,
|
||||
"status": last_file.status,
|
||||
"imported_count": last_file.imported_count,
|
||||
"errors": last_file.errors,
|
||||
"message": last_file.message,
|
||||
"created_at": last_file.created_at.isoformat() if last_file.created_at else None,
|
||||
}
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return data
|
||||
|
||||
|
||||
async def _broadcast_import_progress(db: Session, audit_id: int) -> None:
|
||||
"""Broadcast current progress for audit_id to its WebSocket topic."""
|
||||
audit = db.query(ImportAudit).filter(ImportAudit.id == audit_id).first()
|
||||
if not audit:
|
||||
return
|
||||
payload = _aggregate_batch_progress(db, audit)
|
||||
topic = f"import_batch_progress_{audit_id}"
|
||||
try:
|
||||
await websocket_manager.broadcast_to_topic(
|
||||
topic=topic,
|
||||
message_type="progress",
|
||||
data=payload,
|
||||
)
|
||||
except Exception:
|
||||
# Non-fatal; continue API flow
|
||||
pass
|
||||
|
||||
# Minimal CSV template definitions (headers + one sample row) used for template downloads
|
||||
CSV_IMPORT_TEMPLATES: Dict[str, Dict[str, List[str]]] = {
|
||||
"FILES.csv": {
|
||||
@@ -1920,6 +2001,11 @@ async def batch_import_csv_files(
|
||||
db.add(audit_row)
|
||||
db.commit()
|
||||
db.refresh(audit_row)
|
||||
# Broadcast initial snapshot
|
||||
try:
|
||||
await _broadcast_import_progress(db, audit_row.id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Directory to persist uploaded files for this audit (for reruns)
|
||||
audit_dir = Path(settings.upload_dir).joinpath("import_audits", str(audit_row.id))
|
||||
@@ -2017,6 +2103,10 @@ async def batch_import_csv_files(
|
||||
details={"saved_path": saved_path} if saved_path else {}
|
||||
))
|
||||
db.commit()
|
||||
try:
|
||||
await _broadcast_import_progress(db, audit_row.id)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
db.rollback()
|
||||
continue
|
||||
@@ -2038,6 +2128,10 @@ async def batch_import_csv_files(
|
||||
details={}
|
||||
))
|
||||
db.commit()
|
||||
try:
|
||||
await _broadcast_import_progress(db, audit_row.id)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
db.rollback()
|
||||
continue
|
||||
@@ -2087,6 +2181,10 @@ async def batch_import_csv_files(
|
||||
details={"saved_path": saved_path} if saved_path else {}
|
||||
))
|
||||
db.commit()
|
||||
try:
|
||||
await _broadcast_import_progress(db, audit_row.id)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
db.rollback()
|
||||
continue
|
||||
@@ -2291,6 +2389,10 @@ async def batch_import_csv_files(
|
||||
}
|
||||
))
|
||||
db.commit()
|
||||
try:
|
||||
await _broadcast_import_progress(db, audit_row.id)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
db.rollback()
|
||||
|
||||
@@ -2312,6 +2414,10 @@ async def batch_import_csv_files(
|
||||
details={"saved_path": saved_path} if saved_path else {}
|
||||
))
|
||||
db.commit()
|
||||
try:
|
||||
await _broadcast_import_progress(db, audit_row.id)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
db.rollback()
|
||||
|
||||
@@ -2342,6 +2448,10 @@ async def batch_import_csv_files(
|
||||
}
|
||||
db.add(audit_row)
|
||||
db.commit()
|
||||
try:
|
||||
await _broadcast_import_progress(db, audit_row.id)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
db.rollback()
|
||||
|
||||
@@ -2351,6 +2461,44 @@ async def batch_import_csv_files(
|
||||
}
|
||||
|
||||
|
||||
@router.websocket("/batch-progress/ws/{audit_id}")
|
||||
async def ws_import_batch_progress(websocket: WebSocket, audit_id: int):
|
||||
"""WebSocket: subscribe to real-time updates for an import audit using the pool."""
|
||||
topic = f"import_batch_progress_{audit_id}"
|
||||
|
||||
async def handle_ws_message(connection_id: str, message: WebSocketMessage):
|
||||
# No-op for now; reserved for future client messages
|
||||
import_logger.debug(
|
||||
"Import WS msg",
|
||||
connection_id=connection_id,
|
||||
audit_id=audit_id,
|
||||
type=message.type,
|
||||
)
|
||||
|
||||
connection_id = await websocket_manager.handle_connection(
|
||||
websocket=websocket,
|
||||
topics={topic},
|
||||
require_auth=True,
|
||||
metadata={"audit_id": audit_id, "endpoint": "import_batch_progress"},
|
||||
message_handler=handle_ws_message,
|
||||
)
|
||||
|
||||
if connection_id:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
audit = db.query(ImportAudit).filter(ImportAudit.id == audit_id).first()
|
||||
payload = _aggregate_batch_progress(db, audit) if audit else None
|
||||
if payload:
|
||||
initial_message = WebSocketMessage(
|
||||
type="progress",
|
||||
topic=topic,
|
||||
data=payload,
|
||||
)
|
||||
await websocket_manager.pool._send_to_connection(connection_id, initial_message)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@router.get("/recent-batches")
|
||||
async def recent_batch_imports(
|
||||
limit: int = Query(5, ge=1, le=50),
|
||||
|
||||
@@ -882,7 +882,7 @@ function showProgress(show, message = '', percent = null) {
|
||||
// Batch progress monitoring
|
||||
// -----------------------------
|
||||
const TERMINAL_BATCH_STATUSES = new Set(['success', 'completed_with_errors', 'failed']);
|
||||
let batchProgress = { timer: null, auditId: null };
|
||||
let batchProgress = { timer: null, auditId: null, wsMgr: null };
|
||||
|
||||
async function fetchCurrentBatch() {
|
||||
try {
|
||||
@@ -899,6 +899,8 @@ function stopBatchProgressPolling() {
|
||||
batchProgress.timer = null;
|
||||
}
|
||||
batchProgress.auditId = null;
|
||||
try { if (batchProgress.wsMgr) { batchProgress.wsMgr.close(); } } catch (_) {}
|
||||
batchProgress.wsMgr = null;
|
||||
}
|
||||
|
||||
async function pollBatchProgressOnce(auditId) {
|
||||
@@ -924,9 +926,48 @@ async function pollBatchProgressOnce(auditId) {
|
||||
function startBatchProgressPolling(auditId) {
|
||||
stopBatchProgressPolling();
|
||||
batchProgress.auditId = auditId;
|
||||
// Try WebSocket first; fallback to polling
|
||||
try {
|
||||
const mgr = new (window.notifications && window.notifications.NotificationManager ? window.notifications.NotificationManager : null)({
|
||||
getUrl: () => `/api/import/batch-progress/ws/${encodeURIComponent(auditId)}`,
|
||||
onMessage: (msg) => {
|
||||
if (!msg || !msg.type) return;
|
||||
if (msg.type === 'progress') {
|
||||
const p = msg.data || {};
|
||||
const percent = Number(p.percent || 0);
|
||||
const total = Number(p.total_files || 0);
|
||||
const processed = Number(p.processed_files || 0);
|
||||
const status = String(p.status || 'running');
|
||||
const statusNice = status.replaceAll('_', ' ');
|
||||
const msgText = total > 0
|
||||
? `Processing ${processed}/${total} (${percent.toFixed(1)}%) · ${statusNice}`
|
||||
: `Processing… ${statusNice}`;
|
||||
showProgress(true, msgText, percent);
|
||||
if (TERMINAL_BATCH_STATUSES.has(status)) {
|
||||
stopBatchProgressPolling();
|
||||
}
|
||||
}
|
||||
},
|
||||
onStateChange: (state) => {
|
||||
if (state === 'error' || state === 'closed' || state === 'offline') {
|
||||
// fallback to polling
|
||||
if (!batchProgress.timer) {
|
||||
pollBatchProgressOnce(auditId);
|
||||
batchProgress.timer = setInterval(() => pollBatchProgressOnce(auditId), 1500);
|
||||
}
|
||||
}
|
||||
},
|
||||
autoConnect: true,
|
||||
debug: false
|
||||
});
|
||||
batchProgress.wsMgr = mgr;
|
||||
// Safety: also do an immediate HTTP fetch as first snapshot
|
||||
pollBatchProgressOnce(auditId);
|
||||
} catch (_) {
|
||||
// immediate + interval polling
|
||||
pollBatchProgressOnce(auditId);
|
||||
batchProgress.timer = setInterval(() => pollBatchProgressOnce(auditId), 1500);
|
||||
}
|
||||
}
|
||||
|
||||
async function ensureAuditIdWithRetry(maxAttempts = 10, delayMs = 500) {
|
||||
|
||||
Reference in New Issue
Block a user