feat(import): use WebSocket push for progress updates with polling fallback
This commit is contained in:
@@ -11,10 +11,10 @@ from difflib import SequenceMatcher
|
|||||||
from datetime import datetime, date, timezone
|
from datetime import datetime, date, timezone
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
from typing import List, Dict, Any, Optional, Tuple
|
from typing import List, Dict, Any, Optional, Tuple
|
||||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File as UploadFileForm, Form, Query
|
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File as UploadFileForm, Form, Query, WebSocket
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from app.database.base import get_db
|
from app.database.base import get_db, SessionLocal
|
||||||
from app.auth.security import get_current_user
|
from app.auth.security import get_current_user
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.models.rolodex import Rolodex, Phone
|
from app.models.rolodex import Rolodex, Phone
|
||||||
@@ -28,9 +28,14 @@ from app.models.flexible import FlexibleImport
|
|||||||
from app.models.audit import ImportAudit, ImportAuditFile
|
from app.models.audit import ImportAudit, ImportAuditFile
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
from app.utils.logging import import_logger
|
from app.utils.logging import import_logger
|
||||||
|
from app.middleware.websocket_middleware import get_websocket_manager
|
||||||
|
from app.services.websocket_pool import WebSocketMessage
|
||||||
|
|
||||||
router = APIRouter(tags=["import"])
|
router = APIRouter(tags=["import"])
|
||||||
|
|
||||||
|
# WebSocket manager for import progress notifications
|
||||||
|
websocket_manager = get_websocket_manager()
|
||||||
|
|
||||||
# Common encodings to try for legacy CSV files (order matters)
|
# Common encodings to try for legacy CSV files (order matters)
|
||||||
ENCODINGS = [
|
ENCODINGS = [
|
||||||
'utf-8-sig',
|
'utf-8-sig',
|
||||||
@@ -93,6 +98,82 @@ CSV_MODEL_MAPPING = {
|
|||||||
"RESULTS.csv": PensionResult
|
"RESULTS.csv": PensionResult
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------
|
||||||
|
# Progress aggregation helpers
|
||||||
|
# -----------------------------
|
||||||
|
def _aggregate_batch_progress(db: Session, audit: ImportAudit) -> Dict[str, Any]:
|
||||||
|
"""Compute progress summary for the given audit row."""
|
||||||
|
processed_files = db.query(ImportAuditFile).filter(ImportAuditFile.audit_id == audit.id).count()
|
||||||
|
successful_files = db.query(ImportAuditFile).filter(
|
||||||
|
ImportAuditFile.audit_id == audit.id,
|
||||||
|
ImportAuditFile.status.in_(["success", "completed_with_errors", "skipped"])
|
||||||
|
).count()
|
||||||
|
failed_files = db.query(ImportAuditFile).filter(
|
||||||
|
ImportAuditFile.audit_id == audit.id,
|
||||||
|
ImportAuditFile.status == "failed"
|
||||||
|
).count()
|
||||||
|
|
||||||
|
total_files = audit.total_files or 0
|
||||||
|
percent_complete: float = 0.0
|
||||||
|
if total_files > 0:
|
||||||
|
try:
|
||||||
|
percent_complete = (processed_files / total_files) * 100.0
|
||||||
|
except Exception:
|
||||||
|
percent_complete = 0.0
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"audit_id": audit.id,
|
||||||
|
"status": audit.status,
|
||||||
|
"total_files": total_files,
|
||||||
|
"processed_files": processed_files,
|
||||||
|
"successful_files": successful_files,
|
||||||
|
"failed_files": failed_files,
|
||||||
|
"started_at": audit.started_at.isoformat() if audit.started_at else None,
|
||||||
|
"finished_at": audit.finished_at.isoformat() if audit.finished_at else None,
|
||||||
|
"percent": percent_complete,
|
||||||
|
"message": audit.message,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
last_file = (
|
||||||
|
db.query(ImportAuditFile)
|
||||||
|
.filter(ImportAuditFile.audit_id == audit.id)
|
||||||
|
.order_by(ImportAuditFile.id.desc())
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if last_file:
|
||||||
|
data["last_file"] = {
|
||||||
|
"file_type": last_file.file_type,
|
||||||
|
"status": last_file.status,
|
||||||
|
"imported_count": last_file.imported_count,
|
||||||
|
"errors": last_file.errors,
|
||||||
|
"message": last_file.message,
|
||||||
|
"created_at": last_file.created_at.isoformat() if last_file.created_at else None,
|
||||||
|
}
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
async def _broadcast_import_progress(db: Session, audit_id: int) -> None:
|
||||||
|
"""Broadcast current progress for audit_id to its WebSocket topic."""
|
||||||
|
audit = db.query(ImportAudit).filter(ImportAudit.id == audit_id).first()
|
||||||
|
if not audit:
|
||||||
|
return
|
||||||
|
payload = _aggregate_batch_progress(db, audit)
|
||||||
|
topic = f"import_batch_progress_{audit_id}"
|
||||||
|
try:
|
||||||
|
await websocket_manager.broadcast_to_topic(
|
||||||
|
topic=topic,
|
||||||
|
message_type="progress",
|
||||||
|
data=payload,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
# Non-fatal; continue API flow
|
||||||
|
pass
|
||||||
|
|
||||||
# Minimal CSV template definitions (headers + one sample row) used for template downloads
|
# Minimal CSV template definitions (headers + one sample row) used for template downloads
|
||||||
CSV_IMPORT_TEMPLATES: Dict[str, Dict[str, List[str]]] = {
|
CSV_IMPORT_TEMPLATES: Dict[str, Dict[str, List[str]]] = {
|
||||||
"FILES.csv": {
|
"FILES.csv": {
|
||||||
@@ -1920,6 +2001,11 @@ async def batch_import_csv_files(
|
|||||||
db.add(audit_row)
|
db.add(audit_row)
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(audit_row)
|
db.refresh(audit_row)
|
||||||
|
# Broadcast initial snapshot
|
||||||
|
try:
|
||||||
|
await _broadcast_import_progress(db, audit_row.id)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
# Directory to persist uploaded files for this audit (for reruns)
|
# Directory to persist uploaded files for this audit (for reruns)
|
||||||
audit_dir = Path(settings.upload_dir).joinpath("import_audits", str(audit_row.id))
|
audit_dir = Path(settings.upload_dir).joinpath("import_audits", str(audit_row.id))
|
||||||
@@ -2017,6 +2103,10 @@ async def batch_import_csv_files(
|
|||||||
details={"saved_path": saved_path} if saved_path else {}
|
details={"saved_path": saved_path} if saved_path else {}
|
||||||
))
|
))
|
||||||
db.commit()
|
db.commit()
|
||||||
|
try:
|
||||||
|
await _broadcast_import_progress(db, audit_row.id)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
except Exception:
|
except Exception:
|
||||||
db.rollback()
|
db.rollback()
|
||||||
continue
|
continue
|
||||||
@@ -2038,6 +2128,10 @@ async def batch_import_csv_files(
|
|||||||
details={}
|
details={}
|
||||||
))
|
))
|
||||||
db.commit()
|
db.commit()
|
||||||
|
try:
|
||||||
|
await _broadcast_import_progress(db, audit_row.id)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
except Exception:
|
except Exception:
|
||||||
db.rollback()
|
db.rollback()
|
||||||
continue
|
continue
|
||||||
@@ -2087,6 +2181,10 @@ async def batch_import_csv_files(
|
|||||||
details={"saved_path": saved_path} if saved_path else {}
|
details={"saved_path": saved_path} if saved_path else {}
|
||||||
))
|
))
|
||||||
db.commit()
|
db.commit()
|
||||||
|
try:
|
||||||
|
await _broadcast_import_progress(db, audit_row.id)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
except Exception:
|
except Exception:
|
||||||
db.rollback()
|
db.rollback()
|
||||||
continue
|
continue
|
||||||
@@ -2291,6 +2389,10 @@ async def batch_import_csv_files(
|
|||||||
}
|
}
|
||||||
))
|
))
|
||||||
db.commit()
|
db.commit()
|
||||||
|
try:
|
||||||
|
await _broadcast_import_progress(db, audit_row.id)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
except Exception:
|
except Exception:
|
||||||
db.rollback()
|
db.rollback()
|
||||||
|
|
||||||
@@ -2312,6 +2414,10 @@ async def batch_import_csv_files(
|
|||||||
details={"saved_path": saved_path} if saved_path else {}
|
details={"saved_path": saved_path} if saved_path else {}
|
||||||
))
|
))
|
||||||
db.commit()
|
db.commit()
|
||||||
|
try:
|
||||||
|
await _broadcast_import_progress(db, audit_row.id)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
except Exception:
|
except Exception:
|
||||||
db.rollback()
|
db.rollback()
|
||||||
|
|
||||||
@@ -2342,6 +2448,10 @@ async def batch_import_csv_files(
|
|||||||
}
|
}
|
||||||
db.add(audit_row)
|
db.add(audit_row)
|
||||||
db.commit()
|
db.commit()
|
||||||
|
try:
|
||||||
|
await _broadcast_import_progress(db, audit_row.id)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
except Exception:
|
except Exception:
|
||||||
db.rollback()
|
db.rollback()
|
||||||
|
|
||||||
@@ -2351,6 +2461,44 @@ async def batch_import_csv_files(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.websocket("/batch-progress/ws/{audit_id}")
|
||||||
|
async def ws_import_batch_progress(websocket: WebSocket, audit_id: int):
|
||||||
|
"""WebSocket: subscribe to real-time updates for an import audit using the pool."""
|
||||||
|
topic = f"import_batch_progress_{audit_id}"
|
||||||
|
|
||||||
|
async def handle_ws_message(connection_id: str, message: WebSocketMessage):
|
||||||
|
# No-op for now; reserved for future client messages
|
||||||
|
import_logger.debug(
|
||||||
|
"Import WS msg",
|
||||||
|
connection_id=connection_id,
|
||||||
|
audit_id=audit_id,
|
||||||
|
type=message.type,
|
||||||
|
)
|
||||||
|
|
||||||
|
connection_id = await websocket_manager.handle_connection(
|
||||||
|
websocket=websocket,
|
||||||
|
topics={topic},
|
||||||
|
require_auth=True,
|
||||||
|
metadata={"audit_id": audit_id, "endpoint": "import_batch_progress"},
|
||||||
|
message_handler=handle_ws_message,
|
||||||
|
)
|
||||||
|
|
||||||
|
if connection_id:
|
||||||
|
db = SessionLocal()
|
||||||
|
try:
|
||||||
|
audit = db.query(ImportAudit).filter(ImportAudit.id == audit_id).first()
|
||||||
|
payload = _aggregate_batch_progress(db, audit) if audit else None
|
||||||
|
if payload:
|
||||||
|
initial_message = WebSocketMessage(
|
||||||
|
type="progress",
|
||||||
|
topic=topic,
|
||||||
|
data=payload,
|
||||||
|
)
|
||||||
|
await websocket_manager.pool._send_to_connection(connection_id, initial_message)
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
@router.get("/recent-batches")
|
@router.get("/recent-batches")
|
||||||
async def recent_batch_imports(
|
async def recent_batch_imports(
|
||||||
limit: int = Query(5, ge=1, le=50),
|
limit: int = Query(5, ge=1, le=50),
|
||||||
|
|||||||
@@ -882,7 +882,7 @@ function showProgress(show, message = '', percent = null) {
|
|||||||
// Batch progress monitoring
|
// Batch progress monitoring
|
||||||
// -----------------------------
|
// -----------------------------
|
||||||
const TERMINAL_BATCH_STATUSES = new Set(['success', 'completed_with_errors', 'failed']);
|
const TERMINAL_BATCH_STATUSES = new Set(['success', 'completed_with_errors', 'failed']);
|
||||||
let batchProgress = { timer: null, auditId: null };
|
let batchProgress = { timer: null, auditId: null, wsMgr: null };
|
||||||
|
|
||||||
async function fetchCurrentBatch() {
|
async function fetchCurrentBatch() {
|
||||||
try {
|
try {
|
||||||
@@ -899,6 +899,8 @@ function stopBatchProgressPolling() {
|
|||||||
batchProgress.timer = null;
|
batchProgress.timer = null;
|
||||||
}
|
}
|
||||||
batchProgress.auditId = null;
|
batchProgress.auditId = null;
|
||||||
|
try { if (batchProgress.wsMgr) { batchProgress.wsMgr.close(); } } catch (_) {}
|
||||||
|
batchProgress.wsMgr = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
async function pollBatchProgressOnce(auditId) {
|
async function pollBatchProgressOnce(auditId) {
|
||||||
@@ -924,10 +926,49 @@ async function pollBatchProgressOnce(auditId) {
|
|||||||
function startBatchProgressPolling(auditId) {
|
function startBatchProgressPolling(auditId) {
|
||||||
stopBatchProgressPolling();
|
stopBatchProgressPolling();
|
||||||
batchProgress.auditId = auditId;
|
batchProgress.auditId = auditId;
|
||||||
|
// Try WebSocket first; fallback to polling
|
||||||
|
try {
|
||||||
|
const mgr = new (window.notifications && window.notifications.NotificationManager ? window.notifications.NotificationManager : null)({
|
||||||
|
getUrl: () => `/api/import/batch-progress/ws/${encodeURIComponent(auditId)}`,
|
||||||
|
onMessage: (msg) => {
|
||||||
|
if (!msg || !msg.type) return;
|
||||||
|
if (msg.type === 'progress') {
|
||||||
|
const p = msg.data || {};
|
||||||
|
const percent = Number(p.percent || 0);
|
||||||
|
const total = Number(p.total_files || 0);
|
||||||
|
const processed = Number(p.processed_files || 0);
|
||||||
|
const status = String(p.status || 'running');
|
||||||
|
const statusNice = status.replaceAll('_', ' ');
|
||||||
|
const msgText = total > 0
|
||||||
|
? `Processing ${processed}/${total} (${percent.toFixed(1)}%) · ${statusNice}`
|
||||||
|
: `Processing… ${statusNice}`;
|
||||||
|
showProgress(true, msgText, percent);
|
||||||
|
if (TERMINAL_BATCH_STATUSES.has(status)) {
|
||||||
|
stopBatchProgressPolling();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
onStateChange: (state) => {
|
||||||
|
if (state === 'error' || state === 'closed' || state === 'offline') {
|
||||||
|
// fallback to polling
|
||||||
|
if (!batchProgress.timer) {
|
||||||
|
pollBatchProgressOnce(auditId);
|
||||||
|
batchProgress.timer = setInterval(() => pollBatchProgressOnce(auditId), 1500);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
autoConnect: true,
|
||||||
|
debug: false
|
||||||
|
});
|
||||||
|
batchProgress.wsMgr = mgr;
|
||||||
|
// Safety: also do an immediate HTTP fetch as first snapshot
|
||||||
|
pollBatchProgressOnce(auditId);
|
||||||
|
} catch (_) {
|
||||||
// immediate + interval polling
|
// immediate + interval polling
|
||||||
pollBatchProgressOnce(auditId);
|
pollBatchProgressOnce(auditId);
|
||||||
batchProgress.timer = setInterval(() => pollBatchProgressOnce(auditId), 1500);
|
batchProgress.timer = setInterval(() => pollBatchProgressOnce(auditId), 1500);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async function ensureAuditIdWithRetry(maxAttempts = 10, delayMs = 500) {
|
async function ensureAuditIdWithRetry(maxAttempts = 10, delayMs = 500) {
|
||||||
for (let attempt = 0; attempt < maxAttempts; attempt++) {
|
for (let attempt = 0; attempt < maxAttempts; attempt++) {
|
||||||
|
|||||||
Reference in New Issue
Block a user