feat(import): use WebSocket push for progress updates with polling fallback

This commit is contained in:
HotSwapp
2025-09-04 14:50:14 -05:00
parent 48ca876123
commit 032baf6e3e
2 changed files with 195 additions and 6 deletions

View File

@@ -11,10 +11,10 @@ from difflib import SequenceMatcher
from datetime import datetime, date, timezone
from decimal import Decimal
from typing import List, Dict, Any, Optional, Tuple
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File as UploadFileForm, Form, Query
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File as UploadFileForm, Form, Query, WebSocket
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from app.database.base import get_db
from app.database.base import get_db, SessionLocal
from app.auth.security import get_current_user
from app.models.user import User
from app.models.rolodex import Rolodex, Phone
@@ -28,9 +28,14 @@ from app.models.flexible import FlexibleImport
from app.models.audit import ImportAudit, ImportAuditFile
from app.config import settings
from app.utils.logging import import_logger
from app.middleware.websocket_middleware import get_websocket_manager
from app.services.websocket_pool import WebSocketMessage
router = APIRouter(tags=["import"])
# WebSocket manager for import progress notifications
websocket_manager = get_websocket_manager()
# Common encodings to try for legacy CSV files (order matters)
ENCODINGS = [
'utf-8-sig',
@@ -93,6 +98,82 @@ CSV_MODEL_MAPPING = {
"RESULTS.csv": PensionResult
}
# -----------------------------
# Progress aggregation helpers
# -----------------------------
def _aggregate_batch_progress(db: Session, audit: ImportAudit) -> Dict[str, Any]:
"""Compute progress summary for the given audit row."""
processed_files = db.query(ImportAuditFile).filter(ImportAuditFile.audit_id == audit.id).count()
successful_files = db.query(ImportAuditFile).filter(
ImportAuditFile.audit_id == audit.id,
ImportAuditFile.status.in_(["success", "completed_with_errors", "skipped"])
).count()
failed_files = db.query(ImportAuditFile).filter(
ImportAuditFile.audit_id == audit.id,
ImportAuditFile.status == "failed"
).count()
total_files = audit.total_files or 0
percent_complete: float = 0.0
if total_files > 0:
try:
percent_complete = (processed_files / total_files) * 100.0
except Exception:
percent_complete = 0.0
data = {
"audit_id": audit.id,
"status": audit.status,
"total_files": total_files,
"processed_files": processed_files,
"successful_files": successful_files,
"failed_files": failed_files,
"started_at": audit.started_at.isoformat() if audit.started_at else None,
"finished_at": audit.finished_at.isoformat() if audit.finished_at else None,
"percent": percent_complete,
"message": audit.message,
}
try:
last_file = (
db.query(ImportAuditFile)
.filter(ImportAuditFile.audit_id == audit.id)
.order_by(ImportAuditFile.id.desc())
.first()
)
if last_file:
data["last_file"] = {
"file_type": last_file.file_type,
"status": last_file.status,
"imported_count": last_file.imported_count,
"errors": last_file.errors,
"message": last_file.message,
"created_at": last_file.created_at.isoformat() if last_file.created_at else None,
}
except Exception:
pass
return data
async def _broadcast_import_progress(db: Session, audit_id: int) -> None:
"""Broadcast current progress for audit_id to its WebSocket topic."""
audit = db.query(ImportAudit).filter(ImportAudit.id == audit_id).first()
if not audit:
return
payload = _aggregate_batch_progress(db, audit)
topic = f"import_batch_progress_{audit_id}"
try:
await websocket_manager.broadcast_to_topic(
topic=topic,
message_type="progress",
data=payload,
)
except Exception:
# Non-fatal; continue API flow
pass
# Minimal CSV template definitions (headers + one sample row) used for template downloads
CSV_IMPORT_TEMPLATES: Dict[str, Dict[str, List[str]]] = {
"FILES.csv": {
@@ -1920,6 +2001,11 @@ async def batch_import_csv_files(
db.add(audit_row)
db.commit()
db.refresh(audit_row)
# Broadcast initial snapshot
try:
await _broadcast_import_progress(db, audit_row.id)
except Exception:
pass
# Directory to persist uploaded files for this audit (for reruns)
audit_dir = Path(settings.upload_dir).joinpath("import_audits", str(audit_row.id))
@@ -2017,6 +2103,10 @@ async def batch_import_csv_files(
details={"saved_path": saved_path} if saved_path else {}
))
db.commit()
try:
await _broadcast_import_progress(db, audit_row.id)
except Exception:
pass
except Exception:
db.rollback()
continue
@@ -2038,6 +2128,10 @@ async def batch_import_csv_files(
details={}
))
db.commit()
try:
await _broadcast_import_progress(db, audit_row.id)
except Exception:
pass
except Exception:
db.rollback()
continue
@@ -2087,6 +2181,10 @@ async def batch_import_csv_files(
details={"saved_path": saved_path} if saved_path else {}
))
db.commit()
try:
await _broadcast_import_progress(db, audit_row.id)
except Exception:
pass
except Exception:
db.rollback()
continue
@@ -2291,6 +2389,10 @@ async def batch_import_csv_files(
}
))
db.commit()
try:
await _broadcast_import_progress(db, audit_row.id)
except Exception:
pass
except Exception:
db.rollback()
@@ -2312,6 +2414,10 @@ async def batch_import_csv_files(
details={"saved_path": saved_path} if saved_path else {}
))
db.commit()
try:
await _broadcast_import_progress(db, audit_row.id)
except Exception:
pass
except Exception:
db.rollback()
@@ -2342,6 +2448,10 @@ async def batch_import_csv_files(
}
db.add(audit_row)
db.commit()
try:
await _broadcast_import_progress(db, audit_row.id)
except Exception:
pass
except Exception:
db.rollback()
@@ -2351,6 +2461,44 @@ async def batch_import_csv_files(
}
@router.websocket("/batch-progress/ws/{audit_id}")
async def ws_import_batch_progress(websocket: WebSocket, audit_id: int):
"""WebSocket: subscribe to real-time updates for an import audit using the pool."""
topic = f"import_batch_progress_{audit_id}"
async def handle_ws_message(connection_id: str, message: WebSocketMessage):
# No-op for now; reserved for future client messages
import_logger.debug(
"Import WS msg",
connection_id=connection_id,
audit_id=audit_id,
type=message.type,
)
connection_id = await websocket_manager.handle_connection(
websocket=websocket,
topics={topic},
require_auth=True,
metadata={"audit_id": audit_id, "endpoint": "import_batch_progress"},
message_handler=handle_ws_message,
)
if connection_id:
db = SessionLocal()
try:
audit = db.query(ImportAudit).filter(ImportAudit.id == audit_id).first()
payload = _aggregate_batch_progress(db, audit) if audit else None
if payload:
initial_message = WebSocketMessage(
type="progress",
topic=topic,
data=payload,
)
await websocket_manager.pool._send_to_connection(connection_id, initial_message)
finally:
db.close()
@router.get("/recent-batches")
async def recent_batch_imports(
limit: int = Query(5, ge=1, le=50),