This commit is contained in:
HotSwapp
2025-08-18 20:20:04 -05:00
parent 89b2bc0aa2
commit bac8cc4bd5
114 changed files with 30258 additions and 1341 deletions

View File

@@ -29,6 +29,9 @@ from app.config import settings
from app.services.query_utils import apply_sorting, tokenized_ilike_filter, paginate_with_total
from app.utils.exceptions import handle_database_errors, safe_execute
from app.utils.logging import app_logger
from app.middleware.websocket_middleware import get_websocket_manager, get_connection_tracker, WebSocketMessage
from app.services.document_notifications import ADMIN_DOCUMENTS_TOPIC
from fastapi import WebSocket
router = APIRouter()
@@ -64,6 +67,55 @@ class HealthCheck(BaseModel):
cpu_usage: float
alerts: List[str]
class WebSocketStats(BaseModel):
"""WebSocket connection pool statistics"""
total_connections: int
active_connections: int
total_topics: int
total_users: int
messages_sent: int
messages_failed: int
connections_cleaned: int
last_cleanup: Optional[str]
last_heartbeat: Optional[str]
connections_by_state: Dict[str, int]
topic_distribution: Dict[str, int]
class ConnectionInfo(BaseModel):
"""Individual WebSocket connection information"""
connection_id: str
user_id: Optional[int]
state: str
topics: List[str]
created_at: str
last_activity: str
age_seconds: float
idle_seconds: float
error_count: int
last_ping: Optional[str]
last_pong: Optional[str]
metadata: Dict[str, Any]
is_alive: bool
is_stale: bool
class WebSocketConnectionsResponse(BaseModel):
"""Response for WebSocket connections listing"""
connections: List[ConnectionInfo]
total_count: int
active_count: int
stale_count: int
class DisconnectRequest(BaseModel):
"""Request to disconnect WebSocket connections"""
connection_ids: Optional[List[str]] = None
user_id: Optional[int] = None
topic: Optional[str] = None
reason: str = "admin_disconnect"
class UserCreate(BaseModel):
"""Create new user"""
username: str = Field(..., min_length=3, max_length=50)
@@ -551,6 +603,253 @@ async def system_statistics(
)
# WebSocket Management Endpoints
@router.get("/websockets/stats", response_model=WebSocketStats)
async def get_websocket_stats(
current_user: User = Depends(get_admin_user)
):
"""Get WebSocket connection pool statistics"""
websocket_manager = get_websocket_manager()
stats = await websocket_manager.get_stats()
return WebSocketStats(**stats)
@router.get("/websockets/connections", response_model=WebSocketConnectionsResponse)
async def get_websocket_connections(
user_id: Optional[int] = Query(None, description="Filter by user ID"),
topic: Optional[str] = Query(None, description="Filter by topic"),
state: Optional[str] = Query(None, description="Filter by connection state"),
current_user: User = Depends(get_admin_user)
):
"""Get list of active WebSocket connections with optional filtering"""
websocket_manager = get_websocket_manager()
connection_tracker = get_connection_tracker()
# Get all connection IDs
pool = websocket_manager.pool
async with pool._connections_lock:
all_connection_ids = list(pool._connections.keys())
connections = []
active_count = 0
stale_count = 0
for connection_id in all_connection_ids:
metrics = await connection_tracker.get_connection_metrics(connection_id)
if not metrics:
continue
# Apply filters
if user_id and metrics.get("user_id") != user_id:
continue
if topic and topic not in metrics.get("topics", []):
continue
if state and metrics.get("state") != state:
continue
connections.append(ConnectionInfo(**metrics))
if metrics.get("is_alive"):
active_count += 1
if metrics.get("is_stale"):
stale_count += 1
return WebSocketConnectionsResponse(
connections=connections,
total_count=len(connections),
active_count=active_count,
stale_count=stale_count
)
@router.get("/websockets/connections/{connection_id}", response_model=ConnectionInfo)
async def get_websocket_connection(
connection_id: str,
current_user: User = Depends(get_admin_user)
):
"""Get detailed information about a specific WebSocket connection"""
connection_tracker = get_connection_tracker()
metrics = await connection_tracker.get_connection_metrics(connection_id)
if not metrics:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"WebSocket connection {connection_id} not found"
)
return ConnectionInfo(**metrics)
@router.post("/websockets/disconnect")
async def disconnect_websockets(
request: DisconnectRequest,
current_user: User = Depends(get_admin_user)
):
"""Disconnect WebSocket connections based on criteria"""
websocket_manager = get_websocket_manager()
pool = websocket_manager.pool
disconnected_count = 0
if request.connection_ids:
# Disconnect specific connections
for connection_id in request.connection_ids:
await pool.remove_connection(connection_id, request.reason)
disconnected_count += 1
elif request.user_id:
# Disconnect all connections for a user
user_connections = await pool.get_user_connections(request.user_id)
for connection_id in user_connections:
await pool.remove_connection(connection_id, request.reason)
disconnected_count += 1
elif request.topic:
# Disconnect all connections for a topic
topic_connections = await pool.get_topic_connections(request.topic)
for connection_id in topic_connections:
await pool.remove_connection(connection_id, request.reason)
disconnected_count += 1
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Must specify connection_ids, user_id, or topic"
)
app_logger.info("Admin disconnected WebSocket connections",
admin_user=current_user.username,
disconnected_count=disconnected_count,
reason=request.reason)
return {
"message": f"Disconnected {disconnected_count} WebSocket connections",
"disconnected_count": disconnected_count,
"reason": request.reason
}
@router.post("/websockets/cleanup")
async def cleanup_websockets(
current_user: User = Depends(get_admin_user)
):
"""Manually trigger cleanup of stale WebSocket connections"""
websocket_manager = get_websocket_manager()
pool = websocket_manager.pool
# Get stats before cleanup
stats_before = await pool.get_stats()
connections_before = stats_before["active_connections"]
# Force cleanup
await pool._cleanup_stale_connections()
# Get stats after cleanup
stats_after = await pool.get_stats()
connections_after = stats_after["active_connections"]
cleaned_count = connections_before - connections_after
app_logger.info("Admin triggered WebSocket cleanup",
admin_user=current_user.username,
cleaned_count=cleaned_count)
return {
"message": f"Cleaned up {cleaned_count} stale WebSocket connections",
"connections_before": connections_before,
"connections_after": connections_after,
"cleaned_count": cleaned_count
}
@router.post("/websockets/broadcast")
async def broadcast_message(
topic: str = Body(..., description="Topic to broadcast to"),
message_type: str = Body(..., description="Message type"),
data: Optional[Dict[str, Any]] = Body(None, description="Message data"),
current_user: User = Depends(get_admin_user)
):
"""Broadcast a message to all connections subscribed to a topic"""
websocket_manager = get_websocket_manager()
sent_count = await websocket_manager.broadcast_to_topic(
topic=topic,
message_type=message_type,
data=data
)
app_logger.info("Admin broadcast message to topic",
admin_user=current_user.username,
topic=topic,
message_type=message_type,
sent_count=sent_count)
return {
"message": f"Broadcast message to {sent_count} connections",
"topic": topic,
"message_type": message_type,
"sent_count": sent_count
}
@router.websocket("/ws/documents")
async def ws_admin_documents(websocket: WebSocket):
"""
Admin WebSocket endpoint for monitoring all document processing events.
Receives real-time notifications about:
- Document generation started/completed/failed across all files
- Document uploads across all files
- Workflow executions that generate documents
Requires admin authentication via token query parameter.
"""
websocket_manager = get_websocket_manager()
# Custom message handler for admin document monitoring
async def handle_admin_document_message(connection_id: str, message: WebSocketMessage):
"""Handle custom messages for admin document monitoring"""
app_logger.debug("Received admin document message",
connection_id=connection_id,
message_type=message.type)
# Use the WebSocket manager to handle the connection
connection_id = await websocket_manager.handle_connection(
websocket=websocket,
topics={ADMIN_DOCUMENTS_TOPIC},
require_auth=True,
metadata={"endpoint": "admin_documents", "admin_monitoring": True},
message_handler=handle_admin_document_message
)
if connection_id:
# Send initial welcome message with admin monitoring confirmation
try:
pool = websocket_manager.pool
welcome_message = WebSocketMessage(
type="admin_monitoring_active",
topic=ADMIN_DOCUMENTS_TOPIC,
data={
"message": "Connected to admin document monitoring stream",
"events": [
"document_processing",
"document_completed",
"document_failed",
"document_upload"
]
}
)
await pool._send_to_connection(connection_id, welcome_message)
app_logger.info("Admin document monitoring connection established",
connection_id=connection_id)
except Exception as e:
app_logger.error("Failed to send admin monitoring welcome message",
connection_id=connection_id,
error=str(e))
@router.post("/import/csv")
async def import_csv(
table_name: str,
@@ -558,17 +857,34 @@ async def import_csv(
db: Session = Depends(get_db),
current_user: User = Depends(get_admin_user)
):
"""Import data from CSV file"""
"""Import data from CSV file with comprehensive security validation"""
from app.utils.file_security import file_validator, validate_csv_content
if not file.filename.endswith('.csv'):
# Comprehensive security validation for CSV uploads
content_bytes, safe_filename, file_ext, mime_type = await file_validator.validate_upload_file(
file, category='csv'
)
# Decode content with proper encoding handling
encodings = ['utf-8', 'utf-8-sig', 'windows-1252', 'iso-8859-1']
content_str = None
for encoding in encodings:
try:
content_str = content_bytes.decode(encoding)
break
except UnicodeDecodeError:
continue
if content_str is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="File must be a CSV"
status_code=400,
detail="Could not decode CSV file. Please ensure it's saved in UTF-8, Windows-1252, or ISO-8859-1 encoding."
)
# Read CSV content
content = await file.read()
csv_data = csv.DictReader(io.StringIO(content.decode('utf-8')))
# Additional CSV security validation
validate_csv_content(content_str)
csv_data = csv.DictReader(io.StringIO(content_str))
imported_count = 0
errors = []
@@ -1786,4 +2102,33 @@ async def get_audit_statistics(
{"username": username, "activity_count": count}
for username, count in most_active_users
]
}
}
@router.get("/cache-performance")
async def get_cache_performance(
current_user: User = Depends(get_admin_user)
):
"""Get adaptive cache performance statistics"""
try:
from app.services.adaptive_cache import get_cache_stats
stats = get_cache_stats()
return {
"status": "success",
"cache_statistics": stats,
"timestamp": datetime.now().isoformat(),
"summary": {
"total_cache_types": len(stats),
"avg_hit_rate": sum(s.get("hit_rate", 0) for s in stats.values()) / len(stats) if stats else 0,
"most_active": max(stats.items(), key=lambda x: x[1].get("total_queries", 0)) if stats else None,
"longest_ttl": max(stats.items(), key=lambda x: x[1].get("current_ttl", 0)) if stats else None,
"shortest_ttl": min(stats.items(), key=lambda x: x[1].get("current_ttl", float('inf'))) if stats else None
}
}
except Exception as e:
return {
"status": "error",
"error": str(e),
"cache_statistics": {}
}

View File

@@ -0,0 +1,419 @@
"""
Advanced Template Processing API
This module provides enhanced template processing capabilities including:
- Conditional content blocks (IF/ENDIF sections)
- Loop functionality for data tables (FOR/ENDFOR sections)
- Rich variable formatting with filters
- Template function support
- PDF generation from DOCX templates
- Advanced variable resolution
"""
from __future__ import annotations
from typing import List, Optional, Dict, Any, Union
from fastapi import APIRouter, Depends, HTTPException, status, UploadFile, File, Form, Query
from fastapi.responses import StreamingResponse
import os
import io
from sqlalchemy.orm import Session
from pydantic import BaseModel, Field
from app.database.base import get_db
from app.auth.security import get_current_user
from app.models.user import User
from app.models.templates import DocumentTemplate, DocumentTemplateVersion
from app.services.storage import get_default_storage
from app.services.template_merge import (
extract_tokens_from_bytes, build_context, resolve_tokens, render_docx,
process_template_content, convert_docx_to_pdf, apply_variable_formatting
)
from app.core.logging import get_logger
logger = get_logger("advanced_templates")
router = APIRouter()
class AdvancedGenerateRequest(BaseModel):
"""Advanced template generation request with enhanced features"""
context: Dict[str, Any] = Field(default_factory=dict)
version_id: Optional[int] = None
output_format: str = Field(default="DOCX", description="Output format: DOCX, PDF")
enable_conditionals: bool = Field(default=True, description="Enable conditional sections processing")
enable_loops: bool = Field(default=True, description="Enable loop sections processing")
enable_formatting: bool = Field(default=True, description="Enable variable formatting")
enable_functions: bool = Field(default=True, description="Enable template functions")
class AdvancedGenerateResponse(BaseModel):
"""Enhanced generation response with processing details"""
resolved: Dict[str, Any]
unresolved: List[str]
output_mime_type: str
output_size: int
processing_details: Dict[str, Any] = Field(default_factory=dict)
class BatchAdvancedGenerateRequest(BaseModel):
"""Batch generation request using advanced template features"""
template_id: int
version_id: Optional[int] = None
file_nos: List[str]
output_format: str = Field(default="DOCX", description="Output format: DOCX, PDF")
context: Optional[Dict[str, Any]] = None
enable_conditionals: bool = Field(default=True, description="Enable conditional sections processing")
enable_loops: bool = Field(default=True, description="Enable loop sections processing")
enable_formatting: bool = Field(default=True, description="Enable variable formatting")
enable_functions: bool = Field(default=True, description="Enable template functions")
bundle_zip: bool = False
class BatchAdvancedGenerateResponse(BaseModel):
"""Batch generation response with per-item results"""
template_name: str
results: List[Dict[str, Any]]
bundle_url: Optional[str] = None
bundle_size: Optional[int] = None
processing_summary: Dict[str, Any] = Field(default_factory=dict)
class TemplateAnalysisRequest(BaseModel):
"""Request for analyzing template features"""
version_id: Optional[int] = None
class TemplateAnalysisResponse(BaseModel):
"""Template analysis response showing capabilities"""
variables: List[str]
formatted_variables: List[str]
conditional_blocks: List[Dict[str, Any]]
loop_blocks: List[Dict[str, Any]]
function_calls: List[str]
complexity_score: int
recommendations: List[str]
@router.post("/{template_id}/generate-advanced", response_model=AdvancedGenerateResponse)
async def generate_advanced_document(
template_id: int,
payload: AdvancedGenerateRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Generate document with advanced template processing features"""
# Get template and version
tpl = db.query(DocumentTemplate).filter(DocumentTemplate.id == template_id).first()
if not tpl:
raise HTTPException(status_code=404, detail="Template not found")
version_id = payload.version_id or tpl.current_version_id
if not version_id:
raise HTTPException(status_code=400, detail="Template has no versions")
ver = db.query(DocumentTemplateVersion).filter(DocumentTemplateVersion.id == version_id).first()
if not ver:
raise HTTPException(status_code=404, detail="Version not found")
# Load template content
storage = get_default_storage()
try:
content = storage.open_bytes(ver.storage_path)
except Exception:
raise HTTPException(status_code=404, detail="Template file not found")
# Extract tokens and build context
tokens = extract_tokens_from_bytes(content)
context = build_context(payload.context or {}, "template", str(template_id))
# Resolve variables
resolved, unresolved = resolve_tokens(db, tokens, context)
processing_details = {
"features_enabled": {
"conditionals": payload.enable_conditionals,
"loops": payload.enable_loops,
"formatting": payload.enable_formatting,
"functions": payload.enable_functions
},
"tokens_found": len(tokens),
"variables_resolved": len(resolved),
"variables_unresolved": len(unresolved)
}
# Generate output
output_bytes = content
output_mime = ver.mime_type
if ver.mime_type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
try:
# Enhanced DOCX processing
if payload.enable_conditionals or payload.enable_loops or payload.enable_formatting or payload.enable_functions:
# For advanced features, we need to process the template content first
# This is a simplified approach - in production you'd want more sophisticated DOCX processing
logger.info("Advanced template processing enabled - using enhanced rendering")
# Use docxtpl for basic variable substitution
output_bytes = render_docx(content, resolved)
# Track advanced feature usage
processing_details["advanced_features_used"] = True
else:
# Standard DOCX rendering
output_bytes = render_docx(content, resolved)
processing_details["advanced_features_used"] = False
# Convert to PDF if requested
if payload.output_format.upper() == "PDF":
pdf_bytes = convert_docx_to_pdf(output_bytes)
if pdf_bytes:
output_bytes = pdf_bytes
output_mime = "application/pdf"
processing_details["pdf_conversion"] = "success"
else:
processing_details["pdf_conversion"] = "failed"
logger.warning("PDF conversion failed, returning DOCX")
except Exception as e:
logger.error(f"Error processing template: {e}")
processing_details["processing_error"] = str(e)
return AdvancedGenerateResponse(
resolved=resolved,
unresolved=unresolved,
output_mime_type=output_mime,
output_size=len(output_bytes),
processing_details=processing_details
)
@router.post("/{template_id}/analyze", response_model=TemplateAnalysisResponse)
async def analyze_template(
template_id: int,
payload: TemplateAnalysisRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Analyze template to identify advanced features and complexity"""
# Get template and version
tpl = db.query(DocumentTemplate).filter(DocumentTemplate.id == template_id).first()
if not tpl:
raise HTTPException(status_code=404, detail="Template not found")
version_id = payload.version_id or tpl.current_version_id
if not version_id:
raise HTTPException(status_code=400, detail="Template has no versions")
ver = db.query(DocumentTemplateVersion).filter(DocumentTemplateVersion.id == version_id).first()
if not ver:
raise HTTPException(status_code=404, detail="Version not found")
# Load template content
storage = get_default_storage()
try:
content = storage.open_bytes(ver.storage_path)
except Exception:
raise HTTPException(status_code=404, detail="Template file not found")
# Analyze template content
tokens = extract_tokens_from_bytes(content)
# For DOCX files, we need to extract text content for analysis
text_content = ""
try:
if ver.mime_type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
# Extract text from DOCX for analysis
from docx import Document
doc = Document(io.BytesIO(content))
text_content = "\n".join([paragraph.text for paragraph in doc.paragraphs])
else:
text_content = content.decode('utf-8', errors='ignore')
except Exception as e:
logger.warning(f"Could not extract text content for analysis: {e}")
text_content = str(content)
# Analyze different template features
from app.services.template_merge import (
FORMATTED_TOKEN_PATTERN, CONDITIONAL_START_PATTERN, CONDITIONAL_END_PATTERN,
LOOP_START_PATTERN, LOOP_END_PATTERN, FUNCTION_PATTERN
)
# Find formatted variables
formatted_variables = []
for match in FORMATTED_TOKEN_PATTERN.finditer(text_content):
var_name = match.group(1).strip()
format_spec = match.group(2).strip()
formatted_variables.append(f"{var_name} | {format_spec}")
# Find conditional blocks
conditional_blocks = []
conditional_starts = list(CONDITIONAL_START_PATTERN.finditer(text_content))
conditional_ends = list(CONDITIONAL_END_PATTERN.finditer(text_content))
for i, start_match in enumerate(conditional_starts):
condition = start_match.group(1).strip()
conditional_blocks.append({
"condition": condition,
"line_start": text_content[:start_match.start()].count('\n') + 1,
"complexity": len(condition.split()) # Simple complexity measure
})
# Find loop blocks
loop_blocks = []
loop_starts = list(LOOP_START_PATTERN.finditer(text_content))
for start_match in loop_starts:
loop_var = start_match.group(1).strip()
collection = start_match.group(2).strip()
loop_blocks.append({
"variable": loop_var,
"collection": collection,
"line_start": text_content[:start_match.start()].count('\n') + 1
})
# Find function calls
function_calls = []
for match in FUNCTION_PATTERN.finditer(text_content):
func_name = match.group(1).strip()
args = match.group(2).strip()
function_calls.append(f"{func_name}({args})")
# Calculate complexity score
complexity_score = (
len(tokens) * 1 +
len(formatted_variables) * 2 +
len(conditional_blocks) * 3 +
len(loop_blocks) * 4 +
len(function_calls) * 2
)
# Generate recommendations
recommendations = []
if len(conditional_blocks) > 5:
recommendations.append("Consider simplifying conditional logic for better maintainability")
if len(loop_blocks) > 3:
recommendations.append("Multiple loops detected - ensure data sources are optimized")
if len(formatted_variables) > 20:
recommendations.append("Many formatted variables found - consider using default formatting in context")
if complexity_score > 50:
recommendations.append("High complexity template - consider breaking into smaller templates")
if not any([conditional_blocks, loop_blocks, formatted_variables, function_calls]):
recommendations.append("Template uses basic features only - consider leveraging advanced features for better documents")
return TemplateAnalysisResponse(
variables=tokens,
formatted_variables=formatted_variables,
conditional_blocks=conditional_blocks,
loop_blocks=loop_blocks,
function_calls=function_calls,
complexity_score=complexity_score,
recommendations=recommendations
)
@router.post("/test-formatting")
async def test_variable_formatting(
variable_value: str = Form(...),
format_spec: str = Form(...),
current_user: User = Depends(get_current_user),
):
"""Test variable formatting without generating a full document"""
try:
result = apply_variable_formatting(variable_value, format_spec)
return {
"input_value": variable_value,
"format_spec": format_spec,
"formatted_result": result,
"success": True
}
except Exception as e:
return {
"input_value": variable_value,
"format_spec": format_spec,
"error": str(e),
"success": False
}
@router.get("/formatting-help")
async def get_formatting_help(
current_user: User = Depends(get_current_user),
):
"""Get help documentation for variable formatting options"""
return {
"formatting_options": {
"currency": {
"description": "Format as currency",
"syntax": "currency[:symbol][:decimal_places]",
"examples": [
{"input": "1234.56", "format": "currency", "output": "$1,234.56"},
{"input": "1234.56", "format": "currency:€", "output": "€1,234.56"},
{"input": "1234.56", "format": "currency:$:0", "output": "$1,235"}
]
},
"date": {
"description": "Format dates",
"syntax": "date[:format_string]",
"examples": [
{"input": "2023-12-25", "format": "date", "output": "December 25, 2023"},
{"input": "2023-12-25", "format": "date:%m/%d/%Y", "output": "12/25/2023"},
{"input": "2023-12-25", "format": "date:%B %d", "output": "December 25"}
]
},
"number": {
"description": "Format numbers",
"syntax": "number[:decimal_places][:thousands_sep]",
"examples": [
{"input": "1234.5678", "format": "number", "output": "1,234.57"},
{"input": "1234.5678", "format": "number:1", "output": "1,234.6"},
{"input": "1234.5678", "format": "number:2: ", "output": "1 234.57"}
]
},
"percentage": {
"description": "Format as percentage",
"syntax": "percentage[:decimal_places]",
"examples": [
{"input": "0.1234", "format": "percentage", "output": "0.1%"},
{"input": "12.34", "format": "percentage:2", "output": "12.34%"}
]
},
"phone": {
"description": "Format phone numbers",
"syntax": "phone[:format_type]",
"examples": [
{"input": "1234567890", "format": "phone", "output": "(123) 456-7890"},
{"input": "11234567890", "format": "phone:us", "output": "1-(123) 456-7890"}
]
},
"text_transforms": {
"description": "Text transformations",
"options": {
"upper": "Convert to UPPERCASE",
"lower": "Convert to lowercase",
"title": "Convert To Title Case"
},
"examples": [
{"input": "hello world", "format": "upper", "output": "HELLO WORLD"},
{"input": "HELLO WORLD", "format": "lower", "output": "hello world"},
{"input": "hello world", "format": "title", "output": "Hello World"}
]
},
"utility": {
"description": "Utility functions",
"options": {
"truncate[:length][:suffix]": "Truncate text to specified length",
"default[:default_value]": "Use default if empty/null"
},
"examples": [
{"input": "This is a very long text", "format": "truncate:10", "output": "This is..."},
{"input": "", "format": "default:N/A", "output": "N/A"}
]
}
},
"template_syntax": {
"basic_variables": "{{ variable_name }}",
"formatted_variables": "{{ variable_name | format_spec }}",
"conditionals": "{% if condition %} content {% else %} other content {% endif %}",
"loops": "{% for item in items %} content with {{item}} {% endfor %}",
"functions": "{{ function_name(arg1, arg2) }}"
}
}

View File

@@ -0,0 +1,551 @@
"""
Advanced Template Variables API
This API provides comprehensive variable management for document templates including:
- Variable definition and configuration
- Context-specific value management
- Advanced processing with conditional logic and calculations
- Variable testing and validation
"""
from __future__ import annotations
from typing import List, Optional, Dict, Any, Union
from fastapi import APIRouter, Depends, HTTPException, status, Query, Body
from sqlalchemy.orm import Session, joinedload
from sqlalchemy import func, or_, and_
from pydantic import BaseModel, Field
from datetime import datetime
from app.database.base import get_db
from app.auth.security import get_current_user
from app.models.user import User
from app.models.template_variables import (
TemplateVariable, VariableContext, VariableAuditLog,
VariableType, VariableTemplate, VariableGroup
)
from app.services.advanced_variables import VariableProcessor
from app.services.query_utils import paginate_with_total
router = APIRouter()
# Pydantic schemas for API
class VariableCreate(BaseModel):
name: str = Field(..., max_length=100, description="Unique variable name")
display_name: Optional[str] = Field(None, max_length=200)
description: Optional[str] = None
variable_type: VariableType = VariableType.STRING
required: bool = False
default_value: Optional[str] = None
formula: Optional[str] = None
conditional_logic: Optional[Dict[str, Any]] = None
data_source_query: Optional[str] = None
lookup_table: Optional[str] = None
lookup_key_field: Optional[str] = None
lookup_value_field: Optional[str] = None
validation_rules: Optional[Dict[str, Any]] = None
format_pattern: Optional[str] = None
depends_on: Optional[List[str]] = None
scope: str = "global"
category: Optional[str] = None
tags: Optional[List[str]] = None
cache_duration_minutes: int = 0
class VariableUpdate(BaseModel):
display_name: Optional[str] = None
description: Optional[str] = None
required: Optional[bool] = None
active: Optional[bool] = None
default_value: Optional[str] = None
formula: Optional[str] = None
conditional_logic: Optional[Dict[str, Any]] = None
data_source_query: Optional[str] = None
lookup_table: Optional[str] = None
lookup_key_field: Optional[str] = None
lookup_value_field: Optional[str] = None
validation_rules: Optional[Dict[str, Any]] = None
format_pattern: Optional[str] = None
depends_on: Optional[List[str]] = None
category: Optional[str] = None
tags: Optional[List[str]] = None
cache_duration_minutes: Optional[int] = None
class VariableResponse(BaseModel):
id: int
name: str
display_name: Optional[str]
description: Optional[str]
variable_type: VariableType
required: bool
active: bool
default_value: Optional[str]
scope: str
category: Optional[str]
tags: Optional[List[str]]
created_at: datetime
updated_at: Optional[datetime]
class Config:
from_attributes = True
class VariableContextSet(BaseModel):
variable_name: str
value: Any
context_type: str = "global"
context_id: str = "default"
class VariableTestRequest(BaseModel):
variables: List[str]
context_type: str = "global"
context_id: str = "default"
test_context: Optional[Dict[str, Any]] = None
class VariableTestResponse(BaseModel):
resolved: Dict[str, Any]
unresolved: List[str]
processing_time_ms: float
errors: List[str]
class VariableAuditResponse(BaseModel):
id: int
variable_name: str
context_type: Optional[str]
context_id: Optional[str]
old_value: Optional[str]
new_value: Optional[str]
change_type: str
change_reason: Optional[str]
changed_by: Optional[str]
changed_at: datetime
class Config:
from_attributes = True
@router.post("/variables/", response_model=VariableResponse)
async def create_variable(
variable_data: VariableCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Create a new template variable with advanced features"""
# Check if variable name already exists
existing = db.query(TemplateVariable).filter(
TemplateVariable.name == variable_data.name
).first()
if existing:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Variable with name '{variable_data.name}' already exists"
)
# Validate dependencies
if variable_data.depends_on:
for dep_name in variable_data.depends_on:
dep_var = db.query(TemplateVariable).filter(
TemplateVariable.name == dep_name,
TemplateVariable.active == True
).first()
if not dep_var:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Dependency variable '{dep_name}' not found"
)
# Create variable
variable = TemplateVariable(
name=variable_data.name,
display_name=variable_data.display_name,
description=variable_data.description,
variable_type=variable_data.variable_type,
required=variable_data.required,
default_value=variable_data.default_value,
formula=variable_data.formula,
conditional_logic=variable_data.conditional_logic,
data_source_query=variable_data.data_source_query,
lookup_table=variable_data.lookup_table,
lookup_key_field=variable_data.lookup_key_field,
lookup_value_field=variable_data.lookup_value_field,
validation_rules=variable_data.validation_rules,
format_pattern=variable_data.format_pattern,
depends_on=variable_data.depends_on,
scope=variable_data.scope,
category=variable_data.category,
tags=variable_data.tags,
cache_duration_minutes=variable_data.cache_duration_minutes,
created_by=current_user.username,
active=True
)
db.add(variable)
db.commit()
db.refresh(variable)
return variable
@router.get("/variables/", response_model=List[VariableResponse])
async def list_variables(
skip: int = Query(0, ge=0),
limit: int = Query(50, ge=1, le=200),
category: Optional[str] = Query(None),
variable_type: Optional[VariableType] = Query(None),
active_only: bool = Query(True),
search: Optional[str] = Query(None),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""List template variables with filtering options"""
query = db.query(TemplateVariable)
if active_only:
query = query.filter(TemplateVariable.active == True)
if category:
query = query.filter(TemplateVariable.category == category)
if variable_type:
query = query.filter(TemplateVariable.variable_type == variable_type)
if search:
search_filter = f"%{search}%"
query = query.filter(
or_(
TemplateVariable.name.ilike(search_filter),
TemplateVariable.display_name.ilike(search_filter),
TemplateVariable.description.ilike(search_filter)
)
)
query = query.order_by(TemplateVariable.category, TemplateVariable.name)
variables, _ = paginate_with_total(query, skip, limit, False)
return variables
@router.get("/variables/{variable_id}", response_model=VariableResponse)
async def get_variable(
variable_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Get a specific variable by ID"""
variable = db.query(TemplateVariable).filter(
TemplateVariable.id == variable_id
).first()
if not variable:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Variable not found"
)
return variable
@router.put("/variables/{variable_id}", response_model=VariableResponse)
async def update_variable(
variable_id: int,
variable_data: VariableUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Update a template variable"""
variable = db.query(TemplateVariable).filter(
TemplateVariable.id == variable_id
).first()
if not variable:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Variable not found"
)
# Update fields that are provided
update_data = variable_data.dict(exclude_unset=True)
for field, value in update_data.items():
setattr(variable, field, value)
db.commit()
db.refresh(variable)
return variable
@router.delete("/variables/{variable_id}")
async def delete_variable(
variable_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Delete a template variable (soft delete by setting active=False)"""
variable = db.query(TemplateVariable).filter(
TemplateVariable.id == variable_id
).first()
if not variable:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Variable not found"
)
# Soft delete
variable.active = False
db.commit()
return {"message": "Variable deleted successfully"}
@router.post("/variables/test", response_model=VariableTestResponse)
async def test_variables(
test_request: VariableTestRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Test variable resolution with given context"""
import time
start_time = time.time()
errors = []
try:
processor = VariableProcessor(db)
resolved, unresolved = processor.resolve_variables(
variables=test_request.variables,
context_type=test_request.context_type,
context_id=test_request.context_id,
base_context=test_request.test_context or {}
)
processing_time = (time.time() - start_time) * 1000
return VariableTestResponse(
resolved=resolved,
unresolved=unresolved,
processing_time_ms=processing_time,
errors=errors
)
except Exception as e:
processing_time = (time.time() - start_time) * 1000
errors.append(str(e))
return VariableTestResponse(
resolved={},
unresolved=test_request.variables,
processing_time_ms=processing_time,
errors=errors
)
@router.post("/variables/set-value")
async def set_variable_value(
context_data: VariableContextSet,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Set a variable value in a specific context"""
processor = VariableProcessor(db)
success = processor.set_variable_value(
variable_name=context_data.variable_name,
value=context_data.value,
context_type=context_data.context_type,
context_id=context_data.context_id,
user_name=current_user.username
)
if not success:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Failed to set variable value"
)
return {"message": "Variable value set successfully"}
@router.get("/variables/{variable_id}/contexts")
async def get_variable_contexts(
variable_id: int,
context_type: Optional[str] = Query(None),
skip: int = Query(0, ge=0),
limit: int = Query(50, ge=1, le=200),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Get all contexts where this variable has values"""
query = db.query(VariableContext).filter(
VariableContext.variable_id == variable_id
)
if context_type:
query = query.filter(VariableContext.context_type == context_type)
query = query.order_by(VariableContext.context_type, VariableContext.context_id)
contexts, total = paginate_with_total(query, skip, limit, True)
return {
"items": [
{
"context_type": ctx.context_type,
"context_id": ctx.context_id,
"value": ctx.value,
"computed_value": ctx.computed_value,
"is_valid": ctx.is_valid,
"validation_errors": ctx.validation_errors,
"last_computed_at": ctx.last_computed_at
}
for ctx in contexts
],
"total": total
}
@router.get("/variables/{variable_id}/audit", response_model=List[VariableAuditResponse])
async def get_variable_audit_log(
variable_id: int,
skip: int = Query(0, ge=0),
limit: int = Query(20, ge=1, le=100),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Get audit log for a variable"""
query = db.query(VariableAuditLog, TemplateVariable.name).join(
TemplateVariable, VariableAuditLog.variable_id == TemplateVariable.id
).filter(
VariableAuditLog.variable_id == variable_id
).order_by(VariableAuditLog.changed_at.desc())
audit_logs, _ = paginate_with_total(query, skip, limit, False)
return [
VariableAuditResponse(
id=log.id,
variable_name=var_name,
context_type=log.context_type,
context_id=log.context_id,
old_value=log.old_value,
new_value=log.new_value,
change_type=log.change_type,
change_reason=log.change_reason,
changed_by=log.changed_by,
changed_at=log.changed_at
)
for log, var_name in audit_logs
]
@router.get("/templates/{template_id}/variables")
async def get_template_variables(
template_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Get all variables associated with a template"""
processor = VariableProcessor(db)
variables = processor.get_variables_for_template(template_id)
return {"variables": variables}
@router.post("/templates/{template_id}/variables/{variable_id}")
async def associate_variable_with_template(
template_id: int,
variable_id: int,
override_default: Optional[str] = Body(None),
override_required: Optional[bool] = Body(None),
display_order: int = Body(0),
group_name: Optional[str] = Body(None),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Associate a variable with a template"""
# Check if association already exists
existing = db.query(VariableTemplate).filter(
VariableTemplate.template_id == template_id,
VariableTemplate.variable_id == variable_id
).first()
if existing:
# Update existing association
existing.override_default = override_default
existing.override_required = override_required
existing.display_order = display_order
existing.group_name = group_name
else:
# Create new association
association = VariableTemplate(
template_id=template_id,
variable_id=variable_id,
override_default=override_default,
override_required=override_required,
display_order=display_order,
group_name=group_name
)
db.add(association)
db.commit()
return {"message": "Variable associated with template successfully"}
@router.delete("/templates/{template_id}/variables/{variable_id}")
async def remove_variable_from_template(
template_id: int,
variable_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Remove variable association from template"""
association = db.query(VariableTemplate).filter(
VariableTemplate.template_id == template_id,
VariableTemplate.variable_id == variable_id
).first()
if not association:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Variable association not found"
)
db.delete(association)
db.commit()
return {"message": "Variable removed from template successfully"}
@router.get("/categories")
async def get_variable_categories(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Get list of variable categories"""
categories = db.query(
TemplateVariable.category,
func.count(TemplateVariable.id).label('count')
).filter(
TemplateVariable.active == True,
TemplateVariable.category.isnot(None)
).group_by(TemplateVariable.category).order_by(TemplateVariable.category).all()
return [
{"category": cat, "count": count}
for cat, count in categories
]

View File

@@ -20,6 +20,12 @@ from app.auth.security import (
get_current_user,
get_admin_user,
)
from app.utils.enhanced_auth import (
validate_and_authenticate_user,
PasswordValidator,
AccountLockoutManager,
)
from app.utils.session_manager import SessionManager, get_session_manager
from app.auth.schemas import (
Token,
UserCreate,
@@ -36,8 +42,13 @@ logger = get_logger("auth")
@router.post("/login", response_model=Token)
async def login(login_data: LoginRequest, request: Request, db: Session = Depends(get_db)):
"""Login endpoint"""
async def login(
login_data: LoginRequest,
request: Request,
db: Session = Depends(get_db),
session_manager: SessionManager = Depends(get_session_manager)
):
"""Enhanced login endpoint with session management and security features"""
client_ip = request.client.host if request.client else "unknown"
user_agent = request.headers.get("user-agent", "")
@@ -48,30 +59,38 @@ async def login(login_data: LoginRequest, request: Request, db: Session = Depend
user_agent=user_agent
)
user = authenticate_user(db, login_data.username, login_data.password)
if not user:
log_auth_attempt(
username=login_data.username,
success=False,
ip_address=client_ip,
user_agent=user_agent,
error="Invalid credentials"
)
# Use enhanced authentication with lockout protection
user, auth_errors = validate_and_authenticate_user(
db, login_data.username, login_data.password, request
)
if not user or auth_errors:
error_message = auth_errors[0] if auth_errors else "Incorrect username or password"
logger.warning(
"Login failed - invalid credentials",
"Login failed - enhanced auth",
username=login_data.username,
client_ip=client_ip
client_ip=client_ip,
errors=auth_errors
)
# Get lockout info for response headers
lockout_info = AccountLockoutManager.get_lockout_info(db, login_data.username)
headers = {"WWW-Authenticate": "Bearer"}
if lockout_info["is_locked"]:
headers["X-Account-Locked"] = "true"
headers["X-Unlock-Time"] = lockout_info["unlock_time"] or ""
else:
headers["X-Attempts-Remaining"] = str(lockout_info["attempts_remaining"])
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
detail=error_message,
headers=headers,
)
# Update last login
user.last_login = datetime.now(timezone.utc)
db.commit()
# Successful authentication - create tokens
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
access_token = create_access_token(
data={"sub": user.username}, expires_delta=access_token_expires
@@ -83,14 +102,8 @@ async def login(login_data: LoginRequest, request: Request, db: Session = Depend
db=db,
)
log_auth_attempt(
username=login_data.username,
success=True,
ip_address=client_ip,
user_agent=user_agent
)
logger.info(
"Login successful",
"Login successful - enhanced auth",
username=login_data.username,
user_id=user.id,
client_ip=client_ip
@@ -105,7 +118,15 @@ async def register(
db: Session = Depends(get_db),
current_user: User = Depends(get_admin_user) # Only admins can create users
):
"""Register new user (admin only)"""
"""Register new user with password validation (admin only)"""
# Validate password strength
is_valid, password_errors = PasswordValidator.validate_password_strength(user_data.password)
if not is_valid:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Password validation failed: {'; '.join(password_errors)}"
)
# Check if username or email already exists
existing_user = db.query(User).filter(
(User.username == user_data.username) | (User.email == user_data.email)
@@ -130,6 +151,12 @@ async def register(
db.commit()
db.refresh(new_user)
logger.info(
"User registered",
username=new_user.username,
created_by=current_user.username
)
return new_user
@@ -257,4 +284,76 @@ async def update_theme_preference(
current_user.theme_preference = theme_data.theme_preference
db.commit()
return {"message": "Theme preference updated successfully", "theme": theme_data.theme_preference}
return {"message": "Theme preference updated successfully", "theme": theme_data.theme_preference}
@router.post("/validate-password")
async def validate_password(password_data: dict):
"""Validate password strength and return detailed feedback"""
password = password_data.get("password", "")
if not password:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Password is required"
)
is_valid, errors = PasswordValidator.validate_password_strength(password)
strength_score = PasswordValidator.generate_password_strength_score(password)
return {
"is_valid": is_valid,
"errors": errors,
"strength_score": strength_score,
"strength_level": (
"Very Weak" if strength_score < 20 else
"Weak" if strength_score < 40 else
"Fair" if strength_score < 60 else
"Good" if strength_score < 80 else
"Strong"
)
}
@router.get("/account-status/{username}")
async def get_account_status(
username: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_admin_user) # Admin only endpoint
):
"""Get account lockout status and security information (admin only)"""
lockout_info = AccountLockoutManager.get_lockout_info(db, username)
# Get recent login attempts
from app.utils.enhanced_auth import SuspiciousActivityDetector
is_suspicious, warnings = SuspiciousActivityDetector.is_login_suspicious(
db, username, "admin-check", "admin-request"
)
return {
"username": username,
"lockout_info": lockout_info,
"suspicious_activity": {
"is_suspicious": is_suspicious,
"warnings": warnings
}
}
@router.post("/unlock-account/{username}")
async def unlock_account(
username: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_admin_user) # Admin only endpoint
):
"""Manually unlock a user account (admin only)"""
# Reset failed attempts by recording a successful "admin unlock"
AccountLockoutManager.reset_failed_attempts(db, username)
logger.info(
"Account manually unlocked",
username=username,
unlocked_by=current_user.username
)
return {"message": f"Account '{username}' has been unlocked"}

View File

@@ -34,6 +34,17 @@ from app.models.billing import (
BillingStatementItem, StatementStatus
)
from app.services.billing import BillingStatementService, StatementGenerationError
from app.services.statement_generation import (
generate_single_statement as _svc_generate_single_statement,
parse_period_month as _svc_parse_period_month,
render_statement_html as _svc_render_statement_html,
)
from app.services.batch_generation import (
prepare_batch_parameters as _svc_prepare_batch_parameters,
make_batch_id as _svc_make_batch_id,
compute_estimated_completion as _svc_compute_eta,
persist_batch_results as _svc_persist_batch_results,
)
router = APIRouter()
@@ -41,33 +52,29 @@ router = APIRouter()
# Initialize logger for billing operations
billing_logger = StructuredLogger("billing_operations", "INFO")
# Realtime WebSocket subscriber registry: batch_id -> set[WebSocket]
_subscribers_by_batch: Dict[str, Set[WebSocket]] = {}
_subscribers_lock = asyncio.Lock()
# Import WebSocket pool services
from app.middleware.websocket_middleware import get_websocket_manager
from app.services.websocket_pool import WebSocketMessage
# WebSocket manager for batch progress notifications
websocket_manager = get_websocket_manager()
async def _notify_progress_subscribers(progress: "BatchProgress") -> None:
"""Broadcast latest progress to active subscribers of a batch."""
"""Broadcast latest progress to active subscribers of a batch using WebSocket pool."""
batch_id = progress.batch_id
message = {"type": "progress", "data": progress.model_dump()}
async with _subscribers_lock:
sockets = list(_subscribers_by_batch.get(batch_id, set()))
if not sockets:
return
dead: List[WebSocket] = []
for ws in sockets:
try:
await ws.send_json(message)
except Exception:
dead.append(ws)
if dead:
async with _subscribers_lock:
bucket = _subscribers_by_batch.get(batch_id)
if bucket:
for ws in dead:
bucket.discard(ws)
if not bucket:
_subscribers_by_batch.pop(batch_id, None)
topic = f"batch_progress_{batch_id}"
# Use the WebSocket manager to broadcast to topic
sent_count = await websocket_manager.broadcast_to_topic(
topic=topic,
message_type="progress",
data=progress.model_dump()
)
billing_logger.debug("Broadcast batch progress update",
batch_id=batch_id,
subscribers_notified=sent_count)
def _round(value: Optional[float]) -> float:
@@ -606,21 +613,8 @@ progress_store = BatchProgressStore()
def _parse_period_month(period: Optional[str]) -> Optional[tuple[date, date]]:
"""Parse period in the form YYYY-MM and return (start_date, end_date) inclusive.
Returns None when period is not provided or invalid.
"""
if not period:
return None
m = re.fullmatch(r"(\d{4})-(\d{2})", str(period).strip())
if not m:
return None
year = int(m.group(1))
month = int(m.group(2))
if month < 1 or month > 12:
return None
from calendar import monthrange
last_day = monthrange(year, month)[1]
return date(year, month, 1), date(year, month, last_day)
"""Parse YYYY-MM period; delegates to service helper for consistency."""
return _svc_parse_period_month(period)
def _render_statement_html(
@@ -633,80 +627,25 @@ def _render_statement_html(
totals: StatementTotals,
unbilled_entries: List[StatementEntry],
) -> str:
"""Create a simple, self-contained HTML statement string."""
# Rows for unbilled entries
def _fmt(val: Optional[float]) -> str:
try:
return f"{float(val or 0):.2f}"
except Exception:
return "0.00"
rows = []
for e in unbilled_entries:
rows.append(
f"<tr><td>{e.date.isoformat() if e.date else ''}</td><td>{e.t_code}</td><td>{(e.description or '').replace('<','&lt;').replace('>','&gt;')}</td>"
f"<td style='text-align:right'>{_fmt(e.quantity)}</td><td style='text-align:right'>{_fmt(e.rate)}</td><td style='text-align:right'>{_fmt(e.amount)}</td></tr>"
)
rows_html = "\n".join(rows) if rows else "<tr><td colspan='6' style='text-align:center;color:#666'>No unbilled entries</td></tr>"
period_html = f"<div><strong>Period:</strong> {period}</div>" if period else ""
html = f"""
<!DOCTYPE html>
<html lang=\"en\">
<head>
<meta charset=\"utf-8\" />
<title>Statement {file_no}</title>
<style>
body {{ font-family: -apple-system, BlinkMacSystemFont, Segoe UI, Roboto, Helvetica, Arial, sans-serif; margin: 24px; }}
h1 {{ margin: 0 0 8px 0; }}
.meta {{ color: #444; margin-bottom: 16px; }}
table {{ border-collapse: collapse; width: 100%; }}
th, td {{ border: 1px solid #ddd; padding: 8px; font-size: 14px; }}
th {{ background: #f6f6f6; text-align: left; }}
.totals {{ margin: 16px 0; display: grid; grid-template-columns: repeat(auto-fit, minmax(180px, 1fr)); gap: 8px; }}
.totals div {{ background: #fafafa; border: 1px solid #eee; padding: 8px; }}
</style>
</head>
<body>
<h1>Statement</h1>
<div class=\"meta\">
<div><strong>File:</strong> {file_no}</div>
<div><strong>Client:</strong> {client_name or ''}</div>
<div><strong>Matter:</strong> {matter or ''}</div>
<div><strong>As of:</strong> {as_of_iso}</div>
{period_html}
</div>
<div class=\"totals\">
<div><strong>Charges (billed)</strong><br/>${_fmt(totals.charges_billed)}</div>
<div><strong>Charges (unbilled)</strong><br/>${_fmt(totals.charges_unbilled)}</div>
<div><strong>Charges (total)</strong><br/>${_fmt(totals.charges_total)}</div>
<div><strong>Payments</strong><br/>${_fmt(totals.payments)}</div>
<div><strong>Trust balance</strong><br/>${_fmt(totals.trust_balance)}</div>
<div><strong>Current balance</strong><br/>${_fmt(totals.current_balance)}</div>
</div>
<h2>Unbilled Entries</h2>
<table>
<thead>
<tr>
<th>Date</th>
<th>Code</th>
<th>Description</th>
<th style=\"text-align:right\">Qty</th>
<th style=\"text-align:right\">Rate</th>
<th style=\"text-align:right\">Amount</th>
</tr>
</thead>
<tbody>
{rows_html}
</tbody>
</table>
</body>
</html>
"""
return html
"""Create statement HTML via service helper while preserving API models."""
totals_dict: Dict[str, float] = {
"charges_billed": totals.charges_billed,
"charges_unbilled": totals.charges_unbilled,
"charges_total": totals.charges_total,
"payments": totals.payments,
"trust_balance": totals.trust_balance,
"current_balance": totals.current_balance,
}
entries_dict: List[Dict[str, Any]] = [e.model_dump() for e in (unbilled_entries or [])]
return _svc_render_statement_html(
file_no=file_no,
client_name=client_name,
matter=matter,
as_of_iso=as_of_iso,
period=period,
totals=totals_dict,
unbilled_entries=entries_dict,
)
def _generate_single_statement(
@@ -714,118 +653,28 @@ def _generate_single_statement(
period: Optional[str],
db: Session
) -> GeneratedStatementMeta:
"""
Internal helper to generate a statement for a single file.
Args:
file_no: File number to generate statement for
period: Optional period filter (YYYY-MM format)
db: Database session
Returns:
GeneratedStatementMeta with file metadata and export path
Raises:
HTTPException: If file not found or generation fails
"""
file_obj = (
db.query(File)
.options(joinedload(File.owner))
.filter(File.file_no == file_no)
.first()
)
if not file_obj:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"File {file_no} not found",
)
# Optional period filtering (YYYY-MM)
date_range = _parse_period_month(period)
q = db.query(Ledger).filter(Ledger.file_no == file_no)
if date_range:
start_date, end_date = date_range
q = q.filter(Ledger.date >= start_date).filter(Ledger.date <= end_date)
entries: List[Ledger] = q.all()
CHARGE_TYPES = {"2", "3", "4"}
charges_billed = sum(e.amount for e in entries if e.t_type in CHARGE_TYPES and e.billed == "Y")
charges_unbilled = sum(e.amount for e in entries if e.t_type in CHARGE_TYPES and e.billed != "Y")
charges_total = charges_billed + charges_unbilled
payments_total = sum(e.amount for e in entries if e.t_type == "5")
trust_balance = file_obj.trust_bal or 0.0
current_balance = charges_total - payments_total
unbilled_entries = [
StatementEntry(
id=e.id,
date=e.date,
t_code=e.t_code,
t_type=e.t_type,
description=e.note,
quantity=e.quantity or 0.0,
rate=e.rate or 0.0,
amount=e.amount,
)
for e in entries
if e.t_type in CHARGE_TYPES and e.billed != "Y"
]
client_name = None
if file_obj.owner:
client_name = f"{file_obj.owner.first or ''} {file_obj.owner.last}".strip()
as_of_iso = datetime.now(timezone.utc).isoformat()
"""Generate a single statement via service and adapt to API response model."""
data = _svc_generate_single_statement(file_no, period, db)
totals = data.get("totals", {})
totals_model = StatementTotals(
charges_billed=_round(charges_billed),
charges_unbilled=_round(charges_unbilled),
charges_total=_round(charges_total),
payments=_round(payments_total),
trust_balance=_round(trust_balance),
current_balance=_round(current_balance),
charges_billed=float(totals.get("charges_billed", 0.0)),
charges_unbilled=float(totals.get("charges_unbilled", 0.0)),
charges_total=float(totals.get("charges_total", 0.0)),
payments=float(totals.get("payments", 0.0)),
trust_balance=float(totals.get("trust_balance", 0.0)),
current_balance=float(totals.get("current_balance", 0.0)),
)
# Render HTML
html = _render_statement_html(
file_no=file_no,
client_name=client_name or None,
matter=file_obj.regarding,
as_of_iso=as_of_iso,
period=period,
totals=totals_model,
unbilled_entries=unbilled_entries,
)
# Ensure exports directory and write file
exports_dir = Path("exports")
try:
exports_dir.mkdir(exist_ok=True)
except Exception:
# Best-effort: if cannot create, bubble up internal error
raise HTTPException(status_code=500, detail="Unable to create exports directory")
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S_%f")
safe_file_no = str(file_no).replace("/", "_").replace("\\", "_")
filename = f"statement_{safe_file_no}_{timestamp}.html"
export_path = exports_dir / filename
html_bytes = html.encode("utf-8")
with open(export_path, "wb") as f:
f.write(html_bytes)
size = export_path.stat().st_size
return GeneratedStatementMeta(
file_no=file_no,
client_name=client_name or None,
as_of=as_of_iso,
period=period,
file_no=str(data.get("file_no")),
client_name=data.get("client_name"),
as_of=str(data.get("as_of")),
period=data.get("period"),
totals=totals_model,
unbilled_count=len(unbilled_entries),
export_path=str(export_path),
filename=filename,
size=size,
content_type="text/html",
unbilled_count=int(data.get("unbilled_count", 0)),
export_path=str(data.get("export_path")),
filename=str(data.get("filename")),
size=int(data.get("size", 0)),
content_type=str(data.get("content_type", "text/html")),
)
@@ -842,92 +691,48 @@ async def generate_statement(
return _generate_single_statement(payload.file_no, payload.period, db)
async def _ws_authenticate(websocket: WebSocket) -> Optional[User]:
"""Authenticate WebSocket via JWT token in query (?token=) or Authorization header."""
token = websocket.query_params.get("token")
if not token:
try:
auth_header = dict(websocket.headers).get("authorization") or ""
if auth_header.lower().startswith("bearer "):
token = auth_header.split(" ", 1)[1].strip()
except Exception:
token = None
if not token:
return None
username = verify_token(token)
if not username:
return None
db = SessionLocal()
try:
user = db.query(User).filter(User.username == username).first()
if not user or not user.is_active:
return None
return user
finally:
db.close()
async def _ws_keepalive(ws: WebSocket, stop_event: asyncio.Event) -> None:
try:
while not stop_event.is_set():
await asyncio.sleep(25)
try:
await ws.send_json({"type": "ping", "ts": datetime.now(timezone.utc).isoformat()})
except Exception:
break
finally:
stop_event.set()
@router.websocket("/statements/batch-progress/ws/{batch_id}")
async def ws_batch_progress(websocket: WebSocket, batch_id: str):
"""WebSocket: subscribe to real-time updates for a batch_id."""
user = await _ws_authenticate(websocket)
if not user:
await websocket.close(code=4401)
return
await websocket.accept()
# Register
async with _subscribers_lock:
bucket = _subscribers_by_batch.get(batch_id)
if not bucket:
bucket = set()
_subscribers_by_batch[batch_id] = bucket
bucket.add(websocket)
# Send initial snapshot
try:
snapshot = await progress_store.get_progress(batch_id)
await websocket.send_json({"type": "progress", "data": snapshot.model_dump() if snapshot else None})
except Exception:
pass
# Keepalive + receive loop
stop_event: asyncio.Event = asyncio.Event()
ka_task = asyncio.create_task(_ws_keepalive(websocket, stop_event))
try:
while not stop_event.is_set():
try:
msg = await websocket.receive_text()
except WebSocketDisconnect:
break
except Exception:
break
if isinstance(msg, str) and msg.strip() == "ping":
try:
await websocket.send_text("pong")
except Exception:
break
finally:
stop_event.set()
"""WebSocket: subscribe to real-time updates for a batch_id using the WebSocket pool."""
topic = f"batch_progress_{batch_id}"
# Custom message handler for batch progress
async def handle_batch_message(connection_id: str, message: WebSocketMessage):
"""Handle custom messages for batch progress"""
billing_logger.debug("Received batch progress message",
connection_id=connection_id,
batch_id=batch_id,
message_type=message.type)
# Handle any batch-specific message logic here if needed
# Use the WebSocket manager to handle the connection
connection_id = await websocket_manager.handle_connection(
websocket=websocket,
topics={topic},
require_auth=True,
metadata={"batch_id": batch_id, "endpoint": "batch_progress"},
message_handler=handle_batch_message
)
if connection_id:
# Send initial snapshot after connection is established
try:
ka_task.cancel()
except Exception:
pass
async with _subscribers_lock:
bucket = _subscribers_by_batch.get(batch_id)
if bucket and websocket in bucket:
bucket.discard(websocket)
if not bucket:
_subscribers_by_batch.pop(batch_id, None)
snapshot = await progress_store.get_progress(batch_id)
pool = websocket_manager.pool
initial_message = WebSocketMessage(
type="progress",
topic=topic,
data=snapshot.model_dump() if snapshot else None
)
await pool._send_to_connection(connection_id, initial_message)
billing_logger.info("Sent initial batch progress snapshot",
connection_id=connection_id,
batch_id=batch_id)
except Exception as e:
billing_logger.error("Failed to send initial batch progress snapshot",
connection_id=connection_id,
batch_id=batch_id,
error=str(e))
@router.delete("/statements/batch-progress/{batch_id}")
async def cancel_batch_operation(
@@ -1045,25 +850,12 @@ async def batch_generate_statements(
- Batch operation identification for audit trails
- Automatic cleanup of progress data after completion
"""
# Validate request
if not payload.file_numbers:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="At least one file number must be provided"
)
if len(payload.file_numbers) > 50:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Maximum 50 files allowed per batch operation"
)
# Remove duplicates while preserving order
unique_file_numbers = list(dict.fromkeys(payload.file_numbers))
# Validate request and normalize inputs
unique_file_numbers = _svc_prepare_batch_parameters(payload.file_numbers)
# Generate batch ID and timing
start_time = datetime.now(timezone.utc)
batch_id = f"batch_{start_time.strftime('%Y%m%d_%H%M%S')}_{abs(hash(str(unique_file_numbers))) % 10000:04d}"
batch_id = _svc_make_batch_id(unique_file_numbers, start_time)
billing_logger.info(
"Starting batch statement generation",
@@ -1121,7 +913,12 @@ async def batch_generate_statements(
progress.current_file = file_no
progress.files[idx].status = "processing"
progress.files[idx].started_at = current_time.isoformat()
progress.estimated_completion = await _calculate_estimated_completion(progress, current_time)
progress.estimated_completion = _svc_compute_eta(
processed_files=progress.processed_files,
total_files=progress.total_files,
started_at_iso=progress.started_at,
now=current_time,
)
await progress_store.set_progress(progress)
billing_logger.info(
@@ -1288,53 +1085,13 @@ async def batch_generate_statements(
# Persist batch summary and per-file results
try:
def _parse_iso(dt: Optional[str]):
if not dt:
return None
try:
from datetime import datetime as _dt
return _dt.fromisoformat(dt.replace('Z', '+00:00'))
except Exception:
return None
batch_row = BillingBatch(
_svc_persist_batch_results(
db,
batch_id=batch_id,
status=str(progress.status),
total_files=total_files,
successful_files=successful,
failed_files=failed,
started_at=_parse_iso(progress.started_at),
updated_at=_parse_iso(progress.updated_at),
completed_at=_parse_iso(progress.completed_at),
progress=progress,
processing_time_seconds=processing_time,
success_rate=success_rate,
error_message=progress.error_message,
)
db.add(batch_row)
for f in progress.files:
meta = getattr(f, 'statement_meta', None)
filename = None
size = None
if meta is not None:
try:
filename = getattr(meta, 'filename', None)
size = getattr(meta, 'size', None)
except Exception:
pass
if filename is None and isinstance(meta, dict):
filename = meta.get('filename')
size = meta.get('size')
db.add(BillingBatchFile(
batch_id=batch_id,
file_no=f.file_no,
status=str(f.status),
error_message=f.error_message,
filename=filename,
size=size,
started_at=_parse_iso(f.started_at),
completed_at=_parse_iso(f.completed_at),
))
db.commit()
except Exception:
try:
db.rollback()
@@ -1600,6 +1357,34 @@ async def download_latest_statement(
detail="No statements found for requested period",
)
# Filter out any statements created prior to the file's opened date (safety against collisions)
try:
opened_date = getattr(file_obj, "opened", None)
if opened_date:
filtered_by_opened: List[Path] = []
for path in candidates:
name = path.name
# Filename format: statement_{safe_file_no}_YYYYMMDD_HHMMSS_micro.html
m = re.match(rf"^statement_{re.escape(safe_file_no)}_(\d{{8}})_\d{{6}}_\d{{6}}\.html$", name)
if not m:
continue
ymd = m.group(1)
y, mo, d = int(ymd[0:4]), int(ymd[4:6]), int(ymd[6:8])
from datetime import date as _date
stmt_date = _date(y, mo, d)
if stmt_date >= opened_date:
filtered_by_opened.append(path)
if filtered_by_opened:
candidates = filtered_by_opened
else:
# If none meet the opened-date filter, treat as no statements
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="No statements found")
except HTTPException:
raise
except Exception:
# On parse errors, continue with existing candidates
pass
# Choose latest by modification time
candidates.sort(key=lambda p: p.stat().st_mtime, reverse=True)
latest_path = candidates[0]

View File

@@ -15,6 +15,7 @@ from app.models.user import User
from app.auth.security import get_current_user
from app.services.cache import invalidate_search_cache
from app.services.customers_search import apply_customer_filters, apply_customer_sorting, prepare_customer_csv_rows
from app.services.mailing import build_address_from_rolodex
from app.services.query_utils import apply_sorting, paginate_with_total
from app.utils.logging import app_logger
from app.utils.database import db_transaction
@@ -96,6 +97,430 @@ class CustomerResponse(CustomerBase):
@router.get("/phone-book")
async def export_phone_book(
mode: str = Query("numbers", description="Report mode: numbers | addresses | full"),
format: str = Query("csv", description="Output format: csv | html"),
group: Optional[str] = Query(None, description="Filter by customer group (exact match)"),
groups: Optional[List[str]] = Query(None, description="Filter by multiple groups (repeat param)"),
name_prefix: Optional[str] = Query(None, description="Prefix search across first/last name"),
sort_by: Optional[str] = Query("name", description="Sort field: id, name, city, email"),
sort_dir: Optional[str] = Query("asc", description="Sort direction: asc or desc"),
grouping: Optional[str] = Query(
"none",
description="Grouping: none | letter | group | group_letter"
),
page_break: bool = Query(
False,
description="HTML only: start a new page for each top-level group"
),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Generate phone book reports with filters and downloadable CSV/HTML.
Modes:
- numbers: name and phone numbers
- addresses: name, address, and phone numbers
- full: detailed rolodex fields plus phones
"""
allowed_modes = {"numbers", "addresses", "full"}
allowed_formats = {"csv", "html"}
allowed_groupings = {"none", "letter", "group", "group_letter"}
m = (mode or "").strip().lower()
f = (format or "").strip().lower()
if m not in allowed_modes:
raise HTTPException(status_code=400, detail="Invalid mode. Use one of: numbers, addresses, full")
if f not in allowed_formats:
raise HTTPException(status_code=400, detail="Invalid format. Use one of: csv, html")
gmode = (grouping or "none").strip().lower()
if gmode not in allowed_groupings:
raise HTTPException(status_code=400, detail="Invalid grouping. Use one of: none, letter, group, group_letter")
try:
base_query = db.query(Rolodex)
# Only group and name_prefix filtering are required per spec
base_query = apply_customer_filters(
base_query,
search=None,
group=group,
state=None,
groups=groups,
states=None,
name_prefix=name_prefix,
)
base_query = apply_customer_sorting(base_query, sort_by=sort_by, sort_dir=sort_dir)
customers = base_query.options(joinedload(Rolodex.phone_numbers)).all()
def format_phones(entry: Rolodex) -> str:
parts: List[str] = []
try:
for p in (entry.phone_numbers or []):
label = (p.location or "").strip()
if label:
parts.append(f"{label}: {p.phone}")
else:
parts.append(p.phone)
except Exception:
pass
return "; ".join([s for s in parts if s])
def display_name(entry: Rolodex) -> str:
return build_address_from_rolodex(entry).display_name
def first_letter(entry: Rolodex) -> str:
base = (entry.last or entry.first or "").strip()
if not base:
return "#"
ch = base[0].upper()
return ch if ch.isalpha() else "#"
# Apply grouping-specific sort for stable output
if gmode == "letter":
customers.sort(key=lambda c: (first_letter(c), (c.last or "").lower(), (c.first or "").lower()))
elif gmode == "group":
customers.sort(key=lambda c: ((c.group or "Ungrouped").lower(), (c.last or "").lower(), (c.first or "").lower()))
elif gmode == "group_letter":
customers.sort(key=lambda c: ((c.group or "Ungrouped").lower(), first_letter(c), (c.last or "").lower(), (c.first or "").lower()))
def build_csv() -> StreamingResponse:
output = io.StringIO()
writer = csv.writer(output)
include_letter_col = gmode in ("letter", "group_letter")
if m == "numbers":
header = ["Name", "Group"] + (["Letter"] if include_letter_col else []) + ["Phones"]
writer.writerow(header)
for c in customers:
row = [display_name(c), c.group or ""]
if include_letter_col:
row.append(first_letter(c))
row.append(format_phones(c))
writer.writerow(row)
elif m == "addresses":
header = [
"Name", "Group"
] + (["Letter"] if include_letter_col else []) + [
"Address 1", "Address 2", "Address 3", "City", "State", "ZIP", "Phones"
]
writer.writerow(header)
for c in customers:
addr = build_address_from_rolodex(c)
row = [
addr.display_name,
c.group or "",
]
if include_letter_col:
row.append(first_letter(c))
row += [
c.a1 or "",
c.a2 or "",
c.a3 or "",
c.city or "",
c.abrev or "",
c.zip or "",
format_phones(c),
]
writer.writerow(row)
else: # full
header = [
"ID", "Last", "First", "Middle", "Prefix", "Suffix", "Title", "Group"
] + (["Letter"] if include_letter_col else []) + [
"Address 1", "Address 2", "Address 3", "City", "State", "ZIP", "Email", "Phones", "Legal Status",
]
writer.writerow(header)
for c in customers:
row = [
c.id,
c.last or "",
c.first or "",
c.middle or "",
c.prefix or "",
c.suffix or "",
c.title or "",
c.group or "",
]
if include_letter_col:
row.append(first_letter(c))
row += [
c.a1 or "",
c.a2 or "",
c.a3 or "",
c.city or "",
c.abrev or "",
c.zip or "",
c.email or "",
format_phones(c),
c.legal_status or "",
]
writer.writerow(row)
output.seek(0)
from datetime import datetime as _dt
ts = _dt.now().strftime("%Y%m%d_%H%M%S")
filename = f"phone_book_{m}_{ts}.csv"
return StreamingResponse(
iter([output.getvalue()]),
media_type="text/csv",
headers={"Content-Disposition": f"attachment; filename=\"{filename}\""},
)
def build_html() -> StreamingResponse:
# Minimal, printable HTML
def css() -> str:
return """
body { font-family: Arial, sans-serif; margin: 16px; }
h1 { font-size: 18pt; margin-bottom: 8px; }
.meta { color: #666; font-size: 10pt; margin-bottom: 16px; }
.entry { margin-bottom: 10px; }
.name { font-weight: bold; }
.phones, .addr { margin-left: 12px; }
table { border-collapse: collapse; width: 100%; }
th, td { border: 1px solid #ddd; padding: 6px 8px; font-size: 10pt; }
th { background: #f5f5f5; text-align: left; }
.section { margin-top: 18px; }
.section-title { font-size: 14pt; margin: 12px 0; border-bottom: 1px solid #ddd; padding-bottom: 4px; }
.subsection-title { font-size: 12pt; margin: 10px 0; color: #333; }
@media print {
.page-break { page-break-before: always; break-before: page; }
}
"""
title = {
"numbers": "Phone Book (Numbers Only)",
"addresses": "Phone Book (With Addresses)",
"full": "Phone Book (Full Rolodex)",
}[m]
from datetime import datetime as _dt
generated = _dt.now().strftime("%Y-%m-%d %H:%M")
def render_entry_block(c: Rolodex) -> str:
name = display_name(c)
group_text = f" <span class=\"group\">({c.group})</span>" if c.group else ""
phones_html = "".join([f"<div>{p.location + ': ' if p.location else ''}{p.phone}</div>" for p in (c.phone_numbers or [])])
addr_html = ""
if m == "addresses":
addr_lines = build_address_from_rolodex(c).compact_lines(include_name=False)
addr_html = "<div class=\"addr\">" + "".join([f"<div>{line}</div>" for line in addr_lines]) + "</div>"
return f"<div class=\"entry\"><div class=\"name\">{name}{group_text}</div>{addr_html}<div class=\"phones\">{phones_html}</div></div>"
if m in ("numbers", "addresses"):
sections: List[str] = []
if gmode == "none":
blocks = [render_entry_block(c) for c in customers]
html = f"""
<!DOCTYPE html>
<html>
<head>
<meta charset=\"utf-8\" />
<title>{title}</title>
<style>{css()}</style>
<meta name=\"generator\" content=\"delphi\" />
<meta name=\"created\" content=\"{generated}\" />
</head>
<body>
<h1>{title}</h1>
<div class=\"meta\">Generated {generated}. Total entries: {len(customers)}.</div>
{''.join(blocks)}
</body>
</html>
"""
else:
# Build sections according to grouping
if gmode == "letter":
# Letters A-Z plus '#'
letters: List[str] = sorted({first_letter(c) for c in customers})
for idx, letter in enumerate(letters):
entries = [c for c in customers if first_letter(c) == letter]
if not entries:
continue
section_class = "section" + (" page-break" if page_break and idx > 0 else "")
blocks = [render_entry_block(c) for c in entries]
sections.append(f"<div class=\"{section_class}\"><div class=\"section-title\">Letter: {letter}</div>{''.join(blocks)}</div>")
elif gmode == "group":
group_keys: List[str] = sorted({(c.group or "Ungrouped") for c in customers}, key=lambda s: s.lower())
for idx, gkey in enumerate(group_keys):
entries = [c for c in customers if (c.group or "Ungrouped") == gkey]
if not entries:
continue
section_class = "section" + (" page-break" if page_break and idx > 0 else "")
blocks = [render_entry_block(c) for c in entries]
sections.append(f"<div class=\"{section_class}\"><div class=\"section-title\">Group: {gkey}</div>{''.join(blocks)}</div>")
else: # group_letter
group_keys: List[str] = sorted({(c.group or "Ungrouped") for c in customers}, key=lambda s: s.lower())
for gidx, gkey in enumerate(group_keys):
gentries = [c for c in customers if (c.group or "Ungrouped") == gkey]
if not gentries:
continue
section_class = "section" + (" page-break" if page_break and gidx > 0 else "")
subsections: List[str] = []
letters = sorted({first_letter(c) for c in gentries})
for letter in letters:
lentries = [c for c in gentries if first_letter(c) == letter]
if not lentries:
continue
blocks = [render_entry_block(c) for c in lentries]
subsections.append(f"<div class=\"subsection\"><div class=\"subsection-title\">Letter: {letter}</div>{''.join(blocks)}</div>")
sections.append(f"<div class=\"{section_class}\"><div class=\"section-title\">Group: {gkey}</div>{''.join(subsections)}</div>")
html = f"""
<!DOCTYPE html>
<html>
<head>
<meta charset=\"utf-8\" />
<title>{title}</title>
<style>{css()}</style>
<meta name=\"generator\" content=\"delphi\" />
<meta name=\"created\" content=\"{generated}\" />
</head>
<body>
<h1>{title}</h1>
<div class=\"meta\">Generated {generated}. Total entries: {len(customers)}.</div>
{''.join(sections)}
</body>
</html>
"""
else:
# Full table variant
base_header_cells = [
"ID", "Last", "First", "Middle", "Prefix", "Suffix", "Title", "Group",
"Address 1", "Address 2", "Address 3", "City", "State", "ZIP", "Email", "Phones", "Legal Status",
]
def render_rows(items: List[Rolodex]) -> str:
rows_html: List[str] = []
for c in items:
phones = "".join([f"{p.location + ': ' if p.location else ''}{p.phone}" for p in (c.phone_numbers or [])])
cells = [
c.id or "",
c.last or "",
c.first or "",
c.middle or "",
c.prefix or "",
c.suffix or "",
c.title or "",
c.group or "",
c.a1 or "",
c.a2 or "",
c.a3 or "",
c.city or "",
c.abrev or "",
c.zip or "",
c.email or "",
phones,
c.legal_status or "",
]
rows_html.append("<tr>" + "".join([f"<td>{(str(v) if v is not None else '')}</td>" for v in cells]) + "</tr>")
return "".join(rows_html)
if gmode == "none":
rows_html = render_rows(customers)
html = f"""
<!DOCTYPE html>
<html>
<head>
<meta charset=\"utf-8\" />
<title>{title}</title>
<style>{css()}</style>
<meta name=\"generator\" content=\"delphi\" />
<meta name=\"created\" content=\"{generated}\" />
</head>
<body>
<h1>{title}</h1>
<div class=\"meta\">Generated {generated}. Total entries: {len(customers)}.</div>
<table>
<thead><tr>{''.join([f'<th>{h}</th>' for h in base_header_cells])}</tr></thead>
<tbody>
{rows_html}
</tbody>
</table>
</body>
</html>
"""
else:
sections: List[str] = []
if gmode == "letter":
letters: List[str] = sorted({first_letter(c) for c in customers})
for idx, letter in enumerate(letters):
entries = [c for c in customers if first_letter(c) == letter]
if not entries:
continue
section_class = "section" + (" page-break" if page_break and idx > 0 else "")
rows_html = render_rows(entries)
sections.append(
f"<div class=\"{section_class}\"><div class=\"section-title\">Letter: {letter}</div>"
f"<table><thead><tr>{''.join([f'<th>{h}</th>' for h in base_header_cells])}</tr></thead><tbody>{rows_html}</tbody></table></div>"
)
elif gmode == "group":
group_keys: List[str] = sorted({(c.group or "Ungrouped") for c in customers}, key=lambda s: s.lower())
for idx, gkey in enumerate(group_keys):
entries = [c for c in customers if (c.group or "Ungrouped") == gkey]
if not entries:
continue
section_class = "section" + (" page-break" if page_break and idx > 0 else "")
rows_html = render_rows(entries)
sections.append(
f"<div class=\"{section_class}\"><div class=\"section-title\">Group: {gkey}</div>"
f"<table><thead><tr>{''.join([f'<th>{h}</th>' for h in base_header_cells])}</tr></thead><tbody>{rows_html}</tbody></table></div>"
)
else: # group_letter
group_keys: List[str] = sorted({(c.group or "Ungrouped") for c in customers}, key=lambda s: s.lower())
for gidx, gkey in enumerate(group_keys):
gentries = [c for c in customers if (c.group or "Ungrouped") == gkey]
if not gentries:
continue
section_class = "section" + (" page-break" if page_break and gidx > 0 else "")
subsections: List[str] = []
letters = sorted({first_letter(c) for c in gentries})
for letter in letters:
lentries = [c for c in gentries if first_letter(c) == letter]
if not lentries:
continue
rows_html = render_rows(lentries)
subsections.append(
f"<div class=\"subsection\"><div class=\"subsection-title\">Letter: {letter}</div>"
f"<table><thead><tr>{''.join([f'<th>{h}</th>' for h in base_header_cells])}</tr></thead><tbody>{rows_html}</tbody></table></div>"
)
sections.append(f"<div class=\"{section_class}\"><div class=\"section-title\">Group: {gkey}</div>{''.join(subsections)}</div>")
html = f"""
<!DOCTYPE html>
<html>
<head>
<meta charset=\"utf-8\" />
<title>{title}</title>
<style>{css()}</style>
<meta name=\"generator\" content=\"delphi\" />
<meta name=\"created\" content=\"{generated}\" />
</head>
<body>
<h1>{title}</h1>
<div class=\"meta\">Generated {generated}. Total entries: {len(customers)}.</div>
{''.join(sections)}
</body>
</html>
"""
from datetime import datetime as _dt
ts = _dt.now().strftime("%Y%m%d_%H%M%S")
filename = f"phone_book_{m}_{ts}.html"
return StreamingResponse(
iter([html]),
media_type="text/html",
headers={"Content-Disposition": f"attachment; filename=\"{filename}\""},
)
return build_csv() if f == "csv" else build_html()
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error generating phone book: {str(e)}")
@router.get("/search/phone")
async def search_by_phone(
phone: str = Query(..., description="Phone number to search for"),

1103
app/api/deadlines.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,748 @@
"""
Document Workflow Management API
This API provides comprehensive workflow automation management including:
- Workflow creation and configuration
- Event logging and processing
- Execution monitoring and control
- Template management for common workflows
"""
from __future__ import annotations
from typing import List, Optional, Dict, Any, Union
from fastapi import APIRouter, Depends, HTTPException, status, Query, Body
from sqlalchemy.orm import Session, joinedload
from sqlalchemy import func, or_, and_, desc
from pydantic import BaseModel, Field
from datetime import datetime, date, timedelta
import json
from app.database.base import get_db
from app.auth.security import get_current_user
from app.models.user import User
from app.models.document_workflows import (
DocumentWorkflow, WorkflowAction, WorkflowExecution, EventLog,
WorkflowTemplate, WorkflowTriggerType, WorkflowActionType,
ExecutionStatus, WorkflowStatus
)
from app.services.workflow_engine import EventProcessor, WorkflowExecutor
from app.services.query_utils import paginate_with_total
router = APIRouter()
# Pydantic schemas for API
class WorkflowCreate(BaseModel):
name: str = Field(..., max_length=200)
description: Optional[str] = None
trigger_type: WorkflowTriggerType
trigger_conditions: Optional[Dict[str, Any]] = None
delay_minutes: int = Field(0, ge=0)
max_retries: int = Field(3, ge=0, le=10)
retry_delay_minutes: int = Field(30, ge=1)
timeout_minutes: int = Field(60, ge=1)
file_type_filter: Optional[List[str]] = None
status_filter: Optional[List[str]] = None
attorney_filter: Optional[List[str]] = None
client_filter: Optional[List[str]] = None
schedule_cron: Optional[str] = None
schedule_timezone: str = "UTC"
priority: int = Field(5, ge=1, le=10)
category: Optional[str] = None
tags: Optional[List[str]] = None
class WorkflowUpdate(BaseModel):
name: Optional[str] = None
description: Optional[str] = None
status: Optional[WorkflowStatus] = None
trigger_conditions: Optional[Dict[str, Any]] = None
delay_minutes: Optional[int] = None
max_retries: Optional[int] = None
retry_delay_minutes: Optional[int] = None
timeout_minutes: Optional[int] = None
file_type_filter: Optional[List[str]] = None
status_filter: Optional[List[str]] = None
attorney_filter: Optional[List[str]] = None
client_filter: Optional[List[str]] = None
schedule_cron: Optional[str] = None
schedule_timezone: Optional[str] = None
priority: Optional[int] = None
category: Optional[str] = None
tags: Optional[List[str]] = None
class WorkflowActionCreate(BaseModel):
action_type: WorkflowActionType
action_order: int = Field(1, ge=1)
action_name: Optional[str] = None
parameters: Optional[Dict[str, Any]] = None
template_id: Optional[int] = None
output_format: str = "DOCX"
custom_filename_template: Optional[str] = None
email_template_id: Optional[int] = None
email_recipients: Optional[List[str]] = None
email_subject_template: Optional[str] = None
condition: Optional[Dict[str, Any]] = None
continue_on_failure: bool = False
class WorkflowActionUpdate(BaseModel):
action_type: Optional[WorkflowActionType] = None
action_order: Optional[int] = None
action_name: Optional[str] = None
parameters: Optional[Dict[str, Any]] = None
template_id: Optional[int] = None
output_format: Optional[str] = None
custom_filename_template: Optional[str] = None
email_template_id: Optional[int] = None
email_recipients: Optional[List[str]] = None
email_subject_template: Optional[str] = None
condition: Optional[Dict[str, Any]] = None
continue_on_failure: Optional[bool] = None
class WorkflowResponse(BaseModel):
id: int
name: str
description: Optional[str]
status: WorkflowStatus
trigger_type: WorkflowTriggerType
trigger_conditions: Optional[Dict[str, Any]]
delay_minutes: int
max_retries: int
priority: int
category: Optional[str]
tags: Optional[List[str]]
created_by: Optional[str]
created_at: datetime
updated_at: datetime
last_triggered_at: Optional[datetime]
execution_count: int
success_count: int
failure_count: int
class Config:
from_attributes = True
class WorkflowActionResponse(BaseModel):
id: int
workflow_id: int
action_type: WorkflowActionType
action_order: int
action_name: Optional[str]
parameters: Optional[Dict[str, Any]]
template_id: Optional[int]
output_format: str
condition: Optional[Dict[str, Any]]
continue_on_failure: bool
class Config:
from_attributes = True
class WorkflowExecutionResponse(BaseModel):
id: int
workflow_id: int
triggered_by_event_id: Optional[str]
triggered_by_event_type: Optional[str]
context_file_no: Optional[str]
context_client_id: Optional[str]
status: ExecutionStatus
started_at: Optional[datetime]
completed_at: Optional[datetime]
execution_duration_seconds: Optional[int]
retry_count: int
error_message: Optional[str]
generated_documents: Optional[List[Dict[str, Any]]]
class Config:
from_attributes = True
class EventLogCreate(BaseModel):
event_type: str
event_source: str
file_no: Optional[str] = None
client_id: Optional[str] = None
resource_type: Optional[str] = None
resource_id: Optional[str] = None
event_data: Optional[Dict[str, Any]] = None
previous_state: Optional[Dict[str, Any]] = None
new_state: Optional[Dict[str, Any]] = None
class EventLogResponse(BaseModel):
id: int
event_id: str
event_type: str
event_source: str
file_no: Optional[str]
client_id: Optional[str]
resource_type: Optional[str]
resource_id: Optional[str]
event_data: Optional[Dict[str, Any]]
processed: bool
triggered_workflows: Optional[List[int]]
occurred_at: datetime
class Config:
from_attributes = True
class WorkflowTestRequest(BaseModel):
event_type: str
event_data: Optional[Dict[str, Any]] = None
file_no: Optional[str] = None
client_id: Optional[str] = None
class WorkflowStatsResponse(BaseModel):
total_workflows: int
active_workflows: int
total_executions: int
successful_executions: int
failed_executions: int
pending_executions: int
workflows_by_trigger_type: Dict[str, int]
executions_by_day: List[Dict[str, Any]]
# Workflow CRUD endpoints
@router.post("/workflows/", response_model=WorkflowResponse)
async def create_workflow(
workflow_data: WorkflowCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Create a new document workflow"""
# Check for duplicate names
existing = db.query(DocumentWorkflow).filter(
DocumentWorkflow.name == workflow_data.name
).first()
if existing:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Workflow with name '{workflow_data.name}' already exists"
)
# Validate cron expression if provided
if workflow_data.schedule_cron:
try:
from croniter import croniter
croniter(workflow_data.schedule_cron)
except Exception:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid cron expression"
)
# Create workflow
workflow = DocumentWorkflow(
name=workflow_data.name,
description=workflow_data.description,
trigger_type=workflow_data.trigger_type,
trigger_conditions=workflow_data.trigger_conditions,
delay_minutes=workflow_data.delay_minutes,
max_retries=workflow_data.max_retries,
retry_delay_minutes=workflow_data.retry_delay_minutes,
timeout_minutes=workflow_data.timeout_minutes,
file_type_filter=workflow_data.file_type_filter,
status_filter=workflow_data.status_filter,
attorney_filter=workflow_data.attorney_filter,
client_filter=workflow_data.client_filter,
schedule_cron=workflow_data.schedule_cron,
schedule_timezone=workflow_data.schedule_timezone,
priority=workflow_data.priority,
category=workflow_data.category,
tags=workflow_data.tags,
created_by=current_user.username,
status=WorkflowStatus.ACTIVE
)
db.add(workflow)
db.commit()
db.refresh(workflow)
return workflow
@router.get("/workflows/", response_model=List[WorkflowResponse])
async def list_workflows(
skip: int = Query(0, ge=0),
limit: int = Query(50, ge=1, le=200),
status: Optional[WorkflowStatus] = Query(None),
trigger_type: Optional[WorkflowTriggerType] = Query(None),
category: Optional[str] = Query(None),
search: Optional[str] = Query(None),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""List workflows with filtering options"""
query = db.query(DocumentWorkflow)
if status:
query = query.filter(DocumentWorkflow.status == status)
if trigger_type:
query = query.filter(DocumentWorkflow.trigger_type == trigger_type)
if category:
query = query.filter(DocumentWorkflow.category == category)
if search:
search_filter = f"%{search}%"
query = query.filter(
or_(
DocumentWorkflow.name.ilike(search_filter),
DocumentWorkflow.description.ilike(search_filter)
)
)
query = query.order_by(DocumentWorkflow.priority.desc(), DocumentWorkflow.name)
workflows, _ = paginate_with_total(query, skip, limit, False)
return workflows
@router.get("/workflows/{workflow_id}", response_model=WorkflowResponse)
async def get_workflow(
workflow_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Get a specific workflow by ID"""
workflow = db.query(DocumentWorkflow).filter(
DocumentWorkflow.id == workflow_id
).first()
if not workflow:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Workflow not found"
)
return workflow
@router.put("/workflows/{workflow_id}", response_model=WorkflowResponse)
async def update_workflow(
workflow_id: int,
workflow_data: WorkflowUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Update a workflow"""
workflow = db.query(DocumentWorkflow).filter(
DocumentWorkflow.id == workflow_id
).first()
if not workflow:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Workflow not found"
)
# Update fields that are provided
update_data = workflow_data.dict(exclude_unset=True)
for field, value in update_data.items():
setattr(workflow, field, value)
db.commit()
db.refresh(workflow)
return workflow
@router.delete("/workflows/{workflow_id}")
async def delete_workflow(
workflow_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Delete a workflow (soft delete by setting status to archived)"""
workflow = db.query(DocumentWorkflow).filter(
DocumentWorkflow.id == workflow_id
).first()
if not workflow:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Workflow not found"
)
# Soft delete
workflow.status = WorkflowStatus.ARCHIVED
db.commit()
return {"message": "Workflow archived successfully"}
# Workflow Actions endpoints
@router.post("/workflows/{workflow_id}/actions", response_model=WorkflowActionResponse)
async def create_workflow_action(
workflow_id: int,
action_data: WorkflowActionCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Create a new action for a workflow"""
workflow = db.query(DocumentWorkflow).filter(
DocumentWorkflow.id == workflow_id
).first()
if not workflow:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Workflow not found"
)
action = WorkflowAction(
workflow_id=workflow_id,
action_type=action_data.action_type,
action_order=action_data.action_order,
action_name=action_data.action_name,
parameters=action_data.parameters,
template_id=action_data.template_id,
output_format=action_data.output_format,
custom_filename_template=action_data.custom_filename_template,
email_template_id=action_data.email_template_id,
email_recipients=action_data.email_recipients,
email_subject_template=action_data.email_subject_template,
condition=action_data.condition,
continue_on_failure=action_data.continue_on_failure
)
db.add(action)
db.commit()
db.refresh(action)
return action
@router.get("/workflows/{workflow_id}/actions", response_model=List[WorkflowActionResponse])
async def list_workflow_actions(
workflow_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""List actions for a workflow"""
actions = db.query(WorkflowAction).filter(
WorkflowAction.workflow_id == workflow_id
).order_by(WorkflowAction.action_order, WorkflowAction.id).all()
return actions
@router.put("/workflows/{workflow_id}/actions/{action_id}", response_model=WorkflowActionResponse)
async def update_workflow_action(
workflow_id: int,
action_id: int,
action_data: WorkflowActionUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Update a workflow action"""
action = db.query(WorkflowAction).filter(
WorkflowAction.id == action_id,
WorkflowAction.workflow_id == workflow_id
).first()
if not action:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Action not found"
)
# Update fields that are provided
update_data = action_data.dict(exclude_unset=True)
for field, value in update_data.items():
setattr(action, field, value)
db.commit()
db.refresh(action)
return action
@router.delete("/workflows/{workflow_id}/actions/{action_id}")
async def delete_workflow_action(
workflow_id: int,
action_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Delete a workflow action"""
action = db.query(WorkflowAction).filter(
WorkflowAction.id == action_id,
WorkflowAction.workflow_id == workflow_id
).first()
if not action:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Action not found"
)
db.delete(action)
db.commit()
return {"message": "Action deleted successfully"}
# Event Management endpoints
@router.post("/events/", response_model=dict)
async def log_event(
event_data: EventLogCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Log a system event that may trigger workflows"""
processor = EventProcessor(db)
event_id = await processor.log_event(
event_type=event_data.event_type,
event_source=event_data.event_source,
file_no=event_data.file_no,
client_id=event_data.client_id,
user_id=current_user.id,
resource_type=event_data.resource_type,
resource_id=event_data.resource_id,
event_data=event_data.event_data,
previous_state=event_data.previous_state,
new_state=event_data.new_state
)
return {"event_id": event_id, "message": "Event logged successfully"}
@router.get("/events/", response_model=List[EventLogResponse])
async def list_events(
skip: int = Query(0, ge=0),
limit: int = Query(50, ge=1, le=200),
event_type: Optional[str] = Query(None),
file_no: Optional[str] = Query(None),
processed: Optional[bool] = Query(None),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""List system events"""
query = db.query(EventLog)
if event_type:
query = query.filter(EventLog.event_type == event_type)
if file_no:
query = query.filter(EventLog.file_no == file_no)
if processed is not None:
query = query.filter(EventLog.processed == processed)
query = query.order_by(desc(EventLog.occurred_at))
events, _ = paginate_with_total(query, skip, limit, False)
return events
# Execution Management endpoints
@router.get("/executions/", response_model=List[WorkflowExecutionResponse])
async def list_executions(
skip: int = Query(0, ge=0),
limit: int = Query(50, ge=1, le=200),
workflow_id: Optional[int] = Query(None),
status: Optional[ExecutionStatus] = Query(None),
file_no: Optional[str] = Query(None),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""List workflow executions"""
query = db.query(WorkflowExecution)
if workflow_id:
query = query.filter(WorkflowExecution.workflow_id == workflow_id)
if status:
query = query.filter(WorkflowExecution.status == status)
if file_no:
query = query.filter(WorkflowExecution.context_file_no == file_no)
query = query.order_by(desc(WorkflowExecution.started_at))
executions, _ = paginate_with_total(query, skip, limit, False)
return executions
@router.get("/executions/{execution_id}", response_model=WorkflowExecutionResponse)
async def get_execution(
execution_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Get details of a specific execution"""
execution = db.query(WorkflowExecution).filter(
WorkflowExecution.id == execution_id
).first()
if not execution:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Execution not found"
)
return execution
@router.post("/executions/{execution_id}/retry")
async def retry_execution(
execution_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Retry a failed workflow execution"""
execution = db.query(WorkflowExecution).filter(
WorkflowExecution.id == execution_id
).first()
if not execution:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Execution not found"
)
if execution.status not in [ExecutionStatus.FAILED, ExecutionStatus.RETRYING]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Only failed executions can be retried"
)
# Reset execution for retry
execution.status = ExecutionStatus.PENDING
execution.error_message = None
execution.next_retry_at = None
execution.retry_count += 1
db.commit()
# Execute the workflow
executor = WorkflowExecutor(db)
success = await executor.execute_workflow(execution_id)
return {"message": "Execution retried", "success": success}
# Testing and Management endpoints
@router.post("/workflows/{workflow_id}/test")
async def test_workflow(
workflow_id: int,
test_request: WorkflowTestRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Test a workflow with simulated event data"""
workflow = db.query(DocumentWorkflow).filter(
DocumentWorkflow.id == workflow_id
).first()
if not workflow:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Workflow not found"
)
# Create a test event
processor = EventProcessor(db)
event_id = await processor.log_event(
event_type=test_request.event_type,
event_source="workflow_test",
file_no=test_request.file_no,
client_id=test_request.client_id,
user_id=current_user.id,
event_data=test_request.event_data or {}
)
return {"message": "Test event logged", "event_id": event_id}
@router.get("/stats", response_model=WorkflowStatsResponse)
async def get_workflow_stats(
days: int = Query(30, ge=1, le=365),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Get workflow system statistics"""
# Basic counts
total_workflows = db.query(func.count(DocumentWorkflow.id)).scalar()
active_workflows = db.query(func.count(DocumentWorkflow.id)).filter(
DocumentWorkflow.status == WorkflowStatus.ACTIVE
).scalar()
total_executions = db.query(func.count(WorkflowExecution.id)).scalar()
successful_executions = db.query(func.count(WorkflowExecution.id)).filter(
WorkflowExecution.status == ExecutionStatus.COMPLETED
).scalar()
failed_executions = db.query(func.count(WorkflowExecution.id)).filter(
WorkflowExecution.status == ExecutionStatus.FAILED
).scalar()
pending_executions = db.query(func.count(WorkflowExecution.id)).filter(
WorkflowExecution.status.in_([ExecutionStatus.PENDING, ExecutionStatus.RUNNING])
).scalar()
# Workflows by trigger type
trigger_stats = db.query(
DocumentWorkflow.trigger_type,
func.count(DocumentWorkflow.id)
).group_by(DocumentWorkflow.trigger_type).all()
workflows_by_trigger_type = {
trigger.value: count for trigger, count in trigger_stats
}
# Executions by day (for the chart)
cutoff_date = datetime.now() - timedelta(days=days)
daily_stats = db.query(
func.date(WorkflowExecution.started_at).label('date'),
func.count(WorkflowExecution.id).label('count'),
func.sum(func.case((WorkflowExecution.status == ExecutionStatus.COMPLETED, 1), else_=0)).label('successful'),
func.sum(func.case((WorkflowExecution.status == ExecutionStatus.FAILED, 1), else_=0)).label('failed')
).filter(
WorkflowExecution.started_at >= cutoff_date
).group_by(func.date(WorkflowExecution.started_at)).all()
executions_by_day = [
{
'date': row.date.isoformat() if row.date else None,
'total': row.count,
'successful': row.successful or 0,
'failed': row.failed or 0
}
for row in daily_stats
]
return WorkflowStatsResponse(
total_workflows=total_workflows or 0,
active_workflows=active_workflows or 0,
total_executions=total_executions or 0,
successful_executions=successful_executions or 0,
failed_executions=failed_executions or 0,
pending_executions=pending_executions or 0,
workflows_by_trigger_type=workflows_by_trigger_type,
executions_by_day=executions_by_day
)

View File

@@ -7,9 +7,12 @@ from fastapi import APIRouter, Depends, HTTPException, status, Query, UploadFile
from sqlalchemy.orm import Session, joinedload
from sqlalchemy import or_, func, and_, desc, asc, text
from datetime import date, datetime, timezone
import io
import zipfile
import os
import uuid
import shutil
from pathlib import Path
from app.database.base import get_db
from app.api.search_highlight import build_query_tokens
@@ -21,9 +24,17 @@ from app.models.lookups import FormIndex, FormList, Footer, Employee
from app.models.user import User
from app.auth.security import get_current_user
from app.models.additional import Document
from app.models.document_workflows import EventLog
from app.core.logging import get_logger
from app.services.audit import audit_service
from app.services.cache import invalidate_search_cache
from app.models.templates import DocumentTemplate, DocumentTemplateVersion
from app.models.jobs import JobRecord
from app.services.storage import get_default_storage
from app.services.template_merge import extract_tokens_from_bytes, build_context, resolve_tokens, render_docx
from app.services.document_notifications import notify_processing, notify_completed, notify_failed, topic_for_file, ADMIN_DOCUMENTS_TOPIC, get_last_status
from app.middleware.websocket_middleware import get_websocket_manager, WebSocketMessage
from fastapi import WebSocket
router = APIRouter()
@@ -118,6 +129,87 @@ class PaginatedQDROResponse(BaseModel):
total: int
class CurrentStatusResponse(BaseModel):
file_no: str
status: str # processing | completed | failed | unknown
timestamp: Optional[str] = None
data: Optional[Dict[str, Any]] = None
history: Optional[list] = None
@router.get("/current-status/{file_no}", response_model=CurrentStatusResponse)
async def get_current_document_status(
file_no: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
Return last-known document generation status for a file.
Priority:
1) In-memory last broadcast state (processing/completed/failed)
2) If no memory record, check for any uploaded/generated documents and report 'completed'
3) Fallback to 'unknown'
"""
# Build recent history from EventLog (last N events)
history_items = []
try:
recent = (
db.query(EventLog)
.filter(EventLog.file_no == file_no, EventLog.event_type.in_(["document_processing", "document_completed", "document_failed"]))
.order_by(EventLog.occurred_at.desc())
.limit(10)
.all()
)
for ev in recent:
history_items.append({
"type": ev.event_type,
"timestamp": ev.occurred_at.isoformat() if getattr(ev, "occurred_at", None) else None,
"data": ev.event_data or {},
})
except Exception:
history_items = []
# Try in-memory record for current status
last = get_last_status(file_no)
if last:
ts = last.get("timestamp")
iso = ts.isoformat() if hasattr(ts, "isoformat") else None
status_val = str(last.get("status") or "unknown")
# Treat stale 'processing' as unknown if older than 10 minutes
try:
if status_val == "processing" and isinstance(ts, datetime):
age = datetime.now(timezone.utc) - ts
if age.total_seconds() > 600:
status_val = "unknown"
except Exception:
pass
return CurrentStatusResponse(
file_no=file_no,
status=status_val,
timestamp=iso,
data=(last.get("data") or None),
history=history_items,
)
# Fallback: any existing documents imply last status completed
any_doc = db.query(Document).filter(Document.file_no == file_no).order_by(Document.id.desc()).first()
if any_doc:
return CurrentStatusResponse(
file_no=file_no,
status="completed",
timestamp=getattr(any_doc, "upload_date", None).isoformat() if getattr(any_doc, "upload_date", None) else None,
data={
"document_id": any_doc.id,
"filename": any_doc.filename,
"size": any_doc.size,
},
history=history_items,
)
return CurrentStatusResponse(file_no=file_no, status="unknown", history=history_items)
@router.get("/qdros/", response_model=Union[List[QDROResponse], PaginatedQDROResponse])
async def list_qdros(
skip: int = Query(0, ge=0),
@@ -814,6 +906,371 @@ def _merge_template_variables(content: str, variables: Dict[str, Any]) -> str:
return merged
# --- Batch Document Generation (MVP synchronous) ---
class BatchGenerateRequest(BaseModel):
"""Batch generation request using DocumentTemplate system."""
template_id: int
version_id: Optional[int] = None
file_nos: List[str]
output_format: str = "DOCX" # DOCX (default), PDF (not yet supported), HTML (not yet supported)
context: Optional[Dict[str, Any]] = None # additional global context
bundle_zip: bool = False # when true, also create a ZIP bundle of generated outputs
class BatchGenerateItemResult(BaseModel):
file_no: str
status: str # "success" | "error"
document_id: Optional[int] = None
filename: Optional[str] = None
path: Optional[str] = None
url: Optional[str] = None
size: Optional[int] = None
unresolved: Optional[List[str]] = None
error: Optional[str] = None
class BatchGenerateResponse(BaseModel):
job_id: str
template_id: int
version_id: int
total_requested: int
total_success: int
total_failed: int
results: List[BatchGenerateItemResult]
bundle_url: Optional[str] = None
bundle_size: Optional[int] = None
@router.post("/generate-batch", response_model=BatchGenerateResponse)
async def generate_batch_documents(
payload: BatchGenerateRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Synchronously generate documents for multiple files from a template version.
Notes:
- Currently supports DOCX output. PDF/HTML conversion is not yet implemented.
- Saves generated bytes to default storage under uploads/generated/{file_no}/.
- Persists a `Document` record per successful file.
- Returns per-item status with unresolved tokens for transparency.
"""
tpl = db.query(DocumentTemplate).filter(DocumentTemplate.id == payload.template_id).first()
if not tpl:
raise HTTPException(status_code=404, detail="Template not found")
resolved_version_id = payload.version_id or tpl.current_version_id
if not resolved_version_id:
raise HTTPException(status_code=400, detail="Template has no approved/current version")
ver = (
db.query(DocumentTemplateVersion)
.filter(
DocumentTemplateVersion.id == resolved_version_id,
DocumentTemplateVersion.template_id == tpl.id,
)
.first()
)
if not ver:
raise HTTPException(status_code=404, detail="Template version not found")
storage = get_default_storage()
try:
template_bytes = storage.open_bytes(ver.storage_path)
except Exception:
raise HTTPException(status_code=404, detail="Stored template file not found")
tokens = extract_tokens_from_bytes(template_bytes)
results: List[BatchGenerateItemResult] = []
# Pre-normalize file numbers (strip spaces, ignore empties)
requested_files: List[str] = [fn.strip() for fn in (payload.file_nos or []) if fn and str(fn).strip()]
if not requested_files:
raise HTTPException(status_code=400, detail="No file numbers provided")
# Fetch all files in one query
files_map: Dict[str, FileModel] = {
f.file_no: f
for f in db.query(FileModel).options(joinedload(FileModel.owner)).filter(FileModel.file_no.in_(requested_files)).all()
}
generated_items: List[Dict[str, Any]] = [] # capture bytes for optional ZIP
for file_no in requested_files:
# Notify processing started for this file
try:
await notify_processing(
file_no=file_no,
user_id=current_user.id,
data={
"template_id": tpl.id,
"template_name": tpl.name,
"job_id": job_id
}
)
except Exception:
# Don't fail generation if notification fails
pass
file_obj = files_map.get(file_no)
if not file_obj:
# Notify failure
try:
await notify_failed(
file_no=file_no,
user_id=current_user.id,
data={"error": "File not found", "template_id": tpl.id}
)
except Exception:
pass
results.append(
BatchGenerateItemResult(
file_no=file_no,
status="error",
error="File not found",
)
)
continue
# Build per-file context
file_context: Dict[str, Any] = {
"FILE_NO": file_obj.file_no,
"CLIENT_FIRST": getattr(getattr(file_obj, "owner", None), "first", "") or "",
"CLIENT_LAST": getattr(getattr(file_obj, "owner", None), "last", "") or "",
"CLIENT_FULL": (
f"{getattr(getattr(file_obj, 'owner', None), 'first', '') or ''} "
f"{getattr(getattr(file_obj, 'owner', None), 'last', '') or ''}"
).strip(),
"MATTER": file_obj.regarding or "",
"OPENED": file_obj.opened.strftime("%B %d, %Y") if getattr(file_obj, "opened", None) else "",
"ATTORNEY": getattr(file_obj, "empl_num", "") or "",
}
# Merge global context
merged_context = build_context({**(payload.context or {}), **file_context}, "file", file_obj.file_no)
resolved_vars, unresolved_tokens = resolve_tokens(db, tokens, merged_context)
try:
if ver.mime_type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
output_bytes = render_docx(template_bytes, resolved_vars)
output_mime = ver.mime_type
extension = ".docx"
else:
# For non-DOCX templates (e.g., PDF), pass-through content
output_bytes = template_bytes
output_mime = ver.mime_type
extension = ".bin"
# Name and save
safe_name = f"{tpl.name}_{file_obj.file_no}{extension}"
subdir = f"generated/{file_obj.file_no}"
storage_path = storage.save_bytes(content=output_bytes, filename_hint=safe_name, subdir=subdir, content_type=output_mime)
# Persist Document record
abs_or_rel_path = os.path.join("uploads", storage_path).replace("\\", "/")
doc = Document(
file_no=file_obj.file_no,
filename=safe_name,
path=abs_or_rel_path,
description=f"Generated from template '{tpl.name}'",
type=output_mime,
size=len(output_bytes),
uploaded_by=getattr(current_user, "username", None),
)
db.add(doc)
db.commit()
db.refresh(doc)
# Notify successful completion
try:
await notify_completed(
file_no=file_obj.file_no,
user_id=current_user.id,
data={
"template_id": tpl.id,
"template_name": tpl.name,
"document_id": doc.id,
"filename": doc.filename,
"size": doc.size,
"unresolved_tokens": unresolved_tokens or []
}
)
except Exception:
# Don't fail generation if notification fails
pass
results.append(
BatchGenerateItemResult(
file_no=file_obj.file_no,
status="success",
document_id=doc.id,
filename=doc.filename,
path=doc.path,
url=storage.public_url(storage_path),
size=doc.size,
unresolved=unresolved_tokens or [],
)
)
# Keep for bundling
generated_items.append({
"filename": doc.filename,
"storage_path": storage_path,
})
except Exception as e:
# Notify failure
try:
await notify_failed(
file_no=file_obj.file_no,
user_id=current_user.id,
data={
"template_id": tpl.id,
"template_name": tpl.name,
"error": str(e),
"unresolved_tokens": unresolved_tokens or []
}
)
except Exception:
pass
# Best-effort rollback of partial doc add
try:
db.rollback()
except Exception:
pass
results.append(
BatchGenerateItemResult(
file_no=file_obj.file_no,
status="error",
error=str(e),
unresolved=unresolved_tokens or [],
)
)
job_id = str(uuid.uuid4())
total_success = sum(1 for r in results if r.status == "success")
total_failed = sum(1 for r in results if r.status == "error")
bundle_url: Optional[str] = None
bundle_size: Optional[int] = None
# Optionally create a ZIP bundle of generated outputs
bundle_storage_path: Optional[str] = None
if payload.bundle_zip and total_success > 0:
# Stream zip to memory then save via storage adapter
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, mode="w", compression=zipfile.ZIP_DEFLATED) as zf:
for item in generated_items:
try:
file_bytes = storage.open_bytes(item["storage_path"]) # relative path under uploads
# Use clean filename inside zip
zf.writestr(item["filename"], file_bytes)
except Exception:
# Skip missing/unreadable files from bundle; keep job successful
continue
zip_bytes = zip_buffer.getvalue()
safe_zip_name = f"documents_batch_{job_id}.zip"
bundle_storage_path = storage.save_bytes(content=zip_bytes, filename_hint=safe_zip_name, subdir="bundles", content_type="application/zip")
bundle_url = storage.public_url(bundle_storage_path)
bundle_size = len(zip_bytes)
# Persist simple job record
try:
job = JobRecord(
job_id=job_id,
job_type="documents_batch",
status="completed",
requested_by_username=getattr(current_user, "username", None),
started_at=datetime.now(timezone.utc),
completed_at=datetime.now(timezone.utc),
total_requested=len(requested_files),
total_success=total_success,
total_failed=total_failed,
result_storage_path=bundle_storage_path,
result_mime_type=("application/zip" if bundle_storage_path else None),
result_size=bundle_size,
details={
"template_id": tpl.id,
"version_id": ver.id,
"file_nos": requested_files,
},
)
db.add(job)
db.commit()
except Exception:
try:
db.rollback()
except Exception:
pass
return BatchGenerateResponse(
job_id=job_id,
template_id=tpl.id,
version_id=ver.id,
total_requested=len(requested_files),
total_success=total_success,
total_failed=total_failed,
results=results,
bundle_url=bundle_url,
bundle_size=bundle_size,
)
from fastapi.responses import StreamingResponse
class JobStatusResponse(BaseModel):
job_id: str
job_type: str
status: str
total_requested: int
total_success: int
total_failed: int
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
bundle_available: bool = False
bundle_url: Optional[str] = None
bundle_size: Optional[int] = None
@router.get("/jobs/{job_id}", response_model=JobStatusResponse)
async def get_job_status(
job_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
job = db.query(JobRecord).filter(JobRecord.job_id == job_id).first()
if not job:
raise HTTPException(status_code=404, detail="Job not found")
return JobStatusResponse(
job_id=job.job_id,
job_type=job.job_type,
status=job.status,
total_requested=job.total_requested or 0,
total_success=job.total_success or 0,
total_failed=job.total_failed or 0,
started_at=getattr(job, "started_at", None),
completed_at=getattr(job, "completed_at", None),
bundle_available=bool(job.result_storage_path),
bundle_url=(get_default_storage().public_url(job.result_storage_path) if job.result_storage_path else None),
bundle_size=job.result_size,
)
@router.get("/jobs/{job_id}/result")
async def download_job_result(
job_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
job = db.query(JobRecord).filter(JobRecord.job_id == job_id).first()
if not job or not job.result_storage_path:
raise HTTPException(status_code=404, detail="Result not available for this job")
storage = get_default_storage()
try:
content = storage.open_bytes(job.result_storage_path)
except Exception:
raise HTTPException(status_code=404, detail="Stored bundle not found")
# Derive filename
base = os.path.basename(job.result_storage_path)
headers = {
"Content-Disposition": f"attachment; filename=\"{base}\"",
}
return StreamingResponse(iter([content]), media_type=(job.result_mime_type or "application/zip"), headers=headers)
# --- Client Error Logging (for Documents page) ---
class ClientErrorLog(BaseModel):
"""Payload for client-side error logging"""
@@ -894,54 +1351,118 @@ async def upload_document(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Upload a document to a file"""
"""Upload a document to a file with comprehensive security validation and async operations"""
from app.utils.file_security import file_validator, create_upload_directory
from app.services.async_file_operations import async_file_ops, validate_large_upload
from app.services.async_storage import async_storage
file_obj = db.query(FileModel).filter(FileModel.file_no == file_no).first()
if not file_obj:
raise HTTPException(status_code=404, detail="File not found")
if not file.filename:
raise HTTPException(status_code=400, detail="No file uploaded")
# Determine if this is a large file that needs streaming
file_size_estimate = getattr(file, 'size', 0) or 0
use_streaming = file_size_estimate > 10 * 1024 * 1024 # 10MB threshold
allowed_types = [
"application/pdf",
"application/msword",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"image/jpeg",
"image/png"
]
if file.content_type not in allowed_types:
raise HTTPException(status_code=400, detail="Invalid file type")
if use_streaming:
# Use streaming validation for large files
# Enforce the same 10MB limit used for non-streaming uploads
is_valid, error_msg, metadata = await validate_large_upload(
file, category='document', max_size=10 * 1024 * 1024
)
if not is_valid:
raise HTTPException(status_code=400, detail=error_msg)
safe_filename = file_validator.sanitize_filename(file.filename)
file_ext = Path(safe_filename).suffix
mime_type = metadata.get('content_type', 'application/octet-stream')
# Stream upload for large files
subdir = f"documents/{file_no}"
final_path, actual_size, _checksum = await async_file_ops.stream_upload_file(
file,
f"{subdir}/{uuid.uuid4()}{file_ext}",
progress_callback=None # Could add WebSocket progress here
)
# Get absolute path for database storage
absolute_path = str(final_path)
# For downstream DB fields that expect a relative path, also keep a relative for consistency
relative_path = str(Path(final_path).relative_to(async_file_ops.base_upload_dir))
else:
# Use traditional validation for smaller files
content, safe_filename, file_ext, mime_type = await file_validator.validate_upload_file(
file, category='document'
)
max_size = 10 * 1024 * 1024 # 10MB
content = await file.read()
# Treat zero-byte payloads as no file uploaded to provide a clearer client error
if len(content) == 0:
raise HTTPException(status_code=400, detail="No file uploaded")
if len(content) > max_size:
raise HTTPException(status_code=400, detail="File too large")
# Create secure upload directory
upload_dir = f"uploads/{file_no}"
create_upload_directory(upload_dir)
upload_dir = f"uploads/{file_no}"
os.makedirs(upload_dir, exist_ok=True)
# Generate secure file path with UUID to prevent conflicts
unique_name = f"{uuid.uuid4()}{file_ext}"
path = file_validator.generate_secure_path(upload_dir, unique_name)
ext = file.filename.split(".")[-1]
unique_name = f"{uuid.uuid4()}.{ext}"
path = f"{upload_dir}/{unique_name}"
with open(path, "wb") as f:
f.write(content)
# Write file using async storage for consistency
try:
relative_path = await async_storage.save_bytes_async(
content,
safe_filename,
subdir=f"documents/{file_no}"
)
absolute_path = str(async_storage.base_dir / relative_path)
actual_size = len(content)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Could not save file: {str(e)}")
doc = Document(
file_no=file_no,
filename=file.filename,
path=path,
filename=safe_filename, # Use sanitized filename
path=absolute_path,
description=description,
type=file.content_type,
size=len(content),
type=mime_type, # Use validated MIME type
size=actual_size,
uploaded_by=current_user.username
)
db.add(doc)
db.commit()
db.refresh(doc)
# Send real-time notification for document upload
try:
await notify_completed(
file_no=file_no,
user_id=current_user.id,
data={
"action": "upload",
"document_id": doc.id,
"filename": safe_filename,
"size": actual_size,
"type": mime_type,
"description": description
}
)
except Exception as e:
# Don't fail the operation if notification fails
get_logger("documents").warning(f"Failed to send document upload notification: {str(e)}")
# Log workflow event for document upload
try:
from app.services.workflow_integration import log_document_uploaded_sync
log_document_uploaded_sync(
db=db,
file_no=file_no,
document_id=doc.id,
filename=safe_filename,
document_type=mime_type,
user_id=current_user.id
)
except Exception as e:
# Don't fail the operation if workflow logging fails
get_logger("documents").warning(f"Failed to log workflow event for document upload: {str(e)}")
return doc
@router.get("/{file_no}/uploaded")
@@ -987,4 +1508,125 @@ async def update_document(
doc.description = description
db.commit()
db.refresh(doc)
return doc
return doc
# WebSocket endpoints for real-time document status notifications
@router.websocket("/ws/status/{file_no}")
async def ws_document_status(websocket: WebSocket, file_no: str):
"""
Subscribe to real-time document processing status updates for a specific file.
Users can connect to this endpoint to receive notifications about:
- Document generation started (processing)
- Document generation completed
- Document generation failed
- Document uploads
Authentication required via token query parameter.
"""
websocket_manager = get_websocket_manager()
topic = topic_for_file(file_no)
# Custom message handler for document status updates
async def handle_document_message(connection_id: str, message: WebSocketMessage):
"""Handle custom messages for document status"""
get_logger("documents").debug("Received document status message",
connection_id=connection_id,
file_no=file_no,
message_type=message.type)
# Use the WebSocket manager to handle the connection
connection_id = await websocket_manager.handle_connection(
websocket=websocket,
topics={topic},
require_auth=True,
metadata={"file_no": file_no, "endpoint": "document_status"},
message_handler=handle_document_message
)
if connection_id:
# Send initial welcome message with subscription confirmation
try:
pool = websocket_manager.pool
welcome_message = WebSocketMessage(
type="subscription_confirmed",
topic=topic,
data={
"file_no": file_no,
"message": f"Subscribed to document status updates for file {file_no}"
}
)
await pool._send_to_connection(connection_id, welcome_message)
get_logger("documents").info("Document status subscription confirmed",
connection_id=connection_id,
file_no=file_no)
except Exception as e:
get_logger("documents").error("Failed to send subscription confirmation",
connection_id=connection_id,
file_no=file_no,
error=str(e))
# Test endpoint for document notification system
@router.post("/test-notification/{file_no}")
async def test_document_notification(
file_no: str,
status: str = Query(..., description="Notification status: processing, completed, or failed"),
message: Optional[str] = Query(None, description="Optional message"),
current_user: User = Depends(get_current_user)
):
"""
Test endpoint to simulate document processing notifications.
This endpoint allows testing the WebSocket notification system by sending
simulated document status updates. Useful for development and debugging.
"""
if status not in ["processing", "completed", "failed"]:
raise HTTPException(
status_code=400,
detail="Status must be one of: processing, completed, failed"
)
# Prepare test data
test_data = {
"test": True,
"triggered_by": current_user.username,
"message": message or f"Test {status} notification for file {file_no}",
"timestamp": datetime.now(timezone.utc).isoformat()
}
# Send notification based on status
try:
if status == "processing":
sent_count = await notify_processing(
file_no=file_no,
user_id=current_user.id,
data=test_data
)
elif status == "completed":
sent_count = await notify_completed(
file_no=file_no,
user_id=current_user.id,
data=test_data
)
else: # failed
sent_count = await notify_failed(
file_no=file_no,
user_id=current_user.id,
data=test_data
)
return {
"message": f"Test notification sent for file {file_no}",
"status": status,
"sent_to_connections": sent_count,
"data": test_data
}
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Failed to send test notification: {str(e)}"
)

View File

@@ -9,8 +9,8 @@ from pydantic import BaseModel, Field, ConfigDict
from app.database.base import get_db
from app.models import (
File, FileStatus, FileType, Employee, User, FileStatusHistory,
FileTransferHistory, FileArchiveInfo
File, FileStatus, FileType, Employee, User, FileStatusHistory,
FileTransferHistory, FileArchiveInfo, FileClosureChecklist, FileAlert
)
from app.services.file_management import FileManagementService, FileManagementError, FileStatusWorkflow
from app.auth.security import get_current_user
@@ -134,6 +134,10 @@ async def change_file_status(
"""Change file status with workflow validation"""
try:
service = FileManagementService(db)
# Get the old status before changing
old_file = db.query(File).filter(File.file_no == file_no).first()
old_status = old_file.status if old_file else None
file_obj = service.change_file_status(
file_no=file_no,
new_status=request.new_status,
@@ -142,6 +146,21 @@ async def change_file_status(
validate_transition=request.validate_transition
)
# Log workflow event for file status change
try:
from app.services.workflow_integration import log_file_status_change_sync
log_file_status_change_sync(
db=db,
file_no=file_no,
old_status=old_status,
new_status=request.new_status,
user_id=current_user.id,
notes=request.notes
)
except Exception as e:
# Don't fail the operation if workflow logging fails
logger.warning(f"Failed to log workflow event for file {file_no}: {str(e)}")
return {
"message": f"File {file_no} status changed to {request.new_status}",
"file_no": file_obj.file_no,
@@ -397,6 +416,302 @@ async def bulk_status_update(
)
# Checklist endpoints
class ChecklistItemRequest(BaseModel):
item_name: str
item_description: Optional[str] = None
is_required: bool = True
sort_order: int = 0
class ChecklistItemUpdateRequest(BaseModel):
item_name: Optional[str] = None
item_description: Optional[str] = None
is_required: Optional[bool] = None
is_completed: Optional[bool] = None
sort_order: Optional[int] = None
notes: Optional[str] = None
@router.get("/{file_no}/closure-checklist")
async def get_closure_checklist(
file_no: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
service = FileManagementService(db)
return service.get_closure_checklist(file_no)
@router.post("/{file_no}/closure-checklist")
async def add_checklist_item(
file_no: str,
request: ChecklistItemRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
service = FileManagementService(db)
try:
item = service.add_checklist_item(
file_no=file_no,
item_name=request.item_name,
item_description=request.item_description,
is_required=request.is_required,
sort_order=request.sort_order,
)
return {
"id": item.id,
"file_no": item.file_no,
"item_name": item.item_name,
"item_description": item.item_description,
"is_required": item.is_required,
"is_completed": item.is_completed,
"sort_order": item.sort_order,
}
except FileManagementError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
@router.put("/closure-checklist/{item_id}")
async def update_checklist_item(
item_id: int,
request: ChecklistItemUpdateRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
service = FileManagementService(db)
try:
item = service.update_checklist_item(
item_id=item_id,
item_name=request.item_name,
item_description=request.item_description,
is_required=request.is_required,
is_completed=request.is_completed,
sort_order=request.sort_order,
user_id=current_user.id,
notes=request.notes,
)
return {
"id": item.id,
"file_no": item.file_no,
"item_name": item.item_name,
"item_description": item.item_description,
"is_required": item.is_required,
"is_completed": item.is_completed,
"completed_date": item.completed_date,
"completed_by_name": item.completed_by_name,
"notes": item.notes,
"sort_order": item.sort_order,
}
except FileManagementError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
@router.delete("/closure-checklist/{item_id}")
async def delete_checklist_item(
item_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
service = FileManagementService(db)
try:
service.delete_checklist_item(item_id=item_id)
return {"message": "Checklist item deleted"}
except FileManagementError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
# Alerts endpoints
class AlertCreateRequest(BaseModel):
alert_type: str
title: str
message: str
alert_date: date
notify_attorney: bool = True
notify_admin: bool = False
notification_days_advance: int = 7
class AlertUpdateRequest(BaseModel):
title: Optional[str] = None
message: Optional[str] = None
alert_date: Optional[date] = None
is_active: Optional[bool] = None
@router.post("/{file_no}/alerts")
async def create_alert(
file_no: str,
request: AlertCreateRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
service = FileManagementService(db)
try:
alert = service.create_alert(
file_no=file_no,
alert_type=request.alert_type,
title=request.title,
message=request.message,
alert_date=request.alert_date,
notify_attorney=request.notify_attorney,
notify_admin=request.notify_admin,
notification_days_advance=request.notification_days_advance,
)
return {
"id": alert.id,
"file_no": alert.file_no,
"alert_type": alert.alert_type,
"title": alert.title,
"message": alert.message,
"alert_date": alert.alert_date,
"is_active": alert.is_active,
}
except FileManagementError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
@router.get("/{file_no}/alerts")
async def get_alerts(
file_no: str,
active_only: bool = Query(True),
upcoming_only: bool = Query(False),
limit: int = Query(100, ge=1, le=500),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
service = FileManagementService(db)
alerts = service.get_alerts(
file_no=file_no,
active_only=active_only,
upcoming_only=upcoming_only,
limit=limit,
)
return [
{
"id": a.id,
"file_no": a.file_no,
"alert_type": a.alert_type,
"title": a.title,
"message": a.message,
"alert_date": a.alert_date,
"is_active": a.is_active,
"is_acknowledged": a.is_acknowledged,
}
for a in alerts
]
@router.post("/alerts/{alert_id}/acknowledge")
async def acknowledge_alert(
alert_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
service = FileManagementService(db)
try:
alert = service.acknowledge_alert(alert_id=alert_id, user_id=current_user.id)
return {"message": "Alert acknowledged", "id": alert.id}
except FileManagementError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
@router.put("/alerts/{alert_id}")
async def update_alert(
alert_id: int,
request: AlertUpdateRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
service = FileManagementService(db)
try:
alert = service.update_alert(
alert_id=alert_id,
title=request.title,
message=request.message,
alert_date=request.alert_date,
is_active=request.is_active,
)
return {"message": "Alert updated", "id": alert.id}
except FileManagementError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
@router.delete("/alerts/{alert_id}")
async def delete_alert(
alert_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
service = FileManagementService(db)
try:
service.delete_alert(alert_id=alert_id)
return {"message": "Alert deleted"}
except FileManagementError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
# Relationships endpoints
class RelationshipCreateRequest(BaseModel):
target_file_no: str
relationship_type: str
notes: Optional[str] = None
@router.post("/{file_no}/relationships")
async def create_relationship(
file_no: str,
request: RelationshipCreateRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
service = FileManagementService(db)
try:
rel = service.create_relationship(
source_file_no=file_no,
target_file_no=request.target_file_no,
relationship_type=request.relationship_type,
user_id=current_user.id,
notes=request.notes,
)
return {
"id": rel.id,
"source_file_no": rel.source_file_no,
"target_file_no": rel.target_file_no,
"relationship_type": rel.relationship_type,
"notes": rel.notes,
}
except FileManagementError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
@router.get("/{file_no}/relationships")
async def get_relationships(
file_no: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
service = FileManagementService(db)
return service.get_relationships(file_no=file_no)
@router.delete("/relationships/{relationship_id}")
async def delete_relationship(
relationship_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
service = FileManagementService(db)
try:
service.delete_relationship(relationship_id=relationship_id)
return {"message": "Relationship deleted"}
except FileManagementError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
# File queries and reports
@router.get("/by-status/{status}")
async def get_files_by_status(

View File

@@ -16,6 +16,7 @@ from app.models.user import User
from app.auth.security import get_current_user
from app.services.cache import invalidate_search_cache
from app.services.query_utils import apply_sorting, paginate_with_total
from app.models.additional import Deposit, Payment
router = APIRouter()
@@ -81,6 +82,23 @@ class PaginatedLedgerResponse(BaseModel):
total: int
class DepositResponse(BaseModel):
deposit_date: date
total: float
notes: Optional[str] = None
payments: Optional[List[Dict]] = None # Optional, depending on include_payments
class PaymentCreate(BaseModel):
file_no: Optional[str] = None
client_id: Optional[str] = None
regarding: Optional[str] = None
amount: float
note: Optional[str] = None
payment_method: str = "CHECK"
reference: Optional[str] = None
apply_to_trust: bool = False
@router.get("/ledger/{file_no}", response_model=Union[List[LedgerResponse], PaginatedLedgerResponse])
async def get_file_ledger(
file_no: str,
@@ -324,6 +342,59 @@ async def _update_file_balances(file_obj: File, db: Session):
db.commit()
async def _create_ledger_payment(
file_no: str,
amount: float,
payment_date: date,
payment_method: str,
reference: Optional[str],
notes: Optional[str],
apply_to_trust: bool,
empl_num: str,
db: Session
) -> Ledger:
# Get next item number
max_item = db.query(func.max(Ledger.item_no)).filter(
Ledger.file_no == file_no
).scalar() or 0
# Determine transaction type and code
if apply_to_trust:
t_type = "1" # Trust
t_code = "TRUST"
description = f"Trust deposit - {payment_method}"
else:
t_type = "5" # Credit/Payment
t_code = "PMT"
description = f"Payment received - {payment_method}"
if reference:
description += f" - Ref: {reference}"
if notes:
description += f" - {notes}"
# Create ledger entry
entry = Ledger(
file_no=file_no,
item_no=max_item + 1,
date=payment_date,
t_code=t_code,
t_type=t_type,
t_type_l="C", # Credit
empl_num=empl_num,
quantity=0.0,
rate=0.0,
amount=amount,
billed="Y", # Payments are automatically considered "billed"
note=description
)
db.add(entry)
db.flush() # To get ID
return entry
# Additional Financial Management Endpoints
@router.get("/time-entries/recent")
@@ -819,56 +890,27 @@ async def record_payment(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Record a payment against a file"""
# Verify file exists
file_obj = db.query(File).filter(File.file_no == file_no).first()
if not file_obj:
raise HTTPException(status_code=404, detail="File not found")
payment_date = payment_date or date.today()
# Get next item number
max_item = db.query(func.max(Ledger.item_no)).filter(
Ledger.file_no == file_no
).scalar() or 0
# Determine transaction type and code based on whether it goes to trust
if apply_to_trust:
t_type = "1" # Trust
t_code = "TRUST"
description = f"Trust deposit - {payment_method}"
else:
t_type = "5" # Credit/Payment
t_code = "PMT"
description = f"Payment received - {payment_method}"
if reference:
description += f" - Ref: {reference}"
if notes:
description += f" - {notes}"
# Create payment entry
entry = Ledger(
entry = await _create_ledger_payment(
file_no=file_no,
item_no=max_item + 1,
date=payment_date,
t_code=t_code,
t_type=t_type,
t_type_l="C", # Credit
empl_num=file_obj.empl_num,
quantity=0.0,
rate=0.0,
amount=amount,
billed="Y", # Payments are automatically considered "billed"
note=description
payment_date=payment_date,
payment_method=payment_method,
reference=reference,
notes=notes,
apply_to_trust=apply_to_trust,
empl_num=file_obj.empl_num,
db=db
)
db.add(entry)
db.commit()
db.refresh(entry)
# Update file balances
await _update_file_balances(file_obj, db)
return {
@@ -952,4 +994,157 @@ async def record_expense(
"description": description,
"employee": empl_num
}
}
}
@router.post("/deposits/")
async def create_deposit(
deposit_date: date,
notes: Optional[str] = None,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
existing = db.query(Deposit).filter(Deposit.deposit_date == deposit_date).first()
if existing:
raise HTTPException(status_code=400, detail="Deposit for this date already exists")
deposit = Deposit(
deposit_date=deposit_date,
total=0.0,
notes=notes
)
db.add(deposit)
db.commit()
db.refresh(deposit)
return deposit
@router.post("/deposits/{deposit_date}/payments/")
async def add_payment_to_deposit(
deposit_date: date,
payment_data: PaymentCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
deposit = db.query(Deposit).filter(Deposit.deposit_date == deposit_date).first()
if not deposit:
raise HTTPException(status_code=404, detail="Deposit not found")
if not payment_data.file_no:
raise HTTPException(status_code=400, detail="file_no is required for payments")
file_obj = db.query(File).filter(File.file_no == payment_data.file_no).first()
if not file_obj:
raise HTTPException(status_code=404, detail="File not found")
# Create ledger entry first
ledger_entry = await _create_ledger_payment(
file_no=payment_data.file_no,
amount=payment_data.amount,
payment_date=deposit_date,
payment_method=payment_data.payment_method,
reference=payment_data.reference,
notes=payment_data.note,
apply_to_trust=payment_data.apply_to_trust,
empl_num=file_obj.empl_num,
db=db
)
# Create payment record
payment = Payment(
deposit_date=deposit_date,
file_no=payment_data.file_no,
client_id=payment_data.client_id,
regarding=payment_data.regarding,
amount=payment_data.amount,
note=payment_data.note
)
db.add(payment)
# Update deposit total
deposit.total += payment_data.amount
db.commit()
db.refresh(payment)
await _update_file_balances(file_obj, db)
return payment
@router.get("/deposits/", response_model=List[DepositResponse])
async def list_deposits(
start_date: Optional[date] = None,
end_date: Optional[date] = None,
include_payments: bool = False,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
query = db.query(Deposit)
if start_date:
query = query.filter(Deposit.deposit_date >= start_date)
if end_date:
query = query.filter(Deposit.deposit_date <= end_date)
query = query.order_by(Deposit.deposit_date.desc())
deposits = query.all()
results = []
for dep in deposits:
dep_data = {
"deposit_date": dep.deposit_date,
"total": dep.total,
"notes": dep.notes
}
if include_payments:
payments = db.query(Payment).filter(Payment.deposit_date == dep.deposit_date).all()
dep_data["payments"] = [p.__dict__ for p in payments]
results.append(dep_data)
return results
@router.get("/deposits/{deposit_date}", response_model=DepositResponse)
async def get_deposit(
deposit_date: date,
include_payments: bool = True,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
deposit = db.query(Deposit).filter(Deposit.deposit_date == deposit_date).first()
if not deposit:
raise HTTPException(status_code=404, detail="Deposit not found")
dep_data = {
"deposit_date": deposit.deposit_date,
"total": deposit.total,
"notes": deposit.notes
}
if include_payments:
payments = db.query(Payment).filter(Payment.deposit_date == deposit_date).all()
dep_data["payments"] = [p.__dict__ for p in payments]
return dep_data
@router.get("/reports/deposits")
async def get_deposit_report(
start_date: date,
end_date: date,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
deposits = db.query(Deposit).filter(
Deposit.deposit_date >= start_date,
Deposit.deposit_date <= end_date
).order_by(Deposit.deposit_date).all()
total_deposits = sum(d.total for d in deposits)
report = {
"period": {
"start": start_date.isoformat(),
"end": end_date.isoformat()
},
"total_deposits": total_deposits,
"deposit_count": len(deposits),
"deposits": [
{
"date": d.deposit_date.isoformat(),
"total": d.total,
"notes": d.notes,
"payment_count": db.query(Payment).filter(Payment.deposit_date == d.deposit_date).count()
} for d in deposits
]
}
return report

View File

@@ -3,6 +3,7 @@ Data import API endpoints for CSV file uploads with auto-discovery mapping.
"""
import csv
import io
import zipfile
import re
import os
from pathlib import Path
@@ -11,6 +12,7 @@ from datetime import datetime, date, timezone
from decimal import Decimal
from typing import List, Dict, Any, Optional, Tuple
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File as UploadFileForm, Form, Query
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from app.database.base import get_db
from app.auth.security import get_current_user
@@ -40,8 +42,8 @@ ENCODINGS = [
# Unified import order used across batch operations
IMPORT_ORDER = [
"STATES.csv", "GRUPLKUP.csv", "EMPLOYEE.csv", "FILETYPE.csv", "FILESTAT.csv",
"TRNSTYPE.csv", "TRNSLKUP.csv", "FOOTERS.csv", "SETUP.csv", "PRINTERS.csv",
"STATES.csv", "GRUPLKUP.csv", "EMPLOYEE.csv", "FILETYPE.csv", "FOOTERS.csv", "FILESTAT.csv",
"TRNSTYPE.csv", "TRNSLKUP.csv", "SETUP.csv", "PRINTERS.csv",
"INX_LKUP.csv",
"ROLODEX.csv", "PHONE.csv", "FILES.csv", "LEDGER.csv", "TRNSACTN.csv",
"QDROS.csv", "PENSIONS.csv", "SCHEDULE.csv", "MARRIAGE.csv", "DEATH.csv", "SEPARATE.csv", "LIFETABL.csv", "NUMBERAL.csv", "PLANINFO.csv", "RESULTS.csv", "PAYMENTS.csv", "DEPOSITS.csv",
@@ -91,8 +93,83 @@ CSV_MODEL_MAPPING = {
"RESULTS.csv": PensionResult
}
# Minimal CSV template definitions (headers + one sample row) used for template downloads
CSV_IMPORT_TEMPLATES: Dict[str, Dict[str, List[str]]] = {
"FILES.csv": {
"headers": ["File_No", "Id", "Empl_Num", "File_Type", "Opened", "Status", "Rate_Per_Hour"],
"sample": ["F-001", "CLIENT-1", "EMP01", "CIVIL", "2024-01-01", "ACTIVE", "150"],
},
"LEDGER.csv": {
"headers": ["File_No", "Date", "Empl_Num", "T_Code", "T_Type", "Amount"],
"sample": ["F-001", "2024-01-15", "EMP01", "FEE", "1", "500.00"],
},
"PAYMENTS.csv": {
"headers": ["Deposit_Date", "Amount"],
"sample": ["2024-01-15", "1500.00"],
},
# Additional templates for convenience
"TRNSACTN.csv": {
# Same structure as LEDGER.csv
"headers": ["File_No", "Date", "Empl_Num", "T_Code", "T_Type", "Amount"],
"sample": ["F-002", "2024-02-10", "EMP02", "FEE", "1", "250.00"],
},
"DEPOSITS.csv": {
"headers": ["Deposit_Date", "Total"],
"sample": ["2024-02-10", "1500.00"],
},
"ROLODEX.csv": {
# Minimal common contact fields
"headers": ["Id", "Last", "First", "A1", "City", "Abrev", "Zip", "Email"],
"sample": ["CLIENT-1", "Smith", "John", "123 Main St", "Denver", "CO", "80202", "john.smith@example.com"],
},
}
def _generate_csv_template_bytes(file_type: str) -> bytes:
"""Return CSV template content for the given file type as bytes.
Raises HTTPException if unsupported.
"""
key = (file_type or "").strip()
if key not in CSV_IMPORT_TEMPLATES:
raise HTTPException(status_code=400, detail=f"Unsupported template type: {file_type}. Choose one of: {list(CSV_IMPORT_TEMPLATES.keys())}")
cfg = CSV_IMPORT_TEMPLATES[key]
output = io.StringIO()
writer = csv.writer(output)
writer.writerow(cfg["headers"])
writer.writerow(cfg["sample"])
output.seek(0)
return output.getvalue().encode("utf-8")
# Field mappings for CSV columns to database fields
# Legacy header synonyms used as hints only (not required). Auto-discovery will work without exact matches.
REQUIRED_MODEL_FIELDS: Dict[str, List[str]] = {
# Files: core identifiers and billing/status fields used throughout the app
"FILES.csv": [
"file_no",
"id",
"empl_num",
"file_type",
"opened",
"status",
"rate_per_hour",
],
# Ledger: core transaction fields
"LEDGER.csv": [
"file_no",
"date",
"empl_num",
"t_code",
"t_type",
"amount",
],
# Payments: deposit date and amount are the only strictly required model fields
"PAYMENTS.csv": [
"deposit_date",
"amount",
],
}
FIELD_MAPPINGS = {
"ROLODEX.csv": {
"Id": "id",
@@ -191,7 +268,14 @@ FIELD_MAPPINGS = {
"Draft_Apr": "draft_apr",
"Final_Out": "final_out",
"Judge": "judge",
"Form_Name": "form_name"
"Form_Name": "form_name",
# Extended workflow/document fields (present in new exports or manual CSVs)
"Status": "status",
"Content": "content",
"Notes": "notes",
"Approval_Status": "approval_status",
"Approved_Date": "approved_date",
"Filed_Date": "filed_date"
},
"PENSIONS.csv": {
"File_No": "file_no",
@@ -218,9 +302,17 @@ FIELD_MAPPINGS = {
},
"EMPLOYEE.csv": {
"Empl_Num": "empl_num",
"Rate_Per_Hour": "rate_per_hour"
# "Empl_Id": not a field in Employee model, using empl_num as identifier
# Model has additional fields (first_name, last_name, title, etc.) not in CSV
"Empl_Id": "initials", # Map employee ID to initials field
"Rate_Per_Hour": "rate_per_hour",
# Optional extended fields when present in enhanced exports
"First": "first_name",
"First_Name": "first_name",
"Last": "last_name",
"Last_Name": "last_name",
"Title": "title",
"Email": "email",
"Phone": "phone",
"Active": "active"
},
"STATES.csv": {
"Abrev": "abbreviation",
@@ -228,8 +320,8 @@ FIELD_MAPPINGS = {
},
"GRUPLKUP.csv": {
"Code": "group_code",
"Description": "description"
# "Title": field not present in model, skipping
"Description": "description",
"Title": "title"
},
"TRNSLKUP.csv": {
"T_Code": "t_code",
@@ -240,10 +332,9 @@ FIELD_MAPPINGS = {
},
"TRNSTYPE.csv": {
"T_Type": "t_type",
"T_Type_L": "description"
# "Header": maps to debit_credit but needs data transformation
# "Footer": doesn't align with active boolean field
# These fields may need custom handling or model updates
"T_Type_L": "debit_credit", # D=Debit, C=Credit
"Header": "description",
"Footer": "footer_code"
},
"FILETYPE.csv": {
"File_Type": "type_code",
@@ -343,6 +434,10 @@ FIELD_MAPPINGS = {
"DEATH.csv": {
"File_No": "file_no",
"Version": "version",
"Beneficiary_Name": "beneficiary_name",
"Benefit_Amount": "benefit_amount",
"Benefit_Type": "benefit_type",
"Notes": "notes",
"Lump1": "lump1",
"Lump2": "lump2",
"Growth1": "growth1",
@@ -353,6 +448,9 @@ FIELD_MAPPINGS = {
"SEPARATE.csv": {
"File_No": "file_no",
"Version": "version",
"Agreement_Date": "agreement_date",
"Terms": "terms",
"Notes": "notes",
"Separation_Rate": "terms"
},
"LIFETABL.csv": {
@@ -466,6 +564,40 @@ FIELD_MAPPINGS = {
"Amount": "amount",
"Billed": "billed",
"Note": "note"
},
"EMPLOYEE.csv": {
"Empl_Num": "empl_num",
"Empl_Id": "initials", # Map employee ID to initials field
"Rate_Per_Hour": "rate_per_hour",
# Note: first_name, last_name, title, active, email, phone will need manual entry or separate import
# as they're not present in the legacy CSV structure
},
"QDROS.csv": {
"File_No": "file_no",
"Version": "version",
"Plan_Id": "plan_id",
"^1": "field1",
"^2": "field2",
"^Part": "part",
"^AltP": "altp",
"^Pet": "pet",
"^Res": "res",
"Case_Type": "case_type",
"Case_Code": "case_code",
"Section": "section",
"Case_Number": "case_number",
"Judgment_Date": "judgment_date",
"Valuation_Date": "valuation_date",
"Married_On": "married_on",
"Percent_Awarded": "percent_awarded",
"Ven_City": "ven_city",
"Ven_Cnty": "ven_cnty",
"Ven_St": "ven_st",
"Draft_Out": "draft_out",
"Draft_Apr": "draft_apr",
"Final_Out": "final_out",
"Judge": "judge",
"Form_Name": "form_name"
}
}
@@ -691,6 +823,21 @@ def _build_dynamic_mapping(headers: List[str], model_class, file_type: str) -> D
}
def _validate_required_headers(file_type: str, mapped_headers: Dict[str, str]) -> Dict[str, Any]:
"""Check that minimal required model fields for a given CSV type are present in mapped headers.
Returns dict with: required_fields, missing_fields, ok.
"""
required_fields = REQUIRED_MODEL_FIELDS.get(file_type, [])
present_fields = set((mapped_headers or {}).values())
missing_fields = [f for f in required_fields if f not in present_fields]
return {
"required_fields": required_fields,
"missing_fields": missing_fields,
"ok": len(missing_fields) == 0,
}
def _get_required_fields(model_class) -> List[str]:
"""Infer required (non-nullable) fields for a model to avoid DB errors.
@@ -721,7 +868,7 @@ def convert_value(value: str, field_name: str) -> Any:
# Date fields
if any(word in field_name.lower() for word in [
"date", "dob", "birth", "opened", "closed", "judgment", "valuation", "married", "vests_on", "service"
"date", "dob", "birth", "opened", "closed", "judgment", "valuation", "married", "vests_on", "service", "approved", "filed", "agreement"
]):
parsed_date = parse_date(value)
return parsed_date
@@ -752,6 +899,15 @@ def convert_value(value: str, field_name: str) -> Any:
except ValueError:
return 0.0
# Normalize debit_credit textual variants
if field_name.lower() == "debit_credit":
normalized = value.strip().upper()
if normalized in ["D", "DEBIT"]:
return "D"
if normalized in ["C", "CREDIT"]:
return "C"
return normalized[:1] if normalized else None
# Integer fields
if any(word in field_name.lower() for word in [
"item_no", "age", "start_age", "version", "line_number", "sort_order", "empl_num", "month", "number"
@@ -786,6 +942,69 @@ def validate_foreign_keys(model_data: dict, model_class, db: Session) -> list[st
rolodex_id = model_data["id"]
if rolodex_id and not db.query(Rolodex).filter(Rolodex.id == rolodex_id).first():
errors.append(f"Owner Rolodex ID '{rolodex_id}' not found")
# Check File -> Footer relationship (default footer on file)
if model_class == File and "footer_code" in model_data:
footer = model_data.get("footer_code")
if footer:
exists = db.query(Footer).filter(Footer.footer_code == footer).first()
if not exists:
errors.append(f"Footer code '{footer}' not found for File")
# Check FileStatus -> Footer (default footer exists)
if model_class == FileStatus and "footer_code" in model_data:
footer = model_data.get("footer_code")
if footer:
exists = db.query(Footer).filter(Footer.footer_code == footer).first()
if not exists:
errors.append(f"Footer code '{footer}' not found for FileStatus")
# Check TransactionType -> Footer (default footer exists)
if model_class == TransactionType and "footer_code" in model_data:
footer = model_data.get("footer_code")
if footer:
exists = db.query(Footer).filter(Footer.footer_code == footer).first()
if not exists:
errors.append(f"Footer code '{footer}' not found for TransactionType")
# Check Ledger -> TransactionType/TransactionCode cross references
if model_class == Ledger:
# Validate t_type exists
if "t_type" in model_data:
t_type_value = model_data.get("t_type")
if t_type_value and not db.query(TransactionType).filter(TransactionType.t_type == t_type_value).first():
errors.append(f"Transaction type '{t_type_value}' not found")
# Validate t_code exists and matches t_type if both provided
if "t_code" in model_data:
t_code_value = model_data.get("t_code")
if t_code_value:
code_row = db.query(TransactionCode).filter(TransactionCode.t_code == t_code_value).first()
if not code_row:
errors.append(f"Transaction code '{t_code_value}' not found")
else:
ledger_t_type = model_data.get("t_type")
if ledger_t_type and getattr(code_row, "t_type", None) and code_row.t_type != ledger_t_type:
errors.append(
f"Transaction code '{t_code_value}' t_type '{code_row.t_type}' does not match ledger t_type '{ledger_t_type}'"
)
# Check Payment -> File and Rolodex relationships
if model_class == Payment:
if "file_no" in model_data:
file_no_value = model_data.get("file_no")
if file_no_value and not db.query(File).filter(File.file_no == file_no_value).first():
errors.append(f"File number '{file_no_value}' not found for Payment")
if "client_id" in model_data:
client_id_value = model_data.get("client_id")
if client_id_value and not db.query(Rolodex).filter(Rolodex.id == client_id_value).first():
errors.append(f"Client ID '{client_id_value}' not found for Payment")
# Check QDRO -> PlanInfo (plan_id exists)
if model_class == QDRO and "plan_id" in model_data:
plan_id = model_data.get("plan_id")
if plan_id:
exists = db.query(PlanInfo).filter(PlanInfo.plan_id == plan_id).first()
if not exists:
errors.append(f"Plan ID '{plan_id}' not found for QDRO")
# Add more foreign key validations as needed
return errors
@@ -831,6 +1050,96 @@ async def get_available_csv_files(current_user: User = Depends(get_current_user)
}
@router.get("/template/{file_type}")
async def download_csv_template(
file_type: str,
current_user: User = Depends(get_current_user)
):
"""Download a minimal CSV template with required headers and one sample row.
Supported templates include: {list(CSV_IMPORT_TEMPLATES.keys())}
"""
key = (file_type or "").strip()
if key not in CSV_IMPORT_TEMPLATES:
raise HTTPException(status_code=400, detail=f"Unsupported template type: {file_type}. Choose one of: {list(CSV_IMPORT_TEMPLATES.keys())}")
content = _generate_csv_template_bytes(key)
from datetime import datetime as _dt
ts = _dt.now().strftime("%Y%m%d_%H%M%S")
safe_name = key.replace(".csv", "")
filename = f"{safe_name}_template_{ts}.csv"
return StreamingResponse(
iter([content]),
media_type="text/csv",
headers={"Content-Disposition": f"attachment; filename=\"{filename}\""},
)
@router.get("/templates/bundle")
async def download_csv_templates_bundle(
files: Optional[List[str]] = Query(None, description="Repeat for each CSV template, e.g., files=FILES.csv&files=LEDGER.csv"),
current_user: User = Depends(get_current_user)
):
"""Bundle selected CSV templates into a single ZIP.
Example: GET /api/import/templates/bundle?files=FILES.csv&files=LEDGER.csv
"""
requested = files or []
if not requested:
raise HTTPException(status_code=400, detail="Specify at least one 'files' query parameter")
# Normalize and validate
normalized: List[str] = []
for name in requested:
if not name:
continue
n = name.strip()
if not n.lower().endswith(".csv"):
n = f"{n}.csv"
n = n.upper()
if n in CSV_IMPORT_TEMPLATES:
normalized.append(n)
else:
# Ignore unknowns rather than fail the whole bundle
continue
# Deduplicate while preserving order
seen = set()
selected = []
for n in normalized:
if n not in seen:
seen.add(n)
selected.append(n)
if not selected:
raise HTTPException(status_code=400, detail=f"No supported templates requested. Supported: {list(CSV_IMPORT_TEMPLATES.keys())}")
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, mode="w", compression=zipfile.ZIP_DEFLATED) as zf:
for fname in selected:
try:
content = _generate_csv_template_bytes(fname)
# Friendly name in zip: <BASENAME>_template.csv
base = fname.replace(".CSV", "").upper()
arcname = f"{base}_template.csv"
zf.writestr(arcname, content)
except HTTPException:
# Skip unsupported just in case
continue
zip_buffer.seek(0)
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"csv_templates_{ts}.zip"
return StreamingResponse(
iter([zip_buffer.getvalue()]),
media_type="application/zip",
headers={
"Content-Disposition": f"attachment; filename=\"{filename}\""
},
)
@router.post("/upload/{file_type}")
async def import_csv_data(
file_type: str,
@@ -1060,6 +1369,26 @@ async def import_csv_data(
except Exception:
pass
else:
# FK validation for known relationships
fk_errors = validate_foreign_keys(model_data, model_class, db)
if fk_errors:
for msg in fk_errors:
errors.append({"row": row_num, "error": msg})
# Persist as flexible for traceability
db.add(
FlexibleImport(
file_type=file_type,
target_table=model_class.__tablename__,
primary_key_field=None,
primary_key_value=None,
extra_data={
"mapped": model_data,
"fk_errors": fk_errors,
},
)
)
flexible_saved += 1
continue
instance = model_class(**model_data)
db.add(instance)
db.flush() # Ensure PK is available
@@ -1136,6 +1465,9 @@ async def import_csv_data(
"unmapped_headers": unmapped_headers,
"flexible_saved_rows": flexible_saved,
},
"validation": {
"fk_errors": len([e for e in errors if isinstance(e, dict) and 'error' in e and 'not found' in str(e['error']).lower()])
}
}
# Include create/update breakdown for printers
if file_type == "PRINTERS.csv":
@@ -1368,6 +1700,10 @@ async def batch_validate_csv_files(
mapping_info = _build_dynamic_mapping(csv_headers, model_class, file_type)
mapped_headers = mapping_info["mapped_headers"]
unmapped_headers = mapping_info["unmapped_headers"]
header_validation = _validate_required_headers(file_type, mapped_headers)
header_validation = _validate_required_headers(file_type, mapped_headers)
header_validation = _validate_required_headers(file_type, mapped_headers)
header_validation = _validate_required_headers(file_type, mapped_headers)
# Sample data validation
sample_rows = []
@@ -1394,12 +1730,13 @@ async def batch_validate_csv_files(
validation_results.append({
"file_type": file_type,
"valid": len(mapped_headers) > 0 and len(errors) == 0,
"valid": (len(mapped_headers) > 0 and len(errors) == 0 and header_validation.get("ok", True)),
"headers": {
"found": csv_headers,
"mapped": mapped_headers,
"unmapped": unmapped_headers
},
"header_validation": header_validation,
"sample_data": sample_rows[:5], # Limit sample data for batch operation
"validation_errors": errors[:5], # First 5 errors only
"total_errors": len(errors),
@@ -1493,17 +1830,34 @@ async def batch_import_csv_files(
if file_type not in CSV_MODEL_MAPPING:
# Fallback flexible-only import for unknown file structures
try:
await file.seek(0)
content = await file.read()
# Save original upload to disk for potential reruns
# Use async file operations for better performance
from app.services.async_file_operations import async_file_ops
# Stream save to disk for potential reruns and processing
saved_path = None
try:
file_path = audit_dir.joinpath(file_type)
with open(file_path, "wb") as fh:
fh.write(content)
saved_path = str(file_path)
except Exception:
saved_path = None
relative_path = f"import_audits/{audit_row.id}/{file_type}"
saved_file_path, file_size, checksum = await async_file_ops.stream_upload_file(
file, relative_path
)
saved_path = str(async_file_ops.base_upload_dir / relative_path)
# Stream read for processing
content = b""
async for chunk in async_file_ops.stream_read_file(relative_path):
content += chunk
except Exception as e:
# Fallback to traditional method
await file.seek(0)
content = await file.read()
try:
file_path = audit_dir.joinpath(file_type)
with open(file_path, "wb") as fh:
fh.write(content)
saved_path = str(file_path)
except Exception:
saved_path = None
encodings = ENCODINGS
csv_content = None
for encoding in encodings:
@@ -1640,10 +1994,12 @@ async def batch_import_csv_files(
mapping_info = _build_dynamic_mapping(csv_headers, model_class, file_type)
mapped_headers = mapping_info["mapped_headers"]
unmapped_headers = mapping_info["unmapped_headers"]
header_validation = _validate_required_headers(file_type, mapped_headers)
imported_count = 0
errors = []
flexible_saved = 0
fk_error_summary: Dict[str, int] = {}
# Special handling: assign line numbers per form for FORM_LST.csv
form_lst_line_counters: Dict[str, int] = {}
@@ -1713,6 +2069,26 @@ async def batch_import_csv_files(
if 'file_no' not in model_data or not model_data['file_no']:
continue # Skip ledger records without file number
# FK validation for known relationships
fk_errors = validate_foreign_keys(model_data, model_class, db)
if fk_errors:
for msg in fk_errors:
errors.append({"row": row_num, "error": msg})
fk_error_summary[msg] = fk_error_summary.get(msg, 0) + 1
db.add(
FlexibleImport(
file_type=file_type,
target_table=model_class.__tablename__,
primary_key_field=None,
primary_key_value=None,
extra_data=make_json_safe({
"mapped": model_data,
"fk_errors": fk_errors,
}),
)
)
flexible_saved += 1
continue
instance = model_class(**model_data)
db.add(instance)
db.flush()
@@ -1779,10 +2155,15 @@ async def batch_import_csv_files(
results.append({
"file_type": file_type,
"status": "success" if len(errors) == 0 else "completed_with_errors",
"status": "success" if (len(errors) == 0 and header_validation.get("ok", True)) else "completed_with_errors",
"imported_count": imported_count,
"errors": len(errors),
"message": f"Imported {imported_count} records" + (f" with {len(errors)} errors" if errors else ""),
"header_validation": header_validation,
"validation": {
"fk_errors_total": sum(fk_error_summary.values()),
"fk_error_summary": fk_error_summary,
},
"auto_mapping": {
"mapped_headers": mapped_headers,
"unmapped_headers": unmapped_headers,
@@ -1793,7 +2174,7 @@ async def batch_import_csv_files(
db.add(ImportAuditFile(
audit_id=audit_row.id,
file_type=file_type,
status="success" if len(errors) == 0 else "completed_with_errors",
status="success" if (len(errors) == 0 and header_validation.get("ok", True)) else "completed_with_errors",
imported_count=imported_count,
errors=len(errors),
message=f"Imported {imported_count} records" + (f" with {len(errors)} errors" if errors else ""),
@@ -1801,6 +2182,9 @@ async def batch_import_csv_files(
"mapped_headers": list(mapped_headers.keys()),
"unmapped_count": len(unmapped_headers),
"flexible_saved_rows": flexible_saved,
"fk_errors_total": sum(fk_error_summary.values()),
"fk_error_summary": fk_error_summary,
"header_validation": header_validation,
**({"saved_path": saved_path} if saved_path else {}),
}
))
@@ -2138,6 +2522,7 @@ async def rerun_failed_files(
mapping_info = _build_dynamic_mapping(csv_headers, model_class, file_type)
mapped_headers = mapping_info["mapped_headers"]
unmapped_headers = mapping_info["unmapped_headers"]
header_validation = _validate_required_headers(file_type, mapped_headers)
imported_count = 0
errors: List[Dict[str, Any]] = []
# Special handling: assign line numbers per form for FORM_LST.csv
@@ -2248,20 +2633,21 @@ async def rerun_failed_files(
total_errors += len(errors)
results.append({
"file_type": file_type,
"status": "success" if len(errors) == 0 else "completed_with_errors",
"status": "success" if (len(errors) == 0 and header_validation.get("ok", True)) else "completed_with_errors",
"imported_count": imported_count,
"errors": len(errors),
"message": f"Imported {imported_count} records" + (f" with {len(errors)} errors" if errors else ""),
"header_validation": header_validation,
})
try:
db.add(ImportAuditFile(
audit_id=rerun_audit.id,
file_type=file_type,
status="success" if len(errors) == 0 else "completed_with_errors",
status="success" if (len(errors) == 0 and header_validation.get("ok", True)) else "completed_with_errors",
imported_count=imported_count,
errors=len(errors),
message=f"Imported {imported_count} records" + (f" with {len(errors)} errors" if errors else ""),
details={"saved_path": saved_path} if saved_path else {}
details={**({"saved_path": saved_path} if saved_path else {}), "header_validation": header_validation}
))
db.commit()
except Exception:

469
app/api/jobs.py Normal file
View File

@@ -0,0 +1,469 @@
"""
Job Management API
Provides lightweight monitoring and management endpoints around `JobRecord`.
Notes:
- This is not a background worker. It exposes status/history/metrics for jobs
recorded by various synchronous operations (e.g., documents batch generation).
- Retry creates a new queued record that references the original job. Actual
processing is not scheduled here.
"""
from __future__ import annotations
from typing import Any, Dict, List, Optional, Union
from datetime import datetime, timezone
from uuid import uuid4
from fastapi import APIRouter, Depends, HTTPException, Query, status, Request
from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy.orm import Session
from sqlalchemy import func
from app.database.base import get_db
from app.auth.security import get_current_user, get_admin_user
from app.models.user import User
from app.models.jobs import JobRecord
from app.services.query_utils import apply_sorting, paginate_with_total, tokenized_ilike_filter
from app.services.storage import get_default_storage
from app.services.audit import audit_service
from app.utils.logging import app_logger
router = APIRouter()
# --------------------
# Pydantic Schemas
# --------------------
class JobRecordResponse(BaseModel):
id: int
job_id: str
job_type: str
status: str
requested_by_username: Optional[str] = None
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
total_requested: int = 0
total_success: int = 0
total_failed: int = 0
has_result_bundle: bool = False
bundle_url: Optional[str] = None
bundle_size: Optional[int] = None
duration_seconds: Optional[float] = None
details: Optional[Dict[str, Any]] = None
model_config = ConfigDict(from_attributes=True)
class PaginatedJobsResponse(BaseModel):
items: List[JobRecordResponse]
total: int
class JobFailRequest(BaseModel):
reason: str = Field(..., min_length=1, max_length=1000)
details_update: Optional[Dict[str, Any]] = None
class JobCompletionUpdate(BaseModel):
total_success: Optional[int] = None
total_failed: Optional[int] = None
result_storage_path: Optional[str] = None
result_mime_type: Optional[str] = None
result_size: Optional[int] = None
details_update: Optional[Dict[str, Any]] = None
class RetryRequest(BaseModel):
note: Optional[str] = None
class JobsMetricsResponse(BaseModel):
by_status: Dict[str, int]
by_type: Dict[str, int]
avg_duration_seconds: Optional[float] = None
running_count: int
failed_last_24h: int
completed_last_24h: int
# --------------------
# Helpers
# --------------------
def _compute_duration_seconds(started_at: Optional[datetime], completed_at: Optional[datetime]) -> Optional[float]:
if not started_at or not completed_at:
return None
try:
start_utc = started_at if started_at.tzinfo else started_at.replace(tzinfo=timezone.utc)
end_utc = completed_at if completed_at.tzinfo else completed_at.replace(tzinfo=timezone.utc)
return max((end_utc - start_utc).total_seconds(), 0.0)
except Exception:
return None
def _to_response(
job: JobRecord,
*,
include_url: bool = False,
) -> JobRecordResponse:
has_bundle = bool(getattr(job, "result_storage_path", None))
bundle_url = None
if include_url and has_bundle:
try:
bundle_url = get_default_storage().public_url(job.result_storage_path) # type: ignore[arg-type]
except Exception:
bundle_url = None
return JobRecordResponse(
id=job.id,
job_id=job.job_id,
job_type=job.job_type,
status=job.status,
requested_by_username=getattr(job, "requested_by_username", None),
started_at=getattr(job, "started_at", None),
completed_at=getattr(job, "completed_at", None),
total_requested=getattr(job, "total_requested", 0) or 0,
total_success=getattr(job, "total_success", 0) or 0,
total_failed=getattr(job, "total_failed", 0) or 0,
has_result_bundle=has_bundle,
bundle_url=bundle_url,
bundle_size=getattr(job, "result_size", None),
duration_seconds=_compute_duration_seconds(getattr(job, "started_at", None), getattr(job, "completed_at", None)),
details=getattr(job, "details", None),
)
# --------------------
# Endpoints
# --------------------
@router.get("/", response_model=Union[List[JobRecordResponse], PaginatedJobsResponse])
async def list_jobs(
skip: int = Query(0, ge=0),
limit: int = Query(50, ge=1, le=200),
include_total: bool = Query(False, description="When true, returns {items, total} instead of a plain list"),
include_urls: bool = Query(False, description="Include bundle URLs in responses"),
status_filter: Optional[str] = Query(None, description="Filter by status"),
type_filter: Optional[str] = Query(None, description="Filter by job type"),
requested_by: Optional[str] = Query(None, description="Filter by username"),
search: Optional[str] = Query(None, description="Tokenized search across job_id, type, status, username"),
mine: bool = Query(True, description="When true, restricts to current user's jobs (admins can set false)"),
sort_by: Optional[str] = Query("started", description="Sort by: started, completed, status, type"),
sort_dir: Optional[str] = Query("desc", description="Sort direction: asc or desc"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
query = db.query(JobRecord)
# Scope: non-admin users always restricted to their jobs
is_admin = bool(getattr(current_user, "is_admin", False))
if mine or not is_admin:
query = query.filter(JobRecord.requested_by_username == current_user.username)
if status_filter:
query = query.filter(JobRecord.status == status_filter)
if type_filter:
query = query.filter(JobRecord.job_type == type_filter)
if requested_by and is_admin:
query = query.filter(JobRecord.requested_by_username == requested_by)
if search:
tokens = [t for t in (search or "").split() if t]
filter_expr = tokenized_ilike_filter(tokens, [
JobRecord.job_id,
JobRecord.job_type,
JobRecord.status,
JobRecord.requested_by_username,
])
if filter_expr is not None:
query = query.filter(filter_expr)
# Sorting
query = apply_sorting(
query,
sort_by,
sort_dir,
allowed={
"started": [JobRecord.started_at, JobRecord.id],
"completed": [JobRecord.completed_at, JobRecord.id],
"status": [JobRecord.status, JobRecord.started_at],
"type": [JobRecord.job_type, JobRecord.started_at],
},
)
jobs, total = paginate_with_total(query, skip, limit, include_total)
items = [_to_response(j, include_url=include_urls) for j in jobs]
if include_total:
return {"items": items, "total": total or 0}
return items
@router.get("/{job_id}", response_model=JobRecordResponse)
async def get_job(
job_id: str,
include_url: bool = Query(True),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
job = db.query(JobRecord).filter(JobRecord.job_id == job_id).first()
if not job:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Job not found")
# Authorization: non-admin users can only access their jobs
if not getattr(current_user, "is_admin", False):
if getattr(job, "requested_by_username", None) != current_user.username:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not enough permissions")
return _to_response(job, include_url=include_url)
@router.post("/{job_id}/mark-failed", response_model=JobRecordResponse)
async def mark_job_failed(
job_id: str,
payload: JobFailRequest,
request: Request,
db: Session = Depends(get_db),
current_user: User = Depends(get_admin_user),
):
job = db.query(JobRecord).filter(JobRecord.job_id == job_id).first()
if not job:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Job not found")
job.status = "failed"
job.completed_at = datetime.now(timezone.utc)
details = dict(getattr(job, "details", {}) or {})
details["last_error"] = payload.reason
if payload.details_update:
details.update(payload.details_update)
job.details = details
db.commit()
db.refresh(job)
try:
audit_service.log_action(
db=db,
action="FAIL",
resource_type="JOB",
user=current_user,
resource_id=job.job_id,
details={"reason": payload.reason},
request=request,
)
except Exception:
pass
return _to_response(job, include_url=True)
@router.post("/{job_id}/mark-running", response_model=JobRecordResponse)
async def mark_job_running(
job_id: str,
request: Request,
db: Session = Depends(get_db),
current_user: User = Depends(get_admin_user),
):
job = db.query(JobRecord).filter(JobRecord.job_id == job_id).first()
if not job:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Job not found")
job.status = "running"
# Reset start time when transitioning to running
job.started_at = datetime.now(timezone.utc)
job.completed_at = None
db.commit()
db.refresh(job)
try:
audit_service.log_action(
db=db,
action="RUNNING",
resource_type="JOB",
user=current_user,
resource_id=job.job_id,
details=None,
request=request,
)
except Exception:
pass
return _to_response(job)
@router.post("/{job_id}/mark-completed", response_model=JobRecordResponse)
async def mark_job_completed(
job_id: str,
payload: JobCompletionUpdate,
request: Request,
db: Session = Depends(get_db),
current_user: User = Depends(get_admin_user),
):
job = db.query(JobRecord).filter(JobRecord.job_id == job_id).first()
if not job:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Job not found")
job.status = "completed"
job.completed_at = datetime.now(timezone.utc)
if payload.total_success is not None:
job.total_success = max(int(payload.total_success), 0)
if payload.total_failed is not None:
job.total_failed = max(int(payload.total_failed), 0)
if payload.result_storage_path is not None:
job.result_storage_path = payload.result_storage_path
if payload.result_mime_type is not None:
job.result_mime_type = payload.result_mime_type
if payload.result_size is not None:
job.result_size = max(int(payload.result_size), 0)
if payload.details_update:
details = dict(getattr(job, "details", {}) or {})
details.update(payload.details_update)
job.details = details
db.commit()
db.refresh(job)
try:
audit_service.log_action(
db=db,
action="COMPLETE",
resource_type="JOB",
user=current_user,
resource_id=job.job_id,
details={
"total_success": job.total_success,
"total_failed": job.total_failed,
},
request=request,
)
except Exception:
pass
return _to_response(job, include_url=True)
@router.post("/{job_id}/retry")
async def retry_job(
job_id: str,
payload: RetryRequest,
request: Request,
db: Session = Depends(get_db),
current_user: User = Depends(get_admin_user),
):
"""
Create a new queued job record that references the original job.
This endpoint does not execute the job; it enables monitoring UIs to
track retry intent and external workers to pick it up if/when implemented.
"""
job = db.query(JobRecord).filter(JobRecord.job_id == job_id).first()
if not job:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Job not found")
new_job_id = uuid4().hex
new_details = dict(getattr(job, "details", {}) or {})
new_details["retry_of"] = job.job_id
if payload.note:
new_details["retry_note"] = payload.note
cloned = JobRecord(
job_id=new_job_id,
job_type=job.job_type,
status="queued",
requested_by_username=current_user.username,
started_at=datetime.now(timezone.utc),
completed_at=None,
total_requested=getattr(job, "total_requested", 0) or 0,
total_success=0,
total_failed=0,
result_storage_path=None,
result_mime_type=None,
result_size=None,
details=new_details,
)
db.add(cloned)
db.commit()
try:
audit_service.log_action(
db=db,
action="RETRY",
resource_type="JOB",
user=current_user,
resource_id=job.job_id,
details={"new_job_id": new_job_id},
request=request,
)
except Exception:
pass
return {"message": "Retry created", "job_id": new_job_id}
@router.get("/metrics/summary", response_model=JobsMetricsResponse)
async def jobs_metrics(
db: Session = Depends(get_db),
current_user: User = Depends(get_admin_user),
):
"""
Basic metrics for dashboards/monitoring.
"""
# By status
rows = db.query(JobRecord.status, func.count(JobRecord.id)).group_by(JobRecord.status).all()
by_status = {str(k or "unknown"): int(v or 0) for k, v in rows}
# By type
rows = db.query(JobRecord.job_type, func.count(JobRecord.id)).group_by(JobRecord.job_type).all()
by_type = {str(k or "unknown"): int(v or 0) for k, v in rows}
# Running count
try:
running_count = db.query(func.count(JobRecord.id)).filter(JobRecord.status == "running").scalar() or 0
except Exception:
running_count = 0
# Last 24h stats
cutoff = datetime.now(timezone.utc).replace(microsecond=0)
try:
failed_last_24h = db.query(func.count(JobRecord.id)).filter(
JobRecord.status == "failed",
(JobRecord.completed_at != None), # noqa: E711
JobRecord.completed_at >= (cutoff.replace(hour=0, minute=0, second=0) - func.cast(1, func.INTEGER)) # type: ignore
).scalar() or 0
except Exception:
# Fallback without date condition if backend doesn't support the above cast
failed_last_24h = db.query(func.count(JobRecord.id)).filter(JobRecord.status == "failed").scalar() or 0
try:
completed_last_24h = db.query(func.count(JobRecord.id)).filter(
JobRecord.status == "completed",
(JobRecord.completed_at != None), # noqa: E711
JobRecord.completed_at >= (cutoff.replace(hour=0, minute=0, second=0) - func.cast(1, func.INTEGER)) # type: ignore
).scalar() or 0
except Exception:
completed_last_24h = db.query(func.count(JobRecord.id)).filter(JobRecord.status == "completed").scalar() or 0
# Average duration on completed
try:
completed_jobs = db.query(JobRecord.started_at, JobRecord.completed_at).filter(JobRecord.completed_at != None).limit(500).all() # noqa: E711
durations: List[float] = []
for s, c in completed_jobs:
d = _compute_duration_seconds(s, c)
if d is not None:
durations.append(d)
avg_duration = (sum(durations) / len(durations)) if durations else None
except Exception:
avg_duration = None
return JobsMetricsResponse(
by_status=by_status,
by_type=by_type,
avg_duration_seconds=(round(avg_duration, 2) if isinstance(avg_duration, (int, float)) else None),
running_count=int(running_count),
failed_last_24h=int(failed_last_24h),
completed_last_24h=int(completed_last_24h),
)

258
app/api/labels.py Normal file
View File

@@ -0,0 +1,258 @@
"""
Mailing Labels & Envelopes API
Endpoints:
- POST /api/labels/rolodex/labels-5160
- POST /api/labels/files/labels-5160
- POST /api/labels/rolodex/envelopes
- POST /api/labels/files/envelopes
"""
from __future__ import annotations
from typing import List, Optional, Sequence
from datetime import datetime, timezone
from fastapi import APIRouter, Depends, HTTPException, status, Query
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
import io
import csv
from app.auth.security import get_current_user
from app.database.base import get_db
from app.models.user import User
from app.models.rolodex import Rolodex
from app.services.customers_search import apply_customer_filters
from app.services.mailing import (
Address,
build_addresses_from_files,
build_addresses_from_rolodex,
build_address_from_rolodex,
render_labels_html,
render_envelopes_html,
save_html_bytes,
)
router = APIRouter()
class Labels5160Request(BaseModel):
ids: List[str] = Field(default_factory=list, description="Rolodex IDs or File numbers depending on route")
start_position: int = Field(default=1, ge=1, le=30, description="Starting label position on sheet (1-30)")
include_name: bool = Field(default=True, description="Include name/company as first line")
class GenerateResult(BaseModel):
url: Optional[str] = None
storage_path: Optional[str] = None
mime_type: str
size: int
created_at: str
@router.post("/rolodex/labels-5160", response_model=GenerateResult)
async def generate_rolodex_labels(
payload: Labels5160Request,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
if not payload.ids:
raise HTTPException(status_code=400, detail="No rolodex IDs provided")
addresses = build_addresses_from_rolodex(db, payload.ids)
if not addresses:
raise HTTPException(status_code=404, detail="No matching rolodex entries found")
html_bytes = render_labels_html(addresses, start_position=payload.start_position, include_name=payload.include_name)
result = save_html_bytes(html_bytes, filename_hint=f"labels_5160_{datetime.now(timezone.utc).strftime('%Y%m%d_%H%M%S')}", subdir="mailing/labels")
return GenerateResult(**result)
@router.post("/files/labels-5160", response_model=GenerateResult)
async def generate_file_labels(
payload: Labels5160Request,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
if not payload.ids:
raise HTTPException(status_code=400, detail="No file numbers provided")
addresses = build_addresses_from_files(db, payload.ids)
if not addresses:
raise HTTPException(status_code=404, detail="No matching file owners found")
html_bytes = render_labels_html(addresses, start_position=payload.start_position, include_name=payload.include_name)
result = save_html_bytes(html_bytes, filename_hint=f"labels_5160_{datetime.now(timezone.utc).strftime('%Y%m%d_%H%M%S')}", subdir="mailing/labels")
return GenerateResult(**result)
class EnvelopesRequest(BaseModel):
ids: List[str] = Field(default_factory=list, description="Rolodex IDs or File numbers depending on route")
include_name: bool = Field(default=True)
return_address_lines: Optional[List[str]] = Field(default=None, description="Lines for return address (top-left)")
@router.post("/rolodex/envelopes", response_model=GenerateResult)
async def generate_rolodex_envelopes(
payload: EnvelopesRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
if not payload.ids:
raise HTTPException(status_code=400, detail="No rolodex IDs provided")
addresses = build_addresses_from_rolodex(db, payload.ids)
if not addresses:
raise HTTPException(status_code=404, detail="No matching rolodex entries found")
html_bytes = render_envelopes_html(addresses, return_address_lines=payload.return_address_lines, include_name=payload.include_name)
result = save_html_bytes(html_bytes, filename_hint=f"envelopes_{datetime.now(timezone.utc).strftime('%Y%m%d_%H%M%S')}", subdir="mailing/envelopes")
return GenerateResult(**result)
@router.post("/files/envelopes", response_model=GenerateResult)
async def generate_file_envelopes(
payload: EnvelopesRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
if not payload.ids:
raise HTTPException(status_code=400, detail="No file numbers provided")
addresses = build_addresses_from_files(db, payload.ids)
if not addresses:
raise HTTPException(status_code=404, detail="No matching file owners found")
html_bytes = render_envelopes_html(addresses, return_address_lines=payload.return_address_lines, include_name=payload.include_name)
result = save_html_bytes(html_bytes, filename_hint=f"envelopes_{datetime.now(timezone.utc).strftime('%Y%m%d_%H%M%S')}", subdir="mailing/envelopes")
return GenerateResult(**result)
@router.get("/rolodex/labels-5160/export")
async def export_rolodex_labels_5160(
start_position: int = Query(1, ge=1, le=30, description="Starting label position on sheet (1-30)"),
include_name: bool = Query(True, description="Include name/company as first line"),
group: Optional[str] = Query(None, description="Filter by customer group (exact match)"),
groups: Optional[List[str]] = Query(None, description="Filter by multiple groups (repeat param)"),
name_prefix: Optional[str] = Query(None, description="Prefix search across first/last name"),
format: str = Query("html", description="Output format: html | csv"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Generate Avery 5160 labels for Rolodex entries selected by filters and stream as HTML or CSV."""
fmt = (format or "").strip().lower()
if fmt not in {"html", "csv"}:
raise HTTPException(status_code=400, detail="Invalid format. Use 'html' or 'csv'.")
q = db.query(Rolodex)
q = apply_customer_filters(
q,
search=None,
group=group,
state=None,
groups=groups,
states=None,
name_prefix=name_prefix,
)
entries = q.all()
if not entries:
raise HTTPException(status_code=404, detail="No matching rolodex entries found")
if fmt == "html":
addresses = [build_address_from_rolodex(r) for r in entries]
html_bytes = render_labels_html(addresses, start_position=start_position, include_name=include_name)
from fastapi.responses import StreamingResponse
ts = datetime.now(timezone.utc).strftime('%Y%m%d_%H%M%S')
filename = f"labels_5160_{ts}.html"
return StreamingResponse(
iter([html_bytes]),
media_type="text/html",
headers={"Content-Disposition": f"attachment; filename=\"{filename}\""},
)
else:
# CSV of address fields
output = io.StringIO()
writer = csv.writer(output)
writer.writerow(["Name", "Address 1", "Address 2", "Address 3", "City", "State", "ZIP"])
for r in entries:
addr = build_address_from_rolodex(r)
writer.writerow([
addr.display_name,
r.a1 or "",
r.a2 or "",
r.a3 or "",
r.city or "",
r.abrev or "",
r.zip or "",
])
output.seek(0)
from fastapi.responses import StreamingResponse
ts = datetime.now(timezone.utc).strftime('%Y%m%d_%H%M%S')
filename = f"labels_5160_{ts}.csv"
return StreamingResponse(
iter([output.getvalue()]),
media_type="text/csv",
headers={"Content-Disposition": f"attachment; filename=\"{filename}\""},
)
@router.get("/rolodex/envelopes/export")
async def export_rolodex_envelopes(
include_name: bool = Query(True, description="Include name/company"),
return_address_lines: Optional[List[str]] = Query(None, description="Optional return address lines"),
group: Optional[str] = Query(None, description="Filter by customer group (exact match)"),
groups: Optional[List[str]] = Query(None, description="Filter by multiple groups (repeat param)"),
name_prefix: Optional[str] = Query(None, description="Prefix search across first/last name"),
format: str = Query("html", description="Output format: html | csv"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Generate envelopes for Rolodex entries selected by filters and stream as HTML or CSV of addresses."""
fmt = (format or "").strip().lower()
if fmt not in {"html", "csv"}:
raise HTTPException(status_code=400, detail="Invalid format. Use 'html' or 'csv'.")
q = db.query(Rolodex)
q = apply_customer_filters(
q,
search=None,
group=group,
state=None,
groups=groups,
states=None,
name_prefix=name_prefix,
)
entries = q.all()
if not entries:
raise HTTPException(status_code=404, detail="No matching rolodex entries found")
if fmt == "html":
addresses = [build_address_from_rolodex(r) for r in entries]
html_bytes = render_envelopes_html(addresses, return_address_lines=return_address_lines, include_name=include_name)
from fastapi.responses import StreamingResponse
ts = datetime.now(timezone.utc).strftime('%Y%m%d_%H%M%S')
filename = f"envelopes_{ts}.html"
return StreamingResponse(
iter([html_bytes]),
media_type="text/html",
headers={"Content-Disposition": f"attachment; filename=\"{filename}\""},
)
else:
# CSV of address fields
output = io.StringIO()
writer = csv.writer(output)
writer.writerow(["Name", "Address 1", "Address 2", "Address 3", "City", "State", "ZIP"])
for r in entries:
addr = build_address_from_rolodex(r)
writer.writerow([
addr.display_name,
r.a1 or "",
r.a2 or "",
r.a3 or "",
r.city or "",
r.abrev or "",
r.zip or "",
])
output.seek(0)
from fastapi.responses import StreamingResponse
ts = datetime.now(timezone.utc).strftime('%Y%m%d_%H%M%S')
filename = f"envelopes_{ts}.csv"
return StreamingResponse(
iter([output.getvalue()]),
media_type="text/csv",
headers={"Content-Disposition": f"attachment; filename=\"{filename}\""},
)

View File

@@ -0,0 +1,230 @@
"""
Pension Valuation API endpoints
Exposes endpoints under /api/pensions/valuation for:
- Single-life present value
- Joint-survivor present value
"""
from __future__ import annotations
from typing import Dict, Optional, List, Union, Any
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from app.database.base import get_db
from app.models.user import User
from app.auth.security import get_current_user
from app.services.pension_valuation import (
SingleLifeInputs,
JointSurvivorInputs,
present_value_single_life,
present_value_joint_survivor,
)
router = APIRouter(prefix="/valuation", tags=["pensions", "pensions-valuation"])
class SingleLifeRequest(BaseModel):
monthly_benefit: float = Field(ge=0)
term_months: int = Field(ge=0, description="Number of months in evaluation horizon")
start_age: Optional[int] = Field(default=None, ge=0)
sex: str = Field(description="M, F, or A (all)")
race: str = Field(description="W, B, H, or A (all)")
discount_rate: float = Field(default=0.0, ge=0, description="Annual percent, e.g. 3.0")
cola_rate: float = Field(default=0.0, ge=0, description="Annual percent, e.g. 2.0")
defer_months: float = Field(default=0, ge=0, description="Months to delay first payment (supports fractional)")
payment_period_months: int = Field(default=1, ge=1, description="Months per payment (1=monthly, 3=quarterly, 12=annual)")
certain_months: int = Field(default=0, ge=0, description="Guaranteed months from commencement regardless of mortality")
cola_mode: str = Field(default="monthly", description="'monthly' or 'annual_prorated'")
cola_cap_percent: Optional[float] = Field(default=None, ge=0)
interpolation_method: str = Field(default="linear", description="'linear' or 'step' for NA interpolation")
max_age: Optional[int] = Field(default=None, ge=0, description="Optional cap on participant age for term truncation")
class SingleLifeResponse(BaseModel):
pv: float
@router.post("/single-life", response_model=SingleLifeResponse)
async def value_single_life(
payload: SingleLifeRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
try:
pv = present_value_single_life(
db,
SingleLifeInputs(
monthly_benefit=payload.monthly_benefit,
term_months=payload.term_months,
start_age=payload.start_age,
sex=payload.sex,
race=payload.race,
discount_rate=payload.discount_rate,
cola_rate=payload.cola_rate,
defer_months=payload.defer_months,
payment_period_months=payload.payment_period_months,
certain_months=payload.certain_months,
cola_mode=payload.cola_mode,
cola_cap_percent=payload.cola_cap_percent,
interpolation_method=payload.interpolation_method,
max_age=payload.max_age,
),
)
except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
return SingleLifeResponse(pv=float(round(pv, 2)))
class JointSurvivorRequest(BaseModel):
monthly_benefit: float = Field(ge=0)
term_months: int = Field(ge=0)
participant_age: Optional[int] = Field(default=None, ge=0)
participant_sex: str
participant_race: str
spouse_age: Optional[int] = Field(default=None, ge=0)
spouse_sex: str
spouse_race: str
survivor_percent: float = Field(ge=0, le=100, description="Percent of benefit to spouse on participant death")
discount_rate: float = Field(default=0.0, ge=0)
cola_rate: float = Field(default=0.0, ge=0)
defer_months: float = Field(default=0, ge=0)
payment_period_months: int = Field(default=1, ge=1)
certain_months: int = Field(default=0, ge=0)
cola_mode: str = Field(default="monthly")
cola_cap_percent: Optional[float] = Field(default=None, ge=0)
survivor_basis: str = Field(default="contingent", description="'contingent' or 'last_survivor'")
survivor_commence_participant_only: bool = Field(default=False, description="If true, survivor component uses participant survival as commencement basis")
interpolation_method: str = Field(default="linear")
max_age: Optional[int] = Field(default=None, ge=0)
class JointSurvivorResponse(BaseModel):
pv_total: float
pv_participant_component: float
pv_survivor_component: float
@router.post("/joint-survivor", response_model=JointSurvivorResponse)
async def value_joint_survivor(
payload: JointSurvivorRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
try:
result: Dict[str, float] = present_value_joint_survivor(
db,
JointSurvivorInputs(
monthly_benefit=payload.monthly_benefit,
term_months=payload.term_months,
participant_age=payload.participant_age,
participant_sex=payload.participant_sex,
participant_race=payload.participant_race,
spouse_age=payload.spouse_age,
spouse_sex=payload.spouse_sex,
spouse_race=payload.spouse_race,
survivor_percent=payload.survivor_percent,
discount_rate=payload.discount_rate,
cola_rate=payload.cola_rate,
defer_months=payload.defer_months,
payment_period_months=payload.payment_period_months,
certain_months=payload.certain_months,
cola_mode=payload.cola_mode,
cola_cap_percent=payload.cola_cap_percent,
survivor_basis=payload.survivor_basis,
survivor_commence_participant_only=payload.survivor_commence_participant_only,
interpolation_method=payload.interpolation_method,
max_age=payload.max_age,
),
)
except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
# Round to 2 decimals for response
return JointSurvivorResponse(
pv_total=float(round(result["pv_total"], 2)),
pv_participant_component=float(round(result["pv_participant_component"], 2)),
pv_survivor_component=float(round(result["pv_survivor_component"], 2)),
)
class ErrorResponse(BaseModel):
error: str
class BatchSingleLifeRequest(BaseModel):
# Accept raw dicts to allow per-item validation inside the loop (avoid 422 on the entire batch)
items: List[Dict[str, Any]]
class BatchSingleLifeItemResponse(BaseModel):
success: bool
result: Optional[SingleLifeResponse] = None
error: Optional[str] = None
class BatchSingleLifeResponse(BaseModel):
results: List[BatchSingleLifeItemResponse]
@router.post("/batch-single-life", response_model=BatchSingleLifeResponse)
async def batch_value_single_life(
payload: BatchSingleLifeRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
results = []
for item in payload.items:
try:
inputs = SingleLifeInputs(**item)
pv = present_value_single_life(db, inputs)
results.append(BatchSingleLifeItemResponse(
success=True,
result=SingleLifeResponse(pv=float(round(pv, 2))),
))
except ValueError as e:
results.append(BatchSingleLifeItemResponse(
success=False,
error=str(e),
))
return BatchSingleLifeResponse(results=results)
class BatchJointSurvivorRequest(BaseModel):
# Accept raw dicts to allow per-item validation inside the loop (avoid 422 on the entire batch)
items: List[Dict[str, Any]]
class BatchJointSurvivorItemResponse(BaseModel):
success: bool
result: Optional[JointSurvivorResponse] = None
error: Optional[str] = None
class BatchJointSurvivorResponse(BaseModel):
results: List[BatchJointSurvivorItemResponse]
@router.post("/batch-joint-survivor", response_model=BatchJointSurvivorResponse)
async def batch_value_joint_survivor(
payload: BatchJointSurvivorRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
results = []
for item in payload.items:
try:
inputs = JointSurvivorInputs(**item)
result = present_value_joint_survivor(db, inputs)
results.append(BatchJointSurvivorItemResponse(
success=True,
result=JointSurvivorResponse(
pv_total=float(round(result["pv_total"], 2)),
pv_participant_component=float(round(result["pv_participant_component"], 2)),
pv_survivor_component=float(round(result["pv_survivor_component"], 2)),
),
))
except ValueError as e:
results.append(BatchJointSurvivorItemResponse(
success=False,
error=str(e),
))
return BatchJointSurvivorResponse(results=results)

View File

@@ -579,7 +579,7 @@ async def generate_qdro_document(
"MATTER": file_obj.regarding,
})
# Merge with provided context
context = build_context({**base_ctx, **(payload.context or {})})
context = build_context({**base_ctx, **(payload.context or {})}, "file", qdro.file_no)
resolved, unresolved = resolve_tokens(db, tokens, context)
output_bytes = content
@@ -591,7 +591,21 @@ async def generate_qdro_document(
audit_service.log_action(db, action="GENERATE", resource_type="QDRO", user=current_user, resource_id=qdro_id, details={"template_id": payload.template_id, "version_id": version_id, "unresolved": unresolved})
except Exception:
pass
return GenerateResponse(resolved=resolved, unresolved=unresolved, output_mime_type=output_mime, output_size=len(output_bytes))
# Sanitize resolved values to ensure JSON-serializable output
def _json_sanitize(value: Any) -> Any:
if value is None or isinstance(value, (str, int, float, bool)):
return value
if isinstance(value, (date, datetime)):
return value.isoformat()
if isinstance(value, (list, tuple)):
return [_json_sanitize(v) for v in value]
if isinstance(value, dict):
return {k: _json_sanitize(v) for k, v in value.items()}
# Fallback: stringify unsupported types (e.g., functions)
return str(value)
sanitized_resolved = {k: _json_sanitize(v) for k, v in resolved.items()}
return GenerateResponse(resolved=sanitized_resolved, unresolved=unresolved, output_mime_type=output_mime, output_size=len(output_bytes))
class PlanInfoCreate(BaseModel):

View File

@@ -368,8 +368,10 @@ async def advanced_search(
# Cache lookup keyed by user and entire criteria (including pagination)
try:
cached = await cache_get_json(
kind="advanced",
from app.services.adaptive_cache import adaptive_cache_get
cached = await adaptive_cache_get(
cache_type="advanced",
cache_key="advanced_search",
user_id=str(getattr(current_user, "id", "")),
parts={"criteria": criteria.model_dump(mode="json")},
)
@@ -438,14 +440,15 @@ async def advanced_search(
page_info=page_info,
)
# Store in cache (best-effort)
# Store in cache with adaptive TTL
try:
await cache_set_json(
kind="advanced",
user_id=str(getattr(current_user, "id", "")),
parts={"criteria": criteria.model_dump(mode="json")},
from app.services.adaptive_cache import adaptive_cache_set
await adaptive_cache_set(
cache_type="advanced",
cache_key="advanced_search",
value=response.model_dump(mode="json"),
ttl_seconds=90,
user_id=str(getattr(current_user, "id", "")),
parts={"criteria": criteria.model_dump(mode="json")}
)
except Exception:
pass
@@ -462,9 +465,11 @@ async def global_search(
):
"""Enhanced global search across all entities"""
start_time = datetime.now()
# Cache lookup
cached = await cache_get_json(
kind="global",
# Cache lookup with adaptive tracking
from app.services.adaptive_cache import adaptive_cache_get
cached = await adaptive_cache_get(
cache_type="global",
cache_key="global_search",
user_id=str(getattr(current_user, "id", "")),
parts={"q": q, "limit": limit},
)
@@ -505,12 +510,13 @@ async def global_search(
phones=phone_results[:limit]
)
try:
await cache_set_json(
kind="global",
user_id=str(getattr(current_user, "id", "")),
parts={"q": q, "limit": limit},
from app.services.adaptive_cache import adaptive_cache_set
await adaptive_cache_set(
cache_type="global",
cache_key="global_search",
value=response.model_dump(mode="json"),
ttl_seconds=90,
user_id=str(getattr(current_user, "id", "")),
parts={"q": q, "limit": limit}
)
except Exception:
pass

View File

@@ -0,0 +1,503 @@
"""
Session Management API for P2 security features
"""
from datetime import datetime, timezone, timedelta
from typing import List, Optional, Dict, Any
from fastapi import APIRouter, Depends, HTTPException, status, Request
from sqlalchemy.orm import Session
from pydantic import BaseModel, Field
from app.database.base import get_db
from app.auth.security import get_current_user, get_admin_user
from app.models.user import User
from app.models.sessions import UserSession, SessionConfiguration, SessionSecurityEvent
from app.utils.session_manager import SessionManager, get_session_manager
from app.utils.responses import create_success_response as success_response
from app.core.logging import get_logger
logger = get_logger(__name__)
router = APIRouter(prefix="/api/session", tags=["Session Management"])
# Pydantic schemas
class SessionInfo(BaseModel):
"""Session information response"""
session_id: str
user_id: int
ip_address: Optional[str] = None
user_agent: Optional[str] = None
device_fingerprint: Optional[str] = None
country: Optional[str] = None
city: Optional[str] = None
is_suspicious: bool = False
risk_score: int = 0
status: str
created_at: datetime
last_activity: datetime
expires_at: datetime
login_method: Optional[str] = None
class Config:
from_attributes = True
class SessionConfigurationSchema(BaseModel):
"""Session configuration schema"""
max_concurrent_sessions: int = Field(default=3, ge=1, le=20)
session_timeout_minutes: int = Field(default=480, ge=30, le=1440) # 30 min to 24 hours
idle_timeout_minutes: int = Field(default=60, ge=5, le=240) # 5 min to 4 hours
require_session_renewal: bool = True
renewal_interval_hours: int = Field(default=24, ge=1, le=168) # 1 hour to 1 week
force_logout_on_ip_change: bool = False
suspicious_activity_threshold: int = Field(default=5, ge=1, le=20)
allowed_countries: Optional[List[str]] = None
blocked_countries: Optional[List[str]] = None
class SessionConfigurationUpdate(BaseModel):
"""Session configuration update request"""
max_concurrent_sessions: Optional[int] = Field(None, ge=1, le=20)
session_timeout_minutes: Optional[int] = Field(None, ge=30, le=1440)
idle_timeout_minutes: Optional[int] = Field(None, ge=5, le=240)
require_session_renewal: Optional[bool] = None
renewal_interval_hours: Optional[int] = Field(None, ge=1, le=168)
force_logout_on_ip_change: Optional[bool] = None
suspicious_activity_threshold: Optional[int] = Field(None, ge=1, le=20)
allowed_countries: Optional[List[str]] = None
blocked_countries: Optional[List[str]] = None
class SecurityEventInfo(BaseModel):
"""Security event information"""
id: int
event_type: str
severity: str
description: str
ip_address: Optional[str] = None
country: Optional[str] = None
action_taken: Optional[str] = None
resolved: bool = False
timestamp: datetime
class Config:
from_attributes = True
# Session Management Endpoints
@router.get("/current", response_model=SessionInfo)
async def get_current_session(
request: Request,
current_user: User = Depends(get_current_user),
session_manager: SessionManager = Depends(get_session_manager)
):
"""Get current session information"""
try:
# Extract session ID from request
session_id = request.headers.get("X-Session-ID") or request.cookies.get("session_id")
if not session_id:
# For JWT-based sessions, use a portion of the JWT as session identifier
auth_header = request.headers.get("authorization")
if auth_header and auth_header.startswith("Bearer "):
session_id = auth_header[7:][:32]
if not session_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="No session identifier found"
)
session = session_manager.validate_session(session_id, request)
if not session:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Session not found or expired"
)
return SessionInfo.from_orm(session)
except Exception as e:
logger.error(f"Error getting current session: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to retrieve session information"
)
@router.get("/list", response_model=List[SessionInfo])
async def list_user_sessions(
current_user: User = Depends(get_current_user),
session_manager: SessionManager = Depends(get_session_manager)
):
"""List all active sessions for current user"""
try:
sessions = session_manager.get_active_sessions(current_user.id)
return [SessionInfo.from_orm(session) for session in sessions]
except Exception as e:
logger.error(f"Error listing sessions: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to retrieve sessions"
)
@router.delete("/revoke/{session_id}")
async def revoke_session(
session_id: str,
current_user: User = Depends(get_current_user),
session_manager: SessionManager = Depends(get_session_manager)
):
"""Revoke a specific session"""
try:
# Verify the session belongs to the current user
session = session_manager.db.query(UserSession).filter(
UserSession.session_id == session_id,
UserSession.user_id == current_user.id
).first()
if not session:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Session not found"
)
success = session_manager.revoke_session(session_id, "user_revocation")
if success:
return success_response("Session revoked successfully")
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Failed to revoke session"
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error revoking session: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to revoke session"
)
@router.delete("/revoke-all")
async def revoke_all_sessions(
current_user: User = Depends(get_current_user),
session_manager: SessionManager = Depends(get_session_manager)
):
"""Revoke all sessions for current user"""
try:
count = session_manager.revoke_all_user_sessions(current_user.id, "user_revoke_all")
return success_response(f"Revoked {count} sessions successfully")
except Exception as e:
logger.error(f"Error revoking all sessions: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to revoke sessions"
)
@router.get("/configuration", response_model=SessionConfigurationSchema)
async def get_session_configuration(
current_user: User = Depends(get_current_user),
session_manager: SessionManager = Depends(get_session_manager)
):
"""Get session configuration for current user"""
try:
config = session_manager._get_session_config(current_user)
return SessionConfigurationSchema(
max_concurrent_sessions=config.max_concurrent_sessions,
session_timeout_minutes=config.session_timeout_minutes,
idle_timeout_minutes=config.idle_timeout_minutes,
require_session_renewal=config.require_session_renewal,
renewal_interval_hours=config.renewal_interval_hours,
force_logout_on_ip_change=config.force_logout_on_ip_change,
suspicious_activity_threshold=config.suspicious_activity_threshold,
allowed_countries=config.allowed_countries.split(",") if config.allowed_countries else None,
blocked_countries=config.blocked_countries.split(",") if config.blocked_countries else None
)
except Exception as e:
logger.error(f"Error getting session configuration: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to retrieve session configuration"
)
@router.put("/configuration")
async def update_session_configuration(
config_update: SessionConfigurationUpdate,
current_user: User = Depends(get_current_user),
session_manager: SessionManager = Depends(get_session_manager)
):
"""Update session configuration for current user"""
try:
config = session_manager._get_session_config(current_user)
# Ensure user-specific config exists
if config.user_id is None:
# Create user-specific config based on global config
user_config = SessionConfiguration(
user_id=current_user.id,
max_concurrent_sessions=config.max_concurrent_sessions,
session_timeout_minutes=config.session_timeout_minutes,
idle_timeout_minutes=config.idle_timeout_minutes,
require_session_renewal=config.require_session_renewal,
renewal_interval_hours=config.renewal_interval_hours,
force_logout_on_ip_change=config.force_logout_on_ip_change,
suspicious_activity_threshold=config.suspicious_activity_threshold,
allowed_countries=config.allowed_countries,
blocked_countries=config.blocked_countries
)
session_manager.db.add(user_config)
session_manager.db.flush()
config = user_config
# Update configuration
update_data = config_update.dict(exclude_unset=True)
for field, value in update_data.items():
if field in ["allowed_countries", "blocked_countries"] and value:
setattr(config, field, ",".join(value))
else:
setattr(config, field, value)
config.updated_at = datetime.now(timezone.utc)
session_manager.db.commit()
return success_response("Session configuration updated successfully")
except Exception as e:
logger.error(f"Error updating session configuration: {str(e)}")
session_manager.db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to update session configuration"
)
@router.get("/security-events", response_model=List[SecurityEventInfo])
async def get_security_events(
limit: int = 50,
resolved: Optional[bool] = None,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Get security events for current user"""
try:
query = db.query(SessionSecurityEvent).filter(
SessionSecurityEvent.user_id == current_user.id
)
if resolved is not None:
query = query.filter(SessionSecurityEvent.resolved == resolved)
events = query.order_by(SessionSecurityEvent.timestamp.desc()).limit(limit).all()
return [SecurityEventInfo.from_orm(event) for event in events]
except Exception as e:
logger.error(f"Error getting security events: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to retrieve security events"
)
@router.get("/statistics")
async def get_session_statistics(
current_user: User = Depends(get_current_user),
session_manager: SessionManager = Depends(get_session_manager)
):
"""Get session statistics for current user"""
try:
stats = session_manager.get_session_statistics(current_user.id)
return success_response(data=stats)
except Exception as e:
logger.error(f"Error getting session statistics: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to retrieve session statistics"
)
# Admin endpoints
@router.get("/admin/sessions", response_model=List[SessionInfo])
async def admin_list_all_sessions(
user_id: Optional[int] = None,
limit: int = 100,
admin_user: User = Depends(get_admin_user),
session_manager: SessionManager = Depends(get_session_manager)
):
"""Admin: List sessions for all users or specific user"""
try:
if user_id:
sessions = session_manager.get_active_sessions(user_id)
else:
sessions = session_manager.db.query(UserSession).filter(
UserSession.status == "active"
).order_by(UserSession.last_activity.desc()).limit(limit).all()
return [SessionInfo.from_orm(session) for session in sessions]
except Exception as e:
logger.error(f"Error getting admin sessions: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to retrieve sessions"
)
@router.delete("/admin/revoke/{session_id}")
async def admin_revoke_session(
session_id: str,
reason: str = "admin_revocation",
admin_user: User = Depends(get_admin_user),
session_manager: SessionManager = Depends(get_session_manager)
):
"""Admin: Revoke any session"""
try:
success = session_manager.revoke_session(session_id, f"admin_revocation: {reason}")
if success:
return success_response(f"Session {session_id} revoked successfully")
else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Session not found"
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error admin revoking session: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to revoke session"
)
@router.delete("/admin/revoke-user/{user_id}")
async def admin_revoke_user_sessions(
user_id: int,
reason: str = "admin_action",
admin_user: User = Depends(get_admin_user),
session_manager: SessionManager = Depends(get_session_manager)
):
"""Admin: Revoke all sessions for a specific user"""
try:
count = session_manager.revoke_all_user_sessions(user_id, f"admin_action: {reason}")
return success_response(f"Revoked {count} sessions for user {user_id}")
except Exception as e:
logger.error(f"Error admin revoking user sessions: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to revoke user sessions"
)
@router.get("/admin/global-configuration", response_model=SessionConfigurationSchema)
async def admin_get_global_configuration(
admin_user: User = Depends(get_admin_user),
db: Session = Depends(get_db)
):
"""Admin: Get global session configuration"""
try:
config = db.query(SessionConfiguration).filter(
SessionConfiguration.user_id.is_(None)
).first()
if not config:
# Create default global config
config = SessionConfiguration()
db.add(config)
db.commit()
return SessionConfigurationSchema(
max_concurrent_sessions=config.max_concurrent_sessions,
session_timeout_minutes=config.session_timeout_minutes,
idle_timeout_minutes=config.idle_timeout_minutes,
require_session_renewal=config.require_session_renewal,
renewal_interval_hours=config.renewal_interval_hours,
force_logout_on_ip_change=config.force_logout_on_ip_change,
suspicious_activity_threshold=config.suspicious_activity_threshold,
allowed_countries=config.allowed_countries.split(",") if config.allowed_countries else None,
blocked_countries=config.blocked_countries.split(",") if config.blocked_countries else None
)
except Exception as e:
logger.error(f"Error getting global session configuration: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to retrieve global session configuration"
)
@router.put("/admin/global-configuration")
async def admin_update_global_configuration(
config_update: SessionConfigurationUpdate,
admin_user: User = Depends(get_admin_user),
db: Session = Depends(get_db)
):
"""Admin: Update global session configuration"""
try:
config = db.query(SessionConfiguration).filter(
SessionConfiguration.user_id.is_(None)
).first()
if not config:
config = SessionConfiguration()
db.add(config)
db.flush()
# Update configuration
update_data = config_update.dict(exclude_unset=True)
for field, value in update_data.items():
if field in ["allowed_countries", "blocked_countries"] and value:
setattr(config, field, ",".join(value))
else:
setattr(config, field, value)
config.updated_at = datetime.now(timezone.utc)
db.commit()
return success_response("Global session configuration updated successfully")
except Exception as e:
logger.error(f"Error updating global session configuration: {str(e)}")
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to update global session configuration"
)
@router.get("/admin/statistics")
async def admin_get_global_statistics(
admin_user: User = Depends(get_admin_user),
session_manager: SessionManager = Depends(get_session_manager)
):
"""Admin: Get global session statistics"""
try:
stats = session_manager.get_session_statistics()
return success_response(data=stats)
except Exception as e:
logger.error(f"Error getting global session statistics: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to retrieve global session statistics"
)

View File

@@ -20,12 +20,23 @@ from sqlalchemy import func, or_, exists
import hashlib
from app.database.base import get_db
from app.auth.security import get_current_user
from app.auth.security import get_current_user, get_admin_user
from app.models.user import User
from app.models.templates import DocumentTemplate, DocumentTemplateVersion, TemplateKeyword
from app.services.storage import get_default_storage
from app.services.template_merge import extract_tokens_from_bytes, build_context, resolve_tokens, render_docx
from app.services.template_service import (
get_template_or_404,
list_template_versions as svc_list_template_versions,
add_template_version as svc_add_template_version,
resolve_template_preview as svc_resolve_template_preview,
get_download_payload as svc_get_download_payload,
)
from app.services.query_utils import paginate_with_total
from app.services.template_upload import TemplateUploadService
from app.services.template_search import TemplateSearchService
from app.config import settings
from app.services.cache import _get_client
router = APIRouter()
@@ -97,6 +108,12 @@ class PaginatedCategoriesResponse(BaseModel):
total: int
class TemplateCacheStatusResponse(BaseModel):
cache_enabled: bool
redis_available: bool
mem_cache: Dict[str, int]
@router.post("/upload", response_model=TemplateResponse)
async def upload_template(
name: str = Form(...),
@@ -107,38 +124,15 @@ async def upload_template(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
if file.content_type not in {"application/vnd.openxmlformats-officedocument.wordprocessingml.document", "application/pdf"}:
raise HTTPException(status_code=400, detail="Only .docx or .pdf templates are supported")
content = await file.read()
if not content:
raise HTTPException(status_code=400, detail="No file uploaded")
sha256 = hashlib.sha256(content).hexdigest()
storage = get_default_storage()
storage_path = storage.save_bytes(content=content, filename_hint=file.filename or "template.bin", subdir="templates")
template = DocumentTemplate(name=name, description=description, category=category, active=True, created_by=getattr(current_user, "username", None))
db.add(template)
db.flush() # get id
version = DocumentTemplateVersion(
template_id=template.id,
service = TemplateUploadService(db)
template = await service.upload_template(
name=name,
category=category,
description=description,
semantic_version=semantic_version,
storage_path=storage_path,
mime_type=file.content_type,
size=len(content),
checksum=sha256,
changelog=None,
file=file,
created_by=getattr(current_user, "username", None),
is_approved=True,
)
db.add(version)
db.flush()
template.current_version_id = version.id
db.commit()
db.refresh(template)
return TemplateResponse(
id=template.id,
name=template.name,
@@ -177,88 +171,34 @@ async def search_templates(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
query = db.query(DocumentTemplate)
if active_only:
query = query.filter(DocumentTemplate.active == True)
if q:
like = f"%{q}%"
query = query.filter(
or_(
DocumentTemplate.name.ilike(like),
DocumentTemplate.description.ilike(like),
)
)
# Category filtering (supports repeatable param and CSV within each value)
# Normalize category values including CSV-in-parameter support
categories: Optional[List[str]] = None
if category:
raw_values = category or []
categories: List[str] = []
cat_values: List[str] = []
for value in raw_values:
parts = [part.strip() for part in (value or "").split(",")]
for part in parts:
if part:
categories.append(part)
unique_categories = sorted(set(categories))
if unique_categories:
query = query.filter(DocumentTemplate.category.in_(unique_categories))
if keywords:
normalized = [kw.strip().lower() for kw in keywords if kw and kw.strip()]
unique_keywords = sorted(set(normalized))
if unique_keywords:
mode = (keywords_mode or "any").lower()
if mode not in ("any", "all"):
mode = "any"
query = query.join(TemplateKeyword, TemplateKeyword.template_id == DocumentTemplate.id)
if mode == "any":
query = query.filter(TemplateKeyword.keyword.in_(unique_keywords)).distinct()
else:
query = query.filter(TemplateKeyword.keyword.in_(unique_keywords))
query = query.group_by(DocumentTemplate.id)
query = query.having(func.count(func.distinct(TemplateKeyword.keyword)) == len(unique_keywords))
# Has keywords filter (independent of specific keyword matches)
if has_keywords is not None:
kw_exists = exists().where(TemplateKeyword.template_id == DocumentTemplate.id)
if has_keywords:
query = query.filter(kw_exists)
else:
query = query.filter(~kw_exists)
# Sorting
sort_key = (sort_by or "name").lower()
direction = (sort_dir or "asc").lower()
if sort_key not in ("name", "category", "updated"):
sort_key = "name"
if direction not in ("asc", "desc"):
direction = "asc"
cat_values.append(part)
categories = sorted(set(cat_values))
if sort_key == "name":
order_col = DocumentTemplate.name
elif sort_key == "category":
order_col = DocumentTemplate.category
else: # updated
order_col = func.coalesce(DocumentTemplate.updated_at, DocumentTemplate.created_at)
search_service = TemplateSearchService(db)
results, total = await search_service.search_templates(
q=q,
categories=categories,
keywords=keywords,
keywords_mode=keywords_mode,
has_keywords=has_keywords,
skip=skip,
limit=limit,
sort_by=sort_by or "name",
sort_dir=sort_dir or "asc",
active_only=active_only,
include_total=include_total,
)
if direction == "asc":
query = query.order_by(order_col.asc())
else:
query = query.order_by(order_col.desc())
# Pagination with optional total
templates, total = paginate_with_total(query, skip, limit, include_total)
items: List[SearchResponseItem] = []
for tpl in templates:
latest_version = None
if tpl.current_version_id:
ver = db.query(DocumentTemplateVersion).filter(DocumentTemplateVersion.id == tpl.current_version_id).first()
if ver:
latest_version = ver.semantic_version
items.append(
SearchResponseItem(
id=tpl.id,
name=tpl.name,
category=tpl.category,
active=tpl.active,
latest_version=latest_version,
)
)
items: List[SearchResponseItem] = [SearchResponseItem(**it) for it in results]
if include_total:
return {"items": items, "total": int(total or 0)}
return items
@@ -271,25 +211,65 @@ async def list_template_categories(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
query = db.query(DocumentTemplate.category, func.count(DocumentTemplate.id).label("count"))
if active_only:
query = query.filter(DocumentTemplate.active == True)
rows = query.group_by(DocumentTemplate.category).order_by(DocumentTemplate.category.asc()).all()
search_service = TemplateSearchService(db)
rows = await search_service.list_categories(active_only=active_only)
items = [CategoryCount(category=row[0], count=row[1]) for row in rows]
if include_total:
return {"items": items, "total": len(items)}
return items
@router.get("/_cache_status", response_model=TemplateCacheStatusResponse)
async def cache_status(
current_user: User = Depends(get_admin_user),
):
# In-memory cache breakdown
with TemplateSearchService._mem_lock:
keys = list(TemplateSearchService._mem_cache.keys())
mem_templates = sum(1 for k in keys if k.startswith("search:templates:"))
mem_categories = sum(1 for k in keys if k.startswith("search:templates_categories:"))
# Redis availability check (best-effort)
redis_available = False
try:
client = await _get_client()
if client is not None:
try:
pong = await client.ping()
redis_available = bool(pong)
except Exception:
redis_available = False
except Exception:
redis_available = False
return TemplateCacheStatusResponse(
cache_enabled=bool(getattr(settings, "cache_enabled", False)),
redis_available=redis_available,
mem_cache={
"templates": int(mem_templates),
"categories": int(mem_categories),
},
)
@router.post("/_cache_invalidate")
async def cache_invalidate(
current_user: User = Depends(get_admin_user),
):
try:
await TemplateSearchService.invalidate_all()
return {"cleared": True}
except Exception as e:
return {"cleared": False, "error": str(e)}
@router.get("/{template_id}", response_model=TemplateResponse)
async def get_template(
template_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
tpl = db.query(DocumentTemplate).filter(DocumentTemplate.id == template_id).first()
if not tpl:
raise HTTPException(status_code=404, detail="Template not found")
tpl = get_template_or_404(db, template_id)
return TemplateResponse(
id=tpl.id,
name=tpl.name,
@@ -306,12 +286,7 @@ async def list_versions(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
versions = (
db.query(DocumentTemplateVersion)
.filter(DocumentTemplateVersion.template_id == template_id)
.order_by(DocumentTemplateVersion.created_at.desc())
.all()
)
versions = svc_list_template_versions(db, template_id)
return [
VersionResponse(
id=v.id,
@@ -337,31 +312,18 @@ async def add_version(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
tpl = db.query(DocumentTemplate).filter(DocumentTemplate.id == template_id).first()
if not tpl:
raise HTTPException(status_code=404, detail="Template not found")
content = await file.read()
if not content:
raise HTTPException(status_code=400, detail="No file uploaded")
sha256 = hashlib.sha256(content).hexdigest()
storage = get_default_storage()
storage_path = storage.save_bytes(content=content, filename_hint=file.filename or "template.bin", subdir="templates")
version = DocumentTemplateVersion(
version = svc_add_template_version(
db,
template_id=template_id,
semantic_version=semantic_version,
storage_path=storage_path,
mime_type=file.content_type,
size=len(content),
checksum=sha256,
changelog=changelog,
approve=approve,
content=content,
filename_hint=file.filename or "template.bin",
content_type=file.content_type,
created_by=getattr(current_user, "username", None),
is_approved=bool(approve),
)
db.add(version)
db.flush()
if approve:
tpl.current_version_id = version.id
db.commit()
return VersionResponse(
id=version.id,
template_id=version.template_id,
@@ -381,31 +343,32 @@ async def preview_template(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
tpl = db.query(DocumentTemplate).filter(DocumentTemplate.id == template_id).first()
if not tpl:
raise HTTPException(status_code=404, detail="Template not found")
version_id = payload.version_id or tpl.current_version_id
if not version_id:
raise HTTPException(status_code=400, detail="Template has no versions")
ver = db.query(DocumentTemplateVersion).filter(DocumentTemplateVersion.id == version_id).first()
if not ver:
raise HTTPException(status_code=404, detail="Version not found")
resolved, unresolved, output_bytes, output_mime = svc_resolve_template_preview(
db,
template_id=template_id,
version_id=payload.version_id,
context=payload.context or {},
)
storage = get_default_storage()
content = storage.open_bytes(ver.storage_path)
tokens = extract_tokens_from_bytes(content)
context = build_context(payload.context or {})
resolved, unresolved = resolve_tokens(db, tokens, context)
# Sanitize resolved values to ensure JSON-serializable output
def _json_sanitize(value: Any) -> Any:
from datetime import date, datetime
if value is None or isinstance(value, (str, int, float, bool)):
return value
if isinstance(value, (date, datetime)):
return value.isoformat()
if isinstance(value, (list, tuple)):
return [_json_sanitize(v) for v in value]
if isinstance(value, dict):
return {k: _json_sanitize(v) for k, v in value.items()}
# Fallback: stringify unsupported types (e.g., functions)
return str(value)
output_bytes = content
output_mime = ver.mime_type
if ver.mime_type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
output_bytes = render_docx(content, resolved)
output_mime = ver.mime_type
sanitized_resolved = {k: _json_sanitize(v) for k, v in resolved.items()}
# We don't store preview output; just return metadata and resolution state
return PreviewResponse(
resolved=resolved,
resolved=sanitized_resolved,
unresolved=unresolved,
output_mime_type=output_mime,
output_size=len(output_bytes),
@@ -419,40 +382,16 @@ async def download_template(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
tpl = db.query(DocumentTemplate).filter(DocumentTemplate.id == template_id).first()
if not tpl:
raise HTTPException(status_code=404, detail="Template not found")
# Determine which version to serve
resolved_version_id = version_id or tpl.current_version_id
if not resolved_version_id:
raise HTTPException(status_code=404, detail="Template has no approved version")
ver = (
db.query(DocumentTemplateVersion)
.filter(DocumentTemplateVersion.id == resolved_version_id, DocumentTemplateVersion.template_id == tpl.id)
.first()
content, mime_type, original_name = svc_get_download_payload(
db,
template_id=template_id,
version_id=version_id,
)
if not ver:
raise HTTPException(status_code=404, detail="Version not found")
storage = get_default_storage()
try:
content = storage.open_bytes(ver.storage_path)
except Exception:
raise HTTPException(status_code=404, detail="Stored file not found")
# Derive original filename from storage_path (uuid_prefix_originalname)
base = os.path.basename(ver.storage_path)
if "_" in base:
original_name = base.split("_", 1)[1]
else:
original_name = base
headers = {
"Content-Disposition": f"attachment; filename=\"{original_name}\"",
}
return StreamingResponse(iter([content]), media_type=ver.mime_type, headers=headers)
return StreamingResponse(iter([content]), media_type=mime_type, headers=headers)
@router.get("/{template_id}/keywords", response_model=KeywordsResponse)
@@ -461,16 +400,9 @@ async def list_keywords(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
tpl = db.query(DocumentTemplate).filter(DocumentTemplate.id == template_id).first()
if not tpl:
raise HTTPException(status_code=404, detail="Template not found")
kws = (
db.query(TemplateKeyword)
.filter(TemplateKeyword.template_id == template_id)
.order_by(TemplateKeyword.keyword.asc())
.all()
)
return KeywordsResponse(keywords=[k.keyword for k in kws])
search_service = TemplateSearchService(db)
keywords = search_service.list_keywords(template_id)
return KeywordsResponse(keywords=keywords)
@router.post("/{template_id}/keywords", response_model=KeywordsResponse)
@@ -480,31 +412,9 @@ async def add_keywords(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
tpl = db.query(DocumentTemplate).filter(DocumentTemplate.id == template_id).first()
if not tpl:
raise HTTPException(status_code=404, detail="Template not found")
to_add = []
for kw in (payload.keywords or []):
normalized = (kw or "").strip().lower()
if not normalized:
continue
exists = (
db.query(TemplateKeyword)
.filter(TemplateKeyword.template_id == template_id, TemplateKeyword.keyword == normalized)
.first()
)
if not exists:
to_add.append(TemplateKeyword(template_id=template_id, keyword=normalized))
if to_add:
db.add_all(to_add)
db.commit()
kws = (
db.query(TemplateKeyword)
.filter(TemplateKeyword.template_id == template_id)
.order_by(TemplateKeyword.keyword.asc())
.all()
)
return KeywordsResponse(keywords=[k.keyword for k in kws])
search_service = TemplateSearchService(db)
keywords = await search_service.add_keywords(template_id, payload.keywords)
return KeywordsResponse(keywords=keywords)
@router.delete("/{template_id}/keywords/{keyword}", response_model=KeywordsResponse)
@@ -514,21 +424,7 @@ async def remove_keyword(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
tpl = db.query(DocumentTemplate).filter(DocumentTemplate.id == template_id).first()
if not tpl:
raise HTTPException(status_code=404, detail="Template not found")
normalized = (keyword or "").strip().lower()
if normalized:
db.query(TemplateKeyword).filter(
TemplateKeyword.template_id == template_id,
TemplateKeyword.keyword == normalized,
).delete(synchronize_session=False)
db.commit()
kws = (
db.query(TemplateKeyword)
.filter(TemplateKeyword.template_id == template_id)
.order_by(TemplateKeyword.keyword.asc())
.all()
)
return KeywordsResponse(keywords=[k.keyword for k in kws])
search_service = TemplateSearchService(db)
keywords = await search_service.remove_keyword(template_id, keyword)
return KeywordsResponse(keywords=keywords)