feat(billing): restrict WS subscriptions to initiating user or admins
This commit is contained in:
@@ -473,6 +473,9 @@ class BatchProgress(BaseModel):
|
||||
success_rate: Optional[float] = None
|
||||
files: List[BatchProgressEntry] = Field(default_factory=list)
|
||||
error_message: Optional[str] = None
|
||||
# Initiation metadata for authorization
|
||||
initiated_by_user_id: Optional[int] = None
|
||||
initiated_by_username: Optional[str] = None
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
@@ -694,6 +697,38 @@ async def generate_statement(
|
||||
@router.websocket("/statements/batch-progress/ws/{batch_id}")
|
||||
async def ws_batch_progress(websocket: WebSocket, batch_id: str):
|
||||
"""WebSocket: subscribe to real-time updates for a batch_id using the WebSocket pool."""
|
||||
# Authenticate first (without accepting) to enforce authorization before subscribing
|
||||
user = await websocket_manager.authenticate_websocket(websocket)
|
||||
if not user:
|
||||
try:
|
||||
await websocket.close(code=4401, reason="Authentication failed")
|
||||
except Exception:
|
||||
pass
|
||||
return
|
||||
|
||||
# Authorization: only initiating user or admins may subscribe to this batch stream
|
||||
progress = await progress_store.get_progress(batch_id)
|
||||
if not progress:
|
||||
try:
|
||||
await websocket.close(code=4404, reason="Batch not found")
|
||||
except Exception:
|
||||
pass
|
||||
return
|
||||
|
||||
is_admin = bool(getattr(user, "is_admin", False))
|
||||
if not is_admin and getattr(user, "id", None) != getattr(progress, "initiated_by_user_id", None):
|
||||
billing_logger.warning(
|
||||
"Unauthorized WS subscription attempt for billing batch",
|
||||
batch_id=batch_id,
|
||||
user_id=getattr(user, "id", None),
|
||||
username=getattr(user, "username", None),
|
||||
)
|
||||
try:
|
||||
await websocket.close(code=4403, reason="Not authorized to subscribe to this batch")
|
||||
except Exception:
|
||||
pass
|
||||
return
|
||||
|
||||
topic = f"batch_progress_{batch_id}"
|
||||
|
||||
# Custom message handler for batch progress
|
||||
@@ -877,6 +912,8 @@ async def batch_generate_statements(
|
||||
failed_files=0,
|
||||
started_at=start_time.isoformat(),
|
||||
updated_at=start_time.isoformat(),
|
||||
initiated_by_user_id=getattr(current_user, "id", None),
|
||||
initiated_by_username=getattr(current_user, "username", None),
|
||||
files=[
|
||||
BatchProgressEntry(
|
||||
file_no=file_no,
|
||||
|
||||
Reference in New Issue
Block a user