diff --git a/app/api/import_data.py b/app/api/import_data.py index e6b0e57..e7053b0 100644 --- a/app/api/import_data.py +++ b/app/api/import_data.py @@ -2464,6 +2464,43 @@ async def batch_import_csv_files( @router.websocket("/batch-progress/ws/{audit_id}") async def ws_import_batch_progress(websocket: WebSocket, audit_id: int): """WebSocket: subscribe to real-time updates for an import audit using the pool.""" + # Authenticate first (without accepting) to enforce authorization before subscribing + user = await websocket_manager.authenticate_websocket(websocket) + if not user: + # Authentication failure handled here to avoid accepting unauthorized connections + try: + await websocket.close(code=4401, reason="Authentication failed") + except Exception: + pass + return + + # Authorization: only initiating user or admins may subscribe to this audit stream + db = SessionLocal() + try: + audit = db.query(ImportAudit).filter(ImportAudit.id == audit_id).first() + if not audit: + try: + await websocket.close(code=4404, reason="Batch not found") + except Exception: + pass + return + + is_admin = bool(getattr(user, "is_admin", False)) + if not is_admin and getattr(user, "id", None) != getattr(audit, "initiated_by_user_id", None): + import_logger.warning( + "Unauthorized WS subscription attempt for import batch", + audit_id=audit_id, + user_id=getattr(user, "id", None), + username=getattr(user, "username", None), + ) + try: + await websocket.close(code=4403, reason="Not authorized to subscribe to this batch") + except Exception: + pass + return + finally: + db.close() + topic = f"import_batch_progress_{audit_id}" async def handle_ws_message(connection_id: str, message: WebSocketMessage):