fixes and refactor
This commit is contained in:
360
app/api/admin.py
360
app/api/admin.py
@@ -1,7 +1,7 @@
|
||||
"""
|
||||
Comprehensive Admin API endpoints - User management, system settings, audit logging
|
||||
"""
|
||||
from typing import List, Dict, Any, Optional
|
||||
from typing import List, Dict, Any, Optional, Union
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, UploadFile, File, Query, Body, Request
|
||||
from fastapi.responses import FileResponse
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
@@ -13,24 +13,27 @@ import hashlib
|
||||
import secrets
|
||||
import shutil
|
||||
import time
|
||||
from datetime import datetime, timedelta, date
|
||||
from datetime import datetime, timedelta, date, timezone
|
||||
from pathlib import Path
|
||||
|
||||
from app.database.base import get_db
|
||||
from app.api.search_highlight import build_query_tokens
|
||||
|
||||
# Track application start time
|
||||
APPLICATION_START_TIME = time.time()
|
||||
from app.models import User, Rolodex, File as FileModel, Ledger, QDRO, AuditLog, LoginAttempt
|
||||
from app.models.lookups import SystemSetup, Employee, FileType, FileStatus, TransactionType, TransactionCode, State, FormIndex
|
||||
from app.models.lookups import SystemSetup, Employee, FileType, FileStatus, TransactionType, TransactionCode, State, FormIndex, PrinterSetup
|
||||
from app.auth.security import get_admin_user, get_password_hash, create_access_token
|
||||
from app.services.audit import audit_service
|
||||
from app.config import settings
|
||||
from app.services.query_utils import apply_sorting, tokenized_ilike_filter, paginate_with_total
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# Enhanced Admin Schemas
|
||||
from pydantic import BaseModel, Field, EmailStr
|
||||
from pydantic.config import ConfigDict
|
||||
|
||||
class SystemStats(BaseModel):
|
||||
"""Enhanced system statistics"""
|
||||
@@ -91,8 +94,7 @@ class UserResponse(BaseModel):
|
||||
created_at: Optional[datetime]
|
||||
updated_at: Optional[datetime]
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
class PasswordReset(BaseModel):
|
||||
"""Password reset request"""
|
||||
@@ -124,8 +126,7 @@ class AuditLogEntry(BaseModel):
|
||||
user_agent: Optional[str]
|
||||
timestamp: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
class BackupInfo(BaseModel):
|
||||
"""Backup information"""
|
||||
@@ -135,6 +136,132 @@ class BackupInfo(BaseModel):
|
||||
backup_type: str
|
||||
status: str
|
||||
|
||||
|
||||
class PrinterSetupBase(BaseModel):
|
||||
"""Base schema for printer setup"""
|
||||
description: Optional[str] = None
|
||||
driver: Optional[str] = None
|
||||
port: Optional[str] = None
|
||||
default_printer: Optional[bool] = None
|
||||
active: Optional[bool] = None
|
||||
number: Optional[int] = None
|
||||
page_break: Optional[str] = None
|
||||
setup_st: Optional[str] = None
|
||||
reset_st: Optional[str] = None
|
||||
b_underline: Optional[str] = None
|
||||
e_underline: Optional[str] = None
|
||||
b_bold: Optional[str] = None
|
||||
e_bold: Optional[str] = None
|
||||
phone_book: Optional[bool] = None
|
||||
rolodex_info: Optional[bool] = None
|
||||
envelope: Optional[bool] = None
|
||||
file_cabinet: Optional[bool] = None
|
||||
accounts: Optional[bool] = None
|
||||
statements: Optional[bool] = None
|
||||
calendar: Optional[bool] = None
|
||||
|
||||
|
||||
class PrinterSetupCreate(PrinterSetupBase):
|
||||
printer_name: str
|
||||
|
||||
|
||||
class PrinterSetupUpdate(PrinterSetupBase):
|
||||
pass
|
||||
|
||||
|
||||
class PrinterSetupResponse(PrinterSetupBase):
|
||||
printer_name: str
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
# Printer Setup Management
|
||||
|
||||
@router.get("/printers", response_model=List[PrinterSetupResponse])
|
||||
async def list_printers(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_admin_user)
|
||||
):
|
||||
printers = db.query(PrinterSetup).order_by(PrinterSetup.printer_name.asc()).all()
|
||||
return printers
|
||||
|
||||
|
||||
@router.get("/printers/{printer_name}", response_model=PrinterSetupResponse)
|
||||
async def get_printer(
|
||||
printer_name: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_admin_user)
|
||||
):
|
||||
printer = db.query(PrinterSetup).filter(PrinterSetup.printer_name == printer_name).first()
|
||||
if not printer:
|
||||
raise HTTPException(status_code=404, detail="Printer not found")
|
||||
return printer
|
||||
|
||||
|
||||
@router.post("/printers", response_model=PrinterSetupResponse)
|
||||
async def create_printer(
|
||||
payload: PrinterSetupCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_admin_user)
|
||||
):
|
||||
exists = db.query(PrinterSetup).filter(PrinterSetup.printer_name == payload.printer_name).first()
|
||||
if exists:
|
||||
raise HTTPException(status_code=400, detail="Printer already exists")
|
||||
data = payload.model_dump(exclude_unset=True)
|
||||
instance = PrinterSetup(**data)
|
||||
db.add(instance)
|
||||
# Enforce single default printer
|
||||
if data.get("default_printer"):
|
||||
try:
|
||||
db.query(PrinterSetup).filter(PrinterSetup.printer_name != instance.printer_name).update({PrinterSetup.default_printer: False})
|
||||
except Exception:
|
||||
pass
|
||||
db.commit()
|
||||
db.refresh(instance)
|
||||
return instance
|
||||
|
||||
|
||||
@router.put("/printers/{printer_name}", response_model=PrinterSetupResponse)
|
||||
async def update_printer(
|
||||
printer_name: str,
|
||||
payload: PrinterSetupUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_admin_user)
|
||||
):
|
||||
instance = db.query(PrinterSetup).filter(PrinterSetup.printer_name == printer_name).first()
|
||||
if not instance:
|
||||
raise HTTPException(status_code=404, detail="Printer not found")
|
||||
updates = payload.model_dump(exclude_unset=True)
|
||||
for k, v in updates.items():
|
||||
setattr(instance, k, v)
|
||||
# Enforce single default printer when set true
|
||||
if updates.get("default_printer"):
|
||||
try:
|
||||
db.query(PrinterSetup).filter(PrinterSetup.printer_name != instance.printer_name).update({PrinterSetup.default_printer: False})
|
||||
except Exception:
|
||||
pass
|
||||
db.commit()
|
||||
db.refresh(instance)
|
||||
return instance
|
||||
|
||||
|
||||
@router.delete("/printers/{printer_name}")
|
||||
async def delete_printer(
|
||||
printer_name: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_admin_user)
|
||||
):
|
||||
instance = db.query(PrinterSetup).filter(PrinterSetup.printer_name == printer_name).first()
|
||||
if not instance:
|
||||
raise HTTPException(status_code=404, detail="Printer not found")
|
||||
db.delete(instance)
|
||||
db.commit()
|
||||
return {"message": "Printer deleted"}
|
||||
|
||||
|
||||
|
||||
class LookupTableInfo(BaseModel):
|
||||
"""Lookup table information"""
|
||||
table_name: str
|
||||
@@ -202,7 +329,7 @@ async def system_health(
|
||||
# Count active sessions (simplified)
|
||||
try:
|
||||
active_sessions = db.query(User).filter(
|
||||
User.last_login > datetime.now() - timedelta(hours=24)
|
||||
User.last_login > datetime.now(timezone.utc) - timedelta(hours=24)
|
||||
).count()
|
||||
except:
|
||||
active_sessions = 0
|
||||
@@ -215,7 +342,7 @@ async def system_health(
|
||||
backup_files = list(backup_dir.glob("*.db"))
|
||||
if backup_files:
|
||||
latest_backup = max(backup_files, key=lambda p: p.stat().st_mtime)
|
||||
backup_age = datetime.now() - datetime.fromtimestamp(latest_backup.stat().st_mtime)
|
||||
backup_age = datetime.now(timezone.utc) - datetime.fromtimestamp(latest_backup.stat().st_mtime, tz=timezone.utc)
|
||||
last_backup = latest_backup.name
|
||||
if backup_age.days > 7:
|
||||
alerts.append(f"Last backup is {backup_age.days} days old")
|
||||
@@ -257,7 +384,7 @@ async def system_statistics(
|
||||
|
||||
# Count active users (logged in within last 30 days)
|
||||
total_active_users = db.query(func.count(User.id)).filter(
|
||||
User.last_login > datetime.now() - timedelta(days=30)
|
||||
User.last_login > datetime.now(timezone.utc) - timedelta(days=30)
|
||||
).scalar()
|
||||
|
||||
# Count admin users
|
||||
@@ -308,7 +435,7 @@ async def system_statistics(
|
||||
recent_activity.append({
|
||||
"type": "customer_added",
|
||||
"description": f"Customer {customer.first} {customer.last} added",
|
||||
"timestamp": datetime.now().isoformat()
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
except:
|
||||
pass
|
||||
@@ -409,7 +536,7 @@ async def export_table(
|
||||
# Create exports directory if it doesn't exist
|
||||
os.makedirs("exports", exist_ok=True)
|
||||
|
||||
filename = f"exports/{table_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
|
||||
filename = f"exports/{table_name}_{datetime.now(timezone.utc).strftime('%Y%m%d_%H%M%S')}.csv"
|
||||
|
||||
try:
|
||||
if table_name.lower() == "customers" or table_name.lower() == "rolodex":
|
||||
@@ -470,10 +597,10 @@ async def download_backup(
|
||||
if "sqlite" in settings.database_url:
|
||||
db_path = settings.database_url.replace("sqlite:///", "")
|
||||
if os.path.exists(db_path):
|
||||
return FileResponse(
|
||||
return FileResponse(
|
||||
db_path,
|
||||
media_type='application/octet-stream',
|
||||
filename=f"delphi_backup_{datetime.now().strftime('%Y%m%d_%H%M%S')}.db"
|
||||
filename=f"delphi_backup_{datetime.now(timezone.utc).strftime('%Y%m%d_%H%M%S')}.db"
|
||||
)
|
||||
|
||||
raise HTTPException(
|
||||
@@ -484,12 +611,20 @@ async def download_backup(
|
||||
|
||||
# User Management Endpoints
|
||||
|
||||
@router.get("/users", response_model=List[UserResponse])
|
||||
class PaginatedUsersResponse(BaseModel):
|
||||
items: List[UserResponse]
|
||||
total: int
|
||||
|
||||
|
||||
@router.get("/users", response_model=Union[List[UserResponse], PaginatedUsersResponse])
|
||||
async def list_users(
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(100, ge=1, le=1000),
|
||||
search: Optional[str] = Query(None),
|
||||
active_only: bool = Query(False),
|
||||
sort_by: Optional[str] = Query(None, description="Sort by: username, email, first_name, last_name, created, updated"),
|
||||
sort_dir: Optional[str] = Query("asc", description="Sort direction: asc or desc"),
|
||||
include_total: bool = Query(False, description="When true, returns {items, total} instead of a plain list"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_admin_user)
|
||||
):
|
||||
@@ -498,19 +633,38 @@ async def list_users(
|
||||
query = db.query(User)
|
||||
|
||||
if search:
|
||||
query = query.filter(
|
||||
or_(
|
||||
User.username.ilike(f"%{search}%"),
|
||||
User.email.ilike(f"%{search}%"),
|
||||
User.first_name.ilike(f"%{search}%"),
|
||||
User.last_name.ilike(f"%{search}%")
|
||||
)
|
||||
)
|
||||
# DRY: tokenize and apply case-insensitive multi-field search
|
||||
tokens = build_query_tokens(search)
|
||||
filter_expr = tokenized_ilike_filter(tokens, [
|
||||
User.username,
|
||||
User.email,
|
||||
User.first_name,
|
||||
User.last_name,
|
||||
])
|
||||
if filter_expr is not None:
|
||||
query = query.filter(filter_expr)
|
||||
|
||||
if active_only:
|
||||
query = query.filter(User.is_active == True)
|
||||
|
||||
users = query.offset(skip).limit(limit).all()
|
||||
# Sorting (whitelisted)
|
||||
query = apply_sorting(
|
||||
query,
|
||||
sort_by,
|
||||
sort_dir,
|
||||
allowed={
|
||||
"username": [User.username],
|
||||
"email": [User.email],
|
||||
"first_name": [User.first_name],
|
||||
"last_name": [User.last_name],
|
||||
"created": [User.created_at],
|
||||
"updated": [User.updated_at],
|
||||
},
|
||||
)
|
||||
|
||||
users, total = paginate_with_total(query, skip, limit, include_total)
|
||||
if include_total:
|
||||
return {"items": users, "total": total or 0}
|
||||
return users
|
||||
|
||||
|
||||
@@ -567,8 +721,8 @@ async def create_user(
|
||||
hashed_password=hashed_password,
|
||||
is_admin=user_data.is_admin,
|
||||
is_active=user_data.is_active,
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now()
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
db.add(new_user)
|
||||
@@ -659,7 +813,7 @@ async def update_user(
|
||||
changes[field] = {"from": getattr(user, field), "to": value}
|
||||
setattr(user, field, value)
|
||||
|
||||
user.updated_at = datetime.now()
|
||||
user.updated_at = datetime.now(timezone.utc)
|
||||
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
@@ -702,7 +856,7 @@ async def delete_user(
|
||||
|
||||
# Soft delete by deactivating
|
||||
user.is_active = False
|
||||
user.updated_at = datetime.now()
|
||||
user.updated_at = datetime.now(timezone.utc)
|
||||
|
||||
db.commit()
|
||||
|
||||
@@ -744,7 +898,7 @@ async def reset_user_password(
|
||||
|
||||
# Update password
|
||||
user.hashed_password = get_password_hash(password_data.new_password)
|
||||
user.updated_at = datetime.now()
|
||||
user.updated_at = datetime.now(timezone.utc)
|
||||
|
||||
db.commit()
|
||||
|
||||
@@ -1046,7 +1200,7 @@ async def create_backup(
|
||||
backup_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Generate backup filename
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
|
||||
backup_filename = f"delphi_backup_{timestamp}.db"
|
||||
backup_path = backup_dir / backup_filename
|
||||
|
||||
@@ -1063,7 +1217,7 @@ async def create_backup(
|
||||
"backup_info": {
|
||||
"filename": backup_filename,
|
||||
"size": f"{backup_size / (1024*1024):.1f} MB",
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
"backup_type": "manual",
|
||||
"status": "completed"
|
||||
}
|
||||
@@ -1118,45 +1272,61 @@ async def get_audit_logs(
|
||||
resource_type: Optional[str] = Query(None),
|
||||
action: Optional[str] = Query(None),
|
||||
hours_back: int = Query(168, ge=1, le=8760), # Default 7 days, max 1 year
|
||||
sort_by: Optional[str] = Query("timestamp", description="Sort by: timestamp, username, action, resource_type"),
|
||||
sort_dir: Optional[str] = Query("desc", description="Sort direction: asc or desc"),
|
||||
include_total: bool = Query(False, description="When true, returns {items, total} instead of a plain list"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_admin_user)
|
||||
):
|
||||
"""Get audit log entries with filtering"""
|
||||
|
||||
cutoff_time = datetime.now() - timedelta(hours=hours_back)
|
||||
|
||||
"""Get audit log entries with filtering, sorting, and pagination"""
|
||||
|
||||
cutoff_time = datetime.now(timezone.utc) - timedelta(hours=hours_back)
|
||||
|
||||
query = db.query(AuditLog).filter(AuditLog.timestamp >= cutoff_time)
|
||||
|
||||
|
||||
if user_id:
|
||||
query = query.filter(AuditLog.user_id == user_id)
|
||||
|
||||
|
||||
if resource_type:
|
||||
query = query.filter(AuditLog.resource_type.ilike(f"%{resource_type}%"))
|
||||
|
||||
|
||||
if action:
|
||||
query = query.filter(AuditLog.action.ilike(f"%{action}%"))
|
||||
|
||||
total_count = query.count()
|
||||
logs = query.order_by(AuditLog.timestamp.desc()).offset(skip).limit(limit).all()
|
||||
|
||||
return {
|
||||
"total": total_count,
|
||||
"logs": [
|
||||
{
|
||||
"id": log.id,
|
||||
"user_id": log.user_id,
|
||||
"username": log.username,
|
||||
"action": log.action,
|
||||
"resource_type": log.resource_type,
|
||||
"resource_id": log.resource_id,
|
||||
"details": log.details,
|
||||
"ip_address": log.ip_address,
|
||||
"user_agent": log.user_agent,
|
||||
"timestamp": log.timestamp.isoformat()
|
||||
}
|
||||
for log in logs
|
||||
]
|
||||
}
|
||||
|
||||
# Sorting (whitelisted)
|
||||
query = apply_sorting(
|
||||
query,
|
||||
sort_by,
|
||||
sort_dir,
|
||||
allowed={
|
||||
"timestamp": [AuditLog.timestamp],
|
||||
"username": [AuditLog.username],
|
||||
"action": [AuditLog.action],
|
||||
"resource_type": [AuditLog.resource_type],
|
||||
},
|
||||
)
|
||||
|
||||
logs, total = paginate_with_total(query, skip, limit, include_total)
|
||||
|
||||
items = [
|
||||
{
|
||||
"id": log.id,
|
||||
"user_id": log.user_id,
|
||||
"username": log.username,
|
||||
"action": log.action,
|
||||
"resource_type": log.resource_type,
|
||||
"resource_id": log.resource_id,
|
||||
"details": log.details,
|
||||
"ip_address": log.ip_address,
|
||||
"user_agent": log.user_agent,
|
||||
"timestamp": log.timestamp.isoformat(),
|
||||
}
|
||||
for log in logs
|
||||
]
|
||||
|
||||
if include_total:
|
||||
return {"items": items, "total": total or 0}
|
||||
return items
|
||||
|
||||
|
||||
@router.get("/audit/login-attempts")
|
||||
@@ -1166,39 +1336,55 @@ async def get_login_attempts(
|
||||
username: Optional[str] = Query(None),
|
||||
failed_only: bool = Query(False),
|
||||
hours_back: int = Query(168, ge=1, le=8760), # Default 7 days
|
||||
sort_by: Optional[str] = Query("timestamp", description="Sort by: timestamp, username, ip_address, success"),
|
||||
sort_dir: Optional[str] = Query("desc", description="Sort direction: asc or desc"),
|
||||
include_total: bool = Query(False, description="When true, returns {items, total} instead of a plain list"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_admin_user)
|
||||
):
|
||||
"""Get login attempts with filtering"""
|
||||
|
||||
cutoff_time = datetime.now() - timedelta(hours=hours_back)
|
||||
|
||||
"""Get login attempts with filtering, sorting, and pagination"""
|
||||
|
||||
cutoff_time = datetime.now(timezone.utc) - timedelta(hours=hours_back)
|
||||
|
||||
query = db.query(LoginAttempt).filter(LoginAttempt.timestamp >= cutoff_time)
|
||||
|
||||
|
||||
if username:
|
||||
query = query.filter(LoginAttempt.username.ilike(f"%{username}%"))
|
||||
|
||||
|
||||
if failed_only:
|
||||
query = query.filter(LoginAttempt.success == 0)
|
||||
|
||||
total_count = query.count()
|
||||
attempts = query.order_by(LoginAttempt.timestamp.desc()).offset(skip).limit(limit).all()
|
||||
|
||||
return {
|
||||
"total": total_count,
|
||||
"attempts": [
|
||||
{
|
||||
"id": attempt.id,
|
||||
"username": attempt.username,
|
||||
"ip_address": attempt.ip_address,
|
||||
"user_agent": attempt.user_agent,
|
||||
"success": bool(attempt.success),
|
||||
"failure_reason": attempt.failure_reason,
|
||||
"timestamp": attempt.timestamp.isoformat()
|
||||
}
|
||||
for attempt in attempts
|
||||
]
|
||||
}
|
||||
|
||||
# Sorting (whitelisted)
|
||||
query = apply_sorting(
|
||||
query,
|
||||
sort_by,
|
||||
sort_dir,
|
||||
allowed={
|
||||
"timestamp": [LoginAttempt.timestamp],
|
||||
"username": [LoginAttempt.username],
|
||||
"ip_address": [LoginAttempt.ip_address],
|
||||
"success": [LoginAttempt.success],
|
||||
},
|
||||
)
|
||||
|
||||
attempts, total = paginate_with_total(query, skip, limit, include_total)
|
||||
|
||||
items = [
|
||||
{
|
||||
"id": attempt.id,
|
||||
"username": attempt.username,
|
||||
"ip_address": attempt.ip_address,
|
||||
"user_agent": attempt.user_agent,
|
||||
"success": bool(attempt.success),
|
||||
"failure_reason": attempt.failure_reason,
|
||||
"timestamp": attempt.timestamp.isoformat(),
|
||||
}
|
||||
for attempt in attempts
|
||||
]
|
||||
|
||||
if include_total:
|
||||
return {"items": items, "total": total or 0}
|
||||
return items
|
||||
|
||||
|
||||
@router.get("/audit/user-activity/{user_id}")
|
||||
@@ -1251,7 +1437,7 @@ async def get_security_alerts(
|
||||
):
|
||||
"""Get security alerts and suspicious activity"""
|
||||
|
||||
cutoff_time = datetime.now() - timedelta(hours=hours_back)
|
||||
cutoff_time = datetime.now(timezone.utc) - timedelta(hours=hours_back)
|
||||
|
||||
# Get failed login attempts
|
||||
failed_logins = db.query(LoginAttempt).filter(
|
||||
@@ -1356,7 +1542,7 @@ async def get_audit_statistics(
|
||||
):
|
||||
"""Get audit statistics and metrics"""
|
||||
|
||||
cutoff_time = datetime.now() - timedelta(days=days_back)
|
||||
cutoff_time = datetime.now(timezone.utc) - timedelta(days=days_back)
|
||||
|
||||
# Total activity counts
|
||||
total_audit_entries = db.query(func.count(AuditLog.id)).filter(
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""
|
||||
Authentication API endpoints
|
||||
"""
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import List
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
@@ -69,7 +69,7 @@ async def login(login_data: LoginRequest, request: Request, db: Session = Depend
|
||||
)
|
||||
|
||||
# Update last login
|
||||
user.last_login = datetime.utcnow()
|
||||
user.last_login = datetime.now(timezone.utc)
|
||||
db.commit()
|
||||
|
||||
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
|
||||
@@ -144,7 +144,7 @@ async def read_users_me(current_user: User = Depends(get_current_user)):
|
||||
async def refresh_token_endpoint(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
body: RefreshRequest | None = None,
|
||||
body: RefreshRequest = None,
|
||||
):
|
||||
"""Issue a new access token using a valid, non-revoked refresh token.
|
||||
|
||||
@@ -203,7 +203,7 @@ async def refresh_token_endpoint(
|
||||
if not user or not user.is_active:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found or inactive")
|
||||
|
||||
user.last_login = datetime.utcnow()
|
||||
user.last_login = datetime.now(timezone.utc)
|
||||
db.commit()
|
||||
|
||||
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
|
||||
@@ -225,7 +225,7 @@ async def list_users(
|
||||
|
||||
|
||||
@router.post("/logout")
|
||||
async def logout(body: RefreshRequest | None = None, db: Session = Depends(get_db)):
|
||||
async def logout(body: RefreshRequest = None, db: Session = Depends(get_db)):
|
||||
"""Revoke the provided refresh token. Idempotent and safe to call multiple times.
|
||||
|
||||
The client should send a JSON body: { "refresh_token": "..." }.
|
||||
|
||||
@@ -4,7 +4,7 @@ Customer (Rolodex) API endpoints
|
||||
from typing import List, Optional, Union
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
from sqlalchemy import or_, and_, func, asc, desc
|
||||
from sqlalchemy import func
|
||||
from fastapi.responses import StreamingResponse
|
||||
import csv
|
||||
import io
|
||||
@@ -13,12 +13,16 @@ from app.database.base import get_db
|
||||
from app.models.rolodex import Rolodex, Phone
|
||||
from app.models.user import User
|
||||
from app.auth.security import get_current_user
|
||||
from app.services.cache import invalidate_search_cache
|
||||
from app.services.customers_search import apply_customer_filters, apply_customer_sorting, prepare_customer_csv_rows
|
||||
from app.services.query_utils import apply_sorting, paginate_with_total
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# Pydantic schemas for request/response
|
||||
from pydantic import BaseModel, EmailStr
|
||||
from pydantic import BaseModel, EmailStr, Field
|
||||
from pydantic.config import ConfigDict
|
||||
from datetime import date
|
||||
|
||||
|
||||
@@ -32,8 +36,7 @@ class PhoneResponse(BaseModel):
|
||||
location: Optional[str]
|
||||
phone: str
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class CustomerBase(BaseModel):
|
||||
@@ -84,10 +87,11 @@ class CustomerUpdate(BaseModel):
|
||||
|
||||
|
||||
class CustomerResponse(CustomerBase):
|
||||
phone_numbers: List[PhoneResponse] = []
|
||||
phone_numbers: List[PhoneResponse] = Field(default_factory=list)
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
|
||||
|
||||
@router.get("/search/phone")
|
||||
@@ -196,80 +200,17 @@ async def list_customers(
|
||||
try:
|
||||
base_query = db.query(Rolodex)
|
||||
|
||||
if search:
|
||||
s = (search or "").strip()
|
||||
s_lower = s.lower()
|
||||
tokens = [t for t in s_lower.split() if t]
|
||||
# Basic contains search on several fields (case-insensitive)
|
||||
contains_any = or_(
|
||||
func.lower(Rolodex.id).contains(s_lower),
|
||||
func.lower(Rolodex.last).contains(s_lower),
|
||||
func.lower(Rolodex.first).contains(s_lower),
|
||||
func.lower(Rolodex.middle).contains(s_lower),
|
||||
func.lower(Rolodex.city).contains(s_lower),
|
||||
func.lower(Rolodex.email).contains(s_lower),
|
||||
)
|
||||
# Multi-token name support: every token must match either first, middle, or last
|
||||
name_tokens = [
|
||||
or_(
|
||||
func.lower(Rolodex.first).contains(tok),
|
||||
func.lower(Rolodex.middle).contains(tok),
|
||||
func.lower(Rolodex.last).contains(tok),
|
||||
)
|
||||
for tok in tokens
|
||||
]
|
||||
combined = contains_any if not name_tokens else or_(contains_any, and_(*name_tokens))
|
||||
# Comma pattern: "Last, First"
|
||||
last_first_filter = None
|
||||
if "," in s_lower:
|
||||
last_part, first_part = [p.strip() for p in s_lower.split(",", 1)]
|
||||
if last_part and first_part:
|
||||
last_first_filter = and_(
|
||||
func.lower(Rolodex.last).contains(last_part),
|
||||
func.lower(Rolodex.first).contains(first_part),
|
||||
)
|
||||
elif last_part:
|
||||
last_first_filter = func.lower(Rolodex.last).contains(last_part)
|
||||
final_filter = or_(combined, last_first_filter) if last_first_filter is not None else combined
|
||||
base_query = base_query.filter(final_filter)
|
||||
|
||||
# Apply group/state filters (support single and multi-select)
|
||||
effective_groups = [g for g in (groups or []) if g] or ([group] if group else [])
|
||||
if effective_groups:
|
||||
base_query = base_query.filter(Rolodex.group.in_(effective_groups))
|
||||
effective_states = [s for s in (states or []) if s] or ([state] if state else [])
|
||||
if effective_states:
|
||||
base_query = base_query.filter(Rolodex.abrev.in_(effective_states))
|
||||
base_query = apply_customer_filters(
|
||||
base_query,
|
||||
search=search,
|
||||
group=group,
|
||||
state=state,
|
||||
groups=groups,
|
||||
states=states,
|
||||
)
|
||||
|
||||
# Apply sorting (whitelisted fields only)
|
||||
normalized_sort_by = (sort_by or "id").lower()
|
||||
normalized_sort_dir = (sort_dir or "asc").lower()
|
||||
is_desc = normalized_sort_dir == "desc"
|
||||
|
||||
order_columns = []
|
||||
if normalized_sort_by == "id":
|
||||
order_columns = [Rolodex.id]
|
||||
elif normalized_sort_by == "name":
|
||||
# Sort by last, then first
|
||||
order_columns = [Rolodex.last, Rolodex.first]
|
||||
elif normalized_sort_by == "city":
|
||||
# Sort by city, then state abbreviation
|
||||
order_columns = [Rolodex.city, Rolodex.abrev]
|
||||
elif normalized_sort_by == "email":
|
||||
order_columns = [Rolodex.email]
|
||||
else:
|
||||
# Fallback to id to avoid arbitrary column injection
|
||||
order_columns = [Rolodex.id]
|
||||
|
||||
# Case-insensitive ordering where applicable, preserving None ordering default
|
||||
ordered = []
|
||||
for col in order_columns:
|
||||
# Use lower() for string-like cols; SQLAlchemy will handle non-string safely enough for SQLite/Postgres
|
||||
expr = func.lower(col) if col.type.python_type in (str,) else col # type: ignore[attr-defined]
|
||||
ordered.append(desc(expr) if is_desc else asc(expr))
|
||||
|
||||
if ordered:
|
||||
base_query = base_query.order_by(*ordered)
|
||||
base_query = apply_customer_sorting(base_query, sort_by=sort_by, sort_dir=sort_dir)
|
||||
|
||||
customers = base_query.options(joinedload(Rolodex.phone_numbers)).offset(skip).limit(limit).all()
|
||||
if include_total:
|
||||
@@ -304,72 +245,16 @@ async def export_customers(
|
||||
try:
|
||||
base_query = db.query(Rolodex)
|
||||
|
||||
if search:
|
||||
s = (search or "").strip()
|
||||
s_lower = s.lower()
|
||||
tokens = [t for t in s_lower.split() if t]
|
||||
contains_any = or_(
|
||||
func.lower(Rolodex.id).contains(s_lower),
|
||||
func.lower(Rolodex.last).contains(s_lower),
|
||||
func.lower(Rolodex.first).contains(s_lower),
|
||||
func.lower(Rolodex.middle).contains(s_lower),
|
||||
func.lower(Rolodex.city).contains(s_lower),
|
||||
func.lower(Rolodex.email).contains(s_lower),
|
||||
)
|
||||
name_tokens = [
|
||||
or_(
|
||||
func.lower(Rolodex.first).contains(tok),
|
||||
func.lower(Rolodex.middle).contains(tok),
|
||||
func.lower(Rolodex.last).contains(tok),
|
||||
)
|
||||
for tok in tokens
|
||||
]
|
||||
combined = contains_any if not name_tokens else or_(contains_any, and_(*name_tokens))
|
||||
last_first_filter = None
|
||||
if "," in s_lower:
|
||||
last_part, first_part = [p.strip() for p in s_lower.split(",", 1)]
|
||||
if last_part and first_part:
|
||||
last_first_filter = and_(
|
||||
func.lower(Rolodex.last).contains(last_part),
|
||||
func.lower(Rolodex.first).contains(first_part),
|
||||
)
|
||||
elif last_part:
|
||||
last_first_filter = func.lower(Rolodex.last).contains(last_part)
|
||||
final_filter = or_(combined, last_first_filter) if last_first_filter is not None else combined
|
||||
base_query = base_query.filter(final_filter)
|
||||
base_query = apply_customer_filters(
|
||||
base_query,
|
||||
search=search,
|
||||
group=group,
|
||||
state=state,
|
||||
groups=groups,
|
||||
states=states,
|
||||
)
|
||||
|
||||
effective_groups = [g for g in (groups or []) if g] or ([group] if group else [])
|
||||
if effective_groups:
|
||||
base_query = base_query.filter(Rolodex.group.in_(effective_groups))
|
||||
effective_states = [s for s in (states or []) if s] or ([state] if state else [])
|
||||
if effective_states:
|
||||
base_query = base_query.filter(Rolodex.abrev.in_(effective_states))
|
||||
|
||||
normalized_sort_by = (sort_by or "id").lower()
|
||||
normalized_sort_dir = (sort_dir or "asc").lower()
|
||||
is_desc = normalized_sort_dir == "desc"
|
||||
|
||||
order_columns = []
|
||||
if normalized_sort_by == "id":
|
||||
order_columns = [Rolodex.id]
|
||||
elif normalized_sort_by == "name":
|
||||
order_columns = [Rolodex.last, Rolodex.first]
|
||||
elif normalized_sort_by == "city":
|
||||
order_columns = [Rolodex.city, Rolodex.abrev]
|
||||
elif normalized_sort_by == "email":
|
||||
order_columns = [Rolodex.email]
|
||||
else:
|
||||
order_columns = [Rolodex.id]
|
||||
|
||||
ordered = []
|
||||
for col in order_columns:
|
||||
try:
|
||||
expr = func.lower(col) if col.type.python_type in (str,) else col # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
expr = col
|
||||
ordered.append(desc(expr) if is_desc else asc(expr))
|
||||
if ordered:
|
||||
base_query = base_query.order_by(*ordered)
|
||||
base_query = apply_customer_sorting(base_query, sort_by=sort_by, sort_dir=sort_dir)
|
||||
|
||||
if not export_all:
|
||||
if skip is not None:
|
||||
@@ -382,39 +267,10 @@ async def export_customers(
|
||||
# Prepare CSV
|
||||
output = io.StringIO()
|
||||
writer = csv.writer(output)
|
||||
allowed_fields_in_order = ["id", "name", "group", "city", "state", "phone", "email"]
|
||||
header_names = {
|
||||
"id": "Customer ID",
|
||||
"name": "Name",
|
||||
"group": "Group",
|
||||
"city": "City",
|
||||
"state": "State",
|
||||
"phone": "Primary Phone",
|
||||
"email": "Email",
|
||||
}
|
||||
requested = [f.lower() for f in (fields or []) if isinstance(f, str)]
|
||||
selected_fields = [f for f in allowed_fields_in_order if f in requested] if requested else allowed_fields_in_order
|
||||
if not selected_fields:
|
||||
selected_fields = allowed_fields_in_order
|
||||
writer.writerow([header_names[f] for f in selected_fields])
|
||||
for c in customers:
|
||||
full_name = f"{(c.first or '').strip()} {(c.last or '').strip()}".strip()
|
||||
primary_phone = ""
|
||||
try:
|
||||
if c.phone_numbers:
|
||||
primary_phone = c.phone_numbers[0].phone or ""
|
||||
except Exception:
|
||||
primary_phone = ""
|
||||
row_map = {
|
||||
"id": c.id,
|
||||
"name": full_name,
|
||||
"group": c.group or "",
|
||||
"city": c.city or "",
|
||||
"state": c.abrev or "",
|
||||
"phone": primary_phone,
|
||||
"email": c.email or "",
|
||||
}
|
||||
writer.writerow([row_map[f] for f in selected_fields])
|
||||
header_row, rows = prepare_customer_csv_rows(customers, fields)
|
||||
writer.writerow(header_row)
|
||||
for row in rows:
|
||||
writer.writerow(row)
|
||||
|
||||
output.seek(0)
|
||||
filename = "customers_export.csv"
|
||||
@@ -469,6 +325,10 @@ async def create_customer(
|
||||
db.commit()
|
||||
db.refresh(customer)
|
||||
|
||||
try:
|
||||
await invalidate_search_cache()
|
||||
except Exception:
|
||||
pass
|
||||
return customer
|
||||
|
||||
|
||||
@@ -494,7 +354,10 @@ async def update_customer(
|
||||
|
||||
db.commit()
|
||||
db.refresh(customer)
|
||||
|
||||
try:
|
||||
await invalidate_search_cache()
|
||||
except Exception:
|
||||
pass
|
||||
return customer
|
||||
|
||||
|
||||
@@ -515,17 +378,30 @@ async def delete_customer(
|
||||
|
||||
db.delete(customer)
|
||||
db.commit()
|
||||
|
||||
try:
|
||||
await invalidate_search_cache()
|
||||
except Exception:
|
||||
pass
|
||||
return {"message": "Customer deleted successfully"}
|
||||
|
||||
|
||||
@router.get("/{customer_id}/phones", response_model=List[PhoneResponse])
|
||||
class PaginatedPhonesResponse(BaseModel):
|
||||
items: List[PhoneResponse]
|
||||
total: int
|
||||
|
||||
|
||||
@router.get("/{customer_id}/phones", response_model=Union[List[PhoneResponse], PaginatedPhonesResponse])
|
||||
async def get_customer_phones(
|
||||
customer_id: str,
|
||||
skip: int = Query(0, ge=0, description="Offset for pagination"),
|
||||
limit: int = Query(100, ge=1, le=1000, description="Page size"),
|
||||
sort_by: Optional[str] = Query("location", description="Sort by: location, phone"),
|
||||
sort_dir: Optional[str] = Query("asc", description="Sort direction: asc or desc"),
|
||||
include_total: bool = Query(False, description="When true, returns {items, total} instead of a plain list"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Get customer phone numbers"""
|
||||
"""Get customer phone numbers with optional sorting/pagination"""
|
||||
customer = db.query(Rolodex).filter(Rolodex.id == customer_id).first()
|
||||
|
||||
if not customer:
|
||||
@@ -534,7 +410,21 @@ async def get_customer_phones(
|
||||
detail="Customer not found"
|
||||
)
|
||||
|
||||
phones = db.query(Phone).filter(Phone.rolodex_id == customer_id).all()
|
||||
query = db.query(Phone).filter(Phone.rolodex_id == customer_id)
|
||||
|
||||
query = apply_sorting(
|
||||
query,
|
||||
sort_by,
|
||||
sort_dir,
|
||||
allowed={
|
||||
"location": [Phone.location, Phone.phone],
|
||||
"phone": [Phone.phone],
|
||||
},
|
||||
)
|
||||
|
||||
phones, total = paginate_with_total(query, skip, limit, include_total)
|
||||
if include_total:
|
||||
return {"items": phones, "total": total or 0}
|
||||
return phones
|
||||
|
||||
|
||||
|
||||
@@ -1,16 +1,19 @@
|
||||
"""
|
||||
Document Management API endpoints - QDROs, Templates, and General Documents
|
||||
"""
|
||||
from typing import List, Optional, Dict, Any
|
||||
from __future__ import annotations
|
||||
from typing import List, Optional, Dict, Any, Union
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query, UploadFile, File, Form, Request
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
from sqlalchemy import or_, func, and_, desc, asc, text
|
||||
from datetime import date, datetime
|
||||
from datetime import date, datetime, timezone
|
||||
import os
|
||||
import uuid
|
||||
import shutil
|
||||
|
||||
from app.database.base import get_db
|
||||
from app.api.search_highlight import build_query_tokens
|
||||
from app.services.query_utils import tokenized_ilike_filter, apply_pagination, apply_sorting, paginate_with_total
|
||||
from app.models.qdro import QDRO
|
||||
from app.models.files import File as FileModel
|
||||
from app.models.rolodex import Rolodex
|
||||
@@ -20,18 +23,20 @@ from app.auth.security import get_current_user
|
||||
from app.models.additional import Document
|
||||
from app.core.logging import get_logger
|
||||
from app.services.audit import audit_service
|
||||
from app.services.cache import invalidate_search_cache
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# Pydantic schemas
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class QDROBase(BaseModel):
|
||||
file_no: str
|
||||
version: str = "01"
|
||||
title: Optional[str] = None
|
||||
form_name: Optional[str] = None
|
||||
content: Optional[str] = None
|
||||
status: str = "DRAFT"
|
||||
created_date: Optional[date] = None
|
||||
@@ -51,6 +56,7 @@ class QDROCreate(QDROBase):
|
||||
class QDROUpdate(BaseModel):
|
||||
version: Optional[str] = None
|
||||
title: Optional[str] = None
|
||||
form_name: Optional[str] = None
|
||||
content: Optional[str] = None
|
||||
status: Optional[str] = None
|
||||
created_date: Optional[date] = None
|
||||
@@ -66,27 +72,61 @@ class QDROUpdate(BaseModel):
|
||||
class QDROResponse(QDROBase):
|
||||
id: int
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
@router.get("/qdros/{file_no}", response_model=List[QDROResponse])
|
||||
class PaginatedQDROResponse(BaseModel):
|
||||
items: List[QDROResponse]
|
||||
total: int
|
||||
|
||||
|
||||
@router.get("/qdros/{file_no}", response_model=Union[List[QDROResponse], PaginatedQDROResponse])
|
||||
async def get_file_qdros(
|
||||
file_no: str,
|
||||
skip: int = Query(0, ge=0, description="Offset for pagination"),
|
||||
limit: int = Query(100, ge=1, le=1000, description="Page size"),
|
||||
sort_by: Optional[str] = Query("updated", description="Sort by: updated, created, version, status"),
|
||||
sort_dir: Optional[str] = Query("desc", description="Sort direction: asc or desc"),
|
||||
include_total: bool = Query(False, description="When true, returns {items, total} instead of a plain list"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Get QDROs for specific file"""
|
||||
qdros = db.query(QDRO).filter(QDRO.file_no == file_no).order_by(QDRO.version).all()
|
||||
"""Get QDROs for a specific file with optional sorting/pagination"""
|
||||
query = db.query(QDRO).filter(QDRO.file_no == file_no)
|
||||
|
||||
# Sorting (whitelisted)
|
||||
query = apply_sorting(
|
||||
query,
|
||||
sort_by,
|
||||
sort_dir,
|
||||
allowed={
|
||||
"updated": [QDRO.updated_at, QDRO.id],
|
||||
"created": [QDRO.created_at, QDRO.id],
|
||||
"version": [QDRO.version],
|
||||
"status": [QDRO.status],
|
||||
},
|
||||
)
|
||||
|
||||
qdros, total = paginate_with_total(query, skip, limit, include_total)
|
||||
if include_total:
|
||||
return {"items": qdros, "total": total or 0}
|
||||
return qdros
|
||||
|
||||
|
||||
@router.get("/qdros/", response_model=List[QDROResponse])
|
||||
class PaginatedQDROResponse(BaseModel):
|
||||
items: List[QDROResponse]
|
||||
total: int
|
||||
|
||||
|
||||
@router.get("/qdros/", response_model=Union[List[QDROResponse], PaginatedQDROResponse])
|
||||
async def list_qdros(
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
status_filter: Optional[str] = Query(None),
|
||||
search: Optional[str] = Query(None),
|
||||
sort_by: Optional[str] = Query(None, description="Sort by: file_no, version, status, created, updated"),
|
||||
sort_dir: Optional[str] = Query("asc", description="Sort direction: asc or desc"),
|
||||
include_total: bool = Query(False, description="When true, returns {items, total} instead of a plain list"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
@@ -97,17 +137,37 @@ async def list_qdros(
|
||||
query = query.filter(QDRO.status == status_filter)
|
||||
|
||||
if search:
|
||||
query = query.filter(
|
||||
or_(
|
||||
QDRO.file_no.contains(search),
|
||||
QDRO.title.contains(search),
|
||||
QDRO.participant_name.contains(search),
|
||||
QDRO.spouse_name.contains(search),
|
||||
QDRO.plan_name.contains(search)
|
||||
)
|
||||
)
|
||||
|
||||
qdros = query.offset(skip).limit(limit).all()
|
||||
# DRY: tokenize and apply case-insensitive search across common QDRO fields
|
||||
tokens = build_query_tokens(search)
|
||||
filter_expr = tokenized_ilike_filter(tokens, [
|
||||
QDRO.file_no,
|
||||
QDRO.form_name,
|
||||
QDRO.pet,
|
||||
QDRO.res,
|
||||
QDRO.case_number,
|
||||
QDRO.notes,
|
||||
QDRO.status,
|
||||
])
|
||||
if filter_expr is not None:
|
||||
query = query.filter(filter_expr)
|
||||
|
||||
# Sorting (whitelisted)
|
||||
query = apply_sorting(
|
||||
query,
|
||||
sort_by,
|
||||
sort_dir,
|
||||
allowed={
|
||||
"file_no": [QDRO.file_no],
|
||||
"version": [QDRO.version],
|
||||
"status": [QDRO.status],
|
||||
"created": [QDRO.created_at],
|
||||
"updated": [QDRO.updated_at],
|
||||
},
|
||||
)
|
||||
|
||||
qdros, total = paginate_with_total(query, skip, limit, include_total)
|
||||
if include_total:
|
||||
return {"items": qdros, "total": total or 0}
|
||||
return qdros
|
||||
|
||||
|
||||
@@ -135,6 +195,10 @@ async def create_qdro(
|
||||
db.commit()
|
||||
db.refresh(qdro)
|
||||
|
||||
try:
|
||||
await invalidate_search_cache()
|
||||
except Exception:
|
||||
pass
|
||||
return qdro
|
||||
|
||||
|
||||
@@ -189,6 +253,10 @@ async def update_qdro(
|
||||
db.commit()
|
||||
db.refresh(qdro)
|
||||
|
||||
try:
|
||||
await invalidate_search_cache()
|
||||
except Exception:
|
||||
pass
|
||||
return qdro
|
||||
|
||||
|
||||
@@ -213,7 +281,10 @@ async def delete_qdro(
|
||||
|
||||
db.delete(qdro)
|
||||
db.commit()
|
||||
|
||||
try:
|
||||
await invalidate_search_cache()
|
||||
except Exception:
|
||||
pass
|
||||
return {"message": "QDRO deleted successfully"}
|
||||
|
||||
|
||||
@@ -241,8 +312,7 @@ class TemplateResponse(TemplateBase):
|
||||
active: bool = True
|
||||
created_at: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
# Document Generation Schema
|
||||
class DocumentGenerateRequest(BaseModel):
|
||||
@@ -269,13 +339,21 @@ class DocumentStats(BaseModel):
|
||||
recent_activity: List[Dict[str, Any]]
|
||||
|
||||
|
||||
@router.get("/templates/", response_model=List[TemplateResponse])
|
||||
class PaginatedTemplatesResponse(BaseModel):
|
||||
items: List[TemplateResponse]
|
||||
total: int
|
||||
|
||||
|
||||
@router.get("/templates/", response_model=Union[List[TemplateResponse], PaginatedTemplatesResponse])
|
||||
async def list_templates(
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
category: Optional[str] = Query(None),
|
||||
search: Optional[str] = Query(None),
|
||||
active_only: bool = Query(True),
|
||||
sort_by: Optional[str] = Query(None, description="Sort by: form_id, form_name, category, created, updated"),
|
||||
sort_dir: Optional[str] = Query("asc", description="Sort direction: asc or desc"),
|
||||
include_total: bool = Query(False, description="When true, returns {items, total} instead of a plain list"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
@@ -289,14 +367,31 @@ async def list_templates(
|
||||
query = query.filter(FormIndex.category == category)
|
||||
|
||||
if search:
|
||||
query = query.filter(
|
||||
or_(
|
||||
FormIndex.form_name.contains(search),
|
||||
FormIndex.form_id.contains(search)
|
||||
)
|
||||
)
|
||||
|
||||
templates = query.offset(skip).limit(limit).all()
|
||||
# DRY: tokenize and apply case-insensitive search for templates
|
||||
tokens = build_query_tokens(search)
|
||||
filter_expr = tokenized_ilike_filter(tokens, [
|
||||
FormIndex.form_name,
|
||||
FormIndex.form_id,
|
||||
FormIndex.category,
|
||||
])
|
||||
if filter_expr is not None:
|
||||
query = query.filter(filter_expr)
|
||||
|
||||
# Sorting (whitelisted)
|
||||
query = apply_sorting(
|
||||
query,
|
||||
sort_by,
|
||||
sort_dir,
|
||||
allowed={
|
||||
"form_id": [FormIndex.form_id],
|
||||
"form_name": [FormIndex.form_name],
|
||||
"category": [FormIndex.category],
|
||||
"created": [FormIndex.created_at],
|
||||
"updated": [FormIndex.updated_at],
|
||||
},
|
||||
)
|
||||
|
||||
templates, total = paginate_with_total(query, skip, limit, include_total)
|
||||
|
||||
# Enhanced response with template content
|
||||
results = []
|
||||
@@ -317,6 +412,8 @@ async def list_templates(
|
||||
"variables": _extract_variables_from_content(content)
|
||||
})
|
||||
|
||||
if include_total:
|
||||
return {"items": results, "total": total or 0}
|
||||
return results
|
||||
|
||||
|
||||
@@ -356,6 +453,10 @@ async def create_template(
|
||||
|
||||
db.commit()
|
||||
db.refresh(form_index)
|
||||
try:
|
||||
await invalidate_search_cache()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {
|
||||
"form_id": form_index.form_id,
|
||||
@@ -440,6 +541,10 @@ async def update_template(
|
||||
|
||||
db.commit()
|
||||
db.refresh(template)
|
||||
try:
|
||||
await invalidate_search_cache()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Get updated content
|
||||
template_lines = db.query(FormList).filter(
|
||||
@@ -480,6 +585,10 @@ async def delete_template(
|
||||
# Delete template
|
||||
db.delete(template)
|
||||
db.commit()
|
||||
try:
|
||||
await invalidate_search_cache()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {"message": "Template deleted successfully"}
|
||||
|
||||
@@ -574,7 +683,7 @@ async def generate_document(
|
||||
"file_name": file_name,
|
||||
"file_path": file_path,
|
||||
"size": file_size,
|
||||
"created_at": datetime.now()
|
||||
"created_at": datetime.now(timezone.utc)
|
||||
}
|
||||
|
||||
|
||||
@@ -629,32 +738,49 @@ async def get_document_stats(
|
||||
@router.get("/file/{file_no}/documents")
|
||||
async def get_file_documents(
|
||||
file_no: str,
|
||||
sort_by: Optional[str] = Query("updated", description="Sort by: updated, created"),
|
||||
sort_dir: Optional[str] = Query("desc", description="Sort direction: asc or desc"),
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(100, ge=1, le=1000),
|
||||
include_total: bool = Query(False, description="When true, returns {items, total} instead of a plain list"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Get all documents associated with a specific file"""
|
||||
# Get QDROs for this file
|
||||
qdros = db.query(QDRO).filter(QDRO.file_no == file_no).order_by(desc(QDRO.updated_at)).all()
|
||||
|
||||
# Format response
|
||||
documents = [
|
||||
"""Get all documents associated with a specific file, with optional sorting/pagination"""
|
||||
# Base query for QDROs tied to the file
|
||||
query = db.query(QDRO).filter(QDRO.file_no == file_no)
|
||||
|
||||
# Apply sorting using shared helper (map friendly names to columns)
|
||||
query = apply_sorting(
|
||||
query,
|
||||
sort_by,
|
||||
sort_dir,
|
||||
allowed={
|
||||
"updated": [QDRO.updated_at, QDRO.id],
|
||||
"created": [QDRO.created_at, QDRO.id],
|
||||
},
|
||||
)
|
||||
|
||||
qdros, total = paginate_with_total(query, skip, limit, include_total)
|
||||
|
||||
items = [
|
||||
{
|
||||
"id": qdro.id,
|
||||
"type": "QDRO",
|
||||
"title": f"QDRO v{qdro.version}",
|
||||
"status": qdro.status,
|
||||
"created_date": qdro.created_date.isoformat() if qdro.created_date else None,
|
||||
"updated_at": qdro.updated_at.isoformat() if qdro.updated_at else None,
|
||||
"file_no": qdro.file_no
|
||||
"created_date": qdro.created_date.isoformat() if getattr(qdro, "created_date", None) else None,
|
||||
"updated_at": qdro.updated_at.isoformat() if getattr(qdro, "updated_at", None) else None,
|
||||
"file_no": qdro.file_no,
|
||||
}
|
||||
for qdro in qdros
|
||||
]
|
||||
|
||||
return {
|
||||
"file_no": file_no,
|
||||
"documents": documents,
|
||||
"total_count": len(documents)
|
||||
}
|
||||
|
||||
payload = {"file_no": file_no, "documents": items, "total_count": (total if include_total else None)}
|
||||
# Maintain previous shape by omitting total_count when include_total is False? The prior code always returned total_count.
|
||||
# Keep total_count for backward compatibility but set to actual total when include_total else len(items)
|
||||
payload["total_count"] = (total if include_total else len(items))
|
||||
return payload
|
||||
|
||||
|
||||
def _extract_variables_from_content(content: str) -> Dict[str, str]:
|
||||
|
||||
@@ -1,25 +1,29 @@
|
||||
"""
|
||||
File Management API endpoints
|
||||
"""
|
||||
from typing import List, Optional, Dict, Any
|
||||
from typing import List, Optional, Dict, Any, Union
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
from sqlalchemy import or_, func, and_, desc
|
||||
from datetime import date, datetime
|
||||
|
||||
from app.database.base import get_db
|
||||
from app.api.search_highlight import build_query_tokens
|
||||
from app.services.query_utils import tokenized_ilike_filter, apply_pagination, apply_sorting, paginate_with_total
|
||||
from app.models.files import File
|
||||
from app.models.rolodex import Rolodex
|
||||
from app.models.ledger import Ledger
|
||||
from app.models.lookups import Employee, FileType, FileStatus
|
||||
from app.models.user import User
|
||||
from app.auth.security import get_current_user
|
||||
from app.services.cache import invalidate_search_cache
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# Pydantic schemas
|
||||
from pydantic import BaseModel
|
||||
from pydantic.config import ConfigDict
|
||||
|
||||
|
||||
class FileBase(BaseModel):
|
||||
@@ -67,17 +71,24 @@ class FileResponse(FileBase):
|
||||
amount_owing: float = 0.0
|
||||
transferable: float = 0.0
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
@router.get("/", response_model=List[FileResponse])
|
||||
class PaginatedFilesResponse(BaseModel):
|
||||
items: List[FileResponse]
|
||||
total: int
|
||||
|
||||
|
||||
@router.get("/", response_model=Union[List[FileResponse], PaginatedFilesResponse])
|
||||
async def list_files(
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
search: Optional[str] = Query(None),
|
||||
status_filter: Optional[str] = Query(None),
|
||||
employee_filter: Optional[str] = Query(None),
|
||||
sort_by: Optional[str] = Query(None, description="Sort by: file_no, client, opened, closed, status, amount_owing, total_charges"),
|
||||
sort_dir: Optional[str] = Query("asc", description="Sort direction: asc or desc"),
|
||||
include_total: bool = Query(False, description="When true, returns {items, total} instead of a plain list"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
@@ -85,14 +96,17 @@ async def list_files(
|
||||
query = db.query(File)
|
||||
|
||||
if search:
|
||||
query = query.filter(
|
||||
or_(
|
||||
File.file_no.contains(search),
|
||||
File.id.contains(search),
|
||||
File.regarding.contains(search),
|
||||
File.file_type.contains(search)
|
||||
)
|
||||
)
|
||||
# DRY: tokenize and apply case-insensitive search consistently with search endpoints
|
||||
tokens = build_query_tokens(search)
|
||||
filter_expr = tokenized_ilike_filter(tokens, [
|
||||
File.file_no,
|
||||
File.id,
|
||||
File.regarding,
|
||||
File.file_type,
|
||||
File.memo,
|
||||
])
|
||||
if filter_expr is not None:
|
||||
query = query.filter(filter_expr)
|
||||
|
||||
if status_filter:
|
||||
query = query.filter(File.status == status_filter)
|
||||
@@ -100,7 +114,25 @@ async def list_files(
|
||||
if employee_filter:
|
||||
query = query.filter(File.empl_num == employee_filter)
|
||||
|
||||
files = query.offset(skip).limit(limit).all()
|
||||
# Sorting (whitelisted)
|
||||
query = apply_sorting(
|
||||
query,
|
||||
sort_by,
|
||||
sort_dir,
|
||||
allowed={
|
||||
"file_no": [File.file_no],
|
||||
"client": [File.id],
|
||||
"opened": [File.opened],
|
||||
"closed": [File.closed],
|
||||
"status": [File.status],
|
||||
"amount_owing": [File.amount_owing],
|
||||
"total_charges": [File.total_charges],
|
||||
},
|
||||
)
|
||||
|
||||
files, total = paginate_with_total(query, skip, limit, include_total)
|
||||
if include_total:
|
||||
return {"items": files, "total": total or 0}
|
||||
return files
|
||||
|
||||
|
||||
@@ -142,6 +174,10 @@ async def create_file(
|
||||
db.commit()
|
||||
db.refresh(file_obj)
|
||||
|
||||
try:
|
||||
await invalidate_search_cache()
|
||||
except Exception:
|
||||
pass
|
||||
return file_obj
|
||||
|
||||
|
||||
@@ -167,7 +203,10 @@ async def update_file(
|
||||
|
||||
db.commit()
|
||||
db.refresh(file_obj)
|
||||
|
||||
try:
|
||||
await invalidate_search_cache()
|
||||
except Exception:
|
||||
pass
|
||||
return file_obj
|
||||
|
||||
|
||||
@@ -188,7 +227,10 @@ async def delete_file(
|
||||
|
||||
db.delete(file_obj)
|
||||
db.commit()
|
||||
|
||||
try:
|
||||
await invalidate_search_cache()
|
||||
except Exception:
|
||||
pass
|
||||
return {"message": "File deleted successfully"}
|
||||
|
||||
|
||||
@@ -433,11 +475,13 @@ async def advanced_file_search(
|
||||
query = query.filter(File.file_no.contains(file_no))
|
||||
|
||||
if client_name:
|
||||
# SQLite-safe concatenation for first + last name
|
||||
full_name_expr = (func.coalesce(Rolodex.first, '') + ' ' + func.coalesce(Rolodex.last, ''))
|
||||
query = query.join(Rolodex).filter(
|
||||
or_(
|
||||
Rolodex.first.contains(client_name),
|
||||
Rolodex.last.contains(client_name),
|
||||
func.concat(Rolodex.first, ' ', Rolodex.last).contains(client_name)
|
||||
full_name_expr.contains(client_name)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
"""
|
||||
Financial/Ledger API endpoints
|
||||
"""
|
||||
from typing import List, Optional, Dict, Any
|
||||
from typing import List, Optional, Dict, Any, Union
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
from sqlalchemy import or_, func, and_, desc, asc, text
|
||||
from datetime import date, datetime, timedelta
|
||||
from datetime import date, datetime, timedelta, timezone
|
||||
|
||||
from app.database.base import get_db
|
||||
from app.models.ledger import Ledger
|
||||
@@ -14,12 +14,14 @@ from app.models.rolodex import Rolodex
|
||||
from app.models.lookups import Employee, TransactionType, TransactionCode
|
||||
from app.models.user import User
|
||||
from app.auth.security import get_current_user
|
||||
from app.services.cache import invalidate_search_cache
|
||||
from app.services.query_utils import apply_sorting, paginate_with_total
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# Pydantic schemas
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class LedgerBase(BaseModel):
|
||||
@@ -57,8 +59,7 @@ class LedgerResponse(LedgerBase):
|
||||
id: int
|
||||
item_no: int
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class FinancialSummary(BaseModel):
|
||||
@@ -75,23 +76,46 @@ class FinancialSummary(BaseModel):
|
||||
billed_amount: float
|
||||
|
||||
|
||||
@router.get("/ledger/{file_no}", response_model=List[LedgerResponse])
|
||||
class PaginatedLedgerResponse(BaseModel):
|
||||
items: List[LedgerResponse]
|
||||
total: int
|
||||
|
||||
|
||||
@router.get("/ledger/{file_no}", response_model=Union[List[LedgerResponse], PaginatedLedgerResponse])
|
||||
async def get_file_ledger(
|
||||
file_no: str,
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(100, ge=1, le=500),
|
||||
billed_only: Optional[bool] = Query(None),
|
||||
skip: int = Query(0, ge=0, description="Offset for pagination"),
|
||||
limit: int = Query(100, ge=1, le=500, description="Page size"),
|
||||
billed_only: Optional[bool] = Query(None, description="Filter billed vs unbilled entries"),
|
||||
sort_by: Optional[str] = Query("date", description="Sort by: date, item_no, amount, billed"),
|
||||
sort_dir: Optional[str] = Query("desc", description="Sort direction: asc or desc"),
|
||||
include_total: bool = Query(False, description="When true, returns {items, total} instead of a plain list"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Get ledger entries for specific file"""
|
||||
query = db.query(Ledger).filter(Ledger.file_no == file_no).order_by(Ledger.date.desc())
|
||||
query = db.query(Ledger).filter(Ledger.file_no == file_no)
|
||||
|
||||
if billed_only is not None:
|
||||
billed_filter = "Y" if billed_only else "N"
|
||||
query = query.filter(Ledger.billed == billed_filter)
|
||||
|
||||
entries = query.offset(skip).limit(limit).all()
|
||||
|
||||
# Sorting (whitelisted)
|
||||
query = apply_sorting(
|
||||
query,
|
||||
sort_by,
|
||||
sort_dir,
|
||||
allowed={
|
||||
"date": [Ledger.date, Ledger.item_no],
|
||||
"item_no": [Ledger.item_no],
|
||||
"amount": [Ledger.amount],
|
||||
"billed": [Ledger.billed, Ledger.date],
|
||||
},
|
||||
)
|
||||
|
||||
entries, total = paginate_with_total(query, skip, limit, include_total)
|
||||
if include_total:
|
||||
return {"items": entries, "total": total or 0}
|
||||
return entries
|
||||
|
||||
|
||||
@@ -127,6 +151,10 @@ async def create_ledger_entry(
|
||||
# Update file balances (simplified version)
|
||||
await _update_file_balances(file_obj, db)
|
||||
|
||||
try:
|
||||
await invalidate_search_cache()
|
||||
except Exception:
|
||||
pass
|
||||
return entry
|
||||
|
||||
|
||||
@@ -158,6 +186,10 @@ async def update_ledger_entry(
|
||||
if file_obj:
|
||||
await _update_file_balances(file_obj, db)
|
||||
|
||||
try:
|
||||
await invalidate_search_cache()
|
||||
except Exception:
|
||||
pass
|
||||
return entry
|
||||
|
||||
|
||||
@@ -185,6 +217,10 @@ async def delete_ledger_entry(
|
||||
if file_obj:
|
||||
await _update_file_balances(file_obj, db)
|
||||
|
||||
try:
|
||||
await invalidate_search_cache()
|
||||
except Exception:
|
||||
pass
|
||||
return {"message": "Ledger entry deleted successfully"}
|
||||
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import re
|
||||
import os
|
||||
from pathlib import Path
|
||||
from difflib import SequenceMatcher
|
||||
from datetime import datetime, date
|
||||
from datetime import datetime, date, timezone
|
||||
from decimal import Decimal
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File as UploadFileForm, Form, Query
|
||||
@@ -19,8 +19,8 @@ from app.models.rolodex import Rolodex, Phone
|
||||
from app.models.files import File
|
||||
from app.models.ledger import Ledger
|
||||
from app.models.qdro import QDRO
|
||||
from app.models.pensions import Pension, PensionSchedule, MarriageHistory, DeathBenefit, SeparationAgreement, LifeTable, NumberTable
|
||||
from app.models.lookups import Employee, FileType, FileStatus, TransactionType, TransactionCode, State, GroupLookup, Footer, PlanInfo, FormIndex, FormList, PrinterSetup, SystemSetup
|
||||
from app.models.pensions import Pension, PensionSchedule, MarriageHistory, DeathBenefit, SeparationAgreement, LifeTable, NumberTable, PensionResult
|
||||
from app.models.lookups import Employee, FileType, FileStatus, TransactionType, TransactionCode, State, GroupLookup, Footer, PlanInfo, FormIndex, FormList, PrinterSetup, SystemSetup, FormKeyword
|
||||
from app.models.additional import Payment, Deposit, FileNote, FormVariable, ReportVariable
|
||||
from app.models.flexible import FlexibleImport
|
||||
from app.models.audit import ImportAudit, ImportAuditFile
|
||||
@@ -28,6 +28,25 @@ from app.config import settings
|
||||
|
||||
router = APIRouter(tags=["import"])
|
||||
|
||||
# Common encodings to try for legacy CSV files (order matters)
|
||||
ENCODINGS = [
|
||||
'utf-8-sig',
|
||||
'utf-8',
|
||||
'windows-1252',
|
||||
'iso-8859-1',
|
||||
'cp1252',
|
||||
]
|
||||
|
||||
# Unified import order used across batch operations
|
||||
IMPORT_ORDER = [
|
||||
"STATES.csv", "GRUPLKUP.csv", "EMPLOYEE.csv", "FILETYPE.csv", "FILESTAT.csv",
|
||||
"TRNSTYPE.csv", "TRNSLKUP.csv", "FOOTERS.csv", "SETUP.csv", "PRINTERS.csv",
|
||||
"INX_LKUP.csv",
|
||||
"ROLODEX.csv", "PHONE.csv", "FILES.csv", "LEDGER.csv", "TRNSACTN.csv",
|
||||
"QDROS.csv", "PENSIONS.csv", "LIFETABL.csv", "NUMBERAL.csv", "PLANINFO.csv", "RESULTS.csv", "PAYMENTS.csv", "DEPOSITS.csv",
|
||||
"FILENOTS.csv", "FORM_INX.csv", "FORM_LST.csv", "FVARLKUP.csv", "RVARLKUP.csv"
|
||||
]
|
||||
|
||||
|
||||
# CSV to Model mapping
|
||||
CSV_MODEL_MAPPING = {
|
||||
@@ -56,7 +75,6 @@ CSV_MODEL_MAPPING = {
|
||||
"FOOTERS.csv": Footer,
|
||||
"PLANINFO.csv": PlanInfo,
|
||||
# Legacy alternate names from export directories
|
||||
"SCHEDULE.csv": PensionSchedule,
|
||||
"FORM_INX.csv": FormIndex,
|
||||
"FORM_LST.csv": FormList,
|
||||
"PRINTERS.csv": PrinterSetup,
|
||||
@@ -67,7 +85,9 @@ CSV_MODEL_MAPPING = {
|
||||
"FVARLKUP.csv": FormVariable,
|
||||
"RVARLKUP.csv": ReportVariable,
|
||||
"PAYMENTS.csv": Payment,
|
||||
"TRNSACTN.csv": Ledger # Maps to existing Ledger model (same structure)
|
||||
"TRNSACTN.csv": Ledger, # Maps to existing Ledger model (same structure)
|
||||
"INX_LKUP.csv": FormKeyword,
|
||||
"RESULTS.csv": PensionResult
|
||||
}
|
||||
|
||||
# Field mappings for CSV columns to database fields
|
||||
@@ -230,8 +250,12 @@ FIELD_MAPPINGS = {
|
||||
"Default_Rate": "default_rate"
|
||||
},
|
||||
"FILESTAT.csv": {
|
||||
"Status": "status_code",
|
||||
"Status_Code": "status_code",
|
||||
"Definition": "description",
|
||||
"Description": "description",
|
||||
"Send": "send",
|
||||
"Footer_Code": "footer_code",
|
||||
"Sort_Order": "sort_order"
|
||||
},
|
||||
"FOOTERS.csv": {
|
||||
@@ -253,22 +277,44 @@ FIELD_MAPPINGS = {
|
||||
"Phone": "phone",
|
||||
"Notes": "notes"
|
||||
},
|
||||
"INX_LKUP.csv": {
|
||||
"Keyword": "keyword",
|
||||
"Description": "description"
|
||||
},
|
||||
"FORM_INX.csv": {
|
||||
"Form_Id": "form_id",
|
||||
"Form_Name": "form_name",
|
||||
"Category": "category"
|
||||
"Name": "form_id",
|
||||
"Keyword": "keyword"
|
||||
},
|
||||
"FORM_LST.csv": {
|
||||
"Form_Id": "form_id",
|
||||
"Line_Number": "line_number",
|
||||
"Content": "content"
|
||||
"Name": "form_id",
|
||||
"Memo": "content",
|
||||
"Status": "status"
|
||||
},
|
||||
"PRINTERS.csv": {
|
||||
# Legacy variants
|
||||
"Printer_Name": "printer_name",
|
||||
"Description": "description",
|
||||
"Driver": "driver",
|
||||
"Port": "port",
|
||||
"Default_Printer": "default_printer"
|
||||
"Default_Printer": "default_printer",
|
||||
# Observed legacy headers from export
|
||||
"Number": "number",
|
||||
"Name": "printer_name",
|
||||
"Page_Break": "page_break",
|
||||
"Setup_St": "setup_st",
|
||||
"Reset_St": "reset_st",
|
||||
"B_Underline": "b_underline",
|
||||
"E_Underline": "e_underline",
|
||||
"B_Bold": "b_bold",
|
||||
"E_Bold": "e_bold",
|
||||
# Optional report toggles
|
||||
"Phone_Book": "phone_book",
|
||||
"Rolodex_Info": "rolodex_info",
|
||||
"Envelope": "envelope",
|
||||
"File_Cabinet": "file_cabinet",
|
||||
"Accounts": "accounts",
|
||||
"Statements": "statements",
|
||||
"Calendar": "calendar",
|
||||
},
|
||||
"SETUP.csv": {
|
||||
"Setting_Key": "setting_key",
|
||||
@@ -285,32 +331,98 @@ FIELD_MAPPINGS = {
|
||||
"MARRIAGE.csv": {
|
||||
"File_No": "file_no",
|
||||
"Version": "version",
|
||||
"Marriage_Date": "marriage_date",
|
||||
"Separation_Date": "separation_date",
|
||||
"Divorce_Date": "divorce_date"
|
||||
"Married_From": "married_from",
|
||||
"Married_To": "married_to",
|
||||
"Married_Years": "married_years",
|
||||
"Service_From": "service_from",
|
||||
"Service_To": "service_to",
|
||||
"Service_Years": "service_years",
|
||||
"Marital_%": "marital_percent"
|
||||
},
|
||||
"DEATH.csv": {
|
||||
"File_No": "file_no",
|
||||
"Version": "version",
|
||||
"Benefit_Type": "benefit_type",
|
||||
"Benefit_Amount": "benefit_amount",
|
||||
"Beneficiary": "beneficiary"
|
||||
"Lump1": "lump1",
|
||||
"Lump2": "lump2",
|
||||
"Growth1": "growth1",
|
||||
"Growth2": "growth2",
|
||||
"Disc1": "disc1",
|
||||
"Disc2": "disc2"
|
||||
},
|
||||
"SEPARATE.csv": {
|
||||
"File_No": "file_no",
|
||||
"Version": "version",
|
||||
"Agreement_Date": "agreement_date",
|
||||
"Terms": "terms"
|
||||
"Separation_Rate": "terms"
|
||||
},
|
||||
"LIFETABL.csv": {
|
||||
"Age": "age",
|
||||
"Male_Mortality": "male_mortality",
|
||||
"Female_Mortality": "female_mortality"
|
||||
"AGE": "age",
|
||||
"LE_AA": "le_aa",
|
||||
"NA_AA": "na_aa",
|
||||
"LE_AM": "le_am",
|
||||
"NA_AM": "na_am",
|
||||
"LE_AF": "le_af",
|
||||
"NA_AF": "na_af",
|
||||
"LE_WA": "le_wa",
|
||||
"NA_WA": "na_wa",
|
||||
"LE_WM": "le_wm",
|
||||
"NA_WM": "na_wm",
|
||||
"LE_WF": "le_wf",
|
||||
"NA_WF": "na_wf",
|
||||
"LE_BA": "le_ba",
|
||||
"NA_BA": "na_ba",
|
||||
"LE_BM": "le_bm",
|
||||
"NA_BM": "na_bm",
|
||||
"LE_BF": "le_bf",
|
||||
"NA_BF": "na_bf",
|
||||
"LE_HA": "le_ha",
|
||||
"NA_HA": "na_ha",
|
||||
"LE_HM": "le_hm",
|
||||
"NA_HM": "na_hm",
|
||||
"LE_HF": "le_hf",
|
||||
"NA_HF": "na_hf"
|
||||
},
|
||||
"NUMBERAL.csv": {
|
||||
"Table_Name": "table_name",
|
||||
"Month": "month",
|
||||
"NA_AA": "na_aa",
|
||||
"NA_AM": "na_am",
|
||||
"NA_AF": "na_af",
|
||||
"NA_WA": "na_wa",
|
||||
"NA_WM": "na_wm",
|
||||
"NA_WF": "na_wf",
|
||||
"NA_BA": "na_ba",
|
||||
"NA_BM": "na_bm",
|
||||
"NA_BF": "na_bf",
|
||||
"NA_HA": "na_ha",
|
||||
"NA_HM": "na_hm",
|
||||
"NA_HF": "na_hf"
|
||||
},
|
||||
"RESULTS.csv": {
|
||||
"Accrued": "accrued",
|
||||
"Start_Age": "start_age",
|
||||
"COLA": "cola",
|
||||
"Withdrawal": "withdrawal",
|
||||
"Pre_DR": "pre_dr",
|
||||
"Post_DR": "post_dr",
|
||||
"Tax_Rate": "tax_rate",
|
||||
"Age": "age",
|
||||
"Value": "value"
|
||||
"Years_From": "years_from",
|
||||
"Life_Exp": "life_exp",
|
||||
"EV_Monthly": "ev_monthly",
|
||||
"Payments": "payments",
|
||||
"Pay_Out": "pay_out",
|
||||
"Fund_Value": "fund_value",
|
||||
"PV": "pv",
|
||||
"Mortality": "mortality",
|
||||
"PV_AM": "pv_am",
|
||||
"PV_AMT": "pv_amt",
|
||||
"PV_Pre_DB": "pv_pre_db",
|
||||
"PV_Annuity": "pv_annuity",
|
||||
"WV_AT": "wv_at",
|
||||
"PV_Plan": "pv_plan",
|
||||
"Years_Married": "years_married",
|
||||
"Years_Service": "years_service",
|
||||
"Marr_Per": "marr_per",
|
||||
"Marr_Amt": "marr_amt"
|
||||
},
|
||||
# Additional CSV file mappings
|
||||
"DEPOSITS.csv": {
|
||||
@@ -357,7 +469,7 @@ FIELD_MAPPINGS = {
|
||||
}
|
||||
|
||||
|
||||
def parse_date(date_str: str) -> Optional[datetime]:
|
||||
def parse_date(date_str: str) -> Optional[date]:
|
||||
"""Parse date string in various formats"""
|
||||
if not date_str or date_str.strip() == "":
|
||||
return None
|
||||
@@ -612,7 +724,11 @@ def convert_value(value: str, field_name: str) -> Any:
|
||||
return parsed_date
|
||||
|
||||
# Boolean fields
|
||||
if any(word in field_name.lower() for word in ["active", "default_printer", "billed", "transferable"]):
|
||||
if any(word in field_name.lower() for word in [
|
||||
"active", "default_printer", "billed", "transferable", "send",
|
||||
# PrinterSetup legacy toggles
|
||||
"phone_book", "rolodex_info", "envelope", "file_cabinet", "accounts", "statements", "calendar"
|
||||
]):
|
||||
if value.lower() in ["true", "1", "yes", "y", "on", "active"]:
|
||||
return True
|
||||
elif value.lower() in ["false", "0", "no", "n", "off", "inactive"]:
|
||||
@@ -621,7 +737,11 @@ def convert_value(value: str, field_name: str) -> Any:
|
||||
return None
|
||||
|
||||
# Numeric fields (float)
|
||||
if any(word in field_name.lower() for word in ["rate", "hour", "bal", "fee", "amount", "owing", "transfer", "valu", "accrued", "vested", "cola", "tax", "percent", "benefit_amount", "mortality", "value"]):
|
||||
if any(word in field_name.lower() for word in [
|
||||
"rate", "hour", "bal", "fee", "amount", "owing", "transfer", "valu",
|
||||
"accrued", "vested", "cola", "tax", "percent", "benefit_amount", "mortality",
|
||||
"value"
|
||||
]) or field_name.lower().startswith(("na_", "le_")):
|
||||
try:
|
||||
# Remove currency symbols and commas
|
||||
cleaned_value = value.replace("$", "").replace(",", "").replace("%", "")
|
||||
@@ -630,7 +750,9 @@ def convert_value(value: str, field_name: str) -> Any:
|
||||
return 0.0
|
||||
|
||||
# Integer fields
|
||||
if any(word in field_name.lower() for word in ["item_no", "age", "start_age", "version", "line_number", "sort_order", "empl_num"]):
|
||||
if any(word in field_name.lower() for word in [
|
||||
"item_no", "age", "start_age", "version", "line_number", "sort_order", "empl_num", "month", "number"
|
||||
]):
|
||||
try:
|
||||
return int(float(value)) # Handle cases like "1.0"
|
||||
except ValueError:
|
||||
@@ -673,11 +795,18 @@ async def get_available_csv_files(current_user: User = Depends(get_current_user)
|
||||
"available_files": list(CSV_MODEL_MAPPING.keys()),
|
||||
"descriptions": {
|
||||
"ROLODEX.csv": "Customer/contact information",
|
||||
"ROLEX_V.csv": "Customer/contact information (alias)",
|
||||
"PHONE.csv": "Phone numbers linked to customers",
|
||||
"FILES.csv": "Client files and cases",
|
||||
"FILES_R.csv": "Client files and cases (alias)",
|
||||
"FILES_V.csv": "Client files and cases (alias)",
|
||||
"LEDGER.csv": "Financial transactions per file",
|
||||
"QDROS.csv": "Legal documents and court orders",
|
||||
"PENSIONS.csv": "Pension calculation data",
|
||||
"SCHEDULE.csv": "Vesting schedules for pensions",
|
||||
"MARRIAGE.csv": "Marriage history data",
|
||||
"DEATH.csv": "Death benefit calculations",
|
||||
"SEPARATE.csv": "Separation agreements",
|
||||
"EMPLOYEE.csv": "Staff and employee information",
|
||||
"STATES.csv": "US States lookup table",
|
||||
"FILETYPE.csv": "File type categories",
|
||||
@@ -688,7 +817,12 @@ async def get_available_csv_files(current_user: User = Depends(get_current_user)
|
||||
"FVARLKUP.csv": "Form template variables",
|
||||
"RVARLKUP.csv": "Report template variables",
|
||||
"PAYMENTS.csv": "Individual payments within deposits",
|
||||
"TRNSACTN.csv": "Transaction details (maps to Ledger)"
|
||||
"TRNSACTN.csv": "Transaction details (maps to Ledger)",
|
||||
"INX_LKUP.csv": "Form keywords lookup",
|
||||
"PLANINFO.csv": "Pension plan information",
|
||||
"RESULTS.csv": "Pension computed results",
|
||||
"LIFETABL.csv": "Life expectancy table by age, sex, and race (rich typed)",
|
||||
"NUMBERAL.csv": "Monthly survivor counts by sex and race (rich typed)"
|
||||
},
|
||||
"auto_discovery": True
|
||||
}
|
||||
@@ -724,7 +858,7 @@ async def import_csv_data(
|
||||
content = await file.read()
|
||||
|
||||
# Try multiple encodings for legacy CSV files
|
||||
encodings = ['utf-8', 'windows-1252', 'iso-8859-1', 'cp1252']
|
||||
encodings = ENCODINGS
|
||||
csv_content = None
|
||||
for encoding in encodings:
|
||||
try:
|
||||
@@ -736,34 +870,7 @@ async def import_csv_data(
|
||||
if csv_content is None:
|
||||
raise HTTPException(status_code=400, detail="Could not decode CSV file. Please ensure it's saved in UTF-8, Windows-1252, or ISO-8859-1 encoding.")
|
||||
|
||||
# Preprocess CSV content to fix common legacy issues
|
||||
def preprocess_csv(content):
|
||||
lines = content.split('\n')
|
||||
cleaned_lines = []
|
||||
i = 0
|
||||
|
||||
while i < len(lines):
|
||||
line = lines[i]
|
||||
# If line doesn't have the expected number of commas, it might be a broken multi-line field
|
||||
if i == 0: # Header line
|
||||
cleaned_lines.append(line)
|
||||
expected_comma_count = line.count(',')
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# Check if this line has the expected number of commas
|
||||
if line.count(',') < expected_comma_count:
|
||||
# This might be a continuation of the previous line
|
||||
# Try to merge with previous line
|
||||
if cleaned_lines:
|
||||
cleaned_lines[-1] += " " + line.replace('\n', ' ').replace('\r', ' ')
|
||||
else:
|
||||
cleaned_lines.append(line)
|
||||
else:
|
||||
cleaned_lines.append(line)
|
||||
i += 1
|
||||
|
||||
return '\n'.join(cleaned_lines)
|
||||
# Note: preprocess_csv helper removed as unused; robust parsing handled below
|
||||
|
||||
# Custom robust parser for problematic legacy CSV files
|
||||
class MockCSVReader:
|
||||
@@ -791,7 +898,7 @@ async def import_csv_data(
|
||||
header_reader = csv.reader(io.StringIO(lines[0]))
|
||||
headers = next(header_reader)
|
||||
headers = [h.strip() for h in headers]
|
||||
print(f"DEBUG: Found {len(headers)} headers: {headers}")
|
||||
# Debug logging removed in API path; rely on audit/logging if needed
|
||||
# Build dynamic header mapping for this file/model
|
||||
mapping_info = _build_dynamic_mapping(headers, model_class, file_type)
|
||||
|
||||
@@ -829,17 +936,21 @@ async def import_csv_data(
|
||||
continue
|
||||
|
||||
csv_reader = MockCSVReader(rows_data, headers)
|
||||
print(f"SUCCESS: Parsed {len(rows_data)} rows (skipped {skipped_rows} malformed rows)")
|
||||
# Parsing summary suppressed to avoid noisy stdout in API
|
||||
|
||||
except Exception as e:
|
||||
print(f"Custom parsing failed: {e}")
|
||||
# Keep error minimal for client; internal logging can capture 'e'
|
||||
raise HTTPException(status_code=400, detail=f"Could not parse CSV file. The file appears to have serious formatting issues. Error: {str(e)}")
|
||||
|
||||
imported_count = 0
|
||||
created_count = 0
|
||||
updated_count = 0
|
||||
errors = []
|
||||
flexible_saved = 0
|
||||
mapped_headers = mapping_info.get("mapped_headers", {})
|
||||
unmapped_headers = mapping_info.get("unmapped_headers", [])
|
||||
# Special handling: assign line numbers per form for FORM_LST.csv
|
||||
form_lst_line_counters: Dict[str, int] = {}
|
||||
|
||||
# If replace_existing is True, delete all existing records and related flexible extras
|
||||
if replace_existing:
|
||||
@@ -860,6 +971,16 @@ async def import_csv_data(
|
||||
converted_value = convert_value(row[csv_field], db_field)
|
||||
if converted_value is not None:
|
||||
model_data[db_field] = converted_value
|
||||
|
||||
# Inject sequential line_number for FORM_LST rows grouped by form_id
|
||||
if file_type == "FORM_LST.csv":
|
||||
form_id_value = model_data.get("form_id")
|
||||
if form_id_value:
|
||||
current = form_lst_line_counters.get(str(form_id_value), 0) + 1
|
||||
form_lst_line_counters[str(form_id_value)] = current
|
||||
# Only set if not provided
|
||||
if "line_number" not in model_data:
|
||||
model_data["line_number"] = current
|
||||
|
||||
# Skip empty rows
|
||||
if not any(model_data.values()):
|
||||
@@ -902,10 +1023,43 @@ async def import_csv_data(
|
||||
if 'file_no' not in model_data or not model_data['file_no']:
|
||||
continue # Skip ledger records without file number
|
||||
|
||||
# Create model instance
|
||||
instance = model_class(**model_data)
|
||||
db.add(instance)
|
||||
db.flush() # Ensure PK is available
|
||||
# Create or update model instance
|
||||
instance = None
|
||||
# Upsert behavior for printers
|
||||
if model_class == PrinterSetup:
|
||||
# Determine primary key field name
|
||||
_, pk_names = _get_model_columns(model_class)
|
||||
pk_field_name_local = pk_names[0] if len(pk_names) == 1 else None
|
||||
pk_value_local = model_data.get(pk_field_name_local) if pk_field_name_local else None
|
||||
if pk_field_name_local and pk_value_local:
|
||||
existing = db.query(model_class).filter(getattr(model_class, pk_field_name_local) == pk_value_local).first()
|
||||
if existing:
|
||||
# Update mutable fields
|
||||
for k, v in model_data.items():
|
||||
if k != pk_field_name_local:
|
||||
setattr(existing, k, v)
|
||||
instance = existing
|
||||
updated_count += 1
|
||||
else:
|
||||
instance = model_class(**model_data)
|
||||
db.add(instance)
|
||||
created_count += 1
|
||||
else:
|
||||
# Fallback to insert if PK missing
|
||||
instance = model_class(**model_data)
|
||||
db.add(instance)
|
||||
created_count += 1
|
||||
db.flush()
|
||||
# Enforce single default
|
||||
try:
|
||||
if bool(model_data.get("default_printer")):
|
||||
db.query(model_class).filter(getattr(model_class, pk_field_name_local) != getattr(instance, pk_field_name_local)).update({model_class.default_printer: False})
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
instance = model_class(**model_data)
|
||||
db.add(instance)
|
||||
db.flush() # Ensure PK is available
|
||||
|
||||
# Capture PK details for flexible storage linkage (single-column PKs only)
|
||||
_, pk_names = _get_model_columns(model_class)
|
||||
@@ -980,6 +1134,10 @@ async def import_csv_data(
|
||||
"flexible_saved_rows": flexible_saved,
|
||||
},
|
||||
}
|
||||
# Include create/update breakdown for printers
|
||||
if file_type == "PRINTERS.csv":
|
||||
result["created_count"] = created_count
|
||||
result["updated_count"] = updated_count
|
||||
|
||||
if errors:
|
||||
result["warning"] = f"Import completed with {len(errors)} errors"
|
||||
@@ -987,9 +1145,7 @@ async def import_csv_data(
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
print(f"IMPORT ERROR DEBUG: {type(e).__name__}: {str(e)}")
|
||||
import traceback
|
||||
print(f"TRACEBACK: {traceback.format_exc()}")
|
||||
# Suppress stdout debug prints in API layer
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=500, detail=f"Import failed: {str(e)}")
|
||||
|
||||
@@ -1071,7 +1227,7 @@ async def validate_csv_file(
|
||||
content = await file.read()
|
||||
|
||||
# Try multiple encodings for legacy CSV files
|
||||
encodings = ['utf-8', 'windows-1252', 'iso-8859-1', 'cp1252']
|
||||
encodings = ENCODINGS
|
||||
csv_content = None
|
||||
for encoding in encodings:
|
||||
try:
|
||||
@@ -1083,18 +1239,6 @@ async def validate_csv_file(
|
||||
if csv_content is None:
|
||||
raise HTTPException(status_code=400, detail="Could not decode CSV file. Please ensure it's saved in UTF-8, Windows-1252, or ISO-8859-1 encoding.")
|
||||
|
||||
# Parse CSV with fallback to robust line-by-line parsing
|
||||
def parse_csv_with_fallback(text: str) -> Tuple[List[Dict[str, str]], List[str]]:
|
||||
try:
|
||||
reader = csv.DictReader(io.StringIO(text), delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
|
||||
headers_local = reader.fieldnames or []
|
||||
rows_local = []
|
||||
for r in reader:
|
||||
rows_local.append(r)
|
||||
return rows_local, headers_local
|
||||
except Exception:
|
||||
return parse_csv_robust(text)
|
||||
|
||||
rows_list, csv_headers = parse_csv_with_fallback(csv_content)
|
||||
model_class = CSV_MODEL_MAPPING[file_type]
|
||||
mapping_info = _build_dynamic_mapping(csv_headers, model_class, file_type)
|
||||
@@ -1142,9 +1286,7 @@ async def validate_csv_file(
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"VALIDATION ERROR DEBUG: {type(e).__name__}: {str(e)}")
|
||||
import traceback
|
||||
print(f"VALIDATION TRACEBACK: {traceback.format_exc()}")
|
||||
# Suppress stdout debug prints in API layer
|
||||
raise HTTPException(status_code=500, detail=f"Validation failed: {str(e)}")
|
||||
|
||||
|
||||
@@ -1199,7 +1341,7 @@ async def batch_validate_csv_files(
|
||||
content = await file.read()
|
||||
|
||||
# Try multiple encodings for legacy CSV files (include BOM-friendly utf-8-sig)
|
||||
encodings = ['utf-8-sig', 'utf-8', 'windows-1252', 'iso-8859-1', 'cp1252']
|
||||
encodings = ENCODINGS
|
||||
csv_content = None
|
||||
for encoding in encodings:
|
||||
try:
|
||||
@@ -1302,13 +1444,7 @@ async def batch_import_csv_files(
|
||||
raise HTTPException(status_code=400, detail="Maximum 25 files allowed per batch")
|
||||
|
||||
# Define optimal import order based on dependencies
|
||||
import_order = [
|
||||
"STATES.csv", "GRUPLKUP.csv", "EMPLOYEE.csv", "FILETYPE.csv", "FILESTAT.csv",
|
||||
"TRNSTYPE.csv", "TRNSLKUP.csv", "FOOTERS.csv", "SETUP.csv", "PRINTERS.csv",
|
||||
"ROLODEX.csv", "PHONE.csv", "FILES.csv", "LEDGER.csv", "TRNSACTN.csv",
|
||||
"QDROS.csv", "PENSIONS.csv", "PLANINFO.csv", "PAYMENTS.csv", "DEPOSITS.csv",
|
||||
"FILENOTS.csv", "FORM_INX.csv", "FORM_LST.csv", "FVARLKUP.csv", "RVARLKUP.csv"
|
||||
]
|
||||
import_order = IMPORT_ORDER
|
||||
|
||||
# Sort uploaded files by optimal import order
|
||||
file_map = {f.filename: f for f in files}
|
||||
@@ -1365,7 +1501,7 @@ async def batch_import_csv_files(
|
||||
saved_path = str(file_path)
|
||||
except Exception:
|
||||
saved_path = None
|
||||
encodings = ['utf-8-sig', 'utf-8', 'windows-1252', 'iso-8859-1', 'cp1252']
|
||||
encodings = ENCODINGS
|
||||
csv_content = None
|
||||
for encoding in encodings:
|
||||
try:
|
||||
@@ -1466,7 +1602,7 @@ async def batch_import_csv_files(
|
||||
saved_path = None
|
||||
|
||||
# Try multiple encodings for legacy CSV files
|
||||
encodings = ['utf-8-sig', 'utf-8', 'windows-1252', 'iso-8859-1', 'cp1252']
|
||||
encodings = ENCODINGS
|
||||
csv_content = None
|
||||
for encoding in encodings:
|
||||
try:
|
||||
@@ -1505,6 +1641,8 @@ async def batch_import_csv_files(
|
||||
imported_count = 0
|
||||
errors = []
|
||||
flexible_saved = 0
|
||||
# Special handling: assign line numbers per form for FORM_LST.csv
|
||||
form_lst_line_counters: Dict[str, int] = {}
|
||||
|
||||
# If replace_existing is True and this is the first file of this type
|
||||
if replace_existing:
|
||||
@@ -1523,6 +1661,15 @@ async def batch_import_csv_files(
|
||||
converted_value = convert_value(row[csv_field], db_field)
|
||||
if converted_value is not None:
|
||||
model_data[db_field] = converted_value
|
||||
|
||||
# Inject sequential line_number for FORM_LST rows grouped by form_id
|
||||
if file_type == "FORM_LST.csv":
|
||||
form_id_value = model_data.get("form_id")
|
||||
if form_id_value:
|
||||
current = form_lst_line_counters.get(str(form_id_value), 0) + 1
|
||||
form_lst_line_counters[str(form_id_value)] = current
|
||||
if "line_number" not in model_data:
|
||||
model_data["line_number"] = current
|
||||
|
||||
if not any(model_data.values()):
|
||||
continue
|
||||
@@ -1697,7 +1844,7 @@ async def batch_import_csv_files(
|
||||
"completed_with_errors" if summary["successful_files"] > 0 else "failed"
|
||||
)
|
||||
audit_row.message = f"Batch import completed: {audit_row.successful_files}/{audit_row.total_files} files"
|
||||
audit_row.finished_at = datetime.utcnow()
|
||||
audit_row.finished_at = datetime.now(timezone.utc)
|
||||
audit_row.details = {
|
||||
"files": [
|
||||
{"file_type": r.get("file_type"), "status": r.get("status"), "imported_count": r.get("imported_count", 0), "errors": r.get("errors", 0)}
|
||||
@@ -1844,13 +1991,7 @@ async def rerun_failed_files(
|
||||
raise HTTPException(status_code=400, detail="No saved files available to rerun. Upload again.")
|
||||
|
||||
# Import order for sorting
|
||||
import_order = [
|
||||
"STATES.csv", "GRUPLKUP.csv", "EMPLOYEE.csv", "FILETYPE.csv", "FILESTAT.csv",
|
||||
"TRNSTYPE.csv", "TRNSLKUP.csv", "FOOTERS.csv", "SETUP.csv", "PRINTERS.csv",
|
||||
"ROLODEX.csv", "PHONE.csv", "FILES.csv", "LEDGER.csv", "TRNSACTN.csv",
|
||||
"QDROS.csv", "PENSIONS.csv", "PLANINFO.csv", "PAYMENTS.csv", "DEPOSITS.csv",
|
||||
"FILENOTS.csv", "FORM_INX.csv", "FORM_LST.csv", "FVARLKUP.csv", "RVARLKUP.csv"
|
||||
]
|
||||
import_order = IMPORT_ORDER
|
||||
order_index = {name: i for i, name in enumerate(import_order)}
|
||||
items.sort(key=lambda x: order_index.get(x[0], len(import_order) + 1))
|
||||
|
||||
@@ -1898,7 +2039,7 @@ async def rerun_failed_files(
|
||||
|
||||
if file_type not in CSV_MODEL_MAPPING:
|
||||
# Flexible-only path
|
||||
encodings = ['utf-8-sig', 'utf-8', 'windows-1252', 'iso-8859-1', 'cp1252']
|
||||
encodings = ENCODINGS
|
||||
csv_content = None
|
||||
for enc in encodings:
|
||||
try:
|
||||
@@ -1964,7 +2105,7 @@ async def rerun_failed_files(
|
||||
|
||||
# Known model path
|
||||
model_class = CSV_MODEL_MAPPING[file_type]
|
||||
encodings = ['utf-8-sig', 'utf-8', 'windows-1252', 'iso-8859-1', 'cp1252']
|
||||
encodings = ENCODINGS
|
||||
csv_content = None
|
||||
for enc in encodings:
|
||||
try:
|
||||
@@ -1996,6 +2137,8 @@ async def rerun_failed_files(
|
||||
unmapped_headers = mapping_info["unmapped_headers"]
|
||||
imported_count = 0
|
||||
errors: List[Dict[str, Any]] = []
|
||||
# Special handling: assign line numbers per form for FORM_LST.csv
|
||||
form_lst_line_counters: Dict[str, int] = {}
|
||||
|
||||
if replace_existing:
|
||||
db.query(model_class).delete()
|
||||
@@ -2013,6 +2156,14 @@ async def rerun_failed_files(
|
||||
converted_value = convert_value(row[csv_field], db_field)
|
||||
if converted_value is not None:
|
||||
model_data[db_field] = converted_value
|
||||
# Inject sequential line_number for FORM_LST rows grouped by form_id
|
||||
if file_type == "FORM_LST.csv":
|
||||
form_id_value = model_data.get("form_id")
|
||||
if form_id_value:
|
||||
current = form_lst_line_counters.get(str(form_id_value), 0) + 1
|
||||
form_lst_line_counters[str(form_id_value)] = current
|
||||
if "line_number" not in model_data:
|
||||
model_data["line_number"] = current
|
||||
if not any(model_data.values()):
|
||||
continue
|
||||
required_fields = _get_required_fields(model_class)
|
||||
@@ -2147,7 +2298,7 @@ async def rerun_failed_files(
|
||||
"completed_with_errors" if summary["successful_files"] > 0 else "failed"
|
||||
)
|
||||
rerun_audit.message = f"Rerun completed: {rerun_audit.successful_files}/{rerun_audit.total_files} files"
|
||||
rerun_audit.finished_at = datetime.utcnow()
|
||||
rerun_audit.finished_at = datetime.now(timezone.utc)
|
||||
rerun_audit.details = {"rerun_of": audit_id}
|
||||
db.add(rerun_audit)
|
||||
db.commit()
|
||||
@@ -2183,7 +2334,7 @@ async def upload_flexible_only(
|
||||
db.commit()
|
||||
|
||||
content = await file.read()
|
||||
encodings = ["utf-8-sig", "utf-8", "windows-1252", "iso-8859-1", "cp1252"]
|
||||
encodings = ENCODINGS
|
||||
csv_content = None
|
||||
for encoding in encodings:
|
||||
try:
|
||||
|
||||
72
app/api/mortality.py
Normal file
72
app/api/mortality.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""
|
||||
Mortality/Life Table API endpoints
|
||||
|
||||
Provides read endpoints to query life tables by age and number tables by month,
|
||||
filtered by sex (M/F/A) and race (W/B/H/A).
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status, Path
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database.base import get_db
|
||||
from app.models.user import User
|
||||
from app.auth.security import get_current_user
|
||||
from app.services.mortality import get_life_values, get_number_value, InvalidCodeError
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class LifeResponse(BaseModel):
|
||||
age: int
|
||||
sex: str = Field(description="M, F, or A (all)")
|
||||
race: str = Field(description="W, B, H, or A (all)")
|
||||
le: Optional[float]
|
||||
na: Optional[float]
|
||||
|
||||
|
||||
class NumberResponse(BaseModel):
|
||||
month: int
|
||||
sex: str = Field(description="M, F, or A (all)")
|
||||
race: str = Field(description="W, B, H, or A (all)")
|
||||
na: Optional[float]
|
||||
|
||||
|
||||
@router.get("/life/{age}", response_model=LifeResponse)
|
||||
async def get_life_entry(
|
||||
age: int = Path(..., ge=0, description="Age in years (>= 0)"),
|
||||
sex: str = Query("A", min_length=1, max_length=1, description="M, F, or A (all)"),
|
||||
race: str = Query("A", min_length=1, max_length=1, description="W, B, H, or A (all)"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Get life expectancy (LE) and number alive (NA) for an age/sex/race."""
|
||||
try:
|
||||
result = get_life_values(db, age=age, sex=sex, race=race)
|
||||
except InvalidCodeError as e:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
if result is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Age not found")
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/number/{month}", response_model=NumberResponse)
|
||||
async def get_number_entry(
|
||||
month: int = Path(..., ge=0, description="Month index (>= 0)"),
|
||||
sex: str = Query("A", min_length=1, max_length=1, description="M, F, or A (all)"),
|
||||
race: str = Query("A", min_length=1, max_length=1, description="W, B, H, or A (all)"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Get monthly number alive (NA) for a month/sex/race."""
|
||||
try:
|
||||
result = get_number_value(db, month=month, sex=sex, race=race)
|
||||
except InvalidCodeError as e:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
if result is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Month not found")
|
||||
return result
|
||||
|
||||
|
||||
1019
app/api/search.py
1019
app/api/search.py
File diff suppressed because it is too large
Load Diff
@@ -2,8 +2,10 @@
|
||||
Server-side highlight utilities for search results.
|
||||
|
||||
These functions generate HTML snippets with <strong> around matched tokens,
|
||||
preserving the original casing of the source text. The output is intended to be
|
||||
sanitized on the client before insertion into the DOM.
|
||||
preserving the original casing of the source text. All non-HTML segments are
|
||||
HTML-escaped server-side to prevent injection. Only the <strong> tags added by
|
||||
this module are emitted as HTML; any pre-existing HTML in source text is
|
||||
escaped.
|
||||
"""
|
||||
from typing import List, Tuple, Any
|
||||
import re
|
||||
@@ -42,18 +44,40 @@ def _merge_ranges(ranges: List[Tuple[int, int]]) -> List[Tuple[int, int]]:
|
||||
|
||||
|
||||
def highlight_text(value: str, tokens: List[str]) -> str:
|
||||
"""Return `value` with case-insensitive matches of `tokens` wrapped in <strong>, preserving original casing."""
|
||||
"""Return `value` with case-insensitive matches of `tokens` wrapped in <strong>, preserving original casing.
|
||||
|
||||
Non-highlighted segments and the highlighted text content are HTML-escaped.
|
||||
Only the surrounding <strong> wrappers are emitted as markup.
|
||||
"""
|
||||
if value is None:
|
||||
return ""
|
||||
|
||||
def _escape_html(text: str) -> str:
|
||||
# Minimal, safe HTML escaping
|
||||
if text is None:
|
||||
return ""
|
||||
# Replace ampersand first to avoid double-escaping
|
||||
text = str(text)
|
||||
text = text.replace("&", "&")
|
||||
text = text.replace("<", "<")
|
||||
text = text.replace(">", ">")
|
||||
text = text.replace('"', """)
|
||||
text = text.replace("'", "'")
|
||||
return text
|
||||
source = str(value)
|
||||
if not source or not tokens:
|
||||
return source
|
||||
return _escape_html(source)
|
||||
haystack = source.lower()
|
||||
ranges: List[Tuple[int, int]] = []
|
||||
# Deduplicate tokens case-insensitively to avoid redundant scans (parity with client)
|
||||
unique_needles = []
|
||||
seen_needles = set()
|
||||
for t in tokens:
|
||||
needle = str(t or "").lower()
|
||||
if not needle:
|
||||
continue
|
||||
if needle and needle not in seen_needles:
|
||||
unique_needles.append(needle)
|
||||
seen_needles.add(needle)
|
||||
for needle in unique_needles:
|
||||
start = 0
|
||||
last_possible = max(0, len(haystack) - len(needle))
|
||||
while start <= last_possible and len(needle) > 0:
|
||||
@@ -63,17 +87,17 @@ def highlight_text(value: str, tokens: List[str]) -> str:
|
||||
ranges.append((idx, idx + len(needle)))
|
||||
start = idx + 1
|
||||
if not ranges:
|
||||
return source
|
||||
return _escape_html(source)
|
||||
parts: List[str] = []
|
||||
merged = _merge_ranges(ranges)
|
||||
pos = 0
|
||||
for s, e in merged:
|
||||
if pos < s:
|
||||
parts.append(source[pos:s])
|
||||
parts.append("<strong>" + source[s:e] + "</strong>")
|
||||
parts.append(_escape_html(source[pos:s]))
|
||||
parts.append("<strong>" + _escape_html(source[s:e]) + "</strong>")
|
||||
pos = e
|
||||
if pos < len(source):
|
||||
parts.append(source[pos:])
|
||||
parts.append(_escape_html(source[pos:]))
|
||||
return "".join(parts)
|
||||
|
||||
|
||||
|
||||
@@ -2,21 +2,24 @@
|
||||
Support ticket API endpoints
|
||||
"""
|
||||
from typing import List, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Request, Query
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
from sqlalchemy import func, desc, and_, or_
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
import secrets
|
||||
|
||||
from app.database.base import get_db
|
||||
from app.models import User, SupportTicket, TicketResponse as TicketResponseModel, TicketStatus, TicketPriority, TicketCategory
|
||||
from app.auth.security import get_current_user, get_admin_user
|
||||
from app.services.audit import audit_service
|
||||
from app.services.query_utils import apply_sorting, paginate_with_total, tokenized_ilike_filter
|
||||
from app.api.search_highlight import build_query_tokens
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# Pydantic models for API
|
||||
from pydantic import BaseModel, Field, EmailStr
|
||||
from pydantic.config import ConfigDict
|
||||
|
||||
|
||||
class TicketCreate(BaseModel):
|
||||
@@ -57,8 +60,7 @@ class TicketResponseOut(BaseModel):
|
||||
author_email: Optional[str]
|
||||
created_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class TicketDetail(BaseModel):
|
||||
@@ -81,15 +83,19 @@ class TicketDetail(BaseModel):
|
||||
assigned_to: Optional[int]
|
||||
assigned_admin_name: Optional[str]
|
||||
submitter_name: Optional[str]
|
||||
responses: List[TicketResponseOut] = []
|
||||
responses: List[TicketResponseOut] = Field(default_factory=list)
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class PaginatedTicketsResponse(BaseModel):
|
||||
items: List[TicketDetail]
|
||||
total: int
|
||||
|
||||
|
||||
def generate_ticket_number() -> str:
|
||||
"""Generate unique ticket number like ST-2024-001"""
|
||||
year = datetime.now().year
|
||||
year = datetime.now(timezone.utc).year
|
||||
random_suffix = secrets.token_hex(2).upper()
|
||||
return f"ST-{year}-{random_suffix}"
|
||||
|
||||
@@ -129,7 +135,7 @@ async def create_support_ticket(
|
||||
ip_address=client_ip,
|
||||
user_id=current_user.id if current_user else None,
|
||||
status=TicketStatus.OPEN,
|
||||
created_at=datetime.utcnow()
|
||||
created_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
db.add(new_ticket)
|
||||
@@ -158,14 +164,18 @@ async def create_support_ticket(
|
||||
}
|
||||
|
||||
|
||||
@router.get("/tickets", response_model=List[TicketDetail])
|
||||
@router.get("/tickets", response_model=List[TicketDetail] | PaginatedTicketsResponse)
|
||||
async def list_tickets(
|
||||
status: Optional[TicketStatus] = None,
|
||||
priority: Optional[TicketPriority] = None,
|
||||
category: Optional[TicketCategory] = None,
|
||||
assigned_to_me: bool = False,
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
status: Optional[TicketStatus] = Query(None, description="Filter by ticket status"),
|
||||
priority: Optional[TicketPriority] = Query(None, description="Filter by ticket priority"),
|
||||
category: Optional[TicketCategory] = Query(None, description="Filter by ticket category"),
|
||||
assigned_to_me: bool = Query(False, description="Only include tickets assigned to the current admin"),
|
||||
search: Optional[str] = Query(None, description="Tokenized search across number, subject, description, contact name/email, current page, and IP"),
|
||||
skip: int = Query(0, ge=0, description="Offset for pagination"),
|
||||
limit: int = Query(50, ge=1, le=200, description="Page size"),
|
||||
sort_by: Optional[str] = Query(None, description="Sort by: created, updated, resolved, priority, status, subject"),
|
||||
sort_dir: Optional[str] = Query("desc", description="Sort direction: asc or desc"),
|
||||
include_total: bool = Query(False, description="When true, returns {items, total} instead of a plain list"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_admin_user)
|
||||
):
|
||||
@@ -186,8 +196,38 @@ async def list_tickets(
|
||||
query = query.filter(SupportTicket.category == category)
|
||||
if assigned_to_me:
|
||||
query = query.filter(SupportTicket.assigned_to == current_user.id)
|
||||
|
||||
tickets = query.order_by(desc(SupportTicket.created_at)).offset(skip).limit(limit).all()
|
||||
|
||||
# Search across key text fields
|
||||
if search:
|
||||
tokens = build_query_tokens(search)
|
||||
filter_expr = tokenized_ilike_filter(tokens, [
|
||||
SupportTicket.ticket_number,
|
||||
SupportTicket.subject,
|
||||
SupportTicket.description,
|
||||
SupportTicket.contact_name,
|
||||
SupportTicket.contact_email,
|
||||
SupportTicket.current_page,
|
||||
SupportTicket.ip_address,
|
||||
])
|
||||
if filter_expr is not None:
|
||||
query = query.filter(filter_expr)
|
||||
|
||||
# Sorting (whitelisted)
|
||||
query = apply_sorting(
|
||||
query,
|
||||
sort_by,
|
||||
sort_dir,
|
||||
allowed={
|
||||
"created": [SupportTicket.created_at],
|
||||
"updated": [SupportTicket.updated_at],
|
||||
"resolved": [SupportTicket.resolved_at],
|
||||
"priority": [SupportTicket.priority],
|
||||
"status": [SupportTicket.status],
|
||||
"subject": [SupportTicket.subject],
|
||||
},
|
||||
)
|
||||
|
||||
tickets, total = paginate_with_total(query, skip, limit, include_total)
|
||||
|
||||
# Format response
|
||||
result = []
|
||||
@@ -226,6 +266,8 @@ async def list_tickets(
|
||||
}
|
||||
result.append(ticket_dict)
|
||||
|
||||
if include_total:
|
||||
return {"items": result, "total": total or 0}
|
||||
return result
|
||||
|
||||
|
||||
@@ -312,10 +354,10 @@ async def update_ticket(
|
||||
|
||||
# Set resolved timestamp if status changed to resolved
|
||||
if ticket_data.status == TicketStatus.RESOLVED and ticket.resolved_at is None:
|
||||
ticket.resolved_at = datetime.utcnow()
|
||||
ticket.resolved_at = datetime.now(timezone.utc)
|
||||
changes["resolved_at"] = {"from": None, "to": ticket.resolved_at}
|
||||
|
||||
ticket.updated_at = datetime.utcnow()
|
||||
ticket.updated_at = datetime.now(timezone.utc)
|
||||
db.commit()
|
||||
|
||||
# Audit logging (non-blocking)
|
||||
@@ -358,13 +400,13 @@ async def add_response(
|
||||
message=response_data.message,
|
||||
is_internal=response_data.is_internal,
|
||||
user_id=current_user.id,
|
||||
created_at=datetime.utcnow()
|
||||
created_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
db.add(response)
|
||||
|
||||
# Update ticket timestamp
|
||||
ticket.updated_at = datetime.utcnow()
|
||||
ticket.updated_at = datetime.now(timezone.utc)
|
||||
|
||||
db.commit()
|
||||
db.refresh(response)
|
||||
@@ -386,11 +428,15 @@ async def add_response(
|
||||
return {"message": "Response added successfully", "response_id": response.id}
|
||||
|
||||
|
||||
@router.get("/my-tickets", response_model=List[TicketDetail])
|
||||
@router.get("/my-tickets", response_model=List[TicketDetail] | PaginatedTicketsResponse)
|
||||
async def get_my_tickets(
|
||||
status: Optional[TicketStatus] = None,
|
||||
skip: int = 0,
|
||||
limit: int = 20,
|
||||
status: Optional[TicketStatus] = Query(None, description="Filter by ticket status"),
|
||||
search: Optional[str] = Query(None, description="Tokenized search across number, subject, description"),
|
||||
skip: int = Query(0, ge=0, description="Offset for pagination"),
|
||||
limit: int = Query(20, ge=1, le=200, description="Page size"),
|
||||
sort_by: Optional[str] = Query(None, description="Sort by: created, updated, resolved, priority, status, subject"),
|
||||
sort_dir: Optional[str] = Query("desc", description="Sort direction: asc or desc"),
|
||||
include_total: bool = Query(False, description="When true, returns {items, total} instead of a plain list"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
@@ -403,7 +449,33 @@ async def get_my_tickets(
|
||||
if status:
|
||||
query = query.filter(SupportTicket.status == status)
|
||||
|
||||
tickets = query.order_by(desc(SupportTicket.created_at)).offset(skip).limit(limit).all()
|
||||
# Search within user's tickets
|
||||
if search:
|
||||
tokens = build_query_tokens(search)
|
||||
filter_expr = tokenized_ilike_filter(tokens, [
|
||||
SupportTicket.ticket_number,
|
||||
SupportTicket.subject,
|
||||
SupportTicket.description,
|
||||
])
|
||||
if filter_expr is not None:
|
||||
query = query.filter(filter_expr)
|
||||
|
||||
# Sorting (whitelisted)
|
||||
query = apply_sorting(
|
||||
query,
|
||||
sort_by,
|
||||
sort_dir,
|
||||
allowed={
|
||||
"created": [SupportTicket.created_at],
|
||||
"updated": [SupportTicket.updated_at],
|
||||
"resolved": [SupportTicket.resolved_at],
|
||||
"priority": [SupportTicket.priority],
|
||||
"status": [SupportTicket.status],
|
||||
"subject": [SupportTicket.subject],
|
||||
},
|
||||
)
|
||||
|
||||
tickets, total = paginate_with_total(query, skip, limit, include_total)
|
||||
|
||||
# Format response (exclude internal responses for regular users)
|
||||
result = []
|
||||
@@ -442,6 +514,8 @@ async def get_my_tickets(
|
||||
}
|
||||
result.append(ticket_dict)
|
||||
|
||||
if include_total:
|
||||
return {"items": result, "total": total or 0}
|
||||
return result
|
||||
|
||||
|
||||
@@ -473,7 +547,7 @@ async def get_ticket_stats(
|
||||
|
||||
# Recent tickets (last 7 days)
|
||||
from datetime import timedelta
|
||||
week_ago = datetime.utcnow() - timedelta(days=7)
|
||||
week_ago = datetime.now(timezone.utc) - timedelta(days=7)
|
||||
recent_tickets = db.query(func.count(SupportTicket.id)).filter(
|
||||
SupportTicket.created_at >= week_ago
|
||||
).scalar()
|
||||
|
||||
@@ -3,6 +3,7 @@ Authentication schemas
|
||||
"""
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, EmailStr
|
||||
from pydantic.config import ConfigDict
|
||||
|
||||
|
||||
class UserBase(BaseModel):
|
||||
@@ -32,8 +33,7 @@ class UserResponse(UserBase):
|
||||
is_admin: bool
|
||||
theme_preference: Optional[str] = "light"
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class ThemePreferenceUpdate(BaseModel):
|
||||
@@ -45,7 +45,7 @@ class Token(BaseModel):
|
||||
"""Token response schema"""
|
||||
access_token: str
|
||||
token_type: str
|
||||
refresh_token: str | None = None
|
||||
refresh_token: Optional[str] = None
|
||||
|
||||
|
||||
class TokenData(BaseModel):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""
|
||||
Authentication and security utilities
|
||||
"""
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional, Union, Tuple
|
||||
from uuid import uuid4
|
||||
from jose import JWTError, jwt
|
||||
@@ -54,12 +54,12 @@ def _decode_with_rotation(token: str) -> dict:
|
||||
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
||||
"""Create JWT access token"""
|
||||
to_encode = data.copy()
|
||||
expire = datetime.utcnow() + (
|
||||
expire = datetime.now(timezone.utc) + (
|
||||
expires_delta if expires_delta else timedelta(minutes=settings.access_token_expire_minutes)
|
||||
)
|
||||
to_encode.update({
|
||||
"exp": expire,
|
||||
"iat": datetime.utcnow(),
|
||||
"iat": datetime.now(timezone.utc),
|
||||
"type": "access",
|
||||
})
|
||||
return _encode_with_rotation(to_encode)
|
||||
@@ -68,14 +68,14 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -
|
||||
def create_refresh_token(user: User, user_agent: Optional[str], ip_address: Optional[str], db: Session) -> str:
|
||||
"""Create refresh token, store its JTI in DB for revocation."""
|
||||
jti = uuid4().hex
|
||||
expire = datetime.utcnow() + timedelta(minutes=settings.refresh_token_expire_minutes)
|
||||
expire = datetime.now(timezone.utc) + timedelta(minutes=settings.refresh_token_expire_minutes)
|
||||
payload = {
|
||||
"sub": user.username,
|
||||
"uid": user.id,
|
||||
"jti": jti,
|
||||
"type": "refresh",
|
||||
"exp": expire,
|
||||
"iat": datetime.utcnow(),
|
||||
"iat": datetime.now(timezone.utc),
|
||||
}
|
||||
token = _encode_with_rotation(payload)
|
||||
|
||||
@@ -84,7 +84,7 @@ def create_refresh_token(user: User, user_agent: Optional[str], ip_address: Opti
|
||||
jti=jti,
|
||||
user_agent=user_agent,
|
||||
ip_address=ip_address,
|
||||
issued_at=datetime.utcnow(),
|
||||
issued_at=datetime.now(timezone.utc),
|
||||
expires_at=expire,
|
||||
revoked=False,
|
||||
)
|
||||
@@ -93,6 +93,15 @@ def create_refresh_token(user: User, user_agent: Optional[str], ip_address: Opti
|
||||
return token
|
||||
|
||||
|
||||
def _to_utc_aware(dt: Optional[datetime]) -> Optional[datetime]:
|
||||
"""Convert a datetime to UTC-aware. If naive, assume it's already UTC and attach tzinfo."""
|
||||
if dt is None:
|
||||
return None
|
||||
if dt.tzinfo is None:
|
||||
return dt.replace(tzinfo=timezone.utc)
|
||||
return dt.astimezone(timezone.utc)
|
||||
|
||||
|
||||
def verify_token(token: str) -> Optional[str]:
|
||||
"""Verify JWT token and return username"""
|
||||
try:
|
||||
@@ -122,14 +131,20 @@ def decode_refresh_token(token: str) -> Optional[dict]:
|
||||
|
||||
def is_refresh_token_revoked(jti: str, db: Session) -> bool:
|
||||
token_row = db.query(RefreshToken).filter(RefreshToken.jti == jti).first()
|
||||
return not token_row or token_row.revoked or token_row.expires_at <= datetime.utcnow()
|
||||
if not token_row:
|
||||
return True
|
||||
if token_row.revoked:
|
||||
return True
|
||||
expires_at_utc = _to_utc_aware(token_row.expires_at)
|
||||
now_utc = datetime.now(timezone.utc)
|
||||
return expires_at_utc is not None and expires_at_utc <= now_utc
|
||||
|
||||
|
||||
def revoke_refresh_token(jti: str, db: Session) -> None:
|
||||
token_row = db.query(RefreshToken).filter(RefreshToken.jti == jti).first()
|
||||
if token_row and not token_row.revoked:
|
||||
token_row.revoked = True
|
||||
token_row.revoked_at = datetime.utcnow()
|
||||
token_row.revoked_at = datetime.now(timezone.utc)
|
||||
db.commit()
|
||||
|
||||
|
||||
|
||||
@@ -57,6 +57,10 @@ class Settings(BaseSettings):
|
||||
log_rotation: str = "10 MB"
|
||||
log_retention: str = "30 days"
|
||||
|
||||
# Cache / Redis
|
||||
cache_enabled: bool = False
|
||||
redis_url: Optional[str] = None
|
||||
|
||||
# pydantic-settings v2 configuration
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env",
|
||||
|
||||
@@ -2,8 +2,7 @@
|
||||
Database configuration and session management
|
||||
"""
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker, Session
|
||||
from sqlalchemy.orm import declarative_base, sessionmaker, Session
|
||||
from typing import Generator
|
||||
|
||||
from app.config import settings
|
||||
|
||||
248
app/database/fts.py
Normal file
248
app/database/fts.py
Normal file
@@ -0,0 +1,248 @@
|
||||
"""
|
||||
SQLite Full-Text Search (FTS5) helpers.
|
||||
|
||||
Creates and maintains FTS virtual tables and triggers to keep them in sync
|
||||
with their content tables. Designed to be called at app startup.
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy import text
|
||||
|
||||
|
||||
def _execute_ignore_errors(engine: Engine, sql: str) -> None:
|
||||
"""Execute SQL, ignoring operational errors (e.g., when FTS5 is unavailable)."""
|
||||
from sqlalchemy.exc import OperationalError
|
||||
with engine.begin() as conn:
|
||||
try:
|
||||
conn.execute(text(sql))
|
||||
except OperationalError:
|
||||
# Likely FTS5 extension not available in this SQLite build
|
||||
pass
|
||||
|
||||
|
||||
def ensure_rolodex_fts(engine: Engine) -> None:
|
||||
"""Ensure the `rolodex_fts` virtual table and triggers exist and are populated.
|
||||
|
||||
This uses content=rolodex so the FTS table shadows the base table and is kept
|
||||
in sync via triggers.
|
||||
"""
|
||||
# Create virtual table (if FTS5 is available)
|
||||
_create_table = """
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS rolodex_fts USING fts5(
|
||||
id,
|
||||
first,
|
||||
last,
|
||||
city,
|
||||
email,
|
||||
memo,
|
||||
content='rolodex',
|
||||
content_rowid='rowid'
|
||||
);
|
||||
"""
|
||||
_execute_ignore_errors(engine, _create_table)
|
||||
|
||||
# Triggers to keep FTS in sync
|
||||
_triggers = [
|
||||
"""
|
||||
CREATE TRIGGER IF NOT EXISTS rolodex_ai AFTER INSERT ON rolodex BEGIN
|
||||
INSERT INTO rolodex_fts(rowid, id, first, last, city, email, memo)
|
||||
VALUES (new.rowid, new.id, new.first, new.last, new.city, new.email, new.memo);
|
||||
END;
|
||||
""",
|
||||
"""
|
||||
CREATE TRIGGER IF NOT EXISTS rolodex_ad AFTER DELETE ON rolodex BEGIN
|
||||
INSERT INTO rolodex_fts(rolodex_fts, rowid, id, first, last, city, email, memo)
|
||||
VALUES ('delete', old.rowid, old.id, old.first, old.last, old.city, old.email, old.memo);
|
||||
END;
|
||||
""",
|
||||
"""
|
||||
CREATE TRIGGER IF NOT EXISTS rolodex_au AFTER UPDATE ON rolodex BEGIN
|
||||
INSERT INTO rolodex_fts(rolodex_fts, rowid, id, first, last, city, email, memo)
|
||||
VALUES ('delete', old.rowid, old.id, old.first, old.last, old.city, old.email, old.memo);
|
||||
INSERT INTO rolodex_fts(rowid, id, first, last, city, email, memo)
|
||||
VALUES (new.rowid, new.id, new.first, new.last, new.city, new.email, new.memo);
|
||||
END;
|
||||
""",
|
||||
]
|
||||
for trig in _triggers:
|
||||
_execute_ignore_errors(engine, trig)
|
||||
|
||||
# Backfill if the FTS table exists but is empty
|
||||
with engine.begin() as conn:
|
||||
try:
|
||||
count_fts = conn.execute(text("SELECT count(*) FROM rolodex_fts")).scalar() # type: ignore
|
||||
if count_fts == 0:
|
||||
# Populate from existing rolodex rows
|
||||
conn.execute(text(
|
||||
"""
|
||||
INSERT INTO rolodex_fts(rowid, id, first, last, city, email, memo)
|
||||
SELECT rowid, id, first, last, city, email, memo FROM rolodex;
|
||||
"""
|
||||
))
|
||||
except Exception:
|
||||
# If FTS table doesn't exist or any error occurs, ignore silently
|
||||
pass
|
||||
|
||||
|
||||
def ensure_files_fts(engine: Engine) -> None:
|
||||
"""Ensure the `files_fts` virtual table and triggers exist and are populated."""
|
||||
_create_table = """
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS files_fts USING fts5(
|
||||
file_no,
|
||||
id,
|
||||
regarding,
|
||||
file_type,
|
||||
memo,
|
||||
content='files',
|
||||
content_rowid='rowid'
|
||||
);
|
||||
"""
|
||||
_execute_ignore_errors(engine, _create_table)
|
||||
|
||||
_triggers = [
|
||||
"""
|
||||
CREATE TRIGGER IF NOT EXISTS files_ai AFTER INSERT ON files BEGIN
|
||||
INSERT INTO files_fts(rowid, file_no, id, regarding, file_type, memo)
|
||||
VALUES (new.rowid, new.file_no, new.id, new.regarding, new.file_type, new.memo);
|
||||
END;
|
||||
""",
|
||||
"""
|
||||
CREATE TRIGGER IF NOT EXISTS files_ad AFTER DELETE ON files BEGIN
|
||||
INSERT INTO files_fts(files_fts, rowid, file_no, id, regarding, file_type, memo)
|
||||
VALUES ('delete', old.rowid, old.file_no, old.id, old.regarding, old.file_type, old.memo);
|
||||
END;
|
||||
""",
|
||||
"""
|
||||
CREATE TRIGGER IF NOT EXISTS files_au AFTER UPDATE ON files BEGIN
|
||||
INSERT INTO files_fts(files_fts, rowid, file_no, id, regarding, file_type, memo)
|
||||
VALUES ('delete', old.rowid, old.file_no, old.id, old.regarding, old.file_type, old.memo);
|
||||
INSERT INTO files_fts(rowid, file_no, id, regarding, file_type, memo)
|
||||
VALUES (new.rowid, new.file_no, new.id, new.regarding, new.file_type, new.memo);
|
||||
END;
|
||||
""",
|
||||
]
|
||||
for trig in _triggers:
|
||||
_execute_ignore_errors(engine, trig)
|
||||
|
||||
with engine.begin() as conn:
|
||||
try:
|
||||
count_fts = conn.execute(text("SELECT count(*) FROM files_fts")).scalar() # type: ignore
|
||||
if count_fts == 0:
|
||||
conn.execute(text(
|
||||
"""
|
||||
INSERT INTO files_fts(rowid, file_no, id, regarding, file_type, memo)
|
||||
SELECT rowid, file_no, id, regarding, file_type, memo FROM files;
|
||||
"""
|
||||
))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def ensure_ledger_fts(engine: Engine) -> None:
|
||||
"""Ensure the `ledger_fts` virtual table and triggers exist and are populated."""
|
||||
_create_table = """
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS ledger_fts USING fts5(
|
||||
file_no,
|
||||
t_code,
|
||||
note,
|
||||
empl_num,
|
||||
content='ledger',
|
||||
content_rowid='rowid'
|
||||
);
|
||||
"""
|
||||
_execute_ignore_errors(engine, _create_table)
|
||||
|
||||
_triggers = [
|
||||
"""
|
||||
CREATE TRIGGER IF NOT EXISTS ledger_ai AFTER INSERT ON ledger BEGIN
|
||||
INSERT INTO ledger_fts(rowid, file_no, t_code, note, empl_num)
|
||||
VALUES (new.rowid, new.file_no, new.t_code, new.note, new.empl_num);
|
||||
END;
|
||||
""",
|
||||
"""
|
||||
CREATE TRIGGER IF NOT EXISTS ledger_ad AFTER DELETE ON ledger BEGIN
|
||||
INSERT INTO ledger_fts(ledger_fts, rowid, file_no, t_code, note, empl_num)
|
||||
VALUES ('delete', old.rowid, old.file_no, old.t_code, old.note, old.empl_num);
|
||||
END;
|
||||
""",
|
||||
"""
|
||||
CREATE TRIGGER IF NOT EXISTS ledger_au AFTER UPDATE ON ledger BEGIN
|
||||
INSERT INTO ledger_fts(ledger_fts, rowid, file_no, t_code, note, empl_num)
|
||||
VALUES ('delete', old.rowid, old.file_no, old.t_code, old.note, old.empl_num);
|
||||
INSERT INTO ledger_fts(rowid, file_no, t_code, note, empl_num)
|
||||
VALUES (new.rowid, new.file_no, new.t_code, new.note, new.empl_num);
|
||||
END;
|
||||
""",
|
||||
]
|
||||
for trig in _triggers:
|
||||
_execute_ignore_errors(engine, trig)
|
||||
|
||||
with engine.begin() as conn:
|
||||
try:
|
||||
count_fts = conn.execute(text("SELECT count(*) FROM ledger_fts")).scalar() # type: ignore
|
||||
if count_fts == 0:
|
||||
conn.execute(text(
|
||||
"""
|
||||
INSERT INTO ledger_fts(rowid, file_no, t_code, note, empl_num)
|
||||
SELECT rowid, file_no, t_code, note, empl_num FROM ledger;
|
||||
"""
|
||||
))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def ensure_qdros_fts(engine: Engine) -> None:
|
||||
"""Ensure the `qdros_fts` virtual table and triggers exist and are populated."""
|
||||
_create_table = """
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS qdros_fts USING fts5(
|
||||
file_no,
|
||||
form_name,
|
||||
pet,
|
||||
res,
|
||||
case_number,
|
||||
content='qdros',
|
||||
content_rowid='rowid'
|
||||
);
|
||||
"""
|
||||
_execute_ignore_errors(engine, _create_table)
|
||||
|
||||
_triggers = [
|
||||
"""
|
||||
CREATE TRIGGER IF NOT EXISTS qdros_ai AFTER INSERT ON qdros BEGIN
|
||||
INSERT INTO qdros_fts(rowid, file_no, form_name, pet, res, case_number)
|
||||
VALUES (new.rowid, new.file_no, new.form_name, new.pet, new.res, new.case_number);
|
||||
END;
|
||||
""",
|
||||
"""
|
||||
CREATE TRIGGER IF NOT EXISTS qdros_ad AFTER DELETE ON qdros BEGIN
|
||||
INSERT INTO qdros_fts(qdros_fts, rowid, file_no, form_name, pet, res, case_number)
|
||||
VALUES ('delete', old.rowid, old.file_no, old.form_name, old.pet, old.res, old.case_number);
|
||||
END;
|
||||
""",
|
||||
"""
|
||||
CREATE TRIGGER IF NOT EXISTS qdros_au AFTER UPDATE ON qdros BEGIN
|
||||
INSERT INTO qdros_fts(qdros_fts, rowid, file_no, form_name, pet, res, case_number)
|
||||
VALUES ('delete', old.rowid, old.file_no, old.form_name, old.pet, old.res, old.case_number);
|
||||
INSERT INTO qdros_fts(rowid, file_no, form_name, pet, res, case_number)
|
||||
VALUES (new.rowid, new.file_no, new.form_name, new.pet, new.res, new.case_number);
|
||||
END;
|
||||
""",
|
||||
]
|
||||
for trig in _triggers:
|
||||
_execute_ignore_errors(engine, trig)
|
||||
|
||||
with engine.begin() as conn:
|
||||
try:
|
||||
count_fts = conn.execute(text("SELECT count(*) FROM qdros_fts")).scalar() # type: ignore
|
||||
if count_fts == 0:
|
||||
conn.execute(text(
|
||||
"""
|
||||
INSERT INTO qdros_fts(rowid, file_no, form_name, pet, res, case_number)
|
||||
SELECT rowid, file_no, form_name, pet, res, case_number FROM qdros;
|
||||
"""
|
||||
))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
31
app/database/indexes.py
Normal file
31
app/database/indexes.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""
|
||||
Database secondary indexes helper.
|
||||
|
||||
Creates small B-tree indexes for common equality filters to speed up searches.
|
||||
Uses CREATE INDEX IF NOT EXISTS so it is safe to call repeatedly at startup
|
||||
and works for existing databases without running a migration.
|
||||
"""
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy import text
|
||||
|
||||
|
||||
def ensure_secondary_indexes(engine: Engine) -> None:
|
||||
statements = [
|
||||
# Files
|
||||
"CREATE INDEX IF NOT EXISTS idx_files_status ON files(status)",
|
||||
"CREATE INDEX IF NOT EXISTS idx_files_file_type ON files(file_type)",
|
||||
"CREATE INDEX IF NOT EXISTS idx_files_empl_num ON files(empl_num)",
|
||||
# Ledger
|
||||
"CREATE INDEX IF NOT EXISTS idx_ledger_t_type ON ledger(t_type)",
|
||||
"CREATE INDEX IF NOT EXISTS idx_ledger_empl_num ON ledger(empl_num)",
|
||||
]
|
||||
with engine.begin() as conn:
|
||||
for stmt in statements:
|
||||
try:
|
||||
conn.execute(text(stmt))
|
||||
except Exception:
|
||||
# Ignore failures (e.g., non-SQLite engines that still support IF NOT EXISTS;
|
||||
# if not supported, users should manage indexes via migrations)
|
||||
pass
|
||||
|
||||
|
||||
130
app/database/schema_updates.py
Normal file
130
app/database/schema_updates.py
Normal file
@@ -0,0 +1,130 @@
|
||||
"""
|
||||
Lightweight, idempotent schema updates for SQLite.
|
||||
|
||||
Adds newly introduced columns to existing tables when running on an
|
||||
already-initialized database. Safe to call multiple times.
|
||||
"""
|
||||
from typing import Dict
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
|
||||
def _existing_columns(conn, table: str) -> set[str]:
|
||||
rows = conn.execute(f"PRAGMA table_info('{table}')").fetchall()
|
||||
return {row[1] for row in rows} # name is column 2
|
||||
|
||||
|
||||
def ensure_schema_updates(engine: Engine) -> None:
|
||||
"""Ensure missing columns are added for backward-compatible updates."""
|
||||
# Map of table -> {column: SQL type}
|
||||
updates: Dict[str, Dict[str, str]] = {
|
||||
# Forms
|
||||
"form_index": {
|
||||
"keyword": "TEXT",
|
||||
},
|
||||
# Richer Life/Number tables (forms & pensions harmonized)
|
||||
"life_tables": {
|
||||
"le_aa": "FLOAT",
|
||||
"na_aa": "FLOAT",
|
||||
"le_am": "FLOAT",
|
||||
"na_am": "FLOAT",
|
||||
"le_af": "FLOAT",
|
||||
"na_af": "FLOAT",
|
||||
"le_wa": "FLOAT",
|
||||
"na_wa": "FLOAT",
|
||||
"le_wm": "FLOAT",
|
||||
"na_wm": "FLOAT",
|
||||
"le_wf": "FLOAT",
|
||||
"na_wf": "FLOAT",
|
||||
"le_ba": "FLOAT",
|
||||
"na_ba": "FLOAT",
|
||||
"le_bm": "FLOAT",
|
||||
"na_bm": "FLOAT",
|
||||
"le_bf": "FLOAT",
|
||||
"na_bf": "FLOAT",
|
||||
"le_ha": "FLOAT",
|
||||
"na_ha": "FLOAT",
|
||||
"le_hm": "FLOAT",
|
||||
"na_hm": "FLOAT",
|
||||
"le_hf": "FLOAT",
|
||||
"na_hf": "FLOAT",
|
||||
"table_year": "INTEGER",
|
||||
"table_type": "VARCHAR(45)",
|
||||
},
|
||||
"number_tables": {
|
||||
"month": "INTEGER",
|
||||
"na_aa": "FLOAT",
|
||||
"na_am": "FLOAT",
|
||||
"na_af": "FLOAT",
|
||||
"na_wa": "FLOAT",
|
||||
"na_wm": "FLOAT",
|
||||
"na_wf": "FLOAT",
|
||||
"na_ba": "FLOAT",
|
||||
"na_bm": "FLOAT",
|
||||
"na_bf": "FLOAT",
|
||||
"na_ha": "FLOAT",
|
||||
"na_hm": "FLOAT",
|
||||
"na_hf": "FLOAT",
|
||||
"table_type": "VARCHAR(45)",
|
||||
"description": "TEXT",
|
||||
},
|
||||
"form_list": {
|
||||
"status": "VARCHAR(45)",
|
||||
},
|
||||
# Printers: add advanced legacy fields
|
||||
"printers": {
|
||||
"number": "INTEGER",
|
||||
"page_break": "VARCHAR(50)",
|
||||
"setup_st": "VARCHAR(200)",
|
||||
"reset_st": "VARCHAR(200)",
|
||||
"b_underline": "VARCHAR(100)",
|
||||
"e_underline": "VARCHAR(100)",
|
||||
"b_bold": "VARCHAR(100)",
|
||||
"e_bold": "VARCHAR(100)",
|
||||
"phone_book": "BOOLEAN",
|
||||
"rolodex_info": "BOOLEAN",
|
||||
"envelope": "BOOLEAN",
|
||||
"file_cabinet": "BOOLEAN",
|
||||
"accounts": "BOOLEAN",
|
||||
"statements": "BOOLEAN",
|
||||
"calendar": "BOOLEAN",
|
||||
},
|
||||
# Pensions
|
||||
"pension_schedules": {
|
||||
"vests_on": "DATE",
|
||||
"vests_at": "FLOAT",
|
||||
},
|
||||
"marriage_history": {
|
||||
"married_from": "DATE",
|
||||
"married_to": "DATE",
|
||||
"married_years": "FLOAT",
|
||||
"service_from": "DATE",
|
||||
"service_to": "DATE",
|
||||
"service_years": "FLOAT",
|
||||
"marital_percent": "FLOAT",
|
||||
},
|
||||
"death_benefits": {
|
||||
"lump1": "FLOAT",
|
||||
"lump2": "FLOAT",
|
||||
"growth1": "FLOAT",
|
||||
"growth2": "FLOAT",
|
||||
"disc1": "FLOAT",
|
||||
"disc2": "FLOAT",
|
||||
},
|
||||
}
|
||||
|
||||
with engine.begin() as conn:
|
||||
for table, cols in updates.items():
|
||||
try:
|
||||
existing = _existing_columns(conn, table)
|
||||
except Exception:
|
||||
# Table may not exist yet
|
||||
continue
|
||||
for col_name, col_type in cols.items():
|
||||
if col_name not in existing:
|
||||
try:
|
||||
conn.execute(f"ALTER TABLE {table} ADD COLUMN {col_name} {col_type}")
|
||||
except Exception:
|
||||
# Ignore if not applicable (other engines) or race condition
|
||||
pass
|
||||
|
||||
|
||||
20
app/main.py
20
app/main.py
@@ -9,6 +9,9 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.config import settings
|
||||
from app.database.base import engine
|
||||
from app.database.fts import ensure_rolodex_fts, ensure_files_fts, ensure_ledger_fts, ensure_qdros_fts
|
||||
from app.database.indexes import ensure_secondary_indexes
|
||||
from app.database.schema_updates import ensure_schema_updates
|
||||
from app.models import BaseModel
|
||||
from app.models.user import User
|
||||
from app.auth.security import get_admin_user
|
||||
@@ -24,6 +27,21 @@ logger = get_logger("main")
|
||||
logger.info("Creating database tables")
|
||||
BaseModel.metadata.create_all(bind=engine)
|
||||
|
||||
# Initialize SQLite FTS (if available)
|
||||
logger.info("Initializing FTS (if available)")
|
||||
ensure_rolodex_fts(engine)
|
||||
ensure_files_fts(engine)
|
||||
ensure_ledger_fts(engine)
|
||||
ensure_qdros_fts(engine)
|
||||
|
||||
# Ensure helpful secondary indexes
|
||||
logger.info("Ensuring secondary indexes (status, type, employee, etc.)")
|
||||
ensure_secondary_indexes(engine)
|
||||
|
||||
# Ensure idempotent schema updates for added columns
|
||||
logger.info("Ensuring schema updates (new columns)")
|
||||
ensure_schema_updates(engine)
|
||||
|
||||
# Initialize FastAPI app
|
||||
logger.info("Initializing FastAPI application", version=settings.app_version, debug=settings.debug)
|
||||
app = FastAPI(
|
||||
@@ -71,6 +89,7 @@ from app.api.import_data import router as import_router
|
||||
from app.api.flexible import router as flexible_router
|
||||
from app.api.support import router as support_router
|
||||
from app.api.settings import router as settings_router
|
||||
from app.api.mortality import router as mortality_router
|
||||
|
||||
logger.info("Including API routers")
|
||||
app.include_router(auth_router, prefix="/api/auth", tags=["authentication"])
|
||||
@@ -84,6 +103,7 @@ app.include_router(import_router, prefix="/api/import", tags=["import"])
|
||||
app.include_router(support_router, prefix="/api/support", tags=["support"])
|
||||
app.include_router(settings_router, prefix="/api/settings", tags=["settings"])
|
||||
app.include_router(flexible_router, prefix="/api")
|
||||
app.include_router(mortality_router, prefix="/api/mortality", tags=["mortality"])
|
||||
|
||||
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
|
||||
@@ -46,6 +46,25 @@ def _get_correlation_id(request: Request) -> str:
|
||||
return str(uuid4())
|
||||
|
||||
|
||||
def _json_safe(value: Any) -> Any:
|
||||
"""Recursively convert non-JSON-serializable objects (like Exceptions) into strings.
|
||||
|
||||
Keeps overall structure intact so tests inspecting error details (e.g. 'loc', 'msg') still work.
|
||||
"""
|
||||
# Exception -> string message
|
||||
if isinstance(value, BaseException):
|
||||
return str(value)
|
||||
# Mapping types
|
||||
if isinstance(value, dict):
|
||||
return {k: _json_safe(v) for k, v in value.items()}
|
||||
# Sequence types
|
||||
if isinstance(value, (list, tuple)):
|
||||
return [
|
||||
_json_safe(v) for v in value
|
||||
]
|
||||
return value
|
||||
|
||||
|
||||
def _build_error_response(
|
||||
request: Request,
|
||||
*,
|
||||
@@ -66,7 +85,7 @@ def _build_error_response(
|
||||
"correlation_id": correlation_id,
|
||||
}
|
||||
if details is not None:
|
||||
body["error"]["details"] = details
|
||||
body["error"]["details"] = _json_safe(details)
|
||||
|
||||
response = JSONResponse(content=body, status_code=status_code)
|
||||
response.headers[ERROR_HEADER_NAME] = correlation_id
|
||||
|
||||
@@ -14,12 +14,12 @@ from .flexible import FlexibleImport
|
||||
from .support import SupportTicket, TicketResponse, TicketStatus, TicketPriority, TicketCategory
|
||||
from .pensions import (
|
||||
Pension, PensionSchedule, MarriageHistory, DeathBenefit,
|
||||
SeparationAgreement, LifeTable, NumberTable
|
||||
SeparationAgreement, LifeTable, NumberTable, PensionResult
|
||||
)
|
||||
from .lookups import (
|
||||
Employee, FileType, FileStatus, TransactionType, TransactionCode,
|
||||
State, GroupLookup, Footer, PlanInfo, FormIndex, FormList,
|
||||
PrinterSetup, SystemSetup
|
||||
PrinterSetup, SystemSetup, FormKeyword
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@@ -28,8 +28,8 @@ __all__ = [
|
||||
"Deposit", "Payment", "FileNote", "FormVariable", "ReportVariable", "Document", "FlexibleImport",
|
||||
"SupportTicket", "TicketResponse", "TicketStatus", "TicketPriority", "TicketCategory",
|
||||
"Pension", "PensionSchedule", "MarriageHistory", "DeathBenefit",
|
||||
"SeparationAgreement", "LifeTable", "NumberTable",
|
||||
"SeparationAgreement", "LifeTable", "NumberTable", "PensionResult",
|
||||
"Employee", "FileType", "FileStatus", "TransactionType", "TransactionCode",
|
||||
"State", "GroupLookup", "Footer", "PlanInfo", "FormIndex", "FormList",
|
||||
"PrinterSetup", "SystemSetup"
|
||||
"PrinterSetup", "SystemSetup", "FormKeyword"
|
||||
]
|
||||
@@ -3,7 +3,7 @@ Audit logging models
|
||||
"""
|
||||
from sqlalchemy import Column, Integer, String, Text, DateTime, ForeignKey, JSON
|
||||
from sqlalchemy.orm import relationship
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from app.models.base import BaseModel
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ class AuditLog(BaseModel):
|
||||
details = Column(JSON, nullable=True) # Additional details as JSON
|
||||
ip_address = Column(String(45), nullable=True) # IPv4/IPv6 address
|
||||
user_agent = Column(Text, nullable=True) # Browser/client information
|
||||
timestamp = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
|
||||
timestamp = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False, index=True)
|
||||
|
||||
# Relationships
|
||||
user = relationship("User", back_populates="audit_logs")
|
||||
@@ -42,7 +42,7 @@ class LoginAttempt(BaseModel):
|
||||
ip_address = Column(String(45), nullable=False)
|
||||
user_agent = Column(Text, nullable=True)
|
||||
success = Column(Integer, default=0) # 1 for success, 0 for failure
|
||||
timestamp = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
|
||||
timestamp = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False, index=True)
|
||||
failure_reason = Column(String(200), nullable=True) # Reason for failure
|
||||
|
||||
def __repr__(self):
|
||||
@@ -56,8 +56,8 @@ class ImportAudit(BaseModel):
|
||||
__tablename__ = "import_audit"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True, index=True)
|
||||
started_at = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
|
||||
finished_at = Column(DateTime, nullable=True, index=True)
|
||||
started_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False, index=True)
|
||||
finished_at = Column(DateTime(timezone=True), nullable=True, index=True)
|
||||
status = Column(String(30), nullable=False, default="running", index=True) # running|success|completed_with_errors|failed
|
||||
|
||||
total_files = Column(Integer, nullable=False, default=0)
|
||||
@@ -94,7 +94,7 @@ class ImportAuditFile(BaseModel):
|
||||
errors = Column(Integer, nullable=False, default=0)
|
||||
message = Column(String(255), nullable=True)
|
||||
details = Column(JSON, nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False, index=True)
|
||||
|
||||
audit = relationship("ImportAudit", back_populates="files")
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""
|
||||
Authentication-related persistence models
|
||||
"""
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Column, Integer, String, DateTime, Boolean, ForeignKey, UniqueConstraint
|
||||
@@ -19,10 +19,10 @@ class RefreshToken(BaseModel):
|
||||
jti = Column(String(64), nullable=False, unique=True, index=True)
|
||||
user_agent = Column(String(255), nullable=True)
|
||||
ip_address = Column(String(45), nullable=True)
|
||||
issued_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||
expires_at = Column(DateTime, nullable=False, index=True)
|
||||
issued_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False)
|
||||
expires_at = Column(DateTime(timezone=True), nullable=False, index=True)
|
||||
revoked = Column(Boolean, default=False, nullable=False)
|
||||
revoked_at = Column(DateTime, nullable=True)
|
||||
revoked_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# relationships
|
||||
user = relationship("User")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""
|
||||
Lookup table models based on legacy system analysis
|
||||
"""
|
||||
from sqlalchemy import Column, Integer, String, Text, Boolean, Float
|
||||
from sqlalchemy import Column, Integer, String, Text, Boolean, Float, ForeignKey
|
||||
from app.models.base import BaseModel
|
||||
|
||||
|
||||
@@ -53,6 +53,9 @@ class FileStatus(BaseModel):
|
||||
description = Column(String(200), nullable=False) # Description
|
||||
active = Column(Boolean, default=True) # Is status active
|
||||
sort_order = Column(Integer, default=0) # Display order
|
||||
# Legacy fields for typed import support
|
||||
send = Column(Boolean, default=True) # Should statements print by default
|
||||
footer_code = Column(String(45), ForeignKey("footers.footer_code")) # Default footer
|
||||
|
||||
def __repr__(self):
|
||||
return f"<FileStatus(code='{self.status_code}', description='{self.description}')>"
|
||||
@@ -169,9 +172,10 @@ class FormIndex(BaseModel):
|
||||
"""
|
||||
__tablename__ = "form_index"
|
||||
|
||||
form_id = Column(String(45), primary_key=True, index=True) # Form identifier
|
||||
form_id = Column(String(45), primary_key=True, index=True) # Form identifier (maps to Name)
|
||||
form_name = Column(String(200), nullable=False) # Form name
|
||||
category = Column(String(45)) # Form category
|
||||
keyword = Column(String(200)) # Legacy FORM_INX Name/Keyword pair
|
||||
active = Column(Boolean, default=True) # Is form active
|
||||
|
||||
def __repr__(self):
|
||||
@@ -189,6 +193,7 @@ class FormList(BaseModel):
|
||||
form_id = Column(String(45), nullable=False) # Form identifier
|
||||
line_number = Column(Integer, nullable=False) # Line number in form
|
||||
content = Column(Text) # Line content
|
||||
status = Column(String(45)) # Legacy FORM_LST Status
|
||||
|
||||
def __repr__(self):
|
||||
return f"<FormList(form_id='{self.form_id}', line={self.line_number})>"
|
||||
@@ -201,12 +206,34 @@ class PrinterSetup(BaseModel):
|
||||
"""
|
||||
__tablename__ = "printers"
|
||||
|
||||
# Core identity and basic configuration
|
||||
printer_name = Column(String(100), primary_key=True, index=True) # Printer name
|
||||
description = Column(String(200)) # Description
|
||||
driver = Column(String(100)) # Print driver
|
||||
port = Column(String(20)) # Port/connection
|
||||
default_printer = Column(Boolean, default=False) # Is default printer
|
||||
active = Column(Boolean, default=True) # Is printer active
|
||||
|
||||
# Legacy numeric printer number (from PRINTERS.csv "Number")
|
||||
number = Column(Integer)
|
||||
|
||||
# Legacy control sequences and formatting (from PRINTERS.csv)
|
||||
page_break = Column(String(50))
|
||||
setup_st = Column(String(200))
|
||||
reset_st = Column(String(200))
|
||||
b_underline = Column(String(100))
|
||||
e_underline = Column(String(100))
|
||||
b_bold = Column(String(100))
|
||||
e_bold = Column(String(100))
|
||||
|
||||
# Optional report configuration toggles (legacy flags)
|
||||
phone_book = Column(Boolean, default=False)
|
||||
rolodex_info = Column(Boolean, default=False)
|
||||
envelope = Column(Boolean, default=False)
|
||||
file_cabinet = Column(Boolean, default=False)
|
||||
accounts = Column(Boolean, default=False)
|
||||
statements = Column(Boolean, default=False)
|
||||
calendar = Column(Boolean, default=False)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Printer(name='{self.printer_name}', description='{self.description}')>"
|
||||
@@ -225,4 +252,19 @@ class SystemSetup(BaseModel):
|
||||
setting_type = Column(String(20), default="STRING") # DATA type (STRING, INTEGER, FLOAT, BOOLEAN)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<SystemSetup(key='{self.setting_key}', value='{self.setting_value}')>"
|
||||
return f"<SystemSetup(key='{self.setting_key}', value='{self.setting_value}')>"
|
||||
|
||||
|
||||
class FormKeyword(BaseModel):
|
||||
"""
|
||||
Form keyword lookup
|
||||
Corresponds to INX_LKUP table in legacy system
|
||||
"""
|
||||
__tablename__ = "form_keywords"
|
||||
|
||||
keyword = Column(String(200), primary_key=True, index=True)
|
||||
description = Column(String(200))
|
||||
active = Column(Boolean, default=True)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<FormKeyword(keyword='{self.keyword}')>"
|
||||
@@ -60,11 +60,13 @@ class PensionSchedule(BaseModel):
|
||||
file_no = Column(String(45), ForeignKey("files.file_no"), nullable=False)
|
||||
version = Column(String(10), default="01")
|
||||
|
||||
# Schedule details
|
||||
# Schedule details (legacy vesting fields)
|
||||
start_date = Column(Date) # Start date for payments
|
||||
end_date = Column(Date) # End date for payments
|
||||
payment_amount = Column(Float, default=0.0) # Payment amount
|
||||
frequency = Column(String(20)) # Monthly, quarterly, etc.
|
||||
vests_on = Column(Date) # Legacy SCHEDULE.csv Vests_On
|
||||
vests_at = Column(Float, default=0.0) # Legacy SCHEDULE.csv Vests_At (percent)
|
||||
|
||||
# Relationships
|
||||
file = relationship("File", back_populates="pension_schedules")
|
||||
@@ -85,6 +87,15 @@ class MarriageHistory(BaseModel):
|
||||
divorce_date = Column(Date) # Date of divorce/separation
|
||||
spouse_name = Column(String(100)) # Spouse name
|
||||
notes = Column(Text) # Additional notes
|
||||
|
||||
# Legacy MARRIAGE.csv fields
|
||||
married_from = Column(Date)
|
||||
married_to = Column(Date)
|
||||
married_years = Column(Float, default=0.0)
|
||||
service_from = Column(Date)
|
||||
service_to = Column(Date)
|
||||
service_years = Column(Float, default=0.0)
|
||||
marital_percent = Column(Float, default=0.0)
|
||||
|
||||
# Relationships
|
||||
file = relationship("File", back_populates="marriage_history")
|
||||
@@ -105,6 +116,14 @@ class DeathBenefit(BaseModel):
|
||||
benefit_amount = Column(Float, default=0.0) # Benefit amount
|
||||
benefit_type = Column(String(45)) # Type of death benefit
|
||||
notes = Column(Text) # Additional notes
|
||||
|
||||
# Legacy DEATH.csv fields
|
||||
lump1 = Column(Float, default=0.0)
|
||||
lump2 = Column(Float, default=0.0)
|
||||
growth1 = Column(Float, default=0.0)
|
||||
growth2 = Column(Float, default=0.0)
|
||||
disc1 = Column(Float, default=0.0)
|
||||
disc2 = Column(Float, default=0.0)
|
||||
|
||||
# Relationships
|
||||
file = relationship("File", back_populates="death_benefits")
|
||||
@@ -138,10 +157,36 @@ class LifeTable(BaseModel):
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
age = Column(Integer, nullable=False) # Age
|
||||
male_expectancy = Column(Float) # Male life expectancy
|
||||
female_expectancy = Column(Float) # Female life expectancy
|
||||
table_year = Column(Integer) # Year of table (e.g., 2023)
|
||||
table_type = Column(String(45)) # Type of table
|
||||
# Rich typed columns reflecting legacy LIFETABL.csv headers
|
||||
# LE_* = Life Expectancy, NA_* = Number Alive/Survivors
|
||||
le_aa = Column(Float)
|
||||
na_aa = Column(Float)
|
||||
le_am = Column(Float)
|
||||
na_am = Column(Float)
|
||||
le_af = Column(Float)
|
||||
na_af = Column(Float)
|
||||
le_wa = Column(Float)
|
||||
na_wa = Column(Float)
|
||||
le_wm = Column(Float)
|
||||
na_wm = Column(Float)
|
||||
le_wf = Column(Float)
|
||||
na_wf = Column(Float)
|
||||
le_ba = Column(Float)
|
||||
na_ba = Column(Float)
|
||||
le_bm = Column(Float)
|
||||
na_bm = Column(Float)
|
||||
le_bf = Column(Float)
|
||||
na_bf = Column(Float)
|
||||
le_ha = Column(Float)
|
||||
na_ha = Column(Float)
|
||||
le_hm = Column(Float)
|
||||
na_hm = Column(Float)
|
||||
le_hf = Column(Float)
|
||||
na_hf = Column(Float)
|
||||
|
||||
# Optional metadata retained for future variations
|
||||
table_year = Column(Integer) # Year/version of table if known
|
||||
table_type = Column(String(45)) # Source/type of table (optional)
|
||||
|
||||
|
||||
class NumberTable(BaseModel):
|
||||
@@ -152,7 +197,63 @@ class NumberTable(BaseModel):
|
||||
__tablename__ = "number_tables"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
table_type = Column(String(45), nullable=False) # Type of table
|
||||
key_value = Column(String(45), nullable=False) # Key identifier
|
||||
numeric_value = Column(Float) # Numeric value
|
||||
description = Column(Text) # Description
|
||||
month = Column(Integer, nullable=False)
|
||||
# Rich typed NA_* columns reflecting legacy NUMBERAL.csv headers
|
||||
na_aa = Column(Float)
|
||||
na_am = Column(Float)
|
||||
na_af = Column(Float)
|
||||
na_wa = Column(Float)
|
||||
na_wm = Column(Float)
|
||||
na_wf = Column(Float)
|
||||
na_ba = Column(Float)
|
||||
na_bm = Column(Float)
|
||||
na_bf = Column(Float)
|
||||
na_ha = Column(Float)
|
||||
na_hm = Column(Float)
|
||||
na_hf = Column(Float)
|
||||
|
||||
# Optional metadata retained for future variations
|
||||
table_type = Column(String(45))
|
||||
description = Column(Text)
|
||||
|
||||
|
||||
class PensionResult(BaseModel):
|
||||
"""
|
||||
Computed pension results summary
|
||||
Corresponds to RESULTS table in legacy system
|
||||
"""
|
||||
__tablename__ = "pension_results"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
|
||||
# Optional linkage if present in future exports
|
||||
file_no = Column(String(45))
|
||||
version = Column(String(10))
|
||||
|
||||
# Columns observed in legacy RESULTS.csv header
|
||||
accrued = Column(Float)
|
||||
start_age = Column(Integer)
|
||||
cola = Column(Float)
|
||||
withdrawal = Column(String(45))
|
||||
pre_dr = Column(Float)
|
||||
post_dr = Column(Float)
|
||||
tax_rate = Column(Float)
|
||||
age = Column(Integer)
|
||||
years_from = Column(Float)
|
||||
life_exp = Column(Float)
|
||||
ev_monthly = Column(Float)
|
||||
payments = Column(Float)
|
||||
pay_out = Column(Float)
|
||||
fund_value = Column(Float)
|
||||
pv = Column(Float)
|
||||
mortality = Column(Float)
|
||||
pv_am = Column(Float)
|
||||
pv_amt = Column(Float)
|
||||
pv_pre_db = Column(Float)
|
||||
pv_annuity = Column(Float)
|
||||
wv_at = Column(Float)
|
||||
pv_plan = Column(Float)
|
||||
years_married = Column(Float)
|
||||
years_service = Column(Float)
|
||||
marr_per = Column(Float)
|
||||
marr_amt = Column(Float)
|
||||
@@ -3,7 +3,7 @@ Support ticket models for help desk functionality
|
||||
"""
|
||||
from sqlalchemy import Column, Integer, String, Text, DateTime, Boolean, ForeignKey, Enum
|
||||
from sqlalchemy.orm import relationship
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
import enum
|
||||
|
||||
from app.models.base import BaseModel
|
||||
@@ -63,9 +63,9 @@ class SupportTicket(BaseModel):
|
||||
ip_address = Column(String(45)) # IP address
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
resolved_at = Column(DateTime)
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False)
|
||||
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
|
||||
resolved_at = Column(DateTime(timezone=True))
|
||||
|
||||
# Admin assignment
|
||||
assigned_to = Column(Integer, ForeignKey("users.id"))
|
||||
@@ -95,7 +95,7 @@ class TicketResponse(BaseModel):
|
||||
author_email = Column(String(100)) # For non-user responses
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False)
|
||||
|
||||
# Relationships
|
||||
ticket = relationship("SupportTicket", back_populates="responses")
|
||||
|
||||
@@ -3,7 +3,7 @@ Audit logging service
|
||||
"""
|
||||
import json
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi import Request
|
||||
|
||||
@@ -65,7 +65,7 @@ class AuditService:
|
||||
details=details,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
timestamp=datetime.utcnow()
|
||||
timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -76,7 +76,7 @@ class AuditService:
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
# Log the error but don't fail the main operation
|
||||
logger.error("Failed to log audit entry", error=str(e), action=action, user_id=user_id)
|
||||
logger.error("Failed to log audit entry", error=str(e), action=action)
|
||||
return audit_log
|
||||
|
||||
@staticmethod
|
||||
@@ -119,7 +119,7 @@ class AuditService:
|
||||
ip_address=ip_address or "unknown",
|
||||
user_agent=user_agent,
|
||||
success=1 if success else 0,
|
||||
timestamp=datetime.utcnow(),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
failure_reason=failure_reason if not success else None
|
||||
)
|
||||
|
||||
@@ -252,7 +252,7 @@ class AuditService:
|
||||
Returns:
|
||||
List of failed login attempts
|
||||
"""
|
||||
cutoff_time = datetime.utcnow() - timedelta(hours=hours)
|
||||
cutoff_time = datetime.now(timezone.utc) - timedelta(hours=hours)
|
||||
query = db.query(LoginAttempt).filter(
|
||||
LoginAttempt.success == 0,
|
||||
LoginAttempt.timestamp >= cutoff_time
|
||||
|
||||
98
app/services/cache.py
Normal file
98
app/services/cache.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""
|
||||
Cache utilities with optional Redis backend.
|
||||
|
||||
If Redis is not configured or unavailable, all functions degrade to no-ops.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import hashlib
|
||||
from typing import Any, Optional
|
||||
|
||||
try:
|
||||
import redis.asyncio as redis # type: ignore
|
||||
except Exception: # pragma: no cover - allow running without redis installed
|
||||
redis = None # type: ignore
|
||||
|
||||
from app.config import settings
|
||||
|
||||
|
||||
_client: Optional["redis.Redis"] = None # type: ignore
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def _get_client() -> Optional["redis.Redis"]: # type: ignore
|
||||
"""Lazily initialize and return a shared Redis client if enabled."""
|
||||
global _client
|
||||
if not getattr(settings, "redis_url", None) or not getattr(settings, "cache_enabled", False):
|
||||
return None
|
||||
if redis is None:
|
||||
return None
|
||||
if _client is not None:
|
||||
return _client
|
||||
async with _lock:
|
||||
if _client is None:
|
||||
try:
|
||||
_client = redis.from_url(settings.redis_url, decode_responses=True) # type: ignore
|
||||
except Exception:
|
||||
_client = None
|
||||
return _client
|
||||
|
||||
|
||||
def _stable_hash(obj: Any) -> str:
|
||||
data = json.dumps(obj, sort_keys=True, separators=(",", ":"))
|
||||
return hashlib.sha1(data.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def build_key(kind: str, user_id: Optional[str], parts: dict) -> str:
|
||||
payload = {"u": user_id or "anon", "p": parts}
|
||||
return f"search:{kind}:v1:{_stable_hash(payload)}"
|
||||
|
||||
|
||||
async def cache_get_json(kind: str, user_id: Optional[str], parts: dict) -> Optional[Any]:
|
||||
client = await _get_client()
|
||||
if client is None:
|
||||
return None
|
||||
key = build_key(kind, user_id, parts)
|
||||
try:
|
||||
raw = await client.get(key)
|
||||
if raw is None:
|
||||
return None
|
||||
return json.loads(raw)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
async def cache_set_json(kind: str, user_id: Optional[str], parts: dict, value: Any, ttl_seconds: int) -> None:
|
||||
client = await _get_client()
|
||||
if client is None:
|
||||
return
|
||||
key = build_key(kind, user_id, parts)
|
||||
try:
|
||||
await client.set(key, json.dumps(value, separators=(",", ":")), ex=ttl_seconds)
|
||||
except Exception:
|
||||
return
|
||||
|
||||
|
||||
async def invalidate_prefix(prefix: str) -> None:
|
||||
client = await _get_client()
|
||||
if client is None:
|
||||
return
|
||||
try:
|
||||
# Use SCAN to avoid blocking Redis
|
||||
async for key in client.scan_iter(match=f"{prefix}*"):
|
||||
try:
|
||||
await client.delete(key)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
return
|
||||
|
||||
|
||||
async def invalidate_search_cache() -> None:
|
||||
# Wipe both global search and suggestions namespaces
|
||||
await invalidate_prefix("search:global:")
|
||||
await invalidate_prefix("search:suggestions:")
|
||||
|
||||
|
||||
141
app/services/customers_search.py
Normal file
141
app/services/customers_search.py
Normal file
@@ -0,0 +1,141 @@
|
||||
from typing import Optional, List
|
||||
from sqlalchemy import or_, and_, func, asc, desc
|
||||
|
||||
from app.models.rolodex import Rolodex
|
||||
|
||||
|
||||
def apply_customer_filters(base_query, search: Optional[str], group: Optional[str], state: Optional[str], groups: Optional[List[str]], states: Optional[List[str]]):
|
||||
"""Apply shared search and group/state filters to the provided base_query.
|
||||
|
||||
This helper is used by both list and export endpoints to keep logic in sync.
|
||||
"""
|
||||
s = (search or "").strip()
|
||||
if s:
|
||||
s_lower = s.lower()
|
||||
tokens = [t for t in s_lower.split() if t]
|
||||
contains_any = or_(
|
||||
func.lower(Rolodex.id).contains(s_lower),
|
||||
func.lower(Rolodex.last).contains(s_lower),
|
||||
func.lower(Rolodex.first).contains(s_lower),
|
||||
func.lower(Rolodex.middle).contains(s_lower),
|
||||
func.lower(Rolodex.city).contains(s_lower),
|
||||
func.lower(Rolodex.email).contains(s_lower),
|
||||
)
|
||||
name_tokens = [
|
||||
or_(
|
||||
func.lower(Rolodex.first).contains(tok),
|
||||
func.lower(Rolodex.middle).contains(tok),
|
||||
func.lower(Rolodex.last).contains(tok),
|
||||
)
|
||||
for tok in tokens
|
||||
]
|
||||
combined = contains_any if not name_tokens else or_(contains_any, and_(*name_tokens))
|
||||
|
||||
last_first_filter = None
|
||||
if "," in s_lower:
|
||||
last_part, first_part = [p.strip() for p in s_lower.split(",", 1)]
|
||||
if last_part and first_part:
|
||||
last_first_filter = and_(
|
||||
func.lower(Rolodex.last).contains(last_part),
|
||||
func.lower(Rolodex.first).contains(first_part),
|
||||
)
|
||||
elif last_part:
|
||||
last_first_filter = func.lower(Rolodex.last).contains(last_part)
|
||||
|
||||
final_filter = or_(combined, last_first_filter) if last_first_filter is not None else combined
|
||||
base_query = base_query.filter(final_filter)
|
||||
|
||||
effective_groups = [g for g in (groups or []) if g] or ([group] if group else [])
|
||||
if effective_groups:
|
||||
base_query = base_query.filter(Rolodex.group.in_(effective_groups))
|
||||
|
||||
effective_states = [s for s in (states or []) if s] or ([state] if state else [])
|
||||
if effective_states:
|
||||
base_query = base_query.filter(Rolodex.abrev.in_(effective_states))
|
||||
|
||||
return base_query
|
||||
|
||||
|
||||
def apply_customer_sorting(base_query, sort_by: Optional[str], sort_dir: Optional[str]):
|
||||
"""Apply shared sorting to the provided base_query.
|
||||
|
||||
Supported fields: id, name (last,first), city (city,state), email.
|
||||
Unknown fields fall back to id. Sorting is case-insensitive for strings.
|
||||
"""
|
||||
normalized_sort_by = (sort_by or "id").lower()
|
||||
normalized_sort_dir = (sort_dir or "asc").lower()
|
||||
is_desc = normalized_sort_dir == "desc"
|
||||
|
||||
order_columns = []
|
||||
if normalized_sort_by == "id":
|
||||
order_columns = [Rolodex.id]
|
||||
elif normalized_sort_by == "name":
|
||||
order_columns = [Rolodex.last, Rolodex.first]
|
||||
elif normalized_sort_by == "city":
|
||||
order_columns = [Rolodex.city, Rolodex.abrev]
|
||||
elif normalized_sort_by == "email":
|
||||
order_columns = [Rolodex.email]
|
||||
else:
|
||||
order_columns = [Rolodex.id]
|
||||
|
||||
ordered = []
|
||||
for col in order_columns:
|
||||
try:
|
||||
expr = func.lower(col) if col.type.python_type in (str,) else col # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
expr = col
|
||||
ordered.append(desc(expr) if is_desc else asc(expr))
|
||||
|
||||
if ordered:
|
||||
base_query = base_query.order_by(*ordered)
|
||||
return base_query
|
||||
|
||||
|
||||
def prepare_customer_csv_rows(customers: List[Rolodex], fields: Optional[List[str]]):
|
||||
"""Prepare CSV header and rows for the given customers and requested fields.
|
||||
|
||||
Returns a tuple: (header_row, rows), where header_row is a list of column
|
||||
titles and rows is a list of row lists ready to be written by csv.writer.
|
||||
"""
|
||||
allowed_fields_in_order = ["id", "name", "group", "city", "state", "phone", "email"]
|
||||
header_names = {
|
||||
"id": "Customer ID",
|
||||
"name": "Name",
|
||||
"group": "Group",
|
||||
"city": "City",
|
||||
"state": "State",
|
||||
"phone": "Primary Phone",
|
||||
"email": "Email",
|
||||
}
|
||||
|
||||
requested = [f.lower() for f in (fields or []) if isinstance(f, str)]
|
||||
selected_fields = [f for f in allowed_fields_in_order if f in requested] if requested else allowed_fields_in_order
|
||||
if not selected_fields:
|
||||
selected_fields = allowed_fields_in_order
|
||||
|
||||
header_row = [header_names[f] for f in selected_fields]
|
||||
|
||||
rows: List[List[str]] = []
|
||||
for c in customers:
|
||||
full_name = f"{(c.first or '').strip()} {(c.last or '').strip()}".strip()
|
||||
primary_phone = ""
|
||||
try:
|
||||
if getattr(c, "phone_numbers", None):
|
||||
primary_phone = c.phone_numbers[0].phone or ""
|
||||
except Exception:
|
||||
primary_phone = ""
|
||||
|
||||
row_map = {
|
||||
"id": c.id,
|
||||
"name": full_name,
|
||||
"group": c.group or "",
|
||||
"city": c.city or "",
|
||||
"state": c.abrev or "",
|
||||
"phone": primary_phone,
|
||||
"email": c.email or "",
|
||||
}
|
||||
rows.append([row_map[f] for f in selected_fields])
|
||||
|
||||
return header_row, rows
|
||||
|
||||
|
||||
127
app/services/mortality.py
Normal file
127
app/services/mortality.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""
|
||||
Mortality/Life table utilities.
|
||||
|
||||
Helpers to query `life_tables` and `number_tables` by age/month and
|
||||
return values filtered by sex/race using compact codes:
|
||||
- sex: M, F, A (All)
|
||||
- race: W (White), B (Black), H (Hispanic), A (All)
|
||||
|
||||
Column naming in tables follows the pattern:
|
||||
- LifeTable: le_{race}{sex}, na_{race}{sex}
|
||||
- NumberTable: na_{race}{sex}
|
||||
|
||||
Examples:
|
||||
- race=W, sex=M => suffix "wm" (columns `le_wm`, `na_wm`)
|
||||
- race=A, sex=F => suffix "af" (columns `le_af`, `na_af`)
|
||||
- race=H, sex=A => suffix "ha" (columns `le_ha`, `na_ha`)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, Optional, Tuple
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.pensions import LifeTable, NumberTable
|
||||
|
||||
|
||||
_RACE_MAP: Dict[str, str] = {
|
||||
"W": "w", # White
|
||||
"B": "b", # Black
|
||||
"H": "h", # Hispanic
|
||||
"A": "a", # All races
|
||||
}
|
||||
|
||||
_SEX_MAP: Dict[str, str] = {
|
||||
"M": "m",
|
||||
"F": "f",
|
||||
"A": "a", # All sexes
|
||||
}
|
||||
|
||||
|
||||
class InvalidCodeError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
def _normalize_codes(sex: str, race: str) -> Tuple[str, str, str]:
|
||||
"""Validate/normalize sex and race to construct the column suffix.
|
||||
|
||||
Returns (suffix, sex_u, race_u) where suffix is lowercase like "wm".
|
||||
Raises InvalidCodeError on invalid inputs.
|
||||
"""
|
||||
sex_u = (sex or "").strip().upper()
|
||||
race_u = (race or "").strip().upper()
|
||||
if sex_u not in _SEX_MAP:
|
||||
raise InvalidCodeError(f"Invalid sex code '{sex}'. Expected one of: {', '.join(_SEX_MAP.keys())}")
|
||||
if race_u not in _RACE_MAP:
|
||||
raise InvalidCodeError(f"Invalid race code '{race}'. Expected one of: {', '.join(_RACE_MAP.keys())}")
|
||||
return _RACE_MAP[race_u] + _SEX_MAP[sex_u], sex_u, race_u
|
||||
|
||||
|
||||
def get_life_values(
|
||||
db: Session,
|
||||
*,
|
||||
age: int,
|
||||
sex: str,
|
||||
race: str,
|
||||
) -> Optional[Dict[str, Optional[float]]]:
|
||||
"""Return life table LE and NA values for a given age, sex, and race.
|
||||
|
||||
Returns dict: {"age": int, "sex": str, "race": str, "le": float|None, "na": float|None}
|
||||
Returns None if the age row does not exist.
|
||||
Raises InvalidCodeError for invalid codes.
|
||||
"""
|
||||
suffix, sex_u, race_u = _normalize_codes(sex, race)
|
||||
row: Optional[LifeTable] = db.query(LifeTable).filter(LifeTable.age == age).first()
|
||||
if not row:
|
||||
return None
|
||||
|
||||
le_col = f"le_{suffix}"
|
||||
na_col = f"na_{suffix}"
|
||||
le_val = getattr(row, le_col, None)
|
||||
na_val = getattr(row, na_col, None)
|
||||
|
||||
return {
|
||||
"age": int(age),
|
||||
"sex": sex_u,
|
||||
"race": race_u,
|
||||
"le": float(le_val) if le_val is not None else None,
|
||||
"na": float(na_val) if na_val is not None else None,
|
||||
}
|
||||
|
||||
|
||||
def get_number_value(
|
||||
db: Session,
|
||||
*,
|
||||
month: int,
|
||||
sex: str,
|
||||
race: str,
|
||||
) -> Optional[Dict[str, Optional[float]]]:
|
||||
"""Return number table NA value for a given month, sex, and race.
|
||||
|
||||
Returns dict: {"month": int, "sex": str, "race": str, "na": float|None}
|
||||
Returns None if the month row does not exist.
|
||||
Raises InvalidCodeError for invalid codes.
|
||||
"""
|
||||
suffix, sex_u, race_u = _normalize_codes(sex, race)
|
||||
row: Optional[NumberTable] = db.query(NumberTable).filter(NumberTable.month == month).first()
|
||||
if not row:
|
||||
return None
|
||||
|
||||
na_col = f"na_{suffix}"
|
||||
na_val = getattr(row, na_col, None)
|
||||
|
||||
return {
|
||||
"month": int(month),
|
||||
"sex": sex_u,
|
||||
"race": race_u,
|
||||
"na": float(na_val) if na_val is not None else None,
|
||||
}
|
||||
|
||||
|
||||
__all__ = [
|
||||
"InvalidCodeError",
|
||||
"get_life_values",
|
||||
"get_number_value",
|
||||
]
|
||||
|
||||
|
||||
72
app/services/query_utils.py
Normal file
72
app/services/query_utils.py
Normal file
@@ -0,0 +1,72 @@
|
||||
from typing import Iterable, Optional, Sequence
|
||||
from sqlalchemy import or_, and_, asc, desc, func
|
||||
from sqlalchemy.sql.elements import BinaryExpression
|
||||
from sqlalchemy.sql.schema import Column
|
||||
|
||||
|
||||
def tokenized_ilike_filter(tokens: Sequence[str], columns: Sequence[Column]) -> Optional[BinaryExpression]:
|
||||
"""Build an AND-of-ORs case-insensitive LIKE filter across columns for each token.
|
||||
|
||||
Example: AND(OR(col1 ILIKE %t1%, col2 ILIKE %t1%), OR(col1 ILIKE %t2%, ...))
|
||||
Returns None when tokens or columns are empty.
|
||||
"""
|
||||
if not tokens or not columns:
|
||||
return None
|
||||
per_token_clauses = []
|
||||
for term in tokens:
|
||||
term = str(term or "").strip()
|
||||
if not term:
|
||||
continue
|
||||
per_token_clauses.append(or_(*[c.ilike(f"%{term}%") for c in columns]))
|
||||
if not per_token_clauses:
|
||||
return None
|
||||
return and_(*per_token_clauses)
|
||||
|
||||
|
||||
def apply_pagination(query, skip: int, limit: int):
|
||||
"""Apply offset/limit pagination to a SQLAlchemy query in a DRY way."""
|
||||
return query.offset(skip).limit(limit)
|
||||
|
||||
|
||||
def paginate_with_total(query, skip: int, limit: int, include_total: bool):
|
||||
"""Return (items, total|None) applying pagination and optionally counting total.
|
||||
|
||||
This avoids duplicating count + pagination logic at each endpoint.
|
||||
"""
|
||||
total_count = query.count() if include_total else None
|
||||
items = apply_pagination(query, skip, limit).all()
|
||||
return items, total_count
|
||||
|
||||
|
||||
def apply_sorting(query, sort_by: Optional[str], sort_dir: Optional[str], allowed: dict[str, list[Column]]):
|
||||
"""Apply case-insensitive sorting per a whitelist of allowed fields.
|
||||
|
||||
allowed: mapping from field name -> list of columns to sort by, in priority order.
|
||||
For string columns, compares using lower(column) for stable ordering.
|
||||
Unknown sort_by falls back to the first key in allowed.
|
||||
sort_dir: "asc" or "desc" (default asc)
|
||||
"""
|
||||
if not allowed:
|
||||
return query
|
||||
normalized_sort_by = (sort_by or next(iter(allowed.keys()))).lower()
|
||||
normalized_sort_dir = (sort_dir or "asc").lower()
|
||||
is_desc = normalized_sort_dir == "desc"
|
||||
|
||||
columns = allowed.get(normalized_sort_by)
|
||||
if not columns:
|
||||
columns = allowed.get(next(iter(allowed.keys())))
|
||||
if not columns:
|
||||
return query
|
||||
|
||||
order_exprs = []
|
||||
for col in columns:
|
||||
try:
|
||||
expr = func.lower(col) if getattr(col.type, "python_type", str) is str else col
|
||||
except Exception:
|
||||
expr = col
|
||||
order_exprs.append(desc(expr) if is_desc else asc(expr))
|
||||
if order_exprs:
|
||||
query = query.order_by(*order_exprs)
|
||||
return query
|
||||
|
||||
|
||||
Reference in New Issue
Block a user