feat(billing): restrict WS subscriptions to initiating user or admins

This commit is contained in:
HotSwapp
2025-09-04 15:08:09 -05:00
parent 30e4c83618
commit 41ffbc1430

View File

@@ -473,6 +473,9 @@ class BatchProgress(BaseModel):
success_rate: Optional[float] = None success_rate: Optional[float] = None
files: List[BatchProgressEntry] = Field(default_factory=list) files: List[BatchProgressEntry] = Field(default_factory=list)
error_message: Optional[str] = None error_message: Optional[str] = None
# Initiation metadata for authorization
initiated_by_user_id: Optional[int] = None
initiated_by_username: Optional[str] = None
model_config = ConfigDict( model_config = ConfigDict(
json_schema_extra={ json_schema_extra={
@@ -694,6 +697,38 @@ async def generate_statement(
@router.websocket("/statements/batch-progress/ws/{batch_id}") @router.websocket("/statements/batch-progress/ws/{batch_id}")
async def ws_batch_progress(websocket: WebSocket, batch_id: str): async def ws_batch_progress(websocket: WebSocket, batch_id: str):
"""WebSocket: subscribe to real-time updates for a batch_id using the WebSocket pool.""" """WebSocket: subscribe to real-time updates for a batch_id using the WebSocket pool."""
# Authenticate first (without accepting) to enforce authorization before subscribing
user = await websocket_manager.authenticate_websocket(websocket)
if not user:
try:
await websocket.close(code=4401, reason="Authentication failed")
except Exception:
pass
return
# Authorization: only initiating user or admins may subscribe to this batch stream
progress = await progress_store.get_progress(batch_id)
if not progress:
try:
await websocket.close(code=4404, reason="Batch not found")
except Exception:
pass
return
is_admin = bool(getattr(user, "is_admin", False))
if not is_admin and getattr(user, "id", None) != getattr(progress, "initiated_by_user_id", None):
billing_logger.warning(
"Unauthorized WS subscription attempt for billing batch",
batch_id=batch_id,
user_id=getattr(user, "id", None),
username=getattr(user, "username", None),
)
try:
await websocket.close(code=4403, reason="Not authorized to subscribe to this batch")
except Exception:
pass
return
topic = f"batch_progress_{batch_id}" topic = f"batch_progress_{batch_id}"
# Custom message handler for batch progress # Custom message handler for batch progress
@@ -877,6 +912,8 @@ async def batch_generate_statements(
failed_files=0, failed_files=0,
started_at=start_time.isoformat(), started_at=start_time.isoformat(),
updated_at=start_time.isoformat(), updated_at=start_time.isoformat(),
initiated_by_user_id=getattr(current_user, "id", None),
initiated_by_username=getattr(current_user, "username", None),
files=[ files=[
BatchProgressEntry( BatchProgressEntry(
file_no=file_no, file_no=file_no,