fixes and refactor
This commit is contained in:
@@ -4,7 +4,7 @@ Customer (Rolodex) API endpoints
|
||||
from typing import List, Optional, Union
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
from sqlalchemy import or_, and_, func, asc, desc
|
||||
from sqlalchemy import func
|
||||
from fastapi.responses import StreamingResponse
|
||||
import csv
|
||||
import io
|
||||
@@ -13,12 +13,16 @@ from app.database.base import get_db
|
||||
from app.models.rolodex import Rolodex, Phone
|
||||
from app.models.user import User
|
||||
from app.auth.security import get_current_user
|
||||
from app.services.cache import invalidate_search_cache
|
||||
from app.services.customers_search import apply_customer_filters, apply_customer_sorting, prepare_customer_csv_rows
|
||||
from app.services.query_utils import apply_sorting, paginate_with_total
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# Pydantic schemas for request/response
|
||||
from pydantic import BaseModel, EmailStr
|
||||
from pydantic import BaseModel, EmailStr, Field
|
||||
from pydantic.config import ConfigDict
|
||||
from datetime import date
|
||||
|
||||
|
||||
@@ -32,8 +36,7 @@ class PhoneResponse(BaseModel):
|
||||
location: Optional[str]
|
||||
phone: str
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class CustomerBase(BaseModel):
|
||||
@@ -84,10 +87,11 @@ class CustomerUpdate(BaseModel):
|
||||
|
||||
|
||||
class CustomerResponse(CustomerBase):
|
||||
phone_numbers: List[PhoneResponse] = []
|
||||
phone_numbers: List[PhoneResponse] = Field(default_factory=list)
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
|
||||
|
||||
@router.get("/search/phone")
|
||||
@@ -196,80 +200,17 @@ async def list_customers(
|
||||
try:
|
||||
base_query = db.query(Rolodex)
|
||||
|
||||
if search:
|
||||
s = (search or "").strip()
|
||||
s_lower = s.lower()
|
||||
tokens = [t for t in s_lower.split() if t]
|
||||
# Basic contains search on several fields (case-insensitive)
|
||||
contains_any = or_(
|
||||
func.lower(Rolodex.id).contains(s_lower),
|
||||
func.lower(Rolodex.last).contains(s_lower),
|
||||
func.lower(Rolodex.first).contains(s_lower),
|
||||
func.lower(Rolodex.middle).contains(s_lower),
|
||||
func.lower(Rolodex.city).contains(s_lower),
|
||||
func.lower(Rolodex.email).contains(s_lower),
|
||||
)
|
||||
# Multi-token name support: every token must match either first, middle, or last
|
||||
name_tokens = [
|
||||
or_(
|
||||
func.lower(Rolodex.first).contains(tok),
|
||||
func.lower(Rolodex.middle).contains(tok),
|
||||
func.lower(Rolodex.last).contains(tok),
|
||||
)
|
||||
for tok in tokens
|
||||
]
|
||||
combined = contains_any if not name_tokens else or_(contains_any, and_(*name_tokens))
|
||||
# Comma pattern: "Last, First"
|
||||
last_first_filter = None
|
||||
if "," in s_lower:
|
||||
last_part, first_part = [p.strip() for p in s_lower.split(",", 1)]
|
||||
if last_part and first_part:
|
||||
last_first_filter = and_(
|
||||
func.lower(Rolodex.last).contains(last_part),
|
||||
func.lower(Rolodex.first).contains(first_part),
|
||||
)
|
||||
elif last_part:
|
||||
last_first_filter = func.lower(Rolodex.last).contains(last_part)
|
||||
final_filter = or_(combined, last_first_filter) if last_first_filter is not None else combined
|
||||
base_query = base_query.filter(final_filter)
|
||||
|
||||
# Apply group/state filters (support single and multi-select)
|
||||
effective_groups = [g for g in (groups or []) if g] or ([group] if group else [])
|
||||
if effective_groups:
|
||||
base_query = base_query.filter(Rolodex.group.in_(effective_groups))
|
||||
effective_states = [s for s in (states or []) if s] or ([state] if state else [])
|
||||
if effective_states:
|
||||
base_query = base_query.filter(Rolodex.abrev.in_(effective_states))
|
||||
base_query = apply_customer_filters(
|
||||
base_query,
|
||||
search=search,
|
||||
group=group,
|
||||
state=state,
|
||||
groups=groups,
|
||||
states=states,
|
||||
)
|
||||
|
||||
# Apply sorting (whitelisted fields only)
|
||||
normalized_sort_by = (sort_by or "id").lower()
|
||||
normalized_sort_dir = (sort_dir or "asc").lower()
|
||||
is_desc = normalized_sort_dir == "desc"
|
||||
|
||||
order_columns = []
|
||||
if normalized_sort_by == "id":
|
||||
order_columns = [Rolodex.id]
|
||||
elif normalized_sort_by == "name":
|
||||
# Sort by last, then first
|
||||
order_columns = [Rolodex.last, Rolodex.first]
|
||||
elif normalized_sort_by == "city":
|
||||
# Sort by city, then state abbreviation
|
||||
order_columns = [Rolodex.city, Rolodex.abrev]
|
||||
elif normalized_sort_by == "email":
|
||||
order_columns = [Rolodex.email]
|
||||
else:
|
||||
# Fallback to id to avoid arbitrary column injection
|
||||
order_columns = [Rolodex.id]
|
||||
|
||||
# Case-insensitive ordering where applicable, preserving None ordering default
|
||||
ordered = []
|
||||
for col in order_columns:
|
||||
# Use lower() for string-like cols; SQLAlchemy will handle non-string safely enough for SQLite/Postgres
|
||||
expr = func.lower(col) if col.type.python_type in (str,) else col # type: ignore[attr-defined]
|
||||
ordered.append(desc(expr) if is_desc else asc(expr))
|
||||
|
||||
if ordered:
|
||||
base_query = base_query.order_by(*ordered)
|
||||
base_query = apply_customer_sorting(base_query, sort_by=sort_by, sort_dir=sort_dir)
|
||||
|
||||
customers = base_query.options(joinedload(Rolodex.phone_numbers)).offset(skip).limit(limit).all()
|
||||
if include_total:
|
||||
@@ -304,72 +245,16 @@ async def export_customers(
|
||||
try:
|
||||
base_query = db.query(Rolodex)
|
||||
|
||||
if search:
|
||||
s = (search or "").strip()
|
||||
s_lower = s.lower()
|
||||
tokens = [t for t in s_lower.split() if t]
|
||||
contains_any = or_(
|
||||
func.lower(Rolodex.id).contains(s_lower),
|
||||
func.lower(Rolodex.last).contains(s_lower),
|
||||
func.lower(Rolodex.first).contains(s_lower),
|
||||
func.lower(Rolodex.middle).contains(s_lower),
|
||||
func.lower(Rolodex.city).contains(s_lower),
|
||||
func.lower(Rolodex.email).contains(s_lower),
|
||||
)
|
||||
name_tokens = [
|
||||
or_(
|
||||
func.lower(Rolodex.first).contains(tok),
|
||||
func.lower(Rolodex.middle).contains(tok),
|
||||
func.lower(Rolodex.last).contains(tok),
|
||||
)
|
||||
for tok in tokens
|
||||
]
|
||||
combined = contains_any if not name_tokens else or_(contains_any, and_(*name_tokens))
|
||||
last_first_filter = None
|
||||
if "," in s_lower:
|
||||
last_part, first_part = [p.strip() for p in s_lower.split(",", 1)]
|
||||
if last_part and first_part:
|
||||
last_first_filter = and_(
|
||||
func.lower(Rolodex.last).contains(last_part),
|
||||
func.lower(Rolodex.first).contains(first_part),
|
||||
)
|
||||
elif last_part:
|
||||
last_first_filter = func.lower(Rolodex.last).contains(last_part)
|
||||
final_filter = or_(combined, last_first_filter) if last_first_filter is not None else combined
|
||||
base_query = base_query.filter(final_filter)
|
||||
base_query = apply_customer_filters(
|
||||
base_query,
|
||||
search=search,
|
||||
group=group,
|
||||
state=state,
|
||||
groups=groups,
|
||||
states=states,
|
||||
)
|
||||
|
||||
effective_groups = [g for g in (groups or []) if g] or ([group] if group else [])
|
||||
if effective_groups:
|
||||
base_query = base_query.filter(Rolodex.group.in_(effective_groups))
|
||||
effective_states = [s for s in (states or []) if s] or ([state] if state else [])
|
||||
if effective_states:
|
||||
base_query = base_query.filter(Rolodex.abrev.in_(effective_states))
|
||||
|
||||
normalized_sort_by = (sort_by or "id").lower()
|
||||
normalized_sort_dir = (sort_dir or "asc").lower()
|
||||
is_desc = normalized_sort_dir == "desc"
|
||||
|
||||
order_columns = []
|
||||
if normalized_sort_by == "id":
|
||||
order_columns = [Rolodex.id]
|
||||
elif normalized_sort_by == "name":
|
||||
order_columns = [Rolodex.last, Rolodex.first]
|
||||
elif normalized_sort_by == "city":
|
||||
order_columns = [Rolodex.city, Rolodex.abrev]
|
||||
elif normalized_sort_by == "email":
|
||||
order_columns = [Rolodex.email]
|
||||
else:
|
||||
order_columns = [Rolodex.id]
|
||||
|
||||
ordered = []
|
||||
for col in order_columns:
|
||||
try:
|
||||
expr = func.lower(col) if col.type.python_type in (str,) else col # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
expr = col
|
||||
ordered.append(desc(expr) if is_desc else asc(expr))
|
||||
if ordered:
|
||||
base_query = base_query.order_by(*ordered)
|
||||
base_query = apply_customer_sorting(base_query, sort_by=sort_by, sort_dir=sort_dir)
|
||||
|
||||
if not export_all:
|
||||
if skip is not None:
|
||||
@@ -382,39 +267,10 @@ async def export_customers(
|
||||
# Prepare CSV
|
||||
output = io.StringIO()
|
||||
writer = csv.writer(output)
|
||||
allowed_fields_in_order = ["id", "name", "group", "city", "state", "phone", "email"]
|
||||
header_names = {
|
||||
"id": "Customer ID",
|
||||
"name": "Name",
|
||||
"group": "Group",
|
||||
"city": "City",
|
||||
"state": "State",
|
||||
"phone": "Primary Phone",
|
||||
"email": "Email",
|
||||
}
|
||||
requested = [f.lower() for f in (fields or []) if isinstance(f, str)]
|
||||
selected_fields = [f for f in allowed_fields_in_order if f in requested] if requested else allowed_fields_in_order
|
||||
if not selected_fields:
|
||||
selected_fields = allowed_fields_in_order
|
||||
writer.writerow([header_names[f] for f in selected_fields])
|
||||
for c in customers:
|
||||
full_name = f"{(c.first or '').strip()} {(c.last or '').strip()}".strip()
|
||||
primary_phone = ""
|
||||
try:
|
||||
if c.phone_numbers:
|
||||
primary_phone = c.phone_numbers[0].phone or ""
|
||||
except Exception:
|
||||
primary_phone = ""
|
||||
row_map = {
|
||||
"id": c.id,
|
||||
"name": full_name,
|
||||
"group": c.group or "",
|
||||
"city": c.city or "",
|
||||
"state": c.abrev or "",
|
||||
"phone": primary_phone,
|
||||
"email": c.email or "",
|
||||
}
|
||||
writer.writerow([row_map[f] for f in selected_fields])
|
||||
header_row, rows = prepare_customer_csv_rows(customers, fields)
|
||||
writer.writerow(header_row)
|
||||
for row in rows:
|
||||
writer.writerow(row)
|
||||
|
||||
output.seek(0)
|
||||
filename = "customers_export.csv"
|
||||
@@ -469,6 +325,10 @@ async def create_customer(
|
||||
db.commit()
|
||||
db.refresh(customer)
|
||||
|
||||
try:
|
||||
await invalidate_search_cache()
|
||||
except Exception:
|
||||
pass
|
||||
return customer
|
||||
|
||||
|
||||
@@ -494,7 +354,10 @@ async def update_customer(
|
||||
|
||||
db.commit()
|
||||
db.refresh(customer)
|
||||
|
||||
try:
|
||||
await invalidate_search_cache()
|
||||
except Exception:
|
||||
pass
|
||||
return customer
|
||||
|
||||
|
||||
@@ -515,17 +378,30 @@ async def delete_customer(
|
||||
|
||||
db.delete(customer)
|
||||
db.commit()
|
||||
|
||||
try:
|
||||
await invalidate_search_cache()
|
||||
except Exception:
|
||||
pass
|
||||
return {"message": "Customer deleted successfully"}
|
||||
|
||||
|
||||
@router.get("/{customer_id}/phones", response_model=List[PhoneResponse])
|
||||
class PaginatedPhonesResponse(BaseModel):
|
||||
items: List[PhoneResponse]
|
||||
total: int
|
||||
|
||||
|
||||
@router.get("/{customer_id}/phones", response_model=Union[List[PhoneResponse], PaginatedPhonesResponse])
|
||||
async def get_customer_phones(
|
||||
customer_id: str,
|
||||
skip: int = Query(0, ge=0, description="Offset for pagination"),
|
||||
limit: int = Query(100, ge=1, le=1000, description="Page size"),
|
||||
sort_by: Optional[str] = Query("location", description="Sort by: location, phone"),
|
||||
sort_dir: Optional[str] = Query("asc", description="Sort direction: asc or desc"),
|
||||
include_total: bool = Query(False, description="When true, returns {items, total} instead of a plain list"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Get customer phone numbers"""
|
||||
"""Get customer phone numbers with optional sorting/pagination"""
|
||||
customer = db.query(Rolodex).filter(Rolodex.id == customer_id).first()
|
||||
|
||||
if not customer:
|
||||
@@ -534,7 +410,21 @@ async def get_customer_phones(
|
||||
detail="Customer not found"
|
||||
)
|
||||
|
||||
phones = db.query(Phone).filter(Phone.rolodex_id == customer_id).all()
|
||||
query = db.query(Phone).filter(Phone.rolodex_id == customer_id)
|
||||
|
||||
query = apply_sorting(
|
||||
query,
|
||||
sort_by,
|
||||
sort_dir,
|
||||
allowed={
|
||||
"location": [Phone.location, Phone.phone],
|
||||
"phone": [Phone.phone],
|
||||
},
|
||||
)
|
||||
|
||||
phones, total = paginate_with_total(query, skip, limit, include_total)
|
||||
if include_total:
|
||||
return {"items": phones, "total": total or 0}
|
||||
return phones
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user