""" Batch statement generation helpers. This module extracts request validation, batch ID construction, estimated completion calculation, and database persistence from the API layer. """ from __future__ import annotations from typing import List, Optional, Any, Dict, Tuple from datetime import datetime, timezone, timedelta from dataclasses import dataclass, field from fastapi import HTTPException, status from sqlalchemy.orm import Session from app.models.billing import BillingBatch, BillingBatchFile def prepare_batch_parameters(file_numbers: Optional[List[str]]) -> List[str]: """Validate incoming file numbers and return de-duplicated list, preserving order.""" if not file_numbers: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="At least one file number must be provided", ) if len(file_numbers) > 50: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Maximum 50 files allowed per batch operation", ) # Remove duplicates while preserving order return list(dict.fromkeys(file_numbers)) def make_batch_id(unique_file_numbers: List[str], start_time: datetime) -> str: """Create a stable batch ID matching the previous public behavior.""" return f"batch_{start_time.strftime('%Y%m%d_%H%M%S')}_{abs(hash(str(unique_file_numbers))) % 10000:04d}" def compute_estimated_completion( *, processed_files: int, total_files: int, started_at_iso: str, now: datetime, ) -> Optional[str]: """Calculate estimated completion time as ISO string based on average rate.""" if processed_files <= 0: return None try: start_time = datetime.fromisoformat(started_at_iso.replace("Z", "+00:00")) except Exception: return None elapsed_seconds = (now - start_time).total_seconds() if elapsed_seconds <= 0: return None remaining_files = max(total_files - processed_files, 0) if remaining_files == 0: return now.isoformat() avg_time_per_file = elapsed_seconds / processed_files estimated_remaining_seconds = avg_time_per_file * remaining_files estimated_completion = now + timedelta(seconds=estimated_remaining_seconds) return estimated_completion.isoformat() def persist_batch_results( db: Session, *, batch_id: str, progress: Any, processing_time_seconds: float, success_rate: float, ) -> None: """Persist batch summary and per-file results using the DB models. The `progress` object is expected to expose attributes consistent with the API's BatchProgress model: - status, total_files, successful_files, failed_files - started_at, updated_at, completed_at, error_message - files: list with {file_no, status, error_message, statement_meta, started_at, completed_at} """ def _parse_iso(dt: Optional[str]): if not dt: return None try: from datetime import datetime as _dt return _dt.fromisoformat(str(dt).replace('Z', '+00:00')) except Exception: return None batch_row = BillingBatch( batch_id=batch_id, status=str(getattr(progress, "status", "")), total_files=int(getattr(progress, "total_files", 0)), successful_files=int(getattr(progress, "successful_files", 0)), failed_files=int(getattr(progress, "failed_files", 0)), started_at=_parse_iso(getattr(progress, "started_at", None)), updated_at=_parse_iso(getattr(progress, "updated_at", None)), completed_at=_parse_iso(getattr(progress, "completed_at", None)), processing_time_seconds=float(processing_time_seconds), success_rate=float(success_rate), error_message=getattr(progress, "error_message", None), ) db.add(batch_row) for f in list(getattr(progress, "files", []) or []): meta = getattr(f, "statement_meta", None) filename = None size = None if meta is not None: try: filename = getattr(meta, "filename", None) size = getattr(meta, "size", None) except Exception: filename = None size = None if filename is None and isinstance(meta, dict): filename = meta.get("filename") size = meta.get("size") db.add( BillingBatchFile( batch_id=batch_id, file_no=getattr(f, "file_no", None), status=str(getattr(f, "status", "")), error_message=getattr(f, "error_message", None), filename=filename, size=size, started_at=_parse_iso(getattr(f, "started_at", None)), completed_at=_parse_iso(getattr(f, "completed_at", None)), ) ) db.commit() @dataclass class BatchProgressEntry: """Lightweight progress entry shape used in tests for compatibility.""" file_no: str status: str started_at: Optional[str] = None completed_at: Optional[str] = None error_message: Optional[str] = None statement_meta: Optional[Dict[str, Any]] = None @dataclass class BatchProgress: """Lightweight batch progress shape used in tests for topic formatting checks.""" batch_id: str status: str total_files: int processed_files: int successful_files: int failed_files: int current_file: Optional[str] = None started_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) updated_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) completed_at: Optional[datetime] = None estimated_completion: Optional[datetime] = None processing_time_seconds: Optional[float] = None success_rate: Optional[float] = None files: List[BatchProgressEntry] = field(default_factory=list) error_message: Optional[str] = None def model_dump(self) -> Dict[str, Any]: """Provide a dict representation similar to Pydantic for broadcasting.""" def _dt(v): if isinstance(v, datetime): return v.isoformat() return v return { "batch_id": self.batch_id, "status": self.status, "total_files": self.total_files, "processed_files": self.processed_files, "successful_files": self.successful_files, "failed_files": self.failed_files, "current_file": self.current_file, "started_at": _dt(self.started_at), "updated_at": _dt(self.updated_at), "completed_at": _dt(self.completed_at), "estimated_completion": _dt(self.estimated_completion), "processing_time_seconds": self.processing_time_seconds, "success_rate": self.success_rate, "files": [ { "file_no": f.file_no, "status": f.status, "started_at": f.started_at, "completed_at": f.completed_at, "error_message": f.error_message, "statement_meta": f.statement_meta, } for f in self.files ], "error_message": self.error_message, }