From 41ffbc143020eed90788f3b357cc25c324eef4ac Mon Sep 17 00:00:00 2001 From: HotSwapp <47397945+HotSwapp@users.noreply.github.com> Date: Thu, 4 Sep 2025 15:08:09 -0500 Subject: [PATCH] feat(billing): restrict WS subscriptions to initiating user or admins --- app/api/billing.py | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/app/api/billing.py b/app/api/billing.py index 05a5a0f..39358db 100644 --- a/app/api/billing.py +++ b/app/api/billing.py @@ -473,6 +473,9 @@ class BatchProgress(BaseModel): success_rate: Optional[float] = None files: List[BatchProgressEntry] = Field(default_factory=list) error_message: Optional[str] = None + # Initiation metadata for authorization + initiated_by_user_id: Optional[int] = None + initiated_by_username: Optional[str] = None model_config = ConfigDict( json_schema_extra={ @@ -694,6 +697,38 @@ async def generate_statement( @router.websocket("/statements/batch-progress/ws/{batch_id}") async def ws_batch_progress(websocket: WebSocket, batch_id: str): """WebSocket: subscribe to real-time updates for a batch_id using the WebSocket pool.""" + # Authenticate first (without accepting) to enforce authorization before subscribing + user = await websocket_manager.authenticate_websocket(websocket) + if not user: + try: + await websocket.close(code=4401, reason="Authentication failed") + except Exception: + pass + return + + # Authorization: only initiating user or admins may subscribe to this batch stream + progress = await progress_store.get_progress(batch_id) + if not progress: + try: + await websocket.close(code=4404, reason="Batch not found") + except Exception: + pass + return + + is_admin = bool(getattr(user, "is_admin", False)) + if not is_admin and getattr(user, "id", None) != getattr(progress, "initiated_by_user_id", None): + billing_logger.warning( + "Unauthorized WS subscription attempt for billing batch", + batch_id=batch_id, + user_id=getattr(user, "id", None), + username=getattr(user, "username", None), + ) + try: + await websocket.close(code=4403, reason="Not authorized to subscribe to this batch") + except Exception: + pass + return + topic = f"batch_progress_{batch_id}" # Custom message handler for batch progress @@ -877,6 +912,8 @@ async def batch_generate_statements( failed_files=0, started_at=start_time.isoformat(), updated_at=start_time.isoformat(), + initiated_by_user_id=getattr(current_user, "id", None), + initiated_by_username=getattr(current_user, "username", None), files=[ BatchProgressEntry( file_no=file_no,