changes
This commit is contained in:
361
app/api/admin.py
361
app/api/admin.py
@@ -29,6 +29,9 @@ from app.config import settings
|
||||
from app.services.query_utils import apply_sorting, tokenized_ilike_filter, paginate_with_total
|
||||
from app.utils.exceptions import handle_database_errors, safe_execute
|
||||
from app.utils.logging import app_logger
|
||||
from app.middleware.websocket_middleware import get_websocket_manager, get_connection_tracker, WebSocketMessage
|
||||
from app.services.document_notifications import ADMIN_DOCUMENTS_TOPIC
|
||||
from fastapi import WebSocket
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -64,6 +67,55 @@ class HealthCheck(BaseModel):
|
||||
cpu_usage: float
|
||||
alerts: List[str]
|
||||
|
||||
|
||||
class WebSocketStats(BaseModel):
|
||||
"""WebSocket connection pool statistics"""
|
||||
total_connections: int
|
||||
active_connections: int
|
||||
total_topics: int
|
||||
total_users: int
|
||||
messages_sent: int
|
||||
messages_failed: int
|
||||
connections_cleaned: int
|
||||
last_cleanup: Optional[str]
|
||||
last_heartbeat: Optional[str]
|
||||
connections_by_state: Dict[str, int]
|
||||
topic_distribution: Dict[str, int]
|
||||
|
||||
|
||||
class ConnectionInfo(BaseModel):
|
||||
"""Individual WebSocket connection information"""
|
||||
connection_id: str
|
||||
user_id: Optional[int]
|
||||
state: str
|
||||
topics: List[str]
|
||||
created_at: str
|
||||
last_activity: str
|
||||
age_seconds: float
|
||||
idle_seconds: float
|
||||
error_count: int
|
||||
last_ping: Optional[str]
|
||||
last_pong: Optional[str]
|
||||
metadata: Dict[str, Any]
|
||||
is_alive: bool
|
||||
is_stale: bool
|
||||
|
||||
|
||||
class WebSocketConnectionsResponse(BaseModel):
|
||||
"""Response for WebSocket connections listing"""
|
||||
connections: List[ConnectionInfo]
|
||||
total_count: int
|
||||
active_count: int
|
||||
stale_count: int
|
||||
|
||||
|
||||
class DisconnectRequest(BaseModel):
|
||||
"""Request to disconnect WebSocket connections"""
|
||||
connection_ids: Optional[List[str]] = None
|
||||
user_id: Optional[int] = None
|
||||
topic: Optional[str] = None
|
||||
reason: str = "admin_disconnect"
|
||||
|
||||
class UserCreate(BaseModel):
|
||||
"""Create new user"""
|
||||
username: str = Field(..., min_length=3, max_length=50)
|
||||
@@ -551,6 +603,253 @@ async def system_statistics(
|
||||
)
|
||||
|
||||
|
||||
# WebSocket Management Endpoints
|
||||
|
||||
@router.get("/websockets/stats", response_model=WebSocketStats)
|
||||
async def get_websocket_stats(
|
||||
current_user: User = Depends(get_admin_user)
|
||||
):
|
||||
"""Get WebSocket connection pool statistics"""
|
||||
websocket_manager = get_websocket_manager()
|
||||
stats = await websocket_manager.get_stats()
|
||||
|
||||
return WebSocketStats(**stats)
|
||||
|
||||
|
||||
@router.get("/websockets/connections", response_model=WebSocketConnectionsResponse)
|
||||
async def get_websocket_connections(
|
||||
user_id: Optional[int] = Query(None, description="Filter by user ID"),
|
||||
topic: Optional[str] = Query(None, description="Filter by topic"),
|
||||
state: Optional[str] = Query(None, description="Filter by connection state"),
|
||||
current_user: User = Depends(get_admin_user)
|
||||
):
|
||||
"""Get list of active WebSocket connections with optional filtering"""
|
||||
websocket_manager = get_websocket_manager()
|
||||
connection_tracker = get_connection_tracker()
|
||||
|
||||
# Get all connection IDs
|
||||
pool = websocket_manager.pool
|
||||
async with pool._connections_lock:
|
||||
all_connection_ids = list(pool._connections.keys())
|
||||
|
||||
connections = []
|
||||
active_count = 0
|
||||
stale_count = 0
|
||||
|
||||
for connection_id in all_connection_ids:
|
||||
metrics = await connection_tracker.get_connection_metrics(connection_id)
|
||||
if not metrics:
|
||||
continue
|
||||
|
||||
# Apply filters
|
||||
if user_id and metrics.get("user_id") != user_id:
|
||||
continue
|
||||
if topic and topic not in metrics.get("topics", []):
|
||||
continue
|
||||
if state and metrics.get("state") != state:
|
||||
continue
|
||||
|
||||
connections.append(ConnectionInfo(**metrics))
|
||||
|
||||
if metrics.get("is_alive"):
|
||||
active_count += 1
|
||||
if metrics.get("is_stale"):
|
||||
stale_count += 1
|
||||
|
||||
return WebSocketConnectionsResponse(
|
||||
connections=connections,
|
||||
total_count=len(connections),
|
||||
active_count=active_count,
|
||||
stale_count=stale_count
|
||||
)
|
||||
|
||||
|
||||
@router.get("/websockets/connections/{connection_id}", response_model=ConnectionInfo)
|
||||
async def get_websocket_connection(
|
||||
connection_id: str,
|
||||
current_user: User = Depends(get_admin_user)
|
||||
):
|
||||
"""Get detailed information about a specific WebSocket connection"""
|
||||
connection_tracker = get_connection_tracker()
|
||||
|
||||
metrics = await connection_tracker.get_connection_metrics(connection_id)
|
||||
if not metrics:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"WebSocket connection {connection_id} not found"
|
||||
)
|
||||
|
||||
return ConnectionInfo(**metrics)
|
||||
|
||||
|
||||
@router.post("/websockets/disconnect")
|
||||
async def disconnect_websockets(
|
||||
request: DisconnectRequest,
|
||||
current_user: User = Depends(get_admin_user)
|
||||
):
|
||||
"""Disconnect WebSocket connections based on criteria"""
|
||||
websocket_manager = get_websocket_manager()
|
||||
pool = websocket_manager.pool
|
||||
|
||||
disconnected_count = 0
|
||||
|
||||
if request.connection_ids:
|
||||
# Disconnect specific connections
|
||||
for connection_id in request.connection_ids:
|
||||
await pool.remove_connection(connection_id, request.reason)
|
||||
disconnected_count += 1
|
||||
|
||||
elif request.user_id:
|
||||
# Disconnect all connections for a user
|
||||
user_connections = await pool.get_user_connections(request.user_id)
|
||||
for connection_id in user_connections:
|
||||
await pool.remove_connection(connection_id, request.reason)
|
||||
disconnected_count += 1
|
||||
|
||||
elif request.topic:
|
||||
# Disconnect all connections for a topic
|
||||
topic_connections = await pool.get_topic_connections(request.topic)
|
||||
for connection_id in topic_connections:
|
||||
await pool.remove_connection(connection_id, request.reason)
|
||||
disconnected_count += 1
|
||||
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Must specify connection_ids, user_id, or topic"
|
||||
)
|
||||
|
||||
app_logger.info("Admin disconnected WebSocket connections",
|
||||
admin_user=current_user.username,
|
||||
disconnected_count=disconnected_count,
|
||||
reason=request.reason)
|
||||
|
||||
return {
|
||||
"message": f"Disconnected {disconnected_count} WebSocket connections",
|
||||
"disconnected_count": disconnected_count,
|
||||
"reason": request.reason
|
||||
}
|
||||
|
||||
|
||||
@router.post("/websockets/cleanup")
|
||||
async def cleanup_websockets(
|
||||
current_user: User = Depends(get_admin_user)
|
||||
):
|
||||
"""Manually trigger cleanup of stale WebSocket connections"""
|
||||
websocket_manager = get_websocket_manager()
|
||||
pool = websocket_manager.pool
|
||||
|
||||
# Get stats before cleanup
|
||||
stats_before = await pool.get_stats()
|
||||
connections_before = stats_before["active_connections"]
|
||||
|
||||
# Force cleanup
|
||||
await pool._cleanup_stale_connections()
|
||||
|
||||
# Get stats after cleanup
|
||||
stats_after = await pool.get_stats()
|
||||
connections_after = stats_after["active_connections"]
|
||||
|
||||
cleaned_count = connections_before - connections_after
|
||||
|
||||
app_logger.info("Admin triggered WebSocket cleanup",
|
||||
admin_user=current_user.username,
|
||||
cleaned_count=cleaned_count)
|
||||
|
||||
return {
|
||||
"message": f"Cleaned up {cleaned_count} stale WebSocket connections",
|
||||
"connections_before": connections_before,
|
||||
"connections_after": connections_after,
|
||||
"cleaned_count": cleaned_count
|
||||
}
|
||||
|
||||
|
||||
@router.post("/websockets/broadcast")
|
||||
async def broadcast_message(
|
||||
topic: str = Body(..., description="Topic to broadcast to"),
|
||||
message_type: str = Body(..., description="Message type"),
|
||||
data: Optional[Dict[str, Any]] = Body(None, description="Message data"),
|
||||
current_user: User = Depends(get_admin_user)
|
||||
):
|
||||
"""Broadcast a message to all connections subscribed to a topic"""
|
||||
websocket_manager = get_websocket_manager()
|
||||
|
||||
sent_count = await websocket_manager.broadcast_to_topic(
|
||||
topic=topic,
|
||||
message_type=message_type,
|
||||
data=data
|
||||
)
|
||||
|
||||
app_logger.info("Admin broadcast message to topic",
|
||||
admin_user=current_user.username,
|
||||
topic=topic,
|
||||
message_type=message_type,
|
||||
sent_count=sent_count)
|
||||
|
||||
return {
|
||||
"message": f"Broadcast message to {sent_count} connections",
|
||||
"topic": topic,
|
||||
"message_type": message_type,
|
||||
"sent_count": sent_count
|
||||
}
|
||||
|
||||
|
||||
@router.websocket("/ws/documents")
|
||||
async def ws_admin_documents(websocket: WebSocket):
|
||||
"""
|
||||
Admin WebSocket endpoint for monitoring all document processing events.
|
||||
|
||||
Receives real-time notifications about:
|
||||
- Document generation started/completed/failed across all files
|
||||
- Document uploads across all files
|
||||
- Workflow executions that generate documents
|
||||
|
||||
Requires admin authentication via token query parameter.
|
||||
"""
|
||||
websocket_manager = get_websocket_manager()
|
||||
|
||||
# Custom message handler for admin document monitoring
|
||||
async def handle_admin_document_message(connection_id: str, message: WebSocketMessage):
|
||||
"""Handle custom messages for admin document monitoring"""
|
||||
app_logger.debug("Received admin document message",
|
||||
connection_id=connection_id,
|
||||
message_type=message.type)
|
||||
|
||||
# Use the WebSocket manager to handle the connection
|
||||
connection_id = await websocket_manager.handle_connection(
|
||||
websocket=websocket,
|
||||
topics={ADMIN_DOCUMENTS_TOPIC},
|
||||
require_auth=True,
|
||||
metadata={"endpoint": "admin_documents", "admin_monitoring": True},
|
||||
message_handler=handle_admin_document_message
|
||||
)
|
||||
|
||||
if connection_id:
|
||||
# Send initial welcome message with admin monitoring confirmation
|
||||
try:
|
||||
pool = websocket_manager.pool
|
||||
welcome_message = WebSocketMessage(
|
||||
type="admin_monitoring_active",
|
||||
topic=ADMIN_DOCUMENTS_TOPIC,
|
||||
data={
|
||||
"message": "Connected to admin document monitoring stream",
|
||||
"events": [
|
||||
"document_processing",
|
||||
"document_completed",
|
||||
"document_failed",
|
||||
"document_upload"
|
||||
]
|
||||
}
|
||||
)
|
||||
await pool._send_to_connection(connection_id, welcome_message)
|
||||
app_logger.info("Admin document monitoring connection established",
|
||||
connection_id=connection_id)
|
||||
except Exception as e:
|
||||
app_logger.error("Failed to send admin monitoring welcome message",
|
||||
connection_id=connection_id,
|
||||
error=str(e))
|
||||
|
||||
|
||||
@router.post("/import/csv")
|
||||
async def import_csv(
|
||||
table_name: str,
|
||||
@@ -558,17 +857,34 @@ async def import_csv(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_admin_user)
|
||||
):
|
||||
"""Import data from CSV file"""
|
||||
"""Import data from CSV file with comprehensive security validation"""
|
||||
from app.utils.file_security import file_validator, validate_csv_content
|
||||
|
||||
if not file.filename.endswith('.csv'):
|
||||
# Comprehensive security validation for CSV uploads
|
||||
content_bytes, safe_filename, file_ext, mime_type = await file_validator.validate_upload_file(
|
||||
file, category='csv'
|
||||
)
|
||||
|
||||
# Decode content with proper encoding handling
|
||||
encodings = ['utf-8', 'utf-8-sig', 'windows-1252', 'iso-8859-1']
|
||||
content_str = None
|
||||
for encoding in encodings:
|
||||
try:
|
||||
content_str = content_bytes.decode(encoding)
|
||||
break
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
|
||||
if content_str is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="File must be a CSV"
|
||||
status_code=400,
|
||||
detail="Could not decode CSV file. Please ensure it's saved in UTF-8, Windows-1252, or ISO-8859-1 encoding."
|
||||
)
|
||||
|
||||
# Read CSV content
|
||||
content = await file.read()
|
||||
csv_data = csv.DictReader(io.StringIO(content.decode('utf-8')))
|
||||
# Additional CSV security validation
|
||||
validate_csv_content(content_str)
|
||||
|
||||
csv_data = csv.DictReader(io.StringIO(content_str))
|
||||
|
||||
imported_count = 0
|
||||
errors = []
|
||||
@@ -1786,4 +2102,33 @@ async def get_audit_statistics(
|
||||
{"username": username, "activity_count": count}
|
||||
for username, count in most_active_users
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@router.get("/cache-performance")
|
||||
async def get_cache_performance(
|
||||
current_user: User = Depends(get_admin_user)
|
||||
):
|
||||
"""Get adaptive cache performance statistics"""
|
||||
try:
|
||||
from app.services.adaptive_cache import get_cache_stats
|
||||
stats = get_cache_stats()
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"cache_statistics": stats,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"summary": {
|
||||
"total_cache_types": len(stats),
|
||||
"avg_hit_rate": sum(s.get("hit_rate", 0) for s in stats.values()) / len(stats) if stats else 0,
|
||||
"most_active": max(stats.items(), key=lambda x: x[1].get("total_queries", 0)) if stats else None,
|
||||
"longest_ttl": max(stats.items(), key=lambda x: x[1].get("current_ttl", 0)) if stats else None,
|
||||
"shortest_ttl": min(stats.items(), key=lambda x: x[1].get("current_ttl", float('inf'))) if stats else None
|
||||
}
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "error",
|
||||
"error": str(e),
|
||||
"cache_statistics": {}
|
||||
}
|
||||
419
app/api/advanced_templates.py
Normal file
419
app/api/advanced_templates.py
Normal file
@@ -0,0 +1,419 @@
|
||||
"""
|
||||
Advanced Template Processing API
|
||||
|
||||
This module provides enhanced template processing capabilities including:
|
||||
- Conditional content blocks (IF/ENDIF sections)
|
||||
- Loop functionality for data tables (FOR/ENDFOR sections)
|
||||
- Rich variable formatting with filters
|
||||
- Template function support
|
||||
- PDF generation from DOCX templates
|
||||
- Advanced variable resolution
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional, Dict, Any, Union
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, UploadFile, File, Form, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
import os
|
||||
import io
|
||||
from sqlalchemy.orm import Session
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.database.base import get_db
|
||||
from app.auth.security import get_current_user
|
||||
from app.models.user import User
|
||||
from app.models.templates import DocumentTemplate, DocumentTemplateVersion
|
||||
from app.services.storage import get_default_storage
|
||||
from app.services.template_merge import (
|
||||
extract_tokens_from_bytes, build_context, resolve_tokens, render_docx,
|
||||
process_template_content, convert_docx_to_pdf, apply_variable_formatting
|
||||
)
|
||||
from app.core.logging import get_logger
|
||||
|
||||
logger = get_logger("advanced_templates")
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class AdvancedGenerateRequest(BaseModel):
|
||||
"""Advanced template generation request with enhanced features"""
|
||||
context: Dict[str, Any] = Field(default_factory=dict)
|
||||
version_id: Optional[int] = None
|
||||
output_format: str = Field(default="DOCX", description="Output format: DOCX, PDF")
|
||||
enable_conditionals: bool = Field(default=True, description="Enable conditional sections processing")
|
||||
enable_loops: bool = Field(default=True, description="Enable loop sections processing")
|
||||
enable_formatting: bool = Field(default=True, description="Enable variable formatting")
|
||||
enable_functions: bool = Field(default=True, description="Enable template functions")
|
||||
|
||||
|
||||
class AdvancedGenerateResponse(BaseModel):
|
||||
"""Enhanced generation response with processing details"""
|
||||
resolved: Dict[str, Any]
|
||||
unresolved: List[str]
|
||||
output_mime_type: str
|
||||
output_size: int
|
||||
processing_details: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class BatchAdvancedGenerateRequest(BaseModel):
|
||||
"""Batch generation request using advanced template features"""
|
||||
template_id: int
|
||||
version_id: Optional[int] = None
|
||||
file_nos: List[str]
|
||||
output_format: str = Field(default="DOCX", description="Output format: DOCX, PDF")
|
||||
context: Optional[Dict[str, Any]] = None
|
||||
enable_conditionals: bool = Field(default=True, description="Enable conditional sections processing")
|
||||
enable_loops: bool = Field(default=True, description="Enable loop sections processing")
|
||||
enable_formatting: bool = Field(default=True, description="Enable variable formatting")
|
||||
enable_functions: bool = Field(default=True, description="Enable template functions")
|
||||
bundle_zip: bool = False
|
||||
|
||||
|
||||
class BatchAdvancedGenerateResponse(BaseModel):
|
||||
"""Batch generation response with per-item results"""
|
||||
template_name: str
|
||||
results: List[Dict[str, Any]]
|
||||
bundle_url: Optional[str] = None
|
||||
bundle_size: Optional[int] = None
|
||||
processing_summary: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class TemplateAnalysisRequest(BaseModel):
|
||||
"""Request for analyzing template features"""
|
||||
version_id: Optional[int] = None
|
||||
|
||||
|
||||
class TemplateAnalysisResponse(BaseModel):
|
||||
"""Template analysis response showing capabilities"""
|
||||
variables: List[str]
|
||||
formatted_variables: List[str]
|
||||
conditional_blocks: List[Dict[str, Any]]
|
||||
loop_blocks: List[Dict[str, Any]]
|
||||
function_calls: List[str]
|
||||
complexity_score: int
|
||||
recommendations: List[str]
|
||||
|
||||
|
||||
@router.post("/{template_id}/generate-advanced", response_model=AdvancedGenerateResponse)
|
||||
async def generate_advanced_document(
|
||||
template_id: int,
|
||||
payload: AdvancedGenerateRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Generate document with advanced template processing features"""
|
||||
# Get template and version
|
||||
tpl = db.query(DocumentTemplate).filter(DocumentTemplate.id == template_id).first()
|
||||
if not tpl:
|
||||
raise HTTPException(status_code=404, detail="Template not found")
|
||||
|
||||
version_id = payload.version_id or tpl.current_version_id
|
||||
if not version_id:
|
||||
raise HTTPException(status_code=400, detail="Template has no versions")
|
||||
|
||||
ver = db.query(DocumentTemplateVersion).filter(DocumentTemplateVersion.id == version_id).first()
|
||||
if not ver:
|
||||
raise HTTPException(status_code=404, detail="Version not found")
|
||||
|
||||
# Load template content
|
||||
storage = get_default_storage()
|
||||
try:
|
||||
content = storage.open_bytes(ver.storage_path)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404, detail="Template file not found")
|
||||
|
||||
# Extract tokens and build context
|
||||
tokens = extract_tokens_from_bytes(content)
|
||||
context = build_context(payload.context or {}, "template", str(template_id))
|
||||
|
||||
# Resolve variables
|
||||
resolved, unresolved = resolve_tokens(db, tokens, context)
|
||||
|
||||
processing_details = {
|
||||
"features_enabled": {
|
||||
"conditionals": payload.enable_conditionals,
|
||||
"loops": payload.enable_loops,
|
||||
"formatting": payload.enable_formatting,
|
||||
"functions": payload.enable_functions
|
||||
},
|
||||
"tokens_found": len(tokens),
|
||||
"variables_resolved": len(resolved),
|
||||
"variables_unresolved": len(unresolved)
|
||||
}
|
||||
|
||||
# Generate output
|
||||
output_bytes = content
|
||||
output_mime = ver.mime_type
|
||||
|
||||
if ver.mime_type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
|
||||
try:
|
||||
# Enhanced DOCX processing
|
||||
if payload.enable_conditionals or payload.enable_loops or payload.enable_formatting or payload.enable_functions:
|
||||
# For advanced features, we need to process the template content first
|
||||
# This is a simplified approach - in production you'd want more sophisticated DOCX processing
|
||||
logger.info("Advanced template processing enabled - using enhanced rendering")
|
||||
|
||||
# Use docxtpl for basic variable substitution
|
||||
output_bytes = render_docx(content, resolved)
|
||||
|
||||
# Track advanced feature usage
|
||||
processing_details["advanced_features_used"] = True
|
||||
else:
|
||||
# Standard DOCX rendering
|
||||
output_bytes = render_docx(content, resolved)
|
||||
processing_details["advanced_features_used"] = False
|
||||
|
||||
# Convert to PDF if requested
|
||||
if payload.output_format.upper() == "PDF":
|
||||
pdf_bytes = convert_docx_to_pdf(output_bytes)
|
||||
if pdf_bytes:
|
||||
output_bytes = pdf_bytes
|
||||
output_mime = "application/pdf"
|
||||
processing_details["pdf_conversion"] = "success"
|
||||
else:
|
||||
processing_details["pdf_conversion"] = "failed"
|
||||
logger.warning("PDF conversion failed, returning DOCX")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing template: {e}")
|
||||
processing_details["processing_error"] = str(e)
|
||||
|
||||
return AdvancedGenerateResponse(
|
||||
resolved=resolved,
|
||||
unresolved=unresolved,
|
||||
output_mime_type=output_mime,
|
||||
output_size=len(output_bytes),
|
||||
processing_details=processing_details
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{template_id}/analyze", response_model=TemplateAnalysisResponse)
|
||||
async def analyze_template(
|
||||
template_id: int,
|
||||
payload: TemplateAnalysisRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Analyze template to identify advanced features and complexity"""
|
||||
# Get template and version
|
||||
tpl = db.query(DocumentTemplate).filter(DocumentTemplate.id == template_id).first()
|
||||
if not tpl:
|
||||
raise HTTPException(status_code=404, detail="Template not found")
|
||||
|
||||
version_id = payload.version_id or tpl.current_version_id
|
||||
if not version_id:
|
||||
raise HTTPException(status_code=400, detail="Template has no versions")
|
||||
|
||||
ver = db.query(DocumentTemplateVersion).filter(DocumentTemplateVersion.id == version_id).first()
|
||||
if not ver:
|
||||
raise HTTPException(status_code=404, detail="Version not found")
|
||||
|
||||
# Load template content
|
||||
storage = get_default_storage()
|
||||
try:
|
||||
content = storage.open_bytes(ver.storage_path)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404, detail="Template file not found")
|
||||
|
||||
# Analyze template content
|
||||
tokens = extract_tokens_from_bytes(content)
|
||||
|
||||
# For DOCX files, we need to extract text content for analysis
|
||||
text_content = ""
|
||||
try:
|
||||
if ver.mime_type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
|
||||
# Extract text from DOCX for analysis
|
||||
from docx import Document
|
||||
doc = Document(io.BytesIO(content))
|
||||
text_content = "\n".join([paragraph.text for paragraph in doc.paragraphs])
|
||||
else:
|
||||
text_content = content.decode('utf-8', errors='ignore')
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not extract text content for analysis: {e}")
|
||||
text_content = str(content)
|
||||
|
||||
# Analyze different template features
|
||||
from app.services.template_merge import (
|
||||
FORMATTED_TOKEN_PATTERN, CONDITIONAL_START_PATTERN, CONDITIONAL_END_PATTERN,
|
||||
LOOP_START_PATTERN, LOOP_END_PATTERN, FUNCTION_PATTERN
|
||||
)
|
||||
|
||||
# Find formatted variables
|
||||
formatted_variables = []
|
||||
for match in FORMATTED_TOKEN_PATTERN.finditer(text_content):
|
||||
var_name = match.group(1).strip()
|
||||
format_spec = match.group(2).strip()
|
||||
formatted_variables.append(f"{var_name} | {format_spec}")
|
||||
|
||||
# Find conditional blocks
|
||||
conditional_blocks = []
|
||||
conditional_starts = list(CONDITIONAL_START_PATTERN.finditer(text_content))
|
||||
conditional_ends = list(CONDITIONAL_END_PATTERN.finditer(text_content))
|
||||
|
||||
for i, start_match in enumerate(conditional_starts):
|
||||
condition = start_match.group(1).strip()
|
||||
conditional_blocks.append({
|
||||
"condition": condition,
|
||||
"line_start": text_content[:start_match.start()].count('\n') + 1,
|
||||
"complexity": len(condition.split()) # Simple complexity measure
|
||||
})
|
||||
|
||||
# Find loop blocks
|
||||
loop_blocks = []
|
||||
loop_starts = list(LOOP_START_PATTERN.finditer(text_content))
|
||||
|
||||
for start_match in loop_starts:
|
||||
loop_var = start_match.group(1).strip()
|
||||
collection = start_match.group(2).strip()
|
||||
loop_blocks.append({
|
||||
"variable": loop_var,
|
||||
"collection": collection,
|
||||
"line_start": text_content[:start_match.start()].count('\n') + 1
|
||||
})
|
||||
|
||||
# Find function calls
|
||||
function_calls = []
|
||||
for match in FUNCTION_PATTERN.finditer(text_content):
|
||||
func_name = match.group(1).strip()
|
||||
args = match.group(2).strip()
|
||||
function_calls.append(f"{func_name}({args})")
|
||||
|
||||
# Calculate complexity score
|
||||
complexity_score = (
|
||||
len(tokens) * 1 +
|
||||
len(formatted_variables) * 2 +
|
||||
len(conditional_blocks) * 3 +
|
||||
len(loop_blocks) * 4 +
|
||||
len(function_calls) * 2
|
||||
)
|
||||
|
||||
# Generate recommendations
|
||||
recommendations = []
|
||||
if len(conditional_blocks) > 5:
|
||||
recommendations.append("Consider simplifying conditional logic for better maintainability")
|
||||
if len(loop_blocks) > 3:
|
||||
recommendations.append("Multiple loops detected - ensure data sources are optimized")
|
||||
if len(formatted_variables) > 20:
|
||||
recommendations.append("Many formatted variables found - consider using default formatting in context")
|
||||
if complexity_score > 50:
|
||||
recommendations.append("High complexity template - consider breaking into smaller templates")
|
||||
if not any([conditional_blocks, loop_blocks, formatted_variables, function_calls]):
|
||||
recommendations.append("Template uses basic features only - consider leveraging advanced features for better documents")
|
||||
|
||||
return TemplateAnalysisResponse(
|
||||
variables=tokens,
|
||||
formatted_variables=formatted_variables,
|
||||
conditional_blocks=conditional_blocks,
|
||||
loop_blocks=loop_blocks,
|
||||
function_calls=function_calls,
|
||||
complexity_score=complexity_score,
|
||||
recommendations=recommendations
|
||||
)
|
||||
|
||||
|
||||
@router.post("/test-formatting")
|
||||
async def test_variable_formatting(
|
||||
variable_value: str = Form(...),
|
||||
format_spec: str = Form(...),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Test variable formatting without generating a full document"""
|
||||
try:
|
||||
result = apply_variable_formatting(variable_value, format_spec)
|
||||
return {
|
||||
"input_value": variable_value,
|
||||
"format_spec": format_spec,
|
||||
"formatted_result": result,
|
||||
"success": True
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"input_value": variable_value,
|
||||
"format_spec": format_spec,
|
||||
"error": str(e),
|
||||
"success": False
|
||||
}
|
||||
|
||||
|
||||
@router.get("/formatting-help")
|
||||
async def get_formatting_help(
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Get help documentation for variable formatting options"""
|
||||
return {
|
||||
"formatting_options": {
|
||||
"currency": {
|
||||
"description": "Format as currency",
|
||||
"syntax": "currency[:symbol][:decimal_places]",
|
||||
"examples": [
|
||||
{"input": "1234.56", "format": "currency", "output": "$1,234.56"},
|
||||
{"input": "1234.56", "format": "currency:€", "output": "€1,234.56"},
|
||||
{"input": "1234.56", "format": "currency:$:0", "output": "$1,235"}
|
||||
]
|
||||
},
|
||||
"date": {
|
||||
"description": "Format dates",
|
||||
"syntax": "date[:format_string]",
|
||||
"examples": [
|
||||
{"input": "2023-12-25", "format": "date", "output": "December 25, 2023"},
|
||||
{"input": "2023-12-25", "format": "date:%m/%d/%Y", "output": "12/25/2023"},
|
||||
{"input": "2023-12-25", "format": "date:%B %d", "output": "December 25"}
|
||||
]
|
||||
},
|
||||
"number": {
|
||||
"description": "Format numbers",
|
||||
"syntax": "number[:decimal_places][:thousands_sep]",
|
||||
"examples": [
|
||||
{"input": "1234.5678", "format": "number", "output": "1,234.57"},
|
||||
{"input": "1234.5678", "format": "number:1", "output": "1,234.6"},
|
||||
{"input": "1234.5678", "format": "number:2: ", "output": "1 234.57"}
|
||||
]
|
||||
},
|
||||
"percentage": {
|
||||
"description": "Format as percentage",
|
||||
"syntax": "percentage[:decimal_places]",
|
||||
"examples": [
|
||||
{"input": "0.1234", "format": "percentage", "output": "0.1%"},
|
||||
{"input": "12.34", "format": "percentage:2", "output": "12.34%"}
|
||||
]
|
||||
},
|
||||
"phone": {
|
||||
"description": "Format phone numbers",
|
||||
"syntax": "phone[:format_type]",
|
||||
"examples": [
|
||||
{"input": "1234567890", "format": "phone", "output": "(123) 456-7890"},
|
||||
{"input": "11234567890", "format": "phone:us", "output": "1-(123) 456-7890"}
|
||||
]
|
||||
},
|
||||
"text_transforms": {
|
||||
"description": "Text transformations",
|
||||
"options": {
|
||||
"upper": "Convert to UPPERCASE",
|
||||
"lower": "Convert to lowercase",
|
||||
"title": "Convert To Title Case"
|
||||
},
|
||||
"examples": [
|
||||
{"input": "hello world", "format": "upper", "output": "HELLO WORLD"},
|
||||
{"input": "HELLO WORLD", "format": "lower", "output": "hello world"},
|
||||
{"input": "hello world", "format": "title", "output": "Hello World"}
|
||||
]
|
||||
},
|
||||
"utility": {
|
||||
"description": "Utility functions",
|
||||
"options": {
|
||||
"truncate[:length][:suffix]": "Truncate text to specified length",
|
||||
"default[:default_value]": "Use default if empty/null"
|
||||
},
|
||||
"examples": [
|
||||
{"input": "This is a very long text", "format": "truncate:10", "output": "This is..."},
|
||||
{"input": "", "format": "default:N/A", "output": "N/A"}
|
||||
]
|
||||
}
|
||||
},
|
||||
"template_syntax": {
|
||||
"basic_variables": "{{ variable_name }}",
|
||||
"formatted_variables": "{{ variable_name | format_spec }}",
|
||||
"conditionals": "{% if condition %} content {% else %} other content {% endif %}",
|
||||
"loops": "{% for item in items %} content with {{item}} {% endfor %}",
|
||||
"functions": "{{ function_name(arg1, arg2) }}"
|
||||
}
|
||||
}
|
||||
551
app/api/advanced_variables.py
Normal file
551
app/api/advanced_variables.py
Normal file
@@ -0,0 +1,551 @@
|
||||
"""
|
||||
Advanced Template Variables API
|
||||
|
||||
This API provides comprehensive variable management for document templates including:
|
||||
- Variable definition and configuration
|
||||
- Context-specific value management
|
||||
- Advanced processing with conditional logic and calculations
|
||||
- Variable testing and validation
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional, Dict, Any, Union
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query, Body
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
from sqlalchemy import func, or_, and_
|
||||
from pydantic import BaseModel, Field
|
||||
from datetime import datetime
|
||||
|
||||
from app.database.base import get_db
|
||||
from app.auth.security import get_current_user
|
||||
from app.models.user import User
|
||||
from app.models.template_variables import (
|
||||
TemplateVariable, VariableContext, VariableAuditLog,
|
||||
VariableType, VariableTemplate, VariableGroup
|
||||
)
|
||||
from app.services.advanced_variables import VariableProcessor
|
||||
from app.services.query_utils import paginate_with_total
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# Pydantic schemas for API
|
||||
class VariableCreate(BaseModel):
|
||||
name: str = Field(..., max_length=100, description="Unique variable name")
|
||||
display_name: Optional[str] = Field(None, max_length=200)
|
||||
description: Optional[str] = None
|
||||
variable_type: VariableType = VariableType.STRING
|
||||
required: bool = False
|
||||
default_value: Optional[str] = None
|
||||
formula: Optional[str] = None
|
||||
conditional_logic: Optional[Dict[str, Any]] = None
|
||||
data_source_query: Optional[str] = None
|
||||
lookup_table: Optional[str] = None
|
||||
lookup_key_field: Optional[str] = None
|
||||
lookup_value_field: Optional[str] = None
|
||||
validation_rules: Optional[Dict[str, Any]] = None
|
||||
format_pattern: Optional[str] = None
|
||||
depends_on: Optional[List[str]] = None
|
||||
scope: str = "global"
|
||||
category: Optional[str] = None
|
||||
tags: Optional[List[str]] = None
|
||||
cache_duration_minutes: int = 0
|
||||
|
||||
|
||||
class VariableUpdate(BaseModel):
|
||||
display_name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
required: Optional[bool] = None
|
||||
active: Optional[bool] = None
|
||||
default_value: Optional[str] = None
|
||||
formula: Optional[str] = None
|
||||
conditional_logic: Optional[Dict[str, Any]] = None
|
||||
data_source_query: Optional[str] = None
|
||||
lookup_table: Optional[str] = None
|
||||
lookup_key_field: Optional[str] = None
|
||||
lookup_value_field: Optional[str] = None
|
||||
validation_rules: Optional[Dict[str, Any]] = None
|
||||
format_pattern: Optional[str] = None
|
||||
depends_on: Optional[List[str]] = None
|
||||
category: Optional[str] = None
|
||||
tags: Optional[List[str]] = None
|
||||
cache_duration_minutes: Optional[int] = None
|
||||
|
||||
|
||||
class VariableResponse(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
display_name: Optional[str]
|
||||
description: Optional[str]
|
||||
variable_type: VariableType
|
||||
required: bool
|
||||
active: bool
|
||||
default_value: Optional[str]
|
||||
scope: str
|
||||
category: Optional[str]
|
||||
tags: Optional[List[str]]
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime]
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class VariableContextSet(BaseModel):
|
||||
variable_name: str
|
||||
value: Any
|
||||
context_type: str = "global"
|
||||
context_id: str = "default"
|
||||
|
||||
|
||||
class VariableTestRequest(BaseModel):
|
||||
variables: List[str]
|
||||
context_type: str = "global"
|
||||
context_id: str = "default"
|
||||
test_context: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class VariableTestResponse(BaseModel):
|
||||
resolved: Dict[str, Any]
|
||||
unresolved: List[str]
|
||||
processing_time_ms: float
|
||||
errors: List[str]
|
||||
|
||||
|
||||
class VariableAuditResponse(BaseModel):
|
||||
id: int
|
||||
variable_name: str
|
||||
context_type: Optional[str]
|
||||
context_id: Optional[str]
|
||||
old_value: Optional[str]
|
||||
new_value: Optional[str]
|
||||
change_type: str
|
||||
change_reason: Optional[str]
|
||||
changed_by: Optional[str]
|
||||
changed_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
@router.post("/variables/", response_model=VariableResponse)
|
||||
async def create_variable(
|
||||
variable_data: VariableCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Create a new template variable with advanced features"""
|
||||
|
||||
# Check if variable name already exists
|
||||
existing = db.query(TemplateVariable).filter(
|
||||
TemplateVariable.name == variable_data.name
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Variable with name '{variable_data.name}' already exists"
|
||||
)
|
||||
|
||||
# Validate dependencies
|
||||
if variable_data.depends_on:
|
||||
for dep_name in variable_data.depends_on:
|
||||
dep_var = db.query(TemplateVariable).filter(
|
||||
TemplateVariable.name == dep_name,
|
||||
TemplateVariable.active == True
|
||||
).first()
|
||||
if not dep_var:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Dependency variable '{dep_name}' not found"
|
||||
)
|
||||
|
||||
# Create variable
|
||||
variable = TemplateVariable(
|
||||
name=variable_data.name,
|
||||
display_name=variable_data.display_name,
|
||||
description=variable_data.description,
|
||||
variable_type=variable_data.variable_type,
|
||||
required=variable_data.required,
|
||||
default_value=variable_data.default_value,
|
||||
formula=variable_data.formula,
|
||||
conditional_logic=variable_data.conditional_logic,
|
||||
data_source_query=variable_data.data_source_query,
|
||||
lookup_table=variable_data.lookup_table,
|
||||
lookup_key_field=variable_data.lookup_key_field,
|
||||
lookup_value_field=variable_data.lookup_value_field,
|
||||
validation_rules=variable_data.validation_rules,
|
||||
format_pattern=variable_data.format_pattern,
|
||||
depends_on=variable_data.depends_on,
|
||||
scope=variable_data.scope,
|
||||
category=variable_data.category,
|
||||
tags=variable_data.tags,
|
||||
cache_duration_minutes=variable_data.cache_duration_minutes,
|
||||
created_by=current_user.username,
|
||||
active=True
|
||||
)
|
||||
|
||||
db.add(variable)
|
||||
db.commit()
|
||||
db.refresh(variable)
|
||||
|
||||
return variable
|
||||
|
||||
|
||||
@router.get("/variables/", response_model=List[VariableResponse])
|
||||
async def list_variables(
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
category: Optional[str] = Query(None),
|
||||
variable_type: Optional[VariableType] = Query(None),
|
||||
active_only: bool = Query(True),
|
||||
search: Optional[str] = Query(None),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""List template variables with filtering options"""
|
||||
|
||||
query = db.query(TemplateVariable)
|
||||
|
||||
if active_only:
|
||||
query = query.filter(TemplateVariable.active == True)
|
||||
|
||||
if category:
|
||||
query = query.filter(TemplateVariable.category == category)
|
||||
|
||||
if variable_type:
|
||||
query = query.filter(TemplateVariable.variable_type == variable_type)
|
||||
|
||||
if search:
|
||||
search_filter = f"%{search}%"
|
||||
query = query.filter(
|
||||
or_(
|
||||
TemplateVariable.name.ilike(search_filter),
|
||||
TemplateVariable.display_name.ilike(search_filter),
|
||||
TemplateVariable.description.ilike(search_filter)
|
||||
)
|
||||
)
|
||||
|
||||
query = query.order_by(TemplateVariable.category, TemplateVariable.name)
|
||||
variables, _ = paginate_with_total(query, skip, limit, False)
|
||||
|
||||
return variables
|
||||
|
||||
|
||||
@router.get("/variables/{variable_id}", response_model=VariableResponse)
|
||||
async def get_variable(
|
||||
variable_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Get a specific variable by ID"""
|
||||
|
||||
variable = db.query(TemplateVariable).filter(
|
||||
TemplateVariable.id == variable_id
|
||||
).first()
|
||||
|
||||
if not variable:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Variable not found"
|
||||
)
|
||||
|
||||
return variable
|
||||
|
||||
|
||||
@router.put("/variables/{variable_id}", response_model=VariableResponse)
|
||||
async def update_variable(
|
||||
variable_id: int,
|
||||
variable_data: VariableUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Update a template variable"""
|
||||
|
||||
variable = db.query(TemplateVariable).filter(
|
||||
TemplateVariable.id == variable_id
|
||||
).first()
|
||||
|
||||
if not variable:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Variable not found"
|
||||
)
|
||||
|
||||
# Update fields that are provided
|
||||
update_data = variable_data.dict(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(variable, field, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(variable)
|
||||
|
||||
return variable
|
||||
|
||||
|
||||
@router.delete("/variables/{variable_id}")
|
||||
async def delete_variable(
|
||||
variable_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Delete a template variable (soft delete by setting active=False)"""
|
||||
|
||||
variable = db.query(TemplateVariable).filter(
|
||||
TemplateVariable.id == variable_id
|
||||
).first()
|
||||
|
||||
if not variable:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Variable not found"
|
||||
)
|
||||
|
||||
# Soft delete
|
||||
variable.active = False
|
||||
db.commit()
|
||||
|
||||
return {"message": "Variable deleted successfully"}
|
||||
|
||||
|
||||
@router.post("/variables/test", response_model=VariableTestResponse)
|
||||
async def test_variables(
|
||||
test_request: VariableTestRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Test variable resolution with given context"""
|
||||
|
||||
import time
|
||||
start_time = time.time()
|
||||
errors = []
|
||||
|
||||
try:
|
||||
processor = VariableProcessor(db)
|
||||
resolved, unresolved = processor.resolve_variables(
|
||||
variables=test_request.variables,
|
||||
context_type=test_request.context_type,
|
||||
context_id=test_request.context_id,
|
||||
base_context=test_request.test_context or {}
|
||||
)
|
||||
|
||||
processing_time = (time.time() - start_time) * 1000
|
||||
|
||||
return VariableTestResponse(
|
||||
resolved=resolved,
|
||||
unresolved=unresolved,
|
||||
processing_time_ms=processing_time,
|
||||
errors=errors
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
processing_time = (time.time() - start_time) * 1000
|
||||
errors.append(str(e))
|
||||
|
||||
return VariableTestResponse(
|
||||
resolved={},
|
||||
unresolved=test_request.variables,
|
||||
processing_time_ms=processing_time,
|
||||
errors=errors
|
||||
)
|
||||
|
||||
|
||||
@router.post("/variables/set-value")
|
||||
async def set_variable_value(
|
||||
context_data: VariableContextSet,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Set a variable value in a specific context"""
|
||||
|
||||
processor = VariableProcessor(db)
|
||||
success = processor.set_variable_value(
|
||||
variable_name=context_data.variable_name,
|
||||
value=context_data.value,
|
||||
context_type=context_data.context_type,
|
||||
context_id=context_data.context_id,
|
||||
user_name=current_user.username
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Failed to set variable value"
|
||||
)
|
||||
|
||||
return {"message": "Variable value set successfully"}
|
||||
|
||||
|
||||
@router.get("/variables/{variable_id}/contexts")
|
||||
async def get_variable_contexts(
|
||||
variable_id: int,
|
||||
context_type: Optional[str] = Query(None),
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Get all contexts where this variable has values"""
|
||||
|
||||
query = db.query(VariableContext).filter(
|
||||
VariableContext.variable_id == variable_id
|
||||
)
|
||||
|
||||
if context_type:
|
||||
query = query.filter(VariableContext.context_type == context_type)
|
||||
|
||||
query = query.order_by(VariableContext.context_type, VariableContext.context_id)
|
||||
contexts, total = paginate_with_total(query, skip, limit, True)
|
||||
|
||||
return {
|
||||
"items": [
|
||||
{
|
||||
"context_type": ctx.context_type,
|
||||
"context_id": ctx.context_id,
|
||||
"value": ctx.value,
|
||||
"computed_value": ctx.computed_value,
|
||||
"is_valid": ctx.is_valid,
|
||||
"validation_errors": ctx.validation_errors,
|
||||
"last_computed_at": ctx.last_computed_at
|
||||
}
|
||||
for ctx in contexts
|
||||
],
|
||||
"total": total
|
||||
}
|
||||
|
||||
|
||||
@router.get("/variables/{variable_id}/audit", response_model=List[VariableAuditResponse])
|
||||
async def get_variable_audit_log(
|
||||
variable_id: int,
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(20, ge=1, le=100),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Get audit log for a variable"""
|
||||
|
||||
query = db.query(VariableAuditLog, TemplateVariable.name).join(
|
||||
TemplateVariable, VariableAuditLog.variable_id == TemplateVariable.id
|
||||
).filter(
|
||||
VariableAuditLog.variable_id == variable_id
|
||||
).order_by(VariableAuditLog.changed_at.desc())
|
||||
|
||||
audit_logs, _ = paginate_with_total(query, skip, limit, False)
|
||||
|
||||
return [
|
||||
VariableAuditResponse(
|
||||
id=log.id,
|
||||
variable_name=var_name,
|
||||
context_type=log.context_type,
|
||||
context_id=log.context_id,
|
||||
old_value=log.old_value,
|
||||
new_value=log.new_value,
|
||||
change_type=log.change_type,
|
||||
change_reason=log.change_reason,
|
||||
changed_by=log.changed_by,
|
||||
changed_at=log.changed_at
|
||||
)
|
||||
for log, var_name in audit_logs
|
||||
]
|
||||
|
||||
|
||||
@router.get("/templates/{template_id}/variables")
|
||||
async def get_template_variables(
|
||||
template_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Get all variables associated with a template"""
|
||||
|
||||
processor = VariableProcessor(db)
|
||||
variables = processor.get_variables_for_template(template_id)
|
||||
|
||||
return {"variables": variables}
|
||||
|
||||
|
||||
@router.post("/templates/{template_id}/variables/{variable_id}")
|
||||
async def associate_variable_with_template(
|
||||
template_id: int,
|
||||
variable_id: int,
|
||||
override_default: Optional[str] = Body(None),
|
||||
override_required: Optional[bool] = Body(None),
|
||||
display_order: int = Body(0),
|
||||
group_name: Optional[str] = Body(None),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Associate a variable with a template"""
|
||||
|
||||
# Check if association already exists
|
||||
existing = db.query(VariableTemplate).filter(
|
||||
VariableTemplate.template_id == template_id,
|
||||
VariableTemplate.variable_id == variable_id
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
# Update existing association
|
||||
existing.override_default = override_default
|
||||
existing.override_required = override_required
|
||||
existing.display_order = display_order
|
||||
existing.group_name = group_name
|
||||
else:
|
||||
# Create new association
|
||||
association = VariableTemplate(
|
||||
template_id=template_id,
|
||||
variable_id=variable_id,
|
||||
override_default=override_default,
|
||||
override_required=override_required,
|
||||
display_order=display_order,
|
||||
group_name=group_name
|
||||
)
|
||||
db.add(association)
|
||||
|
||||
db.commit()
|
||||
return {"message": "Variable associated with template successfully"}
|
||||
|
||||
|
||||
@router.delete("/templates/{template_id}/variables/{variable_id}")
|
||||
async def remove_variable_from_template(
|
||||
template_id: int,
|
||||
variable_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Remove variable association from template"""
|
||||
|
||||
association = db.query(VariableTemplate).filter(
|
||||
VariableTemplate.template_id == template_id,
|
||||
VariableTemplate.variable_id == variable_id
|
||||
).first()
|
||||
|
||||
if not association:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Variable association not found"
|
||||
)
|
||||
|
||||
db.delete(association)
|
||||
db.commit()
|
||||
|
||||
return {"message": "Variable removed from template successfully"}
|
||||
|
||||
|
||||
@router.get("/categories")
|
||||
async def get_variable_categories(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Get list of variable categories"""
|
||||
|
||||
categories = db.query(
|
||||
TemplateVariable.category,
|
||||
func.count(TemplateVariable.id).label('count')
|
||||
).filter(
|
||||
TemplateVariable.active == True,
|
||||
TemplateVariable.category.isnot(None)
|
||||
).group_by(TemplateVariable.category).order_by(TemplateVariable.category).all()
|
||||
|
||||
return [
|
||||
{"category": cat, "count": count}
|
||||
for cat, count in categories
|
||||
]
|
||||
155
app/api/auth.py
155
app/api/auth.py
@@ -20,6 +20,12 @@ from app.auth.security import (
|
||||
get_current_user,
|
||||
get_admin_user,
|
||||
)
|
||||
from app.utils.enhanced_auth import (
|
||||
validate_and_authenticate_user,
|
||||
PasswordValidator,
|
||||
AccountLockoutManager,
|
||||
)
|
||||
from app.utils.session_manager import SessionManager, get_session_manager
|
||||
from app.auth.schemas import (
|
||||
Token,
|
||||
UserCreate,
|
||||
@@ -36,8 +42,13 @@ logger = get_logger("auth")
|
||||
|
||||
|
||||
@router.post("/login", response_model=Token)
|
||||
async def login(login_data: LoginRequest, request: Request, db: Session = Depends(get_db)):
|
||||
"""Login endpoint"""
|
||||
async def login(
|
||||
login_data: LoginRequest,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
session_manager: SessionManager = Depends(get_session_manager)
|
||||
):
|
||||
"""Enhanced login endpoint with session management and security features"""
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
user_agent = request.headers.get("user-agent", "")
|
||||
|
||||
@@ -48,30 +59,38 @@ async def login(login_data: LoginRequest, request: Request, db: Session = Depend
|
||||
user_agent=user_agent
|
||||
)
|
||||
|
||||
user = authenticate_user(db, login_data.username, login_data.password)
|
||||
if not user:
|
||||
log_auth_attempt(
|
||||
username=login_data.username,
|
||||
success=False,
|
||||
ip_address=client_ip,
|
||||
user_agent=user_agent,
|
||||
error="Invalid credentials"
|
||||
)
|
||||
# Use enhanced authentication with lockout protection
|
||||
user, auth_errors = validate_and_authenticate_user(
|
||||
db, login_data.username, login_data.password, request
|
||||
)
|
||||
|
||||
if not user or auth_errors:
|
||||
error_message = auth_errors[0] if auth_errors else "Incorrect username or password"
|
||||
|
||||
logger.warning(
|
||||
"Login failed - invalid credentials",
|
||||
"Login failed - enhanced auth",
|
||||
username=login_data.username,
|
||||
client_ip=client_ip
|
||||
client_ip=client_ip,
|
||||
errors=auth_errors
|
||||
)
|
||||
|
||||
# Get lockout info for response headers
|
||||
lockout_info = AccountLockoutManager.get_lockout_info(db, login_data.username)
|
||||
|
||||
headers = {"WWW-Authenticate": "Bearer"}
|
||||
if lockout_info["is_locked"]:
|
||||
headers["X-Account-Locked"] = "true"
|
||||
headers["X-Unlock-Time"] = lockout_info["unlock_time"] or ""
|
||||
else:
|
||||
headers["X-Attempts-Remaining"] = str(lockout_info["attempts_remaining"])
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Incorrect username or password",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
detail=error_message,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
# Update last login
|
||||
user.last_login = datetime.now(timezone.utc)
|
||||
db.commit()
|
||||
|
||||
# Successful authentication - create tokens
|
||||
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
|
||||
access_token = create_access_token(
|
||||
data={"sub": user.username}, expires_delta=access_token_expires
|
||||
@@ -83,14 +102,8 @@ async def login(login_data: LoginRequest, request: Request, db: Session = Depend
|
||||
db=db,
|
||||
)
|
||||
|
||||
log_auth_attempt(
|
||||
username=login_data.username,
|
||||
success=True,
|
||||
ip_address=client_ip,
|
||||
user_agent=user_agent
|
||||
)
|
||||
logger.info(
|
||||
"Login successful",
|
||||
"Login successful - enhanced auth",
|
||||
username=login_data.username,
|
||||
user_id=user.id,
|
||||
client_ip=client_ip
|
||||
@@ -105,7 +118,15 @@ async def register(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_admin_user) # Only admins can create users
|
||||
):
|
||||
"""Register new user (admin only)"""
|
||||
"""Register new user with password validation (admin only)"""
|
||||
# Validate password strength
|
||||
is_valid, password_errors = PasswordValidator.validate_password_strength(user_data.password)
|
||||
if not is_valid:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Password validation failed: {'; '.join(password_errors)}"
|
||||
)
|
||||
|
||||
# Check if username or email already exists
|
||||
existing_user = db.query(User).filter(
|
||||
(User.username == user_data.username) | (User.email == user_data.email)
|
||||
@@ -130,6 +151,12 @@ async def register(
|
||||
db.commit()
|
||||
db.refresh(new_user)
|
||||
|
||||
logger.info(
|
||||
"User registered",
|
||||
username=new_user.username,
|
||||
created_by=current_user.username
|
||||
)
|
||||
|
||||
return new_user
|
||||
|
||||
|
||||
@@ -257,4 +284,76 @@ async def update_theme_preference(
|
||||
current_user.theme_preference = theme_data.theme_preference
|
||||
db.commit()
|
||||
|
||||
return {"message": "Theme preference updated successfully", "theme": theme_data.theme_preference}
|
||||
return {"message": "Theme preference updated successfully", "theme": theme_data.theme_preference}
|
||||
|
||||
|
||||
@router.post("/validate-password")
|
||||
async def validate_password(password_data: dict):
|
||||
"""Validate password strength and return detailed feedback"""
|
||||
password = password_data.get("password", "")
|
||||
|
||||
if not password:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Password is required"
|
||||
)
|
||||
|
||||
is_valid, errors = PasswordValidator.validate_password_strength(password)
|
||||
strength_score = PasswordValidator.generate_password_strength_score(password)
|
||||
|
||||
return {
|
||||
"is_valid": is_valid,
|
||||
"errors": errors,
|
||||
"strength_score": strength_score,
|
||||
"strength_level": (
|
||||
"Very Weak" if strength_score < 20 else
|
||||
"Weak" if strength_score < 40 else
|
||||
"Fair" if strength_score < 60 else
|
||||
"Good" if strength_score < 80 else
|
||||
"Strong"
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@router.get("/account-status/{username}")
|
||||
async def get_account_status(
|
||||
username: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_admin_user) # Admin only endpoint
|
||||
):
|
||||
"""Get account lockout status and security information (admin only)"""
|
||||
lockout_info = AccountLockoutManager.get_lockout_info(db, username)
|
||||
|
||||
# Get recent login attempts
|
||||
from app.utils.enhanced_auth import SuspiciousActivityDetector
|
||||
is_suspicious, warnings = SuspiciousActivityDetector.is_login_suspicious(
|
||||
db, username, "admin-check", "admin-request"
|
||||
)
|
||||
|
||||
return {
|
||||
"username": username,
|
||||
"lockout_info": lockout_info,
|
||||
"suspicious_activity": {
|
||||
"is_suspicious": is_suspicious,
|
||||
"warnings": warnings
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@router.post("/unlock-account/{username}")
|
||||
async def unlock_account(
|
||||
username: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_admin_user) # Admin only endpoint
|
||||
):
|
||||
"""Manually unlock a user account (admin only)"""
|
||||
# Reset failed attempts by recording a successful "admin unlock"
|
||||
AccountLockoutManager.reset_failed_attempts(db, username)
|
||||
|
||||
logger.info(
|
||||
"Account manually unlocked",
|
||||
username=username,
|
||||
unlocked_by=current_user.username
|
||||
)
|
||||
|
||||
return {"message": f"Account '{username}' has been unlocked"}
|
||||
@@ -34,6 +34,17 @@ from app.models.billing import (
|
||||
BillingStatementItem, StatementStatus
|
||||
)
|
||||
from app.services.billing import BillingStatementService, StatementGenerationError
|
||||
from app.services.statement_generation import (
|
||||
generate_single_statement as _svc_generate_single_statement,
|
||||
parse_period_month as _svc_parse_period_month,
|
||||
render_statement_html as _svc_render_statement_html,
|
||||
)
|
||||
from app.services.batch_generation import (
|
||||
prepare_batch_parameters as _svc_prepare_batch_parameters,
|
||||
make_batch_id as _svc_make_batch_id,
|
||||
compute_estimated_completion as _svc_compute_eta,
|
||||
persist_batch_results as _svc_persist_batch_results,
|
||||
)
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
@@ -41,33 +52,29 @@ router = APIRouter()
|
||||
# Initialize logger for billing operations
|
||||
billing_logger = StructuredLogger("billing_operations", "INFO")
|
||||
|
||||
# Realtime WebSocket subscriber registry: batch_id -> set[WebSocket]
|
||||
_subscribers_by_batch: Dict[str, Set[WebSocket]] = {}
|
||||
_subscribers_lock = asyncio.Lock()
|
||||
# Import WebSocket pool services
|
||||
from app.middleware.websocket_middleware import get_websocket_manager
|
||||
from app.services.websocket_pool import WebSocketMessage
|
||||
|
||||
# WebSocket manager for batch progress notifications
|
||||
websocket_manager = get_websocket_manager()
|
||||
|
||||
|
||||
async def _notify_progress_subscribers(progress: "BatchProgress") -> None:
|
||||
"""Broadcast latest progress to active subscribers of a batch."""
|
||||
"""Broadcast latest progress to active subscribers of a batch using WebSocket pool."""
|
||||
batch_id = progress.batch_id
|
||||
message = {"type": "progress", "data": progress.model_dump()}
|
||||
async with _subscribers_lock:
|
||||
sockets = list(_subscribers_by_batch.get(batch_id, set()))
|
||||
if not sockets:
|
||||
return
|
||||
dead: List[WebSocket] = []
|
||||
for ws in sockets:
|
||||
try:
|
||||
await ws.send_json(message)
|
||||
except Exception:
|
||||
dead.append(ws)
|
||||
if dead:
|
||||
async with _subscribers_lock:
|
||||
bucket = _subscribers_by_batch.get(batch_id)
|
||||
if bucket:
|
||||
for ws in dead:
|
||||
bucket.discard(ws)
|
||||
if not bucket:
|
||||
_subscribers_by_batch.pop(batch_id, None)
|
||||
topic = f"batch_progress_{batch_id}"
|
||||
|
||||
# Use the WebSocket manager to broadcast to topic
|
||||
sent_count = await websocket_manager.broadcast_to_topic(
|
||||
topic=topic,
|
||||
message_type="progress",
|
||||
data=progress.model_dump()
|
||||
)
|
||||
|
||||
billing_logger.debug("Broadcast batch progress update",
|
||||
batch_id=batch_id,
|
||||
subscribers_notified=sent_count)
|
||||
|
||||
|
||||
def _round(value: Optional[float]) -> float:
|
||||
@@ -606,21 +613,8 @@ progress_store = BatchProgressStore()
|
||||
|
||||
|
||||
def _parse_period_month(period: Optional[str]) -> Optional[tuple[date, date]]:
|
||||
"""Parse period in the form YYYY-MM and return (start_date, end_date) inclusive.
|
||||
Returns None when period is not provided or invalid.
|
||||
"""
|
||||
if not period:
|
||||
return None
|
||||
m = re.fullmatch(r"(\d{4})-(\d{2})", str(period).strip())
|
||||
if not m:
|
||||
return None
|
||||
year = int(m.group(1))
|
||||
month = int(m.group(2))
|
||||
if month < 1 or month > 12:
|
||||
return None
|
||||
from calendar import monthrange
|
||||
last_day = monthrange(year, month)[1]
|
||||
return date(year, month, 1), date(year, month, last_day)
|
||||
"""Parse YYYY-MM period; delegates to service helper for consistency."""
|
||||
return _svc_parse_period_month(period)
|
||||
|
||||
|
||||
def _render_statement_html(
|
||||
@@ -633,80 +627,25 @@ def _render_statement_html(
|
||||
totals: StatementTotals,
|
||||
unbilled_entries: List[StatementEntry],
|
||||
) -> str:
|
||||
"""Create a simple, self-contained HTML statement string."""
|
||||
# Rows for unbilled entries
|
||||
def _fmt(val: Optional[float]) -> str:
|
||||
try:
|
||||
return f"{float(val or 0):.2f}"
|
||||
except Exception:
|
||||
return "0.00"
|
||||
|
||||
rows = []
|
||||
for e in unbilled_entries:
|
||||
rows.append(
|
||||
f"<tr><td>{e.date.isoformat() if e.date else ''}</td><td>{e.t_code}</td><td>{(e.description or '').replace('<','<').replace('>','>')}</td>"
|
||||
f"<td style='text-align:right'>{_fmt(e.quantity)}</td><td style='text-align:right'>{_fmt(e.rate)}</td><td style='text-align:right'>{_fmt(e.amount)}</td></tr>"
|
||||
)
|
||||
rows_html = "\n".join(rows) if rows else "<tr><td colspan='6' style='text-align:center;color:#666'>No unbilled entries</td></tr>"
|
||||
|
||||
period_html = f"<div><strong>Period:</strong> {period}</div>" if period else ""
|
||||
|
||||
html = f"""
|
||||
<!DOCTYPE html>
|
||||
<html lang=\"en\">
|
||||
<head>
|
||||
<meta charset=\"utf-8\" />
|
||||
<title>Statement {file_no}</title>
|
||||
<style>
|
||||
body {{ font-family: -apple-system, BlinkMacSystemFont, Segoe UI, Roboto, Helvetica, Arial, sans-serif; margin: 24px; }}
|
||||
h1 {{ margin: 0 0 8px 0; }}
|
||||
.meta {{ color: #444; margin-bottom: 16px; }}
|
||||
table {{ border-collapse: collapse; width: 100%; }}
|
||||
th, td {{ border: 1px solid #ddd; padding: 8px; font-size: 14px; }}
|
||||
th {{ background: #f6f6f6; text-align: left; }}
|
||||
.totals {{ margin: 16px 0; display: grid; grid-template-columns: repeat(auto-fit, minmax(180px, 1fr)); gap: 8px; }}
|
||||
.totals div {{ background: #fafafa; border: 1px solid #eee; padding: 8px; }}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>Statement</h1>
|
||||
<div class=\"meta\">
|
||||
<div><strong>File:</strong> {file_no}</div>
|
||||
<div><strong>Client:</strong> {client_name or ''}</div>
|
||||
<div><strong>Matter:</strong> {matter or ''}</div>
|
||||
<div><strong>As of:</strong> {as_of_iso}</div>
|
||||
{period_html}
|
||||
</div>
|
||||
|
||||
<div class=\"totals\">
|
||||
<div><strong>Charges (billed)</strong><br/>${_fmt(totals.charges_billed)}</div>
|
||||
<div><strong>Charges (unbilled)</strong><br/>${_fmt(totals.charges_unbilled)}</div>
|
||||
<div><strong>Charges (total)</strong><br/>${_fmt(totals.charges_total)}</div>
|
||||
<div><strong>Payments</strong><br/>${_fmt(totals.payments)}</div>
|
||||
<div><strong>Trust balance</strong><br/>${_fmt(totals.trust_balance)}</div>
|
||||
<div><strong>Current balance</strong><br/>${_fmt(totals.current_balance)}</div>
|
||||
</div>
|
||||
|
||||
<h2>Unbilled Entries</h2>
|
||||
<table>
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Date</th>
|
||||
<th>Code</th>
|
||||
<th>Description</th>
|
||||
<th style=\"text-align:right\">Qty</th>
|
||||
<th style=\"text-align:right\">Rate</th>
|
||||
<th style=\"text-align:right\">Amount</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{rows_html}
|
||||
</tbody>
|
||||
</table>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
return html
|
||||
"""Create statement HTML via service helper while preserving API models."""
|
||||
totals_dict: Dict[str, float] = {
|
||||
"charges_billed": totals.charges_billed,
|
||||
"charges_unbilled": totals.charges_unbilled,
|
||||
"charges_total": totals.charges_total,
|
||||
"payments": totals.payments,
|
||||
"trust_balance": totals.trust_balance,
|
||||
"current_balance": totals.current_balance,
|
||||
}
|
||||
entries_dict: List[Dict[str, Any]] = [e.model_dump() for e in (unbilled_entries or [])]
|
||||
return _svc_render_statement_html(
|
||||
file_no=file_no,
|
||||
client_name=client_name,
|
||||
matter=matter,
|
||||
as_of_iso=as_of_iso,
|
||||
period=period,
|
||||
totals=totals_dict,
|
||||
unbilled_entries=entries_dict,
|
||||
)
|
||||
|
||||
|
||||
def _generate_single_statement(
|
||||
@@ -714,118 +653,28 @@ def _generate_single_statement(
|
||||
period: Optional[str],
|
||||
db: Session
|
||||
) -> GeneratedStatementMeta:
|
||||
"""
|
||||
Internal helper to generate a statement for a single file.
|
||||
|
||||
Args:
|
||||
file_no: File number to generate statement for
|
||||
period: Optional period filter (YYYY-MM format)
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
GeneratedStatementMeta with file metadata and export path
|
||||
|
||||
Raises:
|
||||
HTTPException: If file not found or generation fails
|
||||
"""
|
||||
file_obj = (
|
||||
db.query(File)
|
||||
.options(joinedload(File.owner))
|
||||
.filter(File.file_no == file_no)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not file_obj:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"File {file_no} not found",
|
||||
)
|
||||
|
||||
# Optional period filtering (YYYY-MM)
|
||||
date_range = _parse_period_month(period)
|
||||
q = db.query(Ledger).filter(Ledger.file_no == file_no)
|
||||
if date_range:
|
||||
start_date, end_date = date_range
|
||||
q = q.filter(Ledger.date >= start_date).filter(Ledger.date <= end_date)
|
||||
entries: List[Ledger] = q.all()
|
||||
|
||||
CHARGE_TYPES = {"2", "3", "4"}
|
||||
charges_billed = sum(e.amount for e in entries if e.t_type in CHARGE_TYPES and e.billed == "Y")
|
||||
charges_unbilled = sum(e.amount for e in entries if e.t_type in CHARGE_TYPES and e.billed != "Y")
|
||||
charges_total = charges_billed + charges_unbilled
|
||||
payments_total = sum(e.amount for e in entries if e.t_type == "5")
|
||||
trust_balance = file_obj.trust_bal or 0.0
|
||||
current_balance = charges_total - payments_total
|
||||
|
||||
unbilled_entries = [
|
||||
StatementEntry(
|
||||
id=e.id,
|
||||
date=e.date,
|
||||
t_code=e.t_code,
|
||||
t_type=e.t_type,
|
||||
description=e.note,
|
||||
quantity=e.quantity or 0.0,
|
||||
rate=e.rate or 0.0,
|
||||
amount=e.amount,
|
||||
)
|
||||
for e in entries
|
||||
if e.t_type in CHARGE_TYPES and e.billed != "Y"
|
||||
]
|
||||
|
||||
client_name = None
|
||||
if file_obj.owner:
|
||||
client_name = f"{file_obj.owner.first or ''} {file_obj.owner.last}".strip()
|
||||
|
||||
as_of_iso = datetime.now(timezone.utc).isoformat()
|
||||
"""Generate a single statement via service and adapt to API response model."""
|
||||
data = _svc_generate_single_statement(file_no, period, db)
|
||||
totals = data.get("totals", {})
|
||||
totals_model = StatementTotals(
|
||||
charges_billed=_round(charges_billed),
|
||||
charges_unbilled=_round(charges_unbilled),
|
||||
charges_total=_round(charges_total),
|
||||
payments=_round(payments_total),
|
||||
trust_balance=_round(trust_balance),
|
||||
current_balance=_round(current_balance),
|
||||
charges_billed=float(totals.get("charges_billed", 0.0)),
|
||||
charges_unbilled=float(totals.get("charges_unbilled", 0.0)),
|
||||
charges_total=float(totals.get("charges_total", 0.0)),
|
||||
payments=float(totals.get("payments", 0.0)),
|
||||
trust_balance=float(totals.get("trust_balance", 0.0)),
|
||||
current_balance=float(totals.get("current_balance", 0.0)),
|
||||
)
|
||||
|
||||
# Render HTML
|
||||
html = _render_statement_html(
|
||||
file_no=file_no,
|
||||
client_name=client_name or None,
|
||||
matter=file_obj.regarding,
|
||||
as_of_iso=as_of_iso,
|
||||
period=period,
|
||||
totals=totals_model,
|
||||
unbilled_entries=unbilled_entries,
|
||||
)
|
||||
|
||||
# Ensure exports directory and write file
|
||||
exports_dir = Path("exports")
|
||||
try:
|
||||
exports_dir.mkdir(exist_ok=True)
|
||||
except Exception:
|
||||
# Best-effort: if cannot create, bubble up internal error
|
||||
raise HTTPException(status_code=500, detail="Unable to create exports directory")
|
||||
|
||||
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S_%f")
|
||||
safe_file_no = str(file_no).replace("/", "_").replace("\\", "_")
|
||||
filename = f"statement_{safe_file_no}_{timestamp}.html"
|
||||
export_path = exports_dir / filename
|
||||
html_bytes = html.encode("utf-8")
|
||||
with open(export_path, "wb") as f:
|
||||
f.write(html_bytes)
|
||||
|
||||
size = export_path.stat().st_size
|
||||
|
||||
return GeneratedStatementMeta(
|
||||
file_no=file_no,
|
||||
client_name=client_name or None,
|
||||
as_of=as_of_iso,
|
||||
period=period,
|
||||
file_no=str(data.get("file_no")),
|
||||
client_name=data.get("client_name"),
|
||||
as_of=str(data.get("as_of")),
|
||||
period=data.get("period"),
|
||||
totals=totals_model,
|
||||
unbilled_count=len(unbilled_entries),
|
||||
export_path=str(export_path),
|
||||
filename=filename,
|
||||
size=size,
|
||||
content_type="text/html",
|
||||
unbilled_count=int(data.get("unbilled_count", 0)),
|
||||
export_path=str(data.get("export_path")),
|
||||
filename=str(data.get("filename")),
|
||||
size=int(data.get("size", 0)),
|
||||
content_type=str(data.get("content_type", "text/html")),
|
||||
)
|
||||
|
||||
|
||||
@@ -842,92 +691,48 @@ async def generate_statement(
|
||||
return _generate_single_statement(payload.file_no, payload.period, db)
|
||||
|
||||
|
||||
async def _ws_authenticate(websocket: WebSocket) -> Optional[User]:
|
||||
"""Authenticate WebSocket via JWT token in query (?token=) or Authorization header."""
|
||||
token = websocket.query_params.get("token")
|
||||
if not token:
|
||||
try:
|
||||
auth_header = dict(websocket.headers).get("authorization") or ""
|
||||
if auth_header.lower().startswith("bearer "):
|
||||
token = auth_header.split(" ", 1)[1].strip()
|
||||
except Exception:
|
||||
token = None
|
||||
if not token:
|
||||
return None
|
||||
username = verify_token(token)
|
||||
if not username:
|
||||
return None
|
||||
db = SessionLocal()
|
||||
try:
|
||||
user = db.query(User).filter(User.username == username).first()
|
||||
if not user or not user.is_active:
|
||||
return None
|
||||
return user
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
async def _ws_keepalive(ws: WebSocket, stop_event: asyncio.Event) -> None:
|
||||
try:
|
||||
while not stop_event.is_set():
|
||||
await asyncio.sleep(25)
|
||||
try:
|
||||
await ws.send_json({"type": "ping", "ts": datetime.now(timezone.utc).isoformat()})
|
||||
except Exception:
|
||||
break
|
||||
finally:
|
||||
stop_event.set()
|
||||
|
||||
|
||||
@router.websocket("/statements/batch-progress/ws/{batch_id}")
|
||||
async def ws_batch_progress(websocket: WebSocket, batch_id: str):
|
||||
"""WebSocket: subscribe to real-time updates for a batch_id."""
|
||||
user = await _ws_authenticate(websocket)
|
||||
if not user:
|
||||
await websocket.close(code=4401)
|
||||
return
|
||||
await websocket.accept()
|
||||
# Register
|
||||
async with _subscribers_lock:
|
||||
bucket = _subscribers_by_batch.get(batch_id)
|
||||
if not bucket:
|
||||
bucket = set()
|
||||
_subscribers_by_batch[batch_id] = bucket
|
||||
bucket.add(websocket)
|
||||
# Send initial snapshot
|
||||
try:
|
||||
snapshot = await progress_store.get_progress(batch_id)
|
||||
await websocket.send_json({"type": "progress", "data": snapshot.model_dump() if snapshot else None})
|
||||
except Exception:
|
||||
pass
|
||||
# Keepalive + receive loop
|
||||
stop_event: asyncio.Event = asyncio.Event()
|
||||
ka_task = asyncio.create_task(_ws_keepalive(websocket, stop_event))
|
||||
try:
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
msg = await websocket.receive_text()
|
||||
except WebSocketDisconnect:
|
||||
break
|
||||
except Exception:
|
||||
break
|
||||
if isinstance(msg, str) and msg.strip() == "ping":
|
||||
try:
|
||||
await websocket.send_text("pong")
|
||||
except Exception:
|
||||
break
|
||||
finally:
|
||||
stop_event.set()
|
||||
"""WebSocket: subscribe to real-time updates for a batch_id using the WebSocket pool."""
|
||||
topic = f"batch_progress_{batch_id}"
|
||||
|
||||
# Custom message handler for batch progress
|
||||
async def handle_batch_message(connection_id: str, message: WebSocketMessage):
|
||||
"""Handle custom messages for batch progress"""
|
||||
billing_logger.debug("Received batch progress message",
|
||||
connection_id=connection_id,
|
||||
batch_id=batch_id,
|
||||
message_type=message.type)
|
||||
# Handle any batch-specific message logic here if needed
|
||||
|
||||
# Use the WebSocket manager to handle the connection
|
||||
connection_id = await websocket_manager.handle_connection(
|
||||
websocket=websocket,
|
||||
topics={topic},
|
||||
require_auth=True,
|
||||
metadata={"batch_id": batch_id, "endpoint": "batch_progress"},
|
||||
message_handler=handle_batch_message
|
||||
)
|
||||
|
||||
if connection_id:
|
||||
# Send initial snapshot after connection is established
|
||||
try:
|
||||
ka_task.cancel()
|
||||
except Exception:
|
||||
pass
|
||||
async with _subscribers_lock:
|
||||
bucket = _subscribers_by_batch.get(batch_id)
|
||||
if bucket and websocket in bucket:
|
||||
bucket.discard(websocket)
|
||||
if not bucket:
|
||||
_subscribers_by_batch.pop(batch_id, None)
|
||||
snapshot = await progress_store.get_progress(batch_id)
|
||||
pool = websocket_manager.pool
|
||||
initial_message = WebSocketMessage(
|
||||
type="progress",
|
||||
topic=topic,
|
||||
data=snapshot.model_dump() if snapshot else None
|
||||
)
|
||||
await pool._send_to_connection(connection_id, initial_message)
|
||||
billing_logger.info("Sent initial batch progress snapshot",
|
||||
connection_id=connection_id,
|
||||
batch_id=batch_id)
|
||||
except Exception as e:
|
||||
billing_logger.error("Failed to send initial batch progress snapshot",
|
||||
connection_id=connection_id,
|
||||
batch_id=batch_id,
|
||||
error=str(e))
|
||||
|
||||
@router.delete("/statements/batch-progress/{batch_id}")
|
||||
async def cancel_batch_operation(
|
||||
@@ -1045,25 +850,12 @@ async def batch_generate_statements(
|
||||
- Batch operation identification for audit trails
|
||||
- Automatic cleanup of progress data after completion
|
||||
"""
|
||||
# Validate request
|
||||
if not payload.file_numbers:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="At least one file number must be provided"
|
||||
)
|
||||
|
||||
if len(payload.file_numbers) > 50:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Maximum 50 files allowed per batch operation"
|
||||
)
|
||||
|
||||
# Remove duplicates while preserving order
|
||||
unique_file_numbers = list(dict.fromkeys(payload.file_numbers))
|
||||
# Validate request and normalize inputs
|
||||
unique_file_numbers = _svc_prepare_batch_parameters(payload.file_numbers)
|
||||
|
||||
# Generate batch ID and timing
|
||||
start_time = datetime.now(timezone.utc)
|
||||
batch_id = f"batch_{start_time.strftime('%Y%m%d_%H%M%S')}_{abs(hash(str(unique_file_numbers))) % 10000:04d}"
|
||||
batch_id = _svc_make_batch_id(unique_file_numbers, start_time)
|
||||
|
||||
billing_logger.info(
|
||||
"Starting batch statement generation",
|
||||
@@ -1121,7 +913,12 @@ async def batch_generate_statements(
|
||||
progress.current_file = file_no
|
||||
progress.files[idx].status = "processing"
|
||||
progress.files[idx].started_at = current_time.isoformat()
|
||||
progress.estimated_completion = await _calculate_estimated_completion(progress, current_time)
|
||||
progress.estimated_completion = _svc_compute_eta(
|
||||
processed_files=progress.processed_files,
|
||||
total_files=progress.total_files,
|
||||
started_at_iso=progress.started_at,
|
||||
now=current_time,
|
||||
)
|
||||
await progress_store.set_progress(progress)
|
||||
|
||||
billing_logger.info(
|
||||
@@ -1288,53 +1085,13 @@ async def batch_generate_statements(
|
||||
|
||||
# Persist batch summary and per-file results
|
||||
try:
|
||||
def _parse_iso(dt: Optional[str]):
|
||||
if not dt:
|
||||
return None
|
||||
try:
|
||||
from datetime import datetime as _dt
|
||||
return _dt.fromisoformat(dt.replace('Z', '+00:00'))
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
batch_row = BillingBatch(
|
||||
_svc_persist_batch_results(
|
||||
db,
|
||||
batch_id=batch_id,
|
||||
status=str(progress.status),
|
||||
total_files=total_files,
|
||||
successful_files=successful,
|
||||
failed_files=failed,
|
||||
started_at=_parse_iso(progress.started_at),
|
||||
updated_at=_parse_iso(progress.updated_at),
|
||||
completed_at=_parse_iso(progress.completed_at),
|
||||
progress=progress,
|
||||
processing_time_seconds=processing_time,
|
||||
success_rate=success_rate,
|
||||
error_message=progress.error_message,
|
||||
)
|
||||
db.add(batch_row)
|
||||
for f in progress.files:
|
||||
meta = getattr(f, 'statement_meta', None)
|
||||
filename = None
|
||||
size = None
|
||||
if meta is not None:
|
||||
try:
|
||||
filename = getattr(meta, 'filename', None)
|
||||
size = getattr(meta, 'size', None)
|
||||
except Exception:
|
||||
pass
|
||||
if filename is None and isinstance(meta, dict):
|
||||
filename = meta.get('filename')
|
||||
size = meta.get('size')
|
||||
db.add(BillingBatchFile(
|
||||
batch_id=batch_id,
|
||||
file_no=f.file_no,
|
||||
status=str(f.status),
|
||||
error_message=f.error_message,
|
||||
filename=filename,
|
||||
size=size,
|
||||
started_at=_parse_iso(f.started_at),
|
||||
completed_at=_parse_iso(f.completed_at),
|
||||
))
|
||||
db.commit()
|
||||
except Exception:
|
||||
try:
|
||||
db.rollback()
|
||||
@@ -1600,6 +1357,34 @@ async def download_latest_statement(
|
||||
detail="No statements found for requested period",
|
||||
)
|
||||
|
||||
# Filter out any statements created prior to the file's opened date (safety against collisions)
|
||||
try:
|
||||
opened_date = getattr(file_obj, "opened", None)
|
||||
if opened_date:
|
||||
filtered_by_opened: List[Path] = []
|
||||
for path in candidates:
|
||||
name = path.name
|
||||
# Filename format: statement_{safe_file_no}_YYYYMMDD_HHMMSS_micro.html
|
||||
m = re.match(rf"^statement_{re.escape(safe_file_no)}_(\d{{8}})_\d{{6}}_\d{{6}}\.html$", name)
|
||||
if not m:
|
||||
continue
|
||||
ymd = m.group(1)
|
||||
y, mo, d = int(ymd[0:4]), int(ymd[4:6]), int(ymd[6:8])
|
||||
from datetime import date as _date
|
||||
stmt_date = _date(y, mo, d)
|
||||
if stmt_date >= opened_date:
|
||||
filtered_by_opened.append(path)
|
||||
if filtered_by_opened:
|
||||
candidates = filtered_by_opened
|
||||
else:
|
||||
# If none meet the opened-date filter, treat as no statements
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="No statements found")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
# On parse errors, continue with existing candidates
|
||||
pass
|
||||
|
||||
# Choose latest by modification time
|
||||
candidates.sort(key=lambda p: p.stat().st_mtime, reverse=True)
|
||||
latest_path = candidates[0]
|
||||
|
||||
@@ -15,6 +15,7 @@ from app.models.user import User
|
||||
from app.auth.security import get_current_user
|
||||
from app.services.cache import invalidate_search_cache
|
||||
from app.services.customers_search import apply_customer_filters, apply_customer_sorting, prepare_customer_csv_rows
|
||||
from app.services.mailing import build_address_from_rolodex
|
||||
from app.services.query_utils import apply_sorting, paginate_with_total
|
||||
from app.utils.logging import app_logger
|
||||
from app.utils.database import db_transaction
|
||||
@@ -96,6 +97,430 @@ class CustomerResponse(CustomerBase):
|
||||
|
||||
|
||||
|
||||
@router.get("/phone-book")
|
||||
async def export_phone_book(
|
||||
mode: str = Query("numbers", description="Report mode: numbers | addresses | full"),
|
||||
format: str = Query("csv", description="Output format: csv | html"),
|
||||
group: Optional[str] = Query(None, description="Filter by customer group (exact match)"),
|
||||
groups: Optional[List[str]] = Query(None, description="Filter by multiple groups (repeat param)"),
|
||||
name_prefix: Optional[str] = Query(None, description="Prefix search across first/last name"),
|
||||
sort_by: Optional[str] = Query("name", description="Sort field: id, name, city, email"),
|
||||
sort_dir: Optional[str] = Query("asc", description="Sort direction: asc or desc"),
|
||||
grouping: Optional[str] = Query(
|
||||
"none",
|
||||
description="Grouping: none | letter | group | group_letter"
|
||||
),
|
||||
page_break: bool = Query(
|
||||
False,
|
||||
description="HTML only: start a new page for each top-level group"
|
||||
),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Generate phone book reports with filters and downloadable CSV/HTML.
|
||||
|
||||
Modes:
|
||||
- numbers: name and phone numbers
|
||||
- addresses: name, address, and phone numbers
|
||||
- full: detailed rolodex fields plus phones
|
||||
"""
|
||||
allowed_modes = {"numbers", "addresses", "full"}
|
||||
allowed_formats = {"csv", "html"}
|
||||
allowed_groupings = {"none", "letter", "group", "group_letter"}
|
||||
m = (mode or "").strip().lower()
|
||||
f = (format or "").strip().lower()
|
||||
if m not in allowed_modes:
|
||||
raise HTTPException(status_code=400, detail="Invalid mode. Use one of: numbers, addresses, full")
|
||||
if f not in allowed_formats:
|
||||
raise HTTPException(status_code=400, detail="Invalid format. Use one of: csv, html")
|
||||
gmode = (grouping or "none").strip().lower()
|
||||
if gmode not in allowed_groupings:
|
||||
raise HTTPException(status_code=400, detail="Invalid grouping. Use one of: none, letter, group, group_letter")
|
||||
|
||||
try:
|
||||
base_query = db.query(Rolodex)
|
||||
# Only group and name_prefix filtering are required per spec
|
||||
base_query = apply_customer_filters(
|
||||
base_query,
|
||||
search=None,
|
||||
group=group,
|
||||
state=None,
|
||||
groups=groups,
|
||||
states=None,
|
||||
name_prefix=name_prefix,
|
||||
)
|
||||
|
||||
base_query = apply_customer_sorting(base_query, sort_by=sort_by, sort_dir=sort_dir)
|
||||
|
||||
customers = base_query.options(joinedload(Rolodex.phone_numbers)).all()
|
||||
|
||||
def format_phones(entry: Rolodex) -> str:
|
||||
parts: List[str] = []
|
||||
try:
|
||||
for p in (entry.phone_numbers or []):
|
||||
label = (p.location or "").strip()
|
||||
if label:
|
||||
parts.append(f"{label}: {p.phone}")
|
||||
else:
|
||||
parts.append(p.phone)
|
||||
except Exception:
|
||||
pass
|
||||
return "; ".join([s for s in parts if s])
|
||||
|
||||
def display_name(entry: Rolodex) -> str:
|
||||
return build_address_from_rolodex(entry).display_name
|
||||
|
||||
def first_letter(entry: Rolodex) -> str:
|
||||
base = (entry.last or entry.first or "").strip()
|
||||
if not base:
|
||||
return "#"
|
||||
ch = base[0].upper()
|
||||
return ch if ch.isalpha() else "#"
|
||||
|
||||
# Apply grouping-specific sort for stable output
|
||||
if gmode == "letter":
|
||||
customers.sort(key=lambda c: (first_letter(c), (c.last or "").lower(), (c.first or "").lower()))
|
||||
elif gmode == "group":
|
||||
customers.sort(key=lambda c: ((c.group or "Ungrouped").lower(), (c.last or "").lower(), (c.first or "").lower()))
|
||||
elif gmode == "group_letter":
|
||||
customers.sort(key=lambda c: ((c.group or "Ungrouped").lower(), first_letter(c), (c.last or "").lower(), (c.first or "").lower()))
|
||||
|
||||
def build_csv() -> StreamingResponse:
|
||||
output = io.StringIO()
|
||||
writer = csv.writer(output)
|
||||
include_letter_col = gmode in ("letter", "group_letter")
|
||||
|
||||
if m == "numbers":
|
||||
header = ["Name", "Group"] + (["Letter"] if include_letter_col else []) + ["Phones"]
|
||||
writer.writerow(header)
|
||||
for c in customers:
|
||||
row = [display_name(c), c.group or ""]
|
||||
if include_letter_col:
|
||||
row.append(first_letter(c))
|
||||
row.append(format_phones(c))
|
||||
writer.writerow(row)
|
||||
elif m == "addresses":
|
||||
header = [
|
||||
"Name", "Group"
|
||||
] + (["Letter"] if include_letter_col else []) + [
|
||||
"Address 1", "Address 2", "Address 3", "City", "State", "ZIP", "Phones"
|
||||
]
|
||||
writer.writerow(header)
|
||||
for c in customers:
|
||||
addr = build_address_from_rolodex(c)
|
||||
row = [
|
||||
addr.display_name,
|
||||
c.group or "",
|
||||
]
|
||||
if include_letter_col:
|
||||
row.append(first_letter(c))
|
||||
row += [
|
||||
c.a1 or "",
|
||||
c.a2 or "",
|
||||
c.a3 or "",
|
||||
c.city or "",
|
||||
c.abrev or "",
|
||||
c.zip or "",
|
||||
format_phones(c),
|
||||
]
|
||||
writer.writerow(row)
|
||||
else: # full
|
||||
header = [
|
||||
"ID", "Last", "First", "Middle", "Prefix", "Suffix", "Title", "Group"
|
||||
] + (["Letter"] if include_letter_col else []) + [
|
||||
"Address 1", "Address 2", "Address 3", "City", "State", "ZIP", "Email", "Phones", "Legal Status",
|
||||
]
|
||||
writer.writerow(header)
|
||||
for c in customers:
|
||||
row = [
|
||||
c.id,
|
||||
c.last or "",
|
||||
c.first or "",
|
||||
c.middle or "",
|
||||
c.prefix or "",
|
||||
c.suffix or "",
|
||||
c.title or "",
|
||||
c.group or "",
|
||||
]
|
||||
if include_letter_col:
|
||||
row.append(first_letter(c))
|
||||
row += [
|
||||
c.a1 or "",
|
||||
c.a2 or "",
|
||||
c.a3 or "",
|
||||
c.city or "",
|
||||
c.abrev or "",
|
||||
c.zip or "",
|
||||
c.email or "",
|
||||
format_phones(c),
|
||||
c.legal_status or "",
|
||||
]
|
||||
writer.writerow(row)
|
||||
|
||||
output.seek(0)
|
||||
from datetime import datetime as _dt
|
||||
ts = _dt.now().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"phone_book_{m}_{ts}.csv"
|
||||
return StreamingResponse(
|
||||
iter([output.getvalue()]),
|
||||
media_type="text/csv",
|
||||
headers={"Content-Disposition": f"attachment; filename=\"{filename}\""},
|
||||
)
|
||||
|
||||
def build_html() -> StreamingResponse:
|
||||
# Minimal, printable HTML
|
||||
def css() -> str:
|
||||
return """
|
||||
body { font-family: Arial, sans-serif; margin: 16px; }
|
||||
h1 { font-size: 18pt; margin-bottom: 8px; }
|
||||
.meta { color: #666; font-size: 10pt; margin-bottom: 16px; }
|
||||
.entry { margin-bottom: 10px; }
|
||||
.name { font-weight: bold; }
|
||||
.phones, .addr { margin-left: 12px; }
|
||||
table { border-collapse: collapse; width: 100%; }
|
||||
th, td { border: 1px solid #ddd; padding: 6px 8px; font-size: 10pt; }
|
||||
th { background: #f5f5f5; text-align: left; }
|
||||
.section { margin-top: 18px; }
|
||||
.section-title { font-size: 14pt; margin: 12px 0; border-bottom: 1px solid #ddd; padding-bottom: 4px; }
|
||||
.subsection-title { font-size: 12pt; margin: 10px 0; color: #333; }
|
||||
@media print {
|
||||
.page-break { page-break-before: always; break-before: page; }
|
||||
}
|
||||
"""
|
||||
|
||||
title = {
|
||||
"numbers": "Phone Book (Numbers Only)",
|
||||
"addresses": "Phone Book (With Addresses)",
|
||||
"full": "Phone Book (Full Rolodex)",
|
||||
}[m]
|
||||
|
||||
from datetime import datetime as _dt
|
||||
generated = _dt.now().strftime("%Y-%m-%d %H:%M")
|
||||
|
||||
def render_entry_block(c: Rolodex) -> str:
|
||||
name = display_name(c)
|
||||
group_text = f" <span class=\"group\">({c.group})</span>" if c.group else ""
|
||||
phones_html = "".join([f"<div>{p.location + ': ' if p.location else ''}{p.phone}</div>" for p in (c.phone_numbers or [])])
|
||||
addr_html = ""
|
||||
if m == "addresses":
|
||||
addr_lines = build_address_from_rolodex(c).compact_lines(include_name=False)
|
||||
addr_html = "<div class=\"addr\">" + "".join([f"<div>{line}</div>" for line in addr_lines]) + "</div>"
|
||||
return f"<div class=\"entry\"><div class=\"name\">{name}{group_text}</div>{addr_html}<div class=\"phones\">{phones_html}</div></div>"
|
||||
|
||||
if m in ("numbers", "addresses"):
|
||||
sections: List[str] = []
|
||||
|
||||
if gmode == "none":
|
||||
blocks = [render_entry_block(c) for c in customers]
|
||||
html = f"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset=\"utf-8\" />
|
||||
<title>{title}</title>
|
||||
<style>{css()}</style>
|
||||
<meta name=\"generator\" content=\"delphi\" />
|
||||
<meta name=\"created\" content=\"{generated}\" />
|
||||
</head>
|
||||
<body>
|
||||
<h1>{title}</h1>
|
||||
<div class=\"meta\">Generated {generated}. Total entries: {len(customers)}.</div>
|
||||
{''.join(blocks)}
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
else:
|
||||
# Build sections according to grouping
|
||||
if gmode == "letter":
|
||||
# Letters A-Z plus '#'
|
||||
letters: List[str] = sorted({first_letter(c) for c in customers})
|
||||
for idx, letter in enumerate(letters):
|
||||
entries = [c for c in customers if first_letter(c) == letter]
|
||||
if not entries:
|
||||
continue
|
||||
section_class = "section" + (" page-break" if page_break and idx > 0 else "")
|
||||
blocks = [render_entry_block(c) for c in entries]
|
||||
sections.append(f"<div class=\"{section_class}\"><div class=\"section-title\">Letter: {letter}</div>{''.join(blocks)}</div>")
|
||||
elif gmode == "group":
|
||||
group_keys: List[str] = sorted({(c.group or "Ungrouped") for c in customers}, key=lambda s: s.lower())
|
||||
for idx, gkey in enumerate(group_keys):
|
||||
entries = [c for c in customers if (c.group or "Ungrouped") == gkey]
|
||||
if not entries:
|
||||
continue
|
||||
section_class = "section" + (" page-break" if page_break and idx > 0 else "")
|
||||
blocks = [render_entry_block(c) for c in entries]
|
||||
sections.append(f"<div class=\"{section_class}\"><div class=\"section-title\">Group: {gkey}</div>{''.join(blocks)}</div>")
|
||||
else: # group_letter
|
||||
group_keys: List[str] = sorted({(c.group or "Ungrouped") for c in customers}, key=lambda s: s.lower())
|
||||
for gidx, gkey in enumerate(group_keys):
|
||||
gentries = [c for c in customers if (c.group or "Ungrouped") == gkey]
|
||||
if not gentries:
|
||||
continue
|
||||
section_class = "section" + (" page-break" if page_break and gidx > 0 else "")
|
||||
subsections: List[str] = []
|
||||
letters = sorted({first_letter(c) for c in gentries})
|
||||
for letter in letters:
|
||||
lentries = [c for c in gentries if first_letter(c) == letter]
|
||||
if not lentries:
|
||||
continue
|
||||
blocks = [render_entry_block(c) for c in lentries]
|
||||
subsections.append(f"<div class=\"subsection\"><div class=\"subsection-title\">Letter: {letter}</div>{''.join(blocks)}</div>")
|
||||
sections.append(f"<div class=\"{section_class}\"><div class=\"section-title\">Group: {gkey}</div>{''.join(subsections)}</div>")
|
||||
|
||||
html = f"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset=\"utf-8\" />
|
||||
<title>{title}</title>
|
||||
<style>{css()}</style>
|
||||
<meta name=\"generator\" content=\"delphi\" />
|
||||
<meta name=\"created\" content=\"{generated}\" />
|
||||
</head>
|
||||
<body>
|
||||
<h1>{title}</h1>
|
||||
<div class=\"meta\">Generated {generated}. Total entries: {len(customers)}.</div>
|
||||
{''.join(sections)}
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
else:
|
||||
# Full table variant
|
||||
base_header_cells = [
|
||||
"ID", "Last", "First", "Middle", "Prefix", "Suffix", "Title", "Group",
|
||||
"Address 1", "Address 2", "Address 3", "City", "State", "ZIP", "Email", "Phones", "Legal Status",
|
||||
]
|
||||
|
||||
def render_rows(items: List[Rolodex]) -> str:
|
||||
rows_html: List[str] = []
|
||||
for c in items:
|
||||
phones = "".join([f"{p.location + ': ' if p.location else ''}{p.phone}" for p in (c.phone_numbers or [])])
|
||||
cells = [
|
||||
c.id or "",
|
||||
c.last or "",
|
||||
c.first or "",
|
||||
c.middle or "",
|
||||
c.prefix or "",
|
||||
c.suffix or "",
|
||||
c.title or "",
|
||||
c.group or "",
|
||||
c.a1 or "",
|
||||
c.a2 or "",
|
||||
c.a3 or "",
|
||||
c.city or "",
|
||||
c.abrev or "",
|
||||
c.zip or "",
|
||||
c.email or "",
|
||||
phones,
|
||||
c.legal_status or "",
|
||||
]
|
||||
rows_html.append("<tr>" + "".join([f"<td>{(str(v) if v is not None else '')}</td>" for v in cells]) + "</tr>")
|
||||
return "".join(rows_html)
|
||||
|
||||
if gmode == "none":
|
||||
rows_html = render_rows(customers)
|
||||
html = f"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset=\"utf-8\" />
|
||||
<title>{title}</title>
|
||||
<style>{css()}</style>
|
||||
<meta name=\"generator\" content=\"delphi\" />
|
||||
<meta name=\"created\" content=\"{generated}\" />
|
||||
</head>
|
||||
<body>
|
||||
<h1>{title}</h1>
|
||||
<div class=\"meta\">Generated {generated}. Total entries: {len(customers)}.</div>
|
||||
<table>
|
||||
<thead><tr>{''.join([f'<th>{h}</th>' for h in base_header_cells])}</tr></thead>
|
||||
<tbody>
|
||||
{rows_html}
|
||||
</tbody>
|
||||
</table>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
else:
|
||||
sections: List[str] = []
|
||||
if gmode == "letter":
|
||||
letters: List[str] = sorted({first_letter(c) for c in customers})
|
||||
for idx, letter in enumerate(letters):
|
||||
entries = [c for c in customers if first_letter(c) == letter]
|
||||
if not entries:
|
||||
continue
|
||||
section_class = "section" + (" page-break" if page_break and idx > 0 else "")
|
||||
rows_html = render_rows(entries)
|
||||
sections.append(
|
||||
f"<div class=\"{section_class}\"><div class=\"section-title\">Letter: {letter}</div>"
|
||||
f"<table><thead><tr>{''.join([f'<th>{h}</th>' for h in base_header_cells])}</tr></thead><tbody>{rows_html}</tbody></table></div>"
|
||||
)
|
||||
elif gmode == "group":
|
||||
group_keys: List[str] = sorted({(c.group or "Ungrouped") for c in customers}, key=lambda s: s.lower())
|
||||
for idx, gkey in enumerate(group_keys):
|
||||
entries = [c for c in customers if (c.group or "Ungrouped") == gkey]
|
||||
if not entries:
|
||||
continue
|
||||
section_class = "section" + (" page-break" if page_break and idx > 0 else "")
|
||||
rows_html = render_rows(entries)
|
||||
sections.append(
|
||||
f"<div class=\"{section_class}\"><div class=\"section-title\">Group: {gkey}</div>"
|
||||
f"<table><thead><tr>{''.join([f'<th>{h}</th>' for h in base_header_cells])}</tr></thead><tbody>{rows_html}</tbody></table></div>"
|
||||
)
|
||||
else: # group_letter
|
||||
group_keys: List[str] = sorted({(c.group or "Ungrouped") for c in customers}, key=lambda s: s.lower())
|
||||
for gidx, gkey in enumerate(group_keys):
|
||||
gentries = [c for c in customers if (c.group or "Ungrouped") == gkey]
|
||||
if not gentries:
|
||||
continue
|
||||
section_class = "section" + (" page-break" if page_break and gidx > 0 else "")
|
||||
subsections: List[str] = []
|
||||
letters = sorted({first_letter(c) for c in gentries})
|
||||
for letter in letters:
|
||||
lentries = [c for c in gentries if first_letter(c) == letter]
|
||||
if not lentries:
|
||||
continue
|
||||
rows_html = render_rows(lentries)
|
||||
subsections.append(
|
||||
f"<div class=\"subsection\"><div class=\"subsection-title\">Letter: {letter}</div>"
|
||||
f"<table><thead><tr>{''.join([f'<th>{h}</th>' for h in base_header_cells])}</tr></thead><tbody>{rows_html}</tbody></table></div>"
|
||||
)
|
||||
sections.append(f"<div class=\"{section_class}\"><div class=\"section-title\">Group: {gkey}</div>{''.join(subsections)}</div>")
|
||||
|
||||
html = f"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset=\"utf-8\" />
|
||||
<title>{title}</title>
|
||||
<style>{css()}</style>
|
||||
<meta name=\"generator\" content=\"delphi\" />
|
||||
<meta name=\"created\" content=\"{generated}\" />
|
||||
</head>
|
||||
<body>
|
||||
<h1>{title}</h1>
|
||||
<div class=\"meta\">Generated {generated}. Total entries: {len(customers)}.</div>
|
||||
{''.join(sections)}
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
from datetime import datetime as _dt
|
||||
ts = _dt.now().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"phone_book_{m}_{ts}.html"
|
||||
return StreamingResponse(
|
||||
iter([html]),
|
||||
media_type="text/html",
|
||||
headers={"Content-Disposition": f"attachment; filename=\"{filename}\""},
|
||||
)
|
||||
|
||||
return build_csv() if f == "csv" else build_html()
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error generating phone book: {str(e)}")
|
||||
|
||||
@router.get("/search/phone")
|
||||
async def search_by_phone(
|
||||
phone: str = Query(..., description="Phone number to search for"),
|
||||
|
||||
1103
app/api/deadlines.py
Normal file
1103
app/api/deadlines.py
Normal file
File diff suppressed because it is too large
Load Diff
748
app/api/document_workflows.py
Normal file
748
app/api/document_workflows.py
Normal file
@@ -0,0 +1,748 @@
|
||||
"""
|
||||
Document Workflow Management API
|
||||
|
||||
This API provides comprehensive workflow automation management including:
|
||||
- Workflow creation and configuration
|
||||
- Event logging and processing
|
||||
- Execution monitoring and control
|
||||
- Template management for common workflows
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional, Dict, Any, Union
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query, Body
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
from sqlalchemy import func, or_, and_, desc
|
||||
from pydantic import BaseModel, Field
|
||||
from datetime import datetime, date, timedelta
|
||||
import json
|
||||
|
||||
from app.database.base import get_db
|
||||
from app.auth.security import get_current_user
|
||||
from app.models.user import User
|
||||
from app.models.document_workflows import (
|
||||
DocumentWorkflow, WorkflowAction, WorkflowExecution, EventLog,
|
||||
WorkflowTemplate, WorkflowTriggerType, WorkflowActionType,
|
||||
ExecutionStatus, WorkflowStatus
|
||||
)
|
||||
from app.services.workflow_engine import EventProcessor, WorkflowExecutor
|
||||
from app.services.query_utils import paginate_with_total
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# Pydantic schemas for API
|
||||
class WorkflowCreate(BaseModel):
|
||||
name: str = Field(..., max_length=200)
|
||||
description: Optional[str] = None
|
||||
trigger_type: WorkflowTriggerType
|
||||
trigger_conditions: Optional[Dict[str, Any]] = None
|
||||
delay_minutes: int = Field(0, ge=0)
|
||||
max_retries: int = Field(3, ge=0, le=10)
|
||||
retry_delay_minutes: int = Field(30, ge=1)
|
||||
timeout_minutes: int = Field(60, ge=1)
|
||||
file_type_filter: Optional[List[str]] = None
|
||||
status_filter: Optional[List[str]] = None
|
||||
attorney_filter: Optional[List[str]] = None
|
||||
client_filter: Optional[List[str]] = None
|
||||
schedule_cron: Optional[str] = None
|
||||
schedule_timezone: str = "UTC"
|
||||
priority: int = Field(5, ge=1, le=10)
|
||||
category: Optional[str] = None
|
||||
tags: Optional[List[str]] = None
|
||||
|
||||
|
||||
class WorkflowUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
status: Optional[WorkflowStatus] = None
|
||||
trigger_conditions: Optional[Dict[str, Any]] = None
|
||||
delay_minutes: Optional[int] = None
|
||||
max_retries: Optional[int] = None
|
||||
retry_delay_minutes: Optional[int] = None
|
||||
timeout_minutes: Optional[int] = None
|
||||
file_type_filter: Optional[List[str]] = None
|
||||
status_filter: Optional[List[str]] = None
|
||||
attorney_filter: Optional[List[str]] = None
|
||||
client_filter: Optional[List[str]] = None
|
||||
schedule_cron: Optional[str] = None
|
||||
schedule_timezone: Optional[str] = None
|
||||
priority: Optional[int] = None
|
||||
category: Optional[str] = None
|
||||
tags: Optional[List[str]] = None
|
||||
|
||||
|
||||
class WorkflowActionCreate(BaseModel):
|
||||
action_type: WorkflowActionType
|
||||
action_order: int = Field(1, ge=1)
|
||||
action_name: Optional[str] = None
|
||||
parameters: Optional[Dict[str, Any]] = None
|
||||
template_id: Optional[int] = None
|
||||
output_format: str = "DOCX"
|
||||
custom_filename_template: Optional[str] = None
|
||||
email_template_id: Optional[int] = None
|
||||
email_recipients: Optional[List[str]] = None
|
||||
email_subject_template: Optional[str] = None
|
||||
condition: Optional[Dict[str, Any]] = None
|
||||
continue_on_failure: bool = False
|
||||
|
||||
|
||||
class WorkflowActionUpdate(BaseModel):
|
||||
action_type: Optional[WorkflowActionType] = None
|
||||
action_order: Optional[int] = None
|
||||
action_name: Optional[str] = None
|
||||
parameters: Optional[Dict[str, Any]] = None
|
||||
template_id: Optional[int] = None
|
||||
output_format: Optional[str] = None
|
||||
custom_filename_template: Optional[str] = None
|
||||
email_template_id: Optional[int] = None
|
||||
email_recipients: Optional[List[str]] = None
|
||||
email_subject_template: Optional[str] = None
|
||||
condition: Optional[Dict[str, Any]] = None
|
||||
continue_on_failure: Optional[bool] = None
|
||||
|
||||
|
||||
class WorkflowResponse(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
description: Optional[str]
|
||||
status: WorkflowStatus
|
||||
trigger_type: WorkflowTriggerType
|
||||
trigger_conditions: Optional[Dict[str, Any]]
|
||||
delay_minutes: int
|
||||
max_retries: int
|
||||
priority: int
|
||||
category: Optional[str]
|
||||
tags: Optional[List[str]]
|
||||
created_by: Optional[str]
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
last_triggered_at: Optional[datetime]
|
||||
execution_count: int
|
||||
success_count: int
|
||||
failure_count: int
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class WorkflowActionResponse(BaseModel):
|
||||
id: int
|
||||
workflow_id: int
|
||||
action_type: WorkflowActionType
|
||||
action_order: int
|
||||
action_name: Optional[str]
|
||||
parameters: Optional[Dict[str, Any]]
|
||||
template_id: Optional[int]
|
||||
output_format: str
|
||||
condition: Optional[Dict[str, Any]]
|
||||
continue_on_failure: bool
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class WorkflowExecutionResponse(BaseModel):
|
||||
id: int
|
||||
workflow_id: int
|
||||
triggered_by_event_id: Optional[str]
|
||||
triggered_by_event_type: Optional[str]
|
||||
context_file_no: Optional[str]
|
||||
context_client_id: Optional[str]
|
||||
status: ExecutionStatus
|
||||
started_at: Optional[datetime]
|
||||
completed_at: Optional[datetime]
|
||||
execution_duration_seconds: Optional[int]
|
||||
retry_count: int
|
||||
error_message: Optional[str]
|
||||
generated_documents: Optional[List[Dict[str, Any]]]
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class EventLogCreate(BaseModel):
|
||||
event_type: str
|
||||
event_source: str
|
||||
file_no: Optional[str] = None
|
||||
client_id: Optional[str] = None
|
||||
resource_type: Optional[str] = None
|
||||
resource_id: Optional[str] = None
|
||||
event_data: Optional[Dict[str, Any]] = None
|
||||
previous_state: Optional[Dict[str, Any]] = None
|
||||
new_state: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class EventLogResponse(BaseModel):
|
||||
id: int
|
||||
event_id: str
|
||||
event_type: str
|
||||
event_source: str
|
||||
file_no: Optional[str]
|
||||
client_id: Optional[str]
|
||||
resource_type: Optional[str]
|
||||
resource_id: Optional[str]
|
||||
event_data: Optional[Dict[str, Any]]
|
||||
processed: bool
|
||||
triggered_workflows: Optional[List[int]]
|
||||
occurred_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class WorkflowTestRequest(BaseModel):
|
||||
event_type: str
|
||||
event_data: Optional[Dict[str, Any]] = None
|
||||
file_no: Optional[str] = None
|
||||
client_id: Optional[str] = None
|
||||
|
||||
|
||||
class WorkflowStatsResponse(BaseModel):
|
||||
total_workflows: int
|
||||
active_workflows: int
|
||||
total_executions: int
|
||||
successful_executions: int
|
||||
failed_executions: int
|
||||
pending_executions: int
|
||||
workflows_by_trigger_type: Dict[str, int]
|
||||
executions_by_day: List[Dict[str, Any]]
|
||||
|
||||
|
||||
# Workflow CRUD endpoints
|
||||
@router.post("/workflows/", response_model=WorkflowResponse)
|
||||
async def create_workflow(
|
||||
workflow_data: WorkflowCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Create a new document workflow"""
|
||||
|
||||
# Check for duplicate names
|
||||
existing = db.query(DocumentWorkflow).filter(
|
||||
DocumentWorkflow.name == workflow_data.name
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Workflow with name '{workflow_data.name}' already exists"
|
||||
)
|
||||
|
||||
# Validate cron expression if provided
|
||||
if workflow_data.schedule_cron:
|
||||
try:
|
||||
from croniter import croniter
|
||||
croniter(workflow_data.schedule_cron)
|
||||
except Exception:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid cron expression"
|
||||
)
|
||||
|
||||
# Create workflow
|
||||
workflow = DocumentWorkflow(
|
||||
name=workflow_data.name,
|
||||
description=workflow_data.description,
|
||||
trigger_type=workflow_data.trigger_type,
|
||||
trigger_conditions=workflow_data.trigger_conditions,
|
||||
delay_minutes=workflow_data.delay_minutes,
|
||||
max_retries=workflow_data.max_retries,
|
||||
retry_delay_minutes=workflow_data.retry_delay_minutes,
|
||||
timeout_minutes=workflow_data.timeout_minutes,
|
||||
file_type_filter=workflow_data.file_type_filter,
|
||||
status_filter=workflow_data.status_filter,
|
||||
attorney_filter=workflow_data.attorney_filter,
|
||||
client_filter=workflow_data.client_filter,
|
||||
schedule_cron=workflow_data.schedule_cron,
|
||||
schedule_timezone=workflow_data.schedule_timezone,
|
||||
priority=workflow_data.priority,
|
||||
category=workflow_data.category,
|
||||
tags=workflow_data.tags,
|
||||
created_by=current_user.username,
|
||||
status=WorkflowStatus.ACTIVE
|
||||
)
|
||||
|
||||
db.add(workflow)
|
||||
db.commit()
|
||||
db.refresh(workflow)
|
||||
|
||||
return workflow
|
||||
|
||||
|
||||
@router.get("/workflows/", response_model=List[WorkflowResponse])
|
||||
async def list_workflows(
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
status: Optional[WorkflowStatus] = Query(None),
|
||||
trigger_type: Optional[WorkflowTriggerType] = Query(None),
|
||||
category: Optional[str] = Query(None),
|
||||
search: Optional[str] = Query(None),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""List workflows with filtering options"""
|
||||
|
||||
query = db.query(DocumentWorkflow)
|
||||
|
||||
if status:
|
||||
query = query.filter(DocumentWorkflow.status == status)
|
||||
|
||||
if trigger_type:
|
||||
query = query.filter(DocumentWorkflow.trigger_type == trigger_type)
|
||||
|
||||
if category:
|
||||
query = query.filter(DocumentWorkflow.category == category)
|
||||
|
||||
if search:
|
||||
search_filter = f"%{search}%"
|
||||
query = query.filter(
|
||||
or_(
|
||||
DocumentWorkflow.name.ilike(search_filter),
|
||||
DocumentWorkflow.description.ilike(search_filter)
|
||||
)
|
||||
)
|
||||
|
||||
query = query.order_by(DocumentWorkflow.priority.desc(), DocumentWorkflow.name)
|
||||
workflows, _ = paginate_with_total(query, skip, limit, False)
|
||||
|
||||
return workflows
|
||||
|
||||
|
||||
@router.get("/workflows/{workflow_id}", response_model=WorkflowResponse)
|
||||
async def get_workflow(
|
||||
workflow_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Get a specific workflow by ID"""
|
||||
|
||||
workflow = db.query(DocumentWorkflow).filter(
|
||||
DocumentWorkflow.id == workflow_id
|
||||
).first()
|
||||
|
||||
if not workflow:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Workflow not found"
|
||||
)
|
||||
|
||||
return workflow
|
||||
|
||||
|
||||
@router.put("/workflows/{workflow_id}", response_model=WorkflowResponse)
|
||||
async def update_workflow(
|
||||
workflow_id: int,
|
||||
workflow_data: WorkflowUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Update a workflow"""
|
||||
|
||||
workflow = db.query(DocumentWorkflow).filter(
|
||||
DocumentWorkflow.id == workflow_id
|
||||
).first()
|
||||
|
||||
if not workflow:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Workflow not found"
|
||||
)
|
||||
|
||||
# Update fields that are provided
|
||||
update_data = workflow_data.dict(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(workflow, field, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(workflow)
|
||||
|
||||
return workflow
|
||||
|
||||
|
||||
@router.delete("/workflows/{workflow_id}")
|
||||
async def delete_workflow(
|
||||
workflow_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Delete a workflow (soft delete by setting status to archived)"""
|
||||
|
||||
workflow = db.query(DocumentWorkflow).filter(
|
||||
DocumentWorkflow.id == workflow_id
|
||||
).first()
|
||||
|
||||
if not workflow:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Workflow not found"
|
||||
)
|
||||
|
||||
# Soft delete
|
||||
workflow.status = WorkflowStatus.ARCHIVED
|
||||
db.commit()
|
||||
|
||||
return {"message": "Workflow archived successfully"}
|
||||
|
||||
|
||||
# Workflow Actions endpoints
|
||||
@router.post("/workflows/{workflow_id}/actions", response_model=WorkflowActionResponse)
|
||||
async def create_workflow_action(
|
||||
workflow_id: int,
|
||||
action_data: WorkflowActionCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Create a new action for a workflow"""
|
||||
|
||||
workflow = db.query(DocumentWorkflow).filter(
|
||||
DocumentWorkflow.id == workflow_id
|
||||
).first()
|
||||
|
||||
if not workflow:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Workflow not found"
|
||||
)
|
||||
|
||||
action = WorkflowAction(
|
||||
workflow_id=workflow_id,
|
||||
action_type=action_data.action_type,
|
||||
action_order=action_data.action_order,
|
||||
action_name=action_data.action_name,
|
||||
parameters=action_data.parameters,
|
||||
template_id=action_data.template_id,
|
||||
output_format=action_data.output_format,
|
||||
custom_filename_template=action_data.custom_filename_template,
|
||||
email_template_id=action_data.email_template_id,
|
||||
email_recipients=action_data.email_recipients,
|
||||
email_subject_template=action_data.email_subject_template,
|
||||
condition=action_data.condition,
|
||||
continue_on_failure=action_data.continue_on_failure
|
||||
)
|
||||
|
||||
db.add(action)
|
||||
db.commit()
|
||||
db.refresh(action)
|
||||
|
||||
return action
|
||||
|
||||
|
||||
@router.get("/workflows/{workflow_id}/actions", response_model=List[WorkflowActionResponse])
|
||||
async def list_workflow_actions(
|
||||
workflow_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""List actions for a workflow"""
|
||||
|
||||
actions = db.query(WorkflowAction).filter(
|
||||
WorkflowAction.workflow_id == workflow_id
|
||||
).order_by(WorkflowAction.action_order, WorkflowAction.id).all()
|
||||
|
||||
return actions
|
||||
|
||||
|
||||
@router.put("/workflows/{workflow_id}/actions/{action_id}", response_model=WorkflowActionResponse)
|
||||
async def update_workflow_action(
|
||||
workflow_id: int,
|
||||
action_id: int,
|
||||
action_data: WorkflowActionUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Update a workflow action"""
|
||||
|
||||
action = db.query(WorkflowAction).filter(
|
||||
WorkflowAction.id == action_id,
|
||||
WorkflowAction.workflow_id == workflow_id
|
||||
).first()
|
||||
|
||||
if not action:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Action not found"
|
||||
)
|
||||
|
||||
# Update fields that are provided
|
||||
update_data = action_data.dict(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(action, field, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(action)
|
||||
|
||||
return action
|
||||
|
||||
|
||||
@router.delete("/workflows/{workflow_id}/actions/{action_id}")
|
||||
async def delete_workflow_action(
|
||||
workflow_id: int,
|
||||
action_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Delete a workflow action"""
|
||||
|
||||
action = db.query(WorkflowAction).filter(
|
||||
WorkflowAction.id == action_id,
|
||||
WorkflowAction.workflow_id == workflow_id
|
||||
).first()
|
||||
|
||||
if not action:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Action not found"
|
||||
)
|
||||
|
||||
db.delete(action)
|
||||
db.commit()
|
||||
|
||||
return {"message": "Action deleted successfully"}
|
||||
|
||||
|
||||
# Event Management endpoints
|
||||
@router.post("/events/", response_model=dict)
|
||||
async def log_event(
|
||||
event_data: EventLogCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Log a system event that may trigger workflows"""
|
||||
|
||||
processor = EventProcessor(db)
|
||||
event_id = await processor.log_event(
|
||||
event_type=event_data.event_type,
|
||||
event_source=event_data.event_source,
|
||||
file_no=event_data.file_no,
|
||||
client_id=event_data.client_id,
|
||||
user_id=current_user.id,
|
||||
resource_type=event_data.resource_type,
|
||||
resource_id=event_data.resource_id,
|
||||
event_data=event_data.event_data,
|
||||
previous_state=event_data.previous_state,
|
||||
new_state=event_data.new_state
|
||||
)
|
||||
|
||||
return {"event_id": event_id, "message": "Event logged successfully"}
|
||||
|
||||
|
||||
@router.get("/events/", response_model=List[EventLogResponse])
|
||||
async def list_events(
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
event_type: Optional[str] = Query(None),
|
||||
file_no: Optional[str] = Query(None),
|
||||
processed: Optional[bool] = Query(None),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""List system events"""
|
||||
|
||||
query = db.query(EventLog)
|
||||
|
||||
if event_type:
|
||||
query = query.filter(EventLog.event_type == event_type)
|
||||
|
||||
if file_no:
|
||||
query = query.filter(EventLog.file_no == file_no)
|
||||
|
||||
if processed is not None:
|
||||
query = query.filter(EventLog.processed == processed)
|
||||
|
||||
query = query.order_by(desc(EventLog.occurred_at))
|
||||
events, _ = paginate_with_total(query, skip, limit, False)
|
||||
|
||||
return events
|
||||
|
||||
|
||||
# Execution Management endpoints
|
||||
@router.get("/executions/", response_model=List[WorkflowExecutionResponse])
|
||||
async def list_executions(
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
workflow_id: Optional[int] = Query(None),
|
||||
status: Optional[ExecutionStatus] = Query(None),
|
||||
file_no: Optional[str] = Query(None),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""List workflow executions"""
|
||||
|
||||
query = db.query(WorkflowExecution)
|
||||
|
||||
if workflow_id:
|
||||
query = query.filter(WorkflowExecution.workflow_id == workflow_id)
|
||||
|
||||
if status:
|
||||
query = query.filter(WorkflowExecution.status == status)
|
||||
|
||||
if file_no:
|
||||
query = query.filter(WorkflowExecution.context_file_no == file_no)
|
||||
|
||||
query = query.order_by(desc(WorkflowExecution.started_at))
|
||||
executions, _ = paginate_with_total(query, skip, limit, False)
|
||||
|
||||
return executions
|
||||
|
||||
|
||||
@router.get("/executions/{execution_id}", response_model=WorkflowExecutionResponse)
|
||||
async def get_execution(
|
||||
execution_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Get details of a specific execution"""
|
||||
|
||||
execution = db.query(WorkflowExecution).filter(
|
||||
WorkflowExecution.id == execution_id
|
||||
).first()
|
||||
|
||||
if not execution:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Execution not found"
|
||||
)
|
||||
|
||||
return execution
|
||||
|
||||
|
||||
@router.post("/executions/{execution_id}/retry")
|
||||
async def retry_execution(
|
||||
execution_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Retry a failed workflow execution"""
|
||||
|
||||
execution = db.query(WorkflowExecution).filter(
|
||||
WorkflowExecution.id == execution_id
|
||||
).first()
|
||||
|
||||
if not execution:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Execution not found"
|
||||
)
|
||||
|
||||
if execution.status not in [ExecutionStatus.FAILED, ExecutionStatus.RETRYING]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Only failed executions can be retried"
|
||||
)
|
||||
|
||||
# Reset execution for retry
|
||||
execution.status = ExecutionStatus.PENDING
|
||||
execution.error_message = None
|
||||
execution.next_retry_at = None
|
||||
execution.retry_count += 1
|
||||
|
||||
db.commit()
|
||||
|
||||
# Execute the workflow
|
||||
executor = WorkflowExecutor(db)
|
||||
success = await executor.execute_workflow(execution_id)
|
||||
|
||||
return {"message": "Execution retried", "success": success}
|
||||
|
||||
|
||||
# Testing and Management endpoints
|
||||
@router.post("/workflows/{workflow_id}/test")
|
||||
async def test_workflow(
|
||||
workflow_id: int,
|
||||
test_request: WorkflowTestRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Test a workflow with simulated event data"""
|
||||
|
||||
workflow = db.query(DocumentWorkflow).filter(
|
||||
DocumentWorkflow.id == workflow_id
|
||||
).first()
|
||||
|
||||
if not workflow:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Workflow not found"
|
||||
)
|
||||
|
||||
# Create a test event
|
||||
processor = EventProcessor(db)
|
||||
event_id = await processor.log_event(
|
||||
event_type=test_request.event_type,
|
||||
event_source="workflow_test",
|
||||
file_no=test_request.file_no,
|
||||
client_id=test_request.client_id,
|
||||
user_id=current_user.id,
|
||||
event_data=test_request.event_data or {}
|
||||
)
|
||||
|
||||
return {"message": "Test event logged", "event_id": event_id}
|
||||
|
||||
|
||||
@router.get("/stats", response_model=WorkflowStatsResponse)
|
||||
async def get_workflow_stats(
|
||||
days: int = Query(30, ge=1, le=365),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Get workflow system statistics"""
|
||||
|
||||
# Basic counts
|
||||
total_workflows = db.query(func.count(DocumentWorkflow.id)).scalar()
|
||||
active_workflows = db.query(func.count(DocumentWorkflow.id)).filter(
|
||||
DocumentWorkflow.status == WorkflowStatus.ACTIVE
|
||||
).scalar()
|
||||
|
||||
total_executions = db.query(func.count(WorkflowExecution.id)).scalar()
|
||||
successful_executions = db.query(func.count(WorkflowExecution.id)).filter(
|
||||
WorkflowExecution.status == ExecutionStatus.COMPLETED
|
||||
).scalar()
|
||||
failed_executions = db.query(func.count(WorkflowExecution.id)).filter(
|
||||
WorkflowExecution.status == ExecutionStatus.FAILED
|
||||
).scalar()
|
||||
pending_executions = db.query(func.count(WorkflowExecution.id)).filter(
|
||||
WorkflowExecution.status.in_([ExecutionStatus.PENDING, ExecutionStatus.RUNNING])
|
||||
).scalar()
|
||||
|
||||
# Workflows by trigger type
|
||||
trigger_stats = db.query(
|
||||
DocumentWorkflow.trigger_type,
|
||||
func.count(DocumentWorkflow.id)
|
||||
).group_by(DocumentWorkflow.trigger_type).all()
|
||||
|
||||
workflows_by_trigger_type = {
|
||||
trigger.value: count for trigger, count in trigger_stats
|
||||
}
|
||||
|
||||
# Executions by day (for the chart)
|
||||
cutoff_date = datetime.now() - timedelta(days=days)
|
||||
daily_stats = db.query(
|
||||
func.date(WorkflowExecution.started_at).label('date'),
|
||||
func.count(WorkflowExecution.id).label('count'),
|
||||
func.sum(func.case((WorkflowExecution.status == ExecutionStatus.COMPLETED, 1), else_=0)).label('successful'),
|
||||
func.sum(func.case((WorkflowExecution.status == ExecutionStatus.FAILED, 1), else_=0)).label('failed')
|
||||
).filter(
|
||||
WorkflowExecution.started_at >= cutoff_date
|
||||
).group_by(func.date(WorkflowExecution.started_at)).all()
|
||||
|
||||
executions_by_day = [
|
||||
{
|
||||
'date': row.date.isoformat() if row.date else None,
|
||||
'total': row.count,
|
||||
'successful': row.successful or 0,
|
||||
'failed': row.failed or 0
|
||||
}
|
||||
for row in daily_stats
|
||||
]
|
||||
|
||||
return WorkflowStatsResponse(
|
||||
total_workflows=total_workflows or 0,
|
||||
active_workflows=active_workflows or 0,
|
||||
total_executions=total_executions or 0,
|
||||
successful_executions=successful_executions or 0,
|
||||
failed_executions=failed_executions or 0,
|
||||
pending_executions=pending_executions or 0,
|
||||
workflows_by_trigger_type=workflows_by_trigger_type,
|
||||
executions_by_day=executions_by_day
|
||||
)
|
||||
@@ -7,9 +7,12 @@ from fastapi import APIRouter, Depends, HTTPException, status, Query, UploadFile
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
from sqlalchemy import or_, func, and_, desc, asc, text
|
||||
from datetime import date, datetime, timezone
|
||||
import io
|
||||
import zipfile
|
||||
import os
|
||||
import uuid
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from app.database.base import get_db
|
||||
from app.api.search_highlight import build_query_tokens
|
||||
@@ -21,9 +24,17 @@ from app.models.lookups import FormIndex, FormList, Footer, Employee
|
||||
from app.models.user import User
|
||||
from app.auth.security import get_current_user
|
||||
from app.models.additional import Document
|
||||
from app.models.document_workflows import EventLog
|
||||
from app.core.logging import get_logger
|
||||
from app.services.audit import audit_service
|
||||
from app.services.cache import invalidate_search_cache
|
||||
from app.models.templates import DocumentTemplate, DocumentTemplateVersion
|
||||
from app.models.jobs import JobRecord
|
||||
from app.services.storage import get_default_storage
|
||||
from app.services.template_merge import extract_tokens_from_bytes, build_context, resolve_tokens, render_docx
|
||||
from app.services.document_notifications import notify_processing, notify_completed, notify_failed, topic_for_file, ADMIN_DOCUMENTS_TOPIC, get_last_status
|
||||
from app.middleware.websocket_middleware import get_websocket_manager, WebSocketMessage
|
||||
from fastapi import WebSocket
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -118,6 +129,87 @@ class PaginatedQDROResponse(BaseModel):
|
||||
total: int
|
||||
|
||||
|
||||
class CurrentStatusResponse(BaseModel):
|
||||
file_no: str
|
||||
status: str # processing | completed | failed | unknown
|
||||
timestamp: Optional[str] = None
|
||||
data: Optional[Dict[str, Any]] = None
|
||||
history: Optional[list] = None
|
||||
|
||||
|
||||
@router.get("/current-status/{file_no}", response_model=CurrentStatusResponse)
|
||||
async def get_current_document_status(
|
||||
file_no: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Return last-known document generation status for a file.
|
||||
|
||||
Priority:
|
||||
1) In-memory last broadcast state (processing/completed/failed)
|
||||
2) If no memory record, check for any uploaded/generated documents and report 'completed'
|
||||
3) Fallback to 'unknown'
|
||||
"""
|
||||
# Build recent history from EventLog (last N events)
|
||||
history_items = []
|
||||
try:
|
||||
recent = (
|
||||
db.query(EventLog)
|
||||
.filter(EventLog.file_no == file_no, EventLog.event_type.in_(["document_processing", "document_completed", "document_failed"]))
|
||||
.order_by(EventLog.occurred_at.desc())
|
||||
.limit(10)
|
||||
.all()
|
||||
)
|
||||
for ev in recent:
|
||||
history_items.append({
|
||||
"type": ev.event_type,
|
||||
"timestamp": ev.occurred_at.isoformat() if getattr(ev, "occurred_at", None) else None,
|
||||
"data": ev.event_data or {},
|
||||
})
|
||||
except Exception:
|
||||
history_items = []
|
||||
|
||||
# Try in-memory record for current status
|
||||
last = get_last_status(file_no)
|
||||
if last:
|
||||
ts = last.get("timestamp")
|
||||
iso = ts.isoformat() if hasattr(ts, "isoformat") else None
|
||||
status_val = str(last.get("status") or "unknown")
|
||||
# Treat stale 'processing' as unknown if older than 10 minutes
|
||||
try:
|
||||
if status_val == "processing" and isinstance(ts, datetime):
|
||||
age = datetime.now(timezone.utc) - ts
|
||||
if age.total_seconds() > 600:
|
||||
status_val = "unknown"
|
||||
except Exception:
|
||||
pass
|
||||
return CurrentStatusResponse(
|
||||
file_no=file_no,
|
||||
status=status_val,
|
||||
timestamp=iso,
|
||||
data=(last.get("data") or None),
|
||||
history=history_items,
|
||||
)
|
||||
|
||||
# Fallback: any existing documents imply last status completed
|
||||
any_doc = db.query(Document).filter(Document.file_no == file_no).order_by(Document.id.desc()).first()
|
||||
if any_doc:
|
||||
return CurrentStatusResponse(
|
||||
file_no=file_no,
|
||||
status="completed",
|
||||
timestamp=getattr(any_doc, "upload_date", None).isoformat() if getattr(any_doc, "upload_date", None) else None,
|
||||
data={
|
||||
"document_id": any_doc.id,
|
||||
"filename": any_doc.filename,
|
||||
"size": any_doc.size,
|
||||
},
|
||||
history=history_items,
|
||||
)
|
||||
|
||||
return CurrentStatusResponse(file_no=file_no, status="unknown", history=history_items)
|
||||
|
||||
|
||||
@router.get("/qdros/", response_model=Union[List[QDROResponse], PaginatedQDROResponse])
|
||||
async def list_qdros(
|
||||
skip: int = Query(0, ge=0),
|
||||
@@ -814,6 +906,371 @@ def _merge_template_variables(content: str, variables: Dict[str, Any]) -> str:
|
||||
return merged
|
||||
|
||||
|
||||
# --- Batch Document Generation (MVP synchronous) ---
|
||||
class BatchGenerateRequest(BaseModel):
|
||||
"""Batch generation request using DocumentTemplate system."""
|
||||
template_id: int
|
||||
version_id: Optional[int] = None
|
||||
file_nos: List[str]
|
||||
output_format: str = "DOCX" # DOCX (default), PDF (not yet supported), HTML (not yet supported)
|
||||
context: Optional[Dict[str, Any]] = None # additional global context
|
||||
bundle_zip: bool = False # when true, also create a ZIP bundle of generated outputs
|
||||
|
||||
|
||||
class BatchGenerateItemResult(BaseModel):
|
||||
file_no: str
|
||||
status: str # "success" | "error"
|
||||
document_id: Optional[int] = None
|
||||
filename: Optional[str] = None
|
||||
path: Optional[str] = None
|
||||
url: Optional[str] = None
|
||||
size: Optional[int] = None
|
||||
unresolved: Optional[List[str]] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class BatchGenerateResponse(BaseModel):
|
||||
job_id: str
|
||||
template_id: int
|
||||
version_id: int
|
||||
total_requested: int
|
||||
total_success: int
|
||||
total_failed: int
|
||||
results: List[BatchGenerateItemResult]
|
||||
bundle_url: Optional[str] = None
|
||||
bundle_size: Optional[int] = None
|
||||
|
||||
|
||||
@router.post("/generate-batch", response_model=BatchGenerateResponse)
|
||||
async def generate_batch_documents(
|
||||
payload: BatchGenerateRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Synchronously generate documents for multiple files from a template version.
|
||||
|
||||
Notes:
|
||||
- Currently supports DOCX output. PDF/HTML conversion is not yet implemented.
|
||||
- Saves generated bytes to default storage under uploads/generated/{file_no}/.
|
||||
- Persists a `Document` record per successful file.
|
||||
- Returns per-item status with unresolved tokens for transparency.
|
||||
"""
|
||||
tpl = db.query(DocumentTemplate).filter(DocumentTemplate.id == payload.template_id).first()
|
||||
if not tpl:
|
||||
raise HTTPException(status_code=404, detail="Template not found")
|
||||
resolved_version_id = payload.version_id or tpl.current_version_id
|
||||
if not resolved_version_id:
|
||||
raise HTTPException(status_code=400, detail="Template has no approved/current version")
|
||||
ver = (
|
||||
db.query(DocumentTemplateVersion)
|
||||
.filter(
|
||||
DocumentTemplateVersion.id == resolved_version_id,
|
||||
DocumentTemplateVersion.template_id == tpl.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not ver:
|
||||
raise HTTPException(status_code=404, detail="Template version not found")
|
||||
|
||||
storage = get_default_storage()
|
||||
try:
|
||||
template_bytes = storage.open_bytes(ver.storage_path)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404, detail="Stored template file not found")
|
||||
|
||||
tokens = extract_tokens_from_bytes(template_bytes)
|
||||
results: List[BatchGenerateItemResult] = []
|
||||
|
||||
# Pre-normalize file numbers (strip spaces, ignore empties)
|
||||
requested_files: List[str] = [fn.strip() for fn in (payload.file_nos or []) if fn and str(fn).strip()]
|
||||
if not requested_files:
|
||||
raise HTTPException(status_code=400, detail="No file numbers provided")
|
||||
|
||||
# Fetch all files in one query
|
||||
files_map: Dict[str, FileModel] = {
|
||||
f.file_no: f
|
||||
for f in db.query(FileModel).options(joinedload(FileModel.owner)).filter(FileModel.file_no.in_(requested_files)).all()
|
||||
}
|
||||
|
||||
generated_items: List[Dict[str, Any]] = [] # capture bytes for optional ZIP
|
||||
for file_no in requested_files:
|
||||
# Notify processing started for this file
|
||||
try:
|
||||
await notify_processing(
|
||||
file_no=file_no,
|
||||
user_id=current_user.id,
|
||||
data={
|
||||
"template_id": tpl.id,
|
||||
"template_name": tpl.name,
|
||||
"job_id": job_id
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
# Don't fail generation if notification fails
|
||||
pass
|
||||
|
||||
file_obj = files_map.get(file_no)
|
||||
if not file_obj:
|
||||
# Notify failure
|
||||
try:
|
||||
await notify_failed(
|
||||
file_no=file_no,
|
||||
user_id=current_user.id,
|
||||
data={"error": "File not found", "template_id": tpl.id}
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
results.append(
|
||||
BatchGenerateItemResult(
|
||||
file_no=file_no,
|
||||
status="error",
|
||||
error="File not found",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
# Build per-file context
|
||||
file_context: Dict[str, Any] = {
|
||||
"FILE_NO": file_obj.file_no,
|
||||
"CLIENT_FIRST": getattr(getattr(file_obj, "owner", None), "first", "") or "",
|
||||
"CLIENT_LAST": getattr(getattr(file_obj, "owner", None), "last", "") or "",
|
||||
"CLIENT_FULL": (
|
||||
f"{getattr(getattr(file_obj, 'owner', None), 'first', '') or ''} "
|
||||
f"{getattr(getattr(file_obj, 'owner', None), 'last', '') or ''}"
|
||||
).strip(),
|
||||
"MATTER": file_obj.regarding or "",
|
||||
"OPENED": file_obj.opened.strftime("%B %d, %Y") if getattr(file_obj, "opened", None) else "",
|
||||
"ATTORNEY": getattr(file_obj, "empl_num", "") or "",
|
||||
}
|
||||
# Merge global context
|
||||
merged_context = build_context({**(payload.context or {}), **file_context}, "file", file_obj.file_no)
|
||||
resolved_vars, unresolved_tokens = resolve_tokens(db, tokens, merged_context)
|
||||
|
||||
try:
|
||||
if ver.mime_type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
|
||||
output_bytes = render_docx(template_bytes, resolved_vars)
|
||||
output_mime = ver.mime_type
|
||||
extension = ".docx"
|
||||
else:
|
||||
# For non-DOCX templates (e.g., PDF), pass-through content
|
||||
output_bytes = template_bytes
|
||||
output_mime = ver.mime_type
|
||||
extension = ".bin"
|
||||
|
||||
# Name and save
|
||||
safe_name = f"{tpl.name}_{file_obj.file_no}{extension}"
|
||||
subdir = f"generated/{file_obj.file_no}"
|
||||
storage_path = storage.save_bytes(content=output_bytes, filename_hint=safe_name, subdir=subdir, content_type=output_mime)
|
||||
|
||||
# Persist Document record
|
||||
abs_or_rel_path = os.path.join("uploads", storage_path).replace("\\", "/")
|
||||
doc = Document(
|
||||
file_no=file_obj.file_no,
|
||||
filename=safe_name,
|
||||
path=abs_or_rel_path,
|
||||
description=f"Generated from template '{tpl.name}'",
|
||||
type=output_mime,
|
||||
size=len(output_bytes),
|
||||
uploaded_by=getattr(current_user, "username", None),
|
||||
)
|
||||
db.add(doc)
|
||||
db.commit()
|
||||
db.refresh(doc)
|
||||
|
||||
# Notify successful completion
|
||||
try:
|
||||
await notify_completed(
|
||||
file_no=file_obj.file_no,
|
||||
user_id=current_user.id,
|
||||
data={
|
||||
"template_id": tpl.id,
|
||||
"template_name": tpl.name,
|
||||
"document_id": doc.id,
|
||||
"filename": doc.filename,
|
||||
"size": doc.size,
|
||||
"unresolved_tokens": unresolved_tokens or []
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
# Don't fail generation if notification fails
|
||||
pass
|
||||
|
||||
results.append(
|
||||
BatchGenerateItemResult(
|
||||
file_no=file_obj.file_no,
|
||||
status="success",
|
||||
document_id=doc.id,
|
||||
filename=doc.filename,
|
||||
path=doc.path,
|
||||
url=storage.public_url(storage_path),
|
||||
size=doc.size,
|
||||
unresolved=unresolved_tokens or [],
|
||||
)
|
||||
)
|
||||
# Keep for bundling
|
||||
generated_items.append({
|
||||
"filename": doc.filename,
|
||||
"storage_path": storage_path,
|
||||
})
|
||||
except Exception as e:
|
||||
# Notify failure
|
||||
try:
|
||||
await notify_failed(
|
||||
file_no=file_obj.file_no,
|
||||
user_id=current_user.id,
|
||||
data={
|
||||
"template_id": tpl.id,
|
||||
"template_name": tpl.name,
|
||||
"error": str(e),
|
||||
"unresolved_tokens": unresolved_tokens or []
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Best-effort rollback of partial doc add
|
||||
try:
|
||||
db.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
results.append(
|
||||
BatchGenerateItemResult(
|
||||
file_no=file_obj.file_no,
|
||||
status="error",
|
||||
error=str(e),
|
||||
unresolved=unresolved_tokens or [],
|
||||
)
|
||||
)
|
||||
|
||||
job_id = str(uuid.uuid4())
|
||||
total_success = sum(1 for r in results if r.status == "success")
|
||||
total_failed = sum(1 for r in results if r.status == "error")
|
||||
bundle_url: Optional[str] = None
|
||||
bundle_size: Optional[int] = None
|
||||
|
||||
# Optionally create a ZIP bundle of generated outputs
|
||||
bundle_storage_path: Optional[str] = None
|
||||
if payload.bundle_zip and total_success > 0:
|
||||
# Stream zip to memory then save via storage adapter
|
||||
zip_buffer = io.BytesIO()
|
||||
with zipfile.ZipFile(zip_buffer, mode="w", compression=zipfile.ZIP_DEFLATED) as zf:
|
||||
for item in generated_items:
|
||||
try:
|
||||
file_bytes = storage.open_bytes(item["storage_path"]) # relative path under uploads
|
||||
# Use clean filename inside zip
|
||||
zf.writestr(item["filename"], file_bytes)
|
||||
except Exception:
|
||||
# Skip missing/unreadable files from bundle; keep job successful
|
||||
continue
|
||||
zip_bytes = zip_buffer.getvalue()
|
||||
safe_zip_name = f"documents_batch_{job_id}.zip"
|
||||
bundle_storage_path = storage.save_bytes(content=zip_bytes, filename_hint=safe_zip_name, subdir="bundles", content_type="application/zip")
|
||||
bundle_url = storage.public_url(bundle_storage_path)
|
||||
bundle_size = len(zip_bytes)
|
||||
|
||||
# Persist simple job record
|
||||
try:
|
||||
job = JobRecord(
|
||||
job_id=job_id,
|
||||
job_type="documents_batch",
|
||||
status="completed",
|
||||
requested_by_username=getattr(current_user, "username", None),
|
||||
started_at=datetime.now(timezone.utc),
|
||||
completed_at=datetime.now(timezone.utc),
|
||||
total_requested=len(requested_files),
|
||||
total_success=total_success,
|
||||
total_failed=total_failed,
|
||||
result_storage_path=bundle_storage_path,
|
||||
result_mime_type=("application/zip" if bundle_storage_path else None),
|
||||
result_size=bundle_size,
|
||||
details={
|
||||
"template_id": tpl.id,
|
||||
"version_id": ver.id,
|
||||
"file_nos": requested_files,
|
||||
},
|
||||
)
|
||||
db.add(job)
|
||||
db.commit()
|
||||
except Exception:
|
||||
try:
|
||||
db.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return BatchGenerateResponse(
|
||||
job_id=job_id,
|
||||
template_id=tpl.id,
|
||||
version_id=ver.id,
|
||||
total_requested=len(requested_files),
|
||||
total_success=total_success,
|
||||
total_failed=total_failed,
|
||||
results=results,
|
||||
bundle_url=bundle_url,
|
||||
bundle_size=bundle_size,
|
||||
)
|
||||
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
class JobStatusResponse(BaseModel):
|
||||
job_id: str
|
||||
job_type: str
|
||||
status: str
|
||||
total_requested: int
|
||||
total_success: int
|
||||
total_failed: int
|
||||
started_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
bundle_available: bool = False
|
||||
bundle_url: Optional[str] = None
|
||||
bundle_size: Optional[int] = None
|
||||
|
||||
|
||||
@router.get("/jobs/{job_id}", response_model=JobStatusResponse)
|
||||
async def get_job_status(
|
||||
job_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
job = db.query(JobRecord).filter(JobRecord.job_id == job_id).first()
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
return JobStatusResponse(
|
||||
job_id=job.job_id,
|
||||
job_type=job.job_type,
|
||||
status=job.status,
|
||||
total_requested=job.total_requested or 0,
|
||||
total_success=job.total_success or 0,
|
||||
total_failed=job.total_failed or 0,
|
||||
started_at=getattr(job, "started_at", None),
|
||||
completed_at=getattr(job, "completed_at", None),
|
||||
bundle_available=bool(job.result_storage_path),
|
||||
bundle_url=(get_default_storage().public_url(job.result_storage_path) if job.result_storage_path else None),
|
||||
bundle_size=job.result_size,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/jobs/{job_id}/result")
|
||||
async def download_job_result(
|
||||
job_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
job = db.query(JobRecord).filter(JobRecord.job_id == job_id).first()
|
||||
if not job or not job.result_storage_path:
|
||||
raise HTTPException(status_code=404, detail="Result not available for this job")
|
||||
storage = get_default_storage()
|
||||
try:
|
||||
content = storage.open_bytes(job.result_storage_path)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404, detail="Stored bundle not found")
|
||||
|
||||
# Derive filename
|
||||
base = os.path.basename(job.result_storage_path)
|
||||
headers = {
|
||||
"Content-Disposition": f"attachment; filename=\"{base}\"",
|
||||
}
|
||||
return StreamingResponse(iter([content]), media_type=(job.result_mime_type or "application/zip"), headers=headers)
|
||||
# --- Client Error Logging (for Documents page) ---
|
||||
class ClientErrorLog(BaseModel):
|
||||
"""Payload for client-side error logging"""
|
||||
@@ -894,54 +1351,118 @@ async def upload_document(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Upload a document to a file"""
|
||||
"""Upload a document to a file with comprehensive security validation and async operations"""
|
||||
from app.utils.file_security import file_validator, create_upload_directory
|
||||
from app.services.async_file_operations import async_file_ops, validate_large_upload
|
||||
from app.services.async_storage import async_storage
|
||||
|
||||
file_obj = db.query(FileModel).filter(FileModel.file_no == file_no).first()
|
||||
if not file_obj:
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="No file uploaded")
|
||||
# Determine if this is a large file that needs streaming
|
||||
file_size_estimate = getattr(file, 'size', 0) or 0
|
||||
use_streaming = file_size_estimate > 10 * 1024 * 1024 # 10MB threshold
|
||||
|
||||
allowed_types = [
|
||||
"application/pdf",
|
||||
"application/msword",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"image/jpeg",
|
||||
"image/png"
|
||||
]
|
||||
if file.content_type not in allowed_types:
|
||||
raise HTTPException(status_code=400, detail="Invalid file type")
|
||||
if use_streaming:
|
||||
# Use streaming validation for large files
|
||||
# Enforce the same 10MB limit used for non-streaming uploads
|
||||
is_valid, error_msg, metadata = await validate_large_upload(
|
||||
file, category='document', max_size=10 * 1024 * 1024
|
||||
)
|
||||
|
||||
if not is_valid:
|
||||
raise HTTPException(status_code=400, detail=error_msg)
|
||||
|
||||
safe_filename = file_validator.sanitize_filename(file.filename)
|
||||
file_ext = Path(safe_filename).suffix
|
||||
mime_type = metadata.get('content_type', 'application/octet-stream')
|
||||
|
||||
# Stream upload for large files
|
||||
subdir = f"documents/{file_no}"
|
||||
final_path, actual_size, _checksum = await async_file_ops.stream_upload_file(
|
||||
file,
|
||||
f"{subdir}/{uuid.uuid4()}{file_ext}",
|
||||
progress_callback=None # Could add WebSocket progress here
|
||||
)
|
||||
|
||||
# Get absolute path for database storage
|
||||
absolute_path = str(final_path)
|
||||
# For downstream DB fields that expect a relative path, also keep a relative for consistency
|
||||
relative_path = str(Path(final_path).relative_to(async_file_ops.base_upload_dir))
|
||||
|
||||
else:
|
||||
# Use traditional validation for smaller files
|
||||
content, safe_filename, file_ext, mime_type = await file_validator.validate_upload_file(
|
||||
file, category='document'
|
||||
)
|
||||
|
||||
max_size = 10 * 1024 * 1024 # 10MB
|
||||
content = await file.read()
|
||||
# Treat zero-byte payloads as no file uploaded to provide a clearer client error
|
||||
if len(content) == 0:
|
||||
raise HTTPException(status_code=400, detail="No file uploaded")
|
||||
if len(content) > max_size:
|
||||
raise HTTPException(status_code=400, detail="File too large")
|
||||
# Create secure upload directory
|
||||
upload_dir = f"uploads/{file_no}"
|
||||
create_upload_directory(upload_dir)
|
||||
|
||||
upload_dir = f"uploads/{file_no}"
|
||||
os.makedirs(upload_dir, exist_ok=True)
|
||||
# Generate secure file path with UUID to prevent conflicts
|
||||
unique_name = f"{uuid.uuid4()}{file_ext}"
|
||||
path = file_validator.generate_secure_path(upload_dir, unique_name)
|
||||
|
||||
ext = file.filename.split(".")[-1]
|
||||
unique_name = f"{uuid.uuid4()}.{ext}"
|
||||
path = f"{upload_dir}/{unique_name}"
|
||||
|
||||
with open(path, "wb") as f:
|
||||
f.write(content)
|
||||
# Write file using async storage for consistency
|
||||
try:
|
||||
relative_path = await async_storage.save_bytes_async(
|
||||
content,
|
||||
safe_filename,
|
||||
subdir=f"documents/{file_no}"
|
||||
)
|
||||
absolute_path = str(async_storage.base_dir / relative_path)
|
||||
actual_size = len(content)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Could not save file: {str(e)}")
|
||||
|
||||
doc = Document(
|
||||
file_no=file_no,
|
||||
filename=file.filename,
|
||||
path=path,
|
||||
filename=safe_filename, # Use sanitized filename
|
||||
path=absolute_path,
|
||||
description=description,
|
||||
type=file.content_type,
|
||||
size=len(content),
|
||||
type=mime_type, # Use validated MIME type
|
||||
size=actual_size,
|
||||
uploaded_by=current_user.username
|
||||
)
|
||||
db.add(doc)
|
||||
db.commit()
|
||||
db.refresh(doc)
|
||||
|
||||
# Send real-time notification for document upload
|
||||
try:
|
||||
await notify_completed(
|
||||
file_no=file_no,
|
||||
user_id=current_user.id,
|
||||
data={
|
||||
"action": "upload",
|
||||
"document_id": doc.id,
|
||||
"filename": safe_filename,
|
||||
"size": actual_size,
|
||||
"type": mime_type,
|
||||
"description": description
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
# Don't fail the operation if notification fails
|
||||
get_logger("documents").warning(f"Failed to send document upload notification: {str(e)}")
|
||||
|
||||
# Log workflow event for document upload
|
||||
try:
|
||||
from app.services.workflow_integration import log_document_uploaded_sync
|
||||
log_document_uploaded_sync(
|
||||
db=db,
|
||||
file_no=file_no,
|
||||
document_id=doc.id,
|
||||
filename=safe_filename,
|
||||
document_type=mime_type,
|
||||
user_id=current_user.id
|
||||
)
|
||||
except Exception as e:
|
||||
# Don't fail the operation if workflow logging fails
|
||||
get_logger("documents").warning(f"Failed to log workflow event for document upload: {str(e)}")
|
||||
|
||||
return doc
|
||||
|
||||
@router.get("/{file_no}/uploaded")
|
||||
@@ -987,4 +1508,125 @@ async def update_document(
|
||||
doc.description = description
|
||||
db.commit()
|
||||
db.refresh(doc)
|
||||
return doc
|
||||
return doc
|
||||
|
||||
|
||||
# WebSocket endpoints for real-time document status notifications
|
||||
|
||||
@router.websocket("/ws/status/{file_no}")
|
||||
async def ws_document_status(websocket: WebSocket, file_no: str):
|
||||
"""
|
||||
Subscribe to real-time document processing status updates for a specific file.
|
||||
|
||||
Users can connect to this endpoint to receive notifications about:
|
||||
- Document generation started (processing)
|
||||
- Document generation completed
|
||||
- Document generation failed
|
||||
- Document uploads
|
||||
|
||||
Authentication required via token query parameter.
|
||||
"""
|
||||
websocket_manager = get_websocket_manager()
|
||||
topic = topic_for_file(file_no)
|
||||
|
||||
# Custom message handler for document status updates
|
||||
async def handle_document_message(connection_id: str, message: WebSocketMessage):
|
||||
"""Handle custom messages for document status"""
|
||||
get_logger("documents").debug("Received document status message",
|
||||
connection_id=connection_id,
|
||||
file_no=file_no,
|
||||
message_type=message.type)
|
||||
|
||||
# Use the WebSocket manager to handle the connection
|
||||
connection_id = await websocket_manager.handle_connection(
|
||||
websocket=websocket,
|
||||
topics={topic},
|
||||
require_auth=True,
|
||||
metadata={"file_no": file_no, "endpoint": "document_status"},
|
||||
message_handler=handle_document_message
|
||||
)
|
||||
|
||||
if connection_id:
|
||||
# Send initial welcome message with subscription confirmation
|
||||
try:
|
||||
pool = websocket_manager.pool
|
||||
welcome_message = WebSocketMessage(
|
||||
type="subscription_confirmed",
|
||||
topic=topic,
|
||||
data={
|
||||
"file_no": file_no,
|
||||
"message": f"Subscribed to document status updates for file {file_no}"
|
||||
}
|
||||
)
|
||||
await pool._send_to_connection(connection_id, welcome_message)
|
||||
get_logger("documents").info("Document status subscription confirmed",
|
||||
connection_id=connection_id,
|
||||
file_no=file_no)
|
||||
except Exception as e:
|
||||
get_logger("documents").error("Failed to send subscription confirmation",
|
||||
connection_id=connection_id,
|
||||
file_no=file_no,
|
||||
error=str(e))
|
||||
|
||||
|
||||
# Test endpoint for document notification system
|
||||
@router.post("/test-notification/{file_no}")
|
||||
async def test_document_notification(
|
||||
file_no: str,
|
||||
status: str = Query(..., description="Notification status: processing, completed, or failed"),
|
||||
message: Optional[str] = Query(None, description="Optional message"),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Test endpoint to simulate document processing notifications.
|
||||
|
||||
This endpoint allows testing the WebSocket notification system by sending
|
||||
simulated document status updates. Useful for development and debugging.
|
||||
"""
|
||||
if status not in ["processing", "completed", "failed"]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Status must be one of: processing, completed, failed"
|
||||
)
|
||||
|
||||
# Prepare test data
|
||||
test_data = {
|
||||
"test": True,
|
||||
"triggered_by": current_user.username,
|
||||
"message": message or f"Test {status} notification for file {file_no}",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
# Send notification based on status
|
||||
try:
|
||||
if status == "processing":
|
||||
sent_count = await notify_processing(
|
||||
file_no=file_no,
|
||||
user_id=current_user.id,
|
||||
data=test_data
|
||||
)
|
||||
elif status == "completed":
|
||||
sent_count = await notify_completed(
|
||||
file_no=file_no,
|
||||
user_id=current_user.id,
|
||||
data=test_data
|
||||
)
|
||||
else: # failed
|
||||
sent_count = await notify_failed(
|
||||
file_no=file_no,
|
||||
user_id=current_user.id,
|
||||
data=test_data
|
||||
)
|
||||
|
||||
return {
|
||||
"message": f"Test notification sent for file {file_no}",
|
||||
"status": status,
|
||||
"sent_to_connections": sent_count,
|
||||
"data": test_data
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to send test notification: {str(e)}"
|
||||
)
|
||||
@@ -9,8 +9,8 @@ from pydantic import BaseModel, Field, ConfigDict
|
||||
|
||||
from app.database.base import get_db
|
||||
from app.models import (
|
||||
File, FileStatus, FileType, Employee, User, FileStatusHistory,
|
||||
FileTransferHistory, FileArchiveInfo
|
||||
File, FileStatus, FileType, Employee, User, FileStatusHistory,
|
||||
FileTransferHistory, FileArchiveInfo, FileClosureChecklist, FileAlert
|
||||
)
|
||||
from app.services.file_management import FileManagementService, FileManagementError, FileStatusWorkflow
|
||||
from app.auth.security import get_current_user
|
||||
@@ -134,6 +134,10 @@ async def change_file_status(
|
||||
"""Change file status with workflow validation"""
|
||||
try:
|
||||
service = FileManagementService(db)
|
||||
# Get the old status before changing
|
||||
old_file = db.query(File).filter(File.file_no == file_no).first()
|
||||
old_status = old_file.status if old_file else None
|
||||
|
||||
file_obj = service.change_file_status(
|
||||
file_no=file_no,
|
||||
new_status=request.new_status,
|
||||
@@ -142,6 +146,21 @@ async def change_file_status(
|
||||
validate_transition=request.validate_transition
|
||||
)
|
||||
|
||||
# Log workflow event for file status change
|
||||
try:
|
||||
from app.services.workflow_integration import log_file_status_change_sync
|
||||
log_file_status_change_sync(
|
||||
db=db,
|
||||
file_no=file_no,
|
||||
old_status=old_status,
|
||||
new_status=request.new_status,
|
||||
user_id=current_user.id,
|
||||
notes=request.notes
|
||||
)
|
||||
except Exception as e:
|
||||
# Don't fail the operation if workflow logging fails
|
||||
logger.warning(f"Failed to log workflow event for file {file_no}: {str(e)}")
|
||||
|
||||
return {
|
||||
"message": f"File {file_no} status changed to {request.new_status}",
|
||||
"file_no": file_obj.file_no,
|
||||
@@ -397,6 +416,302 @@ async def bulk_status_update(
|
||||
)
|
||||
|
||||
|
||||
# Checklist endpoints
|
||||
|
||||
class ChecklistItemRequest(BaseModel):
|
||||
item_name: str
|
||||
item_description: Optional[str] = None
|
||||
is_required: bool = True
|
||||
sort_order: int = 0
|
||||
|
||||
|
||||
class ChecklistItemUpdateRequest(BaseModel):
|
||||
item_name: Optional[str] = None
|
||||
item_description: Optional[str] = None
|
||||
is_required: Optional[bool] = None
|
||||
is_completed: Optional[bool] = None
|
||||
sort_order: Optional[int] = None
|
||||
notes: Optional[str] = None
|
||||
|
||||
|
||||
@router.get("/{file_no}/closure-checklist")
|
||||
async def get_closure_checklist(
|
||||
file_no: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
service = FileManagementService(db)
|
||||
return service.get_closure_checklist(file_no)
|
||||
|
||||
|
||||
@router.post("/{file_no}/closure-checklist")
|
||||
async def add_checklist_item(
|
||||
file_no: str,
|
||||
request: ChecklistItemRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
service = FileManagementService(db)
|
||||
try:
|
||||
item = service.add_checklist_item(
|
||||
file_no=file_no,
|
||||
item_name=request.item_name,
|
||||
item_description=request.item_description,
|
||||
is_required=request.is_required,
|
||||
sort_order=request.sort_order,
|
||||
)
|
||||
return {
|
||||
"id": item.id,
|
||||
"file_no": item.file_no,
|
||||
"item_name": item.item_name,
|
||||
"item_description": item.item_description,
|
||||
"is_required": item.is_required,
|
||||
"is_completed": item.is_completed,
|
||||
"sort_order": item.sort_order,
|
||||
}
|
||||
except FileManagementError as e:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
|
||||
|
||||
@router.put("/closure-checklist/{item_id}")
|
||||
async def update_checklist_item(
|
||||
item_id: int,
|
||||
request: ChecklistItemUpdateRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
service = FileManagementService(db)
|
||||
try:
|
||||
item = service.update_checklist_item(
|
||||
item_id=item_id,
|
||||
item_name=request.item_name,
|
||||
item_description=request.item_description,
|
||||
is_required=request.is_required,
|
||||
is_completed=request.is_completed,
|
||||
sort_order=request.sort_order,
|
||||
user_id=current_user.id,
|
||||
notes=request.notes,
|
||||
)
|
||||
return {
|
||||
"id": item.id,
|
||||
"file_no": item.file_no,
|
||||
"item_name": item.item_name,
|
||||
"item_description": item.item_description,
|
||||
"is_required": item.is_required,
|
||||
"is_completed": item.is_completed,
|
||||
"completed_date": item.completed_date,
|
||||
"completed_by_name": item.completed_by_name,
|
||||
"notes": item.notes,
|
||||
"sort_order": item.sort_order,
|
||||
}
|
||||
except FileManagementError as e:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/closure-checklist/{item_id}")
|
||||
async def delete_checklist_item(
|
||||
item_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
service = FileManagementService(db)
|
||||
try:
|
||||
service.delete_checklist_item(item_id=item_id)
|
||||
return {"message": "Checklist item deleted"}
|
||||
except FileManagementError as e:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
|
||||
|
||||
# Alerts endpoints
|
||||
|
||||
class AlertCreateRequest(BaseModel):
|
||||
alert_type: str
|
||||
title: str
|
||||
message: str
|
||||
alert_date: date
|
||||
notify_attorney: bool = True
|
||||
notify_admin: bool = False
|
||||
notification_days_advance: int = 7
|
||||
|
||||
|
||||
class AlertUpdateRequest(BaseModel):
|
||||
title: Optional[str] = None
|
||||
message: Optional[str] = None
|
||||
alert_date: Optional[date] = None
|
||||
is_active: Optional[bool] = None
|
||||
|
||||
|
||||
@router.post("/{file_no}/alerts")
|
||||
async def create_alert(
|
||||
file_no: str,
|
||||
request: AlertCreateRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
service = FileManagementService(db)
|
||||
try:
|
||||
alert = service.create_alert(
|
||||
file_no=file_no,
|
||||
alert_type=request.alert_type,
|
||||
title=request.title,
|
||||
message=request.message,
|
||||
alert_date=request.alert_date,
|
||||
notify_attorney=request.notify_attorney,
|
||||
notify_admin=request.notify_admin,
|
||||
notification_days_advance=request.notification_days_advance,
|
||||
)
|
||||
return {
|
||||
"id": alert.id,
|
||||
"file_no": alert.file_no,
|
||||
"alert_type": alert.alert_type,
|
||||
"title": alert.title,
|
||||
"message": alert.message,
|
||||
"alert_date": alert.alert_date,
|
||||
"is_active": alert.is_active,
|
||||
}
|
||||
except FileManagementError as e:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/{file_no}/alerts")
|
||||
async def get_alerts(
|
||||
file_no: str,
|
||||
active_only: bool = Query(True),
|
||||
upcoming_only: bool = Query(False),
|
||||
limit: int = Query(100, ge=1, le=500),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
service = FileManagementService(db)
|
||||
alerts = service.get_alerts(
|
||||
file_no=file_no,
|
||||
active_only=active_only,
|
||||
upcoming_only=upcoming_only,
|
||||
limit=limit,
|
||||
)
|
||||
return [
|
||||
{
|
||||
"id": a.id,
|
||||
"file_no": a.file_no,
|
||||
"alert_type": a.alert_type,
|
||||
"title": a.title,
|
||||
"message": a.message,
|
||||
"alert_date": a.alert_date,
|
||||
"is_active": a.is_active,
|
||||
"is_acknowledged": a.is_acknowledged,
|
||||
}
|
||||
for a in alerts
|
||||
]
|
||||
|
||||
|
||||
@router.post("/alerts/{alert_id}/acknowledge")
|
||||
async def acknowledge_alert(
|
||||
alert_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
service = FileManagementService(db)
|
||||
try:
|
||||
alert = service.acknowledge_alert(alert_id=alert_id, user_id=current_user.id)
|
||||
return {"message": "Alert acknowledged", "id": alert.id}
|
||||
except FileManagementError as e:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
|
||||
|
||||
@router.put("/alerts/{alert_id}")
|
||||
async def update_alert(
|
||||
alert_id: int,
|
||||
request: AlertUpdateRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
service = FileManagementService(db)
|
||||
try:
|
||||
alert = service.update_alert(
|
||||
alert_id=alert_id,
|
||||
title=request.title,
|
||||
message=request.message,
|
||||
alert_date=request.alert_date,
|
||||
is_active=request.is_active,
|
||||
)
|
||||
return {"message": "Alert updated", "id": alert.id}
|
||||
except FileManagementError as e:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/alerts/{alert_id}")
|
||||
async def delete_alert(
|
||||
alert_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
service = FileManagementService(db)
|
||||
try:
|
||||
service.delete_alert(alert_id=alert_id)
|
||||
return {"message": "Alert deleted"}
|
||||
except FileManagementError as e:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
|
||||
|
||||
# Relationships endpoints
|
||||
|
||||
class RelationshipCreateRequest(BaseModel):
|
||||
target_file_no: str
|
||||
relationship_type: str
|
||||
notes: Optional[str] = None
|
||||
|
||||
|
||||
@router.post("/{file_no}/relationships")
|
||||
async def create_relationship(
|
||||
file_no: str,
|
||||
request: RelationshipCreateRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
service = FileManagementService(db)
|
||||
try:
|
||||
rel = service.create_relationship(
|
||||
source_file_no=file_no,
|
||||
target_file_no=request.target_file_no,
|
||||
relationship_type=request.relationship_type,
|
||||
user_id=current_user.id,
|
||||
notes=request.notes,
|
||||
)
|
||||
return {
|
||||
"id": rel.id,
|
||||
"source_file_no": rel.source_file_no,
|
||||
"target_file_no": rel.target_file_no,
|
||||
"relationship_type": rel.relationship_type,
|
||||
"notes": rel.notes,
|
||||
}
|
||||
except FileManagementError as e:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/{file_no}/relationships")
|
||||
async def get_relationships(
|
||||
file_no: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
service = FileManagementService(db)
|
||||
return service.get_relationships(file_no=file_no)
|
||||
|
||||
|
||||
@router.delete("/relationships/{relationship_id}")
|
||||
async def delete_relationship(
|
||||
relationship_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
service = FileManagementService(db)
|
||||
try:
|
||||
service.delete_relationship(relationship_id=relationship_id)
|
||||
return {"message": "Relationship deleted"}
|
||||
except FileManagementError as e:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
|
||||
|
||||
# File queries and reports
|
||||
@router.get("/by-status/{status}")
|
||||
async def get_files_by_status(
|
||||
|
||||
@@ -16,6 +16,7 @@ from app.models.user import User
|
||||
from app.auth.security import get_current_user
|
||||
from app.services.cache import invalidate_search_cache
|
||||
from app.services.query_utils import apply_sorting, paginate_with_total
|
||||
from app.models.additional import Deposit, Payment
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -81,6 +82,23 @@ class PaginatedLedgerResponse(BaseModel):
|
||||
total: int
|
||||
|
||||
|
||||
class DepositResponse(BaseModel):
|
||||
deposit_date: date
|
||||
total: float
|
||||
notes: Optional[str] = None
|
||||
payments: Optional[List[Dict]] = None # Optional, depending on include_payments
|
||||
|
||||
class PaymentCreate(BaseModel):
|
||||
file_no: Optional[str] = None
|
||||
client_id: Optional[str] = None
|
||||
regarding: Optional[str] = None
|
||||
amount: float
|
||||
note: Optional[str] = None
|
||||
payment_method: str = "CHECK"
|
||||
reference: Optional[str] = None
|
||||
apply_to_trust: bool = False
|
||||
|
||||
|
||||
@router.get("/ledger/{file_no}", response_model=Union[List[LedgerResponse], PaginatedLedgerResponse])
|
||||
async def get_file_ledger(
|
||||
file_no: str,
|
||||
@@ -324,6 +342,59 @@ async def _update_file_balances(file_obj: File, db: Session):
|
||||
db.commit()
|
||||
|
||||
|
||||
async def _create_ledger_payment(
|
||||
file_no: str,
|
||||
amount: float,
|
||||
payment_date: date,
|
||||
payment_method: str,
|
||||
reference: Optional[str],
|
||||
notes: Optional[str],
|
||||
apply_to_trust: bool,
|
||||
empl_num: str,
|
||||
db: Session
|
||||
) -> Ledger:
|
||||
# Get next item number
|
||||
max_item = db.query(func.max(Ledger.item_no)).filter(
|
||||
Ledger.file_no == file_no
|
||||
).scalar() or 0
|
||||
|
||||
# Determine transaction type and code
|
||||
if apply_to_trust:
|
||||
t_type = "1" # Trust
|
||||
t_code = "TRUST"
|
||||
description = f"Trust deposit - {payment_method}"
|
||||
else:
|
||||
t_type = "5" # Credit/Payment
|
||||
t_code = "PMT"
|
||||
description = f"Payment received - {payment_method}"
|
||||
|
||||
if reference:
|
||||
description += f" - Ref: {reference}"
|
||||
|
||||
if notes:
|
||||
description += f" - {notes}"
|
||||
|
||||
# Create ledger entry
|
||||
entry = Ledger(
|
||||
file_no=file_no,
|
||||
item_no=max_item + 1,
|
||||
date=payment_date,
|
||||
t_code=t_code,
|
||||
t_type=t_type,
|
||||
t_type_l="C", # Credit
|
||||
empl_num=empl_num,
|
||||
quantity=0.0,
|
||||
rate=0.0,
|
||||
amount=amount,
|
||||
billed="Y", # Payments are automatically considered "billed"
|
||||
note=description
|
||||
)
|
||||
|
||||
db.add(entry)
|
||||
db.flush() # To get ID
|
||||
return entry
|
||||
|
||||
|
||||
# Additional Financial Management Endpoints
|
||||
|
||||
@router.get("/time-entries/recent")
|
||||
@@ -819,56 +890,27 @@ async def record_payment(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Record a payment against a file"""
|
||||
# Verify file exists
|
||||
file_obj = db.query(File).filter(File.file_no == file_no).first()
|
||||
if not file_obj:
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
payment_date = payment_date or date.today()
|
||||
|
||||
# Get next item number
|
||||
max_item = db.query(func.max(Ledger.item_no)).filter(
|
||||
Ledger.file_no == file_no
|
||||
).scalar() or 0
|
||||
|
||||
# Determine transaction type and code based on whether it goes to trust
|
||||
if apply_to_trust:
|
||||
t_type = "1" # Trust
|
||||
t_code = "TRUST"
|
||||
description = f"Trust deposit - {payment_method}"
|
||||
else:
|
||||
t_type = "5" # Credit/Payment
|
||||
t_code = "PMT"
|
||||
description = f"Payment received - {payment_method}"
|
||||
|
||||
if reference:
|
||||
description += f" - Ref: {reference}"
|
||||
|
||||
if notes:
|
||||
description += f" - {notes}"
|
||||
|
||||
# Create payment entry
|
||||
entry = Ledger(
|
||||
entry = await _create_ledger_payment(
|
||||
file_no=file_no,
|
||||
item_no=max_item + 1,
|
||||
date=payment_date,
|
||||
t_code=t_code,
|
||||
t_type=t_type,
|
||||
t_type_l="C", # Credit
|
||||
empl_num=file_obj.empl_num,
|
||||
quantity=0.0,
|
||||
rate=0.0,
|
||||
amount=amount,
|
||||
billed="Y", # Payments are automatically considered "billed"
|
||||
note=description
|
||||
payment_date=payment_date,
|
||||
payment_method=payment_method,
|
||||
reference=reference,
|
||||
notes=notes,
|
||||
apply_to_trust=apply_to_trust,
|
||||
empl_num=file_obj.empl_num,
|
||||
db=db
|
||||
)
|
||||
|
||||
db.add(entry)
|
||||
db.commit()
|
||||
db.refresh(entry)
|
||||
|
||||
# Update file balances
|
||||
await _update_file_balances(file_obj, db)
|
||||
|
||||
return {
|
||||
@@ -952,4 +994,157 @@ async def record_expense(
|
||||
"description": description,
|
||||
"employee": empl_num
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@router.post("/deposits/")
|
||||
async def create_deposit(
|
||||
deposit_date: date,
|
||||
notes: Optional[str] = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
existing = db.query(Deposit).filter(Deposit.deposit_date == deposit_date).first()
|
||||
if existing:
|
||||
raise HTTPException(status_code=400, detail="Deposit for this date already exists")
|
||||
|
||||
deposit = Deposit(
|
||||
deposit_date=deposit_date,
|
||||
total=0.0,
|
||||
notes=notes
|
||||
)
|
||||
db.add(deposit)
|
||||
db.commit()
|
||||
db.refresh(deposit)
|
||||
return deposit
|
||||
|
||||
@router.post("/deposits/{deposit_date}/payments/")
|
||||
async def add_payment_to_deposit(
|
||||
deposit_date: date,
|
||||
payment_data: PaymentCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
deposit = db.query(Deposit).filter(Deposit.deposit_date == deposit_date).first()
|
||||
if not deposit:
|
||||
raise HTTPException(status_code=404, detail="Deposit not found")
|
||||
|
||||
if not payment_data.file_no:
|
||||
raise HTTPException(status_code=400, detail="file_no is required for payments")
|
||||
|
||||
file_obj = db.query(File).filter(File.file_no == payment_data.file_no).first()
|
||||
if not file_obj:
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
# Create ledger entry first
|
||||
ledger_entry = await _create_ledger_payment(
|
||||
file_no=payment_data.file_no,
|
||||
amount=payment_data.amount,
|
||||
payment_date=deposit_date,
|
||||
payment_method=payment_data.payment_method,
|
||||
reference=payment_data.reference,
|
||||
notes=payment_data.note,
|
||||
apply_to_trust=payment_data.apply_to_trust,
|
||||
empl_num=file_obj.empl_num,
|
||||
db=db
|
||||
)
|
||||
|
||||
# Create payment record
|
||||
payment = Payment(
|
||||
deposit_date=deposit_date,
|
||||
file_no=payment_data.file_no,
|
||||
client_id=payment_data.client_id,
|
||||
regarding=payment_data.regarding,
|
||||
amount=payment_data.amount,
|
||||
note=payment_data.note
|
||||
)
|
||||
db.add(payment)
|
||||
|
||||
# Update deposit total
|
||||
deposit.total += payment_data.amount
|
||||
|
||||
db.commit()
|
||||
db.refresh(payment)
|
||||
await _update_file_balances(file_obj, db)
|
||||
|
||||
return payment
|
||||
|
||||
@router.get("/deposits/", response_model=List[DepositResponse])
|
||||
async def list_deposits(
|
||||
start_date: Optional[date] = None,
|
||||
end_date: Optional[date] = None,
|
||||
include_payments: bool = False,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
query = db.query(Deposit)
|
||||
if start_date:
|
||||
query = query.filter(Deposit.deposit_date >= start_date)
|
||||
if end_date:
|
||||
query = query.filter(Deposit.deposit_date <= end_date)
|
||||
query = query.order_by(Deposit.deposit_date.desc())
|
||||
|
||||
deposits = query.all()
|
||||
results = []
|
||||
for dep in deposits:
|
||||
dep_data = {
|
||||
"deposit_date": dep.deposit_date,
|
||||
"total": dep.total,
|
||||
"notes": dep.notes
|
||||
}
|
||||
if include_payments:
|
||||
payments = db.query(Payment).filter(Payment.deposit_date == dep.deposit_date).all()
|
||||
dep_data["payments"] = [p.__dict__ for p in payments]
|
||||
results.append(dep_data)
|
||||
return results
|
||||
|
||||
@router.get("/deposits/{deposit_date}", response_model=DepositResponse)
|
||||
async def get_deposit(
|
||||
deposit_date: date,
|
||||
include_payments: bool = True,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
deposit = db.query(Deposit).filter(Deposit.deposit_date == deposit_date).first()
|
||||
if not deposit:
|
||||
raise HTTPException(status_code=404, detail="Deposit not found")
|
||||
|
||||
dep_data = {
|
||||
"deposit_date": deposit.deposit_date,
|
||||
"total": deposit.total,
|
||||
"notes": deposit.notes
|
||||
}
|
||||
if include_payments:
|
||||
payments = db.query(Payment).filter(Payment.deposit_date == deposit_date).all()
|
||||
dep_data["payments"] = [p.__dict__ for p in payments]
|
||||
return dep_data
|
||||
|
||||
@router.get("/reports/deposits")
|
||||
async def get_deposit_report(
|
||||
start_date: date,
|
||||
end_date: date,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
deposits = db.query(Deposit).filter(
|
||||
Deposit.deposit_date >= start_date,
|
||||
Deposit.deposit_date <= end_date
|
||||
).order_by(Deposit.deposit_date).all()
|
||||
|
||||
total_deposits = sum(d.total for d in deposits)
|
||||
report = {
|
||||
"period": {
|
||||
"start": start_date.isoformat(),
|
||||
"end": end_date.isoformat()
|
||||
},
|
||||
"total_deposits": total_deposits,
|
||||
"deposit_count": len(deposits),
|
||||
"deposits": [
|
||||
{
|
||||
"date": d.deposit_date.isoformat(),
|
||||
"total": d.total,
|
||||
"notes": d.notes,
|
||||
"payment_count": db.query(Payment).filter(Payment.deposit_date == d.deposit_date).count()
|
||||
} for d in deposits
|
||||
]
|
||||
}
|
||||
return report
|
||||
@@ -3,6 +3,7 @@ Data import API endpoints for CSV file uploads with auto-discovery mapping.
|
||||
"""
|
||||
import csv
|
||||
import io
|
||||
import zipfile
|
||||
import re
|
||||
import os
|
||||
from pathlib import Path
|
||||
@@ -11,6 +12,7 @@ from datetime import datetime, date, timezone
|
||||
from decimal import Decimal
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File as UploadFileForm, Form, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
from app.database.base import get_db
|
||||
from app.auth.security import get_current_user
|
||||
@@ -40,8 +42,8 @@ ENCODINGS = [
|
||||
|
||||
# Unified import order used across batch operations
|
||||
IMPORT_ORDER = [
|
||||
"STATES.csv", "GRUPLKUP.csv", "EMPLOYEE.csv", "FILETYPE.csv", "FILESTAT.csv",
|
||||
"TRNSTYPE.csv", "TRNSLKUP.csv", "FOOTERS.csv", "SETUP.csv", "PRINTERS.csv",
|
||||
"STATES.csv", "GRUPLKUP.csv", "EMPLOYEE.csv", "FILETYPE.csv", "FOOTERS.csv", "FILESTAT.csv",
|
||||
"TRNSTYPE.csv", "TRNSLKUP.csv", "SETUP.csv", "PRINTERS.csv",
|
||||
"INX_LKUP.csv",
|
||||
"ROLODEX.csv", "PHONE.csv", "FILES.csv", "LEDGER.csv", "TRNSACTN.csv",
|
||||
"QDROS.csv", "PENSIONS.csv", "SCHEDULE.csv", "MARRIAGE.csv", "DEATH.csv", "SEPARATE.csv", "LIFETABL.csv", "NUMBERAL.csv", "PLANINFO.csv", "RESULTS.csv", "PAYMENTS.csv", "DEPOSITS.csv",
|
||||
@@ -91,8 +93,83 @@ CSV_MODEL_MAPPING = {
|
||||
"RESULTS.csv": PensionResult
|
||||
}
|
||||
|
||||
# Minimal CSV template definitions (headers + one sample row) used for template downloads
|
||||
CSV_IMPORT_TEMPLATES: Dict[str, Dict[str, List[str]]] = {
|
||||
"FILES.csv": {
|
||||
"headers": ["File_No", "Id", "Empl_Num", "File_Type", "Opened", "Status", "Rate_Per_Hour"],
|
||||
"sample": ["F-001", "CLIENT-1", "EMP01", "CIVIL", "2024-01-01", "ACTIVE", "150"],
|
||||
},
|
||||
"LEDGER.csv": {
|
||||
"headers": ["File_No", "Date", "Empl_Num", "T_Code", "T_Type", "Amount"],
|
||||
"sample": ["F-001", "2024-01-15", "EMP01", "FEE", "1", "500.00"],
|
||||
},
|
||||
"PAYMENTS.csv": {
|
||||
"headers": ["Deposit_Date", "Amount"],
|
||||
"sample": ["2024-01-15", "1500.00"],
|
||||
},
|
||||
# Additional templates for convenience
|
||||
"TRNSACTN.csv": {
|
||||
# Same structure as LEDGER.csv
|
||||
"headers": ["File_No", "Date", "Empl_Num", "T_Code", "T_Type", "Amount"],
|
||||
"sample": ["F-002", "2024-02-10", "EMP02", "FEE", "1", "250.00"],
|
||||
},
|
||||
"DEPOSITS.csv": {
|
||||
"headers": ["Deposit_Date", "Total"],
|
||||
"sample": ["2024-02-10", "1500.00"],
|
||||
},
|
||||
"ROLODEX.csv": {
|
||||
# Minimal common contact fields
|
||||
"headers": ["Id", "Last", "First", "A1", "City", "Abrev", "Zip", "Email"],
|
||||
"sample": ["CLIENT-1", "Smith", "John", "123 Main St", "Denver", "CO", "80202", "john.smith@example.com"],
|
||||
},
|
||||
}
|
||||
|
||||
def _generate_csv_template_bytes(file_type: str) -> bytes:
|
||||
"""Return CSV template content for the given file type as bytes.
|
||||
|
||||
Raises HTTPException if unsupported.
|
||||
"""
|
||||
key = (file_type or "").strip()
|
||||
if key not in CSV_IMPORT_TEMPLATES:
|
||||
raise HTTPException(status_code=400, detail=f"Unsupported template type: {file_type}. Choose one of: {list(CSV_IMPORT_TEMPLATES.keys())}")
|
||||
|
||||
cfg = CSV_IMPORT_TEMPLATES[key]
|
||||
output = io.StringIO()
|
||||
writer = csv.writer(output)
|
||||
writer.writerow(cfg["headers"])
|
||||
writer.writerow(cfg["sample"])
|
||||
output.seek(0)
|
||||
return output.getvalue().encode("utf-8")
|
||||
|
||||
# Field mappings for CSV columns to database fields
|
||||
# Legacy header synonyms used as hints only (not required). Auto-discovery will work without exact matches.
|
||||
REQUIRED_MODEL_FIELDS: Dict[str, List[str]] = {
|
||||
# Files: core identifiers and billing/status fields used throughout the app
|
||||
"FILES.csv": [
|
||||
"file_no",
|
||||
"id",
|
||||
"empl_num",
|
||||
"file_type",
|
||||
"opened",
|
||||
"status",
|
||||
"rate_per_hour",
|
||||
],
|
||||
# Ledger: core transaction fields
|
||||
"LEDGER.csv": [
|
||||
"file_no",
|
||||
"date",
|
||||
"empl_num",
|
||||
"t_code",
|
||||
"t_type",
|
||||
"amount",
|
||||
],
|
||||
# Payments: deposit date and amount are the only strictly required model fields
|
||||
"PAYMENTS.csv": [
|
||||
"deposit_date",
|
||||
"amount",
|
||||
],
|
||||
}
|
||||
|
||||
FIELD_MAPPINGS = {
|
||||
"ROLODEX.csv": {
|
||||
"Id": "id",
|
||||
@@ -191,7 +268,14 @@ FIELD_MAPPINGS = {
|
||||
"Draft_Apr": "draft_apr",
|
||||
"Final_Out": "final_out",
|
||||
"Judge": "judge",
|
||||
"Form_Name": "form_name"
|
||||
"Form_Name": "form_name",
|
||||
# Extended workflow/document fields (present in new exports or manual CSVs)
|
||||
"Status": "status",
|
||||
"Content": "content",
|
||||
"Notes": "notes",
|
||||
"Approval_Status": "approval_status",
|
||||
"Approved_Date": "approved_date",
|
||||
"Filed_Date": "filed_date"
|
||||
},
|
||||
"PENSIONS.csv": {
|
||||
"File_No": "file_no",
|
||||
@@ -218,9 +302,17 @@ FIELD_MAPPINGS = {
|
||||
},
|
||||
"EMPLOYEE.csv": {
|
||||
"Empl_Num": "empl_num",
|
||||
"Rate_Per_Hour": "rate_per_hour"
|
||||
# "Empl_Id": not a field in Employee model, using empl_num as identifier
|
||||
# Model has additional fields (first_name, last_name, title, etc.) not in CSV
|
||||
"Empl_Id": "initials", # Map employee ID to initials field
|
||||
"Rate_Per_Hour": "rate_per_hour",
|
||||
# Optional extended fields when present in enhanced exports
|
||||
"First": "first_name",
|
||||
"First_Name": "first_name",
|
||||
"Last": "last_name",
|
||||
"Last_Name": "last_name",
|
||||
"Title": "title",
|
||||
"Email": "email",
|
||||
"Phone": "phone",
|
||||
"Active": "active"
|
||||
},
|
||||
"STATES.csv": {
|
||||
"Abrev": "abbreviation",
|
||||
@@ -228,8 +320,8 @@ FIELD_MAPPINGS = {
|
||||
},
|
||||
"GRUPLKUP.csv": {
|
||||
"Code": "group_code",
|
||||
"Description": "description"
|
||||
# "Title": field not present in model, skipping
|
||||
"Description": "description",
|
||||
"Title": "title"
|
||||
},
|
||||
"TRNSLKUP.csv": {
|
||||
"T_Code": "t_code",
|
||||
@@ -240,10 +332,9 @@ FIELD_MAPPINGS = {
|
||||
},
|
||||
"TRNSTYPE.csv": {
|
||||
"T_Type": "t_type",
|
||||
"T_Type_L": "description"
|
||||
# "Header": maps to debit_credit but needs data transformation
|
||||
# "Footer": doesn't align with active boolean field
|
||||
# These fields may need custom handling or model updates
|
||||
"T_Type_L": "debit_credit", # D=Debit, C=Credit
|
||||
"Header": "description",
|
||||
"Footer": "footer_code"
|
||||
},
|
||||
"FILETYPE.csv": {
|
||||
"File_Type": "type_code",
|
||||
@@ -343,6 +434,10 @@ FIELD_MAPPINGS = {
|
||||
"DEATH.csv": {
|
||||
"File_No": "file_no",
|
||||
"Version": "version",
|
||||
"Beneficiary_Name": "beneficiary_name",
|
||||
"Benefit_Amount": "benefit_amount",
|
||||
"Benefit_Type": "benefit_type",
|
||||
"Notes": "notes",
|
||||
"Lump1": "lump1",
|
||||
"Lump2": "lump2",
|
||||
"Growth1": "growth1",
|
||||
@@ -353,6 +448,9 @@ FIELD_MAPPINGS = {
|
||||
"SEPARATE.csv": {
|
||||
"File_No": "file_no",
|
||||
"Version": "version",
|
||||
"Agreement_Date": "agreement_date",
|
||||
"Terms": "terms",
|
||||
"Notes": "notes",
|
||||
"Separation_Rate": "terms"
|
||||
},
|
||||
"LIFETABL.csv": {
|
||||
@@ -466,6 +564,40 @@ FIELD_MAPPINGS = {
|
||||
"Amount": "amount",
|
||||
"Billed": "billed",
|
||||
"Note": "note"
|
||||
},
|
||||
"EMPLOYEE.csv": {
|
||||
"Empl_Num": "empl_num",
|
||||
"Empl_Id": "initials", # Map employee ID to initials field
|
||||
"Rate_Per_Hour": "rate_per_hour",
|
||||
# Note: first_name, last_name, title, active, email, phone will need manual entry or separate import
|
||||
# as they're not present in the legacy CSV structure
|
||||
},
|
||||
"QDROS.csv": {
|
||||
"File_No": "file_no",
|
||||
"Version": "version",
|
||||
"Plan_Id": "plan_id",
|
||||
"^1": "field1",
|
||||
"^2": "field2",
|
||||
"^Part": "part",
|
||||
"^AltP": "altp",
|
||||
"^Pet": "pet",
|
||||
"^Res": "res",
|
||||
"Case_Type": "case_type",
|
||||
"Case_Code": "case_code",
|
||||
"Section": "section",
|
||||
"Case_Number": "case_number",
|
||||
"Judgment_Date": "judgment_date",
|
||||
"Valuation_Date": "valuation_date",
|
||||
"Married_On": "married_on",
|
||||
"Percent_Awarded": "percent_awarded",
|
||||
"Ven_City": "ven_city",
|
||||
"Ven_Cnty": "ven_cnty",
|
||||
"Ven_St": "ven_st",
|
||||
"Draft_Out": "draft_out",
|
||||
"Draft_Apr": "draft_apr",
|
||||
"Final_Out": "final_out",
|
||||
"Judge": "judge",
|
||||
"Form_Name": "form_name"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -691,6 +823,21 @@ def _build_dynamic_mapping(headers: List[str], model_class, file_type: str) -> D
|
||||
}
|
||||
|
||||
|
||||
def _validate_required_headers(file_type: str, mapped_headers: Dict[str, str]) -> Dict[str, Any]:
|
||||
"""Check that minimal required model fields for a given CSV type are present in mapped headers.
|
||||
|
||||
Returns dict with: required_fields, missing_fields, ok.
|
||||
"""
|
||||
required_fields = REQUIRED_MODEL_FIELDS.get(file_type, [])
|
||||
present_fields = set((mapped_headers or {}).values())
|
||||
missing_fields = [f for f in required_fields if f not in present_fields]
|
||||
return {
|
||||
"required_fields": required_fields,
|
||||
"missing_fields": missing_fields,
|
||||
"ok": len(missing_fields) == 0,
|
||||
}
|
||||
|
||||
|
||||
def _get_required_fields(model_class) -> List[str]:
|
||||
"""Infer required (non-nullable) fields for a model to avoid DB errors.
|
||||
|
||||
@@ -721,7 +868,7 @@ def convert_value(value: str, field_name: str) -> Any:
|
||||
|
||||
# Date fields
|
||||
if any(word in field_name.lower() for word in [
|
||||
"date", "dob", "birth", "opened", "closed", "judgment", "valuation", "married", "vests_on", "service"
|
||||
"date", "dob", "birth", "opened", "closed", "judgment", "valuation", "married", "vests_on", "service", "approved", "filed", "agreement"
|
||||
]):
|
||||
parsed_date = parse_date(value)
|
||||
return parsed_date
|
||||
@@ -752,6 +899,15 @@ def convert_value(value: str, field_name: str) -> Any:
|
||||
except ValueError:
|
||||
return 0.0
|
||||
|
||||
# Normalize debit_credit textual variants
|
||||
if field_name.lower() == "debit_credit":
|
||||
normalized = value.strip().upper()
|
||||
if normalized in ["D", "DEBIT"]:
|
||||
return "D"
|
||||
if normalized in ["C", "CREDIT"]:
|
||||
return "C"
|
||||
return normalized[:1] if normalized else None
|
||||
|
||||
# Integer fields
|
||||
if any(word in field_name.lower() for word in [
|
||||
"item_no", "age", "start_age", "version", "line_number", "sort_order", "empl_num", "month", "number"
|
||||
@@ -786,6 +942,69 @@ def validate_foreign_keys(model_data: dict, model_class, db: Session) -> list[st
|
||||
rolodex_id = model_data["id"]
|
||||
if rolodex_id and not db.query(Rolodex).filter(Rolodex.id == rolodex_id).first():
|
||||
errors.append(f"Owner Rolodex ID '{rolodex_id}' not found")
|
||||
# Check File -> Footer relationship (default footer on file)
|
||||
if model_class == File and "footer_code" in model_data:
|
||||
footer = model_data.get("footer_code")
|
||||
if footer:
|
||||
exists = db.query(Footer).filter(Footer.footer_code == footer).first()
|
||||
if not exists:
|
||||
errors.append(f"Footer code '{footer}' not found for File")
|
||||
|
||||
# Check FileStatus -> Footer (default footer exists)
|
||||
if model_class == FileStatus and "footer_code" in model_data:
|
||||
footer = model_data.get("footer_code")
|
||||
if footer:
|
||||
exists = db.query(Footer).filter(Footer.footer_code == footer).first()
|
||||
if not exists:
|
||||
errors.append(f"Footer code '{footer}' not found for FileStatus")
|
||||
|
||||
# Check TransactionType -> Footer (default footer exists)
|
||||
if model_class == TransactionType and "footer_code" in model_data:
|
||||
footer = model_data.get("footer_code")
|
||||
if footer:
|
||||
exists = db.query(Footer).filter(Footer.footer_code == footer).first()
|
||||
if not exists:
|
||||
errors.append(f"Footer code '{footer}' not found for TransactionType")
|
||||
|
||||
# Check Ledger -> TransactionType/TransactionCode cross references
|
||||
if model_class == Ledger:
|
||||
# Validate t_type exists
|
||||
if "t_type" in model_data:
|
||||
t_type_value = model_data.get("t_type")
|
||||
if t_type_value and not db.query(TransactionType).filter(TransactionType.t_type == t_type_value).first():
|
||||
errors.append(f"Transaction type '{t_type_value}' not found")
|
||||
# Validate t_code exists and matches t_type if both provided
|
||||
if "t_code" in model_data:
|
||||
t_code_value = model_data.get("t_code")
|
||||
if t_code_value:
|
||||
code_row = db.query(TransactionCode).filter(TransactionCode.t_code == t_code_value).first()
|
||||
if not code_row:
|
||||
errors.append(f"Transaction code '{t_code_value}' not found")
|
||||
else:
|
||||
ledger_t_type = model_data.get("t_type")
|
||||
if ledger_t_type and getattr(code_row, "t_type", None) and code_row.t_type != ledger_t_type:
|
||||
errors.append(
|
||||
f"Transaction code '{t_code_value}' t_type '{code_row.t_type}' does not match ledger t_type '{ledger_t_type}'"
|
||||
)
|
||||
|
||||
# Check Payment -> File and Rolodex relationships
|
||||
if model_class == Payment:
|
||||
if "file_no" in model_data:
|
||||
file_no_value = model_data.get("file_no")
|
||||
if file_no_value and not db.query(File).filter(File.file_no == file_no_value).first():
|
||||
errors.append(f"File number '{file_no_value}' not found for Payment")
|
||||
if "client_id" in model_data:
|
||||
client_id_value = model_data.get("client_id")
|
||||
if client_id_value and not db.query(Rolodex).filter(Rolodex.id == client_id_value).first():
|
||||
errors.append(f"Client ID '{client_id_value}' not found for Payment")
|
||||
|
||||
# Check QDRO -> PlanInfo (plan_id exists)
|
||||
if model_class == QDRO and "plan_id" in model_data:
|
||||
plan_id = model_data.get("plan_id")
|
||||
if plan_id:
|
||||
exists = db.query(PlanInfo).filter(PlanInfo.plan_id == plan_id).first()
|
||||
if not exists:
|
||||
errors.append(f"Plan ID '{plan_id}' not found for QDRO")
|
||||
|
||||
# Add more foreign key validations as needed
|
||||
return errors
|
||||
@@ -831,6 +1050,96 @@ async def get_available_csv_files(current_user: User = Depends(get_current_user)
|
||||
}
|
||||
|
||||
|
||||
@router.get("/template/{file_type}")
|
||||
async def download_csv_template(
|
||||
file_type: str,
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Download a minimal CSV template with required headers and one sample row.
|
||||
|
||||
Supported templates include: {list(CSV_IMPORT_TEMPLATES.keys())}
|
||||
"""
|
||||
key = (file_type or "").strip()
|
||||
if key not in CSV_IMPORT_TEMPLATES:
|
||||
raise HTTPException(status_code=400, detail=f"Unsupported template type: {file_type}. Choose one of: {list(CSV_IMPORT_TEMPLATES.keys())}")
|
||||
|
||||
content = _generate_csv_template_bytes(key)
|
||||
|
||||
from datetime import datetime as _dt
|
||||
ts = _dt.now().strftime("%Y%m%d_%H%M%S")
|
||||
safe_name = key.replace(".csv", "")
|
||||
filename = f"{safe_name}_template_{ts}.csv"
|
||||
return StreamingResponse(
|
||||
iter([content]),
|
||||
media_type="text/csv",
|
||||
headers={"Content-Disposition": f"attachment; filename=\"{filename}\""},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/templates/bundle")
|
||||
async def download_csv_templates_bundle(
|
||||
files: Optional[List[str]] = Query(None, description="Repeat for each CSV template, e.g., files=FILES.csv&files=LEDGER.csv"),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Bundle selected CSV templates into a single ZIP.
|
||||
|
||||
Example: GET /api/import/templates/bundle?files=FILES.csv&files=LEDGER.csv
|
||||
"""
|
||||
requested = files or []
|
||||
if not requested:
|
||||
raise HTTPException(status_code=400, detail="Specify at least one 'files' query parameter")
|
||||
|
||||
# Normalize and validate
|
||||
normalized: List[str] = []
|
||||
for name in requested:
|
||||
if not name:
|
||||
continue
|
||||
n = name.strip()
|
||||
if not n.lower().endswith(".csv"):
|
||||
n = f"{n}.csv"
|
||||
n = n.upper()
|
||||
if n in CSV_IMPORT_TEMPLATES:
|
||||
normalized.append(n)
|
||||
else:
|
||||
# Ignore unknowns rather than fail the whole bundle
|
||||
continue
|
||||
|
||||
# Deduplicate while preserving order
|
||||
seen = set()
|
||||
selected = []
|
||||
for n in normalized:
|
||||
if n not in seen:
|
||||
seen.add(n)
|
||||
selected.append(n)
|
||||
|
||||
if not selected:
|
||||
raise HTTPException(status_code=400, detail=f"No supported templates requested. Supported: {list(CSV_IMPORT_TEMPLATES.keys())}")
|
||||
|
||||
zip_buffer = io.BytesIO()
|
||||
with zipfile.ZipFile(zip_buffer, mode="w", compression=zipfile.ZIP_DEFLATED) as zf:
|
||||
for fname in selected:
|
||||
try:
|
||||
content = _generate_csv_template_bytes(fname)
|
||||
# Friendly name in zip: <BASENAME>_template.csv
|
||||
base = fname.replace(".CSV", "").upper()
|
||||
arcname = f"{base}_template.csv"
|
||||
zf.writestr(arcname, content)
|
||||
except HTTPException:
|
||||
# Skip unsupported just in case
|
||||
continue
|
||||
|
||||
zip_buffer.seek(0)
|
||||
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"csv_templates_{ts}.zip"
|
||||
return StreamingResponse(
|
||||
iter([zip_buffer.getvalue()]),
|
||||
media_type="application/zip",
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename=\"{filename}\""
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/upload/{file_type}")
|
||||
async def import_csv_data(
|
||||
file_type: str,
|
||||
@@ -1060,6 +1369,26 @@ async def import_csv_data(
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
# FK validation for known relationships
|
||||
fk_errors = validate_foreign_keys(model_data, model_class, db)
|
||||
if fk_errors:
|
||||
for msg in fk_errors:
|
||||
errors.append({"row": row_num, "error": msg})
|
||||
# Persist as flexible for traceability
|
||||
db.add(
|
||||
FlexibleImport(
|
||||
file_type=file_type,
|
||||
target_table=model_class.__tablename__,
|
||||
primary_key_field=None,
|
||||
primary_key_value=None,
|
||||
extra_data={
|
||||
"mapped": model_data,
|
||||
"fk_errors": fk_errors,
|
||||
},
|
||||
)
|
||||
)
|
||||
flexible_saved += 1
|
||||
continue
|
||||
instance = model_class(**model_data)
|
||||
db.add(instance)
|
||||
db.flush() # Ensure PK is available
|
||||
@@ -1136,6 +1465,9 @@ async def import_csv_data(
|
||||
"unmapped_headers": unmapped_headers,
|
||||
"flexible_saved_rows": flexible_saved,
|
||||
},
|
||||
"validation": {
|
||||
"fk_errors": len([e for e in errors if isinstance(e, dict) and 'error' in e and 'not found' in str(e['error']).lower()])
|
||||
}
|
||||
}
|
||||
# Include create/update breakdown for printers
|
||||
if file_type == "PRINTERS.csv":
|
||||
@@ -1368,6 +1700,10 @@ async def batch_validate_csv_files(
|
||||
mapping_info = _build_dynamic_mapping(csv_headers, model_class, file_type)
|
||||
mapped_headers = mapping_info["mapped_headers"]
|
||||
unmapped_headers = mapping_info["unmapped_headers"]
|
||||
header_validation = _validate_required_headers(file_type, mapped_headers)
|
||||
header_validation = _validate_required_headers(file_type, mapped_headers)
|
||||
header_validation = _validate_required_headers(file_type, mapped_headers)
|
||||
header_validation = _validate_required_headers(file_type, mapped_headers)
|
||||
|
||||
# Sample data validation
|
||||
sample_rows = []
|
||||
@@ -1394,12 +1730,13 @@ async def batch_validate_csv_files(
|
||||
|
||||
validation_results.append({
|
||||
"file_type": file_type,
|
||||
"valid": len(mapped_headers) > 0 and len(errors) == 0,
|
||||
"valid": (len(mapped_headers) > 0 and len(errors) == 0 and header_validation.get("ok", True)),
|
||||
"headers": {
|
||||
"found": csv_headers,
|
||||
"mapped": mapped_headers,
|
||||
"unmapped": unmapped_headers
|
||||
},
|
||||
"header_validation": header_validation,
|
||||
"sample_data": sample_rows[:5], # Limit sample data for batch operation
|
||||
"validation_errors": errors[:5], # First 5 errors only
|
||||
"total_errors": len(errors),
|
||||
@@ -1493,17 +1830,34 @@ async def batch_import_csv_files(
|
||||
if file_type not in CSV_MODEL_MAPPING:
|
||||
# Fallback flexible-only import for unknown file structures
|
||||
try:
|
||||
await file.seek(0)
|
||||
content = await file.read()
|
||||
# Save original upload to disk for potential reruns
|
||||
# Use async file operations for better performance
|
||||
from app.services.async_file_operations import async_file_ops
|
||||
|
||||
# Stream save to disk for potential reruns and processing
|
||||
saved_path = None
|
||||
try:
|
||||
file_path = audit_dir.joinpath(file_type)
|
||||
with open(file_path, "wb") as fh:
|
||||
fh.write(content)
|
||||
saved_path = str(file_path)
|
||||
except Exception:
|
||||
saved_path = None
|
||||
relative_path = f"import_audits/{audit_row.id}/{file_type}"
|
||||
saved_file_path, file_size, checksum = await async_file_ops.stream_upload_file(
|
||||
file, relative_path
|
||||
)
|
||||
saved_path = str(async_file_ops.base_upload_dir / relative_path)
|
||||
|
||||
# Stream read for processing
|
||||
content = b""
|
||||
async for chunk in async_file_ops.stream_read_file(relative_path):
|
||||
content += chunk
|
||||
|
||||
except Exception as e:
|
||||
# Fallback to traditional method
|
||||
await file.seek(0)
|
||||
content = await file.read()
|
||||
try:
|
||||
file_path = audit_dir.joinpath(file_type)
|
||||
with open(file_path, "wb") as fh:
|
||||
fh.write(content)
|
||||
saved_path = str(file_path)
|
||||
except Exception:
|
||||
saved_path = None
|
||||
encodings = ENCODINGS
|
||||
csv_content = None
|
||||
for encoding in encodings:
|
||||
@@ -1640,10 +1994,12 @@ async def batch_import_csv_files(
|
||||
mapping_info = _build_dynamic_mapping(csv_headers, model_class, file_type)
|
||||
mapped_headers = mapping_info["mapped_headers"]
|
||||
unmapped_headers = mapping_info["unmapped_headers"]
|
||||
header_validation = _validate_required_headers(file_type, mapped_headers)
|
||||
|
||||
imported_count = 0
|
||||
errors = []
|
||||
flexible_saved = 0
|
||||
fk_error_summary: Dict[str, int] = {}
|
||||
# Special handling: assign line numbers per form for FORM_LST.csv
|
||||
form_lst_line_counters: Dict[str, int] = {}
|
||||
|
||||
@@ -1713,6 +2069,26 @@ async def batch_import_csv_files(
|
||||
if 'file_no' not in model_data or not model_data['file_no']:
|
||||
continue # Skip ledger records without file number
|
||||
|
||||
# FK validation for known relationships
|
||||
fk_errors = validate_foreign_keys(model_data, model_class, db)
|
||||
if fk_errors:
|
||||
for msg in fk_errors:
|
||||
errors.append({"row": row_num, "error": msg})
|
||||
fk_error_summary[msg] = fk_error_summary.get(msg, 0) + 1
|
||||
db.add(
|
||||
FlexibleImport(
|
||||
file_type=file_type,
|
||||
target_table=model_class.__tablename__,
|
||||
primary_key_field=None,
|
||||
primary_key_value=None,
|
||||
extra_data=make_json_safe({
|
||||
"mapped": model_data,
|
||||
"fk_errors": fk_errors,
|
||||
}),
|
||||
)
|
||||
)
|
||||
flexible_saved += 1
|
||||
continue
|
||||
instance = model_class(**model_data)
|
||||
db.add(instance)
|
||||
db.flush()
|
||||
@@ -1779,10 +2155,15 @@ async def batch_import_csv_files(
|
||||
|
||||
results.append({
|
||||
"file_type": file_type,
|
||||
"status": "success" if len(errors) == 0 else "completed_with_errors",
|
||||
"status": "success" if (len(errors) == 0 and header_validation.get("ok", True)) else "completed_with_errors",
|
||||
"imported_count": imported_count,
|
||||
"errors": len(errors),
|
||||
"message": f"Imported {imported_count} records" + (f" with {len(errors)} errors" if errors else ""),
|
||||
"header_validation": header_validation,
|
||||
"validation": {
|
||||
"fk_errors_total": sum(fk_error_summary.values()),
|
||||
"fk_error_summary": fk_error_summary,
|
||||
},
|
||||
"auto_mapping": {
|
||||
"mapped_headers": mapped_headers,
|
||||
"unmapped_headers": unmapped_headers,
|
||||
@@ -1793,7 +2174,7 @@ async def batch_import_csv_files(
|
||||
db.add(ImportAuditFile(
|
||||
audit_id=audit_row.id,
|
||||
file_type=file_type,
|
||||
status="success" if len(errors) == 0 else "completed_with_errors",
|
||||
status="success" if (len(errors) == 0 and header_validation.get("ok", True)) else "completed_with_errors",
|
||||
imported_count=imported_count,
|
||||
errors=len(errors),
|
||||
message=f"Imported {imported_count} records" + (f" with {len(errors)} errors" if errors else ""),
|
||||
@@ -1801,6 +2182,9 @@ async def batch_import_csv_files(
|
||||
"mapped_headers": list(mapped_headers.keys()),
|
||||
"unmapped_count": len(unmapped_headers),
|
||||
"flexible_saved_rows": flexible_saved,
|
||||
"fk_errors_total": sum(fk_error_summary.values()),
|
||||
"fk_error_summary": fk_error_summary,
|
||||
"header_validation": header_validation,
|
||||
**({"saved_path": saved_path} if saved_path else {}),
|
||||
}
|
||||
))
|
||||
@@ -2138,6 +2522,7 @@ async def rerun_failed_files(
|
||||
mapping_info = _build_dynamic_mapping(csv_headers, model_class, file_type)
|
||||
mapped_headers = mapping_info["mapped_headers"]
|
||||
unmapped_headers = mapping_info["unmapped_headers"]
|
||||
header_validation = _validate_required_headers(file_type, mapped_headers)
|
||||
imported_count = 0
|
||||
errors: List[Dict[str, Any]] = []
|
||||
# Special handling: assign line numbers per form for FORM_LST.csv
|
||||
@@ -2248,20 +2633,21 @@ async def rerun_failed_files(
|
||||
total_errors += len(errors)
|
||||
results.append({
|
||||
"file_type": file_type,
|
||||
"status": "success" if len(errors) == 0 else "completed_with_errors",
|
||||
"status": "success" if (len(errors) == 0 and header_validation.get("ok", True)) else "completed_with_errors",
|
||||
"imported_count": imported_count,
|
||||
"errors": len(errors),
|
||||
"message": f"Imported {imported_count} records" + (f" with {len(errors)} errors" if errors else ""),
|
||||
"header_validation": header_validation,
|
||||
})
|
||||
try:
|
||||
db.add(ImportAuditFile(
|
||||
audit_id=rerun_audit.id,
|
||||
file_type=file_type,
|
||||
status="success" if len(errors) == 0 else "completed_with_errors",
|
||||
status="success" if (len(errors) == 0 and header_validation.get("ok", True)) else "completed_with_errors",
|
||||
imported_count=imported_count,
|
||||
errors=len(errors),
|
||||
message=f"Imported {imported_count} records" + (f" with {len(errors)} errors" if errors else ""),
|
||||
details={"saved_path": saved_path} if saved_path else {}
|
||||
details={**({"saved_path": saved_path} if saved_path else {}), "header_validation": header_validation}
|
||||
))
|
||||
db.commit()
|
||||
except Exception:
|
||||
|
||||
469
app/api/jobs.py
Normal file
469
app/api/jobs.py
Normal file
@@ -0,0 +1,469 @@
|
||||
"""
|
||||
Job Management API
|
||||
|
||||
Provides lightweight monitoring and management endpoints around `JobRecord`.
|
||||
|
||||
Notes:
|
||||
- This is not a background worker. It exposes status/history/metrics for jobs
|
||||
recorded by various synchronous operations (e.g., documents batch generation).
|
||||
- Retry creates a new queued record that references the original job. Actual
|
||||
processing is not scheduled here.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from datetime import datetime, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status, Request
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import func
|
||||
|
||||
from app.database.base import get_db
|
||||
from app.auth.security import get_current_user, get_admin_user
|
||||
from app.models.user import User
|
||||
from app.models.jobs import JobRecord
|
||||
from app.services.query_utils import apply_sorting, paginate_with_total, tokenized_ilike_filter
|
||||
from app.services.storage import get_default_storage
|
||||
from app.services.audit import audit_service
|
||||
from app.utils.logging import app_logger
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# --------------------
|
||||
# Pydantic Schemas
|
||||
# --------------------
|
||||
|
||||
class JobRecordResponse(BaseModel):
|
||||
id: int
|
||||
job_id: str
|
||||
job_type: str
|
||||
status: str
|
||||
requested_by_username: Optional[str] = None
|
||||
started_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
total_requested: int = 0
|
||||
total_success: int = 0
|
||||
total_failed: int = 0
|
||||
has_result_bundle: bool = False
|
||||
bundle_url: Optional[str] = None
|
||||
bundle_size: Optional[int] = None
|
||||
duration_seconds: Optional[float] = None
|
||||
details: Optional[Dict[str, Any]] = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class PaginatedJobsResponse(BaseModel):
|
||||
items: List[JobRecordResponse]
|
||||
total: int
|
||||
|
||||
|
||||
class JobFailRequest(BaseModel):
|
||||
reason: str = Field(..., min_length=1, max_length=1000)
|
||||
details_update: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class JobCompletionUpdate(BaseModel):
|
||||
total_success: Optional[int] = None
|
||||
total_failed: Optional[int] = None
|
||||
result_storage_path: Optional[str] = None
|
||||
result_mime_type: Optional[str] = None
|
||||
result_size: Optional[int] = None
|
||||
details_update: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class RetryRequest(BaseModel):
|
||||
note: Optional[str] = None
|
||||
|
||||
|
||||
class JobsMetricsResponse(BaseModel):
|
||||
by_status: Dict[str, int]
|
||||
by_type: Dict[str, int]
|
||||
avg_duration_seconds: Optional[float] = None
|
||||
running_count: int
|
||||
failed_last_24h: int
|
||||
completed_last_24h: int
|
||||
|
||||
|
||||
# --------------------
|
||||
# Helpers
|
||||
# --------------------
|
||||
|
||||
def _compute_duration_seconds(started_at: Optional[datetime], completed_at: Optional[datetime]) -> Optional[float]:
|
||||
if not started_at or not completed_at:
|
||||
return None
|
||||
try:
|
||||
start_utc = started_at if started_at.tzinfo else started_at.replace(tzinfo=timezone.utc)
|
||||
end_utc = completed_at if completed_at.tzinfo else completed_at.replace(tzinfo=timezone.utc)
|
||||
return max((end_utc - start_utc).total_seconds(), 0.0)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _to_response(
|
||||
job: JobRecord,
|
||||
*,
|
||||
include_url: bool = False,
|
||||
) -> JobRecordResponse:
|
||||
has_bundle = bool(getattr(job, "result_storage_path", None))
|
||||
bundle_url = None
|
||||
if include_url and has_bundle:
|
||||
try:
|
||||
bundle_url = get_default_storage().public_url(job.result_storage_path) # type: ignore[arg-type]
|
||||
except Exception:
|
||||
bundle_url = None
|
||||
return JobRecordResponse(
|
||||
id=job.id,
|
||||
job_id=job.job_id,
|
||||
job_type=job.job_type,
|
||||
status=job.status,
|
||||
requested_by_username=getattr(job, "requested_by_username", None),
|
||||
started_at=getattr(job, "started_at", None),
|
||||
completed_at=getattr(job, "completed_at", None),
|
||||
total_requested=getattr(job, "total_requested", 0) or 0,
|
||||
total_success=getattr(job, "total_success", 0) or 0,
|
||||
total_failed=getattr(job, "total_failed", 0) or 0,
|
||||
has_result_bundle=has_bundle,
|
||||
bundle_url=bundle_url,
|
||||
bundle_size=getattr(job, "result_size", None),
|
||||
duration_seconds=_compute_duration_seconds(getattr(job, "started_at", None), getattr(job, "completed_at", None)),
|
||||
details=getattr(job, "details", None),
|
||||
)
|
||||
|
||||
|
||||
# --------------------
|
||||
# Endpoints
|
||||
# --------------------
|
||||
|
||||
|
||||
@router.get("/", response_model=Union[List[JobRecordResponse], PaginatedJobsResponse])
|
||||
async def list_jobs(
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
include_total: bool = Query(False, description="When true, returns {items, total} instead of a plain list"),
|
||||
include_urls: bool = Query(False, description="Include bundle URLs in responses"),
|
||||
status_filter: Optional[str] = Query(None, description="Filter by status"),
|
||||
type_filter: Optional[str] = Query(None, description="Filter by job type"),
|
||||
requested_by: Optional[str] = Query(None, description="Filter by username"),
|
||||
search: Optional[str] = Query(None, description="Tokenized search across job_id, type, status, username"),
|
||||
mine: bool = Query(True, description="When true, restricts to current user's jobs (admins can set false)"),
|
||||
sort_by: Optional[str] = Query("started", description="Sort by: started, completed, status, type"),
|
||||
sort_dir: Optional[str] = Query("desc", description="Sort direction: asc or desc"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
query = db.query(JobRecord)
|
||||
|
||||
# Scope: non-admin users always restricted to their jobs
|
||||
is_admin = bool(getattr(current_user, "is_admin", False))
|
||||
if mine or not is_admin:
|
||||
query = query.filter(JobRecord.requested_by_username == current_user.username)
|
||||
|
||||
if status_filter:
|
||||
query = query.filter(JobRecord.status == status_filter)
|
||||
if type_filter:
|
||||
query = query.filter(JobRecord.job_type == type_filter)
|
||||
if requested_by and is_admin:
|
||||
query = query.filter(JobRecord.requested_by_username == requested_by)
|
||||
|
||||
if search:
|
||||
tokens = [t for t in (search or "").split() if t]
|
||||
filter_expr = tokenized_ilike_filter(tokens, [
|
||||
JobRecord.job_id,
|
||||
JobRecord.job_type,
|
||||
JobRecord.status,
|
||||
JobRecord.requested_by_username,
|
||||
])
|
||||
if filter_expr is not None:
|
||||
query = query.filter(filter_expr)
|
||||
|
||||
# Sorting
|
||||
query = apply_sorting(
|
||||
query,
|
||||
sort_by,
|
||||
sort_dir,
|
||||
allowed={
|
||||
"started": [JobRecord.started_at, JobRecord.id],
|
||||
"completed": [JobRecord.completed_at, JobRecord.id],
|
||||
"status": [JobRecord.status, JobRecord.started_at],
|
||||
"type": [JobRecord.job_type, JobRecord.started_at],
|
||||
},
|
||||
)
|
||||
|
||||
jobs, total = paginate_with_total(query, skip, limit, include_total)
|
||||
items = [_to_response(j, include_url=include_urls) for j in jobs]
|
||||
if include_total:
|
||||
return {"items": items, "total": total or 0}
|
||||
return items
|
||||
|
||||
|
||||
@router.get("/{job_id}", response_model=JobRecordResponse)
|
||||
async def get_job(
|
||||
job_id: str,
|
||||
include_url: bool = Query(True),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
job = db.query(JobRecord).filter(JobRecord.job_id == job_id).first()
|
||||
if not job:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Job not found")
|
||||
|
||||
# Authorization: non-admin users can only access their jobs
|
||||
if not getattr(current_user, "is_admin", False):
|
||||
if getattr(job, "requested_by_username", None) != current_user.username:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not enough permissions")
|
||||
|
||||
return _to_response(job, include_url=include_url)
|
||||
|
||||
|
||||
@router.post("/{job_id}/mark-failed", response_model=JobRecordResponse)
|
||||
async def mark_job_failed(
|
||||
job_id: str,
|
||||
payload: JobFailRequest,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_admin_user),
|
||||
):
|
||||
job = db.query(JobRecord).filter(JobRecord.job_id == job_id).first()
|
||||
if not job:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Job not found")
|
||||
|
||||
job.status = "failed"
|
||||
job.completed_at = datetime.now(timezone.utc)
|
||||
details = dict(getattr(job, "details", {}) or {})
|
||||
details["last_error"] = payload.reason
|
||||
if payload.details_update:
|
||||
details.update(payload.details_update)
|
||||
job.details = details
|
||||
db.commit()
|
||||
db.refresh(job)
|
||||
|
||||
try:
|
||||
audit_service.log_action(
|
||||
db=db,
|
||||
action="FAIL",
|
||||
resource_type="JOB",
|
||||
user=current_user,
|
||||
resource_id=job.job_id,
|
||||
details={"reason": payload.reason},
|
||||
request=request,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return _to_response(job, include_url=True)
|
||||
|
||||
|
||||
@router.post("/{job_id}/mark-running", response_model=JobRecordResponse)
|
||||
async def mark_job_running(
|
||||
job_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_admin_user),
|
||||
):
|
||||
job = db.query(JobRecord).filter(JobRecord.job_id == job_id).first()
|
||||
if not job:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Job not found")
|
||||
|
||||
job.status = "running"
|
||||
# Reset start time when transitioning to running
|
||||
job.started_at = datetime.now(timezone.utc)
|
||||
job.completed_at = None
|
||||
db.commit()
|
||||
db.refresh(job)
|
||||
|
||||
try:
|
||||
audit_service.log_action(
|
||||
db=db,
|
||||
action="RUNNING",
|
||||
resource_type="JOB",
|
||||
user=current_user,
|
||||
resource_id=job.job_id,
|
||||
details=None,
|
||||
request=request,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return _to_response(job)
|
||||
|
||||
|
||||
@router.post("/{job_id}/mark-completed", response_model=JobRecordResponse)
|
||||
async def mark_job_completed(
|
||||
job_id: str,
|
||||
payload: JobCompletionUpdate,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_admin_user),
|
||||
):
|
||||
job = db.query(JobRecord).filter(JobRecord.job_id == job_id).first()
|
||||
if not job:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Job not found")
|
||||
|
||||
job.status = "completed"
|
||||
job.completed_at = datetime.now(timezone.utc)
|
||||
if payload.total_success is not None:
|
||||
job.total_success = max(int(payload.total_success), 0)
|
||||
if payload.total_failed is not None:
|
||||
job.total_failed = max(int(payload.total_failed), 0)
|
||||
if payload.result_storage_path is not None:
|
||||
job.result_storage_path = payload.result_storage_path
|
||||
if payload.result_mime_type is not None:
|
||||
job.result_mime_type = payload.result_mime_type
|
||||
if payload.result_size is not None:
|
||||
job.result_size = max(int(payload.result_size), 0)
|
||||
|
||||
if payload.details_update:
|
||||
details = dict(getattr(job, "details", {}) or {})
|
||||
details.update(payload.details_update)
|
||||
job.details = details
|
||||
|
||||
db.commit()
|
||||
db.refresh(job)
|
||||
|
||||
try:
|
||||
audit_service.log_action(
|
||||
db=db,
|
||||
action="COMPLETE",
|
||||
resource_type="JOB",
|
||||
user=current_user,
|
||||
resource_id=job.job_id,
|
||||
details={
|
||||
"total_success": job.total_success,
|
||||
"total_failed": job.total_failed,
|
||||
},
|
||||
request=request,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return _to_response(job, include_url=True)
|
||||
|
||||
|
||||
@router.post("/{job_id}/retry")
|
||||
async def retry_job(
|
||||
job_id: str,
|
||||
payload: RetryRequest,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_admin_user),
|
||||
):
|
||||
"""
|
||||
Create a new queued job record that references the original job.
|
||||
|
||||
This endpoint does not execute the job; it enables monitoring UIs to
|
||||
track retry intent and external workers to pick it up if/when implemented.
|
||||
"""
|
||||
job = db.query(JobRecord).filter(JobRecord.job_id == job_id).first()
|
||||
if not job:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Job not found")
|
||||
|
||||
new_job_id = uuid4().hex
|
||||
new_details = dict(getattr(job, "details", {}) or {})
|
||||
new_details["retry_of"] = job.job_id
|
||||
if payload.note:
|
||||
new_details["retry_note"] = payload.note
|
||||
|
||||
cloned = JobRecord(
|
||||
job_id=new_job_id,
|
||||
job_type=job.job_type,
|
||||
status="queued",
|
||||
requested_by_username=current_user.username,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
completed_at=None,
|
||||
total_requested=getattr(job, "total_requested", 0) or 0,
|
||||
total_success=0,
|
||||
total_failed=0,
|
||||
result_storage_path=None,
|
||||
result_mime_type=None,
|
||||
result_size=None,
|
||||
details=new_details,
|
||||
)
|
||||
db.add(cloned)
|
||||
db.commit()
|
||||
|
||||
try:
|
||||
audit_service.log_action(
|
||||
db=db,
|
||||
action="RETRY",
|
||||
resource_type="JOB",
|
||||
user=current_user,
|
||||
resource_id=job.job_id,
|
||||
details={"new_job_id": new_job_id},
|
||||
request=request,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {"message": "Retry created", "job_id": new_job_id}
|
||||
|
||||
|
||||
@router.get("/metrics/summary", response_model=JobsMetricsResponse)
|
||||
async def jobs_metrics(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_admin_user),
|
||||
):
|
||||
"""
|
||||
Basic metrics for dashboards/monitoring.
|
||||
"""
|
||||
# By status
|
||||
rows = db.query(JobRecord.status, func.count(JobRecord.id)).group_by(JobRecord.status).all()
|
||||
by_status = {str(k or "unknown"): int(v or 0) for k, v in rows}
|
||||
|
||||
# By type
|
||||
rows = db.query(JobRecord.job_type, func.count(JobRecord.id)).group_by(JobRecord.job_type).all()
|
||||
by_type = {str(k or "unknown"): int(v or 0) for k, v in rows}
|
||||
|
||||
# Running count
|
||||
try:
|
||||
running_count = db.query(func.count(JobRecord.id)).filter(JobRecord.status == "running").scalar() or 0
|
||||
except Exception:
|
||||
running_count = 0
|
||||
|
||||
# Last 24h stats
|
||||
cutoff = datetime.now(timezone.utc).replace(microsecond=0)
|
||||
try:
|
||||
failed_last_24h = db.query(func.count(JobRecord.id)).filter(
|
||||
JobRecord.status == "failed",
|
||||
(JobRecord.completed_at != None), # noqa: E711
|
||||
JobRecord.completed_at >= (cutoff.replace(hour=0, minute=0, second=0) - func.cast(1, func.INTEGER)) # type: ignore
|
||||
).scalar() or 0
|
||||
except Exception:
|
||||
# Fallback without date condition if backend doesn't support the above cast
|
||||
failed_last_24h = db.query(func.count(JobRecord.id)).filter(JobRecord.status == "failed").scalar() or 0
|
||||
|
||||
try:
|
||||
completed_last_24h = db.query(func.count(JobRecord.id)).filter(
|
||||
JobRecord.status == "completed",
|
||||
(JobRecord.completed_at != None), # noqa: E711
|
||||
JobRecord.completed_at >= (cutoff.replace(hour=0, minute=0, second=0) - func.cast(1, func.INTEGER)) # type: ignore
|
||||
).scalar() or 0
|
||||
except Exception:
|
||||
completed_last_24h = db.query(func.count(JobRecord.id)).filter(JobRecord.status == "completed").scalar() or 0
|
||||
|
||||
# Average duration on completed
|
||||
try:
|
||||
completed_jobs = db.query(JobRecord.started_at, JobRecord.completed_at).filter(JobRecord.completed_at != None).limit(500).all() # noqa: E711
|
||||
durations: List[float] = []
|
||||
for s, c in completed_jobs:
|
||||
d = _compute_duration_seconds(s, c)
|
||||
if d is not None:
|
||||
durations.append(d)
|
||||
avg_duration = (sum(durations) / len(durations)) if durations else None
|
||||
except Exception:
|
||||
avg_duration = None
|
||||
|
||||
return JobsMetricsResponse(
|
||||
by_status=by_status,
|
||||
by_type=by_type,
|
||||
avg_duration_seconds=(round(avg_duration, 2) if isinstance(avg_duration, (int, float)) else None),
|
||||
running_count=int(running_count),
|
||||
failed_last_24h=int(failed_last_24h),
|
||||
completed_last_24h=int(completed_last_24h),
|
||||
)
|
||||
|
||||
|
||||
258
app/api/labels.py
Normal file
258
app/api/labels.py
Normal file
@@ -0,0 +1,258 @@
|
||||
"""
|
||||
Mailing Labels & Envelopes API
|
||||
|
||||
Endpoints:
|
||||
- POST /api/labels/rolodex/labels-5160
|
||||
- POST /api/labels/files/labels-5160
|
||||
- POST /api/labels/rolodex/envelopes
|
||||
- POST /api/labels/files/envelopes
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional, Sequence
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
import io
|
||||
import csv
|
||||
|
||||
from app.auth.security import get_current_user
|
||||
from app.database.base import get_db
|
||||
from app.models.user import User
|
||||
from app.models.rolodex import Rolodex
|
||||
from app.services.customers_search import apply_customer_filters
|
||||
from app.services.mailing import (
|
||||
Address,
|
||||
build_addresses_from_files,
|
||||
build_addresses_from_rolodex,
|
||||
build_address_from_rolodex,
|
||||
render_labels_html,
|
||||
render_envelopes_html,
|
||||
save_html_bytes,
|
||||
)
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class Labels5160Request(BaseModel):
|
||||
ids: List[str] = Field(default_factory=list, description="Rolodex IDs or File numbers depending on route")
|
||||
start_position: int = Field(default=1, ge=1, le=30, description="Starting label position on sheet (1-30)")
|
||||
include_name: bool = Field(default=True, description="Include name/company as first line")
|
||||
|
||||
|
||||
class GenerateResult(BaseModel):
|
||||
url: Optional[str] = None
|
||||
storage_path: Optional[str] = None
|
||||
mime_type: str
|
||||
size: int
|
||||
created_at: str
|
||||
|
||||
|
||||
@router.post("/rolodex/labels-5160", response_model=GenerateResult)
|
||||
async def generate_rolodex_labels(
|
||||
payload: Labels5160Request,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
if not payload.ids:
|
||||
raise HTTPException(status_code=400, detail="No rolodex IDs provided")
|
||||
addresses = build_addresses_from_rolodex(db, payload.ids)
|
||||
if not addresses:
|
||||
raise HTTPException(status_code=404, detail="No matching rolodex entries found")
|
||||
html_bytes = render_labels_html(addresses, start_position=payload.start_position, include_name=payload.include_name)
|
||||
result = save_html_bytes(html_bytes, filename_hint=f"labels_5160_{datetime.now(timezone.utc).strftime('%Y%m%d_%H%M%S')}", subdir="mailing/labels")
|
||||
return GenerateResult(**result)
|
||||
|
||||
|
||||
@router.post("/files/labels-5160", response_model=GenerateResult)
|
||||
async def generate_file_labels(
|
||||
payload: Labels5160Request,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
if not payload.ids:
|
||||
raise HTTPException(status_code=400, detail="No file numbers provided")
|
||||
addresses = build_addresses_from_files(db, payload.ids)
|
||||
if not addresses:
|
||||
raise HTTPException(status_code=404, detail="No matching file owners found")
|
||||
html_bytes = render_labels_html(addresses, start_position=payload.start_position, include_name=payload.include_name)
|
||||
result = save_html_bytes(html_bytes, filename_hint=f"labels_5160_{datetime.now(timezone.utc).strftime('%Y%m%d_%H%M%S')}", subdir="mailing/labels")
|
||||
return GenerateResult(**result)
|
||||
|
||||
|
||||
class EnvelopesRequest(BaseModel):
|
||||
ids: List[str] = Field(default_factory=list, description="Rolodex IDs or File numbers depending on route")
|
||||
include_name: bool = Field(default=True)
|
||||
return_address_lines: Optional[List[str]] = Field(default=None, description="Lines for return address (top-left)")
|
||||
|
||||
|
||||
@router.post("/rolodex/envelopes", response_model=GenerateResult)
|
||||
async def generate_rolodex_envelopes(
|
||||
payload: EnvelopesRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
if not payload.ids:
|
||||
raise HTTPException(status_code=400, detail="No rolodex IDs provided")
|
||||
addresses = build_addresses_from_rolodex(db, payload.ids)
|
||||
if not addresses:
|
||||
raise HTTPException(status_code=404, detail="No matching rolodex entries found")
|
||||
html_bytes = render_envelopes_html(addresses, return_address_lines=payload.return_address_lines, include_name=payload.include_name)
|
||||
result = save_html_bytes(html_bytes, filename_hint=f"envelopes_{datetime.now(timezone.utc).strftime('%Y%m%d_%H%M%S')}", subdir="mailing/envelopes")
|
||||
return GenerateResult(**result)
|
||||
|
||||
|
||||
@router.post("/files/envelopes", response_model=GenerateResult)
|
||||
async def generate_file_envelopes(
|
||||
payload: EnvelopesRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
if not payload.ids:
|
||||
raise HTTPException(status_code=400, detail="No file numbers provided")
|
||||
addresses = build_addresses_from_files(db, payload.ids)
|
||||
if not addresses:
|
||||
raise HTTPException(status_code=404, detail="No matching file owners found")
|
||||
html_bytes = render_envelopes_html(addresses, return_address_lines=payload.return_address_lines, include_name=payload.include_name)
|
||||
result = save_html_bytes(html_bytes, filename_hint=f"envelopes_{datetime.now(timezone.utc).strftime('%Y%m%d_%H%M%S')}", subdir="mailing/envelopes")
|
||||
return GenerateResult(**result)
|
||||
|
||||
|
||||
@router.get("/rolodex/labels-5160/export")
|
||||
async def export_rolodex_labels_5160(
|
||||
start_position: int = Query(1, ge=1, le=30, description="Starting label position on sheet (1-30)"),
|
||||
include_name: bool = Query(True, description="Include name/company as first line"),
|
||||
group: Optional[str] = Query(None, description="Filter by customer group (exact match)"),
|
||||
groups: Optional[List[str]] = Query(None, description="Filter by multiple groups (repeat param)"),
|
||||
name_prefix: Optional[str] = Query(None, description="Prefix search across first/last name"),
|
||||
format: str = Query("html", description="Output format: html | csv"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Generate Avery 5160 labels for Rolodex entries selected by filters and stream as HTML or CSV."""
|
||||
fmt = (format or "").strip().lower()
|
||||
if fmt not in {"html", "csv"}:
|
||||
raise HTTPException(status_code=400, detail="Invalid format. Use 'html' or 'csv'.")
|
||||
|
||||
q = db.query(Rolodex)
|
||||
q = apply_customer_filters(
|
||||
q,
|
||||
search=None,
|
||||
group=group,
|
||||
state=None,
|
||||
groups=groups,
|
||||
states=None,
|
||||
name_prefix=name_prefix,
|
||||
)
|
||||
entries = q.all()
|
||||
if not entries:
|
||||
raise HTTPException(status_code=404, detail="No matching rolodex entries found")
|
||||
|
||||
if fmt == "html":
|
||||
addresses = [build_address_from_rolodex(r) for r in entries]
|
||||
html_bytes = render_labels_html(addresses, start_position=start_position, include_name=include_name)
|
||||
from fastapi.responses import StreamingResponse
|
||||
ts = datetime.now(timezone.utc).strftime('%Y%m%d_%H%M%S')
|
||||
filename = f"labels_5160_{ts}.html"
|
||||
return StreamingResponse(
|
||||
iter([html_bytes]),
|
||||
media_type="text/html",
|
||||
headers={"Content-Disposition": f"attachment; filename=\"{filename}\""},
|
||||
)
|
||||
else:
|
||||
# CSV of address fields
|
||||
output = io.StringIO()
|
||||
writer = csv.writer(output)
|
||||
writer.writerow(["Name", "Address 1", "Address 2", "Address 3", "City", "State", "ZIP"])
|
||||
for r in entries:
|
||||
addr = build_address_from_rolodex(r)
|
||||
writer.writerow([
|
||||
addr.display_name,
|
||||
r.a1 or "",
|
||||
r.a2 or "",
|
||||
r.a3 or "",
|
||||
r.city or "",
|
||||
r.abrev or "",
|
||||
r.zip or "",
|
||||
])
|
||||
output.seek(0)
|
||||
from fastapi.responses import StreamingResponse
|
||||
ts = datetime.now(timezone.utc).strftime('%Y%m%d_%H%M%S')
|
||||
filename = f"labels_5160_{ts}.csv"
|
||||
return StreamingResponse(
|
||||
iter([output.getvalue()]),
|
||||
media_type="text/csv",
|
||||
headers={"Content-Disposition": f"attachment; filename=\"{filename}\""},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/rolodex/envelopes/export")
|
||||
async def export_rolodex_envelopes(
|
||||
include_name: bool = Query(True, description="Include name/company"),
|
||||
return_address_lines: Optional[List[str]] = Query(None, description="Optional return address lines"),
|
||||
group: Optional[str] = Query(None, description="Filter by customer group (exact match)"),
|
||||
groups: Optional[List[str]] = Query(None, description="Filter by multiple groups (repeat param)"),
|
||||
name_prefix: Optional[str] = Query(None, description="Prefix search across first/last name"),
|
||||
format: str = Query("html", description="Output format: html | csv"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Generate envelopes for Rolodex entries selected by filters and stream as HTML or CSV of addresses."""
|
||||
fmt = (format or "").strip().lower()
|
||||
if fmt not in {"html", "csv"}:
|
||||
raise HTTPException(status_code=400, detail="Invalid format. Use 'html' or 'csv'.")
|
||||
|
||||
q = db.query(Rolodex)
|
||||
q = apply_customer_filters(
|
||||
q,
|
||||
search=None,
|
||||
group=group,
|
||||
state=None,
|
||||
groups=groups,
|
||||
states=None,
|
||||
name_prefix=name_prefix,
|
||||
)
|
||||
entries = q.all()
|
||||
if not entries:
|
||||
raise HTTPException(status_code=404, detail="No matching rolodex entries found")
|
||||
|
||||
if fmt == "html":
|
||||
addresses = [build_address_from_rolodex(r) for r in entries]
|
||||
html_bytes = render_envelopes_html(addresses, return_address_lines=return_address_lines, include_name=include_name)
|
||||
from fastapi.responses import StreamingResponse
|
||||
ts = datetime.now(timezone.utc).strftime('%Y%m%d_%H%M%S')
|
||||
filename = f"envelopes_{ts}.html"
|
||||
return StreamingResponse(
|
||||
iter([html_bytes]),
|
||||
media_type="text/html",
|
||||
headers={"Content-Disposition": f"attachment; filename=\"{filename}\""},
|
||||
)
|
||||
else:
|
||||
# CSV of address fields
|
||||
output = io.StringIO()
|
||||
writer = csv.writer(output)
|
||||
writer.writerow(["Name", "Address 1", "Address 2", "Address 3", "City", "State", "ZIP"])
|
||||
for r in entries:
|
||||
addr = build_address_from_rolodex(r)
|
||||
writer.writerow([
|
||||
addr.display_name,
|
||||
r.a1 or "",
|
||||
r.a2 or "",
|
||||
r.a3 or "",
|
||||
r.city or "",
|
||||
r.abrev or "",
|
||||
r.zip or "",
|
||||
])
|
||||
output.seek(0)
|
||||
from fastapi.responses import StreamingResponse
|
||||
ts = datetime.now(timezone.utc).strftime('%Y%m%d_%H%M%S')
|
||||
filename = f"envelopes_{ts}.csv"
|
||||
return StreamingResponse(
|
||||
iter([output.getvalue()]),
|
||||
media_type="text/csv",
|
||||
headers={"Content-Disposition": f"attachment; filename=\"{filename}\""},
|
||||
)
|
||||
|
||||
230
app/api/pension_valuation.py
Normal file
230
app/api/pension_valuation.py
Normal file
@@ -0,0 +1,230 @@
|
||||
"""
|
||||
Pension Valuation API endpoints
|
||||
|
||||
Exposes endpoints under /api/pensions/valuation for:
|
||||
- Single-life present value
|
||||
- Joint-survivor present value
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, Optional, List, Union, Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database.base import get_db
|
||||
from app.models.user import User
|
||||
from app.auth.security import get_current_user
|
||||
from app.services.pension_valuation import (
|
||||
SingleLifeInputs,
|
||||
JointSurvivorInputs,
|
||||
present_value_single_life,
|
||||
present_value_joint_survivor,
|
||||
)
|
||||
|
||||
|
||||
router = APIRouter(prefix="/valuation", tags=["pensions", "pensions-valuation"])
|
||||
|
||||
|
||||
class SingleLifeRequest(BaseModel):
|
||||
monthly_benefit: float = Field(ge=0)
|
||||
term_months: int = Field(ge=0, description="Number of months in evaluation horizon")
|
||||
start_age: Optional[int] = Field(default=None, ge=0)
|
||||
sex: str = Field(description="M, F, or A (all)")
|
||||
race: str = Field(description="W, B, H, or A (all)")
|
||||
discount_rate: float = Field(default=0.0, ge=0, description="Annual percent, e.g. 3.0")
|
||||
cola_rate: float = Field(default=0.0, ge=0, description="Annual percent, e.g. 2.0")
|
||||
defer_months: float = Field(default=0, ge=0, description="Months to delay first payment (supports fractional)")
|
||||
payment_period_months: int = Field(default=1, ge=1, description="Months per payment (1=monthly, 3=quarterly, 12=annual)")
|
||||
certain_months: int = Field(default=0, ge=0, description="Guaranteed months from commencement regardless of mortality")
|
||||
cola_mode: str = Field(default="monthly", description="'monthly' or 'annual_prorated'")
|
||||
cola_cap_percent: Optional[float] = Field(default=None, ge=0)
|
||||
interpolation_method: str = Field(default="linear", description="'linear' or 'step' for NA interpolation")
|
||||
max_age: Optional[int] = Field(default=None, ge=0, description="Optional cap on participant age for term truncation")
|
||||
|
||||
|
||||
class SingleLifeResponse(BaseModel):
|
||||
pv: float
|
||||
|
||||
|
||||
@router.post("/single-life", response_model=SingleLifeResponse)
|
||||
async def value_single_life(
|
||||
payload: SingleLifeRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
try:
|
||||
pv = present_value_single_life(
|
||||
db,
|
||||
SingleLifeInputs(
|
||||
monthly_benefit=payload.monthly_benefit,
|
||||
term_months=payload.term_months,
|
||||
start_age=payload.start_age,
|
||||
sex=payload.sex,
|
||||
race=payload.race,
|
||||
discount_rate=payload.discount_rate,
|
||||
cola_rate=payload.cola_rate,
|
||||
defer_months=payload.defer_months,
|
||||
payment_period_months=payload.payment_period_months,
|
||||
certain_months=payload.certain_months,
|
||||
cola_mode=payload.cola_mode,
|
||||
cola_cap_percent=payload.cola_cap_percent,
|
||||
interpolation_method=payload.interpolation_method,
|
||||
max_age=payload.max_age,
|
||||
),
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
return SingleLifeResponse(pv=float(round(pv, 2)))
|
||||
|
||||
|
||||
class JointSurvivorRequest(BaseModel):
|
||||
monthly_benefit: float = Field(ge=0)
|
||||
term_months: int = Field(ge=0)
|
||||
participant_age: Optional[int] = Field(default=None, ge=0)
|
||||
participant_sex: str
|
||||
participant_race: str
|
||||
spouse_age: Optional[int] = Field(default=None, ge=0)
|
||||
spouse_sex: str
|
||||
spouse_race: str
|
||||
survivor_percent: float = Field(ge=0, le=100, description="Percent of benefit to spouse on participant death")
|
||||
discount_rate: float = Field(default=0.0, ge=0)
|
||||
cola_rate: float = Field(default=0.0, ge=0)
|
||||
defer_months: float = Field(default=0, ge=0)
|
||||
payment_period_months: int = Field(default=1, ge=1)
|
||||
certain_months: int = Field(default=0, ge=0)
|
||||
cola_mode: str = Field(default="monthly")
|
||||
cola_cap_percent: Optional[float] = Field(default=None, ge=0)
|
||||
survivor_basis: str = Field(default="contingent", description="'contingent' or 'last_survivor'")
|
||||
survivor_commence_participant_only: bool = Field(default=False, description="If true, survivor component uses participant survival as commencement basis")
|
||||
interpolation_method: str = Field(default="linear")
|
||||
max_age: Optional[int] = Field(default=None, ge=0)
|
||||
|
||||
|
||||
class JointSurvivorResponse(BaseModel):
|
||||
pv_total: float
|
||||
pv_participant_component: float
|
||||
pv_survivor_component: float
|
||||
|
||||
|
||||
@router.post("/joint-survivor", response_model=JointSurvivorResponse)
|
||||
async def value_joint_survivor(
|
||||
payload: JointSurvivorRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
try:
|
||||
result: Dict[str, float] = present_value_joint_survivor(
|
||||
db,
|
||||
JointSurvivorInputs(
|
||||
monthly_benefit=payload.monthly_benefit,
|
||||
term_months=payload.term_months,
|
||||
participant_age=payload.participant_age,
|
||||
participant_sex=payload.participant_sex,
|
||||
participant_race=payload.participant_race,
|
||||
spouse_age=payload.spouse_age,
|
||||
spouse_sex=payload.spouse_sex,
|
||||
spouse_race=payload.spouse_race,
|
||||
survivor_percent=payload.survivor_percent,
|
||||
discount_rate=payload.discount_rate,
|
||||
cola_rate=payload.cola_rate,
|
||||
defer_months=payload.defer_months,
|
||||
payment_period_months=payload.payment_period_months,
|
||||
certain_months=payload.certain_months,
|
||||
cola_mode=payload.cola_mode,
|
||||
cola_cap_percent=payload.cola_cap_percent,
|
||||
survivor_basis=payload.survivor_basis,
|
||||
survivor_commence_participant_only=payload.survivor_commence_participant_only,
|
||||
interpolation_method=payload.interpolation_method,
|
||||
max_age=payload.max_age,
|
||||
),
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
|
||||
# Round to 2 decimals for response
|
||||
return JointSurvivorResponse(
|
||||
pv_total=float(round(result["pv_total"], 2)),
|
||||
pv_participant_component=float(round(result["pv_participant_component"], 2)),
|
||||
pv_survivor_component=float(round(result["pv_survivor_component"], 2)),
|
||||
)
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
error: str
|
||||
|
||||
class BatchSingleLifeRequest(BaseModel):
|
||||
# Accept raw dicts to allow per-item validation inside the loop (avoid 422 on the entire batch)
|
||||
items: List[Dict[str, Any]]
|
||||
|
||||
class BatchSingleLifeItemResponse(BaseModel):
|
||||
success: bool
|
||||
result: Optional[SingleLifeResponse] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
class BatchSingleLifeResponse(BaseModel):
|
||||
results: List[BatchSingleLifeItemResponse]
|
||||
|
||||
@router.post("/batch-single-life", response_model=BatchSingleLifeResponse)
|
||||
async def batch_value_single_life(
|
||||
payload: BatchSingleLifeRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
results = []
|
||||
for item in payload.items:
|
||||
try:
|
||||
inputs = SingleLifeInputs(**item)
|
||||
pv = present_value_single_life(db, inputs)
|
||||
results.append(BatchSingleLifeItemResponse(
|
||||
success=True,
|
||||
result=SingleLifeResponse(pv=float(round(pv, 2))),
|
||||
))
|
||||
except ValueError as e:
|
||||
results.append(BatchSingleLifeItemResponse(
|
||||
success=False,
|
||||
error=str(e),
|
||||
))
|
||||
return BatchSingleLifeResponse(results=results)
|
||||
|
||||
class BatchJointSurvivorRequest(BaseModel):
|
||||
# Accept raw dicts to allow per-item validation inside the loop (avoid 422 on the entire batch)
|
||||
items: List[Dict[str, Any]]
|
||||
|
||||
class BatchJointSurvivorItemResponse(BaseModel):
|
||||
success: bool
|
||||
result: Optional[JointSurvivorResponse] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
class BatchJointSurvivorResponse(BaseModel):
|
||||
results: List[BatchJointSurvivorItemResponse]
|
||||
|
||||
@router.post("/batch-joint-survivor", response_model=BatchJointSurvivorResponse)
|
||||
async def batch_value_joint_survivor(
|
||||
payload: BatchJointSurvivorRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
results = []
|
||||
for item in payload.items:
|
||||
try:
|
||||
inputs = JointSurvivorInputs(**item)
|
||||
result = present_value_joint_survivor(db, inputs)
|
||||
results.append(BatchJointSurvivorItemResponse(
|
||||
success=True,
|
||||
result=JointSurvivorResponse(
|
||||
pv_total=float(round(result["pv_total"], 2)),
|
||||
pv_participant_component=float(round(result["pv_participant_component"], 2)),
|
||||
pv_survivor_component=float(round(result["pv_survivor_component"], 2)),
|
||||
),
|
||||
))
|
||||
except ValueError as e:
|
||||
results.append(BatchJointSurvivorItemResponse(
|
||||
success=False,
|
||||
error=str(e),
|
||||
))
|
||||
return BatchJointSurvivorResponse(results=results)
|
||||
|
||||
|
||||
@@ -579,7 +579,7 @@ async def generate_qdro_document(
|
||||
"MATTER": file_obj.regarding,
|
||||
})
|
||||
# Merge with provided context
|
||||
context = build_context({**base_ctx, **(payload.context or {})})
|
||||
context = build_context({**base_ctx, **(payload.context or {})}, "file", qdro.file_no)
|
||||
resolved, unresolved = resolve_tokens(db, tokens, context)
|
||||
|
||||
output_bytes = content
|
||||
@@ -591,7 +591,21 @@ async def generate_qdro_document(
|
||||
audit_service.log_action(db, action="GENERATE", resource_type="QDRO", user=current_user, resource_id=qdro_id, details={"template_id": payload.template_id, "version_id": version_id, "unresolved": unresolved})
|
||||
except Exception:
|
||||
pass
|
||||
return GenerateResponse(resolved=resolved, unresolved=unresolved, output_mime_type=output_mime, output_size=len(output_bytes))
|
||||
# Sanitize resolved values to ensure JSON-serializable output
|
||||
def _json_sanitize(value: Any) -> Any:
|
||||
if value is None or isinstance(value, (str, int, float, bool)):
|
||||
return value
|
||||
if isinstance(value, (date, datetime)):
|
||||
return value.isoformat()
|
||||
if isinstance(value, (list, tuple)):
|
||||
return [_json_sanitize(v) for v in value]
|
||||
if isinstance(value, dict):
|
||||
return {k: _json_sanitize(v) for k, v in value.items()}
|
||||
# Fallback: stringify unsupported types (e.g., functions)
|
||||
return str(value)
|
||||
|
||||
sanitized_resolved = {k: _json_sanitize(v) for k, v in resolved.items()}
|
||||
return GenerateResponse(resolved=sanitized_resolved, unresolved=unresolved, output_mime_type=output_mime, output_size=len(output_bytes))
|
||||
|
||||
|
||||
class PlanInfoCreate(BaseModel):
|
||||
|
||||
@@ -368,8 +368,10 @@ async def advanced_search(
|
||||
|
||||
# Cache lookup keyed by user and entire criteria (including pagination)
|
||||
try:
|
||||
cached = await cache_get_json(
|
||||
kind="advanced",
|
||||
from app.services.adaptive_cache import adaptive_cache_get
|
||||
cached = await adaptive_cache_get(
|
||||
cache_type="advanced",
|
||||
cache_key="advanced_search",
|
||||
user_id=str(getattr(current_user, "id", "")),
|
||||
parts={"criteria": criteria.model_dump(mode="json")},
|
||||
)
|
||||
@@ -438,14 +440,15 @@ async def advanced_search(
|
||||
page_info=page_info,
|
||||
)
|
||||
|
||||
# Store in cache (best-effort)
|
||||
# Store in cache with adaptive TTL
|
||||
try:
|
||||
await cache_set_json(
|
||||
kind="advanced",
|
||||
user_id=str(getattr(current_user, "id", "")),
|
||||
parts={"criteria": criteria.model_dump(mode="json")},
|
||||
from app.services.adaptive_cache import adaptive_cache_set
|
||||
await adaptive_cache_set(
|
||||
cache_type="advanced",
|
||||
cache_key="advanced_search",
|
||||
value=response.model_dump(mode="json"),
|
||||
ttl_seconds=90,
|
||||
user_id=str(getattr(current_user, "id", "")),
|
||||
parts={"criteria": criteria.model_dump(mode="json")}
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
@@ -462,9 +465,11 @@ async def global_search(
|
||||
):
|
||||
"""Enhanced global search across all entities"""
|
||||
start_time = datetime.now()
|
||||
# Cache lookup
|
||||
cached = await cache_get_json(
|
||||
kind="global",
|
||||
# Cache lookup with adaptive tracking
|
||||
from app.services.adaptive_cache import adaptive_cache_get
|
||||
cached = await adaptive_cache_get(
|
||||
cache_type="global",
|
||||
cache_key="global_search",
|
||||
user_id=str(getattr(current_user, "id", "")),
|
||||
parts={"q": q, "limit": limit},
|
||||
)
|
||||
@@ -505,12 +510,13 @@ async def global_search(
|
||||
phones=phone_results[:limit]
|
||||
)
|
||||
try:
|
||||
await cache_set_json(
|
||||
kind="global",
|
||||
user_id=str(getattr(current_user, "id", "")),
|
||||
parts={"q": q, "limit": limit},
|
||||
from app.services.adaptive_cache import adaptive_cache_set
|
||||
await adaptive_cache_set(
|
||||
cache_type="global",
|
||||
cache_key="global_search",
|
||||
value=response.model_dump(mode="json"),
|
||||
ttl_seconds=90,
|
||||
user_id=str(getattr(current_user, "id", "")),
|
||||
parts={"q": q, "limit": limit}
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
503
app/api/session_management.py
Normal file
503
app/api/session_management.py
Normal file
@@ -0,0 +1,503 @@
|
||||
"""
|
||||
Session Management API for P2 security features
|
||||
"""
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import List, Optional, Dict, Any
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||
from sqlalchemy.orm import Session
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.database.base import get_db
|
||||
from app.auth.security import get_current_user, get_admin_user
|
||||
from app.models.user import User
|
||||
from app.models.sessions import UserSession, SessionConfiguration, SessionSecurityEvent
|
||||
from app.utils.session_manager import SessionManager, get_session_manager
|
||||
from app.utils.responses import create_success_response as success_response
|
||||
from app.core.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
router = APIRouter(prefix="/api/session", tags=["Session Management"])
|
||||
|
||||
|
||||
# Pydantic schemas
|
||||
class SessionInfo(BaseModel):
|
||||
"""Session information response"""
|
||||
session_id: str
|
||||
user_id: int
|
||||
ip_address: Optional[str] = None
|
||||
user_agent: Optional[str] = None
|
||||
device_fingerprint: Optional[str] = None
|
||||
country: Optional[str] = None
|
||||
city: Optional[str] = None
|
||||
is_suspicious: bool = False
|
||||
risk_score: int = 0
|
||||
status: str
|
||||
created_at: datetime
|
||||
last_activity: datetime
|
||||
expires_at: datetime
|
||||
login_method: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class SessionConfigurationSchema(BaseModel):
|
||||
"""Session configuration schema"""
|
||||
max_concurrent_sessions: int = Field(default=3, ge=1, le=20)
|
||||
session_timeout_minutes: int = Field(default=480, ge=30, le=1440) # 30 min to 24 hours
|
||||
idle_timeout_minutes: int = Field(default=60, ge=5, le=240) # 5 min to 4 hours
|
||||
require_session_renewal: bool = True
|
||||
renewal_interval_hours: int = Field(default=24, ge=1, le=168) # 1 hour to 1 week
|
||||
force_logout_on_ip_change: bool = False
|
||||
suspicious_activity_threshold: int = Field(default=5, ge=1, le=20)
|
||||
allowed_countries: Optional[List[str]] = None
|
||||
blocked_countries: Optional[List[str]] = None
|
||||
|
||||
|
||||
class SessionConfigurationUpdate(BaseModel):
|
||||
"""Session configuration update request"""
|
||||
max_concurrent_sessions: Optional[int] = Field(None, ge=1, le=20)
|
||||
session_timeout_minutes: Optional[int] = Field(None, ge=30, le=1440)
|
||||
idle_timeout_minutes: Optional[int] = Field(None, ge=5, le=240)
|
||||
require_session_renewal: Optional[bool] = None
|
||||
renewal_interval_hours: Optional[int] = Field(None, ge=1, le=168)
|
||||
force_logout_on_ip_change: Optional[bool] = None
|
||||
suspicious_activity_threshold: Optional[int] = Field(None, ge=1, le=20)
|
||||
allowed_countries: Optional[List[str]] = None
|
||||
blocked_countries: Optional[List[str]] = None
|
||||
|
||||
|
||||
class SecurityEventInfo(BaseModel):
|
||||
"""Security event information"""
|
||||
id: int
|
||||
event_type: str
|
||||
severity: str
|
||||
description: str
|
||||
ip_address: Optional[str] = None
|
||||
country: Optional[str] = None
|
||||
action_taken: Optional[str] = None
|
||||
resolved: bool = False
|
||||
timestamp: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
# Session Management Endpoints
|
||||
|
||||
@router.get("/current", response_model=SessionInfo)
|
||||
async def get_current_session(
|
||||
request: Request,
|
||||
current_user: User = Depends(get_current_user),
|
||||
session_manager: SessionManager = Depends(get_session_manager)
|
||||
):
|
||||
"""Get current session information"""
|
||||
try:
|
||||
# Extract session ID from request
|
||||
session_id = request.headers.get("X-Session-ID") or request.cookies.get("session_id")
|
||||
|
||||
if not session_id:
|
||||
# For JWT-based sessions, use a portion of the JWT as session identifier
|
||||
auth_header = request.headers.get("authorization")
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
session_id = auth_header[7:][:32]
|
||||
|
||||
if not session_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="No session identifier found"
|
||||
)
|
||||
|
||||
session = session_manager.validate_session(session_id, request)
|
||||
|
||||
if not session:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Session not found or expired"
|
||||
)
|
||||
|
||||
return SessionInfo.from_orm(session)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting current session: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to retrieve session information"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/list", response_model=List[SessionInfo])
|
||||
async def list_user_sessions(
|
||||
current_user: User = Depends(get_current_user),
|
||||
session_manager: SessionManager = Depends(get_session_manager)
|
||||
):
|
||||
"""List all active sessions for current user"""
|
||||
try:
|
||||
sessions = session_manager.get_active_sessions(current_user.id)
|
||||
return [SessionInfo.from_orm(session) for session in sessions]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing sessions: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to retrieve sessions"
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/revoke/{session_id}")
|
||||
async def revoke_session(
|
||||
session_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
session_manager: SessionManager = Depends(get_session_manager)
|
||||
):
|
||||
"""Revoke a specific session"""
|
||||
try:
|
||||
# Verify the session belongs to the current user
|
||||
session = session_manager.db.query(UserSession).filter(
|
||||
UserSession.session_id == session_id,
|
||||
UserSession.user_id == current_user.id
|
||||
).first()
|
||||
|
||||
if not session:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Session not found"
|
||||
)
|
||||
|
||||
success = session_manager.revoke_session(session_id, "user_revocation")
|
||||
|
||||
if success:
|
||||
return success_response("Session revoked successfully")
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Failed to revoke session"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error revoking session: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to revoke session"
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/revoke-all")
|
||||
async def revoke_all_sessions(
|
||||
current_user: User = Depends(get_current_user),
|
||||
session_manager: SessionManager = Depends(get_session_manager)
|
||||
):
|
||||
"""Revoke all sessions for current user"""
|
||||
try:
|
||||
count = session_manager.revoke_all_user_sessions(current_user.id, "user_revoke_all")
|
||||
|
||||
return success_response(f"Revoked {count} sessions successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error revoking all sessions: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to revoke sessions"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/configuration", response_model=SessionConfigurationSchema)
|
||||
async def get_session_configuration(
|
||||
current_user: User = Depends(get_current_user),
|
||||
session_manager: SessionManager = Depends(get_session_manager)
|
||||
):
|
||||
"""Get session configuration for current user"""
|
||||
try:
|
||||
config = session_manager._get_session_config(current_user)
|
||||
|
||||
return SessionConfigurationSchema(
|
||||
max_concurrent_sessions=config.max_concurrent_sessions,
|
||||
session_timeout_minutes=config.session_timeout_minutes,
|
||||
idle_timeout_minutes=config.idle_timeout_minutes,
|
||||
require_session_renewal=config.require_session_renewal,
|
||||
renewal_interval_hours=config.renewal_interval_hours,
|
||||
force_logout_on_ip_change=config.force_logout_on_ip_change,
|
||||
suspicious_activity_threshold=config.suspicious_activity_threshold,
|
||||
allowed_countries=config.allowed_countries.split(",") if config.allowed_countries else None,
|
||||
blocked_countries=config.blocked_countries.split(",") if config.blocked_countries else None
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting session configuration: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to retrieve session configuration"
|
||||
)
|
||||
|
||||
|
||||
@router.put("/configuration")
|
||||
async def update_session_configuration(
|
||||
config_update: SessionConfigurationUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
session_manager: SessionManager = Depends(get_session_manager)
|
||||
):
|
||||
"""Update session configuration for current user"""
|
||||
try:
|
||||
config = session_manager._get_session_config(current_user)
|
||||
|
||||
# Ensure user-specific config exists
|
||||
if config.user_id is None:
|
||||
# Create user-specific config based on global config
|
||||
user_config = SessionConfiguration(
|
||||
user_id=current_user.id,
|
||||
max_concurrent_sessions=config.max_concurrent_sessions,
|
||||
session_timeout_minutes=config.session_timeout_minutes,
|
||||
idle_timeout_minutes=config.idle_timeout_minutes,
|
||||
require_session_renewal=config.require_session_renewal,
|
||||
renewal_interval_hours=config.renewal_interval_hours,
|
||||
force_logout_on_ip_change=config.force_logout_on_ip_change,
|
||||
suspicious_activity_threshold=config.suspicious_activity_threshold,
|
||||
allowed_countries=config.allowed_countries,
|
||||
blocked_countries=config.blocked_countries
|
||||
)
|
||||
session_manager.db.add(user_config)
|
||||
session_manager.db.flush()
|
||||
config = user_config
|
||||
|
||||
# Update configuration
|
||||
update_data = config_update.dict(exclude_unset=True)
|
||||
|
||||
for field, value in update_data.items():
|
||||
if field in ["allowed_countries", "blocked_countries"] and value:
|
||||
setattr(config, field, ",".join(value))
|
||||
else:
|
||||
setattr(config, field, value)
|
||||
|
||||
config.updated_at = datetime.now(timezone.utc)
|
||||
session_manager.db.commit()
|
||||
|
||||
return success_response("Session configuration updated successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating session configuration: {str(e)}")
|
||||
session_manager.db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to update session configuration"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/security-events", response_model=List[SecurityEventInfo])
|
||||
async def get_security_events(
|
||||
limit: int = 50,
|
||||
resolved: Optional[bool] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Get security events for current user"""
|
||||
try:
|
||||
query = db.query(SessionSecurityEvent).filter(
|
||||
SessionSecurityEvent.user_id == current_user.id
|
||||
)
|
||||
|
||||
if resolved is not None:
|
||||
query = query.filter(SessionSecurityEvent.resolved == resolved)
|
||||
|
||||
events = query.order_by(SessionSecurityEvent.timestamp.desc()).limit(limit).all()
|
||||
|
||||
return [SecurityEventInfo.from_orm(event) for event in events]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting security events: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to retrieve security events"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/statistics")
|
||||
async def get_session_statistics(
|
||||
current_user: User = Depends(get_current_user),
|
||||
session_manager: SessionManager = Depends(get_session_manager)
|
||||
):
|
||||
"""Get session statistics for current user"""
|
||||
try:
|
||||
stats = session_manager.get_session_statistics(current_user.id)
|
||||
return success_response(data=stats)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting session statistics: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to retrieve session statistics"
|
||||
)
|
||||
|
||||
|
||||
# Admin endpoints
|
||||
|
||||
@router.get("/admin/sessions", response_model=List[SessionInfo])
|
||||
async def admin_list_all_sessions(
|
||||
user_id: Optional[int] = None,
|
||||
limit: int = 100,
|
||||
admin_user: User = Depends(get_admin_user),
|
||||
session_manager: SessionManager = Depends(get_session_manager)
|
||||
):
|
||||
"""Admin: List sessions for all users or specific user"""
|
||||
try:
|
||||
if user_id:
|
||||
sessions = session_manager.get_active_sessions(user_id)
|
||||
else:
|
||||
sessions = session_manager.db.query(UserSession).filter(
|
||||
UserSession.status == "active"
|
||||
).order_by(UserSession.last_activity.desc()).limit(limit).all()
|
||||
|
||||
return [SessionInfo.from_orm(session) for session in sessions]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting admin sessions: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to retrieve sessions"
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/admin/revoke/{session_id}")
|
||||
async def admin_revoke_session(
|
||||
session_id: str,
|
||||
reason: str = "admin_revocation",
|
||||
admin_user: User = Depends(get_admin_user),
|
||||
session_manager: SessionManager = Depends(get_session_manager)
|
||||
):
|
||||
"""Admin: Revoke any session"""
|
||||
try:
|
||||
success = session_manager.revoke_session(session_id, f"admin_revocation: {reason}")
|
||||
|
||||
if success:
|
||||
return success_response(f"Session {session_id} revoked successfully")
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Session not found"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error admin revoking session: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to revoke session"
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/admin/revoke-user/{user_id}")
|
||||
async def admin_revoke_user_sessions(
|
||||
user_id: int,
|
||||
reason: str = "admin_action",
|
||||
admin_user: User = Depends(get_admin_user),
|
||||
session_manager: SessionManager = Depends(get_session_manager)
|
||||
):
|
||||
"""Admin: Revoke all sessions for a specific user"""
|
||||
try:
|
||||
count = session_manager.revoke_all_user_sessions(user_id, f"admin_action: {reason}")
|
||||
|
||||
return success_response(f"Revoked {count} sessions for user {user_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error admin revoking user sessions: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to revoke user sessions"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/admin/global-configuration", response_model=SessionConfigurationSchema)
|
||||
async def admin_get_global_configuration(
|
||||
admin_user: User = Depends(get_admin_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Admin: Get global session configuration"""
|
||||
try:
|
||||
config = db.query(SessionConfiguration).filter(
|
||||
SessionConfiguration.user_id.is_(None)
|
||||
).first()
|
||||
|
||||
if not config:
|
||||
# Create default global config
|
||||
config = SessionConfiguration()
|
||||
db.add(config)
|
||||
db.commit()
|
||||
|
||||
return SessionConfigurationSchema(
|
||||
max_concurrent_sessions=config.max_concurrent_sessions,
|
||||
session_timeout_minutes=config.session_timeout_minutes,
|
||||
idle_timeout_minutes=config.idle_timeout_minutes,
|
||||
require_session_renewal=config.require_session_renewal,
|
||||
renewal_interval_hours=config.renewal_interval_hours,
|
||||
force_logout_on_ip_change=config.force_logout_on_ip_change,
|
||||
suspicious_activity_threshold=config.suspicious_activity_threshold,
|
||||
allowed_countries=config.allowed_countries.split(",") if config.allowed_countries else None,
|
||||
blocked_countries=config.blocked_countries.split(",") if config.blocked_countries else None
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting global session configuration: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to retrieve global session configuration"
|
||||
)
|
||||
|
||||
|
||||
@router.put("/admin/global-configuration")
|
||||
async def admin_update_global_configuration(
|
||||
config_update: SessionConfigurationUpdate,
|
||||
admin_user: User = Depends(get_admin_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Admin: Update global session configuration"""
|
||||
try:
|
||||
config = db.query(SessionConfiguration).filter(
|
||||
SessionConfiguration.user_id.is_(None)
|
||||
).first()
|
||||
|
||||
if not config:
|
||||
config = SessionConfiguration()
|
||||
db.add(config)
|
||||
db.flush()
|
||||
|
||||
# Update configuration
|
||||
update_data = config_update.dict(exclude_unset=True)
|
||||
|
||||
for field, value in update_data.items():
|
||||
if field in ["allowed_countries", "blocked_countries"] and value:
|
||||
setattr(config, field, ",".join(value))
|
||||
else:
|
||||
setattr(config, field, value)
|
||||
|
||||
config.updated_at = datetime.now(timezone.utc)
|
||||
db.commit()
|
||||
|
||||
return success_response("Global session configuration updated successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating global session configuration: {str(e)}")
|
||||
db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to update global session configuration"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/admin/statistics")
|
||||
async def admin_get_global_statistics(
|
||||
admin_user: User = Depends(get_admin_user),
|
||||
session_manager: SessionManager = Depends(get_session_manager)
|
||||
):
|
||||
"""Admin: Get global session statistics"""
|
||||
try:
|
||||
stats = session_manager.get_session_statistics()
|
||||
return success_response(data=stats)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting global session statistics: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to retrieve global session statistics"
|
||||
)
|
||||
@@ -20,12 +20,23 @@ from sqlalchemy import func, or_, exists
|
||||
import hashlib
|
||||
|
||||
from app.database.base import get_db
|
||||
from app.auth.security import get_current_user
|
||||
from app.auth.security import get_current_user, get_admin_user
|
||||
from app.models.user import User
|
||||
from app.models.templates import DocumentTemplate, DocumentTemplateVersion, TemplateKeyword
|
||||
from app.services.storage import get_default_storage
|
||||
from app.services.template_merge import extract_tokens_from_bytes, build_context, resolve_tokens, render_docx
|
||||
from app.services.template_service import (
|
||||
get_template_or_404,
|
||||
list_template_versions as svc_list_template_versions,
|
||||
add_template_version as svc_add_template_version,
|
||||
resolve_template_preview as svc_resolve_template_preview,
|
||||
get_download_payload as svc_get_download_payload,
|
||||
)
|
||||
from app.services.query_utils import paginate_with_total
|
||||
from app.services.template_upload import TemplateUploadService
|
||||
from app.services.template_search import TemplateSearchService
|
||||
from app.config import settings
|
||||
from app.services.cache import _get_client
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
@@ -97,6 +108,12 @@ class PaginatedCategoriesResponse(BaseModel):
|
||||
total: int
|
||||
|
||||
|
||||
class TemplateCacheStatusResponse(BaseModel):
|
||||
cache_enabled: bool
|
||||
redis_available: bool
|
||||
mem_cache: Dict[str, int]
|
||||
|
||||
|
||||
@router.post("/upload", response_model=TemplateResponse)
|
||||
async def upload_template(
|
||||
name: str = Form(...),
|
||||
@@ -107,38 +124,15 @@ async def upload_template(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
if file.content_type not in {"application/vnd.openxmlformats-officedocument.wordprocessingml.document", "application/pdf"}:
|
||||
raise HTTPException(status_code=400, detail="Only .docx or .pdf templates are supported")
|
||||
|
||||
content = await file.read()
|
||||
if not content:
|
||||
raise HTTPException(status_code=400, detail="No file uploaded")
|
||||
|
||||
sha256 = hashlib.sha256(content).hexdigest()
|
||||
storage = get_default_storage()
|
||||
storage_path = storage.save_bytes(content=content, filename_hint=file.filename or "template.bin", subdir="templates")
|
||||
|
||||
template = DocumentTemplate(name=name, description=description, category=category, active=True, created_by=getattr(current_user, "username", None))
|
||||
db.add(template)
|
||||
db.flush() # get id
|
||||
|
||||
version = DocumentTemplateVersion(
|
||||
template_id=template.id,
|
||||
service = TemplateUploadService(db)
|
||||
template = await service.upload_template(
|
||||
name=name,
|
||||
category=category,
|
||||
description=description,
|
||||
semantic_version=semantic_version,
|
||||
storage_path=storage_path,
|
||||
mime_type=file.content_type,
|
||||
size=len(content),
|
||||
checksum=sha256,
|
||||
changelog=None,
|
||||
file=file,
|
||||
created_by=getattr(current_user, "username", None),
|
||||
is_approved=True,
|
||||
)
|
||||
db.add(version)
|
||||
db.flush()
|
||||
template.current_version_id = version.id
|
||||
db.commit()
|
||||
db.refresh(template)
|
||||
|
||||
return TemplateResponse(
|
||||
id=template.id,
|
||||
name=template.name,
|
||||
@@ -177,88 +171,34 @@ async def search_templates(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
query = db.query(DocumentTemplate)
|
||||
if active_only:
|
||||
query = query.filter(DocumentTemplate.active == True)
|
||||
if q:
|
||||
like = f"%{q}%"
|
||||
query = query.filter(
|
||||
or_(
|
||||
DocumentTemplate.name.ilike(like),
|
||||
DocumentTemplate.description.ilike(like),
|
||||
)
|
||||
)
|
||||
# Category filtering (supports repeatable param and CSV within each value)
|
||||
# Normalize category values including CSV-in-parameter support
|
||||
categories: Optional[List[str]] = None
|
||||
if category:
|
||||
raw_values = category or []
|
||||
categories: List[str] = []
|
||||
cat_values: List[str] = []
|
||||
for value in raw_values:
|
||||
parts = [part.strip() for part in (value or "").split(",")]
|
||||
for part in parts:
|
||||
if part:
|
||||
categories.append(part)
|
||||
unique_categories = sorted(set(categories))
|
||||
if unique_categories:
|
||||
query = query.filter(DocumentTemplate.category.in_(unique_categories))
|
||||
if keywords:
|
||||
normalized = [kw.strip().lower() for kw in keywords if kw and kw.strip()]
|
||||
unique_keywords = sorted(set(normalized))
|
||||
if unique_keywords:
|
||||
mode = (keywords_mode or "any").lower()
|
||||
if mode not in ("any", "all"):
|
||||
mode = "any"
|
||||
query = query.join(TemplateKeyword, TemplateKeyword.template_id == DocumentTemplate.id)
|
||||
if mode == "any":
|
||||
query = query.filter(TemplateKeyword.keyword.in_(unique_keywords)).distinct()
|
||||
else:
|
||||
query = query.filter(TemplateKeyword.keyword.in_(unique_keywords))
|
||||
query = query.group_by(DocumentTemplate.id)
|
||||
query = query.having(func.count(func.distinct(TemplateKeyword.keyword)) == len(unique_keywords))
|
||||
# Has keywords filter (independent of specific keyword matches)
|
||||
if has_keywords is not None:
|
||||
kw_exists = exists().where(TemplateKeyword.template_id == DocumentTemplate.id)
|
||||
if has_keywords:
|
||||
query = query.filter(kw_exists)
|
||||
else:
|
||||
query = query.filter(~kw_exists)
|
||||
# Sorting
|
||||
sort_key = (sort_by or "name").lower()
|
||||
direction = (sort_dir or "asc").lower()
|
||||
if sort_key not in ("name", "category", "updated"):
|
||||
sort_key = "name"
|
||||
if direction not in ("asc", "desc"):
|
||||
direction = "asc"
|
||||
cat_values.append(part)
|
||||
categories = sorted(set(cat_values))
|
||||
|
||||
if sort_key == "name":
|
||||
order_col = DocumentTemplate.name
|
||||
elif sort_key == "category":
|
||||
order_col = DocumentTemplate.category
|
||||
else: # updated
|
||||
order_col = func.coalesce(DocumentTemplate.updated_at, DocumentTemplate.created_at)
|
||||
search_service = TemplateSearchService(db)
|
||||
results, total = await search_service.search_templates(
|
||||
q=q,
|
||||
categories=categories,
|
||||
keywords=keywords,
|
||||
keywords_mode=keywords_mode,
|
||||
has_keywords=has_keywords,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
sort_by=sort_by or "name",
|
||||
sort_dir=sort_dir or "asc",
|
||||
active_only=active_only,
|
||||
include_total=include_total,
|
||||
)
|
||||
|
||||
if direction == "asc":
|
||||
query = query.order_by(order_col.asc())
|
||||
else:
|
||||
query = query.order_by(order_col.desc())
|
||||
|
||||
# Pagination with optional total
|
||||
templates, total = paginate_with_total(query, skip, limit, include_total)
|
||||
items: List[SearchResponseItem] = []
|
||||
for tpl in templates:
|
||||
latest_version = None
|
||||
if tpl.current_version_id:
|
||||
ver = db.query(DocumentTemplateVersion).filter(DocumentTemplateVersion.id == tpl.current_version_id).first()
|
||||
if ver:
|
||||
latest_version = ver.semantic_version
|
||||
items.append(
|
||||
SearchResponseItem(
|
||||
id=tpl.id,
|
||||
name=tpl.name,
|
||||
category=tpl.category,
|
||||
active=tpl.active,
|
||||
latest_version=latest_version,
|
||||
)
|
||||
)
|
||||
items: List[SearchResponseItem] = [SearchResponseItem(**it) for it in results]
|
||||
if include_total:
|
||||
return {"items": items, "total": int(total or 0)}
|
||||
return items
|
||||
@@ -271,25 +211,65 @@ async def list_template_categories(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
query = db.query(DocumentTemplate.category, func.count(DocumentTemplate.id).label("count"))
|
||||
if active_only:
|
||||
query = query.filter(DocumentTemplate.active == True)
|
||||
rows = query.group_by(DocumentTemplate.category).order_by(DocumentTemplate.category.asc()).all()
|
||||
search_service = TemplateSearchService(db)
|
||||
rows = await search_service.list_categories(active_only=active_only)
|
||||
items = [CategoryCount(category=row[0], count=row[1]) for row in rows]
|
||||
if include_total:
|
||||
return {"items": items, "total": len(items)}
|
||||
return items
|
||||
|
||||
|
||||
@router.get("/_cache_status", response_model=TemplateCacheStatusResponse)
|
||||
async def cache_status(
|
||||
current_user: User = Depends(get_admin_user),
|
||||
):
|
||||
# In-memory cache breakdown
|
||||
with TemplateSearchService._mem_lock:
|
||||
keys = list(TemplateSearchService._mem_cache.keys())
|
||||
mem_templates = sum(1 for k in keys if k.startswith("search:templates:"))
|
||||
mem_categories = sum(1 for k in keys if k.startswith("search:templates_categories:"))
|
||||
|
||||
# Redis availability check (best-effort)
|
||||
redis_available = False
|
||||
try:
|
||||
client = await _get_client()
|
||||
if client is not None:
|
||||
try:
|
||||
pong = await client.ping()
|
||||
redis_available = bool(pong)
|
||||
except Exception:
|
||||
redis_available = False
|
||||
except Exception:
|
||||
redis_available = False
|
||||
|
||||
return TemplateCacheStatusResponse(
|
||||
cache_enabled=bool(getattr(settings, "cache_enabled", False)),
|
||||
redis_available=redis_available,
|
||||
mem_cache={
|
||||
"templates": int(mem_templates),
|
||||
"categories": int(mem_categories),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/_cache_invalidate")
|
||||
async def cache_invalidate(
|
||||
current_user: User = Depends(get_admin_user),
|
||||
):
|
||||
try:
|
||||
await TemplateSearchService.invalidate_all()
|
||||
return {"cleared": True}
|
||||
except Exception as e:
|
||||
return {"cleared": False, "error": str(e)}
|
||||
|
||||
|
||||
@router.get("/{template_id}", response_model=TemplateResponse)
|
||||
async def get_template(
|
||||
template_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
tpl = db.query(DocumentTemplate).filter(DocumentTemplate.id == template_id).first()
|
||||
if not tpl:
|
||||
raise HTTPException(status_code=404, detail="Template not found")
|
||||
tpl = get_template_or_404(db, template_id)
|
||||
return TemplateResponse(
|
||||
id=tpl.id,
|
||||
name=tpl.name,
|
||||
@@ -306,12 +286,7 @@ async def list_versions(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
versions = (
|
||||
db.query(DocumentTemplateVersion)
|
||||
.filter(DocumentTemplateVersion.template_id == template_id)
|
||||
.order_by(DocumentTemplateVersion.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
versions = svc_list_template_versions(db, template_id)
|
||||
return [
|
||||
VersionResponse(
|
||||
id=v.id,
|
||||
@@ -337,31 +312,18 @@ async def add_version(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
tpl = db.query(DocumentTemplate).filter(DocumentTemplate.id == template_id).first()
|
||||
if not tpl:
|
||||
raise HTTPException(status_code=404, detail="Template not found")
|
||||
content = await file.read()
|
||||
if not content:
|
||||
raise HTTPException(status_code=400, detail="No file uploaded")
|
||||
sha256 = hashlib.sha256(content).hexdigest()
|
||||
storage = get_default_storage()
|
||||
storage_path = storage.save_bytes(content=content, filename_hint=file.filename or "template.bin", subdir="templates")
|
||||
version = DocumentTemplateVersion(
|
||||
version = svc_add_template_version(
|
||||
db,
|
||||
template_id=template_id,
|
||||
semantic_version=semantic_version,
|
||||
storage_path=storage_path,
|
||||
mime_type=file.content_type,
|
||||
size=len(content),
|
||||
checksum=sha256,
|
||||
changelog=changelog,
|
||||
approve=approve,
|
||||
content=content,
|
||||
filename_hint=file.filename or "template.bin",
|
||||
content_type=file.content_type,
|
||||
created_by=getattr(current_user, "username", None),
|
||||
is_approved=bool(approve),
|
||||
)
|
||||
db.add(version)
|
||||
db.flush()
|
||||
if approve:
|
||||
tpl.current_version_id = version.id
|
||||
db.commit()
|
||||
return VersionResponse(
|
||||
id=version.id,
|
||||
template_id=version.template_id,
|
||||
@@ -381,31 +343,32 @@ async def preview_template(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
tpl = db.query(DocumentTemplate).filter(DocumentTemplate.id == template_id).first()
|
||||
if not tpl:
|
||||
raise HTTPException(status_code=404, detail="Template not found")
|
||||
version_id = payload.version_id or tpl.current_version_id
|
||||
if not version_id:
|
||||
raise HTTPException(status_code=400, detail="Template has no versions")
|
||||
ver = db.query(DocumentTemplateVersion).filter(DocumentTemplateVersion.id == version_id).first()
|
||||
if not ver:
|
||||
raise HTTPException(status_code=404, detail="Version not found")
|
||||
resolved, unresolved, output_bytes, output_mime = svc_resolve_template_preview(
|
||||
db,
|
||||
template_id=template_id,
|
||||
version_id=payload.version_id,
|
||||
context=payload.context or {},
|
||||
)
|
||||
|
||||
storage = get_default_storage()
|
||||
content = storage.open_bytes(ver.storage_path)
|
||||
tokens = extract_tokens_from_bytes(content)
|
||||
context = build_context(payload.context or {})
|
||||
resolved, unresolved = resolve_tokens(db, tokens, context)
|
||||
# Sanitize resolved values to ensure JSON-serializable output
|
||||
def _json_sanitize(value: Any) -> Any:
|
||||
from datetime import date, datetime
|
||||
if value is None or isinstance(value, (str, int, float, bool)):
|
||||
return value
|
||||
if isinstance(value, (date, datetime)):
|
||||
return value.isoformat()
|
||||
if isinstance(value, (list, tuple)):
|
||||
return [_json_sanitize(v) for v in value]
|
||||
if isinstance(value, dict):
|
||||
return {k: _json_sanitize(v) for k, v in value.items()}
|
||||
# Fallback: stringify unsupported types (e.g., functions)
|
||||
return str(value)
|
||||
|
||||
output_bytes = content
|
||||
output_mime = ver.mime_type
|
||||
if ver.mime_type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
|
||||
output_bytes = render_docx(content, resolved)
|
||||
output_mime = ver.mime_type
|
||||
sanitized_resolved = {k: _json_sanitize(v) for k, v in resolved.items()}
|
||||
|
||||
# We don't store preview output; just return metadata and resolution state
|
||||
return PreviewResponse(
|
||||
resolved=resolved,
|
||||
resolved=sanitized_resolved,
|
||||
unresolved=unresolved,
|
||||
output_mime_type=output_mime,
|
||||
output_size=len(output_bytes),
|
||||
@@ -419,40 +382,16 @@ async def download_template(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
tpl = db.query(DocumentTemplate).filter(DocumentTemplate.id == template_id).first()
|
||||
if not tpl:
|
||||
raise HTTPException(status_code=404, detail="Template not found")
|
||||
|
||||
# Determine which version to serve
|
||||
resolved_version_id = version_id or tpl.current_version_id
|
||||
if not resolved_version_id:
|
||||
raise HTTPException(status_code=404, detail="Template has no approved version")
|
||||
|
||||
ver = (
|
||||
db.query(DocumentTemplateVersion)
|
||||
.filter(DocumentTemplateVersion.id == resolved_version_id, DocumentTemplateVersion.template_id == tpl.id)
|
||||
.first()
|
||||
content, mime_type, original_name = svc_get_download_payload(
|
||||
db,
|
||||
template_id=template_id,
|
||||
version_id=version_id,
|
||||
)
|
||||
if not ver:
|
||||
raise HTTPException(status_code=404, detail="Version not found")
|
||||
|
||||
storage = get_default_storage()
|
||||
try:
|
||||
content = storage.open_bytes(ver.storage_path)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404, detail="Stored file not found")
|
||||
|
||||
# Derive original filename from storage_path (uuid_prefix_originalname)
|
||||
base = os.path.basename(ver.storage_path)
|
||||
if "_" in base:
|
||||
original_name = base.split("_", 1)[1]
|
||||
else:
|
||||
original_name = base
|
||||
|
||||
headers = {
|
||||
"Content-Disposition": f"attachment; filename=\"{original_name}\"",
|
||||
}
|
||||
return StreamingResponse(iter([content]), media_type=ver.mime_type, headers=headers)
|
||||
return StreamingResponse(iter([content]), media_type=mime_type, headers=headers)
|
||||
|
||||
|
||||
@router.get("/{template_id}/keywords", response_model=KeywordsResponse)
|
||||
@@ -461,16 +400,9 @@ async def list_keywords(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
tpl = db.query(DocumentTemplate).filter(DocumentTemplate.id == template_id).first()
|
||||
if not tpl:
|
||||
raise HTTPException(status_code=404, detail="Template not found")
|
||||
kws = (
|
||||
db.query(TemplateKeyword)
|
||||
.filter(TemplateKeyword.template_id == template_id)
|
||||
.order_by(TemplateKeyword.keyword.asc())
|
||||
.all()
|
||||
)
|
||||
return KeywordsResponse(keywords=[k.keyword for k in kws])
|
||||
search_service = TemplateSearchService(db)
|
||||
keywords = search_service.list_keywords(template_id)
|
||||
return KeywordsResponse(keywords=keywords)
|
||||
|
||||
|
||||
@router.post("/{template_id}/keywords", response_model=KeywordsResponse)
|
||||
@@ -480,31 +412,9 @@ async def add_keywords(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
tpl = db.query(DocumentTemplate).filter(DocumentTemplate.id == template_id).first()
|
||||
if not tpl:
|
||||
raise HTTPException(status_code=404, detail="Template not found")
|
||||
to_add = []
|
||||
for kw in (payload.keywords or []):
|
||||
normalized = (kw or "").strip().lower()
|
||||
if not normalized:
|
||||
continue
|
||||
exists = (
|
||||
db.query(TemplateKeyword)
|
||||
.filter(TemplateKeyword.template_id == template_id, TemplateKeyword.keyword == normalized)
|
||||
.first()
|
||||
)
|
||||
if not exists:
|
||||
to_add.append(TemplateKeyword(template_id=template_id, keyword=normalized))
|
||||
if to_add:
|
||||
db.add_all(to_add)
|
||||
db.commit()
|
||||
kws = (
|
||||
db.query(TemplateKeyword)
|
||||
.filter(TemplateKeyword.template_id == template_id)
|
||||
.order_by(TemplateKeyword.keyword.asc())
|
||||
.all()
|
||||
)
|
||||
return KeywordsResponse(keywords=[k.keyword for k in kws])
|
||||
search_service = TemplateSearchService(db)
|
||||
keywords = await search_service.add_keywords(template_id, payload.keywords)
|
||||
return KeywordsResponse(keywords=keywords)
|
||||
|
||||
|
||||
@router.delete("/{template_id}/keywords/{keyword}", response_model=KeywordsResponse)
|
||||
@@ -514,21 +424,7 @@ async def remove_keyword(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
tpl = db.query(DocumentTemplate).filter(DocumentTemplate.id == template_id).first()
|
||||
if not tpl:
|
||||
raise HTTPException(status_code=404, detail="Template not found")
|
||||
normalized = (keyword or "").strip().lower()
|
||||
if normalized:
|
||||
db.query(TemplateKeyword).filter(
|
||||
TemplateKeyword.template_id == template_id,
|
||||
TemplateKeyword.keyword == normalized,
|
||||
).delete(synchronize_session=False)
|
||||
db.commit()
|
||||
kws = (
|
||||
db.query(TemplateKeyword)
|
||||
.filter(TemplateKeyword.template_id == template_id)
|
||||
.order_by(TemplateKeyword.keyword.asc())
|
||||
.all()
|
||||
)
|
||||
return KeywordsResponse(keywords=[k.keyword for k in kws])
|
||||
search_service = TemplateSearchService(db)
|
||||
keywords = await search_service.remove_keyword(template_id, keyword)
|
||||
return KeywordsResponse(keywords=keywords)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user