This commit is contained in:
HotSwapp
2025-08-18 20:20:04 -05:00
parent 89b2bc0aa2
commit bac8cc4bd5
114 changed files with 30258 additions and 1341 deletions

View File

@@ -34,6 +34,17 @@ from app.models.billing import (
BillingStatementItem, StatementStatus
)
from app.services.billing import BillingStatementService, StatementGenerationError
from app.services.statement_generation import (
generate_single_statement as _svc_generate_single_statement,
parse_period_month as _svc_parse_period_month,
render_statement_html as _svc_render_statement_html,
)
from app.services.batch_generation import (
prepare_batch_parameters as _svc_prepare_batch_parameters,
make_batch_id as _svc_make_batch_id,
compute_estimated_completion as _svc_compute_eta,
persist_batch_results as _svc_persist_batch_results,
)
router = APIRouter()
@@ -41,33 +52,29 @@ router = APIRouter()
# Initialize logger for billing operations
billing_logger = StructuredLogger("billing_operations", "INFO")
# Realtime WebSocket subscriber registry: batch_id -> set[WebSocket]
_subscribers_by_batch: Dict[str, Set[WebSocket]] = {}
_subscribers_lock = asyncio.Lock()
# Import WebSocket pool services
from app.middleware.websocket_middleware import get_websocket_manager
from app.services.websocket_pool import WebSocketMessage
# WebSocket manager for batch progress notifications
websocket_manager = get_websocket_manager()
async def _notify_progress_subscribers(progress: "BatchProgress") -> None:
"""Broadcast latest progress to active subscribers of a batch."""
"""Broadcast latest progress to active subscribers of a batch using WebSocket pool."""
batch_id = progress.batch_id
message = {"type": "progress", "data": progress.model_dump()}
async with _subscribers_lock:
sockets = list(_subscribers_by_batch.get(batch_id, set()))
if not sockets:
return
dead: List[WebSocket] = []
for ws in sockets:
try:
await ws.send_json(message)
except Exception:
dead.append(ws)
if dead:
async with _subscribers_lock:
bucket = _subscribers_by_batch.get(batch_id)
if bucket:
for ws in dead:
bucket.discard(ws)
if not bucket:
_subscribers_by_batch.pop(batch_id, None)
topic = f"batch_progress_{batch_id}"
# Use the WebSocket manager to broadcast to topic
sent_count = await websocket_manager.broadcast_to_topic(
topic=topic,
message_type="progress",
data=progress.model_dump()
)
billing_logger.debug("Broadcast batch progress update",
batch_id=batch_id,
subscribers_notified=sent_count)
def _round(value: Optional[float]) -> float:
@@ -606,21 +613,8 @@ progress_store = BatchProgressStore()
def _parse_period_month(period: Optional[str]) -> Optional[tuple[date, date]]:
"""Parse period in the form YYYY-MM and return (start_date, end_date) inclusive.
Returns None when period is not provided or invalid.
"""
if not period:
return None
m = re.fullmatch(r"(\d{4})-(\d{2})", str(period).strip())
if not m:
return None
year = int(m.group(1))
month = int(m.group(2))
if month < 1 or month > 12:
return None
from calendar import monthrange
last_day = monthrange(year, month)[1]
return date(year, month, 1), date(year, month, last_day)
"""Parse YYYY-MM period; delegates to service helper for consistency."""
return _svc_parse_period_month(period)
def _render_statement_html(
@@ -633,80 +627,25 @@ def _render_statement_html(
totals: StatementTotals,
unbilled_entries: List[StatementEntry],
) -> str:
"""Create a simple, self-contained HTML statement string."""
# Rows for unbilled entries
def _fmt(val: Optional[float]) -> str:
try:
return f"{float(val or 0):.2f}"
except Exception:
return "0.00"
rows = []
for e in unbilled_entries:
rows.append(
f"<tr><td>{e.date.isoformat() if e.date else ''}</td><td>{e.t_code}</td><td>{(e.description or '').replace('<','&lt;').replace('>','&gt;')}</td>"
f"<td style='text-align:right'>{_fmt(e.quantity)}</td><td style='text-align:right'>{_fmt(e.rate)}</td><td style='text-align:right'>{_fmt(e.amount)}</td></tr>"
)
rows_html = "\n".join(rows) if rows else "<tr><td colspan='6' style='text-align:center;color:#666'>No unbilled entries</td></tr>"
period_html = f"<div><strong>Period:</strong> {period}</div>" if period else ""
html = f"""
<!DOCTYPE html>
<html lang=\"en\">
<head>
<meta charset=\"utf-8\" />
<title>Statement {file_no}</title>
<style>
body {{ font-family: -apple-system, BlinkMacSystemFont, Segoe UI, Roboto, Helvetica, Arial, sans-serif; margin: 24px; }}
h1 {{ margin: 0 0 8px 0; }}
.meta {{ color: #444; margin-bottom: 16px; }}
table {{ border-collapse: collapse; width: 100%; }}
th, td {{ border: 1px solid #ddd; padding: 8px; font-size: 14px; }}
th {{ background: #f6f6f6; text-align: left; }}
.totals {{ margin: 16px 0; display: grid; grid-template-columns: repeat(auto-fit, minmax(180px, 1fr)); gap: 8px; }}
.totals div {{ background: #fafafa; border: 1px solid #eee; padding: 8px; }}
</style>
</head>
<body>
<h1>Statement</h1>
<div class=\"meta\">
<div><strong>File:</strong> {file_no}</div>
<div><strong>Client:</strong> {client_name or ''}</div>
<div><strong>Matter:</strong> {matter or ''}</div>
<div><strong>As of:</strong> {as_of_iso}</div>
{period_html}
</div>
<div class=\"totals\">
<div><strong>Charges (billed)</strong><br/>${_fmt(totals.charges_billed)}</div>
<div><strong>Charges (unbilled)</strong><br/>${_fmt(totals.charges_unbilled)}</div>
<div><strong>Charges (total)</strong><br/>${_fmt(totals.charges_total)}</div>
<div><strong>Payments</strong><br/>${_fmt(totals.payments)}</div>
<div><strong>Trust balance</strong><br/>${_fmt(totals.trust_balance)}</div>
<div><strong>Current balance</strong><br/>${_fmt(totals.current_balance)}</div>
</div>
<h2>Unbilled Entries</h2>
<table>
<thead>
<tr>
<th>Date</th>
<th>Code</th>
<th>Description</th>
<th style=\"text-align:right\">Qty</th>
<th style=\"text-align:right\">Rate</th>
<th style=\"text-align:right\">Amount</th>
</tr>
</thead>
<tbody>
{rows_html}
</tbody>
</table>
</body>
</html>
"""
return html
"""Create statement HTML via service helper while preserving API models."""
totals_dict: Dict[str, float] = {
"charges_billed": totals.charges_billed,
"charges_unbilled": totals.charges_unbilled,
"charges_total": totals.charges_total,
"payments": totals.payments,
"trust_balance": totals.trust_balance,
"current_balance": totals.current_balance,
}
entries_dict: List[Dict[str, Any]] = [e.model_dump() for e in (unbilled_entries or [])]
return _svc_render_statement_html(
file_no=file_no,
client_name=client_name,
matter=matter,
as_of_iso=as_of_iso,
period=period,
totals=totals_dict,
unbilled_entries=entries_dict,
)
def _generate_single_statement(
@@ -714,118 +653,28 @@ def _generate_single_statement(
period: Optional[str],
db: Session
) -> GeneratedStatementMeta:
"""
Internal helper to generate a statement for a single file.
Args:
file_no: File number to generate statement for
period: Optional period filter (YYYY-MM format)
db: Database session
Returns:
GeneratedStatementMeta with file metadata and export path
Raises:
HTTPException: If file not found or generation fails
"""
file_obj = (
db.query(File)
.options(joinedload(File.owner))
.filter(File.file_no == file_no)
.first()
)
if not file_obj:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"File {file_no} not found",
)
# Optional period filtering (YYYY-MM)
date_range = _parse_period_month(period)
q = db.query(Ledger).filter(Ledger.file_no == file_no)
if date_range:
start_date, end_date = date_range
q = q.filter(Ledger.date >= start_date).filter(Ledger.date <= end_date)
entries: List[Ledger] = q.all()
CHARGE_TYPES = {"2", "3", "4"}
charges_billed = sum(e.amount for e in entries if e.t_type in CHARGE_TYPES and e.billed == "Y")
charges_unbilled = sum(e.amount for e in entries if e.t_type in CHARGE_TYPES and e.billed != "Y")
charges_total = charges_billed + charges_unbilled
payments_total = sum(e.amount for e in entries if e.t_type == "5")
trust_balance = file_obj.trust_bal or 0.0
current_balance = charges_total - payments_total
unbilled_entries = [
StatementEntry(
id=e.id,
date=e.date,
t_code=e.t_code,
t_type=e.t_type,
description=e.note,
quantity=e.quantity or 0.0,
rate=e.rate or 0.0,
amount=e.amount,
)
for e in entries
if e.t_type in CHARGE_TYPES and e.billed != "Y"
]
client_name = None
if file_obj.owner:
client_name = f"{file_obj.owner.first or ''} {file_obj.owner.last}".strip()
as_of_iso = datetime.now(timezone.utc).isoformat()
"""Generate a single statement via service and adapt to API response model."""
data = _svc_generate_single_statement(file_no, period, db)
totals = data.get("totals", {})
totals_model = StatementTotals(
charges_billed=_round(charges_billed),
charges_unbilled=_round(charges_unbilled),
charges_total=_round(charges_total),
payments=_round(payments_total),
trust_balance=_round(trust_balance),
current_balance=_round(current_balance),
charges_billed=float(totals.get("charges_billed", 0.0)),
charges_unbilled=float(totals.get("charges_unbilled", 0.0)),
charges_total=float(totals.get("charges_total", 0.0)),
payments=float(totals.get("payments", 0.0)),
trust_balance=float(totals.get("trust_balance", 0.0)),
current_balance=float(totals.get("current_balance", 0.0)),
)
# Render HTML
html = _render_statement_html(
file_no=file_no,
client_name=client_name or None,
matter=file_obj.regarding,
as_of_iso=as_of_iso,
period=period,
totals=totals_model,
unbilled_entries=unbilled_entries,
)
# Ensure exports directory and write file
exports_dir = Path("exports")
try:
exports_dir.mkdir(exist_ok=True)
except Exception:
# Best-effort: if cannot create, bubble up internal error
raise HTTPException(status_code=500, detail="Unable to create exports directory")
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S_%f")
safe_file_no = str(file_no).replace("/", "_").replace("\\", "_")
filename = f"statement_{safe_file_no}_{timestamp}.html"
export_path = exports_dir / filename
html_bytes = html.encode("utf-8")
with open(export_path, "wb") as f:
f.write(html_bytes)
size = export_path.stat().st_size
return GeneratedStatementMeta(
file_no=file_no,
client_name=client_name or None,
as_of=as_of_iso,
period=period,
file_no=str(data.get("file_no")),
client_name=data.get("client_name"),
as_of=str(data.get("as_of")),
period=data.get("period"),
totals=totals_model,
unbilled_count=len(unbilled_entries),
export_path=str(export_path),
filename=filename,
size=size,
content_type="text/html",
unbilled_count=int(data.get("unbilled_count", 0)),
export_path=str(data.get("export_path")),
filename=str(data.get("filename")),
size=int(data.get("size", 0)),
content_type=str(data.get("content_type", "text/html")),
)
@@ -842,92 +691,48 @@ async def generate_statement(
return _generate_single_statement(payload.file_no, payload.period, db)
async def _ws_authenticate(websocket: WebSocket) -> Optional[User]:
"""Authenticate WebSocket via JWT token in query (?token=) or Authorization header."""
token = websocket.query_params.get("token")
if not token:
try:
auth_header = dict(websocket.headers).get("authorization") or ""
if auth_header.lower().startswith("bearer "):
token = auth_header.split(" ", 1)[1].strip()
except Exception:
token = None
if not token:
return None
username = verify_token(token)
if not username:
return None
db = SessionLocal()
try:
user = db.query(User).filter(User.username == username).first()
if not user or not user.is_active:
return None
return user
finally:
db.close()
async def _ws_keepalive(ws: WebSocket, stop_event: asyncio.Event) -> None:
try:
while not stop_event.is_set():
await asyncio.sleep(25)
try:
await ws.send_json({"type": "ping", "ts": datetime.now(timezone.utc).isoformat()})
except Exception:
break
finally:
stop_event.set()
@router.websocket("/statements/batch-progress/ws/{batch_id}")
async def ws_batch_progress(websocket: WebSocket, batch_id: str):
"""WebSocket: subscribe to real-time updates for a batch_id."""
user = await _ws_authenticate(websocket)
if not user:
await websocket.close(code=4401)
return
await websocket.accept()
# Register
async with _subscribers_lock:
bucket = _subscribers_by_batch.get(batch_id)
if not bucket:
bucket = set()
_subscribers_by_batch[batch_id] = bucket
bucket.add(websocket)
# Send initial snapshot
try:
snapshot = await progress_store.get_progress(batch_id)
await websocket.send_json({"type": "progress", "data": snapshot.model_dump() if snapshot else None})
except Exception:
pass
# Keepalive + receive loop
stop_event: asyncio.Event = asyncio.Event()
ka_task = asyncio.create_task(_ws_keepalive(websocket, stop_event))
try:
while not stop_event.is_set():
try:
msg = await websocket.receive_text()
except WebSocketDisconnect:
break
except Exception:
break
if isinstance(msg, str) and msg.strip() == "ping":
try:
await websocket.send_text("pong")
except Exception:
break
finally:
stop_event.set()
"""WebSocket: subscribe to real-time updates for a batch_id using the WebSocket pool."""
topic = f"batch_progress_{batch_id}"
# Custom message handler for batch progress
async def handle_batch_message(connection_id: str, message: WebSocketMessage):
"""Handle custom messages for batch progress"""
billing_logger.debug("Received batch progress message",
connection_id=connection_id,
batch_id=batch_id,
message_type=message.type)
# Handle any batch-specific message logic here if needed
# Use the WebSocket manager to handle the connection
connection_id = await websocket_manager.handle_connection(
websocket=websocket,
topics={topic},
require_auth=True,
metadata={"batch_id": batch_id, "endpoint": "batch_progress"},
message_handler=handle_batch_message
)
if connection_id:
# Send initial snapshot after connection is established
try:
ka_task.cancel()
except Exception:
pass
async with _subscribers_lock:
bucket = _subscribers_by_batch.get(batch_id)
if bucket and websocket in bucket:
bucket.discard(websocket)
if not bucket:
_subscribers_by_batch.pop(batch_id, None)
snapshot = await progress_store.get_progress(batch_id)
pool = websocket_manager.pool
initial_message = WebSocketMessage(
type="progress",
topic=topic,
data=snapshot.model_dump() if snapshot else None
)
await pool._send_to_connection(connection_id, initial_message)
billing_logger.info("Sent initial batch progress snapshot",
connection_id=connection_id,
batch_id=batch_id)
except Exception as e:
billing_logger.error("Failed to send initial batch progress snapshot",
connection_id=connection_id,
batch_id=batch_id,
error=str(e))
@router.delete("/statements/batch-progress/{batch_id}")
async def cancel_batch_operation(
@@ -1045,25 +850,12 @@ async def batch_generate_statements(
- Batch operation identification for audit trails
- Automatic cleanup of progress data after completion
"""
# Validate request
if not payload.file_numbers:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="At least one file number must be provided"
)
if len(payload.file_numbers) > 50:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Maximum 50 files allowed per batch operation"
)
# Remove duplicates while preserving order
unique_file_numbers = list(dict.fromkeys(payload.file_numbers))
# Validate request and normalize inputs
unique_file_numbers = _svc_prepare_batch_parameters(payload.file_numbers)
# Generate batch ID and timing
start_time = datetime.now(timezone.utc)
batch_id = f"batch_{start_time.strftime('%Y%m%d_%H%M%S')}_{abs(hash(str(unique_file_numbers))) % 10000:04d}"
batch_id = _svc_make_batch_id(unique_file_numbers, start_time)
billing_logger.info(
"Starting batch statement generation",
@@ -1121,7 +913,12 @@ async def batch_generate_statements(
progress.current_file = file_no
progress.files[idx].status = "processing"
progress.files[idx].started_at = current_time.isoformat()
progress.estimated_completion = await _calculate_estimated_completion(progress, current_time)
progress.estimated_completion = _svc_compute_eta(
processed_files=progress.processed_files,
total_files=progress.total_files,
started_at_iso=progress.started_at,
now=current_time,
)
await progress_store.set_progress(progress)
billing_logger.info(
@@ -1288,53 +1085,13 @@ async def batch_generate_statements(
# Persist batch summary and per-file results
try:
def _parse_iso(dt: Optional[str]):
if not dt:
return None
try:
from datetime import datetime as _dt
return _dt.fromisoformat(dt.replace('Z', '+00:00'))
except Exception:
return None
batch_row = BillingBatch(
_svc_persist_batch_results(
db,
batch_id=batch_id,
status=str(progress.status),
total_files=total_files,
successful_files=successful,
failed_files=failed,
started_at=_parse_iso(progress.started_at),
updated_at=_parse_iso(progress.updated_at),
completed_at=_parse_iso(progress.completed_at),
progress=progress,
processing_time_seconds=processing_time,
success_rate=success_rate,
error_message=progress.error_message,
)
db.add(batch_row)
for f in progress.files:
meta = getattr(f, 'statement_meta', None)
filename = None
size = None
if meta is not None:
try:
filename = getattr(meta, 'filename', None)
size = getattr(meta, 'size', None)
except Exception:
pass
if filename is None and isinstance(meta, dict):
filename = meta.get('filename')
size = meta.get('size')
db.add(BillingBatchFile(
batch_id=batch_id,
file_no=f.file_no,
status=str(f.status),
error_message=f.error_message,
filename=filename,
size=size,
started_at=_parse_iso(f.started_at),
completed_at=_parse_iso(f.completed_at),
))
db.commit()
except Exception:
try:
db.rollback()
@@ -1600,6 +1357,34 @@ async def download_latest_statement(
detail="No statements found for requested period",
)
# Filter out any statements created prior to the file's opened date (safety against collisions)
try:
opened_date = getattr(file_obj, "opened", None)
if opened_date:
filtered_by_opened: List[Path] = []
for path in candidates:
name = path.name
# Filename format: statement_{safe_file_no}_YYYYMMDD_HHMMSS_micro.html
m = re.match(rf"^statement_{re.escape(safe_file_no)}_(\d{{8}})_\d{{6}}_\d{{6}}\.html$", name)
if not m:
continue
ymd = m.group(1)
y, mo, d = int(ymd[0:4]), int(ymd[4:6]), int(ymd[6:8])
from datetime import date as _date
stmt_date = _date(y, mo, d)
if stmt_date >= opened_date:
filtered_by_opened.append(path)
if filtered_by_opened:
candidates = filtered_by_opened
else:
# If none meet the opened-date filter, treat as no statements
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="No statements found")
except HTTPException:
raise
except Exception:
# On parse errors, continue with existing candidates
pass
# Choose latest by modification time
candidates.sort(key=lambda p: p.stat().st_mtime, reverse=True)
latest_path = candidates[0]