diff --git a/.e2e-token b/.e2e-token new file mode 100644 index 0000000..57c6271 --- /dev/null +++ b/.e2e-token @@ -0,0 +1 @@ +eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJhZG1pbiIsInR5cGUiOiJhY2Nlc3MiLCJpYXQiOjE3NTUyMDAyNzMsImV4cCI6MTc1NTIxNDY3M30.VfcV_zbhtSe50u1awNC4v2O8CU4PQ9AwhlcNeNn40cM \ No newline at end of file diff --git a/DATA_MIGRATION_README.md b/DATA_MIGRATION_README.md index 3ad5d8f..8156802 100644 --- a/DATA_MIGRATION_README.md +++ b/DATA_MIGRATION_README.md @@ -6,7 +6,7 @@ This guide covers the complete data migration process for importing legacy Delph ## 🔍 Migration Status Summary ### ✅ **READY FOR MIGRATION** -- **Readiness Score**: 100% (31/31 files fully mapped) +- **Readiness Score**: 100% (31/31 files supported; several use flexible extras for non-core columns) - **Security**: All sensitive files excluded from git - **API Endpoints**: Complete import/export functionality - **Data Validation**: Enhanced type conversion and validation @@ -30,8 +30,9 @@ This guide covers the complete data migration process for importing legacy Delph | GRUPLKUP.csv | GroupLookup | ✅ Ready | Group categories | | FOOTERS.csv | Footer | ✅ Ready | Statement footer templates | | PLANINFO.csv | PlanInfo | ✅ Ready | Retirement plan information | -| FORM_INX.csv | FormIndex | ✅ Ready | Form templates index | -| FORM_LST.csv | FormList | ✅ Ready | Form template content | +| FORM_INX.csv | FormIndex | ✅ Ready | Form templates index (non-core fields stored as flexible extras) | +| FORM_LST.csv | FormList | ✅ Ready | Form template content (non-core fields stored as flexible extras) | +| INX_LKUP.csv | FormKeyword | ✅ Ready | Form keywords lookup | | PRINTERS.csv | PrinterSetup | ✅ Ready | Printer configuration | | SETUP.csv | SystemSetup | ✅ Ready | System configuration | | **Pension Sub-tables** | | | | @@ -39,8 +40,9 @@ This guide covers the complete data migration process for importing legacy Delph | MARRIAGE.csv | MarriageHistory | ✅ Ready | Marriage history data | | DEATH.csv | DeathBenefit | ✅ Ready | Death benefit calculations | | SEPARATE.csv | SeparationAgreement | ✅ Ready | Separation agreements | -| LIFETABL.csv | LifeTable | ✅ Ready | Life expectancy tables | -| NUMBERAL.csv | NumberTable | ✅ Ready | Numerical calculation tables | +| LIFETABL.csv | LifeTable | ✅ Ready | Life expectancy tables (simplified model; extra columns stored as flexible extras) | +| NUMBERAL.csv | NumberTable | ✅ Ready | Numerical calculation tables (simplified model; extra columns stored as flexible extras) | +| RESULTS.csv | PensionResult | ✅ Ready | Computed results summary | ### ✅ **Recently Added Files** (6/31 files) | File | Model | Status | Notes | diff --git a/README.md b/README.md index a403963..154e8e3 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,25 @@ Modern database system for legal practice management, financial tracking, and do - **Authentication**: JWT with bcrypt password hashing - **Validation**: Pydantic v2 +## ⚡ Search Performance (FTS + Cache) + +- Full-text search is enabled via SQLite FTS5 for Customers (`rolodex`), Files, Ledger, and QDRO. + - The app creates virtual FTS tables and sync triggers at startup. + - On engines without FTS5, search falls back to standard `ILIKE` queries. +- Common filter columns are indexed for faster filtering: `files(status, file_type, empl_num)` and `ledger(t_type, empl_num)`. +- Response caching (optional) uses Redis for global search and suggestions. + - Cache TTL: ~90s for global search, ~60s for suggestions. + - Cache is auto-invalidated on create/update/delete affecting customers, files, ledger, or QDROs. + +Enable cache: +```bash +export CACHE_ENABLED=true +export REDIS_URL=redis://localhost:6379/0 +``` + +Diagnostics: +- `GET /api/search/_debug` reports whether FTS tables exist and Redis is available (requires auth). + ## 📊 Database Structure Based on analysis of legacy Pascal system: @@ -134,6 +153,48 @@ delphi-database/ ## 🔧 API Endpoints +### Common pagination, sorting, and totals +- Many list endpoints support the same query parameters: + - `skip` (int): offset for pagination. Default varies per endpoint. + - `limit` (int): page size. Most endpoints cap at 200–1000. + - `sort_by` (str): whitelisted field name per endpoint. + - `sort_dir` (str): `asc` or `desc`. + - `include_total` (bool): when `true`, the response is an object `{ items, total }`; otherwise a plain list is returned for backwards compatibility. +- Some endpoints also support `search` (tokenized across multiple columns with AND semantics) for simple text filtering. + +Examples: +```bash +# Support tickets (admin) +curl \ + "http://localhost:6920/api/support/tickets?include_total=true&limit=10&sort_by=created&sort_dir=desc" + +# My support tickets (current user) +curl \ + "http://localhost:6920/api/support/my-tickets?include_total=true&limit=10&sort_by=updated&sort_dir=desc" + +# QDROs for a file +curl \ + "http://localhost:6920/api/documents/qdros/FILE-123?include_total=true&sort_by=updated&sort_dir=desc" + +# Ledger entries for a file +curl \ + "http://localhost:6920/api/financial/ledger/FILE-123?include_total=true&sort_by=date&sort_dir=desc" + +# Customer phones +curl \ + "http://localhost:6920/api/customers/CUST-1/phones?include_total=true&sort_by=location&sort_dir=asc" +``` + +Allowed sort fields (high level): +- Support tickets: `created`, `updated`, `resolved`, `priority`, `status`, `subject` +- My tickets: `created`, `updated`, `resolved`, `priority`, `status`, `subject` +- QDROs (list and by file): `file_no`, `version`, `status`, `created`, `updated` +- Ledger by file: `date`, `item_no`, `amount`, `billed` +- Templates: `form_id`, `form_name`, `category`, `created`, `updated` +- Files: `file_no`, `client`, `opened`, `closed`, `status`, `amount_owing`, `total_charges` +- Admin users: `username`, `email`, `first_name`, `last_name`, `created`, `updated` +- Customer phones: `location`, `phone` + ### Authentication - `POST /api/auth/login` - User login - `POST /api/auth/register` - Register user (admin only) @@ -156,19 +217,28 @@ delphi-database/ - `DELETE /api/files/{file_no}` - Delete file ### Financial (Ledger) -- `GET /api/financial/ledger/{file_no}` - Get ledger entries +- `GET /api/financial/ledger/{file_no}` - Get ledger entries (supports pagination, sorting, `include_total`) - `POST /api/financial/ledger/` - Create transaction - `PUT /api/financial/ledger/{id}` - Update transaction - `DELETE /api/financial/ledger/{id}` - Delete transaction - `GET /api/financial/reports/{file_no}` - Financial reports ### Documents (QDROs) -- `GET /api/documents/qdros/{file_no}` - Get QDROs for file +- `GET /api/documents/qdros/{file_no}` - Get QDROs for file (supports pagination, sorting, `include_total`) - `POST /api/documents/qdros/` - Create QDRO - `GET /api/documents/qdros/{file_no}/{id}` - Get specific QDRO - `PUT /api/documents/qdros/{file_no}/{id}` - Update QDRO - `DELETE /api/documents/qdros/{file_no}/{id}` - Delete QDRO +### Support +- `POST /api/support/tickets` - Create support ticket (public; auth optional) +- `GET /api/support/tickets` - List tickets (admin; supports filters, search, pagination, sorting, `include_total`) +- `GET /api/support/tickets/{id}` - Get ticket details (admin) +- `PUT /api/support/tickets/{id}` - Update ticket (admin) +- `POST /api/support/tickets/{id}/responses` - Add response (admin) +- `GET /api/support/my-tickets` - List current user's tickets (supports status filter, search, pagination, sorting, `include_total`) +- `GET /api/support/stats` - Ticket statistics (admin) + ### Search - `GET /api/search/customers?q={query}` - Search customers - `GET /api/search/files?q={query}` - Search files diff --git a/app/api/admin.py b/app/api/admin.py index 2003f19..a808114 100644 --- a/app/api/admin.py +++ b/app/api/admin.py @@ -1,7 +1,7 @@ """ Comprehensive Admin API endpoints - User management, system settings, audit logging """ -from typing import List, Dict, Any, Optional +from typing import List, Dict, Any, Optional, Union from fastapi import APIRouter, Depends, HTTPException, status, UploadFile, File, Query, Body, Request from fastapi.responses import FileResponse from sqlalchemy.orm import Session, joinedload @@ -13,24 +13,27 @@ import hashlib import secrets import shutil import time -from datetime import datetime, timedelta, date +from datetime import datetime, timedelta, date, timezone from pathlib import Path from app.database.base import get_db +from app.api.search_highlight import build_query_tokens # Track application start time APPLICATION_START_TIME = time.time() from app.models import User, Rolodex, File as FileModel, Ledger, QDRO, AuditLog, LoginAttempt -from app.models.lookups import SystemSetup, Employee, FileType, FileStatus, TransactionType, TransactionCode, State, FormIndex +from app.models.lookups import SystemSetup, Employee, FileType, FileStatus, TransactionType, TransactionCode, State, FormIndex, PrinterSetup from app.auth.security import get_admin_user, get_password_hash, create_access_token from app.services.audit import audit_service from app.config import settings +from app.services.query_utils import apply_sorting, tokenized_ilike_filter, paginate_with_total router = APIRouter() # Enhanced Admin Schemas from pydantic import BaseModel, Field, EmailStr +from pydantic.config import ConfigDict class SystemStats(BaseModel): """Enhanced system statistics""" @@ -91,8 +94,7 @@ class UserResponse(BaseModel): created_at: Optional[datetime] updated_at: Optional[datetime] - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) class PasswordReset(BaseModel): """Password reset request""" @@ -124,8 +126,7 @@ class AuditLogEntry(BaseModel): user_agent: Optional[str] timestamp: datetime - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) class BackupInfo(BaseModel): """Backup information""" @@ -135,6 +136,132 @@ class BackupInfo(BaseModel): backup_type: str status: str + +class PrinterSetupBase(BaseModel): + """Base schema for printer setup""" + description: Optional[str] = None + driver: Optional[str] = None + port: Optional[str] = None + default_printer: Optional[bool] = None + active: Optional[bool] = None + number: Optional[int] = None + page_break: Optional[str] = None + setup_st: Optional[str] = None + reset_st: Optional[str] = None + b_underline: Optional[str] = None + e_underline: Optional[str] = None + b_bold: Optional[str] = None + e_bold: Optional[str] = None + phone_book: Optional[bool] = None + rolodex_info: Optional[bool] = None + envelope: Optional[bool] = None + file_cabinet: Optional[bool] = None + accounts: Optional[bool] = None + statements: Optional[bool] = None + calendar: Optional[bool] = None + + +class PrinterSetupCreate(PrinterSetupBase): + printer_name: str + + +class PrinterSetupUpdate(PrinterSetupBase): + pass + + +class PrinterSetupResponse(PrinterSetupBase): + printer_name: str + created_at: Optional[datetime] = None + updated_at: Optional[datetime] = None + + model_config = ConfigDict(from_attributes=True) + + +# Printer Setup Management + +@router.get("/printers", response_model=List[PrinterSetupResponse]) +async def list_printers( + db: Session = Depends(get_db), + current_user: User = Depends(get_admin_user) +): + printers = db.query(PrinterSetup).order_by(PrinterSetup.printer_name.asc()).all() + return printers + + +@router.get("/printers/{printer_name}", response_model=PrinterSetupResponse) +async def get_printer( + printer_name: str, + db: Session = Depends(get_db), + current_user: User = Depends(get_admin_user) +): + printer = db.query(PrinterSetup).filter(PrinterSetup.printer_name == printer_name).first() + if not printer: + raise HTTPException(status_code=404, detail="Printer not found") + return printer + + +@router.post("/printers", response_model=PrinterSetupResponse) +async def create_printer( + payload: PrinterSetupCreate, + db: Session = Depends(get_db), + current_user: User = Depends(get_admin_user) +): + exists = db.query(PrinterSetup).filter(PrinterSetup.printer_name == payload.printer_name).first() + if exists: + raise HTTPException(status_code=400, detail="Printer already exists") + data = payload.model_dump(exclude_unset=True) + instance = PrinterSetup(**data) + db.add(instance) + # Enforce single default printer + if data.get("default_printer"): + try: + db.query(PrinterSetup).filter(PrinterSetup.printer_name != instance.printer_name).update({PrinterSetup.default_printer: False}) + except Exception: + pass + db.commit() + db.refresh(instance) + return instance + + +@router.put("/printers/{printer_name}", response_model=PrinterSetupResponse) +async def update_printer( + printer_name: str, + payload: PrinterSetupUpdate, + db: Session = Depends(get_db), + current_user: User = Depends(get_admin_user) +): + instance = db.query(PrinterSetup).filter(PrinterSetup.printer_name == printer_name).first() + if not instance: + raise HTTPException(status_code=404, detail="Printer not found") + updates = payload.model_dump(exclude_unset=True) + for k, v in updates.items(): + setattr(instance, k, v) + # Enforce single default printer when set true + if updates.get("default_printer"): + try: + db.query(PrinterSetup).filter(PrinterSetup.printer_name != instance.printer_name).update({PrinterSetup.default_printer: False}) + except Exception: + pass + db.commit() + db.refresh(instance) + return instance + + +@router.delete("/printers/{printer_name}") +async def delete_printer( + printer_name: str, + db: Session = Depends(get_db), + current_user: User = Depends(get_admin_user) +): + instance = db.query(PrinterSetup).filter(PrinterSetup.printer_name == printer_name).first() + if not instance: + raise HTTPException(status_code=404, detail="Printer not found") + db.delete(instance) + db.commit() + return {"message": "Printer deleted"} + + + class LookupTableInfo(BaseModel): """Lookup table information""" table_name: str @@ -202,7 +329,7 @@ async def system_health( # Count active sessions (simplified) try: active_sessions = db.query(User).filter( - User.last_login > datetime.now() - timedelta(hours=24) + User.last_login > datetime.now(timezone.utc) - timedelta(hours=24) ).count() except: active_sessions = 0 @@ -215,7 +342,7 @@ async def system_health( backup_files = list(backup_dir.glob("*.db")) if backup_files: latest_backup = max(backup_files, key=lambda p: p.stat().st_mtime) - backup_age = datetime.now() - datetime.fromtimestamp(latest_backup.stat().st_mtime) + backup_age = datetime.now(timezone.utc) - datetime.fromtimestamp(latest_backup.stat().st_mtime, tz=timezone.utc) last_backup = latest_backup.name if backup_age.days > 7: alerts.append(f"Last backup is {backup_age.days} days old") @@ -257,7 +384,7 @@ async def system_statistics( # Count active users (logged in within last 30 days) total_active_users = db.query(func.count(User.id)).filter( - User.last_login > datetime.now() - timedelta(days=30) + User.last_login > datetime.now(timezone.utc) - timedelta(days=30) ).scalar() # Count admin users @@ -308,7 +435,7 @@ async def system_statistics( recent_activity.append({ "type": "customer_added", "description": f"Customer {customer.first} {customer.last} added", - "timestamp": datetime.now().isoformat() + "timestamp": datetime.now(timezone.utc).isoformat() }) except: pass @@ -409,7 +536,7 @@ async def export_table( # Create exports directory if it doesn't exist os.makedirs("exports", exist_ok=True) - filename = f"exports/{table_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv" + filename = f"exports/{table_name}_{datetime.now(timezone.utc).strftime('%Y%m%d_%H%M%S')}.csv" try: if table_name.lower() == "customers" or table_name.lower() == "rolodex": @@ -470,10 +597,10 @@ async def download_backup( if "sqlite" in settings.database_url: db_path = settings.database_url.replace("sqlite:///", "") if os.path.exists(db_path): - return FileResponse( + return FileResponse( db_path, media_type='application/octet-stream', - filename=f"delphi_backup_{datetime.now().strftime('%Y%m%d_%H%M%S')}.db" + filename=f"delphi_backup_{datetime.now(timezone.utc).strftime('%Y%m%d_%H%M%S')}.db" ) raise HTTPException( @@ -484,12 +611,20 @@ async def download_backup( # User Management Endpoints -@router.get("/users", response_model=List[UserResponse]) +class PaginatedUsersResponse(BaseModel): + items: List[UserResponse] + total: int + + +@router.get("/users", response_model=Union[List[UserResponse], PaginatedUsersResponse]) async def list_users( skip: int = Query(0, ge=0), limit: int = Query(100, ge=1, le=1000), search: Optional[str] = Query(None), active_only: bool = Query(False), + sort_by: Optional[str] = Query(None, description="Sort by: username, email, first_name, last_name, created, updated"), + sort_dir: Optional[str] = Query("asc", description="Sort direction: asc or desc"), + include_total: bool = Query(False, description="When true, returns {items, total} instead of a plain list"), db: Session = Depends(get_db), current_user: User = Depends(get_admin_user) ): @@ -498,19 +633,38 @@ async def list_users( query = db.query(User) if search: - query = query.filter( - or_( - User.username.ilike(f"%{search}%"), - User.email.ilike(f"%{search}%"), - User.first_name.ilike(f"%{search}%"), - User.last_name.ilike(f"%{search}%") - ) - ) + # DRY: tokenize and apply case-insensitive multi-field search + tokens = build_query_tokens(search) + filter_expr = tokenized_ilike_filter(tokens, [ + User.username, + User.email, + User.first_name, + User.last_name, + ]) + if filter_expr is not None: + query = query.filter(filter_expr) if active_only: query = query.filter(User.is_active == True) - users = query.offset(skip).limit(limit).all() + # Sorting (whitelisted) + query = apply_sorting( + query, + sort_by, + sort_dir, + allowed={ + "username": [User.username], + "email": [User.email], + "first_name": [User.first_name], + "last_name": [User.last_name], + "created": [User.created_at], + "updated": [User.updated_at], + }, + ) + + users, total = paginate_with_total(query, skip, limit, include_total) + if include_total: + return {"items": users, "total": total or 0} return users @@ -567,8 +721,8 @@ async def create_user( hashed_password=hashed_password, is_admin=user_data.is_admin, is_active=user_data.is_active, - created_at=datetime.now(), - updated_at=datetime.now() + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc) ) db.add(new_user) @@ -659,7 +813,7 @@ async def update_user( changes[field] = {"from": getattr(user, field), "to": value} setattr(user, field, value) - user.updated_at = datetime.now() + user.updated_at = datetime.now(timezone.utc) db.commit() db.refresh(user) @@ -702,7 +856,7 @@ async def delete_user( # Soft delete by deactivating user.is_active = False - user.updated_at = datetime.now() + user.updated_at = datetime.now(timezone.utc) db.commit() @@ -744,7 +898,7 @@ async def reset_user_password( # Update password user.hashed_password = get_password_hash(password_data.new_password) - user.updated_at = datetime.now() + user.updated_at = datetime.now(timezone.utc) db.commit() @@ -1046,7 +1200,7 @@ async def create_backup( backup_dir.mkdir(exist_ok=True) # Generate backup filename - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S") backup_filename = f"delphi_backup_{timestamp}.db" backup_path = backup_dir / backup_filename @@ -1063,7 +1217,7 @@ async def create_backup( "backup_info": { "filename": backup_filename, "size": f"{backup_size / (1024*1024):.1f} MB", - "created_at": datetime.now().isoformat(), + "created_at": datetime.now(timezone.utc).isoformat(), "backup_type": "manual", "status": "completed" } @@ -1118,45 +1272,61 @@ async def get_audit_logs( resource_type: Optional[str] = Query(None), action: Optional[str] = Query(None), hours_back: int = Query(168, ge=1, le=8760), # Default 7 days, max 1 year + sort_by: Optional[str] = Query("timestamp", description="Sort by: timestamp, username, action, resource_type"), + sort_dir: Optional[str] = Query("desc", description="Sort direction: asc or desc"), + include_total: bool = Query(False, description="When true, returns {items, total} instead of a plain list"), db: Session = Depends(get_db), current_user: User = Depends(get_admin_user) ): - """Get audit log entries with filtering""" - - cutoff_time = datetime.now() - timedelta(hours=hours_back) - + """Get audit log entries with filtering, sorting, and pagination""" + + cutoff_time = datetime.now(timezone.utc) - timedelta(hours=hours_back) + query = db.query(AuditLog).filter(AuditLog.timestamp >= cutoff_time) - + if user_id: query = query.filter(AuditLog.user_id == user_id) - + if resource_type: query = query.filter(AuditLog.resource_type.ilike(f"%{resource_type}%")) - + if action: query = query.filter(AuditLog.action.ilike(f"%{action}%")) - - total_count = query.count() - logs = query.order_by(AuditLog.timestamp.desc()).offset(skip).limit(limit).all() - - return { - "total": total_count, - "logs": [ - { - "id": log.id, - "user_id": log.user_id, - "username": log.username, - "action": log.action, - "resource_type": log.resource_type, - "resource_id": log.resource_id, - "details": log.details, - "ip_address": log.ip_address, - "user_agent": log.user_agent, - "timestamp": log.timestamp.isoformat() - } - for log in logs - ] - } + + # Sorting (whitelisted) + query = apply_sorting( + query, + sort_by, + sort_dir, + allowed={ + "timestamp": [AuditLog.timestamp], + "username": [AuditLog.username], + "action": [AuditLog.action], + "resource_type": [AuditLog.resource_type], + }, + ) + + logs, total = paginate_with_total(query, skip, limit, include_total) + + items = [ + { + "id": log.id, + "user_id": log.user_id, + "username": log.username, + "action": log.action, + "resource_type": log.resource_type, + "resource_id": log.resource_id, + "details": log.details, + "ip_address": log.ip_address, + "user_agent": log.user_agent, + "timestamp": log.timestamp.isoformat(), + } + for log in logs + ] + + if include_total: + return {"items": items, "total": total or 0} + return items @router.get("/audit/login-attempts") @@ -1166,39 +1336,55 @@ async def get_login_attempts( username: Optional[str] = Query(None), failed_only: bool = Query(False), hours_back: int = Query(168, ge=1, le=8760), # Default 7 days + sort_by: Optional[str] = Query("timestamp", description="Sort by: timestamp, username, ip_address, success"), + sort_dir: Optional[str] = Query("desc", description="Sort direction: asc or desc"), + include_total: bool = Query(False, description="When true, returns {items, total} instead of a plain list"), db: Session = Depends(get_db), current_user: User = Depends(get_admin_user) ): - """Get login attempts with filtering""" - - cutoff_time = datetime.now() - timedelta(hours=hours_back) - + """Get login attempts with filtering, sorting, and pagination""" + + cutoff_time = datetime.now(timezone.utc) - timedelta(hours=hours_back) + query = db.query(LoginAttempt).filter(LoginAttempt.timestamp >= cutoff_time) - + if username: query = query.filter(LoginAttempt.username.ilike(f"%{username}%")) - + if failed_only: query = query.filter(LoginAttempt.success == 0) - - total_count = query.count() - attempts = query.order_by(LoginAttempt.timestamp.desc()).offset(skip).limit(limit).all() - - return { - "total": total_count, - "attempts": [ - { - "id": attempt.id, - "username": attempt.username, - "ip_address": attempt.ip_address, - "user_agent": attempt.user_agent, - "success": bool(attempt.success), - "failure_reason": attempt.failure_reason, - "timestamp": attempt.timestamp.isoformat() - } - for attempt in attempts - ] - } + + # Sorting (whitelisted) + query = apply_sorting( + query, + sort_by, + sort_dir, + allowed={ + "timestamp": [LoginAttempt.timestamp], + "username": [LoginAttempt.username], + "ip_address": [LoginAttempt.ip_address], + "success": [LoginAttempt.success], + }, + ) + + attempts, total = paginate_with_total(query, skip, limit, include_total) + + items = [ + { + "id": attempt.id, + "username": attempt.username, + "ip_address": attempt.ip_address, + "user_agent": attempt.user_agent, + "success": bool(attempt.success), + "failure_reason": attempt.failure_reason, + "timestamp": attempt.timestamp.isoformat(), + } + for attempt in attempts + ] + + if include_total: + return {"items": items, "total": total or 0} + return items @router.get("/audit/user-activity/{user_id}") @@ -1251,7 +1437,7 @@ async def get_security_alerts( ): """Get security alerts and suspicious activity""" - cutoff_time = datetime.now() - timedelta(hours=hours_back) + cutoff_time = datetime.now(timezone.utc) - timedelta(hours=hours_back) # Get failed login attempts failed_logins = db.query(LoginAttempt).filter( @@ -1356,7 +1542,7 @@ async def get_audit_statistics( ): """Get audit statistics and metrics""" - cutoff_time = datetime.now() - timedelta(days=days_back) + cutoff_time = datetime.now(timezone.utc) - timedelta(days=days_back) # Total activity counts total_audit_entries = db.query(func.count(AuditLog.id)).filter( diff --git a/app/api/auth.py b/app/api/auth.py index d6194ae..dde00a4 100644 --- a/app/api/auth.py +++ b/app/api/auth.py @@ -1,7 +1,7 @@ """ Authentication API endpoints """ -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from typing import List from fastapi import APIRouter, Depends, HTTPException, status, Request from fastapi.security import OAuth2PasswordRequestForm @@ -69,7 +69,7 @@ async def login(login_data: LoginRequest, request: Request, db: Session = Depend ) # Update last login - user.last_login = datetime.utcnow() + user.last_login = datetime.now(timezone.utc) db.commit() access_token_expires = timedelta(minutes=settings.access_token_expire_minutes) @@ -144,7 +144,7 @@ async def read_users_me(current_user: User = Depends(get_current_user)): async def refresh_token_endpoint( request: Request, db: Session = Depends(get_db), - body: RefreshRequest | None = None, + body: RefreshRequest = None, ): """Issue a new access token using a valid, non-revoked refresh token. @@ -203,7 +203,7 @@ async def refresh_token_endpoint( if not user or not user.is_active: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found or inactive") - user.last_login = datetime.utcnow() + user.last_login = datetime.now(timezone.utc) db.commit() access_token_expires = timedelta(minutes=settings.access_token_expire_minutes) @@ -225,7 +225,7 @@ async def list_users( @router.post("/logout") -async def logout(body: RefreshRequest | None = None, db: Session = Depends(get_db)): +async def logout(body: RefreshRequest = None, db: Session = Depends(get_db)): """Revoke the provided refresh token. Idempotent and safe to call multiple times. The client should send a JSON body: { "refresh_token": "..." }. diff --git a/app/api/customers.py b/app/api/customers.py index fd3439f..f89fda1 100644 --- a/app/api/customers.py +++ b/app/api/customers.py @@ -4,7 +4,7 @@ Customer (Rolodex) API endpoints from typing import List, Optional, Union from fastapi import APIRouter, Depends, HTTPException, status, Query from sqlalchemy.orm import Session, joinedload -from sqlalchemy import or_, and_, func, asc, desc +from sqlalchemy import func from fastapi.responses import StreamingResponse import csv import io @@ -13,12 +13,16 @@ from app.database.base import get_db from app.models.rolodex import Rolodex, Phone from app.models.user import User from app.auth.security import get_current_user +from app.services.cache import invalidate_search_cache +from app.services.customers_search import apply_customer_filters, apply_customer_sorting, prepare_customer_csv_rows +from app.services.query_utils import apply_sorting, paginate_with_total router = APIRouter() # Pydantic schemas for request/response -from pydantic import BaseModel, EmailStr +from pydantic import BaseModel, EmailStr, Field +from pydantic.config import ConfigDict from datetime import date @@ -32,8 +36,7 @@ class PhoneResponse(BaseModel): location: Optional[str] phone: str - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) class CustomerBase(BaseModel): @@ -84,10 +87,11 @@ class CustomerUpdate(BaseModel): class CustomerResponse(CustomerBase): - phone_numbers: List[PhoneResponse] = [] + phone_numbers: List[PhoneResponse] = Field(default_factory=list) - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) + + @router.get("/search/phone") @@ -196,80 +200,17 @@ async def list_customers( try: base_query = db.query(Rolodex) - if search: - s = (search or "").strip() - s_lower = s.lower() - tokens = [t for t in s_lower.split() if t] - # Basic contains search on several fields (case-insensitive) - contains_any = or_( - func.lower(Rolodex.id).contains(s_lower), - func.lower(Rolodex.last).contains(s_lower), - func.lower(Rolodex.first).contains(s_lower), - func.lower(Rolodex.middle).contains(s_lower), - func.lower(Rolodex.city).contains(s_lower), - func.lower(Rolodex.email).contains(s_lower), - ) - # Multi-token name support: every token must match either first, middle, or last - name_tokens = [ - or_( - func.lower(Rolodex.first).contains(tok), - func.lower(Rolodex.middle).contains(tok), - func.lower(Rolodex.last).contains(tok), - ) - for tok in tokens - ] - combined = contains_any if not name_tokens else or_(contains_any, and_(*name_tokens)) - # Comma pattern: "Last, First" - last_first_filter = None - if "," in s_lower: - last_part, first_part = [p.strip() for p in s_lower.split(",", 1)] - if last_part and first_part: - last_first_filter = and_( - func.lower(Rolodex.last).contains(last_part), - func.lower(Rolodex.first).contains(first_part), - ) - elif last_part: - last_first_filter = func.lower(Rolodex.last).contains(last_part) - final_filter = or_(combined, last_first_filter) if last_first_filter is not None else combined - base_query = base_query.filter(final_filter) - - # Apply group/state filters (support single and multi-select) - effective_groups = [g for g in (groups or []) if g] or ([group] if group else []) - if effective_groups: - base_query = base_query.filter(Rolodex.group.in_(effective_groups)) - effective_states = [s for s in (states or []) if s] or ([state] if state else []) - if effective_states: - base_query = base_query.filter(Rolodex.abrev.in_(effective_states)) + base_query = apply_customer_filters( + base_query, + search=search, + group=group, + state=state, + groups=groups, + states=states, + ) # Apply sorting (whitelisted fields only) - normalized_sort_by = (sort_by or "id").lower() - normalized_sort_dir = (sort_dir or "asc").lower() - is_desc = normalized_sort_dir == "desc" - - order_columns = [] - if normalized_sort_by == "id": - order_columns = [Rolodex.id] - elif normalized_sort_by == "name": - # Sort by last, then first - order_columns = [Rolodex.last, Rolodex.first] - elif normalized_sort_by == "city": - # Sort by city, then state abbreviation - order_columns = [Rolodex.city, Rolodex.abrev] - elif normalized_sort_by == "email": - order_columns = [Rolodex.email] - else: - # Fallback to id to avoid arbitrary column injection - order_columns = [Rolodex.id] - - # Case-insensitive ordering where applicable, preserving None ordering default - ordered = [] - for col in order_columns: - # Use lower() for string-like cols; SQLAlchemy will handle non-string safely enough for SQLite/Postgres - expr = func.lower(col) if col.type.python_type in (str,) else col # type: ignore[attr-defined] - ordered.append(desc(expr) if is_desc else asc(expr)) - - if ordered: - base_query = base_query.order_by(*ordered) + base_query = apply_customer_sorting(base_query, sort_by=sort_by, sort_dir=sort_dir) customers = base_query.options(joinedload(Rolodex.phone_numbers)).offset(skip).limit(limit).all() if include_total: @@ -304,72 +245,16 @@ async def export_customers( try: base_query = db.query(Rolodex) - if search: - s = (search or "").strip() - s_lower = s.lower() - tokens = [t for t in s_lower.split() if t] - contains_any = or_( - func.lower(Rolodex.id).contains(s_lower), - func.lower(Rolodex.last).contains(s_lower), - func.lower(Rolodex.first).contains(s_lower), - func.lower(Rolodex.middle).contains(s_lower), - func.lower(Rolodex.city).contains(s_lower), - func.lower(Rolodex.email).contains(s_lower), - ) - name_tokens = [ - or_( - func.lower(Rolodex.first).contains(tok), - func.lower(Rolodex.middle).contains(tok), - func.lower(Rolodex.last).contains(tok), - ) - for tok in tokens - ] - combined = contains_any if not name_tokens else or_(contains_any, and_(*name_tokens)) - last_first_filter = None - if "," in s_lower: - last_part, first_part = [p.strip() for p in s_lower.split(",", 1)] - if last_part and first_part: - last_first_filter = and_( - func.lower(Rolodex.last).contains(last_part), - func.lower(Rolodex.first).contains(first_part), - ) - elif last_part: - last_first_filter = func.lower(Rolodex.last).contains(last_part) - final_filter = or_(combined, last_first_filter) if last_first_filter is not None else combined - base_query = base_query.filter(final_filter) + base_query = apply_customer_filters( + base_query, + search=search, + group=group, + state=state, + groups=groups, + states=states, + ) - effective_groups = [g for g in (groups or []) if g] or ([group] if group else []) - if effective_groups: - base_query = base_query.filter(Rolodex.group.in_(effective_groups)) - effective_states = [s for s in (states or []) if s] or ([state] if state else []) - if effective_states: - base_query = base_query.filter(Rolodex.abrev.in_(effective_states)) - - normalized_sort_by = (sort_by or "id").lower() - normalized_sort_dir = (sort_dir or "asc").lower() - is_desc = normalized_sort_dir == "desc" - - order_columns = [] - if normalized_sort_by == "id": - order_columns = [Rolodex.id] - elif normalized_sort_by == "name": - order_columns = [Rolodex.last, Rolodex.first] - elif normalized_sort_by == "city": - order_columns = [Rolodex.city, Rolodex.abrev] - elif normalized_sort_by == "email": - order_columns = [Rolodex.email] - else: - order_columns = [Rolodex.id] - - ordered = [] - for col in order_columns: - try: - expr = func.lower(col) if col.type.python_type in (str,) else col # type: ignore[attr-defined] - except Exception: - expr = col - ordered.append(desc(expr) if is_desc else asc(expr)) - if ordered: - base_query = base_query.order_by(*ordered) + base_query = apply_customer_sorting(base_query, sort_by=sort_by, sort_dir=sort_dir) if not export_all: if skip is not None: @@ -382,39 +267,10 @@ async def export_customers( # Prepare CSV output = io.StringIO() writer = csv.writer(output) - allowed_fields_in_order = ["id", "name", "group", "city", "state", "phone", "email"] - header_names = { - "id": "Customer ID", - "name": "Name", - "group": "Group", - "city": "City", - "state": "State", - "phone": "Primary Phone", - "email": "Email", - } - requested = [f.lower() for f in (fields or []) if isinstance(f, str)] - selected_fields = [f for f in allowed_fields_in_order if f in requested] if requested else allowed_fields_in_order - if not selected_fields: - selected_fields = allowed_fields_in_order - writer.writerow([header_names[f] for f in selected_fields]) - for c in customers: - full_name = f"{(c.first or '').strip()} {(c.last or '').strip()}".strip() - primary_phone = "" - try: - if c.phone_numbers: - primary_phone = c.phone_numbers[0].phone or "" - except Exception: - primary_phone = "" - row_map = { - "id": c.id, - "name": full_name, - "group": c.group or "", - "city": c.city or "", - "state": c.abrev or "", - "phone": primary_phone, - "email": c.email or "", - } - writer.writerow([row_map[f] for f in selected_fields]) + header_row, rows = prepare_customer_csv_rows(customers, fields) + writer.writerow(header_row) + for row in rows: + writer.writerow(row) output.seek(0) filename = "customers_export.csv" @@ -469,6 +325,10 @@ async def create_customer( db.commit() db.refresh(customer) + try: + await invalidate_search_cache() + except Exception: + pass return customer @@ -494,7 +354,10 @@ async def update_customer( db.commit() db.refresh(customer) - + try: + await invalidate_search_cache() + except Exception: + pass return customer @@ -515,17 +378,30 @@ async def delete_customer( db.delete(customer) db.commit() - + try: + await invalidate_search_cache() + except Exception: + pass return {"message": "Customer deleted successfully"} -@router.get("/{customer_id}/phones", response_model=List[PhoneResponse]) +class PaginatedPhonesResponse(BaseModel): + items: List[PhoneResponse] + total: int + + +@router.get("/{customer_id}/phones", response_model=Union[List[PhoneResponse], PaginatedPhonesResponse]) async def get_customer_phones( customer_id: str, + skip: int = Query(0, ge=0, description="Offset for pagination"), + limit: int = Query(100, ge=1, le=1000, description="Page size"), + sort_by: Optional[str] = Query("location", description="Sort by: location, phone"), + sort_dir: Optional[str] = Query("asc", description="Sort direction: asc or desc"), + include_total: bool = Query(False, description="When true, returns {items, total} instead of a plain list"), db: Session = Depends(get_db), current_user: User = Depends(get_current_user) ): - """Get customer phone numbers""" + """Get customer phone numbers with optional sorting/pagination""" customer = db.query(Rolodex).filter(Rolodex.id == customer_id).first() if not customer: @@ -534,7 +410,21 @@ async def get_customer_phones( detail="Customer not found" ) - phones = db.query(Phone).filter(Phone.rolodex_id == customer_id).all() + query = db.query(Phone).filter(Phone.rolodex_id == customer_id) + + query = apply_sorting( + query, + sort_by, + sort_dir, + allowed={ + "location": [Phone.location, Phone.phone], + "phone": [Phone.phone], + }, + ) + + phones, total = paginate_with_total(query, skip, limit, include_total) + if include_total: + return {"items": phones, "total": total or 0} return phones diff --git a/app/api/documents.py b/app/api/documents.py index 69e7e51..a33a22d 100644 --- a/app/api/documents.py +++ b/app/api/documents.py @@ -1,16 +1,19 @@ """ Document Management API endpoints - QDROs, Templates, and General Documents """ -from typing import List, Optional, Dict, Any +from __future__ import annotations +from typing import List, Optional, Dict, Any, Union from fastapi import APIRouter, Depends, HTTPException, status, Query, UploadFile, File, Form, Request from sqlalchemy.orm import Session, joinedload from sqlalchemy import or_, func, and_, desc, asc, text -from datetime import date, datetime +from datetime import date, datetime, timezone import os import uuid import shutil from app.database.base import get_db +from app.api.search_highlight import build_query_tokens +from app.services.query_utils import tokenized_ilike_filter, apply_pagination, apply_sorting, paginate_with_total from app.models.qdro import QDRO from app.models.files import File as FileModel from app.models.rolodex import Rolodex @@ -20,18 +23,20 @@ from app.auth.security import get_current_user from app.models.additional import Document from app.core.logging import get_logger from app.services.audit import audit_service +from app.services.cache import invalidate_search_cache router = APIRouter() # Pydantic schemas -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict class QDROBase(BaseModel): file_no: str version: str = "01" title: Optional[str] = None + form_name: Optional[str] = None content: Optional[str] = None status: str = "DRAFT" created_date: Optional[date] = None @@ -51,6 +56,7 @@ class QDROCreate(QDROBase): class QDROUpdate(BaseModel): version: Optional[str] = None title: Optional[str] = None + form_name: Optional[str] = None content: Optional[str] = None status: Optional[str] = None created_date: Optional[date] = None @@ -66,27 +72,61 @@ class QDROUpdate(BaseModel): class QDROResponse(QDROBase): id: int - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) -@router.get("/qdros/{file_no}", response_model=List[QDROResponse]) +class PaginatedQDROResponse(BaseModel): + items: List[QDROResponse] + total: int + + +@router.get("/qdros/{file_no}", response_model=Union[List[QDROResponse], PaginatedQDROResponse]) async def get_file_qdros( file_no: str, + skip: int = Query(0, ge=0, description="Offset for pagination"), + limit: int = Query(100, ge=1, le=1000, description="Page size"), + sort_by: Optional[str] = Query("updated", description="Sort by: updated, created, version, status"), + sort_dir: Optional[str] = Query("desc", description="Sort direction: asc or desc"), + include_total: bool = Query(False, description="When true, returns {items, total} instead of a plain list"), db: Session = Depends(get_db), current_user: User = Depends(get_current_user) ): - """Get QDROs for specific file""" - qdros = db.query(QDRO).filter(QDRO.file_no == file_no).order_by(QDRO.version).all() + """Get QDROs for a specific file with optional sorting/pagination""" + query = db.query(QDRO).filter(QDRO.file_no == file_no) + + # Sorting (whitelisted) + query = apply_sorting( + query, + sort_by, + sort_dir, + allowed={ + "updated": [QDRO.updated_at, QDRO.id], + "created": [QDRO.created_at, QDRO.id], + "version": [QDRO.version], + "status": [QDRO.status], + }, + ) + + qdros, total = paginate_with_total(query, skip, limit, include_total) + if include_total: + return {"items": qdros, "total": total or 0} return qdros -@router.get("/qdros/", response_model=List[QDROResponse]) +class PaginatedQDROResponse(BaseModel): + items: List[QDROResponse] + total: int + + +@router.get("/qdros/", response_model=Union[List[QDROResponse], PaginatedQDROResponse]) async def list_qdros( skip: int = Query(0, ge=0), limit: int = Query(50, ge=1, le=200), status_filter: Optional[str] = Query(None), search: Optional[str] = Query(None), + sort_by: Optional[str] = Query(None, description="Sort by: file_no, version, status, created, updated"), + sort_dir: Optional[str] = Query("asc", description="Sort direction: asc or desc"), + include_total: bool = Query(False, description="When true, returns {items, total} instead of a plain list"), db: Session = Depends(get_db), current_user: User = Depends(get_current_user) ): @@ -97,17 +137,37 @@ async def list_qdros( query = query.filter(QDRO.status == status_filter) if search: - query = query.filter( - or_( - QDRO.file_no.contains(search), - QDRO.title.contains(search), - QDRO.participant_name.contains(search), - QDRO.spouse_name.contains(search), - QDRO.plan_name.contains(search) - ) - ) - - qdros = query.offset(skip).limit(limit).all() + # DRY: tokenize and apply case-insensitive search across common QDRO fields + tokens = build_query_tokens(search) + filter_expr = tokenized_ilike_filter(tokens, [ + QDRO.file_no, + QDRO.form_name, + QDRO.pet, + QDRO.res, + QDRO.case_number, + QDRO.notes, + QDRO.status, + ]) + if filter_expr is not None: + query = query.filter(filter_expr) + + # Sorting (whitelisted) + query = apply_sorting( + query, + sort_by, + sort_dir, + allowed={ + "file_no": [QDRO.file_no], + "version": [QDRO.version], + "status": [QDRO.status], + "created": [QDRO.created_at], + "updated": [QDRO.updated_at], + }, + ) + + qdros, total = paginate_with_total(query, skip, limit, include_total) + if include_total: + return {"items": qdros, "total": total or 0} return qdros @@ -135,6 +195,10 @@ async def create_qdro( db.commit() db.refresh(qdro) + try: + await invalidate_search_cache() + except Exception: + pass return qdro @@ -189,6 +253,10 @@ async def update_qdro( db.commit() db.refresh(qdro) + try: + await invalidate_search_cache() + except Exception: + pass return qdro @@ -213,7 +281,10 @@ async def delete_qdro( db.delete(qdro) db.commit() - + try: + await invalidate_search_cache() + except Exception: + pass return {"message": "QDRO deleted successfully"} @@ -241,8 +312,7 @@ class TemplateResponse(TemplateBase): active: bool = True created_at: Optional[datetime] = None - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) # Document Generation Schema class DocumentGenerateRequest(BaseModel): @@ -269,13 +339,21 @@ class DocumentStats(BaseModel): recent_activity: List[Dict[str, Any]] -@router.get("/templates/", response_model=List[TemplateResponse]) +class PaginatedTemplatesResponse(BaseModel): + items: List[TemplateResponse] + total: int + + +@router.get("/templates/", response_model=Union[List[TemplateResponse], PaginatedTemplatesResponse]) async def list_templates( skip: int = Query(0, ge=0), limit: int = Query(50, ge=1, le=200), category: Optional[str] = Query(None), search: Optional[str] = Query(None), active_only: bool = Query(True), + sort_by: Optional[str] = Query(None, description="Sort by: form_id, form_name, category, created, updated"), + sort_dir: Optional[str] = Query("asc", description="Sort direction: asc or desc"), + include_total: bool = Query(False, description="When true, returns {items, total} instead of a plain list"), db: Session = Depends(get_db), current_user: User = Depends(get_current_user) ): @@ -289,14 +367,31 @@ async def list_templates( query = query.filter(FormIndex.category == category) if search: - query = query.filter( - or_( - FormIndex.form_name.contains(search), - FormIndex.form_id.contains(search) - ) - ) - - templates = query.offset(skip).limit(limit).all() + # DRY: tokenize and apply case-insensitive search for templates + tokens = build_query_tokens(search) + filter_expr = tokenized_ilike_filter(tokens, [ + FormIndex.form_name, + FormIndex.form_id, + FormIndex.category, + ]) + if filter_expr is not None: + query = query.filter(filter_expr) + + # Sorting (whitelisted) + query = apply_sorting( + query, + sort_by, + sort_dir, + allowed={ + "form_id": [FormIndex.form_id], + "form_name": [FormIndex.form_name], + "category": [FormIndex.category], + "created": [FormIndex.created_at], + "updated": [FormIndex.updated_at], + }, + ) + + templates, total = paginate_with_total(query, skip, limit, include_total) # Enhanced response with template content results = [] @@ -317,6 +412,8 @@ async def list_templates( "variables": _extract_variables_from_content(content) }) + if include_total: + return {"items": results, "total": total or 0} return results @@ -356,6 +453,10 @@ async def create_template( db.commit() db.refresh(form_index) + try: + await invalidate_search_cache() + except Exception: + pass return { "form_id": form_index.form_id, @@ -440,6 +541,10 @@ async def update_template( db.commit() db.refresh(template) + try: + await invalidate_search_cache() + except Exception: + pass # Get updated content template_lines = db.query(FormList).filter( @@ -480,6 +585,10 @@ async def delete_template( # Delete template db.delete(template) db.commit() + try: + await invalidate_search_cache() + except Exception: + pass return {"message": "Template deleted successfully"} @@ -574,7 +683,7 @@ async def generate_document( "file_name": file_name, "file_path": file_path, "size": file_size, - "created_at": datetime.now() + "created_at": datetime.now(timezone.utc) } @@ -629,32 +738,49 @@ async def get_document_stats( @router.get("/file/{file_no}/documents") async def get_file_documents( file_no: str, + sort_by: Optional[str] = Query("updated", description="Sort by: updated, created"), + sort_dir: Optional[str] = Query("desc", description="Sort direction: asc or desc"), + skip: int = Query(0, ge=0), + limit: int = Query(100, ge=1, le=1000), + include_total: bool = Query(False, description="When true, returns {items, total} instead of a plain list"), db: Session = Depends(get_db), current_user: User = Depends(get_current_user) ): - """Get all documents associated with a specific file""" - # Get QDROs for this file - qdros = db.query(QDRO).filter(QDRO.file_no == file_no).order_by(desc(QDRO.updated_at)).all() - - # Format response - documents = [ + """Get all documents associated with a specific file, with optional sorting/pagination""" + # Base query for QDROs tied to the file + query = db.query(QDRO).filter(QDRO.file_no == file_no) + + # Apply sorting using shared helper (map friendly names to columns) + query = apply_sorting( + query, + sort_by, + sort_dir, + allowed={ + "updated": [QDRO.updated_at, QDRO.id], + "created": [QDRO.created_at, QDRO.id], + }, + ) + + qdros, total = paginate_with_total(query, skip, limit, include_total) + + items = [ { "id": qdro.id, "type": "QDRO", "title": f"QDRO v{qdro.version}", "status": qdro.status, - "created_date": qdro.created_date.isoformat() if qdro.created_date else None, - "updated_at": qdro.updated_at.isoformat() if qdro.updated_at else None, - "file_no": qdro.file_no + "created_date": qdro.created_date.isoformat() if getattr(qdro, "created_date", None) else None, + "updated_at": qdro.updated_at.isoformat() if getattr(qdro, "updated_at", None) else None, + "file_no": qdro.file_no, } for qdro in qdros ] - - return { - "file_no": file_no, - "documents": documents, - "total_count": len(documents) - } + + payload = {"file_no": file_no, "documents": items, "total_count": (total if include_total else None)} + # Maintain previous shape by omitting total_count when include_total is False? The prior code always returned total_count. + # Keep total_count for backward compatibility but set to actual total when include_total else len(items) + payload["total_count"] = (total if include_total else len(items)) + return payload def _extract_variables_from_content(content: str) -> Dict[str, str]: diff --git a/app/api/files.py b/app/api/files.py index 8a4ace5..1fb39bc 100644 --- a/app/api/files.py +++ b/app/api/files.py @@ -1,25 +1,29 @@ """ File Management API endpoints """ -from typing import List, Optional, Dict, Any +from typing import List, Optional, Dict, Any, Union from fastapi import APIRouter, Depends, HTTPException, status, Query from sqlalchemy.orm import Session, joinedload from sqlalchemy import or_, func, and_, desc from datetime import date, datetime from app.database.base import get_db +from app.api.search_highlight import build_query_tokens +from app.services.query_utils import tokenized_ilike_filter, apply_pagination, apply_sorting, paginate_with_total from app.models.files import File from app.models.rolodex import Rolodex from app.models.ledger import Ledger from app.models.lookups import Employee, FileType, FileStatus from app.models.user import User from app.auth.security import get_current_user +from app.services.cache import invalidate_search_cache router = APIRouter() # Pydantic schemas from pydantic import BaseModel +from pydantic.config import ConfigDict class FileBase(BaseModel): @@ -67,17 +71,24 @@ class FileResponse(FileBase): amount_owing: float = 0.0 transferable: float = 0.0 - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) -@router.get("/", response_model=List[FileResponse]) +class PaginatedFilesResponse(BaseModel): + items: List[FileResponse] + total: int + + +@router.get("/", response_model=Union[List[FileResponse], PaginatedFilesResponse]) async def list_files( skip: int = Query(0, ge=0), limit: int = Query(50, ge=1, le=200), search: Optional[str] = Query(None), status_filter: Optional[str] = Query(None), employee_filter: Optional[str] = Query(None), + sort_by: Optional[str] = Query(None, description="Sort by: file_no, client, opened, closed, status, amount_owing, total_charges"), + sort_dir: Optional[str] = Query("asc", description="Sort direction: asc or desc"), + include_total: bool = Query(False, description="When true, returns {items, total} instead of a plain list"), db: Session = Depends(get_db), current_user: User = Depends(get_current_user) ): @@ -85,14 +96,17 @@ async def list_files( query = db.query(File) if search: - query = query.filter( - or_( - File.file_no.contains(search), - File.id.contains(search), - File.regarding.contains(search), - File.file_type.contains(search) - ) - ) + # DRY: tokenize and apply case-insensitive search consistently with search endpoints + tokens = build_query_tokens(search) + filter_expr = tokenized_ilike_filter(tokens, [ + File.file_no, + File.id, + File.regarding, + File.file_type, + File.memo, + ]) + if filter_expr is not None: + query = query.filter(filter_expr) if status_filter: query = query.filter(File.status == status_filter) @@ -100,7 +114,25 @@ async def list_files( if employee_filter: query = query.filter(File.empl_num == employee_filter) - files = query.offset(skip).limit(limit).all() + # Sorting (whitelisted) + query = apply_sorting( + query, + sort_by, + sort_dir, + allowed={ + "file_no": [File.file_no], + "client": [File.id], + "opened": [File.opened], + "closed": [File.closed], + "status": [File.status], + "amount_owing": [File.amount_owing], + "total_charges": [File.total_charges], + }, + ) + + files, total = paginate_with_total(query, skip, limit, include_total) + if include_total: + return {"items": files, "total": total or 0} return files @@ -142,6 +174,10 @@ async def create_file( db.commit() db.refresh(file_obj) + try: + await invalidate_search_cache() + except Exception: + pass return file_obj @@ -167,7 +203,10 @@ async def update_file( db.commit() db.refresh(file_obj) - + try: + await invalidate_search_cache() + except Exception: + pass return file_obj @@ -188,7 +227,10 @@ async def delete_file( db.delete(file_obj) db.commit() - + try: + await invalidate_search_cache() + except Exception: + pass return {"message": "File deleted successfully"} @@ -433,11 +475,13 @@ async def advanced_file_search( query = query.filter(File.file_no.contains(file_no)) if client_name: + # SQLite-safe concatenation for first + last name + full_name_expr = (func.coalesce(Rolodex.first, '') + ' ' + func.coalesce(Rolodex.last, '')) query = query.join(Rolodex).filter( or_( Rolodex.first.contains(client_name), Rolodex.last.contains(client_name), - func.concat(Rolodex.first, ' ', Rolodex.last).contains(client_name) + full_name_expr.contains(client_name) ) ) diff --git a/app/api/financial.py b/app/api/financial.py index 54299ac..ff7e5e7 100644 --- a/app/api/financial.py +++ b/app/api/financial.py @@ -1,11 +1,11 @@ """ Financial/Ledger API endpoints """ -from typing import List, Optional, Dict, Any +from typing import List, Optional, Dict, Any, Union from fastapi import APIRouter, Depends, HTTPException, status, Query from sqlalchemy.orm import Session, joinedload from sqlalchemy import or_, func, and_, desc, asc, text -from datetime import date, datetime, timedelta +from datetime import date, datetime, timedelta, timezone from app.database.base import get_db from app.models.ledger import Ledger @@ -14,12 +14,14 @@ from app.models.rolodex import Rolodex from app.models.lookups import Employee, TransactionType, TransactionCode from app.models.user import User from app.auth.security import get_current_user +from app.services.cache import invalidate_search_cache +from app.services.query_utils import apply_sorting, paginate_with_total router = APIRouter() # Pydantic schemas -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict class LedgerBase(BaseModel): @@ -57,8 +59,7 @@ class LedgerResponse(LedgerBase): id: int item_no: int - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) class FinancialSummary(BaseModel): @@ -75,23 +76,46 @@ class FinancialSummary(BaseModel): billed_amount: float -@router.get("/ledger/{file_no}", response_model=List[LedgerResponse]) +class PaginatedLedgerResponse(BaseModel): + items: List[LedgerResponse] + total: int + + +@router.get("/ledger/{file_no}", response_model=Union[List[LedgerResponse], PaginatedLedgerResponse]) async def get_file_ledger( file_no: str, - skip: int = Query(0, ge=0), - limit: int = Query(100, ge=1, le=500), - billed_only: Optional[bool] = Query(None), + skip: int = Query(0, ge=0, description="Offset for pagination"), + limit: int = Query(100, ge=1, le=500, description="Page size"), + billed_only: Optional[bool] = Query(None, description="Filter billed vs unbilled entries"), + sort_by: Optional[str] = Query("date", description="Sort by: date, item_no, amount, billed"), + sort_dir: Optional[str] = Query("desc", description="Sort direction: asc or desc"), + include_total: bool = Query(False, description="When true, returns {items, total} instead of a plain list"), db: Session = Depends(get_db), current_user: User = Depends(get_current_user) ): """Get ledger entries for specific file""" - query = db.query(Ledger).filter(Ledger.file_no == file_no).order_by(Ledger.date.desc()) + query = db.query(Ledger).filter(Ledger.file_no == file_no) if billed_only is not None: billed_filter = "Y" if billed_only else "N" query = query.filter(Ledger.billed == billed_filter) - - entries = query.offset(skip).limit(limit).all() + + # Sorting (whitelisted) + query = apply_sorting( + query, + sort_by, + sort_dir, + allowed={ + "date": [Ledger.date, Ledger.item_no], + "item_no": [Ledger.item_no], + "amount": [Ledger.amount], + "billed": [Ledger.billed, Ledger.date], + }, + ) + + entries, total = paginate_with_total(query, skip, limit, include_total) + if include_total: + return {"items": entries, "total": total or 0} return entries @@ -127,6 +151,10 @@ async def create_ledger_entry( # Update file balances (simplified version) await _update_file_balances(file_obj, db) + try: + await invalidate_search_cache() + except Exception: + pass return entry @@ -158,6 +186,10 @@ async def update_ledger_entry( if file_obj: await _update_file_balances(file_obj, db) + try: + await invalidate_search_cache() + except Exception: + pass return entry @@ -185,6 +217,10 @@ async def delete_ledger_entry( if file_obj: await _update_file_balances(file_obj, db) + try: + await invalidate_search_cache() + except Exception: + pass return {"message": "Ledger entry deleted successfully"} diff --git a/app/api/import_data.py b/app/api/import_data.py index 11f0127..e89b37a 100644 --- a/app/api/import_data.py +++ b/app/api/import_data.py @@ -7,7 +7,7 @@ import re import os from pathlib import Path from difflib import SequenceMatcher -from datetime import datetime, date +from datetime import datetime, date, timezone from decimal import Decimal from typing import List, Dict, Any, Optional, Tuple from fastapi import APIRouter, Depends, HTTPException, UploadFile, File as UploadFileForm, Form, Query @@ -19,8 +19,8 @@ from app.models.rolodex import Rolodex, Phone from app.models.files import File from app.models.ledger import Ledger from app.models.qdro import QDRO -from app.models.pensions import Pension, PensionSchedule, MarriageHistory, DeathBenefit, SeparationAgreement, LifeTable, NumberTable -from app.models.lookups import Employee, FileType, FileStatus, TransactionType, TransactionCode, State, GroupLookup, Footer, PlanInfo, FormIndex, FormList, PrinterSetup, SystemSetup +from app.models.pensions import Pension, PensionSchedule, MarriageHistory, DeathBenefit, SeparationAgreement, LifeTable, NumberTable, PensionResult +from app.models.lookups import Employee, FileType, FileStatus, TransactionType, TransactionCode, State, GroupLookup, Footer, PlanInfo, FormIndex, FormList, PrinterSetup, SystemSetup, FormKeyword from app.models.additional import Payment, Deposit, FileNote, FormVariable, ReportVariable from app.models.flexible import FlexibleImport from app.models.audit import ImportAudit, ImportAuditFile @@ -28,6 +28,25 @@ from app.config import settings router = APIRouter(tags=["import"]) +# Common encodings to try for legacy CSV files (order matters) +ENCODINGS = [ + 'utf-8-sig', + 'utf-8', + 'windows-1252', + 'iso-8859-1', + 'cp1252', +] + +# Unified import order used across batch operations +IMPORT_ORDER = [ + "STATES.csv", "GRUPLKUP.csv", "EMPLOYEE.csv", "FILETYPE.csv", "FILESTAT.csv", + "TRNSTYPE.csv", "TRNSLKUP.csv", "FOOTERS.csv", "SETUP.csv", "PRINTERS.csv", + "INX_LKUP.csv", + "ROLODEX.csv", "PHONE.csv", "FILES.csv", "LEDGER.csv", "TRNSACTN.csv", + "QDROS.csv", "PENSIONS.csv", "LIFETABL.csv", "NUMBERAL.csv", "PLANINFO.csv", "RESULTS.csv", "PAYMENTS.csv", "DEPOSITS.csv", + "FILENOTS.csv", "FORM_INX.csv", "FORM_LST.csv", "FVARLKUP.csv", "RVARLKUP.csv" +] + # CSV to Model mapping CSV_MODEL_MAPPING = { @@ -56,7 +75,6 @@ CSV_MODEL_MAPPING = { "FOOTERS.csv": Footer, "PLANINFO.csv": PlanInfo, # Legacy alternate names from export directories - "SCHEDULE.csv": PensionSchedule, "FORM_INX.csv": FormIndex, "FORM_LST.csv": FormList, "PRINTERS.csv": PrinterSetup, @@ -67,7 +85,9 @@ CSV_MODEL_MAPPING = { "FVARLKUP.csv": FormVariable, "RVARLKUP.csv": ReportVariable, "PAYMENTS.csv": Payment, - "TRNSACTN.csv": Ledger # Maps to existing Ledger model (same structure) + "TRNSACTN.csv": Ledger, # Maps to existing Ledger model (same structure) + "INX_LKUP.csv": FormKeyword, + "RESULTS.csv": PensionResult } # Field mappings for CSV columns to database fields @@ -230,8 +250,12 @@ FIELD_MAPPINGS = { "Default_Rate": "default_rate" }, "FILESTAT.csv": { + "Status": "status_code", "Status_Code": "status_code", + "Definition": "description", "Description": "description", + "Send": "send", + "Footer_Code": "footer_code", "Sort_Order": "sort_order" }, "FOOTERS.csv": { @@ -253,22 +277,44 @@ FIELD_MAPPINGS = { "Phone": "phone", "Notes": "notes" }, + "INX_LKUP.csv": { + "Keyword": "keyword", + "Description": "description" + }, "FORM_INX.csv": { - "Form_Id": "form_id", - "Form_Name": "form_name", - "Category": "category" + "Name": "form_id", + "Keyword": "keyword" }, "FORM_LST.csv": { - "Form_Id": "form_id", - "Line_Number": "line_number", - "Content": "content" + "Name": "form_id", + "Memo": "content", + "Status": "status" }, "PRINTERS.csv": { + # Legacy variants "Printer_Name": "printer_name", "Description": "description", "Driver": "driver", "Port": "port", - "Default_Printer": "default_printer" + "Default_Printer": "default_printer", + # Observed legacy headers from export + "Number": "number", + "Name": "printer_name", + "Page_Break": "page_break", + "Setup_St": "setup_st", + "Reset_St": "reset_st", + "B_Underline": "b_underline", + "E_Underline": "e_underline", + "B_Bold": "b_bold", + "E_Bold": "e_bold", + # Optional report toggles + "Phone_Book": "phone_book", + "Rolodex_Info": "rolodex_info", + "Envelope": "envelope", + "File_Cabinet": "file_cabinet", + "Accounts": "accounts", + "Statements": "statements", + "Calendar": "calendar", }, "SETUP.csv": { "Setting_Key": "setting_key", @@ -285,32 +331,98 @@ FIELD_MAPPINGS = { "MARRIAGE.csv": { "File_No": "file_no", "Version": "version", - "Marriage_Date": "marriage_date", - "Separation_Date": "separation_date", - "Divorce_Date": "divorce_date" + "Married_From": "married_from", + "Married_To": "married_to", + "Married_Years": "married_years", + "Service_From": "service_from", + "Service_To": "service_to", + "Service_Years": "service_years", + "Marital_%": "marital_percent" }, "DEATH.csv": { "File_No": "file_no", "Version": "version", - "Benefit_Type": "benefit_type", - "Benefit_Amount": "benefit_amount", - "Beneficiary": "beneficiary" + "Lump1": "lump1", + "Lump2": "lump2", + "Growth1": "growth1", + "Growth2": "growth2", + "Disc1": "disc1", + "Disc2": "disc2" }, "SEPARATE.csv": { "File_No": "file_no", "Version": "version", - "Agreement_Date": "agreement_date", - "Terms": "terms" + "Separation_Rate": "terms" }, "LIFETABL.csv": { - "Age": "age", - "Male_Mortality": "male_mortality", - "Female_Mortality": "female_mortality" + "AGE": "age", + "LE_AA": "le_aa", + "NA_AA": "na_aa", + "LE_AM": "le_am", + "NA_AM": "na_am", + "LE_AF": "le_af", + "NA_AF": "na_af", + "LE_WA": "le_wa", + "NA_WA": "na_wa", + "LE_WM": "le_wm", + "NA_WM": "na_wm", + "LE_WF": "le_wf", + "NA_WF": "na_wf", + "LE_BA": "le_ba", + "NA_BA": "na_ba", + "LE_BM": "le_bm", + "NA_BM": "na_bm", + "LE_BF": "le_bf", + "NA_BF": "na_bf", + "LE_HA": "le_ha", + "NA_HA": "na_ha", + "LE_HM": "le_hm", + "NA_HM": "na_hm", + "LE_HF": "le_hf", + "NA_HF": "na_hf" }, "NUMBERAL.csv": { - "Table_Name": "table_name", + "Month": "month", + "NA_AA": "na_aa", + "NA_AM": "na_am", + "NA_AF": "na_af", + "NA_WA": "na_wa", + "NA_WM": "na_wm", + "NA_WF": "na_wf", + "NA_BA": "na_ba", + "NA_BM": "na_bm", + "NA_BF": "na_bf", + "NA_HA": "na_ha", + "NA_HM": "na_hm", + "NA_HF": "na_hf" + }, + "RESULTS.csv": { + "Accrued": "accrued", + "Start_Age": "start_age", + "COLA": "cola", + "Withdrawal": "withdrawal", + "Pre_DR": "pre_dr", + "Post_DR": "post_dr", + "Tax_Rate": "tax_rate", "Age": "age", - "Value": "value" + "Years_From": "years_from", + "Life_Exp": "life_exp", + "EV_Monthly": "ev_monthly", + "Payments": "payments", + "Pay_Out": "pay_out", + "Fund_Value": "fund_value", + "PV": "pv", + "Mortality": "mortality", + "PV_AM": "pv_am", + "PV_AMT": "pv_amt", + "PV_Pre_DB": "pv_pre_db", + "PV_Annuity": "pv_annuity", + "WV_AT": "wv_at", + "PV_Plan": "pv_plan", + "Years_Married": "years_married", + "Years_Service": "years_service", + "Marr_Per": "marr_per", + "Marr_Amt": "marr_amt" }, # Additional CSV file mappings "DEPOSITS.csv": { @@ -357,7 +469,7 @@ FIELD_MAPPINGS = { } -def parse_date(date_str: str) -> Optional[datetime]: +def parse_date(date_str: str) -> Optional[date]: """Parse date string in various formats""" if not date_str or date_str.strip() == "": return None @@ -612,7 +724,11 @@ def convert_value(value: str, field_name: str) -> Any: return parsed_date # Boolean fields - if any(word in field_name.lower() for word in ["active", "default_printer", "billed", "transferable"]): + if any(word in field_name.lower() for word in [ + "active", "default_printer", "billed", "transferable", "send", + # PrinterSetup legacy toggles + "phone_book", "rolodex_info", "envelope", "file_cabinet", "accounts", "statements", "calendar" + ]): if value.lower() in ["true", "1", "yes", "y", "on", "active"]: return True elif value.lower() in ["false", "0", "no", "n", "off", "inactive"]: @@ -621,7 +737,11 @@ def convert_value(value: str, field_name: str) -> Any: return None # Numeric fields (float) - if any(word in field_name.lower() for word in ["rate", "hour", "bal", "fee", "amount", "owing", "transfer", "valu", "accrued", "vested", "cola", "tax", "percent", "benefit_amount", "mortality", "value"]): + if any(word in field_name.lower() for word in [ + "rate", "hour", "bal", "fee", "amount", "owing", "transfer", "valu", + "accrued", "vested", "cola", "tax", "percent", "benefit_amount", "mortality", + "value" + ]) or field_name.lower().startswith(("na_", "le_")): try: # Remove currency symbols and commas cleaned_value = value.replace("$", "").replace(",", "").replace("%", "") @@ -630,7 +750,9 @@ def convert_value(value: str, field_name: str) -> Any: return 0.0 # Integer fields - if any(word in field_name.lower() for word in ["item_no", "age", "start_age", "version", "line_number", "sort_order", "empl_num"]): + if any(word in field_name.lower() for word in [ + "item_no", "age", "start_age", "version", "line_number", "sort_order", "empl_num", "month", "number" + ]): try: return int(float(value)) # Handle cases like "1.0" except ValueError: @@ -673,11 +795,18 @@ async def get_available_csv_files(current_user: User = Depends(get_current_user) "available_files": list(CSV_MODEL_MAPPING.keys()), "descriptions": { "ROLODEX.csv": "Customer/contact information", + "ROLEX_V.csv": "Customer/contact information (alias)", "PHONE.csv": "Phone numbers linked to customers", "FILES.csv": "Client files and cases", + "FILES_R.csv": "Client files and cases (alias)", + "FILES_V.csv": "Client files and cases (alias)", "LEDGER.csv": "Financial transactions per file", "QDROS.csv": "Legal documents and court orders", "PENSIONS.csv": "Pension calculation data", + "SCHEDULE.csv": "Vesting schedules for pensions", + "MARRIAGE.csv": "Marriage history data", + "DEATH.csv": "Death benefit calculations", + "SEPARATE.csv": "Separation agreements", "EMPLOYEE.csv": "Staff and employee information", "STATES.csv": "US States lookup table", "FILETYPE.csv": "File type categories", @@ -688,7 +817,12 @@ async def get_available_csv_files(current_user: User = Depends(get_current_user) "FVARLKUP.csv": "Form template variables", "RVARLKUP.csv": "Report template variables", "PAYMENTS.csv": "Individual payments within deposits", - "TRNSACTN.csv": "Transaction details (maps to Ledger)" + "TRNSACTN.csv": "Transaction details (maps to Ledger)", + "INX_LKUP.csv": "Form keywords lookup", + "PLANINFO.csv": "Pension plan information", + "RESULTS.csv": "Pension computed results", + "LIFETABL.csv": "Life expectancy table by age, sex, and race (rich typed)", + "NUMBERAL.csv": "Monthly survivor counts by sex and race (rich typed)" }, "auto_discovery": True } @@ -724,7 +858,7 @@ async def import_csv_data( content = await file.read() # Try multiple encodings for legacy CSV files - encodings = ['utf-8', 'windows-1252', 'iso-8859-1', 'cp1252'] + encodings = ENCODINGS csv_content = None for encoding in encodings: try: @@ -736,34 +870,7 @@ async def import_csv_data( if csv_content is None: raise HTTPException(status_code=400, detail="Could not decode CSV file. Please ensure it's saved in UTF-8, Windows-1252, or ISO-8859-1 encoding.") - # Preprocess CSV content to fix common legacy issues - def preprocess_csv(content): - lines = content.split('\n') - cleaned_lines = [] - i = 0 - - while i < len(lines): - line = lines[i] - # If line doesn't have the expected number of commas, it might be a broken multi-line field - if i == 0: # Header line - cleaned_lines.append(line) - expected_comma_count = line.count(',') - i += 1 - continue - - # Check if this line has the expected number of commas - if line.count(',') < expected_comma_count: - # This might be a continuation of the previous line - # Try to merge with previous line - if cleaned_lines: - cleaned_lines[-1] += " " + line.replace('\n', ' ').replace('\r', ' ') - else: - cleaned_lines.append(line) - else: - cleaned_lines.append(line) - i += 1 - - return '\n'.join(cleaned_lines) + # Note: preprocess_csv helper removed as unused; robust parsing handled below # Custom robust parser for problematic legacy CSV files class MockCSVReader: @@ -791,7 +898,7 @@ async def import_csv_data( header_reader = csv.reader(io.StringIO(lines[0])) headers = next(header_reader) headers = [h.strip() for h in headers] - print(f"DEBUG: Found {len(headers)} headers: {headers}") + # Debug logging removed in API path; rely on audit/logging if needed # Build dynamic header mapping for this file/model mapping_info = _build_dynamic_mapping(headers, model_class, file_type) @@ -829,17 +936,21 @@ async def import_csv_data( continue csv_reader = MockCSVReader(rows_data, headers) - print(f"SUCCESS: Parsed {len(rows_data)} rows (skipped {skipped_rows} malformed rows)") + # Parsing summary suppressed to avoid noisy stdout in API except Exception as e: - print(f"Custom parsing failed: {e}") + # Keep error minimal for client; internal logging can capture 'e' raise HTTPException(status_code=400, detail=f"Could not parse CSV file. The file appears to have serious formatting issues. Error: {str(e)}") imported_count = 0 + created_count = 0 + updated_count = 0 errors = [] flexible_saved = 0 mapped_headers = mapping_info.get("mapped_headers", {}) unmapped_headers = mapping_info.get("unmapped_headers", []) + # Special handling: assign line numbers per form for FORM_LST.csv + form_lst_line_counters: Dict[str, int] = {} # If replace_existing is True, delete all existing records and related flexible extras if replace_existing: @@ -860,6 +971,16 @@ async def import_csv_data( converted_value = convert_value(row[csv_field], db_field) if converted_value is not None: model_data[db_field] = converted_value + + # Inject sequential line_number for FORM_LST rows grouped by form_id + if file_type == "FORM_LST.csv": + form_id_value = model_data.get("form_id") + if form_id_value: + current = form_lst_line_counters.get(str(form_id_value), 0) + 1 + form_lst_line_counters[str(form_id_value)] = current + # Only set if not provided + if "line_number" not in model_data: + model_data["line_number"] = current # Skip empty rows if not any(model_data.values()): @@ -902,10 +1023,43 @@ async def import_csv_data( if 'file_no' not in model_data or not model_data['file_no']: continue # Skip ledger records without file number - # Create model instance - instance = model_class(**model_data) - db.add(instance) - db.flush() # Ensure PK is available + # Create or update model instance + instance = None + # Upsert behavior for printers + if model_class == PrinterSetup: + # Determine primary key field name + _, pk_names = _get_model_columns(model_class) + pk_field_name_local = pk_names[0] if len(pk_names) == 1 else None + pk_value_local = model_data.get(pk_field_name_local) if pk_field_name_local else None + if pk_field_name_local and pk_value_local: + existing = db.query(model_class).filter(getattr(model_class, pk_field_name_local) == pk_value_local).first() + if existing: + # Update mutable fields + for k, v in model_data.items(): + if k != pk_field_name_local: + setattr(existing, k, v) + instance = existing + updated_count += 1 + else: + instance = model_class(**model_data) + db.add(instance) + created_count += 1 + else: + # Fallback to insert if PK missing + instance = model_class(**model_data) + db.add(instance) + created_count += 1 + db.flush() + # Enforce single default + try: + if bool(model_data.get("default_printer")): + db.query(model_class).filter(getattr(model_class, pk_field_name_local) != getattr(instance, pk_field_name_local)).update({model_class.default_printer: False}) + except Exception: + pass + else: + instance = model_class(**model_data) + db.add(instance) + db.flush() # Ensure PK is available # Capture PK details for flexible storage linkage (single-column PKs only) _, pk_names = _get_model_columns(model_class) @@ -980,6 +1134,10 @@ async def import_csv_data( "flexible_saved_rows": flexible_saved, }, } + # Include create/update breakdown for printers + if file_type == "PRINTERS.csv": + result["created_count"] = created_count + result["updated_count"] = updated_count if errors: result["warning"] = f"Import completed with {len(errors)} errors" @@ -987,9 +1145,7 @@ async def import_csv_data( return result except Exception as e: - print(f"IMPORT ERROR DEBUG: {type(e).__name__}: {str(e)}") - import traceback - print(f"TRACEBACK: {traceback.format_exc()}") + # Suppress stdout debug prints in API layer db.rollback() raise HTTPException(status_code=500, detail=f"Import failed: {str(e)}") @@ -1071,7 +1227,7 @@ async def validate_csv_file( content = await file.read() # Try multiple encodings for legacy CSV files - encodings = ['utf-8', 'windows-1252', 'iso-8859-1', 'cp1252'] + encodings = ENCODINGS csv_content = None for encoding in encodings: try: @@ -1083,18 +1239,6 @@ async def validate_csv_file( if csv_content is None: raise HTTPException(status_code=400, detail="Could not decode CSV file. Please ensure it's saved in UTF-8, Windows-1252, or ISO-8859-1 encoding.") - # Parse CSV with fallback to robust line-by-line parsing - def parse_csv_with_fallback(text: str) -> Tuple[List[Dict[str, str]], List[str]]: - try: - reader = csv.DictReader(io.StringIO(text), delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL) - headers_local = reader.fieldnames or [] - rows_local = [] - for r in reader: - rows_local.append(r) - return rows_local, headers_local - except Exception: - return parse_csv_robust(text) - rows_list, csv_headers = parse_csv_with_fallback(csv_content) model_class = CSV_MODEL_MAPPING[file_type] mapping_info = _build_dynamic_mapping(csv_headers, model_class, file_type) @@ -1142,9 +1286,7 @@ async def validate_csv_file( } except Exception as e: - print(f"VALIDATION ERROR DEBUG: {type(e).__name__}: {str(e)}") - import traceback - print(f"VALIDATION TRACEBACK: {traceback.format_exc()}") + # Suppress stdout debug prints in API layer raise HTTPException(status_code=500, detail=f"Validation failed: {str(e)}") @@ -1199,7 +1341,7 @@ async def batch_validate_csv_files( content = await file.read() # Try multiple encodings for legacy CSV files (include BOM-friendly utf-8-sig) - encodings = ['utf-8-sig', 'utf-8', 'windows-1252', 'iso-8859-1', 'cp1252'] + encodings = ENCODINGS csv_content = None for encoding in encodings: try: @@ -1302,13 +1444,7 @@ async def batch_import_csv_files( raise HTTPException(status_code=400, detail="Maximum 25 files allowed per batch") # Define optimal import order based on dependencies - import_order = [ - "STATES.csv", "GRUPLKUP.csv", "EMPLOYEE.csv", "FILETYPE.csv", "FILESTAT.csv", - "TRNSTYPE.csv", "TRNSLKUP.csv", "FOOTERS.csv", "SETUP.csv", "PRINTERS.csv", - "ROLODEX.csv", "PHONE.csv", "FILES.csv", "LEDGER.csv", "TRNSACTN.csv", - "QDROS.csv", "PENSIONS.csv", "PLANINFO.csv", "PAYMENTS.csv", "DEPOSITS.csv", - "FILENOTS.csv", "FORM_INX.csv", "FORM_LST.csv", "FVARLKUP.csv", "RVARLKUP.csv" - ] + import_order = IMPORT_ORDER # Sort uploaded files by optimal import order file_map = {f.filename: f for f in files} @@ -1365,7 +1501,7 @@ async def batch_import_csv_files( saved_path = str(file_path) except Exception: saved_path = None - encodings = ['utf-8-sig', 'utf-8', 'windows-1252', 'iso-8859-1', 'cp1252'] + encodings = ENCODINGS csv_content = None for encoding in encodings: try: @@ -1466,7 +1602,7 @@ async def batch_import_csv_files( saved_path = None # Try multiple encodings for legacy CSV files - encodings = ['utf-8-sig', 'utf-8', 'windows-1252', 'iso-8859-1', 'cp1252'] + encodings = ENCODINGS csv_content = None for encoding in encodings: try: @@ -1505,6 +1641,8 @@ async def batch_import_csv_files( imported_count = 0 errors = [] flexible_saved = 0 + # Special handling: assign line numbers per form for FORM_LST.csv + form_lst_line_counters: Dict[str, int] = {} # If replace_existing is True and this is the first file of this type if replace_existing: @@ -1523,6 +1661,15 @@ async def batch_import_csv_files( converted_value = convert_value(row[csv_field], db_field) if converted_value is not None: model_data[db_field] = converted_value + + # Inject sequential line_number for FORM_LST rows grouped by form_id + if file_type == "FORM_LST.csv": + form_id_value = model_data.get("form_id") + if form_id_value: + current = form_lst_line_counters.get(str(form_id_value), 0) + 1 + form_lst_line_counters[str(form_id_value)] = current + if "line_number" not in model_data: + model_data["line_number"] = current if not any(model_data.values()): continue @@ -1697,7 +1844,7 @@ async def batch_import_csv_files( "completed_with_errors" if summary["successful_files"] > 0 else "failed" ) audit_row.message = f"Batch import completed: {audit_row.successful_files}/{audit_row.total_files} files" - audit_row.finished_at = datetime.utcnow() + audit_row.finished_at = datetime.now(timezone.utc) audit_row.details = { "files": [ {"file_type": r.get("file_type"), "status": r.get("status"), "imported_count": r.get("imported_count", 0), "errors": r.get("errors", 0)} @@ -1844,13 +1991,7 @@ async def rerun_failed_files( raise HTTPException(status_code=400, detail="No saved files available to rerun. Upload again.") # Import order for sorting - import_order = [ - "STATES.csv", "GRUPLKUP.csv", "EMPLOYEE.csv", "FILETYPE.csv", "FILESTAT.csv", - "TRNSTYPE.csv", "TRNSLKUP.csv", "FOOTERS.csv", "SETUP.csv", "PRINTERS.csv", - "ROLODEX.csv", "PHONE.csv", "FILES.csv", "LEDGER.csv", "TRNSACTN.csv", - "QDROS.csv", "PENSIONS.csv", "PLANINFO.csv", "PAYMENTS.csv", "DEPOSITS.csv", - "FILENOTS.csv", "FORM_INX.csv", "FORM_LST.csv", "FVARLKUP.csv", "RVARLKUP.csv" - ] + import_order = IMPORT_ORDER order_index = {name: i for i, name in enumerate(import_order)} items.sort(key=lambda x: order_index.get(x[0], len(import_order) + 1)) @@ -1898,7 +2039,7 @@ async def rerun_failed_files( if file_type not in CSV_MODEL_MAPPING: # Flexible-only path - encodings = ['utf-8-sig', 'utf-8', 'windows-1252', 'iso-8859-1', 'cp1252'] + encodings = ENCODINGS csv_content = None for enc in encodings: try: @@ -1964,7 +2105,7 @@ async def rerun_failed_files( # Known model path model_class = CSV_MODEL_MAPPING[file_type] - encodings = ['utf-8-sig', 'utf-8', 'windows-1252', 'iso-8859-1', 'cp1252'] + encodings = ENCODINGS csv_content = None for enc in encodings: try: @@ -1996,6 +2137,8 @@ async def rerun_failed_files( unmapped_headers = mapping_info["unmapped_headers"] imported_count = 0 errors: List[Dict[str, Any]] = [] + # Special handling: assign line numbers per form for FORM_LST.csv + form_lst_line_counters: Dict[str, int] = {} if replace_existing: db.query(model_class).delete() @@ -2013,6 +2156,14 @@ async def rerun_failed_files( converted_value = convert_value(row[csv_field], db_field) if converted_value is not None: model_data[db_field] = converted_value + # Inject sequential line_number for FORM_LST rows grouped by form_id + if file_type == "FORM_LST.csv": + form_id_value = model_data.get("form_id") + if form_id_value: + current = form_lst_line_counters.get(str(form_id_value), 0) + 1 + form_lst_line_counters[str(form_id_value)] = current + if "line_number" not in model_data: + model_data["line_number"] = current if not any(model_data.values()): continue required_fields = _get_required_fields(model_class) @@ -2147,7 +2298,7 @@ async def rerun_failed_files( "completed_with_errors" if summary["successful_files"] > 0 else "failed" ) rerun_audit.message = f"Rerun completed: {rerun_audit.successful_files}/{rerun_audit.total_files} files" - rerun_audit.finished_at = datetime.utcnow() + rerun_audit.finished_at = datetime.now(timezone.utc) rerun_audit.details = {"rerun_of": audit_id} db.add(rerun_audit) db.commit() @@ -2183,7 +2334,7 @@ async def upload_flexible_only( db.commit() content = await file.read() - encodings = ["utf-8-sig", "utf-8", "windows-1252", "iso-8859-1", "cp1252"] + encodings = ENCODINGS csv_content = None for encoding in encodings: try: diff --git a/app/api/mortality.py b/app/api/mortality.py new file mode 100644 index 0000000..5151775 --- /dev/null +++ b/app/api/mortality.py @@ -0,0 +1,72 @@ +""" +Mortality/Life Table API endpoints + +Provides read endpoints to query life tables by age and number tables by month, +filtered by sex (M/F/A) and race (W/B/H/A). +""" + +from typing import Optional +from fastapi import APIRouter, Depends, HTTPException, Query, status, Path +from pydantic import BaseModel, Field +from sqlalchemy.orm import Session + +from app.database.base import get_db +from app.models.user import User +from app.auth.security import get_current_user +from app.services.mortality import get_life_values, get_number_value, InvalidCodeError + + +router = APIRouter() + + +class LifeResponse(BaseModel): + age: int + sex: str = Field(description="M, F, or A (all)") + race: str = Field(description="W, B, H, or A (all)") + le: Optional[float] + na: Optional[float] + + +class NumberResponse(BaseModel): + month: int + sex: str = Field(description="M, F, or A (all)") + race: str = Field(description="W, B, H, or A (all)") + na: Optional[float] + + +@router.get("/life/{age}", response_model=LifeResponse) +async def get_life_entry( + age: int = Path(..., ge=0, description="Age in years (>= 0)"), + sex: str = Query("A", min_length=1, max_length=1, description="M, F, or A (all)"), + race: str = Query("A", min_length=1, max_length=1, description="W, B, H, or A (all)"), + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user), +): + """Get life expectancy (LE) and number alive (NA) for an age/sex/race.""" + try: + result = get_life_values(db, age=age, sex=sex, race=race) + except InvalidCodeError as e: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) + if result is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Age not found") + return result + + +@router.get("/number/{month}", response_model=NumberResponse) +async def get_number_entry( + month: int = Path(..., ge=0, description="Month index (>= 0)"), + sex: str = Query("A", min_length=1, max_length=1, description="M, F, or A (all)"), + race: str = Query("A", min_length=1, max_length=1, description="W, B, H, or A (all)"), + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user), +): + """Get monthly number alive (NA) for a month/sex/race.""" + try: + result = get_number_value(db, month=month, sex=sex, race=race) + except InvalidCodeError as e: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) + if result is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Month not found") + return result + + diff --git a/app/api/search.py b/app/api/search.py index 9a67b60..4ad1b46 100644 --- a/app/api/search.py +++ b/app/api/search.py @@ -3,11 +3,10 @@ Advanced Search API endpoints - Comprehensive search across all data types """ from typing import List, Optional, Union, Dict, Any, Tuple from fastapi import APIRouter, Depends, HTTPException, status, Query, Body -from sqlalchemy.orm import Session, joinedload -from sqlalchemy import or_, and_, func, desc, asc, text, case, cast, String, DateTime, Date, Numeric +from sqlalchemy.orm import Session, joinedload, Load +from sqlalchemy import or_, and_, func, desc, asc, text, literal from datetime import date, datetime, timedelta -from pydantic import BaseModel, Field -import json +from pydantic import BaseModel, Field, field_validator, model_validator import re from app.database.base import get_db @@ -26,12 +25,73 @@ from app.models.qdro import QDRO from app.models.lookups import FormIndex, Employee, FileType, FileStatus, TransactionType, TransactionCode, State from app.models.user import User from app.auth.security import get_current_user +from app.services.cache import cache_get_json, cache_set_json router = APIRouter() +@router.get("/_debug") +async def search_debug( + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """Report whether FTS tables and Redis cache are active.""" + # Detect FTS by probing sqlite_master + fts_status = { + "rolodex": False, + "files": False, + "ledger": False, + "qdros": False, + } + try: + rows = db.execute(text("SELECT name FROM sqlite_master WHERE type='table' AND name LIKE '%_fts'")) + names = {r[0] for r in rows} + fts_status["rolodex"] = "rolodex_fts" in names + fts_status["files"] = "files_fts" in names + fts_status["ledger"] = "ledger_fts" in names + fts_status["qdros"] = "qdros_fts" in names + except Exception: + pass + + # Detect Redis by trying to obtain a client + try: + from app.services.cache import _get_client # type: ignore + client = await _get_client() + redis_ok = client is not None + except Exception: + redis_ok = False + return { + "fts": fts_status, + "redis": redis_ok, + } # Enhanced Search Schemas +# Allowed values for validation +ALLOWED_SEARCH_TYPES = {"customer", "file", "ledger", "qdro", "document", "template"} +ALLOWED_DATE_FIELDS = {"created", "updated", "opened", "closed"} +ALLOWED_AMOUNT_FIELDS = {"amount", "balance", "total_charges"} + +# Per-type field support for cross-field validation +SUPPORTED_DATE_FIELDS_BY_TYPE: Dict[str, set[str]] = { + "customer": {"created", "updated"}, + "file": {"created", "updated", "opened", "closed"}, + "ledger": {"created", "updated"}, + "qdro": {"created", "updated"}, + "document": {"created", "updated"}, + "template": {"created", "updated"}, +} + +SUPPORTED_AMOUNT_FIELDS_BY_TYPE: Dict[str, set[str]] = { + "customer": set(), + "file": {"balance", "total_charges"}, + "ledger": {"amount"}, + "qdro": set(), + "document": set(), + "template": set(), +} +ALLOWED_SORT_BY = {"relevance", "date", "amount", "title"} +ALLOWED_SORT_ORDER = {"asc", "desc"} + class SearchResult(BaseModel): """Enhanced search result with metadata""" type: str # "customer", "file", "ledger", "qdro", "document", "template", "phone" @@ -83,6 +143,164 @@ class AdvancedSearchCriteria(BaseModel): limit: int = Field(50, ge=1, le=200) offset: int = Field(0, ge=0) + # Field-level validators + @field_validator("search_types", mode="before") + @classmethod + def validate_search_types(cls, value): + # Coerce to list of unique, lower-cased items preserving order + raw_list = value or [] + if not isinstance(raw_list, list): + raw_list = [raw_list] + seen = set() + cleaned: List[str] = [] + for item in raw_list: + token = str(item or "").strip().lower() + if not token: + continue + if token not in ALLOWED_SEARCH_TYPES: + allowed = ", ".join(sorted(ALLOWED_SEARCH_TYPES)) + raise ValueError(f"search_types contains unknown type '{item}'. Allowed: {allowed}") + if token not in seen: + cleaned.append(token) + seen.add(token) + return cleaned + + @field_validator("sort_by") + @classmethod + def validate_sort_by(cls, value: str) -> str: + v = (value or "").strip().lower() + if v not in ALLOWED_SORT_BY: + allowed = ", ".join(sorted(ALLOWED_SORT_BY)) + raise ValueError(f"sort_by must be one of: {allowed}") + return v + + @field_validator("sort_order") + @classmethod + def validate_sort_order(cls, value: str) -> str: + v = (value or "").strip().lower() + if v not in ALLOWED_SORT_ORDER: + allowed = ", ".join(sorted(ALLOWED_SORT_ORDER)) + raise ValueError(f"sort_order must be one of: {allowed}") + return v + + @field_validator("date_field") + @classmethod + def validate_date_field(cls, value: Optional[str]) -> Optional[str]: + if value is None: + return None + v = str(value).strip().lower() + if v not in ALLOWED_DATE_FIELDS: + allowed = ", ".join(sorted(ALLOWED_DATE_FIELDS)) + raise ValueError(f"date_field must be one of: {allowed}") + return v + + @field_validator("amount_field") + @classmethod + def validate_amount_field(cls, value: Optional[str]) -> Optional[str]: + if value is None: + return None + v = str(value).strip().lower() + if v not in ALLOWED_AMOUNT_FIELDS: + allowed = ", ".join(sorted(ALLOWED_AMOUNT_FIELDS)) + raise ValueError(f"amount_field must be one of: {allowed}") + return v + + # Cross-field validations + @model_validator(mode="after") + def validate_cross_fields(self): + # Ensure search_types is not empty + if not self.search_types: + allowed = ", ".join(sorted(ALLOWED_SEARCH_TYPES)) + raise ValueError(f"search_types cannot be empty. Allowed values: {allowed}") + + # exact_phrase and whole_words are mutually exclusive + if self.exact_phrase and self.whole_words: + raise ValueError("exact_phrase and whole_words cannot both be true. Choose one.") + + # Date range bounds + if self.date_from and self.date_to and self.date_from > self.date_to: + raise ValueError("date_from must be less than or equal to date_to") + + # Amount range bounds + if self.amount_min is not None and self.amount_max is not None and self.amount_min > self.amount_max: + raise ValueError("amount_min must be less than or equal to amount_max") + + # Ensure date_field is supported by at least one selected search_type + if self.date_field: + selected = set(self.search_types or []) + # Validate allowed first (handled by field validator) then cross-type support + if not any(self.date_field in SUPPORTED_DATE_FIELDS_BY_TYPE.get(t, set()) for t in selected): + # Build helpful message + examples = { + "opened": "file", + "closed": "file", + } + hint = " Include 'file' in search_types." if examples.get(self.date_field) == "file" else "" + raise ValueError(f"date_field '{self.date_field}' is not supported by the selected search_types.{hint}") + + # Ensure amount_field is supported by at least one selected search_type + if self.amount_field: + selected = set(self.search_types or []) + if not any(self.amount_field in SUPPORTED_AMOUNT_FIELDS_BY_TYPE.get(t, set()) for t in selected): + # Provide actionable hint + field_owner = { + "amount": "ledger", + "balance": "file", + "total_charges": "file", + }.get(self.amount_field) + hint = f" Include '{field_owner}' in search_types." if field_owner else "" + raise ValueError(f"amount_field '{self.amount_field}' is not supported by the selected search_types.{hint}") + + return self + + +def _format_fts_query(raw_query: str, exact_phrase: bool, whole_words: bool) -> str: + """Format a user query for SQLite FTS5 according to flags. + + - exact_phrase: wrap the whole query in quotes for phrase match + - whole_words: leave tokens as-is (default FTS behavior) + - not whole_words: use prefix matching per token via '*' + We keep AND semantics across tokens to mirror SQL fallback behavior. + """ + if not raw_query: + return "" + if exact_phrase: + # Escape internal double quotes by doubling them per SQLite rules + escaped = str(raw_query).replace('"', '""') + return f'"{escaped}"' + tokens = build_query_tokens(raw_query) + if not tokens: + return raw_query + rendered = [] + for t in tokens: + rendered.append(f"{t}" if whole_words else f"{t}*") + # AND semantics between tokens + return " AND ".join(rendered) + + +def _like_whole_word(column, term: str, case_sensitive: bool): + """SQLite-friendly whole-word LIKE using padding with spaces. + This approximates word boundaries by searching in ' ' || lower(column) || ' '. + """ + if case_sensitive: + col_expr = literal(' ') + column + literal(' ') + return col_expr.like(f"% {term} %") + lowered = func.lower(column) + col_expr = literal(' ') + lowered + literal(' ') + return col_expr.like(f"% {term.lower()} %") + + +def _like_phrase_word_boundaries(column, phrase: str, case_sensitive: bool): + """LIKE match for an exact phrase bounded by spaces. + This approximates word-boundary phrase matching similar to FTS5 token phrase. + """ + if case_sensitive: + col_expr = literal(' ') + column + literal(' ') + return col_expr.like(f"% {phrase} %") + lowered = func.lower(column) + col_expr = literal(' ') + lowered + literal(' ') + return col_expr.like(f"% {phrase.lower()} %") + class SearchFilter(BaseModel): """Individual search filter""" field: str @@ -145,6 +363,18 @@ async def advanced_search( ): """Advanced search with complex criteria and filtering""" start_time = datetime.now() + + # Cache lookup keyed by user and entire criteria (including pagination) + try: + cached = await cache_get_json( + kind="advanced", + user_id=str(getattr(current_user, "id", "")), + parts={"criteria": criteria.model_dump(mode="json")}, + ) + except Exception: + cached = None + if cached: + return AdvancedSearchResponse(**cached) all_results = [] facets = {} @@ -176,35 +406,50 @@ async def advanced_search( # Sort results sorted_results = _sort_search_results(all_results, criteria.sort_by, criteria.sort_order) - + # Apply pagination total_count = len(sorted_results) paginated_results = sorted_results[criteria.offset:criteria.offset + criteria.limit] - + # Calculate facets facets = _calculate_facets(sorted_results) - + # Calculate stats execution_time = (datetime.now() - start_time).total_seconds() stats = await _calculate_search_stats(db, execution_time) - + # Page info page_info = { "current_page": (criteria.offset // criteria.limit) + 1, "total_pages": (total_count + criteria.limit - 1) // criteria.limit, "has_next": criteria.offset + criteria.limit < total_count, - "has_previous": criteria.offset > 0 + "has_previous": criteria.offset > 0, } - - return AdvancedSearchResponse( + + # Build response object once + response = AdvancedSearchResponse( criteria=criteria, results=paginated_results, stats=stats, facets=facets, total_results=total_count, - page_info=page_info + page_info=page_info, ) + # Store in cache (best-effort) + try: + await cache_set_json( + kind="advanced", + user_id=str(getattr(current_user, "id", "")), + parts={"criteria": criteria.model_dump(mode="json")}, + value=response.model_dump(mode="json"), + ttl_seconds=90, + ) + except Exception: + pass + + return response + @router.get("/global", response_model=GlobalSearchResponse) async def global_search( @@ -215,6 +460,14 @@ async def global_search( ): """Enhanced global search across all entities""" start_time = datetime.now() + # Cache lookup + cached = await cache_get_json( + kind="global", + user_id=str(getattr(current_user, "id", "")), + parts={"q": q, "limit": limit}, + ) + if cached: + return GlobalSearchResponse(**cached) # Create criteria for global search criteria = AdvancedSearchCriteria( @@ -237,7 +490,7 @@ async def global_search( execution_time = (datetime.now() - start_time).total_seconds() - return GlobalSearchResponse( + response = GlobalSearchResponse( query=q, total_results=total_results, execution_time=execution_time, @@ -249,6 +502,17 @@ async def global_search( templates=template_results[:limit], phones=phone_results[:limit] ) + try: + await cache_set_json( + kind="global", + user_id=str(getattr(current_user, "id", "")), + parts={"q": q, "limit": limit}, + value=response.model_dump(mode="json"), + ttl_seconds=90, + ) + except Exception: + pass + return response @router.get("/suggestions") @@ -259,6 +523,13 @@ async def search_suggestions( current_user: User = Depends(get_current_user) ): """Get search suggestions and autocomplete""" + cached = await cache_get_json( + kind="suggestions", + user_id=str(getattr(current_user, "id", "")), + parts={"q": q, "limit": limit}, + ) + if cached: + return cached suggestions = [] # Customer name suggestions @@ -291,7 +562,55 @@ async def search_suggestions( "description": file_obj.regarding }) - return {"suggestions": suggestions[:limit]} + payload = {"suggestions": suggestions[:limit]} + try: + await cache_set_json( + kind="suggestions", + user_id=str(getattr(current_user, "id", "")), + parts={"q": q, "limit": limit}, + value=payload, + ttl_seconds=60, + ) + except Exception: + pass + return payload + + +@router.get("/last_criteria") +async def get_last_criteria( + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """Return the last advanced search criteria for this user if present (best-effort).""" + try: + cached = await cache_get_json( + kind="last_criteria", + user_id=str(getattr(current_user, "id", "")), + parts={"v": 1}, + ) + return cached or {} + except Exception: + return {} + + +@router.post("/last_criteria") +async def set_last_criteria( + criteria: AdvancedSearchCriteria = Body(...), + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """Persist the last advanced search criteria for this user (best-effort).""" + try: + await cache_set_json( + kind="last_criteria", + user_id=str(getattr(current_user, "id", "")), + parts={"v": 1}, + value=criteria.model_dump(mode="json"), + ttl_seconds=60 * 60 * 24 * 7, # 7 days + ) + except Exception: + pass + return {"ok": True} @router.get("/facets") @@ -388,23 +707,107 @@ async def search_files( # Search Implementation Functions async def _search_customers(criteria: AdvancedSearchCriteria, db: Session) -> List[SearchResult]: - """Search customers with advanced criteria""" - query = db.query(Rolodex).options(joinedload(Rolodex.phone_numbers)) - + """Search customers with advanced criteria. Uses FTS5 when available.""" + results: List[SearchResult] = [] + + # Attempt FTS5 path when there's a query string + if criteria.query: + fts_sql = """ + SELECT r.* + FROM rolodex_fts f + JOIN rolodex r ON r.rowid = f.rowid + WHERE f MATCH :q + ORDER BY bm25(f) ASC + LIMIT :limit + """ + try: + fts_q = _format_fts_query(criteria.query, criteria.exact_phrase, criteria.whole_words) + rows = db.execute( + text(fts_sql), + {"q": fts_q, "limit": criteria.limit} + ).mappings().all() + + # Optionally apply state/date filters post-FTS (small result set) + filtered = [] + for row in rows: + if criteria.states and row.get("abrev") not in set(criteria.states): + continue + if criteria.date_from or criteria.date_to: + # Use created_at/updated_at when requested + if criteria.date_field == "created" and criteria.date_from and row.get("created_at") and row["created_at"] < criteria.date_from: + continue + if criteria.date_field == "created" and criteria.date_to and row.get("created_at") and row["created_at"] > criteria.date_to: + continue + if criteria.date_field == "updated" and criteria.date_from and row.get("updated_at") and row["updated_at"] < criteria.date_from: + continue + if criteria.date_field == "updated" and criteria.date_to and row.get("updated_at") and row["updated_at"] > criteria.date_to: + continue + filtered.append(row) + + for row in filtered[: criteria.limit]: + full_name = f"{row.get('first') or ''} {row.get('last') or ''}".strip() + location = f"{row.get('city') or ''}, {row.get('abrev') or ''}".strip(', ') + # Build a lightweight object-like view for downstream helpers + class _C: + pass + c = _C() + for k, v in row.items(): + setattr(c, k, v) + # Phones require relationship; fetch lazily for these ids + phones = db.query(Phone.phone).filter(Phone.rolodex_id == row["id"]).all() + phone_numbers = [p[0] for p in phones] + + results.append(SearchResult( + type="customer", + id=row["id"], + title=full_name or f"Customer {row['id']}", + description=f"ID: {row['id']} | {location}", + url=f"/customers?id={row['id']}", + metadata={ + "location": location, + "email": row.get("email"), + "phones": phone_numbers, + "group": row.get("group"), + "state": row.get("abrev"), + }, + relevance_score=1.0, # bm25 used for sort; keep minimal score + highlight=_create_customer_highlight(c, criteria.query or ""), + created_at=row.get("created_at"), + updated_at=row.get("updated_at"), + )) + + return results + except Exception: + # Fallback to legacy path when FTS isn't available + pass + + # Legacy SQL path (no query or FTS not available) + query = db.query(Rolodex).options( + Load(Rolodex).load_only( + Rolodex.id, + Rolodex.first, + Rolodex.last, + Rolodex.city, + Rolodex.abrev, + Rolodex.email, + Rolodex.memo, + Rolodex.created_at, + Rolodex.updated_at, + ), + joinedload(Rolodex.phone_numbers).load_only(Phone.phone), + ) + if criteria.query: search_conditions = [] - if criteria.exact_phrase: - # Exact phrase search search_term = criteria.query search_conditions.append( or_( func.concat(Rolodex.first, ' ', Rolodex.last).contains(search_term), - Rolodex.memo.contains(search_term) + Rolodex.memo.contains(search_term), ) ) else: - # Regular search with individual terms search_terms = criteria.query.split() for term in search_terms: if criteria.case_sensitive: @@ -415,7 +818,7 @@ async def _search_customers(criteria: AdvancedSearchCriteria, db: Session) -> Li Rolodex.first.contains(term), Rolodex.city.contains(term), Rolodex.email.contains(term), - Rolodex.memo.contains(term) + Rolodex.memo.contains(term), ) ) else: @@ -426,47 +829,33 @@ async def _search_customers(criteria: AdvancedSearchCriteria, db: Session) -> Li Rolodex.first.ilike(f"%{term}%"), Rolodex.city.ilike(f"%{term}%"), Rolodex.email.ilike(f"%{term}%"), - Rolodex.memo.ilike(f"%{term}%") + Rolodex.memo.ilike(f"%{term}%"), ) ) - if search_conditions: query = query.filter(and_(*search_conditions)) - - # Apply filters + if criteria.states: query = query.filter(Rolodex.abrev.in_(criteria.states)) - - # Apply date filters + if criteria.date_from or criteria.date_to: - date_field_map = { - "created": Rolodex.created_at, - "updated": Rolodex.updated_at - } - + date_field_map = {"created": Rolodex.created_at, "updated": Rolodex.updated_at} if criteria.date_field in date_field_map: field = date_field_map[criteria.date_field] if criteria.date_from: query = query.filter(field >= criteria.date_from) if criteria.date_to: query = query.filter(field <= criteria.date_to) - + customers = query.limit(criteria.limit).all() - - results = [] + for customer in customers: full_name = f"{customer.first or ''} {customer.last}".strip() location = f"{customer.city or ''}, {customer.abrev or ''}".strip(', ') - - # Calculate relevance score relevance = _calculate_customer_relevance(customer, criteria.query or "") - - # Create highlight snippet highlight = _create_customer_highlight(customer, criteria.query or "") - - # Get phone numbers phone_numbers = [p.phone for p in customer.phone_numbers] if customer.phone_numbers else [] - + results.append(SearchResult( type="customer", id=customer.id, @@ -477,107 +866,192 @@ async def _search_customers(criteria: AdvancedSearchCriteria, db: Session) -> Li "location": location, "email": customer.email, "phones": phone_numbers, - "group": customer.group + "group": customer.group, + "state": customer.abrev, }, relevance_score=relevance, highlight=highlight, created_at=customer.created_at, - updated_at=customer.updated_at + updated_at=customer.updated_at, )) - + return results async def _search_files(criteria: AdvancedSearchCriteria, db: Session) -> List[SearchResult]: - """Search files with advanced criteria""" - query = db.query(File).options(joinedload(File.owner)) - + """Search files with advanced criteria. Uses FTS5 when available.""" + results: List[SearchResult] = [] + + if criteria.query: + fts_sql = """ + SELECT f.* + FROM files_fts x + JOIN files f ON f.rowid = x.rowid + WHERE x MATCH :q + ORDER BY bm25(x) ASC + LIMIT :limit + """ + try: + fts_q = _format_fts_query(criteria.query, criteria.exact_phrase, criteria.whole_words) + rows = db.execute(text(fts_sql), {"q": fts_q, "limit": criteria.limit}).mappings().all() + + # Post-filtering on small set + filtered = [] + for row in rows: + if criteria.file_types and row.get("file_type") not in set(criteria.file_types): + continue + if criteria.file_statuses and row.get("status") not in set(criteria.file_statuses): + continue + if criteria.employees and row.get("empl_num") not in set(criteria.employees): + continue + if criteria.has_balance is not None: + owing = float(row.get("amount_owing") or 0) + if criteria.has_balance and not (owing > 0): + continue + if not criteria.has_balance and not (owing <= 0): + continue + if criteria.amount_min is not None and float(row.get("amount_owing") or 0) < criteria.amount_min: + continue + if criteria.amount_max is not None and float(row.get("amount_owing") or 0) > criteria.amount_max: + continue + if criteria.date_from or criteria.date_to: + field = None + if criteria.date_field == "created": + field = row.get("created_at") + elif criteria.date_field == "updated": + field = row.get("updated_at") + elif criteria.date_field == "opened": + field = row.get("opened") + elif criteria.date_field == "closed": + field = row.get("closed") + if criteria.date_from and field and field < criteria.date_from: + continue + if criteria.date_to and field and field > criteria.date_to: + continue + filtered.append(row) + + for row in filtered[: criteria.limit]: + # Load owner name for display + owner = db.query(Rolodex.first, Rolodex.last).filter(Rolodex.id == row.get("id")).first() + client_name = f"{(owner.first if owner else '') or ''} {(owner.last if owner else '') or ''}".strip() + class _F: pass + fobj = _F() + for k, v in row.items(): + setattr(fobj, k, v) + + results.append(SearchResult( + type="file", + id=row["file_no"], + title=f"File #{row['file_no']}", + description=f"Client: {client_name} | {row.get('regarding') or 'No description'} | Status: {row.get('status')}", + url=f"/files?file_no={row['file_no']}", + metadata={ + "client_id": row.get("id"), + "client_name": client_name, + "file_type": row.get("file_type"), + "status": row.get("status"), + "employee": row.get("empl_num"), + "amount_owing": float(row.get("amount_owing") or 0), + "total_charges": float(row.get("total_charges") or 0), + }, + relevance_score=1.0, + highlight=_create_file_highlight(fobj, criteria.query or ""), + created_at=row.get("created_at"), + updated_at=row.get("updated_at"), + )) + + return results + except Exception: + pass + + # Fallback legacy path + query = db.query(File).options( + Load(File).load_only( + File.file_no, + File.id, + File.regarding, + File.status, + File.file_type, + File.empl_num, + File.amount_owing, + File.total_charges, + File.created_at, + File.updated_at, + ), + joinedload(File.owner).load_only(Rolodex.first, Rolodex.last), + ) + if criteria.query: - search_terms = criteria.query.split() search_conditions = [] - - for term in search_terms: - if criteria.case_sensitive: - search_conditions.append( - or_( - File.file_no.contains(term), - File.id.contains(term), - File.regarding.contains(term), - File.file_type.contains(term), - File.memo.contains(term) - ) - ) - else: - search_conditions.append( - or_( - File.file_no.ilike(f"%{term}%"), - File.id.ilike(f"%{term}%"), - File.regarding.ilike(f"%{term}%"), - File.file_type.ilike(f"%{term}%"), - File.memo.ilike(f"%{term}%") - ) - ) - + if criteria.exact_phrase: + phrase = criteria.query + search_conditions.append(or_( + _like_phrase_word_boundaries(File.regarding, phrase, criteria.case_sensitive), + _like_phrase_word_boundaries(File.file_type, phrase, criteria.case_sensitive), + _like_phrase_word_boundaries(File.memo, phrase, criteria.case_sensitive), + _like_phrase_word_boundaries(File.file_no, phrase, criteria.case_sensitive), + _like_phrase_word_boundaries(File.id, phrase, criteria.case_sensitive), + )) + else: + tokens = build_query_tokens(criteria.query) + for term in tokens: + if criteria.whole_words: + search_conditions.append(or_( + _like_whole_word(File.regarding, term, criteria.case_sensitive), + _like_whole_word(File.file_type, term, criteria.case_sensitive), + _like_whole_word(File.memo, term, criteria.case_sensitive), + _like_whole_word(File.file_no, term, criteria.case_sensitive), + _like_whole_word(File.id, term, criteria.case_sensitive), + )) + else: + if criteria.case_sensitive: + search_conditions.append(or_( + File.file_no.contains(term), + File.id.contains(term), + File.regarding.contains(term), + File.file_type.contains(term), + File.memo.contains(term), + )) + else: + search_conditions.append(or_( + File.file_no.ilike(f"%{term}%"), + File.id.ilike(f"%{term}%"), + File.regarding.ilike(f"%{term}%"), + File.file_type.ilike(f"%{term}%"), + File.memo.ilike(f"%{term}%"), + )) if search_conditions: query = query.filter(and_(*search_conditions)) - - # Apply filters + if criteria.file_types: query = query.filter(File.file_type.in_(criteria.file_types)) - if criteria.file_statuses: query = query.filter(File.status.in_(criteria.file_statuses)) - if criteria.employees: query = query.filter(File.empl_num.in_(criteria.employees)) - if criteria.has_balance is not None: - if criteria.has_balance: - query = query.filter(File.amount_owing > 0) - else: - query = query.filter(File.amount_owing <= 0) - - # Amount filters - if criteria.amount_min is not None or criteria.amount_max is not None: - amount_field_map = { - "balance": File.amount_owing, - "total_charges": File.total_charges - } - - if criteria.amount_field in amount_field_map: - field = amount_field_map[criteria.amount_field] - if criteria.amount_min is not None: - query = query.filter(field >= criteria.amount_min) - if criteria.amount_max is not None: - query = query.filter(field <= criteria.amount_max) - - # Date filters + query = query.filter(File.amount_owing > 0) if criteria.has_balance else query.filter(File.amount_owing <= 0) + if criteria.amount_min is not None: + query = query.filter(File.amount_owing >= criteria.amount_min) + if criteria.amount_max is not None: + query = query.filter(File.amount_owing <= criteria.amount_max) if criteria.date_from or criteria.date_to: - date_field_map = { - "created": File.created_at, - "updated": File.updated_at, - "opened": File.opened, - "closed": File.closed - } - + date_field_map = {"created": File.created_at, "updated": File.updated_at, "opened": File.opened, "closed": File.closed} if criteria.date_field in date_field_map: field = date_field_map[criteria.date_field] if criteria.date_from: query = query.filter(field >= criteria.date_from) if criteria.date_to: query = query.filter(field <= criteria.date_to) - + files = query.limit(criteria.limit).all() - - results = [] + for file_obj in files: client_name = "" if file_obj.owner: client_name = f"{file_obj.owner.first or ''} {file_obj.owner.last}".strip() - relevance = _calculate_file_relevance(file_obj, criteria.query or "") highlight = _create_file_highlight(file_obj, criteria.query or "") - results.append(SearchResult( type="file", id=file_obj.file_no, @@ -591,81 +1065,172 @@ async def _search_files(criteria: AdvancedSearchCriteria, db: Session) -> List[S "status": file_obj.status, "employee": file_obj.empl_num, "amount_owing": float(file_obj.amount_owing or 0), - "total_charges": float(file_obj.total_charges or 0) + "total_charges": float(file_obj.total_charges or 0), }, relevance_score=relevance, highlight=highlight, created_at=file_obj.created_at, - updated_at=file_obj.updated_at + updated_at=file_obj.updated_at, )) - + return results async def _search_ledger(criteria: AdvancedSearchCriteria, db: Session) -> List[SearchResult]: - """Search ledger entries with advanced criteria""" - query = db.query(Ledger).options(joinedload(Ledger.file).joinedload(File.owner)) - + """Search ledger entries with advanced criteria. Uses FTS5 when available.""" + results: List[SearchResult] = [] + + if criteria.query: + fts_sql = """ + SELECT l.* + FROM ledger_fts x + JOIN ledger l ON l.rowid = x.rowid + WHERE x MATCH :q + ORDER BY bm25(x) ASC + LIMIT :limit + """ + try: + fts_q = _format_fts_query(criteria.query, criteria.exact_phrase, criteria.whole_words) + rows = db.execute(text(fts_sql), {"q": fts_q, "limit": criteria.limit}).mappings().all() + filtered = [] + for row in rows: + if criteria.transaction_types and row.get("t_type") not in set(criteria.transaction_types): + continue + if criteria.employees and row.get("empl_num") not in set(criteria.employees): + continue + if criteria.is_billed is not None: + billed_flag = (row.get("billed") == "Y") + if criteria.is_billed != billed_flag: + continue + if criteria.amount_min is not None and float(row.get("amount") or 0) < criteria.amount_min: + continue + if criteria.amount_max is not None and float(row.get("amount") or 0) > criteria.amount_max: + continue + if criteria.date_from and row.get("date") and row["date"] < criteria.date_from: + continue + if criteria.date_to and row.get("date") and row["date"] > criteria.date_to: + continue + filtered.append(row) + + # Fetch owner names for display + for row in filtered[: criteria.limit]: + client_name = "" + # Join to files -> rolodex for name + owner = db.query(Rolodex.first, Rolodex.last).join(File, File.id == Rolodex.id).filter(File.file_no == row.get("file_no")).first() + if owner: + client_name = f"{owner.first or ''} {owner.last or ''}".strip() + class _L: pass + lobj = _L() + for k, v in row.items(): + setattr(lobj, k, v) + results.append(SearchResult( + type="ledger", + id=row["id"], + title=f"Transaction {row.get('t_code')} - ${row.get('amount')}", + description=f"File: {row.get('file_no')} | Client: {client_name} | Date: {row.get('date')} | {row.get('note') or 'No note'}", + url=f"/financial?file_no={row.get('file_no')}", + metadata={ + "file_no": row.get("file_no"), + "transaction_type": row.get("t_type"), + "transaction_code": row.get("t_code"), + "amount": float(row.get("amount") or 0), + "quantity": float(row.get("quantity") or 0), + "rate": float(row.get("rate") or 0), + "employee": row.get("empl_num"), + "billed": row.get("billed") == "Y", + "date": row.get("date").isoformat() if row.get("date") else None, + }, + relevance_score=1.0, + highlight=_create_ledger_highlight(lobj, criteria.query or ""), + created_at=row.get("created_at"), + updated_at=row.get("updated_at"), + )) + + return results + except Exception: + pass + + # Fallback legacy path + query = db.query(Ledger).options( + Load(Ledger).load_only( + Ledger.id, + Ledger.file_no, + Ledger.t_code, + Ledger.t_type, + Ledger.empl_num, + Ledger.quantity, + Ledger.rate, + Ledger.amount, + Ledger.billed, + Ledger.note, + Ledger.date, + Ledger.created_at, + Ledger.updated_at, + ), + joinedload(Ledger.file) + .load_only(File.file_no, File.id) + .joinedload(File.owner) + .load_only(Rolodex.first, Rolodex.last), + ) if criteria.query: - search_terms = criteria.query.split() search_conditions = [] - - for term in search_terms: - if criteria.case_sensitive: - search_conditions.append( - or_( - Ledger.file_no.contains(term), - Ledger.t_code.contains(term), - Ledger.note.contains(term), - Ledger.empl_num.contains(term) - ) - ) - else: - search_conditions.append( - or_( - Ledger.file_no.ilike(f"%{term}%"), - Ledger.t_code.ilike(f"%{term}%"), - Ledger.note.ilike(f"%{term}%"), - Ledger.empl_num.ilike(f"%{term}%") - ) - ) - + if criteria.exact_phrase: + phrase = criteria.query + search_conditions.append(or_( + _like_phrase_word_boundaries(Ledger.note, phrase, criteria.case_sensitive), + _like_phrase_word_boundaries(Ledger.t_code, phrase, criteria.case_sensitive), + _like_phrase_word_boundaries(Ledger.file_no, phrase, criteria.case_sensitive), + _like_phrase_word_boundaries(Ledger.empl_num, phrase, criteria.case_sensitive), + )) + else: + tokens = build_query_tokens(criteria.query) + for term in tokens: + if criteria.whole_words: + search_conditions.append(or_( + _like_whole_word(Ledger.note, term, criteria.case_sensitive), + _like_whole_word(Ledger.t_code, term, criteria.case_sensitive), + _like_whole_word(Ledger.file_no, term, criteria.case_sensitive), + _like_whole_word(Ledger.empl_num, term, criteria.case_sensitive), + )) + else: + if criteria.case_sensitive: + search_conditions.append(or_( + Ledger.file_no.contains(term), + Ledger.t_code.contains(term), + Ledger.note.contains(term), + Ledger.empl_num.contains(term), + )) + else: + search_conditions.append(or_( + Ledger.file_no.ilike(f"%{term}%"), + Ledger.t_code.ilike(f"%{term}%"), + Ledger.note.ilike(f"%{term}%"), + Ledger.empl_num.ilike(f"%{term}%"), + )) if search_conditions: query = query.filter(and_(*search_conditions)) - - # Apply filters if criteria.transaction_types: query = query.filter(Ledger.t_type.in_(criteria.transaction_types)) - if criteria.employees: query = query.filter(Ledger.empl_num.in_(criteria.employees)) - if criteria.is_billed is not None: query = query.filter(Ledger.billed == ("Y" if criteria.is_billed else "N")) - - # Amount filters if criteria.amount_min is not None: query = query.filter(Ledger.amount >= criteria.amount_min) if criteria.amount_max is not None: query = query.filter(Ledger.amount <= criteria.amount_max) - - # Date filters if criteria.date_from: query = query.filter(Ledger.date >= criteria.date_from) if criteria.date_to: query = query.filter(Ledger.date <= criteria.date_to) - ledgers = query.limit(criteria.limit).all() - - results = [] + for ledger in ledgers: client_name = "" if ledger.file and ledger.file.owner: client_name = f"{ledger.file.owner.first or ''} {ledger.file.owner.last}".strip() - relevance = _calculate_ledger_relevance(ledger, criteria.query or "") highlight = _create_ledger_highlight(ledger, criteria.query or "") - results.append(SearchResult( type="ledger", id=ledger.id, @@ -681,59 +1246,115 @@ async def _search_ledger(criteria: AdvancedSearchCriteria, db: Session) -> List[ "rate": float(ledger.rate or 0), "employee": ledger.empl_num, "billed": ledger.billed == "Y", - "date": ledger.date.isoformat() if ledger.date else None + "date": ledger.date.isoformat() if ledger.date else None, }, relevance_score=relevance, highlight=highlight, created_at=ledger.created_at, - updated_at=ledger.updated_at + updated_at=ledger.updated_at, )) - + return results async def _search_qdros(criteria: AdvancedSearchCriteria, db: Session) -> List[SearchResult]: - """Search QDRO documents with advanced criteria""" - query = db.query(QDRO).options(joinedload(QDRO.file)) - + """Search QDRO documents with advanced criteria. Uses FTS5 when available.""" + results: List[SearchResult] = [] + + if criteria.query: + fts_sql = """ + SELECT q.* + FROM qdros_fts x + JOIN qdros q ON q.rowid = x.rowid + WHERE x MATCH :q + ORDER BY bm25(x) ASC + LIMIT :limit + """ + try: + fts_q = _format_fts_query(criteria.query, criteria.exact_phrase, criteria.whole_words) + rows = db.execute(text(fts_sql), {"q": fts_q, "limit": criteria.limit}).mappings().all() + for row in rows[: criteria.limit]: + class _Q: pass + q = _Q() + for k, v in row.items(): + setattr(q, k, v) + results.append(SearchResult( + type="qdro", + id=row["id"], + title=row.get("form_name") or f"QDRO v{row.get('version')}", + description=f"File: {row.get('file_no')} | Status: {row.get('status')} | Case: {row.get('case_number') or 'N/A'}", + url=f"/documents?qdro_id={row['id']}", + metadata={ + "file_no": row.get("file_no"), + "version": row.get("version"), + "status": row.get("status"), + "petitioner": row.get("pet"), + "respondent": row.get("res"), + "case_number": row.get("case_number"), + }, + relevance_score=1.0, + highlight=_create_qdro_highlight(q, criteria.query or ""), + created_at=row.get("created_at"), + updated_at=row.get("updated_at"), + )) + return results + except Exception: + pass + + # Fallback legacy path + query = db.query(QDRO).options(joinedload(QDRO.file)) + if criteria.query: - search_terms = criteria.query.split() search_conditions = [] - - for term in search_terms: - if criteria.case_sensitive: - search_conditions.append( - or_( - QDRO.file_no.contains(term), - QDRO.form_name.contains(term), - QDRO.pet.contains(term), - QDRO.res.contains(term), - QDRO.case_number.contains(term), - QDRO.notes.contains(term) - ) - ) - else: - search_conditions.append( - or_( - QDRO.file_no.ilike(f"%{term}%"), - QDRO.form_name.ilike(f"%{term}%"), - QDRO.pet.ilike(f"%{term}%"), - QDRO.res.ilike(f"%{term}%"), - QDRO.case_number.ilike(f"%{term}%"), - QDRO.notes.ilike(f"%{term}%") - ) - ) - + if criteria.exact_phrase: + phrase = criteria.query + search_conditions.append(or_( + _like_phrase_word_boundaries(QDRO.form_name, phrase, criteria.case_sensitive), + _like_phrase_word_boundaries(QDRO.pet, phrase, criteria.case_sensitive), + _like_phrase_word_boundaries(QDRO.res, phrase, criteria.case_sensitive), + _like_phrase_word_boundaries(QDRO.case_number, phrase, criteria.case_sensitive), + _like_phrase_word_boundaries(QDRO.notes, phrase, criteria.case_sensitive), + _like_phrase_word_boundaries(QDRO.file_no, phrase, criteria.case_sensitive), + )) + else: + tokens = build_query_tokens(criteria.query) + for term in tokens: + if criteria.whole_words: + search_conditions.append(or_( + _like_whole_word(QDRO.form_name, term, criteria.case_sensitive), + _like_whole_word(QDRO.pet, term, criteria.case_sensitive), + _like_whole_word(QDRO.res, term, criteria.case_sensitive), + _like_whole_word(QDRO.case_number, term, criteria.case_sensitive), + _like_whole_word(QDRO.notes, term, criteria.case_sensitive), + _like_whole_word(QDRO.file_no, term, criteria.case_sensitive), + )) + else: + if criteria.case_sensitive: + search_conditions.append(or_( + QDRO.file_no.contains(term), + QDRO.form_name.contains(term), + QDRO.pet.contains(term), + QDRO.res.contains(term), + QDRO.case_number.contains(term), + QDRO.notes.contains(term), + )) + else: + search_conditions.append(or_( + QDRO.file_no.ilike(f"%{term}%"), + QDRO.form_name.ilike(f"%{term}%"), + QDRO.pet.ilike(f"%{term}%"), + QDRO.res.ilike(f"%{term}%"), + QDRO.case_number.ilike(f"%{term}%"), + QDRO.notes.ilike(f"%{term}%"), + )) if search_conditions: query = query.filter(and_(*search_conditions)) - + qdros = query.limit(criteria.limit).all() - - results = [] + for qdro in qdros: relevance = _calculate_qdro_relevance(qdro, criteria.query or "") highlight = _create_qdro_highlight(qdro, criteria.query or "") - results.append(SearchResult( type="qdro", id=qdro.id, @@ -746,14 +1367,14 @@ async def _search_qdros(criteria: AdvancedSearchCriteria, db: Session) -> List[S "status": qdro.status, "petitioner": qdro.pet, "respondent": qdro.res, - "case_number": qdro.case_number + "case_number": qdro.case_number, }, relevance_score=relevance, highlight=highlight, created_at=qdro.created_at, - updated_at=qdro.updated_at + updated_at=qdro.updated_at, )) - + return results @@ -821,7 +1442,17 @@ async def _search_templates(criteria: AdvancedSearchCriteria, db: Session) -> Li async def _search_phones(criteria: AdvancedSearchCriteria, db: Session) -> List[SearchResult]: """Search phone numbers""" - query = db.query(Phone).options(joinedload(Phone.rolodex_entry)) + query = db.query(Phone).options( + Load(Phone).load_only( + Phone.id, + Phone.phone, + Phone.location, + Phone.rolodex_id, + Phone.created_at, + Phone.updated_at, + ), + joinedload(Phone.rolodex_entry).load_only(Rolodex.first, Rolodex.last), + ) if criteria.query: # Clean phone number for search (remove non-digits) @@ -887,7 +1518,9 @@ def _calculate_facets(results: List[SearchResult]) -> Dict[str, Dict[str, int]]: "file_type": {}, "status": {}, "employee": {}, - "category": {} + "category": {}, + "state": {}, + "transaction_type": {}, } for result in results: @@ -896,7 +1529,7 @@ def _calculate_facets(results: List[SearchResult]) -> Dict[str, Dict[str, int]]: # Metadata facets if result.metadata: - for facet_key in ["file_type", "status", "employee", "category"]: + for facet_key in ["file_type", "status", "employee", "category", "state", "transaction_type"]: if facet_key in result.metadata: value = result.metadata[facet_key] if value: diff --git a/app/api/search_highlight.py b/app/api/search_highlight.py index 3263e08..662b5e1 100644 --- a/app/api/search_highlight.py +++ b/app/api/search_highlight.py @@ -2,8 +2,10 @@ Server-side highlight utilities for search results. These functions generate HTML snippets with around matched tokens, -preserving the original casing of the source text. The output is intended to be -sanitized on the client before insertion into the DOM. +preserving the original casing of the source text. All non-HTML segments are +HTML-escaped server-side to prevent injection. Only the tags added by +this module are emitted as HTML; any pre-existing HTML in source text is +escaped. """ from typing import List, Tuple, Any import re @@ -42,18 +44,40 @@ def _merge_ranges(ranges: List[Tuple[int, int]]) -> List[Tuple[int, int]]: def highlight_text(value: str, tokens: List[str]) -> str: - """Return `value` with case-insensitive matches of `tokens` wrapped in , preserving original casing.""" + """Return `value` with case-insensitive matches of `tokens` wrapped in , preserving original casing. + + Non-highlighted segments and the highlighted text content are HTML-escaped. + Only the surrounding wrappers are emitted as markup. + """ if value is None: return "" + + def _escape_html(text: str) -> str: + # Minimal, safe HTML escaping + if text is None: + return "" + # Replace ampersand first to avoid double-escaping + text = str(text) + text = text.replace("&", "&") + text = text.replace("<", "<") + text = text.replace(">", ">") + text = text.replace('"', """) + text = text.replace("'", "'") + return text source = str(value) if not source or not tokens: - return source + return _escape_html(source) haystack = source.lower() ranges: List[Tuple[int, int]] = [] + # Deduplicate tokens case-insensitively to avoid redundant scans (parity with client) + unique_needles = [] + seen_needles = set() for t in tokens: needle = str(t or "").lower() - if not needle: - continue + if needle and needle not in seen_needles: + unique_needles.append(needle) + seen_needles.add(needle) + for needle in unique_needles: start = 0 last_possible = max(0, len(haystack) - len(needle)) while start <= last_possible and len(needle) > 0: @@ -63,17 +87,17 @@ def highlight_text(value: str, tokens: List[str]) -> str: ranges.append((idx, idx + len(needle))) start = idx + 1 if not ranges: - return source + return _escape_html(source) parts: List[str] = [] merged = _merge_ranges(ranges) pos = 0 for s, e in merged: if pos < s: - parts.append(source[pos:s]) - parts.append("" + source[s:e] + "") + parts.append(_escape_html(source[pos:s])) + parts.append("" + _escape_html(source[s:e]) + "") pos = e if pos < len(source): - parts.append(source[pos:]) + parts.append(_escape_html(source[pos:])) return "".join(parts) diff --git a/app/api/support.py b/app/api/support.py index ae23d12..5805a54 100644 --- a/app/api/support.py +++ b/app/api/support.py @@ -2,21 +2,24 @@ Support ticket API endpoints """ from typing import List, Optional -from fastapi import APIRouter, Depends, HTTPException, status, Request +from fastapi import APIRouter, Depends, HTTPException, status, Request, Query from sqlalchemy.orm import Session, joinedload from sqlalchemy import func, desc, and_, or_ -from datetime import datetime +from datetime import datetime, timezone import secrets from app.database.base import get_db from app.models import User, SupportTicket, TicketResponse as TicketResponseModel, TicketStatus, TicketPriority, TicketCategory from app.auth.security import get_current_user, get_admin_user from app.services.audit import audit_service +from app.services.query_utils import apply_sorting, paginate_with_total, tokenized_ilike_filter +from app.api.search_highlight import build_query_tokens router = APIRouter() # Pydantic models for API from pydantic import BaseModel, Field, EmailStr +from pydantic.config import ConfigDict class TicketCreate(BaseModel): @@ -57,8 +60,7 @@ class TicketResponseOut(BaseModel): author_email: Optional[str] created_at: datetime - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) class TicketDetail(BaseModel): @@ -81,15 +83,19 @@ class TicketDetail(BaseModel): assigned_to: Optional[int] assigned_admin_name: Optional[str] submitter_name: Optional[str] - responses: List[TicketResponseOut] = [] + responses: List[TicketResponseOut] = Field(default_factory=list) - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) + + +class PaginatedTicketsResponse(BaseModel): + items: List[TicketDetail] + total: int def generate_ticket_number() -> str: """Generate unique ticket number like ST-2024-001""" - year = datetime.now().year + year = datetime.now(timezone.utc).year random_suffix = secrets.token_hex(2).upper() return f"ST-{year}-{random_suffix}" @@ -129,7 +135,7 @@ async def create_support_ticket( ip_address=client_ip, user_id=current_user.id if current_user else None, status=TicketStatus.OPEN, - created_at=datetime.utcnow() + created_at=datetime.now(timezone.utc) ) db.add(new_ticket) @@ -158,14 +164,18 @@ async def create_support_ticket( } -@router.get("/tickets", response_model=List[TicketDetail]) +@router.get("/tickets", response_model=List[TicketDetail] | PaginatedTicketsResponse) async def list_tickets( - status: Optional[TicketStatus] = None, - priority: Optional[TicketPriority] = None, - category: Optional[TicketCategory] = None, - assigned_to_me: bool = False, - skip: int = 0, - limit: int = 50, + status: Optional[TicketStatus] = Query(None, description="Filter by ticket status"), + priority: Optional[TicketPriority] = Query(None, description="Filter by ticket priority"), + category: Optional[TicketCategory] = Query(None, description="Filter by ticket category"), + assigned_to_me: bool = Query(False, description="Only include tickets assigned to the current admin"), + search: Optional[str] = Query(None, description="Tokenized search across number, subject, description, contact name/email, current page, and IP"), + skip: int = Query(0, ge=0, description="Offset for pagination"), + limit: int = Query(50, ge=1, le=200, description="Page size"), + sort_by: Optional[str] = Query(None, description="Sort by: created, updated, resolved, priority, status, subject"), + sort_dir: Optional[str] = Query("desc", description="Sort direction: asc or desc"), + include_total: bool = Query(False, description="When true, returns {items, total} instead of a plain list"), db: Session = Depends(get_db), current_user: User = Depends(get_admin_user) ): @@ -186,8 +196,38 @@ async def list_tickets( query = query.filter(SupportTicket.category == category) if assigned_to_me: query = query.filter(SupportTicket.assigned_to == current_user.id) - - tickets = query.order_by(desc(SupportTicket.created_at)).offset(skip).limit(limit).all() + + # Search across key text fields + if search: + tokens = build_query_tokens(search) + filter_expr = tokenized_ilike_filter(tokens, [ + SupportTicket.ticket_number, + SupportTicket.subject, + SupportTicket.description, + SupportTicket.contact_name, + SupportTicket.contact_email, + SupportTicket.current_page, + SupportTicket.ip_address, + ]) + if filter_expr is not None: + query = query.filter(filter_expr) + + # Sorting (whitelisted) + query = apply_sorting( + query, + sort_by, + sort_dir, + allowed={ + "created": [SupportTicket.created_at], + "updated": [SupportTicket.updated_at], + "resolved": [SupportTicket.resolved_at], + "priority": [SupportTicket.priority], + "status": [SupportTicket.status], + "subject": [SupportTicket.subject], + }, + ) + + tickets, total = paginate_with_total(query, skip, limit, include_total) # Format response result = [] @@ -226,6 +266,8 @@ async def list_tickets( } result.append(ticket_dict) + if include_total: + return {"items": result, "total": total or 0} return result @@ -312,10 +354,10 @@ async def update_ticket( # Set resolved timestamp if status changed to resolved if ticket_data.status == TicketStatus.RESOLVED and ticket.resolved_at is None: - ticket.resolved_at = datetime.utcnow() + ticket.resolved_at = datetime.now(timezone.utc) changes["resolved_at"] = {"from": None, "to": ticket.resolved_at} - ticket.updated_at = datetime.utcnow() + ticket.updated_at = datetime.now(timezone.utc) db.commit() # Audit logging (non-blocking) @@ -358,13 +400,13 @@ async def add_response( message=response_data.message, is_internal=response_data.is_internal, user_id=current_user.id, - created_at=datetime.utcnow() + created_at=datetime.now(timezone.utc) ) db.add(response) # Update ticket timestamp - ticket.updated_at = datetime.utcnow() + ticket.updated_at = datetime.now(timezone.utc) db.commit() db.refresh(response) @@ -386,11 +428,15 @@ async def add_response( return {"message": "Response added successfully", "response_id": response.id} -@router.get("/my-tickets", response_model=List[TicketDetail]) +@router.get("/my-tickets", response_model=List[TicketDetail] | PaginatedTicketsResponse) async def get_my_tickets( - status: Optional[TicketStatus] = None, - skip: int = 0, - limit: int = 20, + status: Optional[TicketStatus] = Query(None, description="Filter by ticket status"), + search: Optional[str] = Query(None, description="Tokenized search across number, subject, description"), + skip: int = Query(0, ge=0, description="Offset for pagination"), + limit: int = Query(20, ge=1, le=200, description="Page size"), + sort_by: Optional[str] = Query(None, description="Sort by: created, updated, resolved, priority, status, subject"), + sort_dir: Optional[str] = Query("desc", description="Sort direction: asc or desc"), + include_total: bool = Query(False, description="When true, returns {items, total} instead of a plain list"), db: Session = Depends(get_db), current_user: User = Depends(get_current_user) ): @@ -403,7 +449,33 @@ async def get_my_tickets( if status: query = query.filter(SupportTicket.status == status) - tickets = query.order_by(desc(SupportTicket.created_at)).offset(skip).limit(limit).all() + # Search within user's tickets + if search: + tokens = build_query_tokens(search) + filter_expr = tokenized_ilike_filter(tokens, [ + SupportTicket.ticket_number, + SupportTicket.subject, + SupportTicket.description, + ]) + if filter_expr is not None: + query = query.filter(filter_expr) + + # Sorting (whitelisted) + query = apply_sorting( + query, + sort_by, + sort_dir, + allowed={ + "created": [SupportTicket.created_at], + "updated": [SupportTicket.updated_at], + "resolved": [SupportTicket.resolved_at], + "priority": [SupportTicket.priority], + "status": [SupportTicket.status], + "subject": [SupportTicket.subject], + }, + ) + + tickets, total = paginate_with_total(query, skip, limit, include_total) # Format response (exclude internal responses for regular users) result = [] @@ -442,6 +514,8 @@ async def get_my_tickets( } result.append(ticket_dict) + if include_total: + return {"items": result, "total": total or 0} return result @@ -473,7 +547,7 @@ async def get_ticket_stats( # Recent tickets (last 7 days) from datetime import timedelta - week_ago = datetime.utcnow() - timedelta(days=7) + week_ago = datetime.now(timezone.utc) - timedelta(days=7) recent_tickets = db.query(func.count(SupportTicket.id)).filter( SupportTicket.created_at >= week_ago ).scalar() diff --git a/app/auth/schemas.py b/app/auth/schemas.py index 537c6ae..d6d2c3a 100644 --- a/app/auth/schemas.py +++ b/app/auth/schemas.py @@ -3,6 +3,7 @@ Authentication schemas """ from typing import Optional from pydantic import BaseModel, EmailStr +from pydantic.config import ConfigDict class UserBase(BaseModel): @@ -32,8 +33,7 @@ class UserResponse(UserBase): is_admin: bool theme_preference: Optional[str] = "light" - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) class ThemePreferenceUpdate(BaseModel): @@ -45,7 +45,7 @@ class Token(BaseModel): """Token response schema""" access_token: str token_type: str - refresh_token: str | None = None + refresh_token: Optional[str] = None class TokenData(BaseModel): diff --git a/app/auth/security.py b/app/auth/security.py index dc069db..1f2b045 100644 --- a/app/auth/security.py +++ b/app/auth/security.py @@ -1,7 +1,7 @@ """ Authentication and security utilities """ -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from typing import Optional, Union, Tuple from uuid import uuid4 from jose import JWTError, jwt @@ -54,12 +54,12 @@ def _decode_with_rotation(token: str) -> dict: def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str: """Create JWT access token""" to_encode = data.copy() - expire = datetime.utcnow() + ( + expire = datetime.now(timezone.utc) + ( expires_delta if expires_delta else timedelta(minutes=settings.access_token_expire_minutes) ) to_encode.update({ "exp": expire, - "iat": datetime.utcnow(), + "iat": datetime.now(timezone.utc), "type": "access", }) return _encode_with_rotation(to_encode) @@ -68,14 +68,14 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) - def create_refresh_token(user: User, user_agent: Optional[str], ip_address: Optional[str], db: Session) -> str: """Create refresh token, store its JTI in DB for revocation.""" jti = uuid4().hex - expire = datetime.utcnow() + timedelta(minutes=settings.refresh_token_expire_minutes) + expire = datetime.now(timezone.utc) + timedelta(minutes=settings.refresh_token_expire_minutes) payload = { "sub": user.username, "uid": user.id, "jti": jti, "type": "refresh", "exp": expire, - "iat": datetime.utcnow(), + "iat": datetime.now(timezone.utc), } token = _encode_with_rotation(payload) @@ -84,7 +84,7 @@ def create_refresh_token(user: User, user_agent: Optional[str], ip_address: Opti jti=jti, user_agent=user_agent, ip_address=ip_address, - issued_at=datetime.utcnow(), + issued_at=datetime.now(timezone.utc), expires_at=expire, revoked=False, ) @@ -93,6 +93,15 @@ def create_refresh_token(user: User, user_agent: Optional[str], ip_address: Opti return token +def _to_utc_aware(dt: Optional[datetime]) -> Optional[datetime]: + """Convert a datetime to UTC-aware. If naive, assume it's already UTC and attach tzinfo.""" + if dt is None: + return None + if dt.tzinfo is None: + return dt.replace(tzinfo=timezone.utc) + return dt.astimezone(timezone.utc) + + def verify_token(token: str) -> Optional[str]: """Verify JWT token and return username""" try: @@ -122,14 +131,20 @@ def decode_refresh_token(token: str) -> Optional[dict]: def is_refresh_token_revoked(jti: str, db: Session) -> bool: token_row = db.query(RefreshToken).filter(RefreshToken.jti == jti).first() - return not token_row or token_row.revoked or token_row.expires_at <= datetime.utcnow() + if not token_row: + return True + if token_row.revoked: + return True + expires_at_utc = _to_utc_aware(token_row.expires_at) + now_utc = datetime.now(timezone.utc) + return expires_at_utc is not None and expires_at_utc <= now_utc def revoke_refresh_token(jti: str, db: Session) -> None: token_row = db.query(RefreshToken).filter(RefreshToken.jti == jti).first() if token_row and not token_row.revoked: token_row.revoked = True - token_row.revoked_at = datetime.utcnow() + token_row.revoked_at = datetime.now(timezone.utc) db.commit() diff --git a/app/config.py b/app/config.py index 06e3d92..1ed3d80 100644 --- a/app/config.py +++ b/app/config.py @@ -57,6 +57,10 @@ class Settings(BaseSettings): log_rotation: str = "10 MB" log_retention: str = "30 days" + # Cache / Redis + cache_enabled: bool = False + redis_url: Optional[str] = None + # pydantic-settings v2 configuration model_config = SettingsConfigDict( env_file=".env", diff --git a/app/database/base.py b/app/database/base.py index 2ac8d20..168ea16 100644 --- a/app/database/base.py +++ b/app/database/base.py @@ -2,8 +2,7 @@ Database configuration and session management """ from sqlalchemy import create_engine -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import sessionmaker, Session +from sqlalchemy.orm import declarative_base, sessionmaker, Session from typing import Generator from app.config import settings diff --git a/app/database/fts.py b/app/database/fts.py new file mode 100644 index 0000000..2761e96 --- /dev/null +++ b/app/database/fts.py @@ -0,0 +1,248 @@ +""" +SQLite Full-Text Search (FTS5) helpers. + +Creates and maintains FTS virtual tables and triggers to keep them in sync +with their content tables. Designed to be called at app startup. +""" +from typing import Optional + +from sqlalchemy.engine import Engine +from sqlalchemy import text + + +def _execute_ignore_errors(engine: Engine, sql: str) -> None: + """Execute SQL, ignoring operational errors (e.g., when FTS5 is unavailable).""" + from sqlalchemy.exc import OperationalError + with engine.begin() as conn: + try: + conn.execute(text(sql)) + except OperationalError: + # Likely FTS5 extension not available in this SQLite build + pass + + +def ensure_rolodex_fts(engine: Engine) -> None: + """Ensure the `rolodex_fts` virtual table and triggers exist and are populated. + + This uses content=rolodex so the FTS table shadows the base table and is kept + in sync via triggers. + """ + # Create virtual table (if FTS5 is available) + _create_table = """ + CREATE VIRTUAL TABLE IF NOT EXISTS rolodex_fts USING fts5( + id, + first, + last, + city, + email, + memo, + content='rolodex', + content_rowid='rowid' + ); + """ + _execute_ignore_errors(engine, _create_table) + + # Triggers to keep FTS in sync + _triggers = [ + """ + CREATE TRIGGER IF NOT EXISTS rolodex_ai AFTER INSERT ON rolodex BEGIN + INSERT INTO rolodex_fts(rowid, id, first, last, city, email, memo) + VALUES (new.rowid, new.id, new.first, new.last, new.city, new.email, new.memo); + END; + """, + """ + CREATE TRIGGER IF NOT EXISTS rolodex_ad AFTER DELETE ON rolodex BEGIN + INSERT INTO rolodex_fts(rolodex_fts, rowid, id, first, last, city, email, memo) + VALUES ('delete', old.rowid, old.id, old.first, old.last, old.city, old.email, old.memo); + END; + """, + """ + CREATE TRIGGER IF NOT EXISTS rolodex_au AFTER UPDATE ON rolodex BEGIN + INSERT INTO rolodex_fts(rolodex_fts, rowid, id, first, last, city, email, memo) + VALUES ('delete', old.rowid, old.id, old.first, old.last, old.city, old.email, old.memo); + INSERT INTO rolodex_fts(rowid, id, first, last, city, email, memo) + VALUES (new.rowid, new.id, new.first, new.last, new.city, new.email, new.memo); + END; + """, + ] + for trig in _triggers: + _execute_ignore_errors(engine, trig) + + # Backfill if the FTS table exists but is empty + with engine.begin() as conn: + try: + count_fts = conn.execute(text("SELECT count(*) FROM rolodex_fts")).scalar() # type: ignore + if count_fts == 0: + # Populate from existing rolodex rows + conn.execute(text( + """ + INSERT INTO rolodex_fts(rowid, id, first, last, city, email, memo) + SELECT rowid, id, first, last, city, email, memo FROM rolodex; + """ + )) + except Exception: + # If FTS table doesn't exist or any error occurs, ignore silently + pass + + +def ensure_files_fts(engine: Engine) -> None: + """Ensure the `files_fts` virtual table and triggers exist and are populated.""" + _create_table = """ + CREATE VIRTUAL TABLE IF NOT EXISTS files_fts USING fts5( + file_no, + id, + regarding, + file_type, + memo, + content='files', + content_rowid='rowid' + ); + """ + _execute_ignore_errors(engine, _create_table) + + _triggers = [ + """ + CREATE TRIGGER IF NOT EXISTS files_ai AFTER INSERT ON files BEGIN + INSERT INTO files_fts(rowid, file_no, id, regarding, file_type, memo) + VALUES (new.rowid, new.file_no, new.id, new.regarding, new.file_type, new.memo); + END; + """, + """ + CREATE TRIGGER IF NOT EXISTS files_ad AFTER DELETE ON files BEGIN + INSERT INTO files_fts(files_fts, rowid, file_no, id, regarding, file_type, memo) + VALUES ('delete', old.rowid, old.file_no, old.id, old.regarding, old.file_type, old.memo); + END; + """, + """ + CREATE TRIGGER IF NOT EXISTS files_au AFTER UPDATE ON files BEGIN + INSERT INTO files_fts(files_fts, rowid, file_no, id, regarding, file_type, memo) + VALUES ('delete', old.rowid, old.file_no, old.id, old.regarding, old.file_type, old.memo); + INSERT INTO files_fts(rowid, file_no, id, regarding, file_type, memo) + VALUES (new.rowid, new.file_no, new.id, new.regarding, new.file_type, new.memo); + END; + """, + ] + for trig in _triggers: + _execute_ignore_errors(engine, trig) + + with engine.begin() as conn: + try: + count_fts = conn.execute(text("SELECT count(*) FROM files_fts")).scalar() # type: ignore + if count_fts == 0: + conn.execute(text( + """ + INSERT INTO files_fts(rowid, file_no, id, regarding, file_type, memo) + SELECT rowid, file_no, id, regarding, file_type, memo FROM files; + """ + )) + except Exception: + pass + + +def ensure_ledger_fts(engine: Engine) -> None: + """Ensure the `ledger_fts` virtual table and triggers exist and are populated.""" + _create_table = """ + CREATE VIRTUAL TABLE IF NOT EXISTS ledger_fts USING fts5( + file_no, + t_code, + note, + empl_num, + content='ledger', + content_rowid='rowid' + ); + """ + _execute_ignore_errors(engine, _create_table) + + _triggers = [ + """ + CREATE TRIGGER IF NOT EXISTS ledger_ai AFTER INSERT ON ledger BEGIN + INSERT INTO ledger_fts(rowid, file_no, t_code, note, empl_num) + VALUES (new.rowid, new.file_no, new.t_code, new.note, new.empl_num); + END; + """, + """ + CREATE TRIGGER IF NOT EXISTS ledger_ad AFTER DELETE ON ledger BEGIN + INSERT INTO ledger_fts(ledger_fts, rowid, file_no, t_code, note, empl_num) + VALUES ('delete', old.rowid, old.file_no, old.t_code, old.note, old.empl_num); + END; + """, + """ + CREATE TRIGGER IF NOT EXISTS ledger_au AFTER UPDATE ON ledger BEGIN + INSERT INTO ledger_fts(ledger_fts, rowid, file_no, t_code, note, empl_num) + VALUES ('delete', old.rowid, old.file_no, old.t_code, old.note, old.empl_num); + INSERT INTO ledger_fts(rowid, file_no, t_code, note, empl_num) + VALUES (new.rowid, new.file_no, new.t_code, new.note, new.empl_num); + END; + """, + ] + for trig in _triggers: + _execute_ignore_errors(engine, trig) + + with engine.begin() as conn: + try: + count_fts = conn.execute(text("SELECT count(*) FROM ledger_fts")).scalar() # type: ignore + if count_fts == 0: + conn.execute(text( + """ + INSERT INTO ledger_fts(rowid, file_no, t_code, note, empl_num) + SELECT rowid, file_no, t_code, note, empl_num FROM ledger; + """ + )) + except Exception: + pass + + +def ensure_qdros_fts(engine: Engine) -> None: + """Ensure the `qdros_fts` virtual table and triggers exist and are populated.""" + _create_table = """ + CREATE VIRTUAL TABLE IF NOT EXISTS qdros_fts USING fts5( + file_no, + form_name, + pet, + res, + case_number, + content='qdros', + content_rowid='rowid' + ); + """ + _execute_ignore_errors(engine, _create_table) + + _triggers = [ + """ + CREATE TRIGGER IF NOT EXISTS qdros_ai AFTER INSERT ON qdros BEGIN + INSERT INTO qdros_fts(rowid, file_no, form_name, pet, res, case_number) + VALUES (new.rowid, new.file_no, new.form_name, new.pet, new.res, new.case_number); + END; + """, + """ + CREATE TRIGGER IF NOT EXISTS qdros_ad AFTER DELETE ON qdros BEGIN + INSERT INTO qdros_fts(qdros_fts, rowid, file_no, form_name, pet, res, case_number) + VALUES ('delete', old.rowid, old.file_no, old.form_name, old.pet, old.res, old.case_number); + END; + """, + """ + CREATE TRIGGER IF NOT EXISTS qdros_au AFTER UPDATE ON qdros BEGIN + INSERT INTO qdros_fts(qdros_fts, rowid, file_no, form_name, pet, res, case_number) + VALUES ('delete', old.rowid, old.file_no, old.form_name, old.pet, old.res, old.case_number); + INSERT INTO qdros_fts(rowid, file_no, form_name, pet, res, case_number) + VALUES (new.rowid, new.file_no, new.form_name, new.pet, new.res, new.case_number); + END; + """, + ] + for trig in _triggers: + _execute_ignore_errors(engine, trig) + + with engine.begin() as conn: + try: + count_fts = conn.execute(text("SELECT count(*) FROM qdros_fts")).scalar() # type: ignore + if count_fts == 0: + conn.execute(text( + """ + INSERT INTO qdros_fts(rowid, file_no, form_name, pet, res, case_number) + SELECT rowid, file_no, form_name, pet, res, case_number FROM qdros; + """ + )) + except Exception: + pass + + diff --git a/app/database/indexes.py b/app/database/indexes.py new file mode 100644 index 0000000..8090d98 --- /dev/null +++ b/app/database/indexes.py @@ -0,0 +1,31 @@ +""" +Database secondary indexes helper. + +Creates small B-tree indexes for common equality filters to speed up searches. +Uses CREATE INDEX IF NOT EXISTS so it is safe to call repeatedly at startup +and works for existing databases without running a migration. +""" +from sqlalchemy.engine import Engine +from sqlalchemy import text + + +def ensure_secondary_indexes(engine: Engine) -> None: + statements = [ + # Files + "CREATE INDEX IF NOT EXISTS idx_files_status ON files(status)", + "CREATE INDEX IF NOT EXISTS idx_files_file_type ON files(file_type)", + "CREATE INDEX IF NOT EXISTS idx_files_empl_num ON files(empl_num)", + # Ledger + "CREATE INDEX IF NOT EXISTS idx_ledger_t_type ON ledger(t_type)", + "CREATE INDEX IF NOT EXISTS idx_ledger_empl_num ON ledger(empl_num)", + ] + with engine.begin() as conn: + for stmt in statements: + try: + conn.execute(text(stmt)) + except Exception: + # Ignore failures (e.g., non-SQLite engines that still support IF NOT EXISTS; + # if not supported, users should manage indexes via migrations) + pass + + diff --git a/app/database/schema_updates.py b/app/database/schema_updates.py new file mode 100644 index 0000000..87c7e12 --- /dev/null +++ b/app/database/schema_updates.py @@ -0,0 +1,130 @@ +""" +Lightweight, idempotent schema updates for SQLite. + +Adds newly introduced columns to existing tables when running on an +already-initialized database. Safe to call multiple times. +""" +from typing import Dict +from sqlalchemy.engine import Engine + + +def _existing_columns(conn, table: str) -> set[str]: + rows = conn.execute(f"PRAGMA table_info('{table}')").fetchall() + return {row[1] for row in rows} # name is column 2 + + +def ensure_schema_updates(engine: Engine) -> None: + """Ensure missing columns are added for backward-compatible updates.""" + # Map of table -> {column: SQL type} + updates: Dict[str, Dict[str, str]] = { + # Forms + "form_index": { + "keyword": "TEXT", + }, + # Richer Life/Number tables (forms & pensions harmonized) + "life_tables": { + "le_aa": "FLOAT", + "na_aa": "FLOAT", + "le_am": "FLOAT", + "na_am": "FLOAT", + "le_af": "FLOAT", + "na_af": "FLOAT", + "le_wa": "FLOAT", + "na_wa": "FLOAT", + "le_wm": "FLOAT", + "na_wm": "FLOAT", + "le_wf": "FLOAT", + "na_wf": "FLOAT", + "le_ba": "FLOAT", + "na_ba": "FLOAT", + "le_bm": "FLOAT", + "na_bm": "FLOAT", + "le_bf": "FLOAT", + "na_bf": "FLOAT", + "le_ha": "FLOAT", + "na_ha": "FLOAT", + "le_hm": "FLOAT", + "na_hm": "FLOAT", + "le_hf": "FLOAT", + "na_hf": "FLOAT", + "table_year": "INTEGER", + "table_type": "VARCHAR(45)", + }, + "number_tables": { + "month": "INTEGER", + "na_aa": "FLOAT", + "na_am": "FLOAT", + "na_af": "FLOAT", + "na_wa": "FLOAT", + "na_wm": "FLOAT", + "na_wf": "FLOAT", + "na_ba": "FLOAT", + "na_bm": "FLOAT", + "na_bf": "FLOAT", + "na_ha": "FLOAT", + "na_hm": "FLOAT", + "na_hf": "FLOAT", + "table_type": "VARCHAR(45)", + "description": "TEXT", + }, + "form_list": { + "status": "VARCHAR(45)", + }, + # Printers: add advanced legacy fields + "printers": { + "number": "INTEGER", + "page_break": "VARCHAR(50)", + "setup_st": "VARCHAR(200)", + "reset_st": "VARCHAR(200)", + "b_underline": "VARCHAR(100)", + "e_underline": "VARCHAR(100)", + "b_bold": "VARCHAR(100)", + "e_bold": "VARCHAR(100)", + "phone_book": "BOOLEAN", + "rolodex_info": "BOOLEAN", + "envelope": "BOOLEAN", + "file_cabinet": "BOOLEAN", + "accounts": "BOOLEAN", + "statements": "BOOLEAN", + "calendar": "BOOLEAN", + }, + # Pensions + "pension_schedules": { + "vests_on": "DATE", + "vests_at": "FLOAT", + }, + "marriage_history": { + "married_from": "DATE", + "married_to": "DATE", + "married_years": "FLOAT", + "service_from": "DATE", + "service_to": "DATE", + "service_years": "FLOAT", + "marital_percent": "FLOAT", + }, + "death_benefits": { + "lump1": "FLOAT", + "lump2": "FLOAT", + "growth1": "FLOAT", + "growth2": "FLOAT", + "disc1": "FLOAT", + "disc2": "FLOAT", + }, + } + + with engine.begin() as conn: + for table, cols in updates.items(): + try: + existing = _existing_columns(conn, table) + except Exception: + # Table may not exist yet + continue + for col_name, col_type in cols.items(): + if col_name not in existing: + try: + conn.execute(f"ALTER TABLE {table} ADD COLUMN {col_name} {col_type}") + except Exception: + # Ignore if not applicable (other engines) or race condition + pass + + diff --git a/app/main.py b/app/main.py index 8f11c95..6bcb86c 100644 --- a/app/main.py +++ b/app/main.py @@ -9,6 +9,9 @@ from fastapi.middleware.cors import CORSMiddleware from app.config import settings from app.database.base import engine +from app.database.fts import ensure_rolodex_fts, ensure_files_fts, ensure_ledger_fts, ensure_qdros_fts +from app.database.indexes import ensure_secondary_indexes +from app.database.schema_updates import ensure_schema_updates from app.models import BaseModel from app.models.user import User from app.auth.security import get_admin_user @@ -24,6 +27,21 @@ logger = get_logger("main") logger.info("Creating database tables") BaseModel.metadata.create_all(bind=engine) +# Initialize SQLite FTS (if available) +logger.info("Initializing FTS (if available)") +ensure_rolodex_fts(engine) +ensure_files_fts(engine) +ensure_ledger_fts(engine) +ensure_qdros_fts(engine) + +# Ensure helpful secondary indexes +logger.info("Ensuring secondary indexes (status, type, employee, etc.)") +ensure_secondary_indexes(engine) + +# Ensure idempotent schema updates for added columns +logger.info("Ensuring schema updates (new columns)") +ensure_schema_updates(engine) + # Initialize FastAPI app logger.info("Initializing FastAPI application", version=settings.app_version, debug=settings.debug) app = FastAPI( @@ -71,6 +89,7 @@ from app.api.import_data import router as import_router from app.api.flexible import router as flexible_router from app.api.support import router as support_router from app.api.settings import router as settings_router +from app.api.mortality import router as mortality_router logger.info("Including API routers") app.include_router(auth_router, prefix="/api/auth", tags=["authentication"]) @@ -84,6 +103,7 @@ app.include_router(import_router, prefix="/api/import", tags=["import"]) app.include_router(support_router, prefix="/api/support", tags=["support"]) app.include_router(settings_router, prefix="/api/settings", tags=["settings"]) app.include_router(flexible_router, prefix="/api") +app.include_router(mortality_router, prefix="/api/mortality", tags=["mortality"]) @app.get("/", response_class=HTMLResponse) diff --git a/app/middleware/errors.py b/app/middleware/errors.py index f1575cb..7917f56 100644 --- a/app/middleware/errors.py +++ b/app/middleware/errors.py @@ -46,6 +46,25 @@ def _get_correlation_id(request: Request) -> str: return str(uuid4()) +def _json_safe(value: Any) -> Any: + """Recursively convert non-JSON-serializable objects (like Exceptions) into strings. + + Keeps overall structure intact so tests inspecting error details (e.g. 'loc', 'msg') still work. + """ + # Exception -> string message + if isinstance(value, BaseException): + return str(value) + # Mapping types + if isinstance(value, dict): + return {k: _json_safe(v) for k, v in value.items()} + # Sequence types + if isinstance(value, (list, tuple)): + return [ + _json_safe(v) for v in value + ] + return value + + def _build_error_response( request: Request, *, @@ -66,7 +85,7 @@ def _build_error_response( "correlation_id": correlation_id, } if details is not None: - body["error"]["details"] = details + body["error"]["details"] = _json_safe(details) response = JSONResponse(content=body, status_code=status_code) response.headers[ERROR_HEADER_NAME] = correlation_id diff --git a/app/models/__init__.py b/app/models/__init__.py index 33f5b20..145754a 100644 --- a/app/models/__init__.py +++ b/app/models/__init__.py @@ -14,12 +14,12 @@ from .flexible import FlexibleImport from .support import SupportTicket, TicketResponse, TicketStatus, TicketPriority, TicketCategory from .pensions import ( Pension, PensionSchedule, MarriageHistory, DeathBenefit, - SeparationAgreement, LifeTable, NumberTable + SeparationAgreement, LifeTable, NumberTable, PensionResult ) from .lookups import ( Employee, FileType, FileStatus, TransactionType, TransactionCode, State, GroupLookup, Footer, PlanInfo, FormIndex, FormList, - PrinterSetup, SystemSetup + PrinterSetup, SystemSetup, FormKeyword ) __all__ = [ @@ -28,8 +28,8 @@ __all__ = [ "Deposit", "Payment", "FileNote", "FormVariable", "ReportVariable", "Document", "FlexibleImport", "SupportTicket", "TicketResponse", "TicketStatus", "TicketPriority", "TicketCategory", "Pension", "PensionSchedule", "MarriageHistory", "DeathBenefit", - "SeparationAgreement", "LifeTable", "NumberTable", + "SeparationAgreement", "LifeTable", "NumberTable", "PensionResult", "Employee", "FileType", "FileStatus", "TransactionType", "TransactionCode", "State", "GroupLookup", "Footer", "PlanInfo", "FormIndex", "FormList", - "PrinterSetup", "SystemSetup" + "PrinterSetup", "SystemSetup", "FormKeyword" ] \ No newline at end of file diff --git a/app/models/audit.py b/app/models/audit.py index b824db9..6108879 100644 --- a/app/models/audit.py +++ b/app/models/audit.py @@ -3,7 +3,7 @@ Audit logging models """ from sqlalchemy import Column, Integer, String, Text, DateTime, ForeignKey, JSON from sqlalchemy.orm import relationship -from datetime import datetime +from datetime import datetime, timezone from app.models.base import BaseModel @@ -22,7 +22,7 @@ class AuditLog(BaseModel): details = Column(JSON, nullable=True) # Additional details as JSON ip_address = Column(String(45), nullable=True) # IPv4/IPv6 address user_agent = Column(Text, nullable=True) # Browser/client information - timestamp = Column(DateTime, default=datetime.utcnow, nullable=False, index=True) + timestamp = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False, index=True) # Relationships user = relationship("User", back_populates="audit_logs") @@ -42,7 +42,7 @@ class LoginAttempt(BaseModel): ip_address = Column(String(45), nullable=False) user_agent = Column(Text, nullable=True) success = Column(Integer, default=0) # 1 for success, 0 for failure - timestamp = Column(DateTime, default=datetime.utcnow, nullable=False, index=True) + timestamp = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False, index=True) failure_reason = Column(String(200), nullable=True) # Reason for failure def __repr__(self): @@ -56,8 +56,8 @@ class ImportAudit(BaseModel): __tablename__ = "import_audit" id = Column(Integer, primary_key=True, autoincrement=True, index=True) - started_at = Column(DateTime, default=datetime.utcnow, nullable=False, index=True) - finished_at = Column(DateTime, nullable=True, index=True) + started_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False, index=True) + finished_at = Column(DateTime(timezone=True), nullable=True, index=True) status = Column(String(30), nullable=False, default="running", index=True) # running|success|completed_with_errors|failed total_files = Column(Integer, nullable=False, default=0) @@ -94,7 +94,7 @@ class ImportAuditFile(BaseModel): errors = Column(Integer, nullable=False, default=0) message = Column(String(255), nullable=True) details = Column(JSON, nullable=True) - created_at = Column(DateTime, default=datetime.utcnow, nullable=False, index=True) + created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False, index=True) audit = relationship("ImportAudit", back_populates="files") diff --git a/app/models/auth.py b/app/models/auth.py index 2fa1d26..ba8040c 100644 --- a/app/models/auth.py +++ b/app/models/auth.py @@ -1,7 +1,7 @@ """ Authentication-related persistence models """ -from datetime import datetime +from datetime import datetime, timezone from typing import Optional from sqlalchemy import Column, Integer, String, DateTime, Boolean, ForeignKey, UniqueConstraint @@ -19,10 +19,10 @@ class RefreshToken(BaseModel): jti = Column(String(64), nullable=False, unique=True, index=True) user_agent = Column(String(255), nullable=True) ip_address = Column(String(45), nullable=True) - issued_at = Column(DateTime, default=datetime.utcnow, nullable=False) - expires_at = Column(DateTime, nullable=False, index=True) + issued_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False) + expires_at = Column(DateTime(timezone=True), nullable=False, index=True) revoked = Column(Boolean, default=False, nullable=False) - revoked_at = Column(DateTime, nullable=True) + revoked_at = Column(DateTime(timezone=True), nullable=True) # relationships user = relationship("User") diff --git a/app/models/lookups.py b/app/models/lookups.py index e8758e3..1171919 100644 --- a/app/models/lookups.py +++ b/app/models/lookups.py @@ -1,7 +1,7 @@ """ Lookup table models based on legacy system analysis """ -from sqlalchemy import Column, Integer, String, Text, Boolean, Float +from sqlalchemy import Column, Integer, String, Text, Boolean, Float, ForeignKey from app.models.base import BaseModel @@ -53,6 +53,9 @@ class FileStatus(BaseModel): description = Column(String(200), nullable=False) # Description active = Column(Boolean, default=True) # Is status active sort_order = Column(Integer, default=0) # Display order + # Legacy fields for typed import support + send = Column(Boolean, default=True) # Should statements print by default + footer_code = Column(String(45), ForeignKey("footers.footer_code")) # Default footer def __repr__(self): return f"" @@ -169,9 +172,10 @@ class FormIndex(BaseModel): """ __tablename__ = "form_index" - form_id = Column(String(45), primary_key=True, index=True) # Form identifier + form_id = Column(String(45), primary_key=True, index=True) # Form identifier (maps to Name) form_name = Column(String(200), nullable=False) # Form name category = Column(String(45)) # Form category + keyword = Column(String(200)) # Legacy FORM_INX Name/Keyword pair active = Column(Boolean, default=True) # Is form active def __repr__(self): @@ -189,6 +193,7 @@ class FormList(BaseModel): form_id = Column(String(45), nullable=False) # Form identifier line_number = Column(Integer, nullable=False) # Line number in form content = Column(Text) # Line content + status = Column(String(45)) # Legacy FORM_LST Status def __repr__(self): return f"" @@ -201,12 +206,34 @@ class PrinterSetup(BaseModel): """ __tablename__ = "printers" + # Core identity and basic configuration printer_name = Column(String(100), primary_key=True, index=True) # Printer name description = Column(String(200)) # Description driver = Column(String(100)) # Print driver port = Column(String(20)) # Port/connection default_printer = Column(Boolean, default=False) # Is default printer active = Column(Boolean, default=True) # Is printer active + + # Legacy numeric printer number (from PRINTERS.csv "Number") + number = Column(Integer) + + # Legacy control sequences and formatting (from PRINTERS.csv) + page_break = Column(String(50)) + setup_st = Column(String(200)) + reset_st = Column(String(200)) + b_underline = Column(String(100)) + e_underline = Column(String(100)) + b_bold = Column(String(100)) + e_bold = Column(String(100)) + + # Optional report configuration toggles (legacy flags) + phone_book = Column(Boolean, default=False) + rolodex_info = Column(Boolean, default=False) + envelope = Column(Boolean, default=False) + file_cabinet = Column(Boolean, default=False) + accounts = Column(Boolean, default=False) + statements = Column(Boolean, default=False) + calendar = Column(Boolean, default=False) def __repr__(self): return f"" @@ -225,4 +252,19 @@ class SystemSetup(BaseModel): setting_type = Column(String(20), default="STRING") # DATA type (STRING, INTEGER, FLOAT, BOOLEAN) def __repr__(self): - return f"" \ No newline at end of file + return f"" + + +class FormKeyword(BaseModel): + """ + Form keyword lookup + Corresponds to INX_LKUP table in legacy system + """ + __tablename__ = "form_keywords" + + keyword = Column(String(200), primary_key=True, index=True) + description = Column(String(200)) + active = Column(Boolean, default=True) + + def __repr__(self): + return f"" \ No newline at end of file diff --git a/app/models/pensions.py b/app/models/pensions.py index 77909f6..cf5ec02 100644 --- a/app/models/pensions.py +++ b/app/models/pensions.py @@ -60,11 +60,13 @@ class PensionSchedule(BaseModel): file_no = Column(String(45), ForeignKey("files.file_no"), nullable=False) version = Column(String(10), default="01") - # Schedule details + # Schedule details (legacy vesting fields) start_date = Column(Date) # Start date for payments end_date = Column(Date) # End date for payments payment_amount = Column(Float, default=0.0) # Payment amount frequency = Column(String(20)) # Monthly, quarterly, etc. + vests_on = Column(Date) # Legacy SCHEDULE.csv Vests_On + vests_at = Column(Float, default=0.0) # Legacy SCHEDULE.csv Vests_At (percent) # Relationships file = relationship("File", back_populates="pension_schedules") @@ -85,6 +87,15 @@ class MarriageHistory(BaseModel): divorce_date = Column(Date) # Date of divorce/separation spouse_name = Column(String(100)) # Spouse name notes = Column(Text) # Additional notes + + # Legacy MARRIAGE.csv fields + married_from = Column(Date) + married_to = Column(Date) + married_years = Column(Float, default=0.0) + service_from = Column(Date) + service_to = Column(Date) + service_years = Column(Float, default=0.0) + marital_percent = Column(Float, default=0.0) # Relationships file = relationship("File", back_populates="marriage_history") @@ -105,6 +116,14 @@ class DeathBenefit(BaseModel): benefit_amount = Column(Float, default=0.0) # Benefit amount benefit_type = Column(String(45)) # Type of death benefit notes = Column(Text) # Additional notes + + # Legacy DEATH.csv fields + lump1 = Column(Float, default=0.0) + lump2 = Column(Float, default=0.0) + growth1 = Column(Float, default=0.0) + growth2 = Column(Float, default=0.0) + disc1 = Column(Float, default=0.0) + disc2 = Column(Float, default=0.0) # Relationships file = relationship("File", back_populates="death_benefits") @@ -138,10 +157,36 @@ class LifeTable(BaseModel): id = Column(Integer, primary_key=True, autoincrement=True) age = Column(Integer, nullable=False) # Age - male_expectancy = Column(Float) # Male life expectancy - female_expectancy = Column(Float) # Female life expectancy - table_year = Column(Integer) # Year of table (e.g., 2023) - table_type = Column(String(45)) # Type of table + # Rich typed columns reflecting legacy LIFETABL.csv headers + # LE_* = Life Expectancy, NA_* = Number Alive/Survivors + le_aa = Column(Float) + na_aa = Column(Float) + le_am = Column(Float) + na_am = Column(Float) + le_af = Column(Float) + na_af = Column(Float) + le_wa = Column(Float) + na_wa = Column(Float) + le_wm = Column(Float) + na_wm = Column(Float) + le_wf = Column(Float) + na_wf = Column(Float) + le_ba = Column(Float) + na_ba = Column(Float) + le_bm = Column(Float) + na_bm = Column(Float) + le_bf = Column(Float) + na_bf = Column(Float) + le_ha = Column(Float) + na_ha = Column(Float) + le_hm = Column(Float) + na_hm = Column(Float) + le_hf = Column(Float) + na_hf = Column(Float) + + # Optional metadata retained for future variations + table_year = Column(Integer) # Year/version of table if known + table_type = Column(String(45)) # Source/type of table (optional) class NumberTable(BaseModel): @@ -152,7 +197,63 @@ class NumberTable(BaseModel): __tablename__ = "number_tables" id = Column(Integer, primary_key=True, autoincrement=True) - table_type = Column(String(45), nullable=False) # Type of table - key_value = Column(String(45), nullable=False) # Key identifier - numeric_value = Column(Float) # Numeric value - description = Column(Text) # Description \ No newline at end of file + month = Column(Integer, nullable=False) + # Rich typed NA_* columns reflecting legacy NUMBERAL.csv headers + na_aa = Column(Float) + na_am = Column(Float) + na_af = Column(Float) + na_wa = Column(Float) + na_wm = Column(Float) + na_wf = Column(Float) + na_ba = Column(Float) + na_bm = Column(Float) + na_bf = Column(Float) + na_ha = Column(Float) + na_hm = Column(Float) + na_hf = Column(Float) + + # Optional metadata retained for future variations + table_type = Column(String(45)) + description = Column(Text) + + +class PensionResult(BaseModel): + """ + Computed pension results summary + Corresponds to RESULTS table in legacy system + """ + __tablename__ = "pension_results" + + id = Column(Integer, primary_key=True, autoincrement=True) + + # Optional linkage if present in future exports + file_no = Column(String(45)) + version = Column(String(10)) + + # Columns observed in legacy RESULTS.csv header + accrued = Column(Float) + start_age = Column(Integer) + cola = Column(Float) + withdrawal = Column(String(45)) + pre_dr = Column(Float) + post_dr = Column(Float) + tax_rate = Column(Float) + age = Column(Integer) + years_from = Column(Float) + life_exp = Column(Float) + ev_monthly = Column(Float) + payments = Column(Float) + pay_out = Column(Float) + fund_value = Column(Float) + pv = Column(Float) + mortality = Column(Float) + pv_am = Column(Float) + pv_amt = Column(Float) + pv_pre_db = Column(Float) + pv_annuity = Column(Float) + wv_at = Column(Float) + pv_plan = Column(Float) + years_married = Column(Float) + years_service = Column(Float) + marr_per = Column(Float) + marr_amt = Column(Float) \ No newline at end of file diff --git a/app/models/support.py b/app/models/support.py index 045ac4c..3bceaf3 100644 --- a/app/models/support.py +++ b/app/models/support.py @@ -3,7 +3,7 @@ Support ticket models for help desk functionality """ from sqlalchemy import Column, Integer, String, Text, DateTime, Boolean, ForeignKey, Enum from sqlalchemy.orm import relationship -from datetime import datetime +from datetime import datetime, timezone import enum from app.models.base import BaseModel @@ -63,9 +63,9 @@ class SupportTicket(BaseModel): ip_address = Column(String(45)) # IP address # Timestamps - created_at = Column(DateTime, default=datetime.utcnow, nullable=False) - updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) - resolved_at = Column(DateTime) + created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False) + updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc)) + resolved_at = Column(DateTime(timezone=True)) # Admin assignment assigned_to = Column(Integer, ForeignKey("users.id")) @@ -95,7 +95,7 @@ class TicketResponse(BaseModel): author_email = Column(String(100)) # For non-user responses # Timestamps - created_at = Column(DateTime, default=datetime.utcnow, nullable=False) + created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False) # Relationships ticket = relationship("SupportTicket", back_populates="responses") diff --git a/app/services/audit.py b/app/services/audit.py index 91cfd06..174d24a 100644 --- a/app/services/audit.py +++ b/app/services/audit.py @@ -3,7 +3,7 @@ Audit logging service """ import json from typing import Dict, Any, Optional -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from sqlalchemy.orm import Session from fastapi import Request @@ -65,7 +65,7 @@ class AuditService: details=details, ip_address=ip_address, user_agent=user_agent, - timestamp=datetime.utcnow() + timestamp=datetime.now(timezone.utc) ) try: @@ -76,7 +76,7 @@ class AuditService: except Exception as e: db.rollback() # Log the error but don't fail the main operation - logger.error("Failed to log audit entry", error=str(e), action=action, user_id=user_id) + logger.error("Failed to log audit entry", error=str(e), action=action) return audit_log @staticmethod @@ -119,7 +119,7 @@ class AuditService: ip_address=ip_address or "unknown", user_agent=user_agent, success=1 if success else 0, - timestamp=datetime.utcnow(), + timestamp=datetime.now(timezone.utc), failure_reason=failure_reason if not success else None ) @@ -252,7 +252,7 @@ class AuditService: Returns: List of failed login attempts """ - cutoff_time = datetime.utcnow() - timedelta(hours=hours) + cutoff_time = datetime.now(timezone.utc) - timedelta(hours=hours) query = db.query(LoginAttempt).filter( LoginAttempt.success == 0, LoginAttempt.timestamp >= cutoff_time diff --git a/app/services/cache.py b/app/services/cache.py new file mode 100644 index 0000000..7b1fce1 --- /dev/null +++ b/app/services/cache.py @@ -0,0 +1,98 @@ +""" +Cache utilities with optional Redis backend. + +If Redis is not configured or unavailable, all functions degrade to no-ops. +""" +from __future__ import annotations + +import asyncio +import json +import hashlib +from typing import Any, Optional + +try: + import redis.asyncio as redis # type: ignore +except Exception: # pragma: no cover - allow running without redis installed + redis = None # type: ignore + +from app.config import settings + + +_client: Optional["redis.Redis"] = None # type: ignore +_lock = asyncio.Lock() + + +async def _get_client() -> Optional["redis.Redis"]: # type: ignore + """Lazily initialize and return a shared Redis client if enabled.""" + global _client + if not getattr(settings, "redis_url", None) or not getattr(settings, "cache_enabled", False): + return None + if redis is None: + return None + if _client is not None: + return _client + async with _lock: + if _client is None: + try: + _client = redis.from_url(settings.redis_url, decode_responses=True) # type: ignore + except Exception: + _client = None + return _client + + +def _stable_hash(obj: Any) -> str: + data = json.dumps(obj, sort_keys=True, separators=(",", ":")) + return hashlib.sha1(data.encode("utf-8")).hexdigest() + + +def build_key(kind: str, user_id: Optional[str], parts: dict) -> str: + payload = {"u": user_id or "anon", "p": parts} + return f"search:{kind}:v1:{_stable_hash(payload)}" + + +async def cache_get_json(kind: str, user_id: Optional[str], parts: dict) -> Optional[Any]: + client = await _get_client() + if client is None: + return None + key = build_key(kind, user_id, parts) + try: + raw = await client.get(key) + if raw is None: + return None + return json.loads(raw) + except Exception: + return None + + +async def cache_set_json(kind: str, user_id: Optional[str], parts: dict, value: Any, ttl_seconds: int) -> None: + client = await _get_client() + if client is None: + return + key = build_key(kind, user_id, parts) + try: + await client.set(key, json.dumps(value, separators=(",", ":")), ex=ttl_seconds) + except Exception: + return + + +async def invalidate_prefix(prefix: str) -> None: + client = await _get_client() + if client is None: + return + try: + # Use SCAN to avoid blocking Redis + async for key in client.scan_iter(match=f"{prefix}*"): + try: + await client.delete(key) + except Exception: + pass + except Exception: + return + + +async def invalidate_search_cache() -> None: + # Wipe both global search and suggestions namespaces + await invalidate_prefix("search:global:") + await invalidate_prefix("search:suggestions:") + + diff --git a/app/services/customers_search.py b/app/services/customers_search.py new file mode 100644 index 0000000..e39167d --- /dev/null +++ b/app/services/customers_search.py @@ -0,0 +1,141 @@ +from typing import Optional, List +from sqlalchemy import or_, and_, func, asc, desc + +from app.models.rolodex import Rolodex + + +def apply_customer_filters(base_query, search: Optional[str], group: Optional[str], state: Optional[str], groups: Optional[List[str]], states: Optional[List[str]]): + """Apply shared search and group/state filters to the provided base_query. + + This helper is used by both list and export endpoints to keep logic in sync. + """ + s = (search or "").strip() + if s: + s_lower = s.lower() + tokens = [t for t in s_lower.split() if t] + contains_any = or_( + func.lower(Rolodex.id).contains(s_lower), + func.lower(Rolodex.last).contains(s_lower), + func.lower(Rolodex.first).contains(s_lower), + func.lower(Rolodex.middle).contains(s_lower), + func.lower(Rolodex.city).contains(s_lower), + func.lower(Rolodex.email).contains(s_lower), + ) + name_tokens = [ + or_( + func.lower(Rolodex.first).contains(tok), + func.lower(Rolodex.middle).contains(tok), + func.lower(Rolodex.last).contains(tok), + ) + for tok in tokens + ] + combined = contains_any if not name_tokens else or_(contains_any, and_(*name_tokens)) + + last_first_filter = None + if "," in s_lower: + last_part, first_part = [p.strip() for p in s_lower.split(",", 1)] + if last_part and first_part: + last_first_filter = and_( + func.lower(Rolodex.last).contains(last_part), + func.lower(Rolodex.first).contains(first_part), + ) + elif last_part: + last_first_filter = func.lower(Rolodex.last).contains(last_part) + + final_filter = or_(combined, last_first_filter) if last_first_filter is not None else combined + base_query = base_query.filter(final_filter) + + effective_groups = [g for g in (groups or []) if g] or ([group] if group else []) + if effective_groups: + base_query = base_query.filter(Rolodex.group.in_(effective_groups)) + + effective_states = [s for s in (states or []) if s] or ([state] if state else []) + if effective_states: + base_query = base_query.filter(Rolodex.abrev.in_(effective_states)) + + return base_query + + +def apply_customer_sorting(base_query, sort_by: Optional[str], sort_dir: Optional[str]): + """Apply shared sorting to the provided base_query. + + Supported fields: id, name (last,first), city (city,state), email. + Unknown fields fall back to id. Sorting is case-insensitive for strings. + """ + normalized_sort_by = (sort_by or "id").lower() + normalized_sort_dir = (sort_dir or "asc").lower() + is_desc = normalized_sort_dir == "desc" + + order_columns = [] + if normalized_sort_by == "id": + order_columns = [Rolodex.id] + elif normalized_sort_by == "name": + order_columns = [Rolodex.last, Rolodex.first] + elif normalized_sort_by == "city": + order_columns = [Rolodex.city, Rolodex.abrev] + elif normalized_sort_by == "email": + order_columns = [Rolodex.email] + else: + order_columns = [Rolodex.id] + + ordered = [] + for col in order_columns: + try: + expr = func.lower(col) if col.type.python_type in (str,) else col # type: ignore[attr-defined] + except Exception: + expr = col + ordered.append(desc(expr) if is_desc else asc(expr)) + + if ordered: + base_query = base_query.order_by(*ordered) + return base_query + + +def prepare_customer_csv_rows(customers: List[Rolodex], fields: Optional[List[str]]): + """Prepare CSV header and rows for the given customers and requested fields. + + Returns a tuple: (header_row, rows), where header_row is a list of column + titles and rows is a list of row lists ready to be written by csv.writer. + """ + allowed_fields_in_order = ["id", "name", "group", "city", "state", "phone", "email"] + header_names = { + "id": "Customer ID", + "name": "Name", + "group": "Group", + "city": "City", + "state": "State", + "phone": "Primary Phone", + "email": "Email", + } + + requested = [f.lower() for f in (fields or []) if isinstance(f, str)] + selected_fields = [f for f in allowed_fields_in_order if f in requested] if requested else allowed_fields_in_order + if not selected_fields: + selected_fields = allowed_fields_in_order + + header_row = [header_names[f] for f in selected_fields] + + rows: List[List[str]] = [] + for c in customers: + full_name = f"{(c.first or '').strip()} {(c.last or '').strip()}".strip() + primary_phone = "" + try: + if getattr(c, "phone_numbers", None): + primary_phone = c.phone_numbers[0].phone or "" + except Exception: + primary_phone = "" + + row_map = { + "id": c.id, + "name": full_name, + "group": c.group or "", + "city": c.city or "", + "state": c.abrev or "", + "phone": primary_phone, + "email": c.email or "", + } + rows.append([row_map[f] for f in selected_fields]) + + return header_row, rows + + diff --git a/app/services/mortality.py b/app/services/mortality.py new file mode 100644 index 0000000..6c0c064 --- /dev/null +++ b/app/services/mortality.py @@ -0,0 +1,127 @@ +""" +Mortality/Life table utilities. + +Helpers to query `life_tables` and `number_tables` by age/month and +return values filtered by sex/race using compact codes: + - sex: M, F, A (All) + - race: W (White), B (Black), H (Hispanic), A (All) + +Column naming in tables follows the pattern: + - LifeTable: le_{race}{sex}, na_{race}{sex} + - NumberTable: na_{race}{sex} + +Examples: + - race=W, sex=M => suffix "wm" (columns `le_wm`, `na_wm`) + - race=A, sex=F => suffix "af" (columns `le_af`, `na_af`) + - race=H, sex=A => suffix "ha" (columns `le_ha`, `na_ha`) +""" + +from __future__ import annotations + +from typing import Dict, Optional, Tuple +from sqlalchemy.orm import Session + +from app.models.pensions import LifeTable, NumberTable + + +_RACE_MAP: Dict[str, str] = { + "W": "w", # White + "B": "b", # Black + "H": "h", # Hispanic + "A": "a", # All races +} + +_SEX_MAP: Dict[str, str] = { + "M": "m", + "F": "f", + "A": "a", # All sexes +} + + +class InvalidCodeError(ValueError): + pass + + +def _normalize_codes(sex: str, race: str) -> Tuple[str, str, str]: + """Validate/normalize sex and race to construct the column suffix. + + Returns (suffix, sex_u, race_u) where suffix is lowercase like "wm". + Raises InvalidCodeError on invalid inputs. + """ + sex_u = (sex or "").strip().upper() + race_u = (race or "").strip().upper() + if sex_u not in _SEX_MAP: + raise InvalidCodeError(f"Invalid sex code '{sex}'. Expected one of: {', '.join(_SEX_MAP.keys())}") + if race_u not in _RACE_MAP: + raise InvalidCodeError(f"Invalid race code '{race}'. Expected one of: {', '.join(_RACE_MAP.keys())}") + return _RACE_MAP[race_u] + _SEX_MAP[sex_u], sex_u, race_u + + +def get_life_values( + db: Session, + *, + age: int, + sex: str, + race: str, +) -> Optional[Dict[str, Optional[float]]]: + """Return life table LE and NA values for a given age, sex, and race. + + Returns dict: {"age": int, "sex": str, "race": str, "le": float|None, "na": float|None} + Returns None if the age row does not exist. + Raises InvalidCodeError for invalid codes. + """ + suffix, sex_u, race_u = _normalize_codes(sex, race) + row: Optional[LifeTable] = db.query(LifeTable).filter(LifeTable.age == age).first() + if not row: + return None + + le_col = f"le_{suffix}" + na_col = f"na_{suffix}" + le_val = getattr(row, le_col, None) + na_val = getattr(row, na_col, None) + + return { + "age": int(age), + "sex": sex_u, + "race": race_u, + "le": float(le_val) if le_val is not None else None, + "na": float(na_val) if na_val is not None else None, + } + + +def get_number_value( + db: Session, + *, + month: int, + sex: str, + race: str, +) -> Optional[Dict[str, Optional[float]]]: + """Return number table NA value for a given month, sex, and race. + + Returns dict: {"month": int, "sex": str, "race": str, "na": float|None} + Returns None if the month row does not exist. + Raises InvalidCodeError for invalid codes. + """ + suffix, sex_u, race_u = _normalize_codes(sex, race) + row: Optional[NumberTable] = db.query(NumberTable).filter(NumberTable.month == month).first() + if not row: + return None + + na_col = f"na_{suffix}" + na_val = getattr(row, na_col, None) + + return { + "month": int(month), + "sex": sex_u, + "race": race_u, + "na": float(na_val) if na_val is not None else None, + } + + +__all__ = [ + "InvalidCodeError", + "get_life_values", + "get_number_value", +] + + diff --git a/app/services/query_utils.py b/app/services/query_utils.py new file mode 100644 index 0000000..7e37795 --- /dev/null +++ b/app/services/query_utils.py @@ -0,0 +1,72 @@ +from typing import Iterable, Optional, Sequence +from sqlalchemy import or_, and_, asc, desc, func +from sqlalchemy.sql.elements import BinaryExpression +from sqlalchemy.sql.schema import Column + + +def tokenized_ilike_filter(tokens: Sequence[str], columns: Sequence[Column]) -> Optional[BinaryExpression]: + """Build an AND-of-ORs case-insensitive LIKE filter across columns for each token. + + Example: AND(OR(col1 ILIKE %t1%, col2 ILIKE %t1%), OR(col1 ILIKE %t2%, ...)) + Returns None when tokens or columns are empty. + """ + if not tokens or not columns: + return None + per_token_clauses = [] + for term in tokens: + term = str(term or "").strip() + if not term: + continue + per_token_clauses.append(or_(*[c.ilike(f"%{term}%") for c in columns])) + if not per_token_clauses: + return None + return and_(*per_token_clauses) + + +def apply_pagination(query, skip: int, limit: int): + """Apply offset/limit pagination to a SQLAlchemy query in a DRY way.""" + return query.offset(skip).limit(limit) + + +def paginate_with_total(query, skip: int, limit: int, include_total: bool): + """Return (items, total|None) applying pagination and optionally counting total. + + This avoids duplicating count + pagination logic at each endpoint. + """ + total_count = query.count() if include_total else None + items = apply_pagination(query, skip, limit).all() + return items, total_count + + +def apply_sorting(query, sort_by: Optional[str], sort_dir: Optional[str], allowed: dict[str, list[Column]]): + """Apply case-insensitive sorting per a whitelist of allowed fields. + + allowed: mapping from field name -> list of columns to sort by, in priority order. + For string columns, compares using lower(column) for stable ordering. + Unknown sort_by falls back to the first key in allowed. + sort_dir: "asc" or "desc" (default asc) + """ + if not allowed: + return query + normalized_sort_by = (sort_by or next(iter(allowed.keys()))).lower() + normalized_sort_dir = (sort_dir or "asc").lower() + is_desc = normalized_sort_dir == "desc" + + columns = allowed.get(normalized_sort_by) + if not columns: + columns = allowed.get(next(iter(allowed.keys()))) + if not columns: + return query + + order_exprs = [] + for col in columns: + try: + expr = func.lower(col) if getattr(col.type, "python_type", str) is str else col + except Exception: + expr = col + order_exprs.append(desc(expr) if is_desc else asc(expr)) + if order_exprs: + query = query.order_by(*order_exprs) + return query + + diff --git a/e2e/global-setup.js b/e2e/global-setup.js new file mode 100644 index 0000000..77da348 --- /dev/null +++ b/e2e/global-setup.js @@ -0,0 +1,69 @@ +// Global setup to seed admin user before Playwright tests +const { spawnSync } = require('child_process'); +const fs = require('fs'); +const jwt = require('jsonwebtoken'); + +module.exports = async () => { + const SECRET_KEY = process.env.SECRET_KEY || 'x'.repeat(32); + const path = require('path'); + const dbPath = path.resolve(__dirname, '..', '.e2e-db.sqlite'); + const DATABASE_URL = process.env.DATABASE_URL || `sqlite:////${dbPath}`; + + // Ensure a clean database for deterministic tests + try { fs.rmSync(dbPath, { force: true }); } catch (_) {} + + const pyCode = ` +from sqlalchemy.orm import sessionmaker +from app.database.base import engine +from app.models import BaseModel +from app.models.user import User +from app.auth.security import get_password_hash +import os + +# Ensure tables +BaseModel.metadata.create_all(bind=engine) + +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) +db = SessionLocal() +try: + admin = db.query(User).filter(User.username=='admin').first() + if not admin: + admin = User( + username=os.getenv('ADMIN_USERNAME','admin'), + email=os.getenv('ADMIN_EMAIL','admin@delphicg.local'), + full_name=os.getenv('ADMIN_FULLNAME','System Administrator'), + hashed_password=get_password_hash(os.getenv('ADMIN_PASSWORD','admin123')), + is_active=True, + is_admin=True, + ) + db.add(admin) + db.commit() + print('Seeded admin user') + else: + print('Admin user already exists') +finally: + db.close() +`; + + const env = { + ...process.env, + SECRET_KEY, + DATABASE_URL, + ADMIN_EMAIL: 'admin@example.com', + ADMIN_USERNAME: 'admin', + ADMIN_PASSWORD: process.env.ADMIN_PASSWORD || 'admin123', + }; + let res = spawnSync('python3', ['-c', pyCode], { env, stdio: 'inherit' }); + if (res.error) { + res = spawnSync('python', ['-c', pyCode], { env, stdio: 'inherit' }); + if (res.error) throw res.error; + } + + // Pre-generate a valid access token to bypass login DB writes in tests + const token = jwt.sign({ sub: env.ADMIN_USERNAME, type: 'access' }, env.SECRET_KEY, { expiresIn: '4h' }); + // Persist to a file for the tests to read + const tokenPath = path.resolve(__dirname, '..', '.e2e-token'); + fs.writeFileSync(tokenPath, token, 'utf-8'); +}; + + diff --git a/e2e/search.e2e.spec.js b/e2e/search.e2e.spec.js new file mode 100644 index 0000000..cf3353c --- /dev/null +++ b/e2e/search.e2e.spec.js @@ -0,0 +1,239 @@ +// Playwright E2E tests for Advanced Search UI +const { test, expect } = require('@playwright/test'); + +async function loginAndSetTokens(page) { + // Read pre-generated access token + const fs = require('fs'); + const path = require('path'); + const tokenPath = path.resolve(__dirname, '..', '.e2e-token'); + const access = fs.readFileSync(tokenPath, 'utf-8').trim(); + const refresh = ''; + await page.addInitScript((a, r) => { + try { window.localStorage.setItem('auth_token', a); } catch (_) {} + try { if (r) window.localStorage.setItem('refresh_token', r); } catch (_) {} + }, access, refresh); + return access; +} + +async function apiCreateCustomer(page, payload, token) { + // Use import endpoint to avoid multiple writes and simplify schema + const req = await page.request.post('/api/import/customers', { + data: { customers: [payload] }, + headers: token ? { Authorization: `Bearer ${token}` } : {}, + }); + expect(req.ok()).toBeTruthy(); + // Return id directly + return payload.id; +} + +async function apiCreateFile(page, payload, token) { + const req = await page.request.post('/api/import/files', { + data: { files: [payload] }, + headers: token ? { Authorization: `Bearer ${token}` } : {}, + }); + expect(req.ok()).toBeTruthy(); + return payload.file_no; +} + +test.describe('Advanced Search UI', () => { + test.beforeEach(async ({ page }) => { + // no-op here; call per test to capture token + }); + + test('returns highlighted results and enforces XSS safety', async ({ page }) => { + const token = `E2E-${Date.now()}`; + const accessToken = await loginAndSetTokens(page); + const malicious = `${token} `; + await apiCreateCustomer(page, { + id: `E2E-CUST-${Date.now()}`, + first: 'Alice', + last: malicious, + email: `alice.${Date.now()}@example.com`, + city: 'Austin', + abrev: 'TX', + }, accessToken); + + await page.goto('/search'); + await page.fill('#searchQuery', token); + await page.click('#advancedSearchForm button[type="submit"]'); + await page.waitForResponse(res => res.url().includes('/api/search/advanced') && res.request().method() === 'POST'); + + const results = page.locator('#searchResults .search-result-item'); + await expect(results.first()).toBeVisible({ timeout: 10000 }); + + const matchHtml = page.locator('#searchResults .search-result-item .text-sm.text-info-600'); + if (await matchHtml.count()) { + const html = await matchHtml.first().innerHTML(); + expect(html).toContain(''); + expect(html).not.toContain('onerror'); + expect(html).not.toContain(' { + const token = `E2E-PAGE-${Date.now()}`; + const accessToken = await loginAndSetTokens(page); + const today = new Date().toISOString().slice(0, 10); + const ownerId = await apiCreateCustomer(page, { + id: `E2E-P-OWNER-${Date.now()}`, + first: 'Bob', + last: 'Pagination', + email: `bob.${Date.now()}@example.com`, + city: 'Austin', + abrev: 'TX', + }, accessToken); + for (let i = 0; i < 60; i++) { + await apiCreateFile(page, { + file_no: `E2E-F-${Date.now()}-${i}`, + id: ownerId, + regarding: `About ${token} #${i}`, + empl_num: 'E01', + file_type: 'CIVIL', + opened: today, + status: 'ACTIVE', + rate_per_hour: 150, + memo: 'seeded', + }, accessToken); + } + + await page.goto('/search'); + await page.fill('#searchQuery', token); + await page.click('#advancedSearchForm button[type="submit"]'); + await page.waitForResponse(res => res.url().includes('/api/search/advanced') && res.request().method() === 'POST'); + + const pager = page.locator('#searchPagination'); + await expect(pager).toBeVisible({ timeout: 10000 }); + const firstPageActive = page.locator('#searchPagination button.bg-primary-600'); + await expect(firstPageActive).toContainText('1'); + + const next = page.locator('#searchPagination button', { hasText: 'Next' }); + await Promise.all([ + page.waitForResponse((res) => res.url().includes('/api/search/advanced') && res.request().method() === 'POST'), + next.click(), + ]); + const active = page.locator('#searchPagination button.bg-primary-600'); + await expect(active).not.toContainText('1'); + }); + + test('suggestions dropdown renders safely and clicking populates input and triggers search', async ({ page }) => { + const token = `E2E-SUG-${Date.now()}`; + await loginAndSetTokens(page); + + const suggestionOne = `${token} first`; + const suggestionTwo = `${token} second`; + + // Stub the suggestions endpoint for our token + await page.route('**/api/search/suggestions*', async (route) => { + try { + const url = new URL(route.request().url()); + const q = url.searchParams.get('q') || ''; + if (q.includes(token)) { + return route.fulfill({ + status: 200, + contentType: 'application/json', + body: JSON.stringify({ + suggestions: [ + { text: suggestionOne, category: 'customer', description: 'Name match' }, + { text: suggestionTwo, category: 'file', description: 'File regarding' }, + ], + }), + }); + } + } catch (_) {} + return route.fallback(); + }); + + // Stub the advanced search to assert it gets triggered with clicked suggestion + let receivedQuery = null; + await page.route('**/api/search/advanced', async (route) => { + try { + const body = route.request().postDataJSON(); + receivedQuery = body?.query || null; + } catch (_) {} + return route.fulfill({ + status: 200, + contentType: 'application/json', + body: JSON.stringify({ + total_results: 0, + stats: { search_execution_time: 0.001 }, + facets: { customer: {}, file: {}, ledger: {}, qdro: {}, document: {}, phone: {} }, + results: [], + page_info: { current_page: 1, total_pages: 0, has_previous: false, has_next: false }, + }), + }); + }); + + await page.goto('/search'); + + // Type to trigger suggestions (debounced) + await page.fill('#searchQuery', token); + + const dropdown = page.locator('#searchSuggestions'); + const items = dropdown.locator('a'); + await expect(items).toHaveCount(2, { timeout: 5000 }); + await expect(dropdown).toBeVisible(); + + // Basic safety check — ensure no script tags ended up in suggestions markup + const dropdownHtml = await dropdown.innerHTML(); + expect(dropdownHtml).not.toContain(' res.url().includes('/api/search/advanced') && res.request().method() === 'POST'), + items.first().click(), + ]); + + await expect(page.locator('#searchQuery')).toHaveValue(new RegExp(`^${suggestionOne}`)); + expect(receivedQuery || '').toContain(suggestionOne); + }); + + test('Escape hides suggestions dropdown without triggering a search', async ({ page }) => { + const token = `E2E-ESC-${Date.now()}`; + await loginAndSetTokens(page); + + // Track whether advanced search is called + let calledAdvanced = false; + await page.route('**/api/search/advanced', async (route) => { + calledAdvanced = true; + return route.fulfill({ + status: 200, + contentType: 'application/json', + body: JSON.stringify({ + total_results: 0, + stats: { search_execution_time: 0.001 }, + facets: { customer: {}, file: {}, ledger: {}, qdro: {}, document: {}, phone: {} }, + results: [], + page_info: { current_page: 1, total_pages: 0, has_previous: false, has_next: false }, + }), + }); + }); + + // Stub suggestions so they appear + await page.route('**/api/search/suggestions*', async (route) => { + return route.fulfill({ + status: 200, + contentType: 'application/json', + body: JSON.stringify({ + suggestions: [ + { text: `${token} foo`, category: 'customer', description: '' }, + { text: `${token} bar`, category: 'file', description: '' }, + ], + }), + }); + }); + + await page.goto('/search'); + await page.fill('#searchQuery', token); + + const dropdown = page.locator('#searchSuggestions'); + await expect(dropdown.locator('a')).toHaveCount(2, { timeout: 5000 }); + await expect(dropdown).toBeVisible(); + + // Press Escape: should hide dropdown and not trigger search + await page.keyboard.press('Escape'); + await expect(dropdown).toHaveClass(/hidden/); + expect(calledAdvanced).toBeFalsy(); + }); +}); + + diff --git a/package-lock.json b/package-lock.json index 776bf53..377e275 100644 --- a/package-lock.json +++ b/package-lock.json @@ -10,10 +10,12 @@ "license": "ISC", "devDependencies": { "@jest/environment": "^30.0.5", + "@playwright/test": "^1.45.0", "@tailwindcss/forms": "^0.5.10", "jest": "^29.7.0", "jest-environment-jsdom": "^30.0.5", "jsdom": "^22.1.0", + "jsonwebtoken": "^9.0.2", "tailwindcss": "^3.4.10" } }, @@ -1720,6 +1722,22 @@ "node": ">=14" } }, + "node_modules/@playwright/test": { + "version": "1.54.2", + "resolved": "https://registry.npmjs.org/@playwright/test/-/test-1.54.2.tgz", + "integrity": "sha512-A+znathYxPf+72riFd1r1ovOLqsIIB0jKIoPjyK2kqEIe30/6jF6BC7QNluHuwUmsD2tv1XZVugN8GqfTMOxsA==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "playwright": "1.54.2" + }, + "bin": { + "playwright": "cli.js" + }, + "engines": { + "node": ">=18" + } + }, "node_modules/@sinclair/typebox": { "version": "0.34.38", "resolved": "https://registry.npmjs.org/@sinclair/typebox/-/typebox-0.34.38.tgz", @@ -2215,6 +2233,13 @@ "node-int64": "^0.4.0" } }, + "node_modules/buffer-equal-constant-time": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/buffer-equal-constant-time/-/buffer-equal-constant-time-1.0.1.tgz", + "integrity": "sha512-zRpUiDwd/xk6ADqPMATG8vc9VPrkck7T07OIx0gnjmJAnHnTVXNQG3vfvWNuiZIkwu9KrKdA1iJKfsfTVxE6NA==", + "dev": true, + "license": "BSD-3-Clause" + }, "node_modules/buffer-from": { "version": "1.1.2", "resolved": "https://registry.npmjs.org/buffer-from/-/buffer-from-1.1.2.tgz", @@ -2824,6 +2849,16 @@ "dev": true, "license": "MIT" }, + "node_modules/ecdsa-sig-formatter": { + "version": "1.0.11", + "resolved": "https://registry.npmjs.org/ecdsa-sig-formatter/-/ecdsa-sig-formatter-1.0.11.tgz", + "integrity": "sha512-nagl3RYrbNv6kQkeJIpt6NJZy8twLB/2vtz6yN9Z4vRKHN4/QZJIEbqohALSgwKdnksuY3k5Addp5lg8sVoVcQ==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "safe-buffer": "^5.0.1" + } + }, "node_modules/electron-to-chromium": { "version": "1.5.200", "resolved": "https://registry.npmjs.org/electron-to-chromium/-/electron-to-chromium-1.5.200.tgz", @@ -6040,6 +6075,65 @@ "node": ">=6" } }, + "node_modules/jsonwebtoken": { + "version": "9.0.2", + "resolved": "https://registry.npmjs.org/jsonwebtoken/-/jsonwebtoken-9.0.2.tgz", + "integrity": "sha512-PRp66vJ865SSqOlgqS8hujT5U4AOgMfhrwYIuIhfKaoSCZcirrmASQr8CX7cUg+RMih+hgznrjp99o+W4pJLHQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "jws": "^3.2.2", + "lodash.includes": "^4.3.0", + "lodash.isboolean": "^3.0.3", + "lodash.isinteger": "^4.0.4", + "lodash.isnumber": "^3.0.3", + "lodash.isplainobject": "^4.0.6", + "lodash.isstring": "^4.0.1", + "lodash.once": "^4.0.0", + "ms": "^2.1.1", + "semver": "^7.5.4" + }, + "engines": { + "node": ">=12", + "npm": ">=6" + } + }, + "node_modules/jsonwebtoken/node_modules/semver": { + "version": "7.7.2", + "resolved": "https://registry.npmjs.org/semver/-/semver-7.7.2.tgz", + "integrity": "sha512-RF0Fw+rO5AMf9MAyaRXI4AV0Ulj5lMHqVxxdSgiVbixSCXoEmmX/jk0CuJw4+3SqroYO9VoUh+HcuJivvtJemA==", + "dev": true, + "license": "ISC", + "bin": { + "semver": "bin/semver.js" + }, + "engines": { + "node": ">=10" + } + }, + "node_modules/jwa": { + "version": "1.4.2", + "resolved": "https://registry.npmjs.org/jwa/-/jwa-1.4.2.tgz", + "integrity": "sha512-eeH5JO+21J78qMvTIDdBXidBd6nG2kZjg5Ohz/1fpa28Z4CcsWUzJ1ZZyFq/3z3N17aZy+ZuBoHljASbL1WfOw==", + "dev": true, + "license": "MIT", + "dependencies": { + "buffer-equal-constant-time": "^1.0.1", + "ecdsa-sig-formatter": "1.0.11", + "safe-buffer": "^5.0.1" + } + }, + "node_modules/jws": { + "version": "3.2.2", + "resolved": "https://registry.npmjs.org/jws/-/jws-3.2.2.tgz", + "integrity": "sha512-YHlZCB6lMTllWDtSPHz/ZXTsi8S00usEV6v1tjq8tOUZzw7DpSDWVXjXDre6ed1w/pd495ODpHZYSdkRTsa0HA==", + "dev": true, + "license": "MIT", + "dependencies": { + "jwa": "^1.4.1", + "safe-buffer": "^5.0.1" + } + }, "node_modules/kleur": { "version": "3.0.3", "resolved": "https://registry.npmjs.org/kleur/-/kleur-3.0.3.tgz", @@ -6090,6 +6184,55 @@ "node": ">=8" } }, + "node_modules/lodash.includes": { + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/lodash.includes/-/lodash.includes-4.3.0.tgz", + "integrity": "sha512-W3Bx6mdkRTGtlJISOvVD/lbqjTlPPUDTMnlXZFnVwi9NKJ6tiAk6LVdlhZMm17VZisqhKcgzpO5Wz91PCt5b0w==", + "dev": true, + "license": "MIT" + }, + "node_modules/lodash.isboolean": { + "version": "3.0.3", + "resolved": "https://registry.npmjs.org/lodash.isboolean/-/lodash.isboolean-3.0.3.tgz", + "integrity": "sha512-Bz5mupy2SVbPHURB98VAcw+aHh4vRV5IPNhILUCsOzRmsTmSQ17jIuqopAentWoehktxGd9e/hbIXq980/1QJg==", + "dev": true, + "license": "MIT" + }, + "node_modules/lodash.isinteger": { + "version": "4.0.4", + "resolved": "https://registry.npmjs.org/lodash.isinteger/-/lodash.isinteger-4.0.4.tgz", + "integrity": "sha512-DBwtEWN2caHQ9/imiNeEA5ys1JoRtRfY3d7V9wkqtbycnAmTvRRmbHKDV4a0EYc678/dia0jrte4tjYwVBaZUA==", + "dev": true, + "license": "MIT" + }, + "node_modules/lodash.isnumber": { + "version": "3.0.3", + "resolved": "https://registry.npmjs.org/lodash.isnumber/-/lodash.isnumber-3.0.3.tgz", + "integrity": "sha512-QYqzpfwO3/CWf3XP+Z+tkQsfaLL/EnUlXWVkIk5FUPc4sBdTehEqZONuyRt2P67PXAk+NXmTBcc97zw9t1FQrw==", + "dev": true, + "license": "MIT" + }, + "node_modules/lodash.isplainobject": { + "version": "4.0.6", + "resolved": "https://registry.npmjs.org/lodash.isplainobject/-/lodash.isplainobject-4.0.6.tgz", + "integrity": "sha512-oSXzaWypCMHkPC3NvBEaPHf0KsA5mvPrOPgQWDsbg8n7orZ290M0BmC/jgRZ4vcJ6DTAhjrsSYgdsW/F+MFOBA==", + "dev": true, + "license": "MIT" + }, + "node_modules/lodash.isstring": { + "version": "4.0.1", + "resolved": "https://registry.npmjs.org/lodash.isstring/-/lodash.isstring-4.0.1.tgz", + "integrity": "sha512-0wJxfxH1wgO3GrbuP+dTTk7op+6L41QCXbGINEmD+ny/G/eCqGzxyCsh7159S+mgDDcoarnBw6PC1PS5+wUGgw==", + "dev": true, + "license": "MIT" + }, + "node_modules/lodash.once": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/lodash.once/-/lodash.once-4.1.1.tgz", + "integrity": "sha512-Sb487aTOCr9drQVL8pIxOzVhafOjZN9UU54hiN8PU3uAiSV7lx1yYNpbNmex2PK6dSJoNTSJUUswT651yww3Mg==", + "dev": true, + "license": "MIT" + }, "node_modules/lru-cache": { "version": "10.4.3", "resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-10.4.3.tgz", @@ -6582,6 +6725,53 @@ "node": ">=8" } }, + "node_modules/playwright": { + "version": "1.54.2", + "resolved": "https://registry.npmjs.org/playwright/-/playwright-1.54.2.tgz", + "integrity": "sha512-Hu/BMoA1NAdRUuulyvQC0pEqZ4vQbGfn8f7wPXcnqQmM+zct9UliKxsIkLNmz/ku7LElUNqmaiv1TG/aL5ACsw==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "playwright-core": "1.54.2" + }, + "bin": { + "playwright": "cli.js" + }, + "engines": { + "node": ">=18" + }, + "optionalDependencies": { + "fsevents": "2.3.2" + } + }, + "node_modules/playwright-core": { + "version": "1.54.2", + "resolved": "https://registry.npmjs.org/playwright-core/-/playwright-core-1.54.2.tgz", + "integrity": "sha512-n5r4HFbMmWsB4twG7tJLDN9gmBUeSPcsBZiWSE4DnYz9mJMAFqr2ID7+eGC9kpEnxExJ1epttwR59LEWCk8mtA==", + "dev": true, + "license": "Apache-2.0", + "bin": { + "playwright-core": "cli.js" + }, + "engines": { + "node": ">=18" + } + }, + "node_modules/playwright/node_modules/fsevents": { + "version": "2.3.2", + "resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.2.tgz", + "integrity": "sha512-xiqMQR4xAeHTuB9uWm+fFRcIOgKBMiOBP+eXiyT7jsgVCq1bkVygt00oASowB7EdtpOHaaPgKt812P9ab+DDKA==", + "dev": true, + "hasInstallScript": true, + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": "^8.16.0 || ^10.6.0 || >=11.0.0" + } + }, "node_modules/postcss": { "version": "8.5.6", "resolved": "https://registry.npmjs.org/postcss/-/postcss-8.5.6.tgz", @@ -7018,6 +7208,27 @@ "queue-microtask": "^1.2.2" } }, + "node_modules/safe-buffer": { + "version": "5.2.1", + "resolved": "https://registry.npmjs.org/safe-buffer/-/safe-buffer-5.2.1.tgz", + "integrity": "sha512-rp3So07KcdmmKbGvgaNxQSJr7bGVSVk5S9Eq1F+ppbRo70+YeaDxkw5Dd8NPN+GD6bjnYm2VuPuCXmpuYvmCXQ==", + "dev": true, + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/feross" + }, + { + "type": "patreon", + "url": "https://www.patreon.com/feross" + }, + { + "type": "consulting", + "url": "https://feross.org/support" + } + ], + "license": "MIT" + }, "node_modules/safer-buffer": { "version": "2.1.2", "resolved": "https://registry.npmjs.org/safer-buffer/-/safer-buffer-2.1.2.tgz", diff --git a/package.json b/package.json index e957726..85fe1f8 100644 --- a/package.json +++ b/package.json @@ -4,7 +4,11 @@ "description": "A modern Python web application built with FastAPI to replace the legacy Pascal-based database system. This system maintains the familiar keyboard shortcuts and workflows while providing a robust, modular backend with a clean web interface.", "main": "tailwind.config.js", "scripts": { - "test": "jest" + "test": "jest", + "e2e": "playwright test", + "e2e:headed": "playwright test --headed", + "e2e:debug": "PWDEBUG=1 playwright test", + "e2e:install": "playwright install --with-deps" }, "repository": { "type": "git", @@ -23,6 +27,8 @@ "devDependencies": { "@jest/environment": "^30.0.5", "@tailwindcss/forms": "^0.5.10", + "@playwright/test": "^1.45.0", + "jsonwebtoken": "^9.0.2", "jest": "^29.7.0", "jest-environment-jsdom": "^30.0.5", "jsdom": "^22.1.0", diff --git a/playwright.config.js b/playwright.config.js new file mode 100644 index 0000000..8cb1bf2 --- /dev/null +++ b/playwright.config.js @@ -0,0 +1,34 @@ +// @ts-check +const { defineConfig } = require('@playwright/test'); +const path = require('path'); +const DB_ABS_PATH = path.resolve(__dirname, '.e2e-db.sqlite'); + +module.exports = defineConfig({ + testDir: './e2e', + fullyParallel: true, + retries: process.env.CI ? 2 : 0, + workers: process.env.CI ? 2 : undefined, + use: { + baseURL: process.env.PW_BASE_URL || 'http://127.0.0.1:6123', + trace: 'on-first-retry', + }, + globalSetup: require.resolve('./e2e/global-setup.js'), + webServer: { + command: 'uvicorn app.main:app --host 127.0.0.1 --port 6123', + env: { + SECRET_KEY: 'xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx', + DATABASE_URL: `sqlite:////${DB_ABS_PATH}`, + LOG_LEVEL: 'WARNING', + DISABLE_LOG_ENQUEUE: '1', + LOG_TO_FILE: 'False', + ADMIN_EMAIL: 'admin@example.com', + ADMIN_USERNAME: 'admin', + ADMIN_PASSWORD: process.env.ADMIN_PASSWORD || 'admin123', + }, + url: 'http://127.0.0.1:6123/health', + reuseExistingServer: !process.env.CI, + timeout: 60 * 1000, + }, +}); + + diff --git a/requirements.txt b/requirements.txt index 25d7f37..3d787bd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -34,4 +34,7 @@ httpx==0.28.1 python-dotenv==1.0.1 # Logging -loguru==0.7.2 \ No newline at end of file +loguru==0.7.2 + +# Caching (optional) +redis==5.0.8 \ No newline at end of file diff --git a/static/js/__tests__/search_snippet.ui.test.js b/static/js/__tests__/search_snippet.ui.test.js new file mode 100644 index 0000000..1b04265 --- /dev/null +++ b/static/js/__tests__/search_snippet.ui.test.js @@ -0,0 +1,50 @@ +/** @jest-environment jsdom */ + +// Load sanitizer and highlight utils used by the UI +require('../sanitizer.js'); +require('../highlight.js'); + +describe('Search highlight integration (server snippet rendering)', () => { + const { formatSnippet, highlight, buildTokens } = window.highlightUtils; + + test('formatSnippet preserves server and sanitizes dangerous HTML', () => { + const tokens = buildTokens('alpha'); + const serverSnippet = 'Hello Alpha link'; + const html = formatSnippet(serverSnippet, tokens); + // Server-provided strong is preserved + expect(html).toContain('Alpha'); + // Dangerous attributes removed + expect(html).not.toContain('onerror='); + // javascript: protocol removed + expect(html.toLowerCase()).not.toContain('href="javascript:'); + // Image tag should remain but sanitized (no onerror) + expect(html).toContain(' { + const container = document.createElement('div'); + const rawHtml = '
Text bold
'; + // Using global helper installed by sanitizer.js + window.setSafeHTML(container, rawHtml); + // Script tags removed + expect(container.innerHTML).not.toContain(' Alpha & Beta', ['alpha', 'beta']) + # Tags are escaped; only wrappers exist + assert '<script>alert(1)</script>' in out + assert 'Alpha' in out + assert 'Beta' in out + assert '' not in out + + +def test_highlight_text_handles_quotes_and_apostrophes_safely(): + out = highlight_text('He said "Hello" & it\'s fine', ['hello']) + # Quotes and ampersand should be escaped + assert '"Hello"' in out + assert ''s' in out + assert '&' in out + + +def test_highlight_text_no_tokens_returns_escaped_source(): + out = highlight_text('bold', []) + assert out == '<b>bold</b>' + diff --git a/tests/test_search_sort_documents.py b/tests/test_search_sort_documents.py new file mode 100644 index 0000000..0adee0a --- /dev/null +++ b/tests/test_search_sort_documents.py @@ -0,0 +1,101 @@ +import os +import uuid +from datetime import date + +import pytest +from fastapi.testclient import TestClient + +os.environ.setdefault("SECRET_KEY", "x" * 32) +os.environ.setdefault("DATABASE_URL", "sqlite:////tmp/delphi_test.sqlite") + +from app.main import app # noqa: E402 +from app.auth.security import get_current_user # noqa: E402 + + +class _User: + def __init__(self): + self.id = 1 + self.username = "tester" + self.is_admin = True + self.is_active = True + + +@pytest.fixture() +def client(): + app.dependency_overrides[get_current_user] = lambda: _User() + try: + yield TestClient(app) + finally: + app.dependency_overrides.pop(get_current_user, None) + + +def _create_customer_and_file(client: TestClient): + cust_id = f"DOCSS-{uuid.uuid4().hex[:8]}" + resp = client.post("/api/customers/", json={"id": cust_id, "last": "DocSS", "email": "dss@example.com"}) + assert resp.status_code == 200 + file_no = f"D-{uuid.uuid4().hex[:6]}" + payload = { + "file_no": file_no, + "id": cust_id, + "regarding": "Doc matter", + "empl_num": "E01", + "file_type": "CIVIL", + "opened": date.today().isoformat(), + "status": "ACTIVE", + "rate_per_hour": 100.0, + } + resp = client.post("/api/files/", json=payload) + assert resp.status_code == 200 + return cust_id, file_no + + +def test_templates_tokenized_search_and_sort(client: TestClient): + # Create templates + t1 = f"TMP-{uuid.uuid4().hex[:6]}" + t2 = f"TMP-{uuid.uuid4().hex[:6]}" + + resp = client.post( + "/api/documents/templates/", + json={"form_id": t1, "form_name": "Alpha Letter", "category": "GENERAL", "content": "Hello"}, + ) + assert resp.status_code == 200 + resp = client.post( + "/api/documents/templates/", + json={"form_id": t2, "form_name": "Beta Memo", "category": "GENERAL", "content": "Hello"}, + ) + assert resp.status_code == 200 + + # Tokenized search for both tokens only matches when both present + resp = client.get("/api/documents/templates/", params={"search": "Alpha Letter"}) + assert resp.status_code == 200 + items = resp.json() + ids = {i["form_id"] for i in items} + assert t1 in ids and t2 not in ids + + # Sorting by form_name desc + resp = client.get("/api/documents/templates/", params={"sort_by": "form_name", "sort_dir": "desc"}) + assert resp.status_code == 200 + items = resp.json() + if len(items) >= 2: + assert items[0]["form_name"] >= items[1]["form_name"] + + +def test_qdros_tokenized_search(client: TestClient): + _, file_no = _create_customer_and_file(client) + # Create QDROs + q1 = {"file_no": file_no, "version": "01", "status": "DRAFT", "form_name": "Alpha Order", "notes": "Beta token present"} + q2 = {"file_no": file_no, "version": "02", "status": "DRAFT", "form_name": "Gamma", "notes": "Beta only"} + resp = client.post("/api/documents/qdros/", json=q1) + assert resp.status_code == 200 + resp = client.post("/api/documents/qdros/", json=q2) + assert resp.status_code == 200 + + # Only the one containing both tokens should match + resp = client.get("/api/documents/qdros/", params={"search": "Alpha Beta"}) + assert resp.status_code == 200 + items = resp.json() + names = {i.get("form_name") for i in items} + assert "Alpha Order" in names + assert "Gamma" not in names + + diff --git a/tests/test_search_sort_files.py b/tests/test_search_sort_files.py new file mode 100644 index 0000000..6bfc566 --- /dev/null +++ b/tests/test_search_sort_files.py @@ -0,0 +1,94 @@ +import os +import uuid +from datetime import date, timedelta + +import pytest +from fastapi.testclient import TestClient + +os.environ.setdefault("SECRET_KEY", "x" * 32) +os.environ.setdefault("DATABASE_URL", "sqlite:////tmp/delphi_test.sqlite") + +from app.main import app # noqa: E402 +from app.auth.security import get_current_user # noqa: E402 + + +class _User: + def __init__(self): + self.id = "test" + self.username = "tester" + self.is_admin = True + self.is_active = True + + +@pytest.fixture() +def client(): + app.dependency_overrides[get_current_user] = lambda: _User() + try: + yield TestClient(app) + finally: + app.dependency_overrides.pop(get_current_user, None) + + +def _create_customer(client: TestClient) -> str: + cid = f"FSSR-{uuid.uuid4().hex[:8]}" + resp = client.post("/api/customers/", json={"id": cid, "last": "SearchSort", "email": f"{cid}@example.com"}) + assert resp.status_code == 200 + return cid + + +def _create_file(client: TestClient, file_no: str, owner_id: str, regarding: str, opened: date): + payload = { + "file_no": file_no, + "id": owner_id, + "regarding": regarding, + "empl_num": "E01", + "file_type": "CIVIL", + "opened": opened.isoformat(), + "status": "ACTIVE", + "rate_per_hour": 100.0, + "memo": "test search/sort", + } + resp = client.post("/api/files/", json=payload) + assert resp.status_code == 200 + + +def test_files_tokenized_search_sort_and_pagination(client: TestClient): + owner_id = _create_customer(client) + base_day = date.today() + f1 = f"FS-{uuid.uuid4().hex[:6]}" + f2 = f"FS-{uuid.uuid4().hex[:6]}" + + # f1 contains both tokens across a single field + _create_file(client, f1, owner_id, regarding="Alpha project Beta milestone", opened=base_day - timedelta(days=1)) + # f2 contains only one token + _create_file(client, f2, owner_id, regarding="Only Alpha token here", opened=base_day) + + # Tokenized search: both tokens required (AND-of-OR across fields) + resp = client.get("/api/files/", params={"search": "Alpha Beta"}) + assert resp.status_code == 200 + items = resp.json() + file_nos = {it["file_no"] for it in items} + assert f1 in file_nos and f2 not in file_nos + + # Sorting by opened desc should put f2 first if both were present; we restrict to both-token result (just f1) + resp = client.get("/api/files/", params={"search": "Alpha Beta", "sort_by": "opened", "sort_dir": "desc"}) + assert resp.status_code == 200 + items = resp.json() + assert len(items) >= 1 and items[0]["file_no"] == f1 + + # Pagination over a broader query (single-token) to verify skip/limit + resp = client.get( + "/api/files/", + params={"search": "Alpha", "sort_by": "file_no", "sort_dir": "asc", "limit": 1, "skip": 0}, + ) + assert resp.status_code == 200 + first_page = resp.json() + assert len(first_page) == 1 + resp = client.get( + "/api/files/", + params={"search": "Alpha", "sort_by": "file_no", "sort_dir": "asc", "limit": 1, "skip": 1}, + ) + second_page = resp.json() + assert len(second_page) >= 0 # may be 0 or 1 depending on other fixtures + + diff --git a/tests/test_search_validation.py b/tests/test_search_validation.py new file mode 100644 index 0000000..9b05882 --- /dev/null +++ b/tests/test_search_validation.py @@ -0,0 +1,177 @@ +import os +import sys +from pathlib import Path + +import pytest +from fastapi.testclient import TestClient + +# Ensure required env vars for app import/config +os.environ.setdefault("SECRET_KEY", "x" * 32) +os.environ.setdefault("DATABASE_URL", "sqlite:////tmp/delphi_test.sqlite") + +# Ensure repository root on sys.path for direct test runs +ROOT = Path(__file__).resolve().parents[1] +if str(ROOT) not in sys.path: + sys.path.insert(0, str(ROOT)) + +from app.main import app # noqa: E402 +from app.auth.security import get_current_user # noqa: E402 +from tests.helpers import assert_validation_error # noqa: E402 +from app.config import settings # noqa: E402 + + +@pytest.fixture(scope="module") +def client(): + # Override auth to bypass JWT for these tests + class _User: + def __init__(self): + self.id = "test" + self.username = "tester" + self.is_admin = True + self.is_active = True + + app.dependency_overrides[get_current_user] = lambda: _User() + + # Disable cache to make validation tests deterministic + settings.cache_enabled = False + + try: + yield TestClient(app) + finally: + app.dependency_overrides.pop(get_current_user, None) + + +def test_advanced_search_invalid_search_types(client: TestClient): + payload = { + "query": "anything", + "search_types": ["customer", "bogus"], + } + resp = client.post("/api/search/advanced", json=payload) + assert_validation_error(resp, "search_types") + + +def test_advanced_search_invalid_sort_options(client: TestClient): + # Invalid sort_by + payload = { + "query": "x", + "search_types": ["customer"], + "sort_by": "nope", + } + resp = client.post("/api/search/advanced", json=payload) + assert_validation_error(resp, "sort_by") + + # Invalid sort_order + payload = { + "query": "x", + "search_types": ["customer"], + "sort_order": "sideways", + } + resp = client.post("/api/search/advanced", json=payload) + assert_validation_error(resp, "sort_order") + + +def test_advanced_search_limit_bounds(client: TestClient): + # Too low + payload = { + "query": "x", + "search_types": ["customer"], + "limit": 0, + } + resp = client.post("/api/search/advanced", json=payload) + assert_validation_error(resp, "limit") + + # Too high + payload["limit"] = 201 + resp = client.post("/api/search/advanced", json=payload) + assert_validation_error(resp, "limit") + + +def test_advanced_search_conflicting_flags_exact_phrase_and_whole_words(client: TestClient): + payload = { + "query": "apple pie", + "search_types": ["file"], + "exact_phrase": True, + "whole_words": True, + } + resp = client.post("/api/search/advanced", json=payload) + # Cannot rely on field location for model-level validation, check message text in details + assert resp.status_code == 422 + body = resp.json() + assert body.get("success") is False + assert body.get("error", {}).get("code") == "validation_error" + msgs = [d.get("msg", "") for d in body.get("error", {}).get("details", [])] + assert any("exact_phrase and whole_words" in m for m in msgs) + + +def test_advanced_search_inverted_date_range(client: TestClient): + payload = { + "search_types": ["file"], + "date_field": "created", + "date_from": "2024-02-01", + "date_to": "2024-01-31", + } + resp = client.post("/api/search/advanced", json=payload) + assert resp.status_code == 422 + body = resp.json() + assert body.get("success") is False + assert body.get("error", {}).get("code") == "validation_error" + msgs = [d.get("msg", "") for d in body.get("error", {}).get("details", [])] + assert any("date_from must be less than or equal to date_to" in m for m in msgs) + + +def test_advanced_search_inverted_amount_range(client: TestClient): + payload = { + "search_types": ["file"], + "amount_field": "amount", + "amount_min": 100.0, + "amount_max": 50.0, + } + resp = client.post("/api/search/advanced", json=payload) + assert resp.status_code == 422 + body = resp.json() + assert body.get("success") is False + assert body.get("error", {}).get("code") == "validation_error" + msgs = [d.get("msg", "") for d in body.get("error", {}).get("details", [])] + assert any("amount_min must be less than or equal to amount_max" in m for m in msgs) + + +def test_advanced_search_date_field_supported_per_type(client: TestClient): + # 'opened' is only valid for files + payload = { + "search_types": ["customer", "ledger"], + "date_field": "opened", + "date_from": "2024-01-01", + "date_to": "2024-12-31", + } + resp = client.post("/api/search/advanced", json=payload) + assert resp.status_code == 422 + body = resp.json() + msgs = [d.get("msg", "") for d in body.get("error", {}).get("details", [])] + assert any("date_field 'opened' is not supported" in m for m in msgs) + + # Valid when 'file' included + payload["search_types"] = ["file"] + resp = client.post("/api/search/advanced", json=payload) + assert resp.status_code == 200 + + +def test_advanced_search_amount_field_supported_per_type(client: TestClient): + # 'amount' is only valid for ledger + payload = { + "search_types": ["file"], + "amount_field": "amount", + "amount_min": 1, + "amount_max": 10, + } + resp = client.post("/api/search/advanced", json=payload) + assert resp.status_code == 422 + body = resp.json() + msgs = [d.get("msg", "") for d in body.get("error", {}).get("details", [])] + assert any("amount_field 'amount' is not supported" in m for m in msgs) + + # Valid when 'ledger' included + payload["search_types"] = ["ledger"] + resp = client.post("/api/search/advanced", json=payload) + assert resp.status_code == 200 + + diff --git a/tests/test_support_api.py b/tests/test_support_api.py index 6fda9f0..cb9d731 100644 --- a/tests/test_support_api.py +++ b/tests/test_support_api.py @@ -119,4 +119,9 @@ def test_ticket_lifecycle_and_404s_with_audit(client: TestClient): assert resp.status_code == 200 assert isinstance(resp.json(), list) + # Search should filter results + resp = client.get("/api/support/tickets", params={"search": "Support issue"}) + assert resp.status_code == 200 + assert isinstance(resp.json(), list) +