fixes and refactor

This commit is contained in:
HotSwapp
2025-08-14 19:16:28 -05:00
parent 5111079149
commit bfc04a6909
61 changed files with 5689 additions and 767 deletions

View File

@@ -1,7 +1,7 @@
"""
Comprehensive Admin API endpoints - User management, system settings, audit logging
"""
from typing import List, Dict, Any, Optional
from typing import List, Dict, Any, Optional, Union
from fastapi import APIRouter, Depends, HTTPException, status, UploadFile, File, Query, Body, Request
from fastapi.responses import FileResponse
from sqlalchemy.orm import Session, joinedload
@@ -13,24 +13,27 @@ import hashlib
import secrets
import shutil
import time
from datetime import datetime, timedelta, date
from datetime import datetime, timedelta, date, timezone
from pathlib import Path
from app.database.base import get_db
from app.api.search_highlight import build_query_tokens
# Track application start time
APPLICATION_START_TIME = time.time()
from app.models import User, Rolodex, File as FileModel, Ledger, QDRO, AuditLog, LoginAttempt
from app.models.lookups import SystemSetup, Employee, FileType, FileStatus, TransactionType, TransactionCode, State, FormIndex
from app.models.lookups import SystemSetup, Employee, FileType, FileStatus, TransactionType, TransactionCode, State, FormIndex, PrinterSetup
from app.auth.security import get_admin_user, get_password_hash, create_access_token
from app.services.audit import audit_service
from app.config import settings
from app.services.query_utils import apply_sorting, tokenized_ilike_filter, paginate_with_total
router = APIRouter()
# Enhanced Admin Schemas
from pydantic import BaseModel, Field, EmailStr
from pydantic.config import ConfigDict
class SystemStats(BaseModel):
"""Enhanced system statistics"""
@@ -91,8 +94,7 @@ class UserResponse(BaseModel):
created_at: Optional[datetime]
updated_at: Optional[datetime]
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
class PasswordReset(BaseModel):
"""Password reset request"""
@@ -124,8 +126,7 @@ class AuditLogEntry(BaseModel):
user_agent: Optional[str]
timestamp: datetime
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
class BackupInfo(BaseModel):
"""Backup information"""
@@ -135,6 +136,132 @@ class BackupInfo(BaseModel):
backup_type: str
status: str
class PrinterSetupBase(BaseModel):
"""Base schema for printer setup"""
description: Optional[str] = None
driver: Optional[str] = None
port: Optional[str] = None
default_printer: Optional[bool] = None
active: Optional[bool] = None
number: Optional[int] = None
page_break: Optional[str] = None
setup_st: Optional[str] = None
reset_st: Optional[str] = None
b_underline: Optional[str] = None
e_underline: Optional[str] = None
b_bold: Optional[str] = None
e_bold: Optional[str] = None
phone_book: Optional[bool] = None
rolodex_info: Optional[bool] = None
envelope: Optional[bool] = None
file_cabinet: Optional[bool] = None
accounts: Optional[bool] = None
statements: Optional[bool] = None
calendar: Optional[bool] = None
class PrinterSetupCreate(PrinterSetupBase):
printer_name: str
class PrinterSetupUpdate(PrinterSetupBase):
pass
class PrinterSetupResponse(PrinterSetupBase):
printer_name: str
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
model_config = ConfigDict(from_attributes=True)
# Printer Setup Management
@router.get("/printers", response_model=List[PrinterSetupResponse])
async def list_printers(
db: Session = Depends(get_db),
current_user: User = Depends(get_admin_user)
):
printers = db.query(PrinterSetup).order_by(PrinterSetup.printer_name.asc()).all()
return printers
@router.get("/printers/{printer_name}", response_model=PrinterSetupResponse)
async def get_printer(
printer_name: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_admin_user)
):
printer = db.query(PrinterSetup).filter(PrinterSetup.printer_name == printer_name).first()
if not printer:
raise HTTPException(status_code=404, detail="Printer not found")
return printer
@router.post("/printers", response_model=PrinterSetupResponse)
async def create_printer(
payload: PrinterSetupCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_admin_user)
):
exists = db.query(PrinterSetup).filter(PrinterSetup.printer_name == payload.printer_name).first()
if exists:
raise HTTPException(status_code=400, detail="Printer already exists")
data = payload.model_dump(exclude_unset=True)
instance = PrinterSetup(**data)
db.add(instance)
# Enforce single default printer
if data.get("default_printer"):
try:
db.query(PrinterSetup).filter(PrinterSetup.printer_name != instance.printer_name).update({PrinterSetup.default_printer: False})
except Exception:
pass
db.commit()
db.refresh(instance)
return instance
@router.put("/printers/{printer_name}", response_model=PrinterSetupResponse)
async def update_printer(
printer_name: str,
payload: PrinterSetupUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_admin_user)
):
instance = db.query(PrinterSetup).filter(PrinterSetup.printer_name == printer_name).first()
if not instance:
raise HTTPException(status_code=404, detail="Printer not found")
updates = payload.model_dump(exclude_unset=True)
for k, v in updates.items():
setattr(instance, k, v)
# Enforce single default printer when set true
if updates.get("default_printer"):
try:
db.query(PrinterSetup).filter(PrinterSetup.printer_name != instance.printer_name).update({PrinterSetup.default_printer: False})
except Exception:
pass
db.commit()
db.refresh(instance)
return instance
@router.delete("/printers/{printer_name}")
async def delete_printer(
printer_name: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_admin_user)
):
instance = db.query(PrinterSetup).filter(PrinterSetup.printer_name == printer_name).first()
if not instance:
raise HTTPException(status_code=404, detail="Printer not found")
db.delete(instance)
db.commit()
return {"message": "Printer deleted"}
class LookupTableInfo(BaseModel):
"""Lookup table information"""
table_name: str
@@ -202,7 +329,7 @@ async def system_health(
# Count active sessions (simplified)
try:
active_sessions = db.query(User).filter(
User.last_login > datetime.now() - timedelta(hours=24)
User.last_login > datetime.now(timezone.utc) - timedelta(hours=24)
).count()
except:
active_sessions = 0
@@ -215,7 +342,7 @@ async def system_health(
backup_files = list(backup_dir.glob("*.db"))
if backup_files:
latest_backup = max(backup_files, key=lambda p: p.stat().st_mtime)
backup_age = datetime.now() - datetime.fromtimestamp(latest_backup.stat().st_mtime)
backup_age = datetime.now(timezone.utc) - datetime.fromtimestamp(latest_backup.stat().st_mtime, tz=timezone.utc)
last_backup = latest_backup.name
if backup_age.days > 7:
alerts.append(f"Last backup is {backup_age.days} days old")
@@ -257,7 +384,7 @@ async def system_statistics(
# Count active users (logged in within last 30 days)
total_active_users = db.query(func.count(User.id)).filter(
User.last_login > datetime.now() - timedelta(days=30)
User.last_login > datetime.now(timezone.utc) - timedelta(days=30)
).scalar()
# Count admin users
@@ -308,7 +435,7 @@ async def system_statistics(
recent_activity.append({
"type": "customer_added",
"description": f"Customer {customer.first} {customer.last} added",
"timestamp": datetime.now().isoformat()
"timestamp": datetime.now(timezone.utc).isoformat()
})
except:
pass
@@ -409,7 +536,7 @@ async def export_table(
# Create exports directory if it doesn't exist
os.makedirs("exports", exist_ok=True)
filename = f"exports/{table_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
filename = f"exports/{table_name}_{datetime.now(timezone.utc).strftime('%Y%m%d_%H%M%S')}.csv"
try:
if table_name.lower() == "customers" or table_name.lower() == "rolodex":
@@ -470,10 +597,10 @@ async def download_backup(
if "sqlite" in settings.database_url:
db_path = settings.database_url.replace("sqlite:///", "")
if os.path.exists(db_path):
return FileResponse(
return FileResponse(
db_path,
media_type='application/octet-stream',
filename=f"delphi_backup_{datetime.now().strftime('%Y%m%d_%H%M%S')}.db"
filename=f"delphi_backup_{datetime.now(timezone.utc).strftime('%Y%m%d_%H%M%S')}.db"
)
raise HTTPException(
@@ -484,12 +611,20 @@ async def download_backup(
# User Management Endpoints
@router.get("/users", response_model=List[UserResponse])
class PaginatedUsersResponse(BaseModel):
items: List[UserResponse]
total: int
@router.get("/users", response_model=Union[List[UserResponse], PaginatedUsersResponse])
async def list_users(
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=1000),
search: Optional[str] = Query(None),
active_only: bool = Query(False),
sort_by: Optional[str] = Query(None, description="Sort by: username, email, first_name, last_name, created, updated"),
sort_dir: Optional[str] = Query("asc", description="Sort direction: asc or desc"),
include_total: bool = Query(False, description="When true, returns {items, total} instead of a plain list"),
db: Session = Depends(get_db),
current_user: User = Depends(get_admin_user)
):
@@ -498,19 +633,38 @@ async def list_users(
query = db.query(User)
if search:
query = query.filter(
or_(
User.username.ilike(f"%{search}%"),
User.email.ilike(f"%{search}%"),
User.first_name.ilike(f"%{search}%"),
User.last_name.ilike(f"%{search}%")
)
)
# DRY: tokenize and apply case-insensitive multi-field search
tokens = build_query_tokens(search)
filter_expr = tokenized_ilike_filter(tokens, [
User.username,
User.email,
User.first_name,
User.last_name,
])
if filter_expr is not None:
query = query.filter(filter_expr)
if active_only:
query = query.filter(User.is_active == True)
users = query.offset(skip).limit(limit).all()
# Sorting (whitelisted)
query = apply_sorting(
query,
sort_by,
sort_dir,
allowed={
"username": [User.username],
"email": [User.email],
"first_name": [User.first_name],
"last_name": [User.last_name],
"created": [User.created_at],
"updated": [User.updated_at],
},
)
users, total = paginate_with_total(query, skip, limit, include_total)
if include_total:
return {"items": users, "total": total or 0}
return users
@@ -567,8 +721,8 @@ async def create_user(
hashed_password=hashed_password,
is_admin=user_data.is_admin,
is_active=user_data.is_active,
created_at=datetime.now(),
updated_at=datetime.now()
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc)
)
db.add(new_user)
@@ -659,7 +813,7 @@ async def update_user(
changes[field] = {"from": getattr(user, field), "to": value}
setattr(user, field, value)
user.updated_at = datetime.now()
user.updated_at = datetime.now(timezone.utc)
db.commit()
db.refresh(user)
@@ -702,7 +856,7 @@ async def delete_user(
# Soft delete by deactivating
user.is_active = False
user.updated_at = datetime.now()
user.updated_at = datetime.now(timezone.utc)
db.commit()
@@ -744,7 +898,7 @@ async def reset_user_password(
# Update password
user.hashed_password = get_password_hash(password_data.new_password)
user.updated_at = datetime.now()
user.updated_at = datetime.now(timezone.utc)
db.commit()
@@ -1046,7 +1200,7 @@ async def create_backup(
backup_dir.mkdir(exist_ok=True)
# Generate backup filename
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
backup_filename = f"delphi_backup_{timestamp}.db"
backup_path = backup_dir / backup_filename
@@ -1063,7 +1217,7 @@ async def create_backup(
"backup_info": {
"filename": backup_filename,
"size": f"{backup_size / (1024*1024):.1f} MB",
"created_at": datetime.now().isoformat(),
"created_at": datetime.now(timezone.utc).isoformat(),
"backup_type": "manual",
"status": "completed"
}
@@ -1118,45 +1272,61 @@ async def get_audit_logs(
resource_type: Optional[str] = Query(None),
action: Optional[str] = Query(None),
hours_back: int = Query(168, ge=1, le=8760), # Default 7 days, max 1 year
sort_by: Optional[str] = Query("timestamp", description="Sort by: timestamp, username, action, resource_type"),
sort_dir: Optional[str] = Query("desc", description="Sort direction: asc or desc"),
include_total: bool = Query(False, description="When true, returns {items, total} instead of a plain list"),
db: Session = Depends(get_db),
current_user: User = Depends(get_admin_user)
):
"""Get audit log entries with filtering"""
cutoff_time = datetime.now() - timedelta(hours=hours_back)
"""Get audit log entries with filtering, sorting, and pagination"""
cutoff_time = datetime.now(timezone.utc) - timedelta(hours=hours_back)
query = db.query(AuditLog).filter(AuditLog.timestamp >= cutoff_time)
if user_id:
query = query.filter(AuditLog.user_id == user_id)
if resource_type:
query = query.filter(AuditLog.resource_type.ilike(f"%{resource_type}%"))
if action:
query = query.filter(AuditLog.action.ilike(f"%{action}%"))
total_count = query.count()
logs = query.order_by(AuditLog.timestamp.desc()).offset(skip).limit(limit).all()
return {
"total": total_count,
"logs": [
{
"id": log.id,
"user_id": log.user_id,
"username": log.username,
"action": log.action,
"resource_type": log.resource_type,
"resource_id": log.resource_id,
"details": log.details,
"ip_address": log.ip_address,
"user_agent": log.user_agent,
"timestamp": log.timestamp.isoformat()
}
for log in logs
]
}
# Sorting (whitelisted)
query = apply_sorting(
query,
sort_by,
sort_dir,
allowed={
"timestamp": [AuditLog.timestamp],
"username": [AuditLog.username],
"action": [AuditLog.action],
"resource_type": [AuditLog.resource_type],
},
)
logs, total = paginate_with_total(query, skip, limit, include_total)
items = [
{
"id": log.id,
"user_id": log.user_id,
"username": log.username,
"action": log.action,
"resource_type": log.resource_type,
"resource_id": log.resource_id,
"details": log.details,
"ip_address": log.ip_address,
"user_agent": log.user_agent,
"timestamp": log.timestamp.isoformat(),
}
for log in logs
]
if include_total:
return {"items": items, "total": total or 0}
return items
@router.get("/audit/login-attempts")
@@ -1166,39 +1336,55 @@ async def get_login_attempts(
username: Optional[str] = Query(None),
failed_only: bool = Query(False),
hours_back: int = Query(168, ge=1, le=8760), # Default 7 days
sort_by: Optional[str] = Query("timestamp", description="Sort by: timestamp, username, ip_address, success"),
sort_dir: Optional[str] = Query("desc", description="Sort direction: asc or desc"),
include_total: bool = Query(False, description="When true, returns {items, total} instead of a plain list"),
db: Session = Depends(get_db),
current_user: User = Depends(get_admin_user)
):
"""Get login attempts with filtering"""
cutoff_time = datetime.now() - timedelta(hours=hours_back)
"""Get login attempts with filtering, sorting, and pagination"""
cutoff_time = datetime.now(timezone.utc) - timedelta(hours=hours_back)
query = db.query(LoginAttempt).filter(LoginAttempt.timestamp >= cutoff_time)
if username:
query = query.filter(LoginAttempt.username.ilike(f"%{username}%"))
if failed_only:
query = query.filter(LoginAttempt.success == 0)
total_count = query.count()
attempts = query.order_by(LoginAttempt.timestamp.desc()).offset(skip).limit(limit).all()
return {
"total": total_count,
"attempts": [
{
"id": attempt.id,
"username": attempt.username,
"ip_address": attempt.ip_address,
"user_agent": attempt.user_agent,
"success": bool(attempt.success),
"failure_reason": attempt.failure_reason,
"timestamp": attempt.timestamp.isoformat()
}
for attempt in attempts
]
}
# Sorting (whitelisted)
query = apply_sorting(
query,
sort_by,
sort_dir,
allowed={
"timestamp": [LoginAttempt.timestamp],
"username": [LoginAttempt.username],
"ip_address": [LoginAttempt.ip_address],
"success": [LoginAttempt.success],
},
)
attempts, total = paginate_with_total(query, skip, limit, include_total)
items = [
{
"id": attempt.id,
"username": attempt.username,
"ip_address": attempt.ip_address,
"user_agent": attempt.user_agent,
"success": bool(attempt.success),
"failure_reason": attempt.failure_reason,
"timestamp": attempt.timestamp.isoformat(),
}
for attempt in attempts
]
if include_total:
return {"items": items, "total": total or 0}
return items
@router.get("/audit/user-activity/{user_id}")
@@ -1251,7 +1437,7 @@ async def get_security_alerts(
):
"""Get security alerts and suspicious activity"""
cutoff_time = datetime.now() - timedelta(hours=hours_back)
cutoff_time = datetime.now(timezone.utc) - timedelta(hours=hours_back)
# Get failed login attempts
failed_logins = db.query(LoginAttempt).filter(
@@ -1356,7 +1542,7 @@ async def get_audit_statistics(
):
"""Get audit statistics and metrics"""
cutoff_time = datetime.now() - timedelta(days=days_back)
cutoff_time = datetime.now(timezone.utc) - timedelta(days=days_back)
# Total activity counts
total_audit_entries = db.query(func.count(AuditLog.id)).filter(

View File

@@ -1,7 +1,7 @@
"""
Authentication API endpoints
"""
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from typing import List
from fastapi import APIRouter, Depends, HTTPException, status, Request
from fastapi.security import OAuth2PasswordRequestForm
@@ -69,7 +69,7 @@ async def login(login_data: LoginRequest, request: Request, db: Session = Depend
)
# Update last login
user.last_login = datetime.utcnow()
user.last_login = datetime.now(timezone.utc)
db.commit()
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
@@ -144,7 +144,7 @@ async def read_users_me(current_user: User = Depends(get_current_user)):
async def refresh_token_endpoint(
request: Request,
db: Session = Depends(get_db),
body: RefreshRequest | None = None,
body: RefreshRequest = None,
):
"""Issue a new access token using a valid, non-revoked refresh token.
@@ -203,7 +203,7 @@ async def refresh_token_endpoint(
if not user or not user.is_active:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found or inactive")
user.last_login = datetime.utcnow()
user.last_login = datetime.now(timezone.utc)
db.commit()
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
@@ -225,7 +225,7 @@ async def list_users(
@router.post("/logout")
async def logout(body: RefreshRequest | None = None, db: Session = Depends(get_db)):
async def logout(body: RefreshRequest = None, db: Session = Depends(get_db)):
"""Revoke the provided refresh token. Idempotent and safe to call multiple times.
The client should send a JSON body: { "refresh_token": "..." }.

View File

@@ -4,7 +4,7 @@ Customer (Rolodex) API endpoints
from typing import List, Optional, Union
from fastapi import APIRouter, Depends, HTTPException, status, Query
from sqlalchemy.orm import Session, joinedload
from sqlalchemy import or_, and_, func, asc, desc
from sqlalchemy import func
from fastapi.responses import StreamingResponse
import csv
import io
@@ -13,12 +13,16 @@ from app.database.base import get_db
from app.models.rolodex import Rolodex, Phone
from app.models.user import User
from app.auth.security import get_current_user
from app.services.cache import invalidate_search_cache
from app.services.customers_search import apply_customer_filters, apply_customer_sorting, prepare_customer_csv_rows
from app.services.query_utils import apply_sorting, paginate_with_total
router = APIRouter()
# Pydantic schemas for request/response
from pydantic import BaseModel, EmailStr
from pydantic import BaseModel, EmailStr, Field
from pydantic.config import ConfigDict
from datetime import date
@@ -32,8 +36,7 @@ class PhoneResponse(BaseModel):
location: Optional[str]
phone: str
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
class CustomerBase(BaseModel):
@@ -84,10 +87,11 @@ class CustomerUpdate(BaseModel):
class CustomerResponse(CustomerBase):
phone_numbers: List[PhoneResponse] = []
phone_numbers: List[PhoneResponse] = Field(default_factory=list)
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
@router.get("/search/phone")
@@ -196,80 +200,17 @@ async def list_customers(
try:
base_query = db.query(Rolodex)
if search:
s = (search or "").strip()
s_lower = s.lower()
tokens = [t for t in s_lower.split() if t]
# Basic contains search on several fields (case-insensitive)
contains_any = or_(
func.lower(Rolodex.id).contains(s_lower),
func.lower(Rolodex.last).contains(s_lower),
func.lower(Rolodex.first).contains(s_lower),
func.lower(Rolodex.middle).contains(s_lower),
func.lower(Rolodex.city).contains(s_lower),
func.lower(Rolodex.email).contains(s_lower),
)
# Multi-token name support: every token must match either first, middle, or last
name_tokens = [
or_(
func.lower(Rolodex.first).contains(tok),
func.lower(Rolodex.middle).contains(tok),
func.lower(Rolodex.last).contains(tok),
)
for tok in tokens
]
combined = contains_any if not name_tokens else or_(contains_any, and_(*name_tokens))
# Comma pattern: "Last, First"
last_first_filter = None
if "," in s_lower:
last_part, first_part = [p.strip() for p in s_lower.split(",", 1)]
if last_part and first_part:
last_first_filter = and_(
func.lower(Rolodex.last).contains(last_part),
func.lower(Rolodex.first).contains(first_part),
)
elif last_part:
last_first_filter = func.lower(Rolodex.last).contains(last_part)
final_filter = or_(combined, last_first_filter) if last_first_filter is not None else combined
base_query = base_query.filter(final_filter)
# Apply group/state filters (support single and multi-select)
effective_groups = [g for g in (groups or []) if g] or ([group] if group else [])
if effective_groups:
base_query = base_query.filter(Rolodex.group.in_(effective_groups))
effective_states = [s for s in (states or []) if s] or ([state] if state else [])
if effective_states:
base_query = base_query.filter(Rolodex.abrev.in_(effective_states))
base_query = apply_customer_filters(
base_query,
search=search,
group=group,
state=state,
groups=groups,
states=states,
)
# Apply sorting (whitelisted fields only)
normalized_sort_by = (sort_by or "id").lower()
normalized_sort_dir = (sort_dir or "asc").lower()
is_desc = normalized_sort_dir == "desc"
order_columns = []
if normalized_sort_by == "id":
order_columns = [Rolodex.id]
elif normalized_sort_by == "name":
# Sort by last, then first
order_columns = [Rolodex.last, Rolodex.first]
elif normalized_sort_by == "city":
# Sort by city, then state abbreviation
order_columns = [Rolodex.city, Rolodex.abrev]
elif normalized_sort_by == "email":
order_columns = [Rolodex.email]
else:
# Fallback to id to avoid arbitrary column injection
order_columns = [Rolodex.id]
# Case-insensitive ordering where applicable, preserving None ordering default
ordered = []
for col in order_columns:
# Use lower() for string-like cols; SQLAlchemy will handle non-string safely enough for SQLite/Postgres
expr = func.lower(col) if col.type.python_type in (str,) else col # type: ignore[attr-defined]
ordered.append(desc(expr) if is_desc else asc(expr))
if ordered:
base_query = base_query.order_by(*ordered)
base_query = apply_customer_sorting(base_query, sort_by=sort_by, sort_dir=sort_dir)
customers = base_query.options(joinedload(Rolodex.phone_numbers)).offset(skip).limit(limit).all()
if include_total:
@@ -304,72 +245,16 @@ async def export_customers(
try:
base_query = db.query(Rolodex)
if search:
s = (search or "").strip()
s_lower = s.lower()
tokens = [t for t in s_lower.split() if t]
contains_any = or_(
func.lower(Rolodex.id).contains(s_lower),
func.lower(Rolodex.last).contains(s_lower),
func.lower(Rolodex.first).contains(s_lower),
func.lower(Rolodex.middle).contains(s_lower),
func.lower(Rolodex.city).contains(s_lower),
func.lower(Rolodex.email).contains(s_lower),
)
name_tokens = [
or_(
func.lower(Rolodex.first).contains(tok),
func.lower(Rolodex.middle).contains(tok),
func.lower(Rolodex.last).contains(tok),
)
for tok in tokens
]
combined = contains_any if not name_tokens else or_(contains_any, and_(*name_tokens))
last_first_filter = None
if "," in s_lower:
last_part, first_part = [p.strip() for p in s_lower.split(",", 1)]
if last_part and first_part:
last_first_filter = and_(
func.lower(Rolodex.last).contains(last_part),
func.lower(Rolodex.first).contains(first_part),
)
elif last_part:
last_first_filter = func.lower(Rolodex.last).contains(last_part)
final_filter = or_(combined, last_first_filter) if last_first_filter is not None else combined
base_query = base_query.filter(final_filter)
base_query = apply_customer_filters(
base_query,
search=search,
group=group,
state=state,
groups=groups,
states=states,
)
effective_groups = [g for g in (groups or []) if g] or ([group] if group else [])
if effective_groups:
base_query = base_query.filter(Rolodex.group.in_(effective_groups))
effective_states = [s for s in (states or []) if s] or ([state] if state else [])
if effective_states:
base_query = base_query.filter(Rolodex.abrev.in_(effective_states))
normalized_sort_by = (sort_by or "id").lower()
normalized_sort_dir = (sort_dir or "asc").lower()
is_desc = normalized_sort_dir == "desc"
order_columns = []
if normalized_sort_by == "id":
order_columns = [Rolodex.id]
elif normalized_sort_by == "name":
order_columns = [Rolodex.last, Rolodex.first]
elif normalized_sort_by == "city":
order_columns = [Rolodex.city, Rolodex.abrev]
elif normalized_sort_by == "email":
order_columns = [Rolodex.email]
else:
order_columns = [Rolodex.id]
ordered = []
for col in order_columns:
try:
expr = func.lower(col) if col.type.python_type in (str,) else col # type: ignore[attr-defined]
except Exception:
expr = col
ordered.append(desc(expr) if is_desc else asc(expr))
if ordered:
base_query = base_query.order_by(*ordered)
base_query = apply_customer_sorting(base_query, sort_by=sort_by, sort_dir=sort_dir)
if not export_all:
if skip is not None:
@@ -382,39 +267,10 @@ async def export_customers(
# Prepare CSV
output = io.StringIO()
writer = csv.writer(output)
allowed_fields_in_order = ["id", "name", "group", "city", "state", "phone", "email"]
header_names = {
"id": "Customer ID",
"name": "Name",
"group": "Group",
"city": "City",
"state": "State",
"phone": "Primary Phone",
"email": "Email",
}
requested = [f.lower() for f in (fields or []) if isinstance(f, str)]
selected_fields = [f for f in allowed_fields_in_order if f in requested] if requested else allowed_fields_in_order
if not selected_fields:
selected_fields = allowed_fields_in_order
writer.writerow([header_names[f] for f in selected_fields])
for c in customers:
full_name = f"{(c.first or '').strip()} {(c.last or '').strip()}".strip()
primary_phone = ""
try:
if c.phone_numbers:
primary_phone = c.phone_numbers[0].phone or ""
except Exception:
primary_phone = ""
row_map = {
"id": c.id,
"name": full_name,
"group": c.group or "",
"city": c.city or "",
"state": c.abrev or "",
"phone": primary_phone,
"email": c.email or "",
}
writer.writerow([row_map[f] for f in selected_fields])
header_row, rows = prepare_customer_csv_rows(customers, fields)
writer.writerow(header_row)
for row in rows:
writer.writerow(row)
output.seek(0)
filename = "customers_export.csv"
@@ -469,6 +325,10 @@ async def create_customer(
db.commit()
db.refresh(customer)
try:
await invalidate_search_cache()
except Exception:
pass
return customer
@@ -494,7 +354,10 @@ async def update_customer(
db.commit()
db.refresh(customer)
try:
await invalidate_search_cache()
except Exception:
pass
return customer
@@ -515,17 +378,30 @@ async def delete_customer(
db.delete(customer)
db.commit()
try:
await invalidate_search_cache()
except Exception:
pass
return {"message": "Customer deleted successfully"}
@router.get("/{customer_id}/phones", response_model=List[PhoneResponse])
class PaginatedPhonesResponse(BaseModel):
items: List[PhoneResponse]
total: int
@router.get("/{customer_id}/phones", response_model=Union[List[PhoneResponse], PaginatedPhonesResponse])
async def get_customer_phones(
customer_id: str,
skip: int = Query(0, ge=0, description="Offset for pagination"),
limit: int = Query(100, ge=1, le=1000, description="Page size"),
sort_by: Optional[str] = Query("location", description="Sort by: location, phone"),
sort_dir: Optional[str] = Query("asc", description="Sort direction: asc or desc"),
include_total: bool = Query(False, description="When true, returns {items, total} instead of a plain list"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Get customer phone numbers"""
"""Get customer phone numbers with optional sorting/pagination"""
customer = db.query(Rolodex).filter(Rolodex.id == customer_id).first()
if not customer:
@@ -534,7 +410,21 @@ async def get_customer_phones(
detail="Customer not found"
)
phones = db.query(Phone).filter(Phone.rolodex_id == customer_id).all()
query = db.query(Phone).filter(Phone.rolodex_id == customer_id)
query = apply_sorting(
query,
sort_by,
sort_dir,
allowed={
"location": [Phone.location, Phone.phone],
"phone": [Phone.phone],
},
)
phones, total = paginate_with_total(query, skip, limit, include_total)
if include_total:
return {"items": phones, "total": total or 0}
return phones

View File

@@ -1,16 +1,19 @@
"""
Document Management API endpoints - QDROs, Templates, and General Documents
"""
from typing import List, Optional, Dict, Any
from __future__ import annotations
from typing import List, Optional, Dict, Any, Union
from fastapi import APIRouter, Depends, HTTPException, status, Query, UploadFile, File, Form, Request
from sqlalchemy.orm import Session, joinedload
from sqlalchemy import or_, func, and_, desc, asc, text
from datetime import date, datetime
from datetime import date, datetime, timezone
import os
import uuid
import shutil
from app.database.base import get_db
from app.api.search_highlight import build_query_tokens
from app.services.query_utils import tokenized_ilike_filter, apply_pagination, apply_sorting, paginate_with_total
from app.models.qdro import QDRO
from app.models.files import File as FileModel
from app.models.rolodex import Rolodex
@@ -20,18 +23,20 @@ from app.auth.security import get_current_user
from app.models.additional import Document
from app.core.logging import get_logger
from app.services.audit import audit_service
from app.services.cache import invalidate_search_cache
router = APIRouter()
# Pydantic schemas
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict
class QDROBase(BaseModel):
file_no: str
version: str = "01"
title: Optional[str] = None
form_name: Optional[str] = None
content: Optional[str] = None
status: str = "DRAFT"
created_date: Optional[date] = None
@@ -51,6 +56,7 @@ class QDROCreate(QDROBase):
class QDROUpdate(BaseModel):
version: Optional[str] = None
title: Optional[str] = None
form_name: Optional[str] = None
content: Optional[str] = None
status: Optional[str] = None
created_date: Optional[date] = None
@@ -66,27 +72,61 @@ class QDROUpdate(BaseModel):
class QDROResponse(QDROBase):
id: int
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
@router.get("/qdros/{file_no}", response_model=List[QDROResponse])
class PaginatedQDROResponse(BaseModel):
items: List[QDROResponse]
total: int
@router.get("/qdros/{file_no}", response_model=Union[List[QDROResponse], PaginatedQDROResponse])
async def get_file_qdros(
file_no: str,
skip: int = Query(0, ge=0, description="Offset for pagination"),
limit: int = Query(100, ge=1, le=1000, description="Page size"),
sort_by: Optional[str] = Query("updated", description="Sort by: updated, created, version, status"),
sort_dir: Optional[str] = Query("desc", description="Sort direction: asc or desc"),
include_total: bool = Query(False, description="When true, returns {items, total} instead of a plain list"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Get QDROs for specific file"""
qdros = db.query(QDRO).filter(QDRO.file_no == file_no).order_by(QDRO.version).all()
"""Get QDROs for a specific file with optional sorting/pagination"""
query = db.query(QDRO).filter(QDRO.file_no == file_no)
# Sorting (whitelisted)
query = apply_sorting(
query,
sort_by,
sort_dir,
allowed={
"updated": [QDRO.updated_at, QDRO.id],
"created": [QDRO.created_at, QDRO.id],
"version": [QDRO.version],
"status": [QDRO.status],
},
)
qdros, total = paginate_with_total(query, skip, limit, include_total)
if include_total:
return {"items": qdros, "total": total or 0}
return qdros
@router.get("/qdros/", response_model=List[QDROResponse])
class PaginatedQDROResponse(BaseModel):
items: List[QDROResponse]
total: int
@router.get("/qdros/", response_model=Union[List[QDROResponse], PaginatedQDROResponse])
async def list_qdros(
skip: int = Query(0, ge=0),
limit: int = Query(50, ge=1, le=200),
status_filter: Optional[str] = Query(None),
search: Optional[str] = Query(None),
sort_by: Optional[str] = Query(None, description="Sort by: file_no, version, status, created, updated"),
sort_dir: Optional[str] = Query("asc", description="Sort direction: asc or desc"),
include_total: bool = Query(False, description="When true, returns {items, total} instead of a plain list"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
@@ -97,17 +137,37 @@ async def list_qdros(
query = query.filter(QDRO.status == status_filter)
if search:
query = query.filter(
or_(
QDRO.file_no.contains(search),
QDRO.title.contains(search),
QDRO.participant_name.contains(search),
QDRO.spouse_name.contains(search),
QDRO.plan_name.contains(search)
)
)
qdros = query.offset(skip).limit(limit).all()
# DRY: tokenize and apply case-insensitive search across common QDRO fields
tokens = build_query_tokens(search)
filter_expr = tokenized_ilike_filter(tokens, [
QDRO.file_no,
QDRO.form_name,
QDRO.pet,
QDRO.res,
QDRO.case_number,
QDRO.notes,
QDRO.status,
])
if filter_expr is not None:
query = query.filter(filter_expr)
# Sorting (whitelisted)
query = apply_sorting(
query,
sort_by,
sort_dir,
allowed={
"file_no": [QDRO.file_no],
"version": [QDRO.version],
"status": [QDRO.status],
"created": [QDRO.created_at],
"updated": [QDRO.updated_at],
},
)
qdros, total = paginate_with_total(query, skip, limit, include_total)
if include_total:
return {"items": qdros, "total": total or 0}
return qdros
@@ -135,6 +195,10 @@ async def create_qdro(
db.commit()
db.refresh(qdro)
try:
await invalidate_search_cache()
except Exception:
pass
return qdro
@@ -189,6 +253,10 @@ async def update_qdro(
db.commit()
db.refresh(qdro)
try:
await invalidate_search_cache()
except Exception:
pass
return qdro
@@ -213,7 +281,10 @@ async def delete_qdro(
db.delete(qdro)
db.commit()
try:
await invalidate_search_cache()
except Exception:
pass
return {"message": "QDRO deleted successfully"}
@@ -241,8 +312,7 @@ class TemplateResponse(TemplateBase):
active: bool = True
created_at: Optional[datetime] = None
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
# Document Generation Schema
class DocumentGenerateRequest(BaseModel):
@@ -269,13 +339,21 @@ class DocumentStats(BaseModel):
recent_activity: List[Dict[str, Any]]
@router.get("/templates/", response_model=List[TemplateResponse])
class PaginatedTemplatesResponse(BaseModel):
items: List[TemplateResponse]
total: int
@router.get("/templates/", response_model=Union[List[TemplateResponse], PaginatedTemplatesResponse])
async def list_templates(
skip: int = Query(0, ge=0),
limit: int = Query(50, ge=1, le=200),
category: Optional[str] = Query(None),
search: Optional[str] = Query(None),
active_only: bool = Query(True),
sort_by: Optional[str] = Query(None, description="Sort by: form_id, form_name, category, created, updated"),
sort_dir: Optional[str] = Query("asc", description="Sort direction: asc or desc"),
include_total: bool = Query(False, description="When true, returns {items, total} instead of a plain list"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
@@ -289,14 +367,31 @@ async def list_templates(
query = query.filter(FormIndex.category == category)
if search:
query = query.filter(
or_(
FormIndex.form_name.contains(search),
FormIndex.form_id.contains(search)
)
)
templates = query.offset(skip).limit(limit).all()
# DRY: tokenize and apply case-insensitive search for templates
tokens = build_query_tokens(search)
filter_expr = tokenized_ilike_filter(tokens, [
FormIndex.form_name,
FormIndex.form_id,
FormIndex.category,
])
if filter_expr is not None:
query = query.filter(filter_expr)
# Sorting (whitelisted)
query = apply_sorting(
query,
sort_by,
sort_dir,
allowed={
"form_id": [FormIndex.form_id],
"form_name": [FormIndex.form_name],
"category": [FormIndex.category],
"created": [FormIndex.created_at],
"updated": [FormIndex.updated_at],
},
)
templates, total = paginate_with_total(query, skip, limit, include_total)
# Enhanced response with template content
results = []
@@ -317,6 +412,8 @@ async def list_templates(
"variables": _extract_variables_from_content(content)
})
if include_total:
return {"items": results, "total": total or 0}
return results
@@ -356,6 +453,10 @@ async def create_template(
db.commit()
db.refresh(form_index)
try:
await invalidate_search_cache()
except Exception:
pass
return {
"form_id": form_index.form_id,
@@ -440,6 +541,10 @@ async def update_template(
db.commit()
db.refresh(template)
try:
await invalidate_search_cache()
except Exception:
pass
# Get updated content
template_lines = db.query(FormList).filter(
@@ -480,6 +585,10 @@ async def delete_template(
# Delete template
db.delete(template)
db.commit()
try:
await invalidate_search_cache()
except Exception:
pass
return {"message": "Template deleted successfully"}
@@ -574,7 +683,7 @@ async def generate_document(
"file_name": file_name,
"file_path": file_path,
"size": file_size,
"created_at": datetime.now()
"created_at": datetime.now(timezone.utc)
}
@@ -629,32 +738,49 @@ async def get_document_stats(
@router.get("/file/{file_no}/documents")
async def get_file_documents(
file_no: str,
sort_by: Optional[str] = Query("updated", description="Sort by: updated, created"),
sort_dir: Optional[str] = Query("desc", description="Sort direction: asc or desc"),
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=1000),
include_total: bool = Query(False, description="When true, returns {items, total} instead of a plain list"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Get all documents associated with a specific file"""
# Get QDROs for this file
qdros = db.query(QDRO).filter(QDRO.file_no == file_no).order_by(desc(QDRO.updated_at)).all()
# Format response
documents = [
"""Get all documents associated with a specific file, with optional sorting/pagination"""
# Base query for QDROs tied to the file
query = db.query(QDRO).filter(QDRO.file_no == file_no)
# Apply sorting using shared helper (map friendly names to columns)
query = apply_sorting(
query,
sort_by,
sort_dir,
allowed={
"updated": [QDRO.updated_at, QDRO.id],
"created": [QDRO.created_at, QDRO.id],
},
)
qdros, total = paginate_with_total(query, skip, limit, include_total)
items = [
{
"id": qdro.id,
"type": "QDRO",
"title": f"QDRO v{qdro.version}",
"status": qdro.status,
"created_date": qdro.created_date.isoformat() if qdro.created_date else None,
"updated_at": qdro.updated_at.isoformat() if qdro.updated_at else None,
"file_no": qdro.file_no
"created_date": qdro.created_date.isoformat() if getattr(qdro, "created_date", None) else None,
"updated_at": qdro.updated_at.isoformat() if getattr(qdro, "updated_at", None) else None,
"file_no": qdro.file_no,
}
for qdro in qdros
]
return {
"file_no": file_no,
"documents": documents,
"total_count": len(documents)
}
payload = {"file_no": file_no, "documents": items, "total_count": (total if include_total else None)}
# Maintain previous shape by omitting total_count when include_total is False? The prior code always returned total_count.
# Keep total_count for backward compatibility but set to actual total when include_total else len(items)
payload["total_count"] = (total if include_total else len(items))
return payload
def _extract_variables_from_content(content: str) -> Dict[str, str]:

View File

@@ -1,25 +1,29 @@
"""
File Management API endpoints
"""
from typing import List, Optional, Dict, Any
from typing import List, Optional, Dict, Any, Union
from fastapi import APIRouter, Depends, HTTPException, status, Query
from sqlalchemy.orm import Session, joinedload
from sqlalchemy import or_, func, and_, desc
from datetime import date, datetime
from app.database.base import get_db
from app.api.search_highlight import build_query_tokens
from app.services.query_utils import tokenized_ilike_filter, apply_pagination, apply_sorting, paginate_with_total
from app.models.files import File
from app.models.rolodex import Rolodex
from app.models.ledger import Ledger
from app.models.lookups import Employee, FileType, FileStatus
from app.models.user import User
from app.auth.security import get_current_user
from app.services.cache import invalidate_search_cache
router = APIRouter()
# Pydantic schemas
from pydantic import BaseModel
from pydantic.config import ConfigDict
class FileBase(BaseModel):
@@ -67,17 +71,24 @@ class FileResponse(FileBase):
amount_owing: float = 0.0
transferable: float = 0.0
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
@router.get("/", response_model=List[FileResponse])
class PaginatedFilesResponse(BaseModel):
items: List[FileResponse]
total: int
@router.get("/", response_model=Union[List[FileResponse], PaginatedFilesResponse])
async def list_files(
skip: int = Query(0, ge=0),
limit: int = Query(50, ge=1, le=200),
search: Optional[str] = Query(None),
status_filter: Optional[str] = Query(None),
employee_filter: Optional[str] = Query(None),
sort_by: Optional[str] = Query(None, description="Sort by: file_no, client, opened, closed, status, amount_owing, total_charges"),
sort_dir: Optional[str] = Query("asc", description="Sort direction: asc or desc"),
include_total: bool = Query(False, description="When true, returns {items, total} instead of a plain list"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
@@ -85,14 +96,17 @@ async def list_files(
query = db.query(File)
if search:
query = query.filter(
or_(
File.file_no.contains(search),
File.id.contains(search),
File.regarding.contains(search),
File.file_type.contains(search)
)
)
# DRY: tokenize and apply case-insensitive search consistently with search endpoints
tokens = build_query_tokens(search)
filter_expr = tokenized_ilike_filter(tokens, [
File.file_no,
File.id,
File.regarding,
File.file_type,
File.memo,
])
if filter_expr is not None:
query = query.filter(filter_expr)
if status_filter:
query = query.filter(File.status == status_filter)
@@ -100,7 +114,25 @@ async def list_files(
if employee_filter:
query = query.filter(File.empl_num == employee_filter)
files = query.offset(skip).limit(limit).all()
# Sorting (whitelisted)
query = apply_sorting(
query,
sort_by,
sort_dir,
allowed={
"file_no": [File.file_no],
"client": [File.id],
"opened": [File.opened],
"closed": [File.closed],
"status": [File.status],
"amount_owing": [File.amount_owing],
"total_charges": [File.total_charges],
},
)
files, total = paginate_with_total(query, skip, limit, include_total)
if include_total:
return {"items": files, "total": total or 0}
return files
@@ -142,6 +174,10 @@ async def create_file(
db.commit()
db.refresh(file_obj)
try:
await invalidate_search_cache()
except Exception:
pass
return file_obj
@@ -167,7 +203,10 @@ async def update_file(
db.commit()
db.refresh(file_obj)
try:
await invalidate_search_cache()
except Exception:
pass
return file_obj
@@ -188,7 +227,10 @@ async def delete_file(
db.delete(file_obj)
db.commit()
try:
await invalidate_search_cache()
except Exception:
pass
return {"message": "File deleted successfully"}
@@ -433,11 +475,13 @@ async def advanced_file_search(
query = query.filter(File.file_no.contains(file_no))
if client_name:
# SQLite-safe concatenation for first + last name
full_name_expr = (func.coalesce(Rolodex.first, '') + ' ' + func.coalesce(Rolodex.last, ''))
query = query.join(Rolodex).filter(
or_(
Rolodex.first.contains(client_name),
Rolodex.last.contains(client_name),
func.concat(Rolodex.first, ' ', Rolodex.last).contains(client_name)
full_name_expr.contains(client_name)
)
)

View File

@@ -1,11 +1,11 @@
"""
Financial/Ledger API endpoints
"""
from typing import List, Optional, Dict, Any
from typing import List, Optional, Dict, Any, Union
from fastapi import APIRouter, Depends, HTTPException, status, Query
from sqlalchemy.orm import Session, joinedload
from sqlalchemy import or_, func, and_, desc, asc, text
from datetime import date, datetime, timedelta
from datetime import date, datetime, timedelta, timezone
from app.database.base import get_db
from app.models.ledger import Ledger
@@ -14,12 +14,14 @@ from app.models.rolodex import Rolodex
from app.models.lookups import Employee, TransactionType, TransactionCode
from app.models.user import User
from app.auth.security import get_current_user
from app.services.cache import invalidate_search_cache
from app.services.query_utils import apply_sorting, paginate_with_total
router = APIRouter()
# Pydantic schemas
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict
class LedgerBase(BaseModel):
@@ -57,8 +59,7 @@ class LedgerResponse(LedgerBase):
id: int
item_no: int
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
class FinancialSummary(BaseModel):
@@ -75,23 +76,46 @@ class FinancialSummary(BaseModel):
billed_amount: float
@router.get("/ledger/{file_no}", response_model=List[LedgerResponse])
class PaginatedLedgerResponse(BaseModel):
items: List[LedgerResponse]
total: int
@router.get("/ledger/{file_no}", response_model=Union[List[LedgerResponse], PaginatedLedgerResponse])
async def get_file_ledger(
file_no: str,
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=500),
billed_only: Optional[bool] = Query(None),
skip: int = Query(0, ge=0, description="Offset for pagination"),
limit: int = Query(100, ge=1, le=500, description="Page size"),
billed_only: Optional[bool] = Query(None, description="Filter billed vs unbilled entries"),
sort_by: Optional[str] = Query("date", description="Sort by: date, item_no, amount, billed"),
sort_dir: Optional[str] = Query("desc", description="Sort direction: asc or desc"),
include_total: bool = Query(False, description="When true, returns {items, total} instead of a plain list"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Get ledger entries for specific file"""
query = db.query(Ledger).filter(Ledger.file_no == file_no).order_by(Ledger.date.desc())
query = db.query(Ledger).filter(Ledger.file_no == file_no)
if billed_only is not None:
billed_filter = "Y" if billed_only else "N"
query = query.filter(Ledger.billed == billed_filter)
entries = query.offset(skip).limit(limit).all()
# Sorting (whitelisted)
query = apply_sorting(
query,
sort_by,
sort_dir,
allowed={
"date": [Ledger.date, Ledger.item_no],
"item_no": [Ledger.item_no],
"amount": [Ledger.amount],
"billed": [Ledger.billed, Ledger.date],
},
)
entries, total = paginate_with_total(query, skip, limit, include_total)
if include_total:
return {"items": entries, "total": total or 0}
return entries
@@ -127,6 +151,10 @@ async def create_ledger_entry(
# Update file balances (simplified version)
await _update_file_balances(file_obj, db)
try:
await invalidate_search_cache()
except Exception:
pass
return entry
@@ -158,6 +186,10 @@ async def update_ledger_entry(
if file_obj:
await _update_file_balances(file_obj, db)
try:
await invalidate_search_cache()
except Exception:
pass
return entry
@@ -185,6 +217,10 @@ async def delete_ledger_entry(
if file_obj:
await _update_file_balances(file_obj, db)
try:
await invalidate_search_cache()
except Exception:
pass
return {"message": "Ledger entry deleted successfully"}

View File

@@ -7,7 +7,7 @@ import re
import os
from pathlib import Path
from difflib import SequenceMatcher
from datetime import datetime, date
from datetime import datetime, date, timezone
from decimal import Decimal
from typing import List, Dict, Any, Optional, Tuple
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File as UploadFileForm, Form, Query
@@ -19,8 +19,8 @@ from app.models.rolodex import Rolodex, Phone
from app.models.files import File
from app.models.ledger import Ledger
from app.models.qdro import QDRO
from app.models.pensions import Pension, PensionSchedule, MarriageHistory, DeathBenefit, SeparationAgreement, LifeTable, NumberTable
from app.models.lookups import Employee, FileType, FileStatus, TransactionType, TransactionCode, State, GroupLookup, Footer, PlanInfo, FormIndex, FormList, PrinterSetup, SystemSetup
from app.models.pensions import Pension, PensionSchedule, MarriageHistory, DeathBenefit, SeparationAgreement, LifeTable, NumberTable, PensionResult
from app.models.lookups import Employee, FileType, FileStatus, TransactionType, TransactionCode, State, GroupLookup, Footer, PlanInfo, FormIndex, FormList, PrinterSetup, SystemSetup, FormKeyword
from app.models.additional import Payment, Deposit, FileNote, FormVariable, ReportVariable
from app.models.flexible import FlexibleImport
from app.models.audit import ImportAudit, ImportAuditFile
@@ -28,6 +28,25 @@ from app.config import settings
router = APIRouter(tags=["import"])
# Common encodings to try for legacy CSV files (order matters)
ENCODINGS = [
'utf-8-sig',
'utf-8',
'windows-1252',
'iso-8859-1',
'cp1252',
]
# Unified import order used across batch operations
IMPORT_ORDER = [
"STATES.csv", "GRUPLKUP.csv", "EMPLOYEE.csv", "FILETYPE.csv", "FILESTAT.csv",
"TRNSTYPE.csv", "TRNSLKUP.csv", "FOOTERS.csv", "SETUP.csv", "PRINTERS.csv",
"INX_LKUP.csv",
"ROLODEX.csv", "PHONE.csv", "FILES.csv", "LEDGER.csv", "TRNSACTN.csv",
"QDROS.csv", "PENSIONS.csv", "LIFETABL.csv", "NUMBERAL.csv", "PLANINFO.csv", "RESULTS.csv", "PAYMENTS.csv", "DEPOSITS.csv",
"FILENOTS.csv", "FORM_INX.csv", "FORM_LST.csv", "FVARLKUP.csv", "RVARLKUP.csv"
]
# CSV to Model mapping
CSV_MODEL_MAPPING = {
@@ -56,7 +75,6 @@ CSV_MODEL_MAPPING = {
"FOOTERS.csv": Footer,
"PLANINFO.csv": PlanInfo,
# Legacy alternate names from export directories
"SCHEDULE.csv": PensionSchedule,
"FORM_INX.csv": FormIndex,
"FORM_LST.csv": FormList,
"PRINTERS.csv": PrinterSetup,
@@ -67,7 +85,9 @@ CSV_MODEL_MAPPING = {
"FVARLKUP.csv": FormVariable,
"RVARLKUP.csv": ReportVariable,
"PAYMENTS.csv": Payment,
"TRNSACTN.csv": Ledger # Maps to existing Ledger model (same structure)
"TRNSACTN.csv": Ledger, # Maps to existing Ledger model (same structure)
"INX_LKUP.csv": FormKeyword,
"RESULTS.csv": PensionResult
}
# Field mappings for CSV columns to database fields
@@ -230,8 +250,12 @@ FIELD_MAPPINGS = {
"Default_Rate": "default_rate"
},
"FILESTAT.csv": {
"Status": "status_code",
"Status_Code": "status_code",
"Definition": "description",
"Description": "description",
"Send": "send",
"Footer_Code": "footer_code",
"Sort_Order": "sort_order"
},
"FOOTERS.csv": {
@@ -253,22 +277,44 @@ FIELD_MAPPINGS = {
"Phone": "phone",
"Notes": "notes"
},
"INX_LKUP.csv": {
"Keyword": "keyword",
"Description": "description"
},
"FORM_INX.csv": {
"Form_Id": "form_id",
"Form_Name": "form_name",
"Category": "category"
"Name": "form_id",
"Keyword": "keyword"
},
"FORM_LST.csv": {
"Form_Id": "form_id",
"Line_Number": "line_number",
"Content": "content"
"Name": "form_id",
"Memo": "content",
"Status": "status"
},
"PRINTERS.csv": {
# Legacy variants
"Printer_Name": "printer_name",
"Description": "description",
"Driver": "driver",
"Port": "port",
"Default_Printer": "default_printer"
"Default_Printer": "default_printer",
# Observed legacy headers from export
"Number": "number",
"Name": "printer_name",
"Page_Break": "page_break",
"Setup_St": "setup_st",
"Reset_St": "reset_st",
"B_Underline": "b_underline",
"E_Underline": "e_underline",
"B_Bold": "b_bold",
"E_Bold": "e_bold",
# Optional report toggles
"Phone_Book": "phone_book",
"Rolodex_Info": "rolodex_info",
"Envelope": "envelope",
"File_Cabinet": "file_cabinet",
"Accounts": "accounts",
"Statements": "statements",
"Calendar": "calendar",
},
"SETUP.csv": {
"Setting_Key": "setting_key",
@@ -285,32 +331,98 @@ FIELD_MAPPINGS = {
"MARRIAGE.csv": {
"File_No": "file_no",
"Version": "version",
"Marriage_Date": "marriage_date",
"Separation_Date": "separation_date",
"Divorce_Date": "divorce_date"
"Married_From": "married_from",
"Married_To": "married_to",
"Married_Years": "married_years",
"Service_From": "service_from",
"Service_To": "service_to",
"Service_Years": "service_years",
"Marital_%": "marital_percent"
},
"DEATH.csv": {
"File_No": "file_no",
"Version": "version",
"Benefit_Type": "benefit_type",
"Benefit_Amount": "benefit_amount",
"Beneficiary": "beneficiary"
"Lump1": "lump1",
"Lump2": "lump2",
"Growth1": "growth1",
"Growth2": "growth2",
"Disc1": "disc1",
"Disc2": "disc2"
},
"SEPARATE.csv": {
"File_No": "file_no",
"Version": "version",
"Agreement_Date": "agreement_date",
"Terms": "terms"
"Separation_Rate": "terms"
},
"LIFETABL.csv": {
"Age": "age",
"Male_Mortality": "male_mortality",
"Female_Mortality": "female_mortality"
"AGE": "age",
"LE_AA": "le_aa",
"NA_AA": "na_aa",
"LE_AM": "le_am",
"NA_AM": "na_am",
"LE_AF": "le_af",
"NA_AF": "na_af",
"LE_WA": "le_wa",
"NA_WA": "na_wa",
"LE_WM": "le_wm",
"NA_WM": "na_wm",
"LE_WF": "le_wf",
"NA_WF": "na_wf",
"LE_BA": "le_ba",
"NA_BA": "na_ba",
"LE_BM": "le_bm",
"NA_BM": "na_bm",
"LE_BF": "le_bf",
"NA_BF": "na_bf",
"LE_HA": "le_ha",
"NA_HA": "na_ha",
"LE_HM": "le_hm",
"NA_HM": "na_hm",
"LE_HF": "le_hf",
"NA_HF": "na_hf"
},
"NUMBERAL.csv": {
"Table_Name": "table_name",
"Month": "month",
"NA_AA": "na_aa",
"NA_AM": "na_am",
"NA_AF": "na_af",
"NA_WA": "na_wa",
"NA_WM": "na_wm",
"NA_WF": "na_wf",
"NA_BA": "na_ba",
"NA_BM": "na_bm",
"NA_BF": "na_bf",
"NA_HA": "na_ha",
"NA_HM": "na_hm",
"NA_HF": "na_hf"
},
"RESULTS.csv": {
"Accrued": "accrued",
"Start_Age": "start_age",
"COLA": "cola",
"Withdrawal": "withdrawal",
"Pre_DR": "pre_dr",
"Post_DR": "post_dr",
"Tax_Rate": "tax_rate",
"Age": "age",
"Value": "value"
"Years_From": "years_from",
"Life_Exp": "life_exp",
"EV_Monthly": "ev_monthly",
"Payments": "payments",
"Pay_Out": "pay_out",
"Fund_Value": "fund_value",
"PV": "pv",
"Mortality": "mortality",
"PV_AM": "pv_am",
"PV_AMT": "pv_amt",
"PV_Pre_DB": "pv_pre_db",
"PV_Annuity": "pv_annuity",
"WV_AT": "wv_at",
"PV_Plan": "pv_plan",
"Years_Married": "years_married",
"Years_Service": "years_service",
"Marr_Per": "marr_per",
"Marr_Amt": "marr_amt"
},
# Additional CSV file mappings
"DEPOSITS.csv": {
@@ -357,7 +469,7 @@ FIELD_MAPPINGS = {
}
def parse_date(date_str: str) -> Optional[datetime]:
def parse_date(date_str: str) -> Optional[date]:
"""Parse date string in various formats"""
if not date_str or date_str.strip() == "":
return None
@@ -612,7 +724,11 @@ def convert_value(value: str, field_name: str) -> Any:
return parsed_date
# Boolean fields
if any(word in field_name.lower() for word in ["active", "default_printer", "billed", "transferable"]):
if any(word in field_name.lower() for word in [
"active", "default_printer", "billed", "transferable", "send",
# PrinterSetup legacy toggles
"phone_book", "rolodex_info", "envelope", "file_cabinet", "accounts", "statements", "calendar"
]):
if value.lower() in ["true", "1", "yes", "y", "on", "active"]:
return True
elif value.lower() in ["false", "0", "no", "n", "off", "inactive"]:
@@ -621,7 +737,11 @@ def convert_value(value: str, field_name: str) -> Any:
return None
# Numeric fields (float)
if any(word in field_name.lower() for word in ["rate", "hour", "bal", "fee", "amount", "owing", "transfer", "valu", "accrued", "vested", "cola", "tax", "percent", "benefit_amount", "mortality", "value"]):
if any(word in field_name.lower() for word in [
"rate", "hour", "bal", "fee", "amount", "owing", "transfer", "valu",
"accrued", "vested", "cola", "tax", "percent", "benefit_amount", "mortality",
"value"
]) or field_name.lower().startswith(("na_", "le_")):
try:
# Remove currency symbols and commas
cleaned_value = value.replace("$", "").replace(",", "").replace("%", "")
@@ -630,7 +750,9 @@ def convert_value(value: str, field_name: str) -> Any:
return 0.0
# Integer fields
if any(word in field_name.lower() for word in ["item_no", "age", "start_age", "version", "line_number", "sort_order", "empl_num"]):
if any(word in field_name.lower() for word in [
"item_no", "age", "start_age", "version", "line_number", "sort_order", "empl_num", "month", "number"
]):
try:
return int(float(value)) # Handle cases like "1.0"
except ValueError:
@@ -673,11 +795,18 @@ async def get_available_csv_files(current_user: User = Depends(get_current_user)
"available_files": list(CSV_MODEL_MAPPING.keys()),
"descriptions": {
"ROLODEX.csv": "Customer/contact information",
"ROLEX_V.csv": "Customer/contact information (alias)",
"PHONE.csv": "Phone numbers linked to customers",
"FILES.csv": "Client files and cases",
"FILES_R.csv": "Client files and cases (alias)",
"FILES_V.csv": "Client files and cases (alias)",
"LEDGER.csv": "Financial transactions per file",
"QDROS.csv": "Legal documents and court orders",
"PENSIONS.csv": "Pension calculation data",
"SCHEDULE.csv": "Vesting schedules for pensions",
"MARRIAGE.csv": "Marriage history data",
"DEATH.csv": "Death benefit calculations",
"SEPARATE.csv": "Separation agreements",
"EMPLOYEE.csv": "Staff and employee information",
"STATES.csv": "US States lookup table",
"FILETYPE.csv": "File type categories",
@@ -688,7 +817,12 @@ async def get_available_csv_files(current_user: User = Depends(get_current_user)
"FVARLKUP.csv": "Form template variables",
"RVARLKUP.csv": "Report template variables",
"PAYMENTS.csv": "Individual payments within deposits",
"TRNSACTN.csv": "Transaction details (maps to Ledger)"
"TRNSACTN.csv": "Transaction details (maps to Ledger)",
"INX_LKUP.csv": "Form keywords lookup",
"PLANINFO.csv": "Pension plan information",
"RESULTS.csv": "Pension computed results",
"LIFETABL.csv": "Life expectancy table by age, sex, and race (rich typed)",
"NUMBERAL.csv": "Monthly survivor counts by sex and race (rich typed)"
},
"auto_discovery": True
}
@@ -724,7 +858,7 @@ async def import_csv_data(
content = await file.read()
# Try multiple encodings for legacy CSV files
encodings = ['utf-8', 'windows-1252', 'iso-8859-1', 'cp1252']
encodings = ENCODINGS
csv_content = None
for encoding in encodings:
try:
@@ -736,34 +870,7 @@ async def import_csv_data(
if csv_content is None:
raise HTTPException(status_code=400, detail="Could not decode CSV file. Please ensure it's saved in UTF-8, Windows-1252, or ISO-8859-1 encoding.")
# Preprocess CSV content to fix common legacy issues
def preprocess_csv(content):
lines = content.split('\n')
cleaned_lines = []
i = 0
while i < len(lines):
line = lines[i]
# If line doesn't have the expected number of commas, it might be a broken multi-line field
if i == 0: # Header line
cleaned_lines.append(line)
expected_comma_count = line.count(',')
i += 1
continue
# Check if this line has the expected number of commas
if line.count(',') < expected_comma_count:
# This might be a continuation of the previous line
# Try to merge with previous line
if cleaned_lines:
cleaned_lines[-1] += " " + line.replace('\n', ' ').replace('\r', ' ')
else:
cleaned_lines.append(line)
else:
cleaned_lines.append(line)
i += 1
return '\n'.join(cleaned_lines)
# Note: preprocess_csv helper removed as unused; robust parsing handled below
# Custom robust parser for problematic legacy CSV files
class MockCSVReader:
@@ -791,7 +898,7 @@ async def import_csv_data(
header_reader = csv.reader(io.StringIO(lines[0]))
headers = next(header_reader)
headers = [h.strip() for h in headers]
print(f"DEBUG: Found {len(headers)} headers: {headers}")
# Debug logging removed in API path; rely on audit/logging if needed
# Build dynamic header mapping for this file/model
mapping_info = _build_dynamic_mapping(headers, model_class, file_type)
@@ -829,17 +936,21 @@ async def import_csv_data(
continue
csv_reader = MockCSVReader(rows_data, headers)
print(f"SUCCESS: Parsed {len(rows_data)} rows (skipped {skipped_rows} malformed rows)")
# Parsing summary suppressed to avoid noisy stdout in API
except Exception as e:
print(f"Custom parsing failed: {e}")
# Keep error minimal for client; internal logging can capture 'e'
raise HTTPException(status_code=400, detail=f"Could not parse CSV file. The file appears to have serious formatting issues. Error: {str(e)}")
imported_count = 0
created_count = 0
updated_count = 0
errors = []
flexible_saved = 0
mapped_headers = mapping_info.get("mapped_headers", {})
unmapped_headers = mapping_info.get("unmapped_headers", [])
# Special handling: assign line numbers per form for FORM_LST.csv
form_lst_line_counters: Dict[str, int] = {}
# If replace_existing is True, delete all existing records and related flexible extras
if replace_existing:
@@ -860,6 +971,16 @@ async def import_csv_data(
converted_value = convert_value(row[csv_field], db_field)
if converted_value is not None:
model_data[db_field] = converted_value
# Inject sequential line_number for FORM_LST rows grouped by form_id
if file_type == "FORM_LST.csv":
form_id_value = model_data.get("form_id")
if form_id_value:
current = form_lst_line_counters.get(str(form_id_value), 0) + 1
form_lst_line_counters[str(form_id_value)] = current
# Only set if not provided
if "line_number" not in model_data:
model_data["line_number"] = current
# Skip empty rows
if not any(model_data.values()):
@@ -902,10 +1023,43 @@ async def import_csv_data(
if 'file_no' not in model_data or not model_data['file_no']:
continue # Skip ledger records without file number
# Create model instance
instance = model_class(**model_data)
db.add(instance)
db.flush() # Ensure PK is available
# Create or update model instance
instance = None
# Upsert behavior for printers
if model_class == PrinterSetup:
# Determine primary key field name
_, pk_names = _get_model_columns(model_class)
pk_field_name_local = pk_names[0] if len(pk_names) == 1 else None
pk_value_local = model_data.get(pk_field_name_local) if pk_field_name_local else None
if pk_field_name_local and pk_value_local:
existing = db.query(model_class).filter(getattr(model_class, pk_field_name_local) == pk_value_local).first()
if existing:
# Update mutable fields
for k, v in model_data.items():
if k != pk_field_name_local:
setattr(existing, k, v)
instance = existing
updated_count += 1
else:
instance = model_class(**model_data)
db.add(instance)
created_count += 1
else:
# Fallback to insert if PK missing
instance = model_class(**model_data)
db.add(instance)
created_count += 1
db.flush()
# Enforce single default
try:
if bool(model_data.get("default_printer")):
db.query(model_class).filter(getattr(model_class, pk_field_name_local) != getattr(instance, pk_field_name_local)).update({model_class.default_printer: False})
except Exception:
pass
else:
instance = model_class(**model_data)
db.add(instance)
db.flush() # Ensure PK is available
# Capture PK details for flexible storage linkage (single-column PKs only)
_, pk_names = _get_model_columns(model_class)
@@ -980,6 +1134,10 @@ async def import_csv_data(
"flexible_saved_rows": flexible_saved,
},
}
# Include create/update breakdown for printers
if file_type == "PRINTERS.csv":
result["created_count"] = created_count
result["updated_count"] = updated_count
if errors:
result["warning"] = f"Import completed with {len(errors)} errors"
@@ -987,9 +1145,7 @@ async def import_csv_data(
return result
except Exception as e:
print(f"IMPORT ERROR DEBUG: {type(e).__name__}: {str(e)}")
import traceback
print(f"TRACEBACK: {traceback.format_exc()}")
# Suppress stdout debug prints in API layer
db.rollback()
raise HTTPException(status_code=500, detail=f"Import failed: {str(e)}")
@@ -1071,7 +1227,7 @@ async def validate_csv_file(
content = await file.read()
# Try multiple encodings for legacy CSV files
encodings = ['utf-8', 'windows-1252', 'iso-8859-1', 'cp1252']
encodings = ENCODINGS
csv_content = None
for encoding in encodings:
try:
@@ -1083,18 +1239,6 @@ async def validate_csv_file(
if csv_content is None:
raise HTTPException(status_code=400, detail="Could not decode CSV file. Please ensure it's saved in UTF-8, Windows-1252, or ISO-8859-1 encoding.")
# Parse CSV with fallback to robust line-by-line parsing
def parse_csv_with_fallback(text: str) -> Tuple[List[Dict[str, str]], List[str]]:
try:
reader = csv.DictReader(io.StringIO(text), delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
headers_local = reader.fieldnames or []
rows_local = []
for r in reader:
rows_local.append(r)
return rows_local, headers_local
except Exception:
return parse_csv_robust(text)
rows_list, csv_headers = parse_csv_with_fallback(csv_content)
model_class = CSV_MODEL_MAPPING[file_type]
mapping_info = _build_dynamic_mapping(csv_headers, model_class, file_type)
@@ -1142,9 +1286,7 @@ async def validate_csv_file(
}
except Exception as e:
print(f"VALIDATION ERROR DEBUG: {type(e).__name__}: {str(e)}")
import traceback
print(f"VALIDATION TRACEBACK: {traceback.format_exc()}")
# Suppress stdout debug prints in API layer
raise HTTPException(status_code=500, detail=f"Validation failed: {str(e)}")
@@ -1199,7 +1341,7 @@ async def batch_validate_csv_files(
content = await file.read()
# Try multiple encodings for legacy CSV files (include BOM-friendly utf-8-sig)
encodings = ['utf-8-sig', 'utf-8', 'windows-1252', 'iso-8859-1', 'cp1252']
encodings = ENCODINGS
csv_content = None
for encoding in encodings:
try:
@@ -1302,13 +1444,7 @@ async def batch_import_csv_files(
raise HTTPException(status_code=400, detail="Maximum 25 files allowed per batch")
# Define optimal import order based on dependencies
import_order = [
"STATES.csv", "GRUPLKUP.csv", "EMPLOYEE.csv", "FILETYPE.csv", "FILESTAT.csv",
"TRNSTYPE.csv", "TRNSLKUP.csv", "FOOTERS.csv", "SETUP.csv", "PRINTERS.csv",
"ROLODEX.csv", "PHONE.csv", "FILES.csv", "LEDGER.csv", "TRNSACTN.csv",
"QDROS.csv", "PENSIONS.csv", "PLANINFO.csv", "PAYMENTS.csv", "DEPOSITS.csv",
"FILENOTS.csv", "FORM_INX.csv", "FORM_LST.csv", "FVARLKUP.csv", "RVARLKUP.csv"
]
import_order = IMPORT_ORDER
# Sort uploaded files by optimal import order
file_map = {f.filename: f for f in files}
@@ -1365,7 +1501,7 @@ async def batch_import_csv_files(
saved_path = str(file_path)
except Exception:
saved_path = None
encodings = ['utf-8-sig', 'utf-8', 'windows-1252', 'iso-8859-1', 'cp1252']
encodings = ENCODINGS
csv_content = None
for encoding in encodings:
try:
@@ -1466,7 +1602,7 @@ async def batch_import_csv_files(
saved_path = None
# Try multiple encodings for legacy CSV files
encodings = ['utf-8-sig', 'utf-8', 'windows-1252', 'iso-8859-1', 'cp1252']
encodings = ENCODINGS
csv_content = None
for encoding in encodings:
try:
@@ -1505,6 +1641,8 @@ async def batch_import_csv_files(
imported_count = 0
errors = []
flexible_saved = 0
# Special handling: assign line numbers per form for FORM_LST.csv
form_lst_line_counters: Dict[str, int] = {}
# If replace_existing is True and this is the first file of this type
if replace_existing:
@@ -1523,6 +1661,15 @@ async def batch_import_csv_files(
converted_value = convert_value(row[csv_field], db_field)
if converted_value is not None:
model_data[db_field] = converted_value
# Inject sequential line_number for FORM_LST rows grouped by form_id
if file_type == "FORM_LST.csv":
form_id_value = model_data.get("form_id")
if form_id_value:
current = form_lst_line_counters.get(str(form_id_value), 0) + 1
form_lst_line_counters[str(form_id_value)] = current
if "line_number" not in model_data:
model_data["line_number"] = current
if not any(model_data.values()):
continue
@@ -1697,7 +1844,7 @@ async def batch_import_csv_files(
"completed_with_errors" if summary["successful_files"] > 0 else "failed"
)
audit_row.message = f"Batch import completed: {audit_row.successful_files}/{audit_row.total_files} files"
audit_row.finished_at = datetime.utcnow()
audit_row.finished_at = datetime.now(timezone.utc)
audit_row.details = {
"files": [
{"file_type": r.get("file_type"), "status": r.get("status"), "imported_count": r.get("imported_count", 0), "errors": r.get("errors", 0)}
@@ -1844,13 +1991,7 @@ async def rerun_failed_files(
raise HTTPException(status_code=400, detail="No saved files available to rerun. Upload again.")
# Import order for sorting
import_order = [
"STATES.csv", "GRUPLKUP.csv", "EMPLOYEE.csv", "FILETYPE.csv", "FILESTAT.csv",
"TRNSTYPE.csv", "TRNSLKUP.csv", "FOOTERS.csv", "SETUP.csv", "PRINTERS.csv",
"ROLODEX.csv", "PHONE.csv", "FILES.csv", "LEDGER.csv", "TRNSACTN.csv",
"QDROS.csv", "PENSIONS.csv", "PLANINFO.csv", "PAYMENTS.csv", "DEPOSITS.csv",
"FILENOTS.csv", "FORM_INX.csv", "FORM_LST.csv", "FVARLKUP.csv", "RVARLKUP.csv"
]
import_order = IMPORT_ORDER
order_index = {name: i for i, name in enumerate(import_order)}
items.sort(key=lambda x: order_index.get(x[0], len(import_order) + 1))
@@ -1898,7 +2039,7 @@ async def rerun_failed_files(
if file_type not in CSV_MODEL_MAPPING:
# Flexible-only path
encodings = ['utf-8-sig', 'utf-8', 'windows-1252', 'iso-8859-1', 'cp1252']
encodings = ENCODINGS
csv_content = None
for enc in encodings:
try:
@@ -1964,7 +2105,7 @@ async def rerun_failed_files(
# Known model path
model_class = CSV_MODEL_MAPPING[file_type]
encodings = ['utf-8-sig', 'utf-8', 'windows-1252', 'iso-8859-1', 'cp1252']
encodings = ENCODINGS
csv_content = None
for enc in encodings:
try:
@@ -1996,6 +2137,8 @@ async def rerun_failed_files(
unmapped_headers = mapping_info["unmapped_headers"]
imported_count = 0
errors: List[Dict[str, Any]] = []
# Special handling: assign line numbers per form for FORM_LST.csv
form_lst_line_counters: Dict[str, int] = {}
if replace_existing:
db.query(model_class).delete()
@@ -2013,6 +2156,14 @@ async def rerun_failed_files(
converted_value = convert_value(row[csv_field], db_field)
if converted_value is not None:
model_data[db_field] = converted_value
# Inject sequential line_number for FORM_LST rows grouped by form_id
if file_type == "FORM_LST.csv":
form_id_value = model_data.get("form_id")
if form_id_value:
current = form_lst_line_counters.get(str(form_id_value), 0) + 1
form_lst_line_counters[str(form_id_value)] = current
if "line_number" not in model_data:
model_data["line_number"] = current
if not any(model_data.values()):
continue
required_fields = _get_required_fields(model_class)
@@ -2147,7 +2298,7 @@ async def rerun_failed_files(
"completed_with_errors" if summary["successful_files"] > 0 else "failed"
)
rerun_audit.message = f"Rerun completed: {rerun_audit.successful_files}/{rerun_audit.total_files} files"
rerun_audit.finished_at = datetime.utcnow()
rerun_audit.finished_at = datetime.now(timezone.utc)
rerun_audit.details = {"rerun_of": audit_id}
db.add(rerun_audit)
db.commit()
@@ -2183,7 +2334,7 @@ async def upload_flexible_only(
db.commit()
content = await file.read()
encodings = ["utf-8-sig", "utf-8", "windows-1252", "iso-8859-1", "cp1252"]
encodings = ENCODINGS
csv_content = None
for encoding in encodings:
try:

72
app/api/mortality.py Normal file
View File

@@ -0,0 +1,72 @@
"""
Mortality/Life Table API endpoints
Provides read endpoints to query life tables by age and number tables by month,
filtered by sex (M/F/A) and race (W/B/H/A).
"""
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query, status, Path
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from app.database.base import get_db
from app.models.user import User
from app.auth.security import get_current_user
from app.services.mortality import get_life_values, get_number_value, InvalidCodeError
router = APIRouter()
class LifeResponse(BaseModel):
age: int
sex: str = Field(description="M, F, or A (all)")
race: str = Field(description="W, B, H, or A (all)")
le: Optional[float]
na: Optional[float]
class NumberResponse(BaseModel):
month: int
sex: str = Field(description="M, F, or A (all)")
race: str = Field(description="W, B, H, or A (all)")
na: Optional[float]
@router.get("/life/{age}", response_model=LifeResponse)
async def get_life_entry(
age: int = Path(..., ge=0, description="Age in years (>= 0)"),
sex: str = Query("A", min_length=1, max_length=1, description="M, F, or A (all)"),
race: str = Query("A", min_length=1, max_length=1, description="W, B, H, or A (all)"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Get life expectancy (LE) and number alive (NA) for an age/sex/race."""
try:
result = get_life_values(db, age=age, sex=sex, race=race)
except InvalidCodeError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
if result is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Age not found")
return result
@router.get("/number/{month}", response_model=NumberResponse)
async def get_number_entry(
month: int = Path(..., ge=0, description="Month index (>= 0)"),
sex: str = Query("A", min_length=1, max_length=1, description="M, F, or A (all)"),
race: str = Query("A", min_length=1, max_length=1, description="W, B, H, or A (all)"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Get monthly number alive (NA) for a month/sex/race."""
try:
result = get_number_value(db, month=month, sex=sex, race=race)
except InvalidCodeError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
if result is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Month not found")
return result

File diff suppressed because it is too large Load Diff

View File

@@ -2,8 +2,10 @@
Server-side highlight utilities for search results.
These functions generate HTML snippets with <strong> around matched tokens,
preserving the original casing of the source text. The output is intended to be
sanitized on the client before insertion into the DOM.
preserving the original casing of the source text. All non-HTML segments are
HTML-escaped server-side to prevent injection. Only the <strong> tags added by
this module are emitted as HTML; any pre-existing HTML in source text is
escaped.
"""
from typing import List, Tuple, Any
import re
@@ -42,18 +44,40 @@ def _merge_ranges(ranges: List[Tuple[int, int]]) -> List[Tuple[int, int]]:
def highlight_text(value: str, tokens: List[str]) -> str:
"""Return `value` with case-insensitive matches of `tokens` wrapped in <strong>, preserving original casing."""
"""Return `value` with case-insensitive matches of `tokens` wrapped in <strong>, preserving original casing.
Non-highlighted segments and the highlighted text content are HTML-escaped.
Only the surrounding <strong> wrappers are emitted as markup.
"""
if value is None:
return ""
def _escape_html(text: str) -> str:
# Minimal, safe HTML escaping
if text is None:
return ""
# Replace ampersand first to avoid double-escaping
text = str(text)
text = text.replace("&", "&amp;")
text = text.replace("<", "&lt;")
text = text.replace(">", "&gt;")
text = text.replace('"', "&quot;")
text = text.replace("'", "&#39;")
return text
source = str(value)
if not source or not tokens:
return source
return _escape_html(source)
haystack = source.lower()
ranges: List[Tuple[int, int]] = []
# Deduplicate tokens case-insensitively to avoid redundant scans (parity with client)
unique_needles = []
seen_needles = set()
for t in tokens:
needle = str(t or "").lower()
if not needle:
continue
if needle and needle not in seen_needles:
unique_needles.append(needle)
seen_needles.add(needle)
for needle in unique_needles:
start = 0
last_possible = max(0, len(haystack) - len(needle))
while start <= last_possible and len(needle) > 0:
@@ -63,17 +87,17 @@ def highlight_text(value: str, tokens: List[str]) -> str:
ranges.append((idx, idx + len(needle)))
start = idx + 1
if not ranges:
return source
return _escape_html(source)
parts: List[str] = []
merged = _merge_ranges(ranges)
pos = 0
for s, e in merged:
if pos < s:
parts.append(source[pos:s])
parts.append("<strong>" + source[s:e] + "</strong>")
parts.append(_escape_html(source[pos:s]))
parts.append("<strong>" + _escape_html(source[s:e]) + "</strong>")
pos = e
if pos < len(source):
parts.append(source[pos:])
parts.append(_escape_html(source[pos:]))
return "".join(parts)

View File

@@ -2,21 +2,24 @@
Support ticket API endpoints
"""
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, status, Request
from fastapi import APIRouter, Depends, HTTPException, status, Request, Query
from sqlalchemy.orm import Session, joinedload
from sqlalchemy import func, desc, and_, or_
from datetime import datetime
from datetime import datetime, timezone
import secrets
from app.database.base import get_db
from app.models import User, SupportTicket, TicketResponse as TicketResponseModel, TicketStatus, TicketPriority, TicketCategory
from app.auth.security import get_current_user, get_admin_user
from app.services.audit import audit_service
from app.services.query_utils import apply_sorting, paginate_with_total, tokenized_ilike_filter
from app.api.search_highlight import build_query_tokens
router = APIRouter()
# Pydantic models for API
from pydantic import BaseModel, Field, EmailStr
from pydantic.config import ConfigDict
class TicketCreate(BaseModel):
@@ -57,8 +60,7 @@ class TicketResponseOut(BaseModel):
author_email: Optional[str]
created_at: datetime
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
class TicketDetail(BaseModel):
@@ -81,15 +83,19 @@ class TicketDetail(BaseModel):
assigned_to: Optional[int]
assigned_admin_name: Optional[str]
submitter_name: Optional[str]
responses: List[TicketResponseOut] = []
responses: List[TicketResponseOut] = Field(default_factory=list)
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
class PaginatedTicketsResponse(BaseModel):
items: List[TicketDetail]
total: int
def generate_ticket_number() -> str:
"""Generate unique ticket number like ST-2024-001"""
year = datetime.now().year
year = datetime.now(timezone.utc).year
random_suffix = secrets.token_hex(2).upper()
return f"ST-{year}-{random_suffix}"
@@ -129,7 +135,7 @@ async def create_support_ticket(
ip_address=client_ip,
user_id=current_user.id if current_user else None,
status=TicketStatus.OPEN,
created_at=datetime.utcnow()
created_at=datetime.now(timezone.utc)
)
db.add(new_ticket)
@@ -158,14 +164,18 @@ async def create_support_ticket(
}
@router.get("/tickets", response_model=List[TicketDetail])
@router.get("/tickets", response_model=List[TicketDetail] | PaginatedTicketsResponse)
async def list_tickets(
status: Optional[TicketStatus] = None,
priority: Optional[TicketPriority] = None,
category: Optional[TicketCategory] = None,
assigned_to_me: bool = False,
skip: int = 0,
limit: int = 50,
status: Optional[TicketStatus] = Query(None, description="Filter by ticket status"),
priority: Optional[TicketPriority] = Query(None, description="Filter by ticket priority"),
category: Optional[TicketCategory] = Query(None, description="Filter by ticket category"),
assigned_to_me: bool = Query(False, description="Only include tickets assigned to the current admin"),
search: Optional[str] = Query(None, description="Tokenized search across number, subject, description, contact name/email, current page, and IP"),
skip: int = Query(0, ge=0, description="Offset for pagination"),
limit: int = Query(50, ge=1, le=200, description="Page size"),
sort_by: Optional[str] = Query(None, description="Sort by: created, updated, resolved, priority, status, subject"),
sort_dir: Optional[str] = Query("desc", description="Sort direction: asc or desc"),
include_total: bool = Query(False, description="When true, returns {items, total} instead of a plain list"),
db: Session = Depends(get_db),
current_user: User = Depends(get_admin_user)
):
@@ -186,8 +196,38 @@ async def list_tickets(
query = query.filter(SupportTicket.category == category)
if assigned_to_me:
query = query.filter(SupportTicket.assigned_to == current_user.id)
tickets = query.order_by(desc(SupportTicket.created_at)).offset(skip).limit(limit).all()
# Search across key text fields
if search:
tokens = build_query_tokens(search)
filter_expr = tokenized_ilike_filter(tokens, [
SupportTicket.ticket_number,
SupportTicket.subject,
SupportTicket.description,
SupportTicket.contact_name,
SupportTicket.contact_email,
SupportTicket.current_page,
SupportTicket.ip_address,
])
if filter_expr is not None:
query = query.filter(filter_expr)
# Sorting (whitelisted)
query = apply_sorting(
query,
sort_by,
sort_dir,
allowed={
"created": [SupportTicket.created_at],
"updated": [SupportTicket.updated_at],
"resolved": [SupportTicket.resolved_at],
"priority": [SupportTicket.priority],
"status": [SupportTicket.status],
"subject": [SupportTicket.subject],
},
)
tickets, total = paginate_with_total(query, skip, limit, include_total)
# Format response
result = []
@@ -226,6 +266,8 @@ async def list_tickets(
}
result.append(ticket_dict)
if include_total:
return {"items": result, "total": total or 0}
return result
@@ -312,10 +354,10 @@ async def update_ticket(
# Set resolved timestamp if status changed to resolved
if ticket_data.status == TicketStatus.RESOLVED and ticket.resolved_at is None:
ticket.resolved_at = datetime.utcnow()
ticket.resolved_at = datetime.now(timezone.utc)
changes["resolved_at"] = {"from": None, "to": ticket.resolved_at}
ticket.updated_at = datetime.utcnow()
ticket.updated_at = datetime.now(timezone.utc)
db.commit()
# Audit logging (non-blocking)
@@ -358,13 +400,13 @@ async def add_response(
message=response_data.message,
is_internal=response_data.is_internal,
user_id=current_user.id,
created_at=datetime.utcnow()
created_at=datetime.now(timezone.utc)
)
db.add(response)
# Update ticket timestamp
ticket.updated_at = datetime.utcnow()
ticket.updated_at = datetime.now(timezone.utc)
db.commit()
db.refresh(response)
@@ -386,11 +428,15 @@ async def add_response(
return {"message": "Response added successfully", "response_id": response.id}
@router.get("/my-tickets", response_model=List[TicketDetail])
@router.get("/my-tickets", response_model=List[TicketDetail] | PaginatedTicketsResponse)
async def get_my_tickets(
status: Optional[TicketStatus] = None,
skip: int = 0,
limit: int = 20,
status: Optional[TicketStatus] = Query(None, description="Filter by ticket status"),
search: Optional[str] = Query(None, description="Tokenized search across number, subject, description"),
skip: int = Query(0, ge=0, description="Offset for pagination"),
limit: int = Query(20, ge=1, le=200, description="Page size"),
sort_by: Optional[str] = Query(None, description="Sort by: created, updated, resolved, priority, status, subject"),
sort_dir: Optional[str] = Query("desc", description="Sort direction: asc or desc"),
include_total: bool = Query(False, description="When true, returns {items, total} instead of a plain list"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
@@ -403,7 +449,33 @@ async def get_my_tickets(
if status:
query = query.filter(SupportTicket.status == status)
tickets = query.order_by(desc(SupportTicket.created_at)).offset(skip).limit(limit).all()
# Search within user's tickets
if search:
tokens = build_query_tokens(search)
filter_expr = tokenized_ilike_filter(tokens, [
SupportTicket.ticket_number,
SupportTicket.subject,
SupportTicket.description,
])
if filter_expr is not None:
query = query.filter(filter_expr)
# Sorting (whitelisted)
query = apply_sorting(
query,
sort_by,
sort_dir,
allowed={
"created": [SupportTicket.created_at],
"updated": [SupportTicket.updated_at],
"resolved": [SupportTicket.resolved_at],
"priority": [SupportTicket.priority],
"status": [SupportTicket.status],
"subject": [SupportTicket.subject],
},
)
tickets, total = paginate_with_total(query, skip, limit, include_total)
# Format response (exclude internal responses for regular users)
result = []
@@ -442,6 +514,8 @@ async def get_my_tickets(
}
result.append(ticket_dict)
if include_total:
return {"items": result, "total": total or 0}
return result
@@ -473,7 +547,7 @@ async def get_ticket_stats(
# Recent tickets (last 7 days)
from datetime import timedelta
week_ago = datetime.utcnow() - timedelta(days=7)
week_ago = datetime.now(timezone.utc) - timedelta(days=7)
recent_tickets = db.query(func.count(SupportTicket.id)).filter(
SupportTicket.created_at >= week_ago
).scalar()

View File

@@ -3,6 +3,7 @@ Authentication schemas
"""
from typing import Optional
from pydantic import BaseModel, EmailStr
from pydantic.config import ConfigDict
class UserBase(BaseModel):
@@ -32,8 +33,7 @@ class UserResponse(UserBase):
is_admin: bool
theme_preference: Optional[str] = "light"
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
class ThemePreferenceUpdate(BaseModel):
@@ -45,7 +45,7 @@ class Token(BaseModel):
"""Token response schema"""
access_token: str
token_type: str
refresh_token: str | None = None
refresh_token: Optional[str] = None
class TokenData(BaseModel):

View File

@@ -1,7 +1,7 @@
"""
Authentication and security utilities
"""
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from typing import Optional, Union, Tuple
from uuid import uuid4
from jose import JWTError, jwt
@@ -54,12 +54,12 @@ def _decode_with_rotation(token: str) -> dict:
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""Create JWT access token"""
to_encode = data.copy()
expire = datetime.utcnow() + (
expire = datetime.now(timezone.utc) + (
expires_delta if expires_delta else timedelta(minutes=settings.access_token_expire_minutes)
)
to_encode.update({
"exp": expire,
"iat": datetime.utcnow(),
"iat": datetime.now(timezone.utc),
"type": "access",
})
return _encode_with_rotation(to_encode)
@@ -68,14 +68,14 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -
def create_refresh_token(user: User, user_agent: Optional[str], ip_address: Optional[str], db: Session) -> str:
"""Create refresh token, store its JTI in DB for revocation."""
jti = uuid4().hex
expire = datetime.utcnow() + timedelta(minutes=settings.refresh_token_expire_minutes)
expire = datetime.now(timezone.utc) + timedelta(minutes=settings.refresh_token_expire_minutes)
payload = {
"sub": user.username,
"uid": user.id,
"jti": jti,
"type": "refresh",
"exp": expire,
"iat": datetime.utcnow(),
"iat": datetime.now(timezone.utc),
}
token = _encode_with_rotation(payload)
@@ -84,7 +84,7 @@ def create_refresh_token(user: User, user_agent: Optional[str], ip_address: Opti
jti=jti,
user_agent=user_agent,
ip_address=ip_address,
issued_at=datetime.utcnow(),
issued_at=datetime.now(timezone.utc),
expires_at=expire,
revoked=False,
)
@@ -93,6 +93,15 @@ def create_refresh_token(user: User, user_agent: Optional[str], ip_address: Opti
return token
def _to_utc_aware(dt: Optional[datetime]) -> Optional[datetime]:
"""Convert a datetime to UTC-aware. If naive, assume it's already UTC and attach tzinfo."""
if dt is None:
return None
if dt.tzinfo is None:
return dt.replace(tzinfo=timezone.utc)
return dt.astimezone(timezone.utc)
def verify_token(token: str) -> Optional[str]:
"""Verify JWT token and return username"""
try:
@@ -122,14 +131,20 @@ def decode_refresh_token(token: str) -> Optional[dict]:
def is_refresh_token_revoked(jti: str, db: Session) -> bool:
token_row = db.query(RefreshToken).filter(RefreshToken.jti == jti).first()
return not token_row or token_row.revoked or token_row.expires_at <= datetime.utcnow()
if not token_row:
return True
if token_row.revoked:
return True
expires_at_utc = _to_utc_aware(token_row.expires_at)
now_utc = datetime.now(timezone.utc)
return expires_at_utc is not None and expires_at_utc <= now_utc
def revoke_refresh_token(jti: str, db: Session) -> None:
token_row = db.query(RefreshToken).filter(RefreshToken.jti == jti).first()
if token_row and not token_row.revoked:
token_row.revoked = True
token_row.revoked_at = datetime.utcnow()
token_row.revoked_at = datetime.now(timezone.utc)
db.commit()

View File

@@ -57,6 +57,10 @@ class Settings(BaseSettings):
log_rotation: str = "10 MB"
log_retention: str = "30 days"
# Cache / Redis
cache_enabled: bool = False
redis_url: Optional[str] = None
# pydantic-settings v2 configuration
model_config = SettingsConfigDict(
env_file=".env",

View File

@@ -2,8 +2,7 @@
Database configuration and session management
"""
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.orm import declarative_base, sessionmaker, Session
from typing import Generator
from app.config import settings

248
app/database/fts.py Normal file
View File

@@ -0,0 +1,248 @@
"""
SQLite Full-Text Search (FTS5) helpers.
Creates and maintains FTS virtual tables and triggers to keep them in sync
with their content tables. Designed to be called at app startup.
"""
from typing import Optional
from sqlalchemy.engine import Engine
from sqlalchemy import text
def _execute_ignore_errors(engine: Engine, sql: str) -> None:
"""Execute SQL, ignoring operational errors (e.g., when FTS5 is unavailable)."""
from sqlalchemy.exc import OperationalError
with engine.begin() as conn:
try:
conn.execute(text(sql))
except OperationalError:
# Likely FTS5 extension not available in this SQLite build
pass
def ensure_rolodex_fts(engine: Engine) -> None:
"""Ensure the `rolodex_fts` virtual table and triggers exist and are populated.
This uses content=rolodex so the FTS table shadows the base table and is kept
in sync via triggers.
"""
# Create virtual table (if FTS5 is available)
_create_table = """
CREATE VIRTUAL TABLE IF NOT EXISTS rolodex_fts USING fts5(
id,
first,
last,
city,
email,
memo,
content='rolodex',
content_rowid='rowid'
);
"""
_execute_ignore_errors(engine, _create_table)
# Triggers to keep FTS in sync
_triggers = [
"""
CREATE TRIGGER IF NOT EXISTS rolodex_ai AFTER INSERT ON rolodex BEGIN
INSERT INTO rolodex_fts(rowid, id, first, last, city, email, memo)
VALUES (new.rowid, new.id, new.first, new.last, new.city, new.email, new.memo);
END;
""",
"""
CREATE TRIGGER IF NOT EXISTS rolodex_ad AFTER DELETE ON rolodex BEGIN
INSERT INTO rolodex_fts(rolodex_fts, rowid, id, first, last, city, email, memo)
VALUES ('delete', old.rowid, old.id, old.first, old.last, old.city, old.email, old.memo);
END;
""",
"""
CREATE TRIGGER IF NOT EXISTS rolodex_au AFTER UPDATE ON rolodex BEGIN
INSERT INTO rolodex_fts(rolodex_fts, rowid, id, first, last, city, email, memo)
VALUES ('delete', old.rowid, old.id, old.first, old.last, old.city, old.email, old.memo);
INSERT INTO rolodex_fts(rowid, id, first, last, city, email, memo)
VALUES (new.rowid, new.id, new.first, new.last, new.city, new.email, new.memo);
END;
""",
]
for trig in _triggers:
_execute_ignore_errors(engine, trig)
# Backfill if the FTS table exists but is empty
with engine.begin() as conn:
try:
count_fts = conn.execute(text("SELECT count(*) FROM rolodex_fts")).scalar() # type: ignore
if count_fts == 0:
# Populate from existing rolodex rows
conn.execute(text(
"""
INSERT INTO rolodex_fts(rowid, id, first, last, city, email, memo)
SELECT rowid, id, first, last, city, email, memo FROM rolodex;
"""
))
except Exception:
# If FTS table doesn't exist or any error occurs, ignore silently
pass
def ensure_files_fts(engine: Engine) -> None:
"""Ensure the `files_fts` virtual table and triggers exist and are populated."""
_create_table = """
CREATE VIRTUAL TABLE IF NOT EXISTS files_fts USING fts5(
file_no,
id,
regarding,
file_type,
memo,
content='files',
content_rowid='rowid'
);
"""
_execute_ignore_errors(engine, _create_table)
_triggers = [
"""
CREATE TRIGGER IF NOT EXISTS files_ai AFTER INSERT ON files BEGIN
INSERT INTO files_fts(rowid, file_no, id, regarding, file_type, memo)
VALUES (new.rowid, new.file_no, new.id, new.regarding, new.file_type, new.memo);
END;
""",
"""
CREATE TRIGGER IF NOT EXISTS files_ad AFTER DELETE ON files BEGIN
INSERT INTO files_fts(files_fts, rowid, file_no, id, regarding, file_type, memo)
VALUES ('delete', old.rowid, old.file_no, old.id, old.regarding, old.file_type, old.memo);
END;
""",
"""
CREATE TRIGGER IF NOT EXISTS files_au AFTER UPDATE ON files BEGIN
INSERT INTO files_fts(files_fts, rowid, file_no, id, regarding, file_type, memo)
VALUES ('delete', old.rowid, old.file_no, old.id, old.regarding, old.file_type, old.memo);
INSERT INTO files_fts(rowid, file_no, id, regarding, file_type, memo)
VALUES (new.rowid, new.file_no, new.id, new.regarding, new.file_type, new.memo);
END;
""",
]
for trig in _triggers:
_execute_ignore_errors(engine, trig)
with engine.begin() as conn:
try:
count_fts = conn.execute(text("SELECT count(*) FROM files_fts")).scalar() # type: ignore
if count_fts == 0:
conn.execute(text(
"""
INSERT INTO files_fts(rowid, file_no, id, regarding, file_type, memo)
SELECT rowid, file_no, id, regarding, file_type, memo FROM files;
"""
))
except Exception:
pass
def ensure_ledger_fts(engine: Engine) -> None:
"""Ensure the `ledger_fts` virtual table and triggers exist and are populated."""
_create_table = """
CREATE VIRTUAL TABLE IF NOT EXISTS ledger_fts USING fts5(
file_no,
t_code,
note,
empl_num,
content='ledger',
content_rowid='rowid'
);
"""
_execute_ignore_errors(engine, _create_table)
_triggers = [
"""
CREATE TRIGGER IF NOT EXISTS ledger_ai AFTER INSERT ON ledger BEGIN
INSERT INTO ledger_fts(rowid, file_no, t_code, note, empl_num)
VALUES (new.rowid, new.file_no, new.t_code, new.note, new.empl_num);
END;
""",
"""
CREATE TRIGGER IF NOT EXISTS ledger_ad AFTER DELETE ON ledger BEGIN
INSERT INTO ledger_fts(ledger_fts, rowid, file_no, t_code, note, empl_num)
VALUES ('delete', old.rowid, old.file_no, old.t_code, old.note, old.empl_num);
END;
""",
"""
CREATE TRIGGER IF NOT EXISTS ledger_au AFTER UPDATE ON ledger BEGIN
INSERT INTO ledger_fts(ledger_fts, rowid, file_no, t_code, note, empl_num)
VALUES ('delete', old.rowid, old.file_no, old.t_code, old.note, old.empl_num);
INSERT INTO ledger_fts(rowid, file_no, t_code, note, empl_num)
VALUES (new.rowid, new.file_no, new.t_code, new.note, new.empl_num);
END;
""",
]
for trig in _triggers:
_execute_ignore_errors(engine, trig)
with engine.begin() as conn:
try:
count_fts = conn.execute(text("SELECT count(*) FROM ledger_fts")).scalar() # type: ignore
if count_fts == 0:
conn.execute(text(
"""
INSERT INTO ledger_fts(rowid, file_no, t_code, note, empl_num)
SELECT rowid, file_no, t_code, note, empl_num FROM ledger;
"""
))
except Exception:
pass
def ensure_qdros_fts(engine: Engine) -> None:
"""Ensure the `qdros_fts` virtual table and triggers exist and are populated."""
_create_table = """
CREATE VIRTUAL TABLE IF NOT EXISTS qdros_fts USING fts5(
file_no,
form_name,
pet,
res,
case_number,
content='qdros',
content_rowid='rowid'
);
"""
_execute_ignore_errors(engine, _create_table)
_triggers = [
"""
CREATE TRIGGER IF NOT EXISTS qdros_ai AFTER INSERT ON qdros BEGIN
INSERT INTO qdros_fts(rowid, file_no, form_name, pet, res, case_number)
VALUES (new.rowid, new.file_no, new.form_name, new.pet, new.res, new.case_number);
END;
""",
"""
CREATE TRIGGER IF NOT EXISTS qdros_ad AFTER DELETE ON qdros BEGIN
INSERT INTO qdros_fts(qdros_fts, rowid, file_no, form_name, pet, res, case_number)
VALUES ('delete', old.rowid, old.file_no, old.form_name, old.pet, old.res, old.case_number);
END;
""",
"""
CREATE TRIGGER IF NOT EXISTS qdros_au AFTER UPDATE ON qdros BEGIN
INSERT INTO qdros_fts(qdros_fts, rowid, file_no, form_name, pet, res, case_number)
VALUES ('delete', old.rowid, old.file_no, old.form_name, old.pet, old.res, old.case_number);
INSERT INTO qdros_fts(rowid, file_no, form_name, pet, res, case_number)
VALUES (new.rowid, new.file_no, new.form_name, new.pet, new.res, new.case_number);
END;
""",
]
for trig in _triggers:
_execute_ignore_errors(engine, trig)
with engine.begin() as conn:
try:
count_fts = conn.execute(text("SELECT count(*) FROM qdros_fts")).scalar() # type: ignore
if count_fts == 0:
conn.execute(text(
"""
INSERT INTO qdros_fts(rowid, file_no, form_name, pet, res, case_number)
SELECT rowid, file_no, form_name, pet, res, case_number FROM qdros;
"""
))
except Exception:
pass

31
app/database/indexes.py Normal file
View File

@@ -0,0 +1,31 @@
"""
Database secondary indexes helper.
Creates small B-tree indexes for common equality filters to speed up searches.
Uses CREATE INDEX IF NOT EXISTS so it is safe to call repeatedly at startup
and works for existing databases without running a migration.
"""
from sqlalchemy.engine import Engine
from sqlalchemy import text
def ensure_secondary_indexes(engine: Engine) -> None:
statements = [
# Files
"CREATE INDEX IF NOT EXISTS idx_files_status ON files(status)",
"CREATE INDEX IF NOT EXISTS idx_files_file_type ON files(file_type)",
"CREATE INDEX IF NOT EXISTS idx_files_empl_num ON files(empl_num)",
# Ledger
"CREATE INDEX IF NOT EXISTS idx_ledger_t_type ON ledger(t_type)",
"CREATE INDEX IF NOT EXISTS idx_ledger_empl_num ON ledger(empl_num)",
]
with engine.begin() as conn:
for stmt in statements:
try:
conn.execute(text(stmt))
except Exception:
# Ignore failures (e.g., non-SQLite engines that still support IF NOT EXISTS;
# if not supported, users should manage indexes via migrations)
pass

View File

@@ -0,0 +1,130 @@
"""
Lightweight, idempotent schema updates for SQLite.
Adds newly introduced columns to existing tables when running on an
already-initialized database. Safe to call multiple times.
"""
from typing import Dict
from sqlalchemy.engine import Engine
def _existing_columns(conn, table: str) -> set[str]:
rows = conn.execute(f"PRAGMA table_info('{table}')").fetchall()
return {row[1] for row in rows} # name is column 2
def ensure_schema_updates(engine: Engine) -> None:
"""Ensure missing columns are added for backward-compatible updates."""
# Map of table -> {column: SQL type}
updates: Dict[str, Dict[str, str]] = {
# Forms
"form_index": {
"keyword": "TEXT",
},
# Richer Life/Number tables (forms & pensions harmonized)
"life_tables": {
"le_aa": "FLOAT",
"na_aa": "FLOAT",
"le_am": "FLOAT",
"na_am": "FLOAT",
"le_af": "FLOAT",
"na_af": "FLOAT",
"le_wa": "FLOAT",
"na_wa": "FLOAT",
"le_wm": "FLOAT",
"na_wm": "FLOAT",
"le_wf": "FLOAT",
"na_wf": "FLOAT",
"le_ba": "FLOAT",
"na_ba": "FLOAT",
"le_bm": "FLOAT",
"na_bm": "FLOAT",
"le_bf": "FLOAT",
"na_bf": "FLOAT",
"le_ha": "FLOAT",
"na_ha": "FLOAT",
"le_hm": "FLOAT",
"na_hm": "FLOAT",
"le_hf": "FLOAT",
"na_hf": "FLOAT",
"table_year": "INTEGER",
"table_type": "VARCHAR(45)",
},
"number_tables": {
"month": "INTEGER",
"na_aa": "FLOAT",
"na_am": "FLOAT",
"na_af": "FLOAT",
"na_wa": "FLOAT",
"na_wm": "FLOAT",
"na_wf": "FLOAT",
"na_ba": "FLOAT",
"na_bm": "FLOAT",
"na_bf": "FLOAT",
"na_ha": "FLOAT",
"na_hm": "FLOAT",
"na_hf": "FLOAT",
"table_type": "VARCHAR(45)",
"description": "TEXT",
},
"form_list": {
"status": "VARCHAR(45)",
},
# Printers: add advanced legacy fields
"printers": {
"number": "INTEGER",
"page_break": "VARCHAR(50)",
"setup_st": "VARCHAR(200)",
"reset_st": "VARCHAR(200)",
"b_underline": "VARCHAR(100)",
"e_underline": "VARCHAR(100)",
"b_bold": "VARCHAR(100)",
"e_bold": "VARCHAR(100)",
"phone_book": "BOOLEAN",
"rolodex_info": "BOOLEAN",
"envelope": "BOOLEAN",
"file_cabinet": "BOOLEAN",
"accounts": "BOOLEAN",
"statements": "BOOLEAN",
"calendar": "BOOLEAN",
},
# Pensions
"pension_schedules": {
"vests_on": "DATE",
"vests_at": "FLOAT",
},
"marriage_history": {
"married_from": "DATE",
"married_to": "DATE",
"married_years": "FLOAT",
"service_from": "DATE",
"service_to": "DATE",
"service_years": "FLOAT",
"marital_percent": "FLOAT",
},
"death_benefits": {
"lump1": "FLOAT",
"lump2": "FLOAT",
"growth1": "FLOAT",
"growth2": "FLOAT",
"disc1": "FLOAT",
"disc2": "FLOAT",
},
}
with engine.begin() as conn:
for table, cols in updates.items():
try:
existing = _existing_columns(conn, table)
except Exception:
# Table may not exist yet
continue
for col_name, col_type in cols.items():
if col_name not in existing:
try:
conn.execute(f"ALTER TABLE {table} ADD COLUMN {col_name} {col_type}")
except Exception:
# Ignore if not applicable (other engines) or race condition
pass

View File

@@ -9,6 +9,9 @@ from fastapi.middleware.cors import CORSMiddleware
from app.config import settings
from app.database.base import engine
from app.database.fts import ensure_rolodex_fts, ensure_files_fts, ensure_ledger_fts, ensure_qdros_fts
from app.database.indexes import ensure_secondary_indexes
from app.database.schema_updates import ensure_schema_updates
from app.models import BaseModel
from app.models.user import User
from app.auth.security import get_admin_user
@@ -24,6 +27,21 @@ logger = get_logger("main")
logger.info("Creating database tables")
BaseModel.metadata.create_all(bind=engine)
# Initialize SQLite FTS (if available)
logger.info("Initializing FTS (if available)")
ensure_rolodex_fts(engine)
ensure_files_fts(engine)
ensure_ledger_fts(engine)
ensure_qdros_fts(engine)
# Ensure helpful secondary indexes
logger.info("Ensuring secondary indexes (status, type, employee, etc.)")
ensure_secondary_indexes(engine)
# Ensure idempotent schema updates for added columns
logger.info("Ensuring schema updates (new columns)")
ensure_schema_updates(engine)
# Initialize FastAPI app
logger.info("Initializing FastAPI application", version=settings.app_version, debug=settings.debug)
app = FastAPI(
@@ -71,6 +89,7 @@ from app.api.import_data import router as import_router
from app.api.flexible import router as flexible_router
from app.api.support import router as support_router
from app.api.settings import router as settings_router
from app.api.mortality import router as mortality_router
logger.info("Including API routers")
app.include_router(auth_router, prefix="/api/auth", tags=["authentication"])
@@ -84,6 +103,7 @@ app.include_router(import_router, prefix="/api/import", tags=["import"])
app.include_router(support_router, prefix="/api/support", tags=["support"])
app.include_router(settings_router, prefix="/api/settings", tags=["settings"])
app.include_router(flexible_router, prefix="/api")
app.include_router(mortality_router, prefix="/api/mortality", tags=["mortality"])
@app.get("/", response_class=HTMLResponse)

View File

@@ -46,6 +46,25 @@ def _get_correlation_id(request: Request) -> str:
return str(uuid4())
def _json_safe(value: Any) -> Any:
"""Recursively convert non-JSON-serializable objects (like Exceptions) into strings.
Keeps overall structure intact so tests inspecting error details (e.g. 'loc', 'msg') still work.
"""
# Exception -> string message
if isinstance(value, BaseException):
return str(value)
# Mapping types
if isinstance(value, dict):
return {k: _json_safe(v) for k, v in value.items()}
# Sequence types
if isinstance(value, (list, tuple)):
return [
_json_safe(v) for v in value
]
return value
def _build_error_response(
request: Request,
*,
@@ -66,7 +85,7 @@ def _build_error_response(
"correlation_id": correlation_id,
}
if details is not None:
body["error"]["details"] = details
body["error"]["details"] = _json_safe(details)
response = JSONResponse(content=body, status_code=status_code)
response.headers[ERROR_HEADER_NAME] = correlation_id

View File

@@ -14,12 +14,12 @@ from .flexible import FlexibleImport
from .support import SupportTicket, TicketResponse, TicketStatus, TicketPriority, TicketCategory
from .pensions import (
Pension, PensionSchedule, MarriageHistory, DeathBenefit,
SeparationAgreement, LifeTable, NumberTable
SeparationAgreement, LifeTable, NumberTable, PensionResult
)
from .lookups import (
Employee, FileType, FileStatus, TransactionType, TransactionCode,
State, GroupLookup, Footer, PlanInfo, FormIndex, FormList,
PrinterSetup, SystemSetup
PrinterSetup, SystemSetup, FormKeyword
)
__all__ = [
@@ -28,8 +28,8 @@ __all__ = [
"Deposit", "Payment", "FileNote", "FormVariable", "ReportVariable", "Document", "FlexibleImport",
"SupportTicket", "TicketResponse", "TicketStatus", "TicketPriority", "TicketCategory",
"Pension", "PensionSchedule", "MarriageHistory", "DeathBenefit",
"SeparationAgreement", "LifeTable", "NumberTable",
"SeparationAgreement", "LifeTable", "NumberTable", "PensionResult",
"Employee", "FileType", "FileStatus", "TransactionType", "TransactionCode",
"State", "GroupLookup", "Footer", "PlanInfo", "FormIndex", "FormList",
"PrinterSetup", "SystemSetup"
"PrinterSetup", "SystemSetup", "FormKeyword"
]

View File

@@ -3,7 +3,7 @@ Audit logging models
"""
from sqlalchemy import Column, Integer, String, Text, DateTime, ForeignKey, JSON
from sqlalchemy.orm import relationship
from datetime import datetime
from datetime import datetime, timezone
from app.models.base import BaseModel
@@ -22,7 +22,7 @@ class AuditLog(BaseModel):
details = Column(JSON, nullable=True) # Additional details as JSON
ip_address = Column(String(45), nullable=True) # IPv4/IPv6 address
user_agent = Column(Text, nullable=True) # Browser/client information
timestamp = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
timestamp = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False, index=True)
# Relationships
user = relationship("User", back_populates="audit_logs")
@@ -42,7 +42,7 @@ class LoginAttempt(BaseModel):
ip_address = Column(String(45), nullable=False)
user_agent = Column(Text, nullable=True)
success = Column(Integer, default=0) # 1 for success, 0 for failure
timestamp = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
timestamp = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False, index=True)
failure_reason = Column(String(200), nullable=True) # Reason for failure
def __repr__(self):
@@ -56,8 +56,8 @@ class ImportAudit(BaseModel):
__tablename__ = "import_audit"
id = Column(Integer, primary_key=True, autoincrement=True, index=True)
started_at = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
finished_at = Column(DateTime, nullable=True, index=True)
started_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False, index=True)
finished_at = Column(DateTime(timezone=True), nullable=True, index=True)
status = Column(String(30), nullable=False, default="running", index=True) # running|success|completed_with_errors|failed
total_files = Column(Integer, nullable=False, default=0)
@@ -94,7 +94,7 @@ class ImportAuditFile(BaseModel):
errors = Column(Integer, nullable=False, default=0)
message = Column(String(255), nullable=True)
details = Column(JSON, nullable=True)
created_at = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False, index=True)
audit = relationship("ImportAudit", back_populates="files")

View File

@@ -1,7 +1,7 @@
"""
Authentication-related persistence models
"""
from datetime import datetime
from datetime import datetime, timezone
from typing import Optional
from sqlalchemy import Column, Integer, String, DateTime, Boolean, ForeignKey, UniqueConstraint
@@ -19,10 +19,10 @@ class RefreshToken(BaseModel):
jti = Column(String(64), nullable=False, unique=True, index=True)
user_agent = Column(String(255), nullable=True)
ip_address = Column(String(45), nullable=True)
issued_at = Column(DateTime, default=datetime.utcnow, nullable=False)
expires_at = Column(DateTime, nullable=False, index=True)
issued_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False)
expires_at = Column(DateTime(timezone=True), nullable=False, index=True)
revoked = Column(Boolean, default=False, nullable=False)
revoked_at = Column(DateTime, nullable=True)
revoked_at = Column(DateTime(timezone=True), nullable=True)
# relationships
user = relationship("User")

View File

@@ -1,7 +1,7 @@
"""
Lookup table models based on legacy system analysis
"""
from sqlalchemy import Column, Integer, String, Text, Boolean, Float
from sqlalchemy import Column, Integer, String, Text, Boolean, Float, ForeignKey
from app.models.base import BaseModel
@@ -53,6 +53,9 @@ class FileStatus(BaseModel):
description = Column(String(200), nullable=False) # Description
active = Column(Boolean, default=True) # Is status active
sort_order = Column(Integer, default=0) # Display order
# Legacy fields for typed import support
send = Column(Boolean, default=True) # Should statements print by default
footer_code = Column(String(45), ForeignKey("footers.footer_code")) # Default footer
def __repr__(self):
return f"<FileStatus(code='{self.status_code}', description='{self.description}')>"
@@ -169,9 +172,10 @@ class FormIndex(BaseModel):
"""
__tablename__ = "form_index"
form_id = Column(String(45), primary_key=True, index=True) # Form identifier
form_id = Column(String(45), primary_key=True, index=True) # Form identifier (maps to Name)
form_name = Column(String(200), nullable=False) # Form name
category = Column(String(45)) # Form category
keyword = Column(String(200)) # Legacy FORM_INX Name/Keyword pair
active = Column(Boolean, default=True) # Is form active
def __repr__(self):
@@ -189,6 +193,7 @@ class FormList(BaseModel):
form_id = Column(String(45), nullable=False) # Form identifier
line_number = Column(Integer, nullable=False) # Line number in form
content = Column(Text) # Line content
status = Column(String(45)) # Legacy FORM_LST Status
def __repr__(self):
return f"<FormList(form_id='{self.form_id}', line={self.line_number})>"
@@ -201,12 +206,34 @@ class PrinterSetup(BaseModel):
"""
__tablename__ = "printers"
# Core identity and basic configuration
printer_name = Column(String(100), primary_key=True, index=True) # Printer name
description = Column(String(200)) # Description
driver = Column(String(100)) # Print driver
port = Column(String(20)) # Port/connection
default_printer = Column(Boolean, default=False) # Is default printer
active = Column(Boolean, default=True) # Is printer active
# Legacy numeric printer number (from PRINTERS.csv "Number")
number = Column(Integer)
# Legacy control sequences and formatting (from PRINTERS.csv)
page_break = Column(String(50))
setup_st = Column(String(200))
reset_st = Column(String(200))
b_underline = Column(String(100))
e_underline = Column(String(100))
b_bold = Column(String(100))
e_bold = Column(String(100))
# Optional report configuration toggles (legacy flags)
phone_book = Column(Boolean, default=False)
rolodex_info = Column(Boolean, default=False)
envelope = Column(Boolean, default=False)
file_cabinet = Column(Boolean, default=False)
accounts = Column(Boolean, default=False)
statements = Column(Boolean, default=False)
calendar = Column(Boolean, default=False)
def __repr__(self):
return f"<Printer(name='{self.printer_name}', description='{self.description}')>"
@@ -225,4 +252,19 @@ class SystemSetup(BaseModel):
setting_type = Column(String(20), default="STRING") # DATA type (STRING, INTEGER, FLOAT, BOOLEAN)
def __repr__(self):
return f"<SystemSetup(key='{self.setting_key}', value='{self.setting_value}')>"
return f"<SystemSetup(key='{self.setting_key}', value='{self.setting_value}')>"
class FormKeyword(BaseModel):
"""
Form keyword lookup
Corresponds to INX_LKUP table in legacy system
"""
__tablename__ = "form_keywords"
keyword = Column(String(200), primary_key=True, index=True)
description = Column(String(200))
active = Column(Boolean, default=True)
def __repr__(self):
return f"<FormKeyword(keyword='{self.keyword}')>"

View File

@@ -60,11 +60,13 @@ class PensionSchedule(BaseModel):
file_no = Column(String(45), ForeignKey("files.file_no"), nullable=False)
version = Column(String(10), default="01")
# Schedule details
# Schedule details (legacy vesting fields)
start_date = Column(Date) # Start date for payments
end_date = Column(Date) # End date for payments
payment_amount = Column(Float, default=0.0) # Payment amount
frequency = Column(String(20)) # Monthly, quarterly, etc.
vests_on = Column(Date) # Legacy SCHEDULE.csv Vests_On
vests_at = Column(Float, default=0.0) # Legacy SCHEDULE.csv Vests_At (percent)
# Relationships
file = relationship("File", back_populates="pension_schedules")
@@ -85,6 +87,15 @@ class MarriageHistory(BaseModel):
divorce_date = Column(Date) # Date of divorce/separation
spouse_name = Column(String(100)) # Spouse name
notes = Column(Text) # Additional notes
# Legacy MARRIAGE.csv fields
married_from = Column(Date)
married_to = Column(Date)
married_years = Column(Float, default=0.0)
service_from = Column(Date)
service_to = Column(Date)
service_years = Column(Float, default=0.0)
marital_percent = Column(Float, default=0.0)
# Relationships
file = relationship("File", back_populates="marriage_history")
@@ -105,6 +116,14 @@ class DeathBenefit(BaseModel):
benefit_amount = Column(Float, default=0.0) # Benefit amount
benefit_type = Column(String(45)) # Type of death benefit
notes = Column(Text) # Additional notes
# Legacy DEATH.csv fields
lump1 = Column(Float, default=0.0)
lump2 = Column(Float, default=0.0)
growth1 = Column(Float, default=0.0)
growth2 = Column(Float, default=0.0)
disc1 = Column(Float, default=0.0)
disc2 = Column(Float, default=0.0)
# Relationships
file = relationship("File", back_populates="death_benefits")
@@ -138,10 +157,36 @@ class LifeTable(BaseModel):
id = Column(Integer, primary_key=True, autoincrement=True)
age = Column(Integer, nullable=False) # Age
male_expectancy = Column(Float) # Male life expectancy
female_expectancy = Column(Float) # Female life expectancy
table_year = Column(Integer) # Year of table (e.g., 2023)
table_type = Column(String(45)) # Type of table
# Rich typed columns reflecting legacy LIFETABL.csv headers
# LE_* = Life Expectancy, NA_* = Number Alive/Survivors
le_aa = Column(Float)
na_aa = Column(Float)
le_am = Column(Float)
na_am = Column(Float)
le_af = Column(Float)
na_af = Column(Float)
le_wa = Column(Float)
na_wa = Column(Float)
le_wm = Column(Float)
na_wm = Column(Float)
le_wf = Column(Float)
na_wf = Column(Float)
le_ba = Column(Float)
na_ba = Column(Float)
le_bm = Column(Float)
na_bm = Column(Float)
le_bf = Column(Float)
na_bf = Column(Float)
le_ha = Column(Float)
na_ha = Column(Float)
le_hm = Column(Float)
na_hm = Column(Float)
le_hf = Column(Float)
na_hf = Column(Float)
# Optional metadata retained for future variations
table_year = Column(Integer) # Year/version of table if known
table_type = Column(String(45)) # Source/type of table (optional)
class NumberTable(BaseModel):
@@ -152,7 +197,63 @@ class NumberTable(BaseModel):
__tablename__ = "number_tables"
id = Column(Integer, primary_key=True, autoincrement=True)
table_type = Column(String(45), nullable=False) # Type of table
key_value = Column(String(45), nullable=False) # Key identifier
numeric_value = Column(Float) # Numeric value
description = Column(Text) # Description
month = Column(Integer, nullable=False)
# Rich typed NA_* columns reflecting legacy NUMBERAL.csv headers
na_aa = Column(Float)
na_am = Column(Float)
na_af = Column(Float)
na_wa = Column(Float)
na_wm = Column(Float)
na_wf = Column(Float)
na_ba = Column(Float)
na_bm = Column(Float)
na_bf = Column(Float)
na_ha = Column(Float)
na_hm = Column(Float)
na_hf = Column(Float)
# Optional metadata retained for future variations
table_type = Column(String(45))
description = Column(Text)
class PensionResult(BaseModel):
"""
Computed pension results summary
Corresponds to RESULTS table in legacy system
"""
__tablename__ = "pension_results"
id = Column(Integer, primary_key=True, autoincrement=True)
# Optional linkage if present in future exports
file_no = Column(String(45))
version = Column(String(10))
# Columns observed in legacy RESULTS.csv header
accrued = Column(Float)
start_age = Column(Integer)
cola = Column(Float)
withdrawal = Column(String(45))
pre_dr = Column(Float)
post_dr = Column(Float)
tax_rate = Column(Float)
age = Column(Integer)
years_from = Column(Float)
life_exp = Column(Float)
ev_monthly = Column(Float)
payments = Column(Float)
pay_out = Column(Float)
fund_value = Column(Float)
pv = Column(Float)
mortality = Column(Float)
pv_am = Column(Float)
pv_amt = Column(Float)
pv_pre_db = Column(Float)
pv_annuity = Column(Float)
wv_at = Column(Float)
pv_plan = Column(Float)
years_married = Column(Float)
years_service = Column(Float)
marr_per = Column(Float)
marr_amt = Column(Float)

View File

@@ -3,7 +3,7 @@ Support ticket models for help desk functionality
"""
from sqlalchemy import Column, Integer, String, Text, DateTime, Boolean, ForeignKey, Enum
from sqlalchemy.orm import relationship
from datetime import datetime
from datetime import datetime, timezone
import enum
from app.models.base import BaseModel
@@ -63,9 +63,9 @@ class SupportTicket(BaseModel):
ip_address = Column(String(45)) # IP address
# Timestamps
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
resolved_at = Column(DateTime)
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False)
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
resolved_at = Column(DateTime(timezone=True))
# Admin assignment
assigned_to = Column(Integer, ForeignKey("users.id"))
@@ -95,7 +95,7 @@ class TicketResponse(BaseModel):
author_email = Column(String(100)) # For non-user responses
# Timestamps
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False)
# Relationships
ticket = relationship("SupportTicket", back_populates="responses")

View File

@@ -3,7 +3,7 @@ Audit logging service
"""
import json
from typing import Dict, Any, Optional
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from sqlalchemy.orm import Session
from fastapi import Request
@@ -65,7 +65,7 @@ class AuditService:
details=details,
ip_address=ip_address,
user_agent=user_agent,
timestamp=datetime.utcnow()
timestamp=datetime.now(timezone.utc)
)
try:
@@ -76,7 +76,7 @@ class AuditService:
except Exception as e:
db.rollback()
# Log the error but don't fail the main operation
logger.error("Failed to log audit entry", error=str(e), action=action, user_id=user_id)
logger.error("Failed to log audit entry", error=str(e), action=action)
return audit_log
@staticmethod
@@ -119,7 +119,7 @@ class AuditService:
ip_address=ip_address or "unknown",
user_agent=user_agent,
success=1 if success else 0,
timestamp=datetime.utcnow(),
timestamp=datetime.now(timezone.utc),
failure_reason=failure_reason if not success else None
)
@@ -252,7 +252,7 @@ class AuditService:
Returns:
List of failed login attempts
"""
cutoff_time = datetime.utcnow() - timedelta(hours=hours)
cutoff_time = datetime.now(timezone.utc) - timedelta(hours=hours)
query = db.query(LoginAttempt).filter(
LoginAttempt.success == 0,
LoginAttempt.timestamp >= cutoff_time

98
app/services/cache.py Normal file
View File

@@ -0,0 +1,98 @@
"""
Cache utilities with optional Redis backend.
If Redis is not configured or unavailable, all functions degrade to no-ops.
"""
from __future__ import annotations
import asyncio
import json
import hashlib
from typing import Any, Optional
try:
import redis.asyncio as redis # type: ignore
except Exception: # pragma: no cover - allow running without redis installed
redis = None # type: ignore
from app.config import settings
_client: Optional["redis.Redis"] = None # type: ignore
_lock = asyncio.Lock()
async def _get_client() -> Optional["redis.Redis"]: # type: ignore
"""Lazily initialize and return a shared Redis client if enabled."""
global _client
if not getattr(settings, "redis_url", None) or not getattr(settings, "cache_enabled", False):
return None
if redis is None:
return None
if _client is not None:
return _client
async with _lock:
if _client is None:
try:
_client = redis.from_url(settings.redis_url, decode_responses=True) # type: ignore
except Exception:
_client = None
return _client
def _stable_hash(obj: Any) -> str:
data = json.dumps(obj, sort_keys=True, separators=(",", ":"))
return hashlib.sha1(data.encode("utf-8")).hexdigest()
def build_key(kind: str, user_id: Optional[str], parts: dict) -> str:
payload = {"u": user_id or "anon", "p": parts}
return f"search:{kind}:v1:{_stable_hash(payload)}"
async def cache_get_json(kind: str, user_id: Optional[str], parts: dict) -> Optional[Any]:
client = await _get_client()
if client is None:
return None
key = build_key(kind, user_id, parts)
try:
raw = await client.get(key)
if raw is None:
return None
return json.loads(raw)
except Exception:
return None
async def cache_set_json(kind: str, user_id: Optional[str], parts: dict, value: Any, ttl_seconds: int) -> None:
client = await _get_client()
if client is None:
return
key = build_key(kind, user_id, parts)
try:
await client.set(key, json.dumps(value, separators=(",", ":")), ex=ttl_seconds)
except Exception:
return
async def invalidate_prefix(prefix: str) -> None:
client = await _get_client()
if client is None:
return
try:
# Use SCAN to avoid blocking Redis
async for key in client.scan_iter(match=f"{prefix}*"):
try:
await client.delete(key)
except Exception:
pass
except Exception:
return
async def invalidate_search_cache() -> None:
# Wipe both global search and suggestions namespaces
await invalidate_prefix("search:global:")
await invalidate_prefix("search:suggestions:")

View File

@@ -0,0 +1,141 @@
from typing import Optional, List
from sqlalchemy import or_, and_, func, asc, desc
from app.models.rolodex import Rolodex
def apply_customer_filters(base_query, search: Optional[str], group: Optional[str], state: Optional[str], groups: Optional[List[str]], states: Optional[List[str]]):
"""Apply shared search and group/state filters to the provided base_query.
This helper is used by both list and export endpoints to keep logic in sync.
"""
s = (search or "").strip()
if s:
s_lower = s.lower()
tokens = [t for t in s_lower.split() if t]
contains_any = or_(
func.lower(Rolodex.id).contains(s_lower),
func.lower(Rolodex.last).contains(s_lower),
func.lower(Rolodex.first).contains(s_lower),
func.lower(Rolodex.middle).contains(s_lower),
func.lower(Rolodex.city).contains(s_lower),
func.lower(Rolodex.email).contains(s_lower),
)
name_tokens = [
or_(
func.lower(Rolodex.first).contains(tok),
func.lower(Rolodex.middle).contains(tok),
func.lower(Rolodex.last).contains(tok),
)
for tok in tokens
]
combined = contains_any if not name_tokens else or_(contains_any, and_(*name_tokens))
last_first_filter = None
if "," in s_lower:
last_part, first_part = [p.strip() for p in s_lower.split(",", 1)]
if last_part and first_part:
last_first_filter = and_(
func.lower(Rolodex.last).contains(last_part),
func.lower(Rolodex.first).contains(first_part),
)
elif last_part:
last_first_filter = func.lower(Rolodex.last).contains(last_part)
final_filter = or_(combined, last_first_filter) if last_first_filter is not None else combined
base_query = base_query.filter(final_filter)
effective_groups = [g for g in (groups or []) if g] or ([group] if group else [])
if effective_groups:
base_query = base_query.filter(Rolodex.group.in_(effective_groups))
effective_states = [s for s in (states or []) if s] or ([state] if state else [])
if effective_states:
base_query = base_query.filter(Rolodex.abrev.in_(effective_states))
return base_query
def apply_customer_sorting(base_query, sort_by: Optional[str], sort_dir: Optional[str]):
"""Apply shared sorting to the provided base_query.
Supported fields: id, name (last,first), city (city,state), email.
Unknown fields fall back to id. Sorting is case-insensitive for strings.
"""
normalized_sort_by = (sort_by or "id").lower()
normalized_sort_dir = (sort_dir or "asc").lower()
is_desc = normalized_sort_dir == "desc"
order_columns = []
if normalized_sort_by == "id":
order_columns = [Rolodex.id]
elif normalized_sort_by == "name":
order_columns = [Rolodex.last, Rolodex.first]
elif normalized_sort_by == "city":
order_columns = [Rolodex.city, Rolodex.abrev]
elif normalized_sort_by == "email":
order_columns = [Rolodex.email]
else:
order_columns = [Rolodex.id]
ordered = []
for col in order_columns:
try:
expr = func.lower(col) if col.type.python_type in (str,) else col # type: ignore[attr-defined]
except Exception:
expr = col
ordered.append(desc(expr) if is_desc else asc(expr))
if ordered:
base_query = base_query.order_by(*ordered)
return base_query
def prepare_customer_csv_rows(customers: List[Rolodex], fields: Optional[List[str]]):
"""Prepare CSV header and rows for the given customers and requested fields.
Returns a tuple: (header_row, rows), where header_row is a list of column
titles and rows is a list of row lists ready to be written by csv.writer.
"""
allowed_fields_in_order = ["id", "name", "group", "city", "state", "phone", "email"]
header_names = {
"id": "Customer ID",
"name": "Name",
"group": "Group",
"city": "City",
"state": "State",
"phone": "Primary Phone",
"email": "Email",
}
requested = [f.lower() for f in (fields or []) if isinstance(f, str)]
selected_fields = [f for f in allowed_fields_in_order if f in requested] if requested else allowed_fields_in_order
if not selected_fields:
selected_fields = allowed_fields_in_order
header_row = [header_names[f] for f in selected_fields]
rows: List[List[str]] = []
for c in customers:
full_name = f"{(c.first or '').strip()} {(c.last or '').strip()}".strip()
primary_phone = ""
try:
if getattr(c, "phone_numbers", None):
primary_phone = c.phone_numbers[0].phone or ""
except Exception:
primary_phone = ""
row_map = {
"id": c.id,
"name": full_name,
"group": c.group or "",
"city": c.city or "",
"state": c.abrev or "",
"phone": primary_phone,
"email": c.email or "",
}
rows.append([row_map[f] for f in selected_fields])
return header_row, rows

127
app/services/mortality.py Normal file
View File

@@ -0,0 +1,127 @@
"""
Mortality/Life table utilities.
Helpers to query `life_tables` and `number_tables` by age/month and
return values filtered by sex/race using compact codes:
- sex: M, F, A (All)
- race: W (White), B (Black), H (Hispanic), A (All)
Column naming in tables follows the pattern:
- LifeTable: le_{race}{sex}, na_{race}{sex}
- NumberTable: na_{race}{sex}
Examples:
- race=W, sex=M => suffix "wm" (columns `le_wm`, `na_wm`)
- race=A, sex=F => suffix "af" (columns `le_af`, `na_af`)
- race=H, sex=A => suffix "ha" (columns `le_ha`, `na_ha`)
"""
from __future__ import annotations
from typing import Dict, Optional, Tuple
from sqlalchemy.orm import Session
from app.models.pensions import LifeTable, NumberTable
_RACE_MAP: Dict[str, str] = {
"W": "w", # White
"B": "b", # Black
"H": "h", # Hispanic
"A": "a", # All races
}
_SEX_MAP: Dict[str, str] = {
"M": "m",
"F": "f",
"A": "a", # All sexes
}
class InvalidCodeError(ValueError):
pass
def _normalize_codes(sex: str, race: str) -> Tuple[str, str, str]:
"""Validate/normalize sex and race to construct the column suffix.
Returns (suffix, sex_u, race_u) where suffix is lowercase like "wm".
Raises InvalidCodeError on invalid inputs.
"""
sex_u = (sex or "").strip().upper()
race_u = (race or "").strip().upper()
if sex_u not in _SEX_MAP:
raise InvalidCodeError(f"Invalid sex code '{sex}'. Expected one of: {', '.join(_SEX_MAP.keys())}")
if race_u not in _RACE_MAP:
raise InvalidCodeError(f"Invalid race code '{race}'. Expected one of: {', '.join(_RACE_MAP.keys())}")
return _RACE_MAP[race_u] + _SEX_MAP[sex_u], sex_u, race_u
def get_life_values(
db: Session,
*,
age: int,
sex: str,
race: str,
) -> Optional[Dict[str, Optional[float]]]:
"""Return life table LE and NA values for a given age, sex, and race.
Returns dict: {"age": int, "sex": str, "race": str, "le": float|None, "na": float|None}
Returns None if the age row does not exist.
Raises InvalidCodeError for invalid codes.
"""
suffix, sex_u, race_u = _normalize_codes(sex, race)
row: Optional[LifeTable] = db.query(LifeTable).filter(LifeTable.age == age).first()
if not row:
return None
le_col = f"le_{suffix}"
na_col = f"na_{suffix}"
le_val = getattr(row, le_col, None)
na_val = getattr(row, na_col, None)
return {
"age": int(age),
"sex": sex_u,
"race": race_u,
"le": float(le_val) if le_val is not None else None,
"na": float(na_val) if na_val is not None else None,
}
def get_number_value(
db: Session,
*,
month: int,
sex: str,
race: str,
) -> Optional[Dict[str, Optional[float]]]:
"""Return number table NA value for a given month, sex, and race.
Returns dict: {"month": int, "sex": str, "race": str, "na": float|None}
Returns None if the month row does not exist.
Raises InvalidCodeError for invalid codes.
"""
suffix, sex_u, race_u = _normalize_codes(sex, race)
row: Optional[NumberTable] = db.query(NumberTable).filter(NumberTable.month == month).first()
if not row:
return None
na_col = f"na_{suffix}"
na_val = getattr(row, na_col, None)
return {
"month": int(month),
"sex": sex_u,
"race": race_u,
"na": float(na_val) if na_val is not None else None,
}
__all__ = [
"InvalidCodeError",
"get_life_values",
"get_number_value",
]

View File

@@ -0,0 +1,72 @@
from typing import Iterable, Optional, Sequence
from sqlalchemy import or_, and_, asc, desc, func
from sqlalchemy.sql.elements import BinaryExpression
from sqlalchemy.sql.schema import Column
def tokenized_ilike_filter(tokens: Sequence[str], columns: Sequence[Column]) -> Optional[BinaryExpression]:
"""Build an AND-of-ORs case-insensitive LIKE filter across columns for each token.
Example: AND(OR(col1 ILIKE %t1%, col2 ILIKE %t1%), OR(col1 ILIKE %t2%, ...))
Returns None when tokens or columns are empty.
"""
if not tokens or not columns:
return None
per_token_clauses = []
for term in tokens:
term = str(term or "").strip()
if not term:
continue
per_token_clauses.append(or_(*[c.ilike(f"%{term}%") for c in columns]))
if not per_token_clauses:
return None
return and_(*per_token_clauses)
def apply_pagination(query, skip: int, limit: int):
"""Apply offset/limit pagination to a SQLAlchemy query in a DRY way."""
return query.offset(skip).limit(limit)
def paginate_with_total(query, skip: int, limit: int, include_total: bool):
"""Return (items, total|None) applying pagination and optionally counting total.
This avoids duplicating count + pagination logic at each endpoint.
"""
total_count = query.count() if include_total else None
items = apply_pagination(query, skip, limit).all()
return items, total_count
def apply_sorting(query, sort_by: Optional[str], sort_dir: Optional[str], allowed: dict[str, list[Column]]):
"""Apply case-insensitive sorting per a whitelist of allowed fields.
allowed: mapping from field name -> list of columns to sort by, in priority order.
For string columns, compares using lower(column) for stable ordering.
Unknown sort_by falls back to the first key in allowed.
sort_dir: "asc" or "desc" (default asc)
"""
if not allowed:
return query
normalized_sort_by = (sort_by or next(iter(allowed.keys()))).lower()
normalized_sort_dir = (sort_dir or "asc").lower()
is_desc = normalized_sort_dir == "desc"
columns = allowed.get(normalized_sort_by)
if not columns:
columns = allowed.get(next(iter(allowed.keys())))
if not columns:
return query
order_exprs = []
for col in columns:
try:
expr = func.lower(col) if getattr(col.type, "python_type", str) is str else col
except Exception:
expr = col
order_exprs.append(desc(expr) if is_desc else asc(expr))
if order_exprs:
query = query.order_by(*order_exprs)
return query