Files
delphi-database/app/api/billing.py
2025-08-15 22:04:43 -05:00

1608 lines
56 KiB
Python

"""
Billing & Statements API endpoints
"""
from typing import List, Optional, Dict, Any, Set
from datetime import datetime, timezone, date, timedelta
import os
import re
from pathlib import Path
import asyncio
import logging
import threading
import time
from enum import Enum
from fastapi import APIRouter, Depends, HTTPException, status, Query, BackgroundTasks
from fastapi import Path as PathParam
from fastapi.responses import FileResponse
from fastapi import WebSocket, WebSocketDisconnect
from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy.orm import Session, joinedload
from sqlalchemy.exc import SQLAlchemyError
from app.database.base import get_db, SessionLocal
from app.models.files import File
from app.models.ledger import Ledger
from app.models.rolodex import Rolodex
from app.models.user import User
from app.auth.security import get_current_user, verify_token
from app.utils.responses import BulkOperationResponse, ErrorDetail
from app.utils.logging import StructuredLogger
from app.services.cache import cache_get_json, cache_set_json
from app.models.billing import BillingBatch, BillingBatchFile
router = APIRouter()
# Initialize logger for billing operations
billing_logger = StructuredLogger("billing_operations", "INFO")
# Realtime WebSocket subscriber registry: batch_id -> set[WebSocket]
_subscribers_by_batch: Dict[str, Set[WebSocket]] = {}
_subscribers_lock = asyncio.Lock()
async def _notify_progress_subscribers(progress: "BatchProgress") -> None:
"""Broadcast latest progress to active subscribers of a batch."""
batch_id = progress.batch_id
message = {"type": "progress", "data": progress.model_dump()}
async with _subscribers_lock:
sockets = list(_subscribers_by_batch.get(batch_id, set()))
if not sockets:
return
dead: List[WebSocket] = []
for ws in sockets:
try:
await ws.send_json(message)
except Exception:
dead.append(ws)
if dead:
async with _subscribers_lock:
bucket = _subscribers_by_batch.get(batch_id)
if bucket:
for ws in dead:
bucket.discard(ws)
if not bucket:
_subscribers_by_batch.pop(batch_id, None)
def _round(value: Optional[float]) -> float:
try:
return round(float(value or 0.0), 2)
except Exception:
return 0.0
class StatementEntry(BaseModel):
id: int
date: Optional[date]
t_code: str
t_type: str
description: Optional[str] = None
quantity: float = 0.0
rate: float = 0.0
amount: float
model_config = ConfigDict(from_attributes=True)
class StatementTotals(BaseModel):
charges_billed: float
charges_unbilled: float
charges_total: float
payments: float
trust_balance: float
current_balance: float
class StatementResponse(BaseModel):
file_no: str
client_name: Optional[str] = None
as_of: str
totals: StatementTotals
unbilled_entries: List[StatementEntry]
class BatchHistorySummary(BaseModel):
batch_id: str
status: str
total_files: int
successful_files: int
failed_files: int
started_at: str
updated_at: str
completed_at: Optional[str] = None
processing_time_seconds: Optional[float] = None
@router.get("/statements/batch-list", response_model=List[str])
async def list_active_batches(
current_user: User = Depends(get_current_user),
):
"""
List all currently active batch statement generation operations.
Returns batch IDs for operations that are currently pending or running.
Completed, failed, and cancelled operations are excluded.
**Returns:**
- List of active batch IDs that can be used with the progress endpoint
**Usage:**
Use this endpoint to discover active batch operations for progress monitoring.
"""
# Get the global progress store instance
# Note: progress_store is defined later in this module but is available at runtime
global progress_store
return await progress_store.list_active_batches()
@router.get("/statements/batch-progress/{batch_id}", response_model=Dict[str, Any])
async def get_batch_progress(
batch_id: str = PathParam(..., description="Batch operation identifier"),
current_user: User = Depends(get_current_user),
):
"""
Get real-time progress information for a batch statement generation operation.
Provides comprehensive progress tracking including:
- Overall batch status and completion percentage
- Individual file processing status and timing
- Current file being processed
- Estimated completion time based on processing rate
- Success/failure rates and error details
**Parameters:**
- **batch_id**: Unique identifier for the batch operation
**Returns:**
- Complete progress information including:
- Batch status (pending, running, completed, failed, cancelled)
- File counts (total, processed, successful, failed)
- Timing information and estimates
- Individual file details and results
- Error information if applicable
**Errors:**
- 404: Batch operation not found (may have expired or never existed)
"""
# Get the global progress store instance
global progress_store
progress = await progress_store.get_progress(batch_id)
if not progress:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Batch operation {batch_id} not found. It may have expired or never existed."
)
return progress.model_dump()
@router.get("/statements/batch-history", response_model=List[BatchHistorySummary])
async def list_batch_history(
status_filter: Optional[str] = Query(None, description="Status filter: pending|running|completed|failed|cancelled"),
sort: Optional[str] = Query("updated_desc", description="Sort: updated_desc|updated_asc|started_desc|started_asc|completed_desc|completed_asc"),
limit: int = Query(50, ge=1, le=200),
offset: int = Query(0, ge=0, le=10000),
start_date: Optional[str] = Query(None, description="ISO start bound (filters started_at)"),
end_date: Optional[str] = Query(None, description="ISO end bound (filters started_at)"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""List batch operations from persistent history with filters and pagination."""
q = db.query(BillingBatch)
if status_filter:
q = q.filter(BillingBatch.status == status_filter)
def _parse(dt: Optional[str]):
if not dt:
return None
try:
from datetime import datetime as _dt
return _dt.fromisoformat(dt.replace('Z', '+00:00'))
except Exception:
return None
if start_date:
sd = _parse(start_date)
if sd:
q = q.filter(BillingBatch.started_at >= sd)
if end_date:
ed = _parse(end_date)
if ed:
q = q.filter(BillingBatch.started_at <= ed)
sort_map = {
"updated_desc": (BillingBatch.updated_at.desc(),),
"updated_asc": (BillingBatch.updated_at.asc(),),
"started_desc": (BillingBatch.started_at.desc(),),
"started_asc": (BillingBatch.started_at.asc(),),
"completed_desc": (BillingBatch.completed_at.desc(),),
"completed_asc": (BillingBatch.completed_at.asc(),),
}
q = q.order_by(*sort_map.get(sort or "updated_desc", sort_map["updated_desc"]))
rows = q.offset(offset).limit(limit).all()
items: List[BatchHistorySummary] = []
for r in rows:
items.append(BatchHistorySummary(
batch_id=r.batch_id,
status=r.status,
total_files=r.total_files,
successful_files=r.successful_files,
failed_files=r.failed_files,
started_at=r.started_at.isoformat() if r.started_at else None,
updated_at=r.updated_at.isoformat() if r.updated_at else None,
completed_at=r.completed_at.isoformat() if r.completed_at else None,
processing_time_seconds=r.processing_time_seconds,
))
return items
@router.get("/statements/{file_no}", response_model=StatementResponse)
async def get_statement_snapshot(
file_no: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Return a computed statement snapshot for a file.
Includes totals (billed/unbilled charges, payments, trust balance, current balance)
and an itemized list of unbilled transactions.
"""
file_obj = (
db.query(File)
.options(joinedload(File.owner))
.filter(File.file_no == file_no)
.first()
)
if not file_obj:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="File not found",
)
# Load relevant ledger entries once
entries: List[Ledger] = db.query(Ledger).filter(Ledger.file_no == file_no).all()
# Charges are debits: hourly (2), flat (3), disbursements (4)
CHARGE_TYPES = {"2", "3", "4"}
charges_billed = sum(e.amount for e in entries if e.t_type in CHARGE_TYPES and e.billed == "Y")
charges_unbilled = sum(e.amount for e in entries if e.t_type in CHARGE_TYPES and e.billed != "Y")
charges_total = charges_billed + charges_unbilled
# Payments/credits are type 5
payments_total = sum(e.amount for e in entries if e.t_type == "5")
# Trust balance is tracked on File (kept in sync by ledger endpoints)
trust_balance = file_obj.trust_bal or 0.0
# Current balance is total charges minus payments
current_balance = charges_total - payments_total
# Itemized unbilled entries (charges only)
unbilled_entries = [
StatementEntry(
id=e.id,
date=e.date,
t_code=e.t_code,
t_type=e.t_type,
description=e.note,
quantity=e.quantity or 0.0,
rate=e.rate or 0.0,
amount=e.amount,
)
for e in entries
if e.t_type in CHARGE_TYPES and e.billed != "Y"
]
client_name = None
if file_obj.owner:
client_name = f"{file_obj.owner.first or ''} {file_obj.owner.last}".strip()
response = StatementResponse(
file_no=file_no,
client_name=client_name or None,
as_of=datetime.now(timezone.utc).isoformat(),
totals=StatementTotals(
charges_billed=_round(charges_billed),
charges_unbilled=_round(charges_unbilled),
charges_total=_round(charges_total),
payments=_round(payments_total),
trust_balance=_round(trust_balance),
current_balance=_round(current_balance),
),
unbilled_entries=unbilled_entries,
)
return response
class GenerateStatementRequest(BaseModel):
file_no: str
period: Optional[str] = None # Supports YYYY-MM for monthly; optional
class GeneratedStatementMeta(BaseModel):
file_no: str
client_name: Optional[str] = None
as_of: str
period: Optional[str] = None
totals: StatementTotals
unbilled_count: int
export_path: str
filename: str
size: int
content_type: str = "text/html"
class BatchGenerateStatementRequest(BaseModel):
file_numbers: List[str] = Field(..., description="List of file numbers to generate statements for", max_length=50)
period: Optional[str] = Field(None, description="Optional period filter in YYYY-MM format")
model_config = ConfigDict(
json_schema_extra={
"example": {
"file_numbers": ["ABC-123", "DEF-456", "GHI-789"],
"period": "2024-01"
}
}
)
class BatchFileResult(BaseModel):
file_no: str
status: str # "success" or "failed"
message: Optional[str] = None
statement_meta: Optional[GeneratedStatementMeta] = None
error_details: Optional[str] = None
model_config = ConfigDict(
json_schema_extra={
"example": {
"file_no": "ABC-123",
"status": "success",
"message": "Statement generated successfully",
"statement_meta": {
"file_no": "ABC-123",
"filename": "statement_ABC-123_20240115_143022.html",
"size": 2048
}
}
}
)
class BatchGenerateStatementResponse(BaseModel):
batch_id: str = Field(..., description="Unique identifier for this batch operation")
total_files: int = Field(..., description="Total number of files requested")
successful: int = Field(..., description="Number of files processed successfully")
failed: int = Field(..., description="Number of files that failed processing")
success_rate: float = Field(..., description="Success rate as percentage")
started_at: str = Field(..., description="ISO timestamp when batch started")
completed_at: str = Field(..., description="ISO timestamp when batch completed")
processing_time_seconds: float = Field(..., description="Total processing time in seconds")
results: List[BatchFileResult] = Field(..., description="Individual file processing results")
model_config = ConfigDict(
json_schema_extra={
"example": {
"batch_id": "batch_20240115_143022_abc123",
"total_files": 3,
"successful": 2,
"failed": 1,
"success_rate": 66.67,
"started_at": "2024-01-15T14:30:22.123456+00:00",
"completed_at": "2024-01-15T14:30:27.654321+00:00",
"processing_time_seconds": 5.53,
"results": [
{
"file_no": "ABC-123",
"status": "success",
"message": "Statement generated successfully"
}
]
}
}
)
class BatchStatus(str, Enum):
"""Enumeration of batch operation statuses."""
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
class BatchProgressEntry(BaseModel):
"""Progress information for a single file in a batch operation."""
file_no: str
status: str # "pending", "processing", "completed", "failed"
started_at: Optional[str] = None
completed_at: Optional[str] = None
error_message: Optional[str] = None
statement_meta: Optional[GeneratedStatementMeta] = None
model_config = ConfigDict(
json_schema_extra={
"example": {
"file_no": "ABC-123",
"status": "completed",
"started_at": "2024-01-15T14:30:22.123456+00:00",
"completed_at": "2024-01-15T14:30:25.654321+00:00",
"statement_meta": {
"file_no": "ABC-123",
"filename": "statement_ABC-123_20240115_143022.html",
"size": 2048
}
}
}
)
class BatchProgress(BaseModel):
"""Comprehensive progress information for a batch operation."""
batch_id: str
status: BatchStatus
total_files: int
processed_files: int
successful_files: int
failed_files: int
current_file: Optional[str] = None
started_at: str
updated_at: str
completed_at: Optional[str] = None
estimated_completion: Optional[str] = None
processing_time_seconds: Optional[float] = None
success_rate: Optional[float] = None
files: List[BatchProgressEntry] = Field(default_factory=list)
error_message: Optional[str] = None
model_config = ConfigDict(
json_schema_extra={
"example": {
"batch_id": "batch_20240115_143022_abc123",
"status": "running",
"total_files": 5,
"processed_files": 2,
"successful_files": 2,
"failed_files": 0,
"current_file": "ABC-123",
"started_at": "2024-01-15T14:30:22.123456+00:00",
"updated_at": "2024-01-15T14:30:24.789012+00:00",
"estimated_completion": "2024-01-15T14:30:30.000000+00:00",
"files": [
{
"file_no": "ABC-123",
"status": "processing",
"started_at": "2024-01-15T14:30:24.789012+00:00"
}
]
}
}
)
class BatchProgressStore:
"""
Thread-safe progress store for batch operations with caching support.
Uses Redis for distributed caching when available, falls back to in-memory storage.
Includes automatic cleanup of old progress data.
"""
def __init__(self):
self._lock = threading.RLock()
self._in_memory_store: Dict[str, BatchProgress] = {}
self._cleanup_interval = 3600 # 1 hour
self._retention_period = 86400 # 24 hours
self._last_cleanup = time.time()
def _should_cleanup(self) -> bool:
"""Check if cleanup should be performed."""
return time.time() - self._last_cleanup > self._cleanup_interval
async def _cleanup_old_entries(self) -> None:
"""Remove old progress entries based on retention policy."""
if not self._should_cleanup():
return
cutoff_time = datetime.now(timezone.utc) - timedelta(seconds=self._retention_period)
cutoff_str = cutoff_time.isoformat()
with self._lock:
# Clean up in-memory store
expired_keys = []
for batch_id, progress in self._in_memory_store.items():
if (progress.status in [BatchStatus.COMPLETED, BatchStatus.FAILED, BatchStatus.CANCELLED] and
progress.updated_at < cutoff_str):
expired_keys.append(batch_id)
for key in expired_keys:
del self._in_memory_store[key]
billing_logger.info(
"Cleaned up old batch progress entries",
cleaned_count=len(expired_keys),
cutoff_time=cutoff_str
)
self._last_cleanup = time.time()
async def get_progress(self, batch_id: str) -> Optional[BatchProgress]:
"""Get progress information for a batch operation."""
await self._cleanup_old_entries()
# Try cache first
try:
cached_data = await cache_get_json("batch_progress", None, {"batch_id": batch_id})
if cached_data:
return BatchProgress.model_validate(cached_data)
except Exception as e:
billing_logger.debug(f"Cache get failed for batch {batch_id}: {str(e)}")
# Fall back to in-memory store
with self._lock:
return self._in_memory_store.get(batch_id)
async def set_progress(self, progress: BatchProgress) -> None:
"""Store progress information for a batch operation."""
progress.updated_at = datetime.now(timezone.utc).isoformat()
# Store in cache with TTL
try:
await cache_set_json(
"batch_progress",
None,
{"batch_id": progress.batch_id},
progress.model_dump(),
self._retention_period
)
except Exception as e:
billing_logger.debug(f"Cache set failed for batch {progress.batch_id}: {str(e)}")
# Store in memory as backup
with self._lock:
self._in_memory_store[progress.batch_id] = progress
# Notify subscribers (best-effort)
try:
await _notify_progress_subscribers(progress)
except Exception:
pass
async def delete_progress(self, batch_id: str) -> None:
"""Delete progress information for a batch operation."""
# Note: The current cache service doesn't have a delete function
# We'll rely on TTL expiration for cache cleanup
# Just remove from in-memory store
with self._lock:
self._in_memory_store.pop(batch_id, None)
async def list_active_batches(self) -> List[str]:
"""List all active batch operations."""
await self._cleanup_old_entries()
active_batches = []
with self._lock:
for batch_id, progress in self._in_memory_store.items():
if progress.status in [BatchStatus.PENDING, BatchStatus.RUNNING]:
active_batches.append(batch_id)
return active_batches
# Global progress store instance
progress_store = BatchProgressStore()
def _parse_period_month(period: Optional[str]) -> Optional[tuple[date, date]]:
"""Parse period in the form YYYY-MM and return (start_date, end_date) inclusive.
Returns None when period is not provided or invalid.
"""
if not period:
return None
m = re.fullmatch(r"(\d{4})-(\d{2})", str(period).strip())
if not m:
return None
year = int(m.group(1))
month = int(m.group(2))
if month < 1 or month > 12:
return None
from calendar import monthrange
last_day = monthrange(year, month)[1]
return date(year, month, 1), date(year, month, last_day)
def _render_statement_html(
*,
file_no: str,
client_name: Optional[str],
matter: Optional[str],
as_of_iso: str,
period: Optional[str],
totals: StatementTotals,
unbilled_entries: List[StatementEntry],
) -> str:
"""Create a simple, self-contained HTML statement string."""
# Rows for unbilled entries
def _fmt(val: Optional[float]) -> str:
try:
return f"{float(val or 0):.2f}"
except Exception:
return "0.00"
rows = []
for e in unbilled_entries:
rows.append(
f"<tr><td>{e.date.isoformat() if e.date else ''}</td><td>{e.t_code}</td><td>{(e.description or '').replace('<','&lt;').replace('>','&gt;')}</td>"
f"<td style='text-align:right'>{_fmt(e.quantity)}</td><td style='text-align:right'>{_fmt(e.rate)}</td><td style='text-align:right'>{_fmt(e.amount)}</td></tr>"
)
rows_html = "\n".join(rows) if rows else "<tr><td colspan='6' style='text-align:center;color:#666'>No unbilled entries</td></tr>"
period_html = f"<div><strong>Period:</strong> {period}</div>" if period else ""
html = f"""
<!DOCTYPE html>
<html lang=\"en\">
<head>
<meta charset=\"utf-8\" />
<title>Statement {file_no}</title>
<style>
body {{ font-family: -apple-system, BlinkMacSystemFont, Segoe UI, Roboto, Helvetica, Arial, sans-serif; margin: 24px; }}
h1 {{ margin: 0 0 8px 0; }}
.meta {{ color: #444; margin-bottom: 16px; }}
table {{ border-collapse: collapse; width: 100%; }}
th, td {{ border: 1px solid #ddd; padding: 8px; font-size: 14px; }}
th {{ background: #f6f6f6; text-align: left; }}
.totals {{ margin: 16px 0; display: grid; grid-template-columns: repeat(auto-fit, minmax(180px, 1fr)); gap: 8px; }}
.totals div {{ background: #fafafa; border: 1px solid #eee; padding: 8px; }}
</style>
</head>
<body>
<h1>Statement</h1>
<div class=\"meta\">
<div><strong>File:</strong> {file_no}</div>
<div><strong>Client:</strong> {client_name or ''}</div>
<div><strong>Matter:</strong> {matter or ''}</div>
<div><strong>As of:</strong> {as_of_iso}</div>
{period_html}
</div>
<div class=\"totals\">
<div><strong>Charges (billed)</strong><br/>${_fmt(totals.charges_billed)}</div>
<div><strong>Charges (unbilled)</strong><br/>${_fmt(totals.charges_unbilled)}</div>
<div><strong>Charges (total)</strong><br/>${_fmt(totals.charges_total)}</div>
<div><strong>Payments</strong><br/>${_fmt(totals.payments)}</div>
<div><strong>Trust balance</strong><br/>${_fmt(totals.trust_balance)}</div>
<div><strong>Current balance</strong><br/>${_fmt(totals.current_balance)}</div>
</div>
<h2>Unbilled Entries</h2>
<table>
<thead>
<tr>
<th>Date</th>
<th>Code</th>
<th>Description</th>
<th style=\"text-align:right\">Qty</th>
<th style=\"text-align:right\">Rate</th>
<th style=\"text-align:right\">Amount</th>
</tr>
</thead>
<tbody>
{rows_html}
</tbody>
</table>
</body>
</html>
"""
return html
def _generate_single_statement(
file_no: str,
period: Optional[str],
db: Session
) -> GeneratedStatementMeta:
"""
Internal helper to generate a statement for a single file.
Args:
file_no: File number to generate statement for
period: Optional period filter (YYYY-MM format)
db: Database session
Returns:
GeneratedStatementMeta with file metadata and export path
Raises:
HTTPException: If file not found or generation fails
"""
file_obj = (
db.query(File)
.options(joinedload(File.owner))
.filter(File.file_no == file_no)
.first()
)
if not file_obj:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"File {file_no} not found",
)
# Optional period filtering (YYYY-MM)
date_range = _parse_period_month(period)
q = db.query(Ledger).filter(Ledger.file_no == file_no)
if date_range:
start_date, end_date = date_range
q = q.filter(Ledger.date >= start_date).filter(Ledger.date <= end_date)
entries: List[Ledger] = q.all()
CHARGE_TYPES = {"2", "3", "4"}
charges_billed = sum(e.amount for e in entries if e.t_type in CHARGE_TYPES and e.billed == "Y")
charges_unbilled = sum(e.amount for e in entries if e.t_type in CHARGE_TYPES and e.billed != "Y")
charges_total = charges_billed + charges_unbilled
payments_total = sum(e.amount for e in entries if e.t_type == "5")
trust_balance = file_obj.trust_bal or 0.0
current_balance = charges_total - payments_total
unbilled_entries = [
StatementEntry(
id=e.id,
date=e.date,
t_code=e.t_code,
t_type=e.t_type,
description=e.note,
quantity=e.quantity or 0.0,
rate=e.rate or 0.0,
amount=e.amount,
)
for e in entries
if e.t_type in CHARGE_TYPES and e.billed != "Y"
]
client_name = None
if file_obj.owner:
client_name = f"{file_obj.owner.first or ''} {file_obj.owner.last}".strip()
as_of_iso = datetime.now(timezone.utc).isoformat()
totals_model = StatementTotals(
charges_billed=_round(charges_billed),
charges_unbilled=_round(charges_unbilled),
charges_total=_round(charges_total),
payments=_round(payments_total),
trust_balance=_round(trust_balance),
current_balance=_round(current_balance),
)
# Render HTML
html = _render_statement_html(
file_no=file_no,
client_name=client_name or None,
matter=file_obj.regarding,
as_of_iso=as_of_iso,
period=period,
totals=totals_model,
unbilled_entries=unbilled_entries,
)
# Ensure exports directory and write file
exports_dir = Path("exports")
try:
exports_dir.mkdir(exist_ok=True)
except Exception:
# Best-effort: if cannot create, bubble up internal error
raise HTTPException(status_code=500, detail="Unable to create exports directory")
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S_%f")
safe_file_no = str(file_no).replace("/", "_").replace("\\", "_")
filename = f"statement_{safe_file_no}_{timestamp}.html"
export_path = exports_dir / filename
html_bytes = html.encode("utf-8")
with open(export_path, "wb") as f:
f.write(html_bytes)
size = export_path.stat().st_size
return GeneratedStatementMeta(
file_no=file_no,
client_name=client_name or None,
as_of=as_of_iso,
period=period,
totals=totals_model,
unbilled_count=len(unbilled_entries),
export_path=str(export_path),
filename=filename,
size=size,
content_type="text/html",
)
@router.post("/statements/generate", response_model=GeneratedStatementMeta)
async def generate_statement(
payload: GenerateStatementRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Generate a simple HTML statement and store it under exports/.
Returns metadata about the generated artifact.
"""
return _generate_single_statement(payload.file_no, payload.period, db)
async def _ws_authenticate(websocket: WebSocket) -> Optional[User]:
"""Authenticate WebSocket via JWT token in query (?token=) or Authorization header."""
token = websocket.query_params.get("token")
if not token:
try:
auth_header = dict(websocket.headers).get("authorization") or ""
if auth_header.lower().startswith("bearer "):
token = auth_header.split(" ", 1)[1].strip()
except Exception:
token = None
if not token:
return None
username = verify_token(token)
if not username:
return None
db = SessionLocal()
try:
user = db.query(User).filter(User.username == username).first()
if not user or not user.is_active:
return None
return user
finally:
db.close()
async def _ws_keepalive(ws: WebSocket, stop_event: asyncio.Event) -> None:
try:
while not stop_event.is_set():
await asyncio.sleep(25)
try:
await ws.send_json({"type": "ping", "ts": datetime.now(timezone.utc).isoformat()})
except Exception:
break
finally:
stop_event.set()
@router.websocket("/statements/batch-progress/ws/{batch_id}")
async def ws_batch_progress(websocket: WebSocket, batch_id: str):
"""WebSocket: subscribe to real-time updates for a batch_id."""
user = await _ws_authenticate(websocket)
if not user:
await websocket.close(code=4401)
return
await websocket.accept()
# Register
async with _subscribers_lock:
bucket = _subscribers_by_batch.get(batch_id)
if not bucket:
bucket = set()
_subscribers_by_batch[batch_id] = bucket
bucket.add(websocket)
# Send initial snapshot
try:
snapshot = await progress_store.get_progress(batch_id)
await websocket.send_json({"type": "progress", "data": snapshot.model_dump() if snapshot else None})
except Exception:
pass
# Keepalive + receive loop
stop_event: asyncio.Event = asyncio.Event()
ka_task = asyncio.create_task(_ws_keepalive(websocket, stop_event))
try:
while not stop_event.is_set():
try:
msg = await websocket.receive_text()
except WebSocketDisconnect:
break
except Exception:
break
if isinstance(msg, str) and msg.strip() == "ping":
try:
await websocket.send_text("pong")
except Exception:
break
finally:
stop_event.set()
try:
ka_task.cancel()
except Exception:
pass
async with _subscribers_lock:
bucket = _subscribers_by_batch.get(batch_id)
if bucket and websocket in bucket:
bucket.discard(websocket)
if not bucket:
_subscribers_by_batch.pop(batch_id, None)
@router.delete("/statements/batch-progress/{batch_id}")
async def cancel_batch_operation(
batch_id: str = PathParam(..., description="Batch operation identifier to cancel"),
current_user: User = Depends(get_current_user),
):
"""
Cancel an active batch statement generation operation.
**Note:** This endpoint marks the batch as cancelled but does not interrupt
currently running file processing. Files already being processed will complete,
but pending files will be skipped.
**Parameters:**
- **batch_id**: Unique identifier for the batch operation to cancel
**Returns:**
- Success message confirming cancellation
**Errors:**
- 404: Batch operation not found
- 400: Batch operation cannot be cancelled (already completed/failed)
"""
progress = await progress_store.get_progress(batch_id)
if not progress:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Batch operation {batch_id} not found"
)
if progress.status not in [BatchStatus.PENDING, BatchStatus.RUNNING]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Cannot cancel batch operation with status: {progress.status}"
)
# Mark as cancelled
progress.status = BatchStatus.CANCELLED
progress.completed_at = datetime.now(timezone.utc).isoformat()
progress.processing_time_seconds = (
datetime.fromisoformat(progress.completed_at.replace('Z', '+00:00')) -
datetime.fromisoformat(progress.started_at.replace('Z', '+00:00'))
).total_seconds()
await progress_store.set_progress(progress)
billing_logger.info(
"Batch operation cancelled",
batch_id=batch_id,
user_id=getattr(current_user, "id", None),
processed_files=progress.processed_files,
total_files=progress.total_files
)
return {"message": f"Batch operation {batch_id} has been cancelled"}
async def _calculate_estimated_completion(
progress: BatchProgress,
current_time: datetime
) -> Optional[str]:
"""Calculate estimated completion time based on current progress."""
if progress.processed_files == 0:
return None
start_time = datetime.fromisoformat(progress.started_at.replace('Z', '+00:00'))
elapsed_seconds = (current_time - start_time).total_seconds()
if elapsed_seconds <= 0:
return None
# Calculate average time per file
avg_time_per_file = elapsed_seconds / progress.processed_files
remaining_files = progress.total_files - progress.processed_files
if remaining_files <= 0:
return current_time.isoformat()
estimated_remaining_seconds = avg_time_per_file * remaining_files
estimated_completion = current_time + timedelta(seconds=estimated_remaining_seconds)
return estimated_completion.isoformat()
@router.post("/statements/batch-generate", response_model=BatchGenerateStatementResponse)
async def batch_generate_statements(
payload: BatchGenerateStatementRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
Generate statements for multiple files simultaneously with real-time progress tracking and error handling.
Processes statements for up to 50 files at once. Individual file failures do not stop the batch operation.
Each file is processed independently with detailed error reporting and real-time progress updates.
**Parameters:**
- **file_numbers**: List of file numbers to generate statements for (max 50)
- **period**: Optional period filter in YYYY-MM format for all files
**Returns:**
- Detailed batch operation results including:
- Total files processed
- Success/failure counts and rates
- Individual file results with error details
- Processing time metrics
- Unique batch identifier for progress tracking
**Features:**
- Real-time progress tracking via `/statements/batch-progress/{batch_id}`
- Individual file error handling - failures don't stop other files
- Estimated completion time calculations
- Detailed error reporting per file
- Batch operation identification for audit trails
- Automatic cleanup of progress data after completion
"""
# Validate request
if not payload.file_numbers:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="At least one file number must be provided"
)
if len(payload.file_numbers) > 50:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Maximum 50 files allowed per batch operation"
)
# Remove duplicates while preserving order
unique_file_numbers = list(dict.fromkeys(payload.file_numbers))
# Generate batch ID and timing
start_time = datetime.now(timezone.utc)
batch_id = f"batch_{start_time.strftime('%Y%m%d_%H%M%S')}_{abs(hash(str(unique_file_numbers))) % 10000:04d}"
billing_logger.info(
"Starting batch statement generation",
batch_id=batch_id,
total_files=len(unique_file_numbers),
file_numbers=unique_file_numbers,
period=payload.period,
user_id=getattr(current_user, "id", None),
user_name=getattr(current_user, "username", None)
)
# Initialize progress tracking
progress = BatchProgress(
batch_id=batch_id,
status=BatchStatus.RUNNING,
total_files=len(unique_file_numbers),
processed_files=0,
successful_files=0,
failed_files=0,
started_at=start_time.isoformat(),
updated_at=start_time.isoformat(),
files=[
BatchProgressEntry(
file_no=file_no,
status="pending"
) for file_no in unique_file_numbers
]
)
# Store initial progress
await progress_store.set_progress(progress)
# Track results for final response
results: List[BatchFileResult] = []
successful = 0
failed = 0
try:
# Process each file
for idx, file_no in enumerate(unique_file_numbers):
current_time = datetime.now(timezone.utc)
# Check if operation was cancelled
current_progress = await progress_store.get_progress(batch_id)
if current_progress and current_progress.status == BatchStatus.CANCELLED:
billing_logger.info(
"Batch operation cancelled, skipping remaining files",
batch_id=batch_id,
file_no=file_no,
remaining_files=len(unique_file_numbers) - idx
)
break
# Update progress - mark current file as processing
progress.current_file = file_no
progress.files[idx].status = "processing"
progress.files[idx].started_at = current_time.isoformat()
progress.estimated_completion = await _calculate_estimated_completion(progress, current_time)
await progress_store.set_progress(progress)
billing_logger.info(
"Processing file statement",
batch_id=batch_id,
file_no=file_no,
progress=f"{idx + 1}/{len(unique_file_numbers)}",
progress_percent=round(((idx + 1) / len(unique_file_numbers)) * 100, 1)
)
try:
# Generate statement for this file
statement_meta = _generate_single_statement(file_no, payload.period, db)
# Success - update progress
completed_time = datetime.now(timezone.utc).isoformat()
progress.files[idx].status = "completed"
progress.files[idx].completed_at = completed_time
progress.files[idx].statement_meta = statement_meta
progress.processed_files += 1
progress.successful_files += 1
successful += 1
results.append(BatchFileResult(
file_no=file_no,
status="success",
message="Statement generated successfully",
statement_meta=statement_meta
))
billing_logger.info(
"File statement generated successfully",
batch_id=batch_id,
file_no=file_no,
filename=statement_meta.filename,
size=statement_meta.size
)
except HTTPException as e:
# HTTP errors (e.g., file not found)
error_msg = e.detail
completed_time = datetime.now(timezone.utc).isoformat()
progress.files[idx].status = "failed"
progress.files[idx].completed_at = completed_time
progress.files[idx].error_message = error_msg
progress.processed_files += 1
progress.failed_files += 1
failed += 1
results.append(BatchFileResult(
file_no=file_no,
status="failed",
message=f"Generation failed: {error_msg}",
error_details=str(e.detail)
))
billing_logger.warning(
"File statement generation failed (HTTP error)",
batch_id=batch_id,
file_no=file_no,
error=error_msg,
status_code=e.status_code
)
except SQLAlchemyError as e:
# Database errors
error_msg = f"Database error: {str(e)}"
completed_time = datetime.now(timezone.utc).isoformat()
progress.files[idx].status = "failed"
progress.files[idx].completed_at = completed_time
progress.files[idx].error_message = error_msg
progress.processed_files += 1
progress.failed_files += 1
failed += 1
results.append(BatchFileResult(
file_no=file_no,
status="failed",
message=f"Database error during generation",
error_details=error_msg
))
billing_logger.error(
"File statement generation failed (database error)",
batch_id=batch_id,
file_no=file_no,
error=str(e)
)
except Exception as e:
# Any other unexpected errors
error_msg = f"Unexpected error: {str(e)}"
completed_time = datetime.now(timezone.utc).isoformat()
progress.files[idx].status = "failed"
progress.files[idx].completed_at = completed_time
progress.files[idx].error_message = error_msg
progress.processed_files += 1
progress.failed_files += 1
failed += 1
results.append(BatchFileResult(
file_no=file_no,
status="failed",
message="Unexpected error during generation",
error_details=error_msg
))
billing_logger.error(
"File statement generation failed (unexpected error)",
batch_id=batch_id,
file_no=file_no,
error=str(e),
error_type=type(e).__name__
)
# Update progress after each file
await progress_store.set_progress(progress)
# Mark batch as completed
end_time = datetime.now(timezone.utc)
progress.status = BatchStatus.COMPLETED
progress.completed_at = end_time.isoformat()
progress.current_file = None
progress.processing_time_seconds = (end_time - start_time).total_seconds()
progress.success_rate = (successful / len(unique_file_numbers) * 100) if len(unique_file_numbers) > 0 else 0
progress.estimated_completion = None
await progress_store.set_progress(progress)
except Exception as e:
# Handle batch-level failures
end_time = datetime.now(timezone.utc)
progress.status = BatchStatus.FAILED
progress.completed_at = end_time.isoformat()
progress.error_message = f"Batch operation failed: {str(e)}"
progress.processing_time_seconds = (end_time - start_time).total_seconds()
await progress_store.set_progress(progress)
billing_logger.error(
"Batch statement generation failed",
batch_id=batch_id,
error=str(e),
error_type=type(e).__name__
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Batch operation failed: {str(e)}"
)
# Calculate final metrics
processing_time = (end_time - start_time).total_seconds()
total_files = len(unique_file_numbers)
success_rate = (successful / total_files * 100) if total_files > 0 else 0
billing_logger.info(
"Batch statement generation completed",
batch_id=batch_id,
total_files=total_files,
successful=successful,
failed=failed,
success_rate=success_rate,
processing_time_seconds=processing_time
)
# Persist batch summary and per-file results
try:
def _parse_iso(dt: Optional[str]):
if not dt:
return None
try:
from datetime import datetime as _dt
return _dt.fromisoformat(dt.replace('Z', '+00:00'))
except Exception:
return None
batch_row = BillingBatch(
batch_id=batch_id,
status=str(progress.status),
total_files=total_files,
successful_files=successful,
failed_files=failed,
started_at=_parse_iso(progress.started_at),
updated_at=_parse_iso(progress.updated_at),
completed_at=_parse_iso(progress.completed_at),
processing_time_seconds=processing_time,
success_rate=success_rate,
error_message=progress.error_message,
)
db.add(batch_row)
for f in progress.files:
meta = getattr(f, 'statement_meta', None)
filename = None
size = None
if meta is not None:
try:
filename = getattr(meta, 'filename', None)
size = getattr(meta, 'size', None)
except Exception:
pass
if filename is None and isinstance(meta, dict):
filename = meta.get('filename')
size = meta.get('size')
db.add(BillingBatchFile(
batch_id=batch_id,
file_no=f.file_no,
status=str(f.status),
error_message=f.error_message,
filename=filename,
size=size,
started_at=_parse_iso(f.started_at),
completed_at=_parse_iso(f.completed_at),
))
db.commit()
except Exception:
try:
db.rollback()
except Exception:
pass
return BatchGenerateStatementResponse(
batch_id=batch_id,
total_files=total_files,
successful=successful,
failed=failed,
success_rate=round(success_rate, 2),
started_at=start_time.isoformat(),
completed_at=end_time.isoformat(),
processing_time_seconds=round(processing_time, 2),
results=results
)
class StatementFileMeta(BaseModel):
"""Metadata for a generated statement file."""
filename: str = Field(..., description="The filename of the generated statement")
size: int = Field(..., description="File size in bytes")
created: str = Field(..., description="ISO timestamp when the file was created")
model_config = ConfigDict(
json_schema_extra={
"example": {
"filename": "statement_ABC-123_20240115_143022.html",
"size": 2048,
"created": "2024-01-15T14:30:22.123456+00:00"
}
}
)
class DeleteStatementResponse(BaseModel):
"""Response for successful statement deletion."""
message: str = Field(..., description="Success message")
filename: str = Field(..., description="Name of the deleted file")
model_config = ConfigDict(
json_schema_extra={
"example": {
"message": "Statement deleted successfully",
"filename": "statement_ABC-123_20240115_143022.html"
}
}
)
@router.get("/statements/{file_no}/list", response_model=List[StatementFileMeta])
async def list_generated_statements(
file_no: str = PathParam(..., description="File number to list statements for"),
period: Optional[str] = Query(
None,
description="Optional period filter in YYYY-MM format (e.g., '2024-01')",
pattern=r"^\d{4}-\d{2}$"
),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""List generated statement files for a specific file number.
Returns metadata for all generated statement HTML files, sorted by creation time (newest first).
Optionally filter by billing period using the period parameter.
**Parameters:**
- **file_no**: The file number to list statements for
- **period**: Optional filter for statements from a specific billing period (YYYY-MM format)
**Returns:**
- List of statement file metadata including filename, size, and creation timestamp
**Errors:**
- 404: File not found or no statements exist
"""
# Ensure file exists
file_obj = db.query(File).filter(File.file_no == file_no).first()
if not file_obj:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="File not found",
)
exports_dir = Path("exports")
if not exports_dir.exists():
return []
safe_file_no = str(file_no).replace("/", "_").replace("\\", "_")
candidates = list(exports_dir.glob(f"statement_{safe_file_no}_*.html"))
if not candidates:
return []
# Optional filter by period by inspecting HTML content
if period:
filtered: List[Path] = []
search_token = f"Period:</strong> {period}</div>"
for path in candidates:
try:
with open(path, "r", encoding="utf-8") as f:
content = f.read()
if search_token in content:
filtered.append(path)
except Exception:
continue
candidates = filtered
# Sort newest first by modification time
candidates.sort(key=lambda p: p.stat().st_mtime, reverse=True)
result: List[StatementFileMeta] = []
for path in candidates:
try:
st = path.stat()
created_iso = datetime.fromtimestamp(st.st_mtime, timezone.utc).isoformat()
result.append(StatementFileMeta(filename=path.name, size=st.st_size, created=created_iso))
except FileNotFoundError:
continue
return result
@router.delete("/statements/{file_no}/{filename}", response_model=DeleteStatementResponse)
async def delete_generated_statement(
file_no: str = PathParam(..., description="File number that owns the statement"),
filename: str = PathParam(..., description="Name of the statement file to delete"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Delete a specific generated statement file.
Securely deletes a statement HTML file that belongs to the specified file number.
Security constraints ensure users can only delete statements that belong to the specified file_no.
**Parameters:**
- **file_no**: The file number that owns the statement
- **filename**: Name of the statement file to delete (must match expected naming pattern)
**Returns:**
- Success message and deleted filename
**Security:**
- Only allows deletion of files matching the expected naming pattern for the file_no
- Prevents cross-file statement deletion and path traversal attacks
**Errors:**
- 404: File not found, statement file not found, or security validation failed
- 500: File deletion failed
"""
# Ensure file exists
file_obj = db.query(File).filter(File.file_no == file_no).first()
if not file_obj:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="File not found",
)
exports_dir = Path("exports")
if not exports_dir.exists():
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Statement not found")
# Security: ensure filename matches expected pattern for this file_no
safe_file_no = str(file_no).replace("/", "_").replace("\\", "_")
expected_prefix = f"statement_{safe_file_no}_"
if not filename.startswith(expected_prefix) or not filename.endswith(".html"):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Statement not found",
)
statement_path = exports_dir / filename
if not statement_path.exists():
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Statement not found",
)
try:
statement_path.unlink()
return DeleteStatementResponse(
message="Statement deleted successfully",
filename=filename
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to delete statement: {str(e)}",
)
@router.get("/statements/{file_no}/download", responses={
200: {
"description": "Statement HTML file",
"content": {"text/html": {}},
"headers": {
"content-disposition": {
"description": "Attachment header with filename",
"schema": {"type": "string"}
}
}
},
404: {"description": "File or statement not found"}
})
async def download_latest_statement(
file_no: str = PathParam(..., description="File number to download statement for"),
period: Optional[str] = Query(
None,
description="Optional period filter in YYYY-MM format (e.g., '2024-01')",
pattern=r"^\d{4}-\d{2}$"
),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Download the most recent generated statement HTML file for a file number.
Returns the newest statement file as an HTML attachment. Optionally filter to find
the newest statement from a specific billing period.
**Parameters:**
- **file_no**: The file number to download statement for
- **period**: Optional filter for statements from a specific billing period (YYYY-MM format)
**Returns:**
- HTML file as attachment with appropriate content-disposition header
**Errors:**
- 404: File not found, no statements exist, or no statements match period filter
"""
# Ensure file exists
file_obj = db.query(File).filter(File.file_no == file_no).first()
if not file_obj:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="File not found",
)
exports_dir = Path("exports")
if not exports_dir.exists():
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="No statements found")
safe_file_no = str(file_no).replace("/", "_").replace("\\", "_")
candidates = list(exports_dir.glob(f"statement_{safe_file_no}_*.html"))
if not candidates:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="No statements found")
# Optional filter by period by inspecting HTML content
if period:
filtered = []
search_token = f"Period:</strong> {period}</div>"
for path in candidates:
try:
with open(path, "r", encoding="utf-8") as f:
content = f.read()
if search_token in content:
filtered.append(path)
except Exception:
# Skip unreadable files
continue
candidates = filtered
if not candidates:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="No statements found for requested period",
)
# Choose latest by modification time
candidates.sort(key=lambda p: p.stat().st_mtime, reverse=True)
latest_path = candidates[0]
return FileResponse(
latest_path,
media_type="text/html",
filename=latest_path.name,
)