fixes and refactor

This commit is contained in:
HotSwapp
2025-08-14 19:16:28 -05:00
parent 5111079149
commit bfc04a6909
61 changed files with 5689 additions and 767 deletions

View File

@@ -1,7 +1,7 @@
"""
Comprehensive Admin API endpoints - User management, system settings, audit logging
"""
from typing import List, Dict, Any, Optional
from typing import List, Dict, Any, Optional, Union
from fastapi import APIRouter, Depends, HTTPException, status, UploadFile, File, Query, Body, Request
from fastapi.responses import FileResponse
from sqlalchemy.orm import Session, joinedload
@@ -13,24 +13,27 @@ import hashlib
import secrets
import shutil
import time
from datetime import datetime, timedelta, date
from datetime import datetime, timedelta, date, timezone
from pathlib import Path
from app.database.base import get_db
from app.api.search_highlight import build_query_tokens
# Track application start time
APPLICATION_START_TIME = time.time()
from app.models import User, Rolodex, File as FileModel, Ledger, QDRO, AuditLog, LoginAttempt
from app.models.lookups import SystemSetup, Employee, FileType, FileStatus, TransactionType, TransactionCode, State, FormIndex
from app.models.lookups import SystemSetup, Employee, FileType, FileStatus, TransactionType, TransactionCode, State, FormIndex, PrinterSetup
from app.auth.security import get_admin_user, get_password_hash, create_access_token
from app.services.audit import audit_service
from app.config import settings
from app.services.query_utils import apply_sorting, tokenized_ilike_filter, paginate_with_total
router = APIRouter()
# Enhanced Admin Schemas
from pydantic import BaseModel, Field, EmailStr
from pydantic.config import ConfigDict
class SystemStats(BaseModel):
"""Enhanced system statistics"""
@@ -91,8 +94,7 @@ class UserResponse(BaseModel):
created_at: Optional[datetime]
updated_at: Optional[datetime]
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
class PasswordReset(BaseModel):
"""Password reset request"""
@@ -124,8 +126,7 @@ class AuditLogEntry(BaseModel):
user_agent: Optional[str]
timestamp: datetime
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
class BackupInfo(BaseModel):
"""Backup information"""
@@ -135,6 +136,132 @@ class BackupInfo(BaseModel):
backup_type: str
status: str
class PrinterSetupBase(BaseModel):
"""Base schema for printer setup"""
description: Optional[str] = None
driver: Optional[str] = None
port: Optional[str] = None
default_printer: Optional[bool] = None
active: Optional[bool] = None
number: Optional[int] = None
page_break: Optional[str] = None
setup_st: Optional[str] = None
reset_st: Optional[str] = None
b_underline: Optional[str] = None
e_underline: Optional[str] = None
b_bold: Optional[str] = None
e_bold: Optional[str] = None
phone_book: Optional[bool] = None
rolodex_info: Optional[bool] = None
envelope: Optional[bool] = None
file_cabinet: Optional[bool] = None
accounts: Optional[bool] = None
statements: Optional[bool] = None
calendar: Optional[bool] = None
class PrinterSetupCreate(PrinterSetupBase):
printer_name: str
class PrinterSetupUpdate(PrinterSetupBase):
pass
class PrinterSetupResponse(PrinterSetupBase):
printer_name: str
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
model_config = ConfigDict(from_attributes=True)
# Printer Setup Management
@router.get("/printers", response_model=List[PrinterSetupResponse])
async def list_printers(
db: Session = Depends(get_db),
current_user: User = Depends(get_admin_user)
):
printers = db.query(PrinterSetup).order_by(PrinterSetup.printer_name.asc()).all()
return printers
@router.get("/printers/{printer_name}", response_model=PrinterSetupResponse)
async def get_printer(
printer_name: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_admin_user)
):
printer = db.query(PrinterSetup).filter(PrinterSetup.printer_name == printer_name).first()
if not printer:
raise HTTPException(status_code=404, detail="Printer not found")
return printer
@router.post("/printers", response_model=PrinterSetupResponse)
async def create_printer(
payload: PrinterSetupCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_admin_user)
):
exists = db.query(PrinterSetup).filter(PrinterSetup.printer_name == payload.printer_name).first()
if exists:
raise HTTPException(status_code=400, detail="Printer already exists")
data = payload.model_dump(exclude_unset=True)
instance = PrinterSetup(**data)
db.add(instance)
# Enforce single default printer
if data.get("default_printer"):
try:
db.query(PrinterSetup).filter(PrinterSetup.printer_name != instance.printer_name).update({PrinterSetup.default_printer: False})
except Exception:
pass
db.commit()
db.refresh(instance)
return instance
@router.put("/printers/{printer_name}", response_model=PrinterSetupResponse)
async def update_printer(
printer_name: str,
payload: PrinterSetupUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_admin_user)
):
instance = db.query(PrinterSetup).filter(PrinterSetup.printer_name == printer_name).first()
if not instance:
raise HTTPException(status_code=404, detail="Printer not found")
updates = payload.model_dump(exclude_unset=True)
for k, v in updates.items():
setattr(instance, k, v)
# Enforce single default printer when set true
if updates.get("default_printer"):
try:
db.query(PrinterSetup).filter(PrinterSetup.printer_name != instance.printer_name).update({PrinterSetup.default_printer: False})
except Exception:
pass
db.commit()
db.refresh(instance)
return instance
@router.delete("/printers/{printer_name}")
async def delete_printer(
printer_name: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_admin_user)
):
instance = db.query(PrinterSetup).filter(PrinterSetup.printer_name == printer_name).first()
if not instance:
raise HTTPException(status_code=404, detail="Printer not found")
db.delete(instance)
db.commit()
return {"message": "Printer deleted"}
class LookupTableInfo(BaseModel):
"""Lookup table information"""
table_name: str
@@ -202,7 +329,7 @@ async def system_health(
# Count active sessions (simplified)
try:
active_sessions = db.query(User).filter(
User.last_login > datetime.now() - timedelta(hours=24)
User.last_login > datetime.now(timezone.utc) - timedelta(hours=24)
).count()
except:
active_sessions = 0
@@ -215,7 +342,7 @@ async def system_health(
backup_files = list(backup_dir.glob("*.db"))
if backup_files:
latest_backup = max(backup_files, key=lambda p: p.stat().st_mtime)
backup_age = datetime.now() - datetime.fromtimestamp(latest_backup.stat().st_mtime)
backup_age = datetime.now(timezone.utc) - datetime.fromtimestamp(latest_backup.stat().st_mtime, tz=timezone.utc)
last_backup = latest_backup.name
if backup_age.days > 7:
alerts.append(f"Last backup is {backup_age.days} days old")
@@ -257,7 +384,7 @@ async def system_statistics(
# Count active users (logged in within last 30 days)
total_active_users = db.query(func.count(User.id)).filter(
User.last_login > datetime.now() - timedelta(days=30)
User.last_login > datetime.now(timezone.utc) - timedelta(days=30)
).scalar()
# Count admin users
@@ -308,7 +435,7 @@ async def system_statistics(
recent_activity.append({
"type": "customer_added",
"description": f"Customer {customer.first} {customer.last} added",
"timestamp": datetime.now().isoformat()
"timestamp": datetime.now(timezone.utc).isoformat()
})
except:
pass
@@ -409,7 +536,7 @@ async def export_table(
# Create exports directory if it doesn't exist
os.makedirs("exports", exist_ok=True)
filename = f"exports/{table_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
filename = f"exports/{table_name}_{datetime.now(timezone.utc).strftime('%Y%m%d_%H%M%S')}.csv"
try:
if table_name.lower() == "customers" or table_name.lower() == "rolodex":
@@ -470,10 +597,10 @@ async def download_backup(
if "sqlite" in settings.database_url:
db_path = settings.database_url.replace("sqlite:///", "")
if os.path.exists(db_path):
return FileResponse(
return FileResponse(
db_path,
media_type='application/octet-stream',
filename=f"delphi_backup_{datetime.now().strftime('%Y%m%d_%H%M%S')}.db"
filename=f"delphi_backup_{datetime.now(timezone.utc).strftime('%Y%m%d_%H%M%S')}.db"
)
raise HTTPException(
@@ -484,12 +611,20 @@ async def download_backup(
# User Management Endpoints
@router.get("/users", response_model=List[UserResponse])
class PaginatedUsersResponse(BaseModel):
items: List[UserResponse]
total: int
@router.get("/users", response_model=Union[List[UserResponse], PaginatedUsersResponse])
async def list_users(
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=1000),
search: Optional[str] = Query(None),
active_only: bool = Query(False),
sort_by: Optional[str] = Query(None, description="Sort by: username, email, first_name, last_name, created, updated"),
sort_dir: Optional[str] = Query("asc", description="Sort direction: asc or desc"),
include_total: bool = Query(False, description="When true, returns {items, total} instead of a plain list"),
db: Session = Depends(get_db),
current_user: User = Depends(get_admin_user)
):
@@ -498,19 +633,38 @@ async def list_users(
query = db.query(User)
if search:
query = query.filter(
or_(
User.username.ilike(f"%{search}%"),
User.email.ilike(f"%{search}%"),
User.first_name.ilike(f"%{search}%"),
User.last_name.ilike(f"%{search}%")
)
)
# DRY: tokenize and apply case-insensitive multi-field search
tokens = build_query_tokens(search)
filter_expr = tokenized_ilike_filter(tokens, [
User.username,
User.email,
User.first_name,
User.last_name,
])
if filter_expr is not None:
query = query.filter(filter_expr)
if active_only:
query = query.filter(User.is_active == True)
users = query.offset(skip).limit(limit).all()
# Sorting (whitelisted)
query = apply_sorting(
query,
sort_by,
sort_dir,
allowed={
"username": [User.username],
"email": [User.email],
"first_name": [User.first_name],
"last_name": [User.last_name],
"created": [User.created_at],
"updated": [User.updated_at],
},
)
users, total = paginate_with_total(query, skip, limit, include_total)
if include_total:
return {"items": users, "total": total or 0}
return users
@@ -567,8 +721,8 @@ async def create_user(
hashed_password=hashed_password,
is_admin=user_data.is_admin,
is_active=user_data.is_active,
created_at=datetime.now(),
updated_at=datetime.now()
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc)
)
db.add(new_user)
@@ -659,7 +813,7 @@ async def update_user(
changes[field] = {"from": getattr(user, field), "to": value}
setattr(user, field, value)
user.updated_at = datetime.now()
user.updated_at = datetime.now(timezone.utc)
db.commit()
db.refresh(user)
@@ -702,7 +856,7 @@ async def delete_user(
# Soft delete by deactivating
user.is_active = False
user.updated_at = datetime.now()
user.updated_at = datetime.now(timezone.utc)
db.commit()
@@ -744,7 +898,7 @@ async def reset_user_password(
# Update password
user.hashed_password = get_password_hash(password_data.new_password)
user.updated_at = datetime.now()
user.updated_at = datetime.now(timezone.utc)
db.commit()
@@ -1046,7 +1200,7 @@ async def create_backup(
backup_dir.mkdir(exist_ok=True)
# Generate backup filename
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
backup_filename = f"delphi_backup_{timestamp}.db"
backup_path = backup_dir / backup_filename
@@ -1063,7 +1217,7 @@ async def create_backup(
"backup_info": {
"filename": backup_filename,
"size": f"{backup_size / (1024*1024):.1f} MB",
"created_at": datetime.now().isoformat(),
"created_at": datetime.now(timezone.utc).isoformat(),
"backup_type": "manual",
"status": "completed"
}
@@ -1118,45 +1272,61 @@ async def get_audit_logs(
resource_type: Optional[str] = Query(None),
action: Optional[str] = Query(None),
hours_back: int = Query(168, ge=1, le=8760), # Default 7 days, max 1 year
sort_by: Optional[str] = Query("timestamp", description="Sort by: timestamp, username, action, resource_type"),
sort_dir: Optional[str] = Query("desc", description="Sort direction: asc or desc"),
include_total: bool = Query(False, description="When true, returns {items, total} instead of a plain list"),
db: Session = Depends(get_db),
current_user: User = Depends(get_admin_user)
):
"""Get audit log entries with filtering"""
cutoff_time = datetime.now() - timedelta(hours=hours_back)
"""Get audit log entries with filtering, sorting, and pagination"""
cutoff_time = datetime.now(timezone.utc) - timedelta(hours=hours_back)
query = db.query(AuditLog).filter(AuditLog.timestamp >= cutoff_time)
if user_id:
query = query.filter(AuditLog.user_id == user_id)
if resource_type:
query = query.filter(AuditLog.resource_type.ilike(f"%{resource_type}%"))
if action:
query = query.filter(AuditLog.action.ilike(f"%{action}%"))
total_count = query.count()
logs = query.order_by(AuditLog.timestamp.desc()).offset(skip).limit(limit).all()
return {
"total": total_count,
"logs": [
{
"id": log.id,
"user_id": log.user_id,
"username": log.username,
"action": log.action,
"resource_type": log.resource_type,
"resource_id": log.resource_id,
"details": log.details,
"ip_address": log.ip_address,
"user_agent": log.user_agent,
"timestamp": log.timestamp.isoformat()
}
for log in logs
]
}
# Sorting (whitelisted)
query = apply_sorting(
query,
sort_by,
sort_dir,
allowed={
"timestamp": [AuditLog.timestamp],
"username": [AuditLog.username],
"action": [AuditLog.action],
"resource_type": [AuditLog.resource_type],
},
)
logs, total = paginate_with_total(query, skip, limit, include_total)
items = [
{
"id": log.id,
"user_id": log.user_id,
"username": log.username,
"action": log.action,
"resource_type": log.resource_type,
"resource_id": log.resource_id,
"details": log.details,
"ip_address": log.ip_address,
"user_agent": log.user_agent,
"timestamp": log.timestamp.isoformat(),
}
for log in logs
]
if include_total:
return {"items": items, "total": total or 0}
return items
@router.get("/audit/login-attempts")
@@ -1166,39 +1336,55 @@ async def get_login_attempts(
username: Optional[str] = Query(None),
failed_only: bool = Query(False),
hours_back: int = Query(168, ge=1, le=8760), # Default 7 days
sort_by: Optional[str] = Query("timestamp", description="Sort by: timestamp, username, ip_address, success"),
sort_dir: Optional[str] = Query("desc", description="Sort direction: asc or desc"),
include_total: bool = Query(False, description="When true, returns {items, total} instead of a plain list"),
db: Session = Depends(get_db),
current_user: User = Depends(get_admin_user)
):
"""Get login attempts with filtering"""
cutoff_time = datetime.now() - timedelta(hours=hours_back)
"""Get login attempts with filtering, sorting, and pagination"""
cutoff_time = datetime.now(timezone.utc) - timedelta(hours=hours_back)
query = db.query(LoginAttempt).filter(LoginAttempt.timestamp >= cutoff_time)
if username:
query = query.filter(LoginAttempt.username.ilike(f"%{username}%"))
if failed_only:
query = query.filter(LoginAttempt.success == 0)
total_count = query.count()
attempts = query.order_by(LoginAttempt.timestamp.desc()).offset(skip).limit(limit).all()
return {
"total": total_count,
"attempts": [
{
"id": attempt.id,
"username": attempt.username,
"ip_address": attempt.ip_address,
"user_agent": attempt.user_agent,
"success": bool(attempt.success),
"failure_reason": attempt.failure_reason,
"timestamp": attempt.timestamp.isoformat()
}
for attempt in attempts
]
}
# Sorting (whitelisted)
query = apply_sorting(
query,
sort_by,
sort_dir,
allowed={
"timestamp": [LoginAttempt.timestamp],
"username": [LoginAttempt.username],
"ip_address": [LoginAttempt.ip_address],
"success": [LoginAttempt.success],
},
)
attempts, total = paginate_with_total(query, skip, limit, include_total)
items = [
{
"id": attempt.id,
"username": attempt.username,
"ip_address": attempt.ip_address,
"user_agent": attempt.user_agent,
"success": bool(attempt.success),
"failure_reason": attempt.failure_reason,
"timestamp": attempt.timestamp.isoformat(),
}
for attempt in attempts
]
if include_total:
return {"items": items, "total": total or 0}
return items
@router.get("/audit/user-activity/{user_id}")
@@ -1251,7 +1437,7 @@ async def get_security_alerts(
):
"""Get security alerts and suspicious activity"""
cutoff_time = datetime.now() - timedelta(hours=hours_back)
cutoff_time = datetime.now(timezone.utc) - timedelta(hours=hours_back)
# Get failed login attempts
failed_logins = db.query(LoginAttempt).filter(
@@ -1356,7 +1542,7 @@ async def get_audit_statistics(
):
"""Get audit statistics and metrics"""
cutoff_time = datetime.now() - timedelta(days=days_back)
cutoff_time = datetime.now(timezone.utc) - timedelta(days=days_back)
# Total activity counts
total_audit_entries = db.query(func.count(AuditLog.id)).filter(

View File

@@ -1,7 +1,7 @@
"""
Authentication API endpoints
"""
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from typing import List
from fastapi import APIRouter, Depends, HTTPException, status, Request
from fastapi.security import OAuth2PasswordRequestForm
@@ -69,7 +69,7 @@ async def login(login_data: LoginRequest, request: Request, db: Session = Depend
)
# Update last login
user.last_login = datetime.utcnow()
user.last_login = datetime.now(timezone.utc)
db.commit()
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
@@ -144,7 +144,7 @@ async def read_users_me(current_user: User = Depends(get_current_user)):
async def refresh_token_endpoint(
request: Request,
db: Session = Depends(get_db),
body: RefreshRequest | None = None,
body: RefreshRequest = None,
):
"""Issue a new access token using a valid, non-revoked refresh token.
@@ -203,7 +203,7 @@ async def refresh_token_endpoint(
if not user or not user.is_active:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found or inactive")
user.last_login = datetime.utcnow()
user.last_login = datetime.now(timezone.utc)
db.commit()
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
@@ -225,7 +225,7 @@ async def list_users(
@router.post("/logout")
async def logout(body: RefreshRequest | None = None, db: Session = Depends(get_db)):
async def logout(body: RefreshRequest = None, db: Session = Depends(get_db)):
"""Revoke the provided refresh token. Idempotent and safe to call multiple times.
The client should send a JSON body: { "refresh_token": "..." }.

View File

@@ -4,7 +4,7 @@ Customer (Rolodex) API endpoints
from typing import List, Optional, Union
from fastapi import APIRouter, Depends, HTTPException, status, Query
from sqlalchemy.orm import Session, joinedload
from sqlalchemy import or_, and_, func, asc, desc
from sqlalchemy import func
from fastapi.responses import StreamingResponse
import csv
import io
@@ -13,12 +13,16 @@ from app.database.base import get_db
from app.models.rolodex import Rolodex, Phone
from app.models.user import User
from app.auth.security import get_current_user
from app.services.cache import invalidate_search_cache
from app.services.customers_search import apply_customer_filters, apply_customer_sorting, prepare_customer_csv_rows
from app.services.query_utils import apply_sorting, paginate_with_total
router = APIRouter()
# Pydantic schemas for request/response
from pydantic import BaseModel, EmailStr
from pydantic import BaseModel, EmailStr, Field
from pydantic.config import ConfigDict
from datetime import date
@@ -32,8 +36,7 @@ class PhoneResponse(BaseModel):
location: Optional[str]
phone: str
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
class CustomerBase(BaseModel):
@@ -84,10 +87,11 @@ class CustomerUpdate(BaseModel):
class CustomerResponse(CustomerBase):
phone_numbers: List[PhoneResponse] = []
phone_numbers: List[PhoneResponse] = Field(default_factory=list)
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
@router.get("/search/phone")
@@ -196,80 +200,17 @@ async def list_customers(
try:
base_query = db.query(Rolodex)
if search:
s = (search or "").strip()
s_lower = s.lower()
tokens = [t for t in s_lower.split() if t]
# Basic contains search on several fields (case-insensitive)
contains_any = or_(
func.lower(Rolodex.id).contains(s_lower),
func.lower(Rolodex.last).contains(s_lower),
func.lower(Rolodex.first).contains(s_lower),
func.lower(Rolodex.middle).contains(s_lower),
func.lower(Rolodex.city).contains(s_lower),
func.lower(Rolodex.email).contains(s_lower),
)
# Multi-token name support: every token must match either first, middle, or last
name_tokens = [
or_(
func.lower(Rolodex.first).contains(tok),
func.lower(Rolodex.middle).contains(tok),
func.lower(Rolodex.last).contains(tok),
)
for tok in tokens
]
combined = contains_any if not name_tokens else or_(contains_any, and_(*name_tokens))
# Comma pattern: "Last, First"
last_first_filter = None
if "," in s_lower:
last_part, first_part = [p.strip() for p in s_lower.split(",", 1)]
if last_part and first_part:
last_first_filter = and_(
func.lower(Rolodex.last).contains(last_part),
func.lower(Rolodex.first).contains(first_part),
)
elif last_part:
last_first_filter = func.lower(Rolodex.last).contains(last_part)
final_filter = or_(combined, last_first_filter) if last_first_filter is not None else combined
base_query = base_query.filter(final_filter)
# Apply group/state filters (support single and multi-select)
effective_groups = [g for g in (groups or []) if g] or ([group] if group else [])
if effective_groups:
base_query = base_query.filter(Rolodex.group.in_(effective_groups))
effective_states = [s for s in (states or []) if s] or ([state] if state else [])
if effective_states:
base_query = base_query.filter(Rolodex.abrev.in_(effective_states))
base_query = apply_customer_filters(
base_query,
search=search,
group=group,
state=state,
groups=groups,
states=states,
)
# Apply sorting (whitelisted fields only)
normalized_sort_by = (sort_by or "id").lower()
normalized_sort_dir = (sort_dir or "asc").lower()
is_desc = normalized_sort_dir == "desc"
order_columns = []
if normalized_sort_by == "id":
order_columns = [Rolodex.id]
elif normalized_sort_by == "name":
# Sort by last, then first
order_columns = [Rolodex.last, Rolodex.first]
elif normalized_sort_by == "city":
# Sort by city, then state abbreviation
order_columns = [Rolodex.city, Rolodex.abrev]
elif normalized_sort_by == "email":
order_columns = [Rolodex.email]
else:
# Fallback to id to avoid arbitrary column injection
order_columns = [Rolodex.id]
# Case-insensitive ordering where applicable, preserving None ordering default
ordered = []
for col in order_columns:
# Use lower() for string-like cols; SQLAlchemy will handle non-string safely enough for SQLite/Postgres
expr = func.lower(col) if col.type.python_type in (str,) else col # type: ignore[attr-defined]
ordered.append(desc(expr) if is_desc else asc(expr))
if ordered:
base_query = base_query.order_by(*ordered)
base_query = apply_customer_sorting(base_query, sort_by=sort_by, sort_dir=sort_dir)
customers = base_query.options(joinedload(Rolodex.phone_numbers)).offset(skip).limit(limit).all()
if include_total:
@@ -304,72 +245,16 @@ async def export_customers(
try:
base_query = db.query(Rolodex)
if search:
s = (search or "").strip()
s_lower = s.lower()
tokens = [t for t in s_lower.split() if t]
contains_any = or_(
func.lower(Rolodex.id).contains(s_lower),
func.lower(Rolodex.last).contains(s_lower),
func.lower(Rolodex.first).contains(s_lower),
func.lower(Rolodex.middle).contains(s_lower),
func.lower(Rolodex.city).contains(s_lower),
func.lower(Rolodex.email).contains(s_lower),
)
name_tokens = [
or_(
func.lower(Rolodex.first).contains(tok),
func.lower(Rolodex.middle).contains(tok),
func.lower(Rolodex.last).contains(tok),
)
for tok in tokens
]
combined = contains_any if not name_tokens else or_(contains_any, and_(*name_tokens))
last_first_filter = None
if "," in s_lower:
last_part, first_part = [p.strip() for p in s_lower.split(",", 1)]
if last_part and first_part:
last_first_filter = and_(
func.lower(Rolodex.last).contains(last_part),
func.lower(Rolodex.first).contains(first_part),
)
elif last_part:
last_first_filter = func.lower(Rolodex.last).contains(last_part)
final_filter = or_(combined, last_first_filter) if last_first_filter is not None else combined
base_query = base_query.filter(final_filter)
base_query = apply_customer_filters(
base_query,
search=search,
group=group,
state=state,
groups=groups,
states=states,
)
effective_groups = [g for g in (groups or []) if g] or ([group] if group else [])
if effective_groups:
base_query = base_query.filter(Rolodex.group.in_(effective_groups))
effective_states = [s for s in (states or []) if s] or ([state] if state else [])
if effective_states:
base_query = base_query.filter(Rolodex.abrev.in_(effective_states))
normalized_sort_by = (sort_by or "id").lower()
normalized_sort_dir = (sort_dir or "asc").lower()
is_desc = normalized_sort_dir == "desc"
order_columns = []
if normalized_sort_by == "id":
order_columns = [Rolodex.id]
elif normalized_sort_by == "name":
order_columns = [Rolodex.last, Rolodex.first]
elif normalized_sort_by == "city":
order_columns = [Rolodex.city, Rolodex.abrev]
elif normalized_sort_by == "email":
order_columns = [Rolodex.email]
else:
order_columns = [Rolodex.id]
ordered = []
for col in order_columns:
try:
expr = func.lower(col) if col.type.python_type in (str,) else col # type: ignore[attr-defined]
except Exception:
expr = col
ordered.append(desc(expr) if is_desc else asc(expr))
if ordered:
base_query = base_query.order_by(*ordered)
base_query = apply_customer_sorting(base_query, sort_by=sort_by, sort_dir=sort_dir)
if not export_all:
if skip is not None:
@@ -382,39 +267,10 @@ async def export_customers(
# Prepare CSV
output = io.StringIO()
writer = csv.writer(output)
allowed_fields_in_order = ["id", "name", "group", "city", "state", "phone", "email"]
header_names = {
"id": "Customer ID",
"name": "Name",
"group": "Group",
"city": "City",
"state": "State",
"phone": "Primary Phone",
"email": "Email",
}
requested = [f.lower() for f in (fields or []) if isinstance(f, str)]
selected_fields = [f for f in allowed_fields_in_order if f in requested] if requested else allowed_fields_in_order
if not selected_fields:
selected_fields = allowed_fields_in_order
writer.writerow([header_names[f] for f in selected_fields])
for c in customers:
full_name = f"{(c.first or '').strip()} {(c.last or '').strip()}".strip()
primary_phone = ""
try:
if c.phone_numbers:
primary_phone = c.phone_numbers[0].phone or ""
except Exception:
primary_phone = ""
row_map = {
"id": c.id,
"name": full_name,
"group": c.group or "",
"city": c.city or "",
"state": c.abrev or "",
"phone": primary_phone,
"email": c.email or "",
}
writer.writerow([row_map[f] for f in selected_fields])
header_row, rows = prepare_customer_csv_rows(customers, fields)
writer.writerow(header_row)
for row in rows:
writer.writerow(row)
output.seek(0)
filename = "customers_export.csv"
@@ -469,6 +325,10 @@ async def create_customer(
db.commit()
db.refresh(customer)
try:
await invalidate_search_cache()
except Exception:
pass
return customer
@@ -494,7 +354,10 @@ async def update_customer(
db.commit()
db.refresh(customer)
try:
await invalidate_search_cache()
except Exception:
pass
return customer
@@ -515,17 +378,30 @@ async def delete_customer(
db.delete(customer)
db.commit()
try:
await invalidate_search_cache()
except Exception:
pass
return {"message": "Customer deleted successfully"}
@router.get("/{customer_id}/phones", response_model=List[PhoneResponse])
class PaginatedPhonesResponse(BaseModel):
items: List[PhoneResponse]
total: int
@router.get("/{customer_id}/phones", response_model=Union[List[PhoneResponse], PaginatedPhonesResponse])
async def get_customer_phones(
customer_id: str,
skip: int = Query(0, ge=0, description="Offset for pagination"),
limit: int = Query(100, ge=1, le=1000, description="Page size"),
sort_by: Optional[str] = Query("location", description="Sort by: location, phone"),
sort_dir: Optional[str] = Query("asc", description="Sort direction: asc or desc"),
include_total: bool = Query(False, description="When true, returns {items, total} instead of a plain list"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Get customer phone numbers"""
"""Get customer phone numbers with optional sorting/pagination"""
customer = db.query(Rolodex).filter(Rolodex.id == customer_id).first()
if not customer:
@@ -534,7 +410,21 @@ async def get_customer_phones(
detail="Customer not found"
)
phones = db.query(Phone).filter(Phone.rolodex_id == customer_id).all()
query = db.query(Phone).filter(Phone.rolodex_id == customer_id)
query = apply_sorting(
query,
sort_by,
sort_dir,
allowed={
"location": [Phone.location, Phone.phone],
"phone": [Phone.phone],
},
)
phones, total = paginate_with_total(query, skip, limit, include_total)
if include_total:
return {"items": phones, "total": total or 0}
return phones

View File

@@ -1,16 +1,19 @@
"""
Document Management API endpoints - QDROs, Templates, and General Documents
"""
from typing import List, Optional, Dict, Any
from __future__ import annotations
from typing import List, Optional, Dict, Any, Union
from fastapi import APIRouter, Depends, HTTPException, status, Query, UploadFile, File, Form, Request
from sqlalchemy.orm import Session, joinedload
from sqlalchemy import or_, func, and_, desc, asc, text
from datetime import date, datetime
from datetime import date, datetime, timezone
import os
import uuid
import shutil
from app.database.base import get_db
from app.api.search_highlight import build_query_tokens
from app.services.query_utils import tokenized_ilike_filter, apply_pagination, apply_sorting, paginate_with_total
from app.models.qdro import QDRO
from app.models.files import File as FileModel
from app.models.rolodex import Rolodex
@@ -20,18 +23,20 @@ from app.auth.security import get_current_user
from app.models.additional import Document
from app.core.logging import get_logger
from app.services.audit import audit_service
from app.services.cache import invalidate_search_cache
router = APIRouter()
# Pydantic schemas
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict
class QDROBase(BaseModel):
file_no: str
version: str = "01"
title: Optional[str] = None
form_name: Optional[str] = None
content: Optional[str] = None
status: str = "DRAFT"
created_date: Optional[date] = None
@@ -51,6 +56,7 @@ class QDROCreate(QDROBase):
class QDROUpdate(BaseModel):
version: Optional[str] = None
title: Optional[str] = None
form_name: Optional[str] = None
content: Optional[str] = None
status: Optional[str] = None
created_date: Optional[date] = None
@@ -66,27 +72,61 @@ class QDROUpdate(BaseModel):
class QDROResponse(QDROBase):
id: int
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
@router.get("/qdros/{file_no}", response_model=List[QDROResponse])
class PaginatedQDROResponse(BaseModel):
items: List[QDROResponse]
total: int
@router.get("/qdros/{file_no}", response_model=Union[List[QDROResponse], PaginatedQDROResponse])
async def get_file_qdros(
file_no: str,
skip: int = Query(0, ge=0, description="Offset for pagination"),
limit: int = Query(100, ge=1, le=1000, description="Page size"),
sort_by: Optional[str] = Query("updated", description="Sort by: updated, created, version, status"),
sort_dir: Optional[str] = Query("desc", description="Sort direction: asc or desc"),
include_total: bool = Query(False, description="When true, returns {items, total} instead of a plain list"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Get QDROs for specific file"""
qdros = db.query(QDRO).filter(QDRO.file_no == file_no).order_by(QDRO.version).all()
"""Get QDROs for a specific file with optional sorting/pagination"""
query = db.query(QDRO).filter(QDRO.file_no == file_no)
# Sorting (whitelisted)
query = apply_sorting(
query,
sort_by,
sort_dir,
allowed={
"updated": [QDRO.updated_at, QDRO.id],
"created": [QDRO.created_at, QDRO.id],
"version": [QDRO.version],
"status": [QDRO.status],
},
)
qdros, total = paginate_with_total(query, skip, limit, include_total)
if include_total:
return {"items": qdros, "total": total or 0}
return qdros
@router.get("/qdros/", response_model=List[QDROResponse])
class PaginatedQDROResponse(BaseModel):
items: List[QDROResponse]
total: int
@router.get("/qdros/", response_model=Union[List[QDROResponse], PaginatedQDROResponse])
async def list_qdros(
skip: int = Query(0, ge=0),
limit: int = Query(50, ge=1, le=200),
status_filter: Optional[str] = Query(None),
search: Optional[str] = Query(None),
sort_by: Optional[str] = Query(None, description="Sort by: file_no, version, status, created, updated"),
sort_dir: Optional[str] = Query("asc", description="Sort direction: asc or desc"),
include_total: bool = Query(False, description="When true, returns {items, total} instead of a plain list"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
@@ -97,17 +137,37 @@ async def list_qdros(
query = query.filter(QDRO.status == status_filter)
if search:
query = query.filter(
or_(
QDRO.file_no.contains(search),
QDRO.title.contains(search),
QDRO.participant_name.contains(search),
QDRO.spouse_name.contains(search),
QDRO.plan_name.contains(search)
)
)
qdros = query.offset(skip).limit(limit).all()
# DRY: tokenize and apply case-insensitive search across common QDRO fields
tokens = build_query_tokens(search)
filter_expr = tokenized_ilike_filter(tokens, [
QDRO.file_no,
QDRO.form_name,
QDRO.pet,
QDRO.res,
QDRO.case_number,
QDRO.notes,
QDRO.status,
])
if filter_expr is not None:
query = query.filter(filter_expr)
# Sorting (whitelisted)
query = apply_sorting(
query,
sort_by,
sort_dir,
allowed={
"file_no": [QDRO.file_no],
"version": [QDRO.version],
"status": [QDRO.status],
"created": [QDRO.created_at],
"updated": [QDRO.updated_at],
},
)
qdros, total = paginate_with_total(query, skip, limit, include_total)
if include_total:
return {"items": qdros, "total": total or 0}
return qdros
@@ -135,6 +195,10 @@ async def create_qdro(
db.commit()
db.refresh(qdro)
try:
await invalidate_search_cache()
except Exception:
pass
return qdro
@@ -189,6 +253,10 @@ async def update_qdro(
db.commit()
db.refresh(qdro)
try:
await invalidate_search_cache()
except Exception:
pass
return qdro
@@ -213,7 +281,10 @@ async def delete_qdro(
db.delete(qdro)
db.commit()
try:
await invalidate_search_cache()
except Exception:
pass
return {"message": "QDRO deleted successfully"}
@@ -241,8 +312,7 @@ class TemplateResponse(TemplateBase):
active: bool = True
created_at: Optional[datetime] = None
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
# Document Generation Schema
class DocumentGenerateRequest(BaseModel):
@@ -269,13 +339,21 @@ class DocumentStats(BaseModel):
recent_activity: List[Dict[str, Any]]
@router.get("/templates/", response_model=List[TemplateResponse])
class PaginatedTemplatesResponse(BaseModel):
items: List[TemplateResponse]
total: int
@router.get("/templates/", response_model=Union[List[TemplateResponse], PaginatedTemplatesResponse])
async def list_templates(
skip: int = Query(0, ge=0),
limit: int = Query(50, ge=1, le=200),
category: Optional[str] = Query(None),
search: Optional[str] = Query(None),
active_only: bool = Query(True),
sort_by: Optional[str] = Query(None, description="Sort by: form_id, form_name, category, created, updated"),
sort_dir: Optional[str] = Query("asc", description="Sort direction: asc or desc"),
include_total: bool = Query(False, description="When true, returns {items, total} instead of a plain list"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
@@ -289,14 +367,31 @@ async def list_templates(
query = query.filter(FormIndex.category == category)
if search:
query = query.filter(
or_(
FormIndex.form_name.contains(search),
FormIndex.form_id.contains(search)
)
)
templates = query.offset(skip).limit(limit).all()
# DRY: tokenize and apply case-insensitive search for templates
tokens = build_query_tokens(search)
filter_expr = tokenized_ilike_filter(tokens, [
FormIndex.form_name,
FormIndex.form_id,
FormIndex.category,
])
if filter_expr is not None:
query = query.filter(filter_expr)
# Sorting (whitelisted)
query = apply_sorting(
query,
sort_by,
sort_dir,
allowed={
"form_id": [FormIndex.form_id],
"form_name": [FormIndex.form_name],
"category": [FormIndex.category],
"created": [FormIndex.created_at],
"updated": [FormIndex.updated_at],
},
)
templates, total = paginate_with_total(query, skip, limit, include_total)
# Enhanced response with template content
results = []
@@ -317,6 +412,8 @@ async def list_templates(
"variables": _extract_variables_from_content(content)
})
if include_total:
return {"items": results, "total": total or 0}
return results
@@ -356,6 +453,10 @@ async def create_template(
db.commit()
db.refresh(form_index)
try:
await invalidate_search_cache()
except Exception:
pass
return {
"form_id": form_index.form_id,
@@ -440,6 +541,10 @@ async def update_template(
db.commit()
db.refresh(template)
try:
await invalidate_search_cache()
except Exception:
pass
# Get updated content
template_lines = db.query(FormList).filter(
@@ -480,6 +585,10 @@ async def delete_template(
# Delete template
db.delete(template)
db.commit()
try:
await invalidate_search_cache()
except Exception:
pass
return {"message": "Template deleted successfully"}
@@ -574,7 +683,7 @@ async def generate_document(
"file_name": file_name,
"file_path": file_path,
"size": file_size,
"created_at": datetime.now()
"created_at": datetime.now(timezone.utc)
}
@@ -629,32 +738,49 @@ async def get_document_stats(
@router.get("/file/{file_no}/documents")
async def get_file_documents(
file_no: str,
sort_by: Optional[str] = Query("updated", description="Sort by: updated, created"),
sort_dir: Optional[str] = Query("desc", description="Sort direction: asc or desc"),
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=1000),
include_total: bool = Query(False, description="When true, returns {items, total} instead of a plain list"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Get all documents associated with a specific file"""
# Get QDROs for this file
qdros = db.query(QDRO).filter(QDRO.file_no == file_no).order_by(desc(QDRO.updated_at)).all()
# Format response
documents = [
"""Get all documents associated with a specific file, with optional sorting/pagination"""
# Base query for QDROs tied to the file
query = db.query(QDRO).filter(QDRO.file_no == file_no)
# Apply sorting using shared helper (map friendly names to columns)
query = apply_sorting(
query,
sort_by,
sort_dir,
allowed={
"updated": [QDRO.updated_at, QDRO.id],
"created": [QDRO.created_at, QDRO.id],
},
)
qdros, total = paginate_with_total(query, skip, limit, include_total)
items = [
{
"id": qdro.id,
"type": "QDRO",
"title": f"QDRO v{qdro.version}",
"status": qdro.status,
"created_date": qdro.created_date.isoformat() if qdro.created_date else None,
"updated_at": qdro.updated_at.isoformat() if qdro.updated_at else None,
"file_no": qdro.file_no
"created_date": qdro.created_date.isoformat() if getattr(qdro, "created_date", None) else None,
"updated_at": qdro.updated_at.isoformat() if getattr(qdro, "updated_at", None) else None,
"file_no": qdro.file_no,
}
for qdro in qdros
]
return {
"file_no": file_no,
"documents": documents,
"total_count": len(documents)
}
payload = {"file_no": file_no, "documents": items, "total_count": (total if include_total else None)}
# Maintain previous shape by omitting total_count when include_total is False? The prior code always returned total_count.
# Keep total_count for backward compatibility but set to actual total when include_total else len(items)
payload["total_count"] = (total if include_total else len(items))
return payload
def _extract_variables_from_content(content: str) -> Dict[str, str]:

View File

@@ -1,25 +1,29 @@
"""
File Management API endpoints
"""
from typing import List, Optional, Dict, Any
from typing import List, Optional, Dict, Any, Union
from fastapi import APIRouter, Depends, HTTPException, status, Query
from sqlalchemy.orm import Session, joinedload
from sqlalchemy import or_, func, and_, desc
from datetime import date, datetime
from app.database.base import get_db
from app.api.search_highlight import build_query_tokens
from app.services.query_utils import tokenized_ilike_filter, apply_pagination, apply_sorting, paginate_with_total
from app.models.files import File
from app.models.rolodex import Rolodex
from app.models.ledger import Ledger
from app.models.lookups import Employee, FileType, FileStatus
from app.models.user import User
from app.auth.security import get_current_user
from app.services.cache import invalidate_search_cache
router = APIRouter()
# Pydantic schemas
from pydantic import BaseModel
from pydantic.config import ConfigDict
class FileBase(BaseModel):
@@ -67,17 +71,24 @@ class FileResponse(FileBase):
amount_owing: float = 0.0
transferable: float = 0.0
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
@router.get("/", response_model=List[FileResponse])
class PaginatedFilesResponse(BaseModel):
items: List[FileResponse]
total: int
@router.get("/", response_model=Union[List[FileResponse], PaginatedFilesResponse])
async def list_files(
skip: int = Query(0, ge=0),
limit: int = Query(50, ge=1, le=200),
search: Optional[str] = Query(None),
status_filter: Optional[str] = Query(None),
employee_filter: Optional[str] = Query(None),
sort_by: Optional[str] = Query(None, description="Sort by: file_no, client, opened, closed, status, amount_owing, total_charges"),
sort_dir: Optional[str] = Query("asc", description="Sort direction: asc or desc"),
include_total: bool = Query(False, description="When true, returns {items, total} instead of a plain list"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
@@ -85,14 +96,17 @@ async def list_files(
query = db.query(File)
if search:
query = query.filter(
or_(
File.file_no.contains(search),
File.id.contains(search),
File.regarding.contains(search),
File.file_type.contains(search)
)
)
# DRY: tokenize and apply case-insensitive search consistently with search endpoints
tokens = build_query_tokens(search)
filter_expr = tokenized_ilike_filter(tokens, [
File.file_no,
File.id,
File.regarding,
File.file_type,
File.memo,
])
if filter_expr is not None:
query = query.filter(filter_expr)
if status_filter:
query = query.filter(File.status == status_filter)
@@ -100,7 +114,25 @@ async def list_files(
if employee_filter:
query = query.filter(File.empl_num == employee_filter)
files = query.offset(skip).limit(limit).all()
# Sorting (whitelisted)
query = apply_sorting(
query,
sort_by,
sort_dir,
allowed={
"file_no": [File.file_no],
"client": [File.id],
"opened": [File.opened],
"closed": [File.closed],
"status": [File.status],
"amount_owing": [File.amount_owing],
"total_charges": [File.total_charges],
},
)
files, total = paginate_with_total(query, skip, limit, include_total)
if include_total:
return {"items": files, "total": total or 0}
return files
@@ -142,6 +174,10 @@ async def create_file(
db.commit()
db.refresh(file_obj)
try:
await invalidate_search_cache()
except Exception:
pass
return file_obj
@@ -167,7 +203,10 @@ async def update_file(
db.commit()
db.refresh(file_obj)
try:
await invalidate_search_cache()
except Exception:
pass
return file_obj
@@ -188,7 +227,10 @@ async def delete_file(
db.delete(file_obj)
db.commit()
try:
await invalidate_search_cache()
except Exception:
pass
return {"message": "File deleted successfully"}
@@ -433,11 +475,13 @@ async def advanced_file_search(
query = query.filter(File.file_no.contains(file_no))
if client_name:
# SQLite-safe concatenation for first + last name
full_name_expr = (func.coalesce(Rolodex.first, '') + ' ' + func.coalesce(Rolodex.last, ''))
query = query.join(Rolodex).filter(
or_(
Rolodex.first.contains(client_name),
Rolodex.last.contains(client_name),
func.concat(Rolodex.first, ' ', Rolodex.last).contains(client_name)
full_name_expr.contains(client_name)
)
)

View File

@@ -1,11 +1,11 @@
"""
Financial/Ledger API endpoints
"""
from typing import List, Optional, Dict, Any
from typing import List, Optional, Dict, Any, Union
from fastapi import APIRouter, Depends, HTTPException, status, Query
from sqlalchemy.orm import Session, joinedload
from sqlalchemy import or_, func, and_, desc, asc, text
from datetime import date, datetime, timedelta
from datetime import date, datetime, timedelta, timezone
from app.database.base import get_db
from app.models.ledger import Ledger
@@ -14,12 +14,14 @@ from app.models.rolodex import Rolodex
from app.models.lookups import Employee, TransactionType, TransactionCode
from app.models.user import User
from app.auth.security import get_current_user
from app.services.cache import invalidate_search_cache
from app.services.query_utils import apply_sorting, paginate_with_total
router = APIRouter()
# Pydantic schemas
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict
class LedgerBase(BaseModel):
@@ -57,8 +59,7 @@ class LedgerResponse(LedgerBase):
id: int
item_no: int
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
class FinancialSummary(BaseModel):
@@ -75,23 +76,46 @@ class FinancialSummary(BaseModel):
billed_amount: float
@router.get("/ledger/{file_no}", response_model=List[LedgerResponse])
class PaginatedLedgerResponse(BaseModel):
items: List[LedgerResponse]
total: int
@router.get("/ledger/{file_no}", response_model=Union[List[LedgerResponse], PaginatedLedgerResponse])
async def get_file_ledger(
file_no: str,
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=500),
billed_only: Optional[bool] = Query(None),
skip: int = Query(0, ge=0, description="Offset for pagination"),
limit: int = Query(100, ge=1, le=500, description="Page size"),
billed_only: Optional[bool] = Query(None, description="Filter billed vs unbilled entries"),
sort_by: Optional[str] = Query("date", description="Sort by: date, item_no, amount, billed"),
sort_dir: Optional[str] = Query("desc", description="Sort direction: asc or desc"),
include_total: bool = Query(False, description="When true, returns {items, total} instead of a plain list"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Get ledger entries for specific file"""
query = db.query(Ledger).filter(Ledger.file_no == file_no).order_by(Ledger.date.desc())
query = db.query(Ledger).filter(Ledger.file_no == file_no)
if billed_only is not None:
billed_filter = "Y" if billed_only else "N"
query = query.filter(Ledger.billed == billed_filter)
entries = query.offset(skip).limit(limit).all()
# Sorting (whitelisted)
query = apply_sorting(
query,
sort_by,
sort_dir,
allowed={
"date": [Ledger.date, Ledger.item_no],
"item_no": [Ledger.item_no],
"amount": [Ledger.amount],
"billed": [Ledger.billed, Ledger.date],
},
)
entries, total = paginate_with_total(query, skip, limit, include_total)
if include_total:
return {"items": entries, "total": total or 0}
return entries
@@ -127,6 +151,10 @@ async def create_ledger_entry(
# Update file balances (simplified version)
await _update_file_balances(file_obj, db)
try:
await invalidate_search_cache()
except Exception:
pass
return entry
@@ -158,6 +186,10 @@ async def update_ledger_entry(
if file_obj:
await _update_file_balances(file_obj, db)
try:
await invalidate_search_cache()
except Exception:
pass
return entry
@@ -185,6 +217,10 @@ async def delete_ledger_entry(
if file_obj:
await _update_file_balances(file_obj, db)
try:
await invalidate_search_cache()
except Exception:
pass
return {"message": "Ledger entry deleted successfully"}

View File

@@ -7,7 +7,7 @@ import re
import os
from pathlib import Path
from difflib import SequenceMatcher
from datetime import datetime, date
from datetime import datetime, date, timezone
from decimal import Decimal
from typing import List, Dict, Any, Optional, Tuple
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File as UploadFileForm, Form, Query
@@ -19,8 +19,8 @@ from app.models.rolodex import Rolodex, Phone
from app.models.files import File
from app.models.ledger import Ledger
from app.models.qdro import QDRO
from app.models.pensions import Pension, PensionSchedule, MarriageHistory, DeathBenefit, SeparationAgreement, LifeTable, NumberTable
from app.models.lookups import Employee, FileType, FileStatus, TransactionType, TransactionCode, State, GroupLookup, Footer, PlanInfo, FormIndex, FormList, PrinterSetup, SystemSetup
from app.models.pensions import Pension, PensionSchedule, MarriageHistory, DeathBenefit, SeparationAgreement, LifeTable, NumberTable, PensionResult
from app.models.lookups import Employee, FileType, FileStatus, TransactionType, TransactionCode, State, GroupLookup, Footer, PlanInfo, FormIndex, FormList, PrinterSetup, SystemSetup, FormKeyword
from app.models.additional import Payment, Deposit, FileNote, FormVariable, ReportVariable
from app.models.flexible import FlexibleImport
from app.models.audit import ImportAudit, ImportAuditFile
@@ -28,6 +28,25 @@ from app.config import settings
router = APIRouter(tags=["import"])
# Common encodings to try for legacy CSV files (order matters)
ENCODINGS = [
'utf-8-sig',
'utf-8',
'windows-1252',
'iso-8859-1',
'cp1252',
]
# Unified import order used across batch operations
IMPORT_ORDER = [
"STATES.csv", "GRUPLKUP.csv", "EMPLOYEE.csv", "FILETYPE.csv", "FILESTAT.csv",
"TRNSTYPE.csv", "TRNSLKUP.csv", "FOOTERS.csv", "SETUP.csv", "PRINTERS.csv",
"INX_LKUP.csv",
"ROLODEX.csv", "PHONE.csv", "FILES.csv", "LEDGER.csv", "TRNSACTN.csv",
"QDROS.csv", "PENSIONS.csv", "LIFETABL.csv", "NUMBERAL.csv", "PLANINFO.csv", "RESULTS.csv", "PAYMENTS.csv", "DEPOSITS.csv",
"FILENOTS.csv", "FORM_INX.csv", "FORM_LST.csv", "FVARLKUP.csv", "RVARLKUP.csv"
]
# CSV to Model mapping
CSV_MODEL_MAPPING = {
@@ -56,7 +75,6 @@ CSV_MODEL_MAPPING = {
"FOOTERS.csv": Footer,
"PLANINFO.csv": PlanInfo,
# Legacy alternate names from export directories
"SCHEDULE.csv": PensionSchedule,
"FORM_INX.csv": FormIndex,
"FORM_LST.csv": FormList,
"PRINTERS.csv": PrinterSetup,
@@ -67,7 +85,9 @@ CSV_MODEL_MAPPING = {
"FVARLKUP.csv": FormVariable,
"RVARLKUP.csv": ReportVariable,
"PAYMENTS.csv": Payment,
"TRNSACTN.csv": Ledger # Maps to existing Ledger model (same structure)
"TRNSACTN.csv": Ledger, # Maps to existing Ledger model (same structure)
"INX_LKUP.csv": FormKeyword,
"RESULTS.csv": PensionResult
}
# Field mappings for CSV columns to database fields
@@ -230,8 +250,12 @@ FIELD_MAPPINGS = {
"Default_Rate": "default_rate"
},
"FILESTAT.csv": {
"Status": "status_code",
"Status_Code": "status_code",
"Definition": "description",
"Description": "description",
"Send": "send",
"Footer_Code": "footer_code",
"Sort_Order": "sort_order"
},
"FOOTERS.csv": {
@@ -253,22 +277,44 @@ FIELD_MAPPINGS = {
"Phone": "phone",
"Notes": "notes"
},
"INX_LKUP.csv": {
"Keyword": "keyword",
"Description": "description"
},
"FORM_INX.csv": {
"Form_Id": "form_id",
"Form_Name": "form_name",
"Category": "category"
"Name": "form_id",
"Keyword": "keyword"
},
"FORM_LST.csv": {
"Form_Id": "form_id",
"Line_Number": "line_number",
"Content": "content"
"Name": "form_id",
"Memo": "content",
"Status": "status"
},
"PRINTERS.csv": {
# Legacy variants
"Printer_Name": "printer_name",
"Description": "description",
"Driver": "driver",
"Port": "port",
"Default_Printer": "default_printer"
"Default_Printer": "default_printer",
# Observed legacy headers from export
"Number": "number",
"Name": "printer_name",
"Page_Break": "page_break",
"Setup_St": "setup_st",
"Reset_St": "reset_st",
"B_Underline": "b_underline",
"E_Underline": "e_underline",
"B_Bold": "b_bold",
"E_Bold": "e_bold",
# Optional report toggles
"Phone_Book": "phone_book",
"Rolodex_Info": "rolodex_info",
"Envelope": "envelope",
"File_Cabinet": "file_cabinet",
"Accounts": "accounts",
"Statements": "statements",
"Calendar": "calendar",
},
"SETUP.csv": {
"Setting_Key": "setting_key",
@@ -285,32 +331,98 @@ FIELD_MAPPINGS = {
"MARRIAGE.csv": {
"File_No": "file_no",
"Version": "version",
"Marriage_Date": "marriage_date",
"Separation_Date": "separation_date",
"Divorce_Date": "divorce_date"
"Married_From": "married_from",
"Married_To": "married_to",
"Married_Years": "married_years",
"Service_From": "service_from",
"Service_To": "service_to",
"Service_Years": "service_years",
"Marital_%": "marital_percent"
},
"DEATH.csv": {
"File_No": "file_no",
"Version": "version",
"Benefit_Type": "benefit_type",
"Benefit_Amount": "benefit_amount",
"Beneficiary": "beneficiary"
"Lump1": "lump1",
"Lump2": "lump2",
"Growth1": "growth1",
"Growth2": "growth2",
"Disc1": "disc1",
"Disc2": "disc2"
},
"SEPARATE.csv": {
"File_No": "file_no",
"Version": "version",
"Agreement_Date": "agreement_date",
"Terms": "terms"
"Separation_Rate": "terms"
},
"LIFETABL.csv": {
"Age": "age",
"Male_Mortality": "male_mortality",
"Female_Mortality": "female_mortality"
"AGE": "age",
"LE_AA": "le_aa",
"NA_AA": "na_aa",
"LE_AM": "le_am",
"NA_AM": "na_am",
"LE_AF": "le_af",
"NA_AF": "na_af",
"LE_WA": "le_wa",
"NA_WA": "na_wa",
"LE_WM": "le_wm",
"NA_WM": "na_wm",
"LE_WF": "le_wf",
"NA_WF": "na_wf",
"LE_BA": "le_ba",
"NA_BA": "na_ba",
"LE_BM": "le_bm",
"NA_BM": "na_bm",
"LE_BF": "le_bf",
"NA_BF": "na_bf",
"LE_HA": "le_ha",
"NA_HA": "na_ha",
"LE_HM": "le_hm",
"NA_HM": "na_hm",
"LE_HF": "le_hf",
"NA_HF": "na_hf"
},
"NUMBERAL.csv": {
"Table_Name": "table_name",
"Month": "month",
"NA_AA": "na_aa",
"NA_AM": "na_am",
"NA_AF": "na_af",
"NA_WA": "na_wa",
"NA_WM": "na_wm",
"NA_WF": "na_wf",
"NA_BA": "na_ba",
"NA_BM": "na_bm",
"NA_BF": "na_bf",
"NA_HA": "na_ha",
"NA_HM": "na_hm",
"NA_HF": "na_hf"
},
"RESULTS.csv": {
"Accrued": "accrued",
"Start_Age": "start_age",
"COLA": "cola",
"Withdrawal": "withdrawal",
"Pre_DR": "pre_dr",
"Post_DR": "post_dr",
"Tax_Rate": "tax_rate",
"Age": "age",
"Value": "value"
"Years_From": "years_from",
"Life_Exp": "life_exp",
"EV_Monthly": "ev_monthly",
"Payments": "payments",
"Pay_Out": "pay_out",
"Fund_Value": "fund_value",
"PV": "pv",
"Mortality": "mortality",
"PV_AM": "pv_am",
"PV_AMT": "pv_amt",
"PV_Pre_DB": "pv_pre_db",
"PV_Annuity": "pv_annuity",
"WV_AT": "wv_at",
"PV_Plan": "pv_plan",
"Years_Married": "years_married",
"Years_Service": "years_service",
"Marr_Per": "marr_per",
"Marr_Amt": "marr_amt"
},
# Additional CSV file mappings
"DEPOSITS.csv": {
@@ -357,7 +469,7 @@ FIELD_MAPPINGS = {
}
def parse_date(date_str: str) -> Optional[datetime]:
def parse_date(date_str: str) -> Optional[date]:
"""Parse date string in various formats"""
if not date_str or date_str.strip() == "":
return None
@@ -612,7 +724,11 @@ def convert_value(value: str, field_name: str) -> Any:
return parsed_date
# Boolean fields
if any(word in field_name.lower() for word in ["active", "default_printer", "billed", "transferable"]):
if any(word in field_name.lower() for word in [
"active", "default_printer", "billed", "transferable", "send",
# PrinterSetup legacy toggles
"phone_book", "rolodex_info", "envelope", "file_cabinet", "accounts", "statements", "calendar"
]):
if value.lower() in ["true", "1", "yes", "y", "on", "active"]:
return True
elif value.lower() in ["false", "0", "no", "n", "off", "inactive"]:
@@ -621,7 +737,11 @@ def convert_value(value: str, field_name: str) -> Any:
return None
# Numeric fields (float)
if any(word in field_name.lower() for word in ["rate", "hour", "bal", "fee", "amount", "owing", "transfer", "valu", "accrued", "vested", "cola", "tax", "percent", "benefit_amount", "mortality", "value"]):
if any(word in field_name.lower() for word in [
"rate", "hour", "bal", "fee", "amount", "owing", "transfer", "valu",
"accrued", "vested", "cola", "tax", "percent", "benefit_amount", "mortality",
"value"
]) or field_name.lower().startswith(("na_", "le_")):
try:
# Remove currency symbols and commas
cleaned_value = value.replace("$", "").replace(",", "").replace("%", "")
@@ -630,7 +750,9 @@ def convert_value(value: str, field_name: str) -> Any:
return 0.0
# Integer fields
if any(word in field_name.lower() for word in ["item_no", "age", "start_age", "version", "line_number", "sort_order", "empl_num"]):
if any(word in field_name.lower() for word in [
"item_no", "age", "start_age", "version", "line_number", "sort_order", "empl_num", "month", "number"
]):
try:
return int(float(value)) # Handle cases like "1.0"
except ValueError:
@@ -673,11 +795,18 @@ async def get_available_csv_files(current_user: User = Depends(get_current_user)
"available_files": list(CSV_MODEL_MAPPING.keys()),
"descriptions": {
"ROLODEX.csv": "Customer/contact information",
"ROLEX_V.csv": "Customer/contact information (alias)",
"PHONE.csv": "Phone numbers linked to customers",
"FILES.csv": "Client files and cases",
"FILES_R.csv": "Client files and cases (alias)",
"FILES_V.csv": "Client files and cases (alias)",
"LEDGER.csv": "Financial transactions per file",
"QDROS.csv": "Legal documents and court orders",
"PENSIONS.csv": "Pension calculation data",
"SCHEDULE.csv": "Vesting schedules for pensions",
"MARRIAGE.csv": "Marriage history data",
"DEATH.csv": "Death benefit calculations",
"SEPARATE.csv": "Separation agreements",
"EMPLOYEE.csv": "Staff and employee information",
"STATES.csv": "US States lookup table",
"FILETYPE.csv": "File type categories",
@@ -688,7 +817,12 @@ async def get_available_csv_files(current_user: User = Depends(get_current_user)
"FVARLKUP.csv": "Form template variables",
"RVARLKUP.csv": "Report template variables",
"PAYMENTS.csv": "Individual payments within deposits",
"TRNSACTN.csv": "Transaction details (maps to Ledger)"
"TRNSACTN.csv": "Transaction details (maps to Ledger)",
"INX_LKUP.csv": "Form keywords lookup",
"PLANINFO.csv": "Pension plan information",
"RESULTS.csv": "Pension computed results",
"LIFETABL.csv": "Life expectancy table by age, sex, and race (rich typed)",
"NUMBERAL.csv": "Monthly survivor counts by sex and race (rich typed)"
},
"auto_discovery": True
}
@@ -724,7 +858,7 @@ async def import_csv_data(
content = await file.read()
# Try multiple encodings for legacy CSV files
encodings = ['utf-8', 'windows-1252', 'iso-8859-1', 'cp1252']
encodings = ENCODINGS
csv_content = None
for encoding in encodings:
try:
@@ -736,34 +870,7 @@ async def import_csv_data(
if csv_content is None:
raise HTTPException(status_code=400, detail="Could not decode CSV file. Please ensure it's saved in UTF-8, Windows-1252, or ISO-8859-1 encoding.")
# Preprocess CSV content to fix common legacy issues
def preprocess_csv(content):
lines = content.split('\n')
cleaned_lines = []
i = 0
while i < len(lines):
line = lines[i]
# If line doesn't have the expected number of commas, it might be a broken multi-line field
if i == 0: # Header line
cleaned_lines.append(line)
expected_comma_count = line.count(',')
i += 1
continue
# Check if this line has the expected number of commas
if line.count(',') < expected_comma_count:
# This might be a continuation of the previous line
# Try to merge with previous line
if cleaned_lines:
cleaned_lines[-1] += " " + line.replace('\n', ' ').replace('\r', ' ')
else:
cleaned_lines.append(line)
else:
cleaned_lines.append(line)
i += 1
return '\n'.join(cleaned_lines)
# Note: preprocess_csv helper removed as unused; robust parsing handled below
# Custom robust parser for problematic legacy CSV files
class MockCSVReader:
@@ -791,7 +898,7 @@ async def import_csv_data(
header_reader = csv.reader(io.StringIO(lines[0]))
headers = next(header_reader)
headers = [h.strip() for h in headers]
print(f"DEBUG: Found {len(headers)} headers: {headers}")
# Debug logging removed in API path; rely on audit/logging if needed
# Build dynamic header mapping for this file/model
mapping_info = _build_dynamic_mapping(headers, model_class, file_type)
@@ -829,17 +936,21 @@ async def import_csv_data(
continue
csv_reader = MockCSVReader(rows_data, headers)
print(f"SUCCESS: Parsed {len(rows_data)} rows (skipped {skipped_rows} malformed rows)")
# Parsing summary suppressed to avoid noisy stdout in API
except Exception as e:
print(f"Custom parsing failed: {e}")
# Keep error minimal for client; internal logging can capture 'e'
raise HTTPException(status_code=400, detail=f"Could not parse CSV file. The file appears to have serious formatting issues. Error: {str(e)}")
imported_count = 0
created_count = 0
updated_count = 0
errors = []
flexible_saved = 0
mapped_headers = mapping_info.get("mapped_headers", {})
unmapped_headers = mapping_info.get("unmapped_headers", [])
# Special handling: assign line numbers per form for FORM_LST.csv
form_lst_line_counters: Dict[str, int] = {}
# If replace_existing is True, delete all existing records and related flexible extras
if replace_existing:
@@ -860,6 +971,16 @@ async def import_csv_data(
converted_value = convert_value(row[csv_field], db_field)
if converted_value is not None:
model_data[db_field] = converted_value
# Inject sequential line_number for FORM_LST rows grouped by form_id
if file_type == "FORM_LST.csv":
form_id_value = model_data.get("form_id")
if form_id_value:
current = form_lst_line_counters.get(str(form_id_value), 0) + 1
form_lst_line_counters[str(form_id_value)] = current
# Only set if not provided
if "line_number" not in model_data:
model_data["line_number"] = current
# Skip empty rows
if not any(model_data.values()):
@@ -902,10 +1023,43 @@ async def import_csv_data(
if 'file_no' not in model_data or not model_data['file_no']:
continue # Skip ledger records without file number
# Create model instance
instance = model_class(**model_data)
db.add(instance)
db.flush() # Ensure PK is available
# Create or update model instance
instance = None
# Upsert behavior for printers
if model_class == PrinterSetup:
# Determine primary key field name
_, pk_names = _get_model_columns(model_class)
pk_field_name_local = pk_names[0] if len(pk_names) == 1 else None
pk_value_local = model_data.get(pk_field_name_local) if pk_field_name_local else None
if pk_field_name_local and pk_value_local:
existing = db.query(model_class).filter(getattr(model_class, pk_field_name_local) == pk_value_local).first()
if existing:
# Update mutable fields
for k, v in model_data.items():
if k != pk_field_name_local:
setattr(existing, k, v)
instance = existing
updated_count += 1
else:
instance = model_class(**model_data)
db.add(instance)
created_count += 1
else:
# Fallback to insert if PK missing
instance = model_class(**model_data)
db.add(instance)
created_count += 1
db.flush()
# Enforce single default
try:
if bool(model_data.get("default_printer")):
db.query(model_class).filter(getattr(model_class, pk_field_name_local) != getattr(instance, pk_field_name_local)).update({model_class.default_printer: False})
except Exception:
pass
else:
instance = model_class(**model_data)
db.add(instance)
db.flush() # Ensure PK is available
# Capture PK details for flexible storage linkage (single-column PKs only)
_, pk_names = _get_model_columns(model_class)
@@ -980,6 +1134,10 @@ async def import_csv_data(
"flexible_saved_rows": flexible_saved,
},
}
# Include create/update breakdown for printers
if file_type == "PRINTERS.csv":
result["created_count"] = created_count
result["updated_count"] = updated_count
if errors:
result["warning"] = f"Import completed with {len(errors)} errors"
@@ -987,9 +1145,7 @@ async def import_csv_data(
return result
except Exception as e:
print(f"IMPORT ERROR DEBUG: {type(e).__name__}: {str(e)}")
import traceback
print(f"TRACEBACK: {traceback.format_exc()}")
# Suppress stdout debug prints in API layer
db.rollback()
raise HTTPException(status_code=500, detail=f"Import failed: {str(e)}")
@@ -1071,7 +1227,7 @@ async def validate_csv_file(
content = await file.read()
# Try multiple encodings for legacy CSV files
encodings = ['utf-8', 'windows-1252', 'iso-8859-1', 'cp1252']
encodings = ENCODINGS
csv_content = None
for encoding in encodings:
try:
@@ -1083,18 +1239,6 @@ async def validate_csv_file(
if csv_content is None:
raise HTTPException(status_code=400, detail="Could not decode CSV file. Please ensure it's saved in UTF-8, Windows-1252, or ISO-8859-1 encoding.")
# Parse CSV with fallback to robust line-by-line parsing
def parse_csv_with_fallback(text: str) -> Tuple[List[Dict[str, str]], List[str]]:
try:
reader = csv.DictReader(io.StringIO(text), delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
headers_local = reader.fieldnames or []
rows_local = []
for r in reader:
rows_local.append(r)
return rows_local, headers_local
except Exception:
return parse_csv_robust(text)
rows_list, csv_headers = parse_csv_with_fallback(csv_content)
model_class = CSV_MODEL_MAPPING[file_type]
mapping_info = _build_dynamic_mapping(csv_headers, model_class, file_type)
@@ -1142,9 +1286,7 @@ async def validate_csv_file(
}
except Exception as e:
print(f"VALIDATION ERROR DEBUG: {type(e).__name__}: {str(e)}")
import traceback
print(f"VALIDATION TRACEBACK: {traceback.format_exc()}")
# Suppress stdout debug prints in API layer
raise HTTPException(status_code=500, detail=f"Validation failed: {str(e)}")
@@ -1199,7 +1341,7 @@ async def batch_validate_csv_files(
content = await file.read()
# Try multiple encodings for legacy CSV files (include BOM-friendly utf-8-sig)
encodings = ['utf-8-sig', 'utf-8', 'windows-1252', 'iso-8859-1', 'cp1252']
encodings = ENCODINGS
csv_content = None
for encoding in encodings:
try:
@@ -1302,13 +1444,7 @@ async def batch_import_csv_files(
raise HTTPException(status_code=400, detail="Maximum 25 files allowed per batch")
# Define optimal import order based on dependencies
import_order = [
"STATES.csv", "GRUPLKUP.csv", "EMPLOYEE.csv", "FILETYPE.csv", "FILESTAT.csv",
"TRNSTYPE.csv", "TRNSLKUP.csv", "FOOTERS.csv", "SETUP.csv", "PRINTERS.csv",
"ROLODEX.csv", "PHONE.csv", "FILES.csv", "LEDGER.csv", "TRNSACTN.csv",
"QDROS.csv", "PENSIONS.csv", "PLANINFO.csv", "PAYMENTS.csv", "DEPOSITS.csv",
"FILENOTS.csv", "FORM_INX.csv", "FORM_LST.csv", "FVARLKUP.csv", "RVARLKUP.csv"
]
import_order = IMPORT_ORDER
# Sort uploaded files by optimal import order
file_map = {f.filename: f for f in files}
@@ -1365,7 +1501,7 @@ async def batch_import_csv_files(
saved_path = str(file_path)
except Exception:
saved_path = None
encodings = ['utf-8-sig', 'utf-8', 'windows-1252', 'iso-8859-1', 'cp1252']
encodings = ENCODINGS
csv_content = None
for encoding in encodings:
try:
@@ -1466,7 +1602,7 @@ async def batch_import_csv_files(
saved_path = None
# Try multiple encodings for legacy CSV files
encodings = ['utf-8-sig', 'utf-8', 'windows-1252', 'iso-8859-1', 'cp1252']
encodings = ENCODINGS
csv_content = None
for encoding in encodings:
try:
@@ -1505,6 +1641,8 @@ async def batch_import_csv_files(
imported_count = 0
errors = []
flexible_saved = 0
# Special handling: assign line numbers per form for FORM_LST.csv
form_lst_line_counters: Dict[str, int] = {}
# If replace_existing is True and this is the first file of this type
if replace_existing:
@@ -1523,6 +1661,15 @@ async def batch_import_csv_files(
converted_value = convert_value(row[csv_field], db_field)
if converted_value is not None:
model_data[db_field] = converted_value
# Inject sequential line_number for FORM_LST rows grouped by form_id
if file_type == "FORM_LST.csv":
form_id_value = model_data.get("form_id")
if form_id_value:
current = form_lst_line_counters.get(str(form_id_value), 0) + 1
form_lst_line_counters[str(form_id_value)] = current
if "line_number" not in model_data:
model_data["line_number"] = current
if not any(model_data.values()):
continue
@@ -1697,7 +1844,7 @@ async def batch_import_csv_files(
"completed_with_errors" if summary["successful_files"] > 0 else "failed"
)
audit_row.message = f"Batch import completed: {audit_row.successful_files}/{audit_row.total_files} files"
audit_row.finished_at = datetime.utcnow()
audit_row.finished_at = datetime.now(timezone.utc)
audit_row.details = {
"files": [
{"file_type": r.get("file_type"), "status": r.get("status"), "imported_count": r.get("imported_count", 0), "errors": r.get("errors", 0)}
@@ -1844,13 +1991,7 @@ async def rerun_failed_files(
raise HTTPException(status_code=400, detail="No saved files available to rerun. Upload again.")
# Import order for sorting
import_order = [
"STATES.csv", "GRUPLKUP.csv", "EMPLOYEE.csv", "FILETYPE.csv", "FILESTAT.csv",
"TRNSTYPE.csv", "TRNSLKUP.csv", "FOOTERS.csv", "SETUP.csv", "PRINTERS.csv",
"ROLODEX.csv", "PHONE.csv", "FILES.csv", "LEDGER.csv", "TRNSACTN.csv",
"QDROS.csv", "PENSIONS.csv", "PLANINFO.csv", "PAYMENTS.csv", "DEPOSITS.csv",
"FILENOTS.csv", "FORM_INX.csv", "FORM_LST.csv", "FVARLKUP.csv", "RVARLKUP.csv"
]
import_order = IMPORT_ORDER
order_index = {name: i for i, name in enumerate(import_order)}
items.sort(key=lambda x: order_index.get(x[0], len(import_order) + 1))
@@ -1898,7 +2039,7 @@ async def rerun_failed_files(
if file_type not in CSV_MODEL_MAPPING:
# Flexible-only path
encodings = ['utf-8-sig', 'utf-8', 'windows-1252', 'iso-8859-1', 'cp1252']
encodings = ENCODINGS
csv_content = None
for enc in encodings:
try:
@@ -1964,7 +2105,7 @@ async def rerun_failed_files(
# Known model path
model_class = CSV_MODEL_MAPPING[file_type]
encodings = ['utf-8-sig', 'utf-8', 'windows-1252', 'iso-8859-1', 'cp1252']
encodings = ENCODINGS
csv_content = None
for enc in encodings:
try:
@@ -1996,6 +2137,8 @@ async def rerun_failed_files(
unmapped_headers = mapping_info["unmapped_headers"]
imported_count = 0
errors: List[Dict[str, Any]] = []
# Special handling: assign line numbers per form for FORM_LST.csv
form_lst_line_counters: Dict[str, int] = {}
if replace_existing:
db.query(model_class).delete()
@@ -2013,6 +2156,14 @@ async def rerun_failed_files(
converted_value = convert_value(row[csv_field], db_field)
if converted_value is not None:
model_data[db_field] = converted_value
# Inject sequential line_number for FORM_LST rows grouped by form_id
if file_type == "FORM_LST.csv":
form_id_value = model_data.get("form_id")
if form_id_value:
current = form_lst_line_counters.get(str(form_id_value), 0) + 1
form_lst_line_counters[str(form_id_value)] = current
if "line_number" not in model_data:
model_data["line_number"] = current
if not any(model_data.values()):
continue
required_fields = _get_required_fields(model_class)
@@ -2147,7 +2298,7 @@ async def rerun_failed_files(
"completed_with_errors" if summary["successful_files"] > 0 else "failed"
)
rerun_audit.message = f"Rerun completed: {rerun_audit.successful_files}/{rerun_audit.total_files} files"
rerun_audit.finished_at = datetime.utcnow()
rerun_audit.finished_at = datetime.now(timezone.utc)
rerun_audit.details = {"rerun_of": audit_id}
db.add(rerun_audit)
db.commit()
@@ -2183,7 +2334,7 @@ async def upload_flexible_only(
db.commit()
content = await file.read()
encodings = ["utf-8-sig", "utf-8", "windows-1252", "iso-8859-1", "cp1252"]
encodings = ENCODINGS
csv_content = None
for encoding in encodings:
try:

72
app/api/mortality.py Normal file
View File

@@ -0,0 +1,72 @@
"""
Mortality/Life Table API endpoints
Provides read endpoints to query life tables by age and number tables by month,
filtered by sex (M/F/A) and race (W/B/H/A).
"""
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query, status, Path
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from app.database.base import get_db
from app.models.user import User
from app.auth.security import get_current_user
from app.services.mortality import get_life_values, get_number_value, InvalidCodeError
router = APIRouter()
class LifeResponse(BaseModel):
age: int
sex: str = Field(description="M, F, or A (all)")
race: str = Field(description="W, B, H, or A (all)")
le: Optional[float]
na: Optional[float]
class NumberResponse(BaseModel):
month: int
sex: str = Field(description="M, F, or A (all)")
race: str = Field(description="W, B, H, or A (all)")
na: Optional[float]
@router.get("/life/{age}", response_model=LifeResponse)
async def get_life_entry(
age: int = Path(..., ge=0, description="Age in years (>= 0)"),
sex: str = Query("A", min_length=1, max_length=1, description="M, F, or A (all)"),
race: str = Query("A", min_length=1, max_length=1, description="W, B, H, or A (all)"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Get life expectancy (LE) and number alive (NA) for an age/sex/race."""
try:
result = get_life_values(db, age=age, sex=sex, race=race)
except InvalidCodeError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
if result is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Age not found")
return result
@router.get("/number/{month}", response_model=NumberResponse)
async def get_number_entry(
month: int = Path(..., ge=0, description="Month index (>= 0)"),
sex: str = Query("A", min_length=1, max_length=1, description="M, F, or A (all)"),
race: str = Query("A", min_length=1, max_length=1, description="W, B, H, or A (all)"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Get monthly number alive (NA) for a month/sex/race."""
try:
result = get_number_value(db, month=month, sex=sex, race=race)
except InvalidCodeError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
if result is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Month not found")
return result

File diff suppressed because it is too large Load Diff

View File

@@ -2,8 +2,10 @@
Server-side highlight utilities for search results.
These functions generate HTML snippets with <strong> around matched tokens,
preserving the original casing of the source text. The output is intended to be
sanitized on the client before insertion into the DOM.
preserving the original casing of the source text. All non-HTML segments are
HTML-escaped server-side to prevent injection. Only the <strong> tags added by
this module are emitted as HTML; any pre-existing HTML in source text is
escaped.
"""
from typing import List, Tuple, Any
import re
@@ -42,18 +44,40 @@ def _merge_ranges(ranges: List[Tuple[int, int]]) -> List[Tuple[int, int]]:
def highlight_text(value: str, tokens: List[str]) -> str:
"""Return `value` with case-insensitive matches of `tokens` wrapped in <strong>, preserving original casing."""
"""Return `value` with case-insensitive matches of `tokens` wrapped in <strong>, preserving original casing.
Non-highlighted segments and the highlighted text content are HTML-escaped.
Only the surrounding <strong> wrappers are emitted as markup.
"""
if value is None:
return ""
def _escape_html(text: str) -> str:
# Minimal, safe HTML escaping
if text is None:
return ""
# Replace ampersand first to avoid double-escaping
text = str(text)
text = text.replace("&", "&amp;")
text = text.replace("<", "&lt;")
text = text.replace(">", "&gt;")
text = text.replace('"', "&quot;")
text = text.replace("'", "&#39;")
return text
source = str(value)
if not source or not tokens:
return source
return _escape_html(source)
haystack = source.lower()
ranges: List[Tuple[int, int]] = []
# Deduplicate tokens case-insensitively to avoid redundant scans (parity with client)
unique_needles = []
seen_needles = set()
for t in tokens:
needle = str(t or "").lower()
if not needle:
continue
if needle and needle not in seen_needles:
unique_needles.append(needle)
seen_needles.add(needle)
for needle in unique_needles:
start = 0
last_possible = max(0, len(haystack) - len(needle))
while start <= last_possible and len(needle) > 0:
@@ -63,17 +87,17 @@ def highlight_text(value: str, tokens: List[str]) -> str:
ranges.append((idx, idx + len(needle)))
start = idx + 1
if not ranges:
return source
return _escape_html(source)
parts: List[str] = []
merged = _merge_ranges(ranges)
pos = 0
for s, e in merged:
if pos < s:
parts.append(source[pos:s])
parts.append("<strong>" + source[s:e] + "</strong>")
parts.append(_escape_html(source[pos:s]))
parts.append("<strong>" + _escape_html(source[s:e]) + "</strong>")
pos = e
if pos < len(source):
parts.append(source[pos:])
parts.append(_escape_html(source[pos:]))
return "".join(parts)

View File

@@ -2,21 +2,24 @@
Support ticket API endpoints
"""
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, status, Request
from fastapi import APIRouter, Depends, HTTPException, status, Request, Query
from sqlalchemy.orm import Session, joinedload
from sqlalchemy import func, desc, and_, or_
from datetime import datetime
from datetime import datetime, timezone
import secrets
from app.database.base import get_db
from app.models import User, SupportTicket, TicketResponse as TicketResponseModel, TicketStatus, TicketPriority, TicketCategory
from app.auth.security import get_current_user, get_admin_user
from app.services.audit import audit_service
from app.services.query_utils import apply_sorting, paginate_with_total, tokenized_ilike_filter
from app.api.search_highlight import build_query_tokens
router = APIRouter()
# Pydantic models for API
from pydantic import BaseModel, Field, EmailStr
from pydantic.config import ConfigDict
class TicketCreate(BaseModel):
@@ -57,8 +60,7 @@ class TicketResponseOut(BaseModel):
author_email: Optional[str]
created_at: datetime
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
class TicketDetail(BaseModel):
@@ -81,15 +83,19 @@ class TicketDetail(BaseModel):
assigned_to: Optional[int]
assigned_admin_name: Optional[str]
submitter_name: Optional[str]
responses: List[TicketResponseOut] = []
responses: List[TicketResponseOut] = Field(default_factory=list)
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
class PaginatedTicketsResponse(BaseModel):
items: List[TicketDetail]
total: int
def generate_ticket_number() -> str:
"""Generate unique ticket number like ST-2024-001"""
year = datetime.now().year
year = datetime.now(timezone.utc).year
random_suffix = secrets.token_hex(2).upper()
return f"ST-{year}-{random_suffix}"
@@ -129,7 +135,7 @@ async def create_support_ticket(
ip_address=client_ip,
user_id=current_user.id if current_user else None,
status=TicketStatus.OPEN,
created_at=datetime.utcnow()
created_at=datetime.now(timezone.utc)
)
db.add(new_ticket)
@@ -158,14 +164,18 @@ async def create_support_ticket(
}
@router.get("/tickets", response_model=List[TicketDetail])
@router.get("/tickets", response_model=List[TicketDetail] | PaginatedTicketsResponse)
async def list_tickets(
status: Optional[TicketStatus] = None,
priority: Optional[TicketPriority] = None,
category: Optional[TicketCategory] = None,
assigned_to_me: bool = False,
skip: int = 0,
limit: int = 50,
status: Optional[TicketStatus] = Query(None, description="Filter by ticket status"),
priority: Optional[TicketPriority] = Query(None, description="Filter by ticket priority"),
category: Optional[TicketCategory] = Query(None, description="Filter by ticket category"),
assigned_to_me: bool = Query(False, description="Only include tickets assigned to the current admin"),
search: Optional[str] = Query(None, description="Tokenized search across number, subject, description, contact name/email, current page, and IP"),
skip: int = Query(0, ge=0, description="Offset for pagination"),
limit: int = Query(50, ge=1, le=200, description="Page size"),
sort_by: Optional[str] = Query(None, description="Sort by: created, updated, resolved, priority, status, subject"),
sort_dir: Optional[str] = Query("desc", description="Sort direction: asc or desc"),
include_total: bool = Query(False, description="When true, returns {items, total} instead of a plain list"),
db: Session = Depends(get_db),
current_user: User = Depends(get_admin_user)
):
@@ -186,8 +196,38 @@ async def list_tickets(
query = query.filter(SupportTicket.category == category)
if assigned_to_me:
query = query.filter(SupportTicket.assigned_to == current_user.id)
tickets = query.order_by(desc(SupportTicket.created_at)).offset(skip).limit(limit).all()
# Search across key text fields
if search:
tokens = build_query_tokens(search)
filter_expr = tokenized_ilike_filter(tokens, [
SupportTicket.ticket_number,
SupportTicket.subject,
SupportTicket.description,
SupportTicket.contact_name,
SupportTicket.contact_email,
SupportTicket.current_page,
SupportTicket.ip_address,
])
if filter_expr is not None:
query = query.filter(filter_expr)
# Sorting (whitelisted)
query = apply_sorting(
query,
sort_by,
sort_dir,
allowed={
"created": [SupportTicket.created_at],
"updated": [SupportTicket.updated_at],
"resolved": [SupportTicket.resolved_at],
"priority": [SupportTicket.priority],
"status": [SupportTicket.status],
"subject": [SupportTicket.subject],
},
)
tickets, total = paginate_with_total(query, skip, limit, include_total)
# Format response
result = []
@@ -226,6 +266,8 @@ async def list_tickets(
}
result.append(ticket_dict)
if include_total:
return {"items": result, "total": total or 0}
return result
@@ -312,10 +354,10 @@ async def update_ticket(
# Set resolved timestamp if status changed to resolved
if ticket_data.status == TicketStatus.RESOLVED and ticket.resolved_at is None:
ticket.resolved_at = datetime.utcnow()
ticket.resolved_at = datetime.now(timezone.utc)
changes["resolved_at"] = {"from": None, "to": ticket.resolved_at}
ticket.updated_at = datetime.utcnow()
ticket.updated_at = datetime.now(timezone.utc)
db.commit()
# Audit logging (non-blocking)
@@ -358,13 +400,13 @@ async def add_response(
message=response_data.message,
is_internal=response_data.is_internal,
user_id=current_user.id,
created_at=datetime.utcnow()
created_at=datetime.now(timezone.utc)
)
db.add(response)
# Update ticket timestamp
ticket.updated_at = datetime.utcnow()
ticket.updated_at = datetime.now(timezone.utc)
db.commit()
db.refresh(response)
@@ -386,11 +428,15 @@ async def add_response(
return {"message": "Response added successfully", "response_id": response.id}
@router.get("/my-tickets", response_model=List[TicketDetail])
@router.get("/my-tickets", response_model=List[TicketDetail] | PaginatedTicketsResponse)
async def get_my_tickets(
status: Optional[TicketStatus] = None,
skip: int = 0,
limit: int = 20,
status: Optional[TicketStatus] = Query(None, description="Filter by ticket status"),
search: Optional[str] = Query(None, description="Tokenized search across number, subject, description"),
skip: int = Query(0, ge=0, description="Offset for pagination"),
limit: int = Query(20, ge=1, le=200, description="Page size"),
sort_by: Optional[str] = Query(None, description="Sort by: created, updated, resolved, priority, status, subject"),
sort_dir: Optional[str] = Query("desc", description="Sort direction: asc or desc"),
include_total: bool = Query(False, description="When true, returns {items, total} instead of a plain list"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
@@ -403,7 +449,33 @@ async def get_my_tickets(
if status:
query = query.filter(SupportTicket.status == status)
tickets = query.order_by(desc(SupportTicket.created_at)).offset(skip).limit(limit).all()
# Search within user's tickets
if search:
tokens = build_query_tokens(search)
filter_expr = tokenized_ilike_filter(tokens, [
SupportTicket.ticket_number,
SupportTicket.subject,
SupportTicket.description,
])
if filter_expr is not None:
query = query.filter(filter_expr)
# Sorting (whitelisted)
query = apply_sorting(
query,
sort_by,
sort_dir,
allowed={
"created": [SupportTicket.created_at],
"updated": [SupportTicket.updated_at],
"resolved": [SupportTicket.resolved_at],
"priority": [SupportTicket.priority],
"status": [SupportTicket.status],
"subject": [SupportTicket.subject],
},
)
tickets, total = paginate_with_total(query, skip, limit, include_total)
# Format response (exclude internal responses for regular users)
result = []
@@ -442,6 +514,8 @@ async def get_my_tickets(
}
result.append(ticket_dict)
if include_total:
return {"items": result, "total": total or 0}
return result
@@ -473,7 +547,7 @@ async def get_ticket_stats(
# Recent tickets (last 7 days)
from datetime import timedelta
week_ago = datetime.utcnow() - timedelta(days=7)
week_ago = datetime.now(timezone.utc) - timedelta(days=7)
recent_tickets = db.query(func.count(SupportTicket.id)).filter(
SupportTicket.created_at >= week_ago
).scalar()