feat(billing): restrict WS subscriptions to initiating user or admins
This commit is contained in:
@@ -473,6 +473,9 @@ class BatchProgress(BaseModel):
|
|||||||
success_rate: Optional[float] = None
|
success_rate: Optional[float] = None
|
||||||
files: List[BatchProgressEntry] = Field(default_factory=list)
|
files: List[BatchProgressEntry] = Field(default_factory=list)
|
||||||
error_message: Optional[str] = None
|
error_message: Optional[str] = None
|
||||||
|
# Initiation metadata for authorization
|
||||||
|
initiated_by_user_id: Optional[int] = None
|
||||||
|
initiated_by_username: Optional[str] = None
|
||||||
|
|
||||||
model_config = ConfigDict(
|
model_config = ConfigDict(
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
@@ -694,6 +697,38 @@ async def generate_statement(
|
|||||||
@router.websocket("/statements/batch-progress/ws/{batch_id}")
|
@router.websocket("/statements/batch-progress/ws/{batch_id}")
|
||||||
async def ws_batch_progress(websocket: WebSocket, batch_id: str):
|
async def ws_batch_progress(websocket: WebSocket, batch_id: str):
|
||||||
"""WebSocket: subscribe to real-time updates for a batch_id using the WebSocket pool."""
|
"""WebSocket: subscribe to real-time updates for a batch_id using the WebSocket pool."""
|
||||||
|
# Authenticate first (without accepting) to enforce authorization before subscribing
|
||||||
|
user = await websocket_manager.authenticate_websocket(websocket)
|
||||||
|
if not user:
|
||||||
|
try:
|
||||||
|
await websocket.close(code=4401, reason="Authentication failed")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return
|
||||||
|
|
||||||
|
# Authorization: only initiating user or admins may subscribe to this batch stream
|
||||||
|
progress = await progress_store.get_progress(batch_id)
|
||||||
|
if not progress:
|
||||||
|
try:
|
||||||
|
await websocket.close(code=4404, reason="Batch not found")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return
|
||||||
|
|
||||||
|
is_admin = bool(getattr(user, "is_admin", False))
|
||||||
|
if not is_admin and getattr(user, "id", None) != getattr(progress, "initiated_by_user_id", None):
|
||||||
|
billing_logger.warning(
|
||||||
|
"Unauthorized WS subscription attempt for billing batch",
|
||||||
|
batch_id=batch_id,
|
||||||
|
user_id=getattr(user, "id", None),
|
||||||
|
username=getattr(user, "username", None),
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
await websocket.close(code=4403, reason="Not authorized to subscribe to this batch")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return
|
||||||
|
|
||||||
topic = f"batch_progress_{batch_id}"
|
topic = f"batch_progress_{batch_id}"
|
||||||
|
|
||||||
# Custom message handler for batch progress
|
# Custom message handler for batch progress
|
||||||
@@ -877,6 +912,8 @@ async def batch_generate_statements(
|
|||||||
failed_files=0,
|
failed_files=0,
|
||||||
started_at=start_time.isoformat(),
|
started_at=start_time.isoformat(),
|
||||||
updated_at=start_time.isoformat(),
|
updated_at=start_time.isoformat(),
|
||||||
|
initiated_by_user_id=getattr(current_user, "id", None),
|
||||||
|
initiated_by_username=getattr(current_user, "username", None),
|
||||||
files=[
|
files=[
|
||||||
BatchProgressEntry(
|
BatchProgressEntry(
|
||||||
file_no=file_no,
|
file_no=file_no,
|
||||||
|
|||||||
Reference in New Issue
Block a user