all working

This commit is contained in:
HotSwapp
2025-08-10 21:34:11 -05:00
parent 14ee479edc
commit 1512b2d12a
22 changed files with 1453 additions and 489 deletions

View File

@@ -10,18 +10,23 @@ from sqlalchemy.orm import Session
from app.database.base import get_db
from app.models.user import User
from app.auth.security import (
authenticate_user,
create_access_token,
authenticate_user,
create_access_token,
create_refresh_token,
decode_refresh_token,
is_refresh_token_revoked,
revoke_refresh_token,
get_password_hash,
get_current_user,
get_admin_user
get_admin_user,
)
from app.auth.schemas import (
Token,
UserCreate,
UserResponse,
Token,
UserCreate,
UserResponse,
LoginRequest,
ThemePreferenceUpdate
ThemePreferenceUpdate,
RefreshRequest,
)
from app.config import settings
from app.core.logging import get_logger, log_auth_attempt
@@ -71,6 +76,12 @@ async def login(login_data: LoginRequest, request: Request, db: Session = Depend
access_token = create_access_token(
data={"sub": user.username}, expires_delta=access_token_expires
)
refresh_token = create_refresh_token(
user=user,
user_agent=request.headers.get("user-agent", ""),
ip_address=request.client.host if request.client else None,
db=db,
)
log_auth_attempt(
username=login_data.username,
@@ -85,7 +96,7 @@ async def login(login_data: LoginRequest, request: Request, db: Session = Depend
client_ip=client_ip
)
return {"access_token": access_token, "token_type": "bearer"}
return {"access_token": access_token, "token_type": "bearer", "refresh_token": refresh_token}
@router.post("/register", response_model=UserResponse)
@@ -130,22 +141,76 @@ async def read_users_me(current_user: User = Depends(get_current_user)):
@router.post("/refresh", response_model=Token)
async def refresh_token(
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
async def refresh_token_endpoint(
request: Request,
db: Session = Depends(get_db),
body: RefreshRequest | None = None,
):
"""Refresh access token for current user"""
# Update last login timestamp
current_user.last_login = datetime.utcnow()
"""Issue a new access token using a valid, non-revoked refresh token.
For backwards compatibility with existing clients that may call this without a body,
consider falling back to Authorization header in the future if needed.
"""
# New flow: refresh token in body
if body and body.refresh_token:
payload = decode_refresh_token(body.refresh_token)
if not payload:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token")
jti = payload.get("jti")
username = payload.get("sub")
if not jti or not username:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token payload")
# Verify token not revoked/expired
if is_refresh_token_revoked(jti, db):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Refresh token revoked or expired")
# Load user
user = db.query(User).filter(User.username == username).first()
if not user or not user.is_active:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found or inactive")
# Rotate refresh token on use
revoke_refresh_token(jti, db)
new_refresh_token = create_refresh_token(
user=user,
user_agent=request.headers.get("user-agent", ""),
ip_address=request.client.host if request.client else None,
db=db,
)
# Issue new access token
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
access_token = create_access_token(
data={"sub": user.username}, expires_delta=access_token_expires
)
return {"access_token": access_token, "token_type": "bearer", "refresh_token": new_refresh_token}
# Legacy flow: Authorization header-based refresh
auth_header = request.headers.get("authorization") or request.headers.get("Authorization")
if not auth_header or not auth_header.lower().startswith("bearer "):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing credentials")
token = auth_header.split(" ", 1)[1].strip()
from app.auth.security import verify_token # local import to avoid circular
username = verify_token(token)
if not username:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token")
user = db.query(User).filter(User.username == username).first()
if not user or not user.is_active:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found or inactive")
user.last_login = datetime.utcnow()
db.commit()
# Create new token with full expiration time
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
access_token = create_access_token(
data={"sub": current_user.username},
expires_delta=access_token_expires
data={"sub": user.username}, expires_delta=access_token_expires
)
return {"access_token": access_token, "token_type": "bearer"}
@@ -159,6 +224,23 @@ async def list_users(
return users
@router.post("/logout")
async def logout(body: RefreshRequest | None = None, db: Session = Depends(get_db)):
"""Revoke the provided refresh token. Idempotent and safe to call multiple times.
The client should send a JSON body: { "refresh_token": "..." }.
"""
try:
if body and body.refresh_token:
payload = decode_refresh_token(body.refresh_token)
if payload and payload.get("jti"):
revoke_refresh_token(payload["jti"], db)
except Exception:
# Don't leak details; logout should be best-effort
pass
return {"status": "ok"}
@router.post("/theme-preference")
async def update_theme_preference(
theme_data: ThemePreferenceUpdate,

View File

@@ -2,7 +2,7 @@
Document Management API endpoints - QDROs, Templates, and General Documents
"""
from typing import List, Optional, Dict, Any
from fastapi import APIRouter, Depends, HTTPException, status, Query, UploadFile, File, Form
from fastapi import APIRouter, Depends, HTTPException, status, Query, UploadFile, File, Form, Request
from sqlalchemy.orm import Session, joinedload
from sqlalchemy import or_, func, and_, desc, asc, text
from datetime import date, datetime
@@ -18,6 +18,8 @@ from app.models.lookups import FormIndex, FormList, Footer, Employee
from app.models.user import User
from app.auth.security import get_current_user
from app.models.additional import Document
from app.core.logging import get_logger
from app.services.audit import audit_service
router = APIRouter()
@@ -666,6 +668,78 @@ def _merge_template_variables(content: str, variables: Dict[str, Any]) -> str:
return merged
# --- Client Error Logging (for Documents page) ---
class ClientErrorLog(BaseModel):
"""Payload for client-side error logging"""
message: str
action: Optional[str] = None
stack: Optional[str] = None
url: Optional[str] = None
line: Optional[int] = None
column: Optional[int] = None
user_agent: Optional[str] = None
extra: Optional[Dict[str, Any]] = None
@router.post("/client-error")
async def log_client_error(
payload: ClientErrorLog,
request: Request,
db: Session = Depends(get_db),
current_user: Optional[User] = Depends(lambda: None)
):
"""Accept client-side error logs from the Documents page.
This endpoint is lightweight and safe to call; it records the error to the
application logs and best-effort to the audit log without interrupting the UI.
"""
logger = get_logger("client.documents")
client_ip = request.headers.get("x-forwarded-for")
if client_ip:
client_ip = client_ip.split(",")[0].strip()
else:
client_ip = request.client.host if request.client else None
logger.error(
"Client error reported",
action=payload.action,
message=payload.message,
stack=payload.stack,
page="/documents",
url=payload.url or str(request.url),
line=payload.line,
column=payload.column,
user=getattr(current_user, "username", None),
user_id=getattr(current_user, "id", None),
user_agent=payload.user_agent or request.headers.get("user-agent"),
client_ip=client_ip,
extra=payload.extra,
)
# Best-effort audit log; do not raise on failure
try:
audit_service.log_action(
db=db,
action="CLIENT_ERROR",
resource_type="DOCUMENTS",
user=current_user,
resource_id=None,
details={
"action": payload.action,
"message": payload.message,
"url": payload.url or str(request.url),
"line": payload.line,
"column": payload.column,
"extra": payload.extra,
},
request=request,
)
except Exception:
pass
return {"status": "logged"}
@router.post("/upload/{file_no}")
async def upload_document(
file_no: str,

View File

@@ -45,6 +45,7 @@ class Token(BaseModel):
"""Token response schema"""
access_token: str
token_type: str
refresh_token: str | None = None
class TokenData(BaseModel):
@@ -55,4 +56,9 @@ class TokenData(BaseModel):
class LoginRequest(BaseModel):
"""Login request schema"""
username: str
password: str
password: str
class RefreshRequest(BaseModel):
"""Refresh token submission"""
refresh_token: str

View File

@@ -2,7 +2,8 @@
Authentication and security utilities
"""
from datetime import datetime, timedelta
from typing import Optional, Union
from typing import Optional, Union, Tuple
from uuid import uuid4
from jose import JWTError, jwt
from passlib.context import CryptContext
from sqlalchemy.orm import Session
@@ -12,6 +13,7 @@ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from app.config import settings
from app.database.base import get_db
from app.models.user import User
from app.models.auth import RefreshToken
# Password hashing
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
@@ -30,39 +32,107 @@ def get_password_hash(password: str) -> str:
return pwd_context.hash(password)
def _encode_with_rotation(payload: dict) -> str:
"""Encode JWT with active secret and algorithm."""
return jwt.encode(payload, settings.secret_key, algorithm=settings.algorithm)
def _decode_with_rotation(token: str) -> dict:
"""Decode JWT trying current then previous secret if set."""
try:
return jwt.decode(token, settings.secret_key, algorithms=[settings.algorithm])
except JWTError:
# Try previous secret to allow seamless rotation
if settings.previous_secret_key:
try:
return jwt.decode(token, settings.previous_secret_key, algorithms=[settings.algorithm])
except JWTError:
pass
raise
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""Create JWT access token"""
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=settings.access_token_expire_minutes)
expire = datetime.utcnow() + (
expires_delta if expires_delta else timedelta(minutes=settings.access_token_expire_minutes)
)
to_encode.update({
"exp": expire,
"iat": datetime.utcnow(),
"type": "access",
})
encoded_jwt = jwt.encode(to_encode, settings.secret_key, algorithm=settings.algorithm)
return encoded_jwt
return _encode_with_rotation(to_encode)
def create_refresh_token(user: User, user_agent: Optional[str], ip_address: Optional[str], db: Session) -> str:
"""Create refresh token, store its JTI in DB for revocation."""
jti = uuid4().hex
expire = datetime.utcnow() + timedelta(minutes=settings.refresh_token_expire_minutes)
payload = {
"sub": user.username,
"uid": user.id,
"jti": jti,
"type": "refresh",
"exp": expire,
"iat": datetime.utcnow(),
}
token = _encode_with_rotation(payload)
db_token = RefreshToken(
user_id=user.id,
jti=jti,
user_agent=user_agent,
ip_address=ip_address,
issued_at=datetime.utcnow(),
expires_at=expire,
revoked=False,
)
db.add(db_token)
db.commit()
return token
def verify_token(token: str) -> Optional[str]:
"""Verify JWT token and return username"""
try:
payload = jwt.decode(
token,
settings.secret_key,
algorithms=[settings.algorithm],
leeway=30 # allow small clock skew
)
payload = _decode_with_rotation(token)
username: str = payload.get("sub")
token_type: str = payload.get("type")
if username is None:
return None
# Only accept access tokens for auth
if token_type and token_type != "access":
return None
return username
except JWTError:
return None
def decode_refresh_token(token: str) -> Optional[dict]:
"""Decode refresh token and return payload if valid and not revoked."""
try:
payload = _decode_with_rotation(token)
if payload.get("type") != "refresh":
return None
return payload
except JWTError:
return None
def is_refresh_token_revoked(jti: str, db: Session) -> bool:
token_row = db.query(RefreshToken).filter(RefreshToken.jti == jti).first()
return not token_row or token_row.revoked or token_row.expires_at <= datetime.utcnow()
def revoke_refresh_token(jti: str, db: Session) -> None:
token_row = db.query(RefreshToken).filter(RefreshToken.jti == jti).first()
if token_row and not token_row.revoked:
token_row.revoked = True
token_row.revoked_at = datetime.utcnow()
db.commit()
def authenticate_user(db: Session, username: str, password: str) -> Optional[User]:
"""Authenticate user credentials"""
user = db.query(User).filter(User.username == username).first()

View File

@@ -1,53 +1,70 @@
"""
Delphi Consulting Group Database System - Configuration
"""
from pydantic_settings import BaseSettings
from typing import Optional
from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
"""Application configuration"""
"""Application configuration (env and .env driven).
Environment precedence: real environment variables take priority over .env,
which take priority over defaults.
"""
# Application
app_name: str = "Delphi Consulting Group Database System"
app_version: str = "1.0.0"
debug: bool = False
# Database
database_url: str = "sqlite:///./data/delphi_database.db"
# Authentication
secret_key: str = "your-secret-key-change-in-production"
# Authentication / JWT
# Require SECRET_KEY to be provided via environment/.env (no insecure default)
secret_key: str = Field(..., min_length=32)
# Optional previous secret key to allow seamless rotation
previous_secret_key: Optional[str] = None
algorithm: str = "HS256"
access_token_expire_minutes: int = 240 # 4 hours
# Long-lived refresh token expiration (default 30 days)
refresh_token_expire_minutes: int = 43200
# Admin account settings
admin_username: str = "admin"
admin_password: str = "change-me"
# File paths
upload_dir: str = "./uploads"
backup_dir: str = "./backups"
# Pagination
default_page_size: int = 50
max_page_size: int = 200
# Docker/deployment settings
external_port: Optional[str] = None
allowed_hosts: Optional[str] = None
cors_origins: Optional[str] = None
secure_cookies: bool = False
compose_project_name: Optional[str] = None
# Logging
log_level: str = "INFO"
log_to_file: bool = True
log_rotation: str = "10 MB"
log_retention: str = "30 days"
class Config:
env_file = ".env"
# pydantic-settings v2 configuration
model_config = SettingsConfigDict(
env_file=".env",
env_prefix="",
case_sensitive=False,
env_ignore_empty=True,
extra="ignore",
)
settings = Settings()

View File

@@ -14,6 +14,7 @@ from app.models.user import User
from app.auth.security import get_admin_user
from app.core.logging import setup_logging, get_logger
from app.middleware.logging import LoggingMiddleware
from app.middleware.errors import register_exception_handlers
# Initialize logging
setup_logging()
@@ -35,6 +36,10 @@ app = FastAPI(
logger.info("Adding request logging middleware")
app.add_middleware(LoggingMiddleware, log_requests=True, log_responses=settings.debug)
# Register global exception handlers
logger.info("Registering global exception handlers")
register_exception_handlers(app)
# Configure CORS
logger.info("Configuring CORS middleware")
app.add_middleware(

132
app/middleware/errors.py Normal file
View File

@@ -0,0 +1,132 @@
"""
Global exception handlers that return a consistent JSON error envelope
and propagate the X-Correlation-ID header.
"""
from __future__ import annotations
from typing import Any, Optional
from uuid import uuid4
from fastapi import FastAPI, Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from starlette import status as http_status
from starlette.requests import Request as StarletteRequest
from starlette.responses import JSONResponse as StarletteJSONResponse
from fastapi import HTTPException
try:
# Prefer project logger if available
from app.core.logging import get_logger
logger = get_logger("middleware.errors")
except Exception: # pragma: no cover - fallback simple logger
import logging
logger = logging.getLogger("middleware.errors")
ERROR_HEADER_NAME = "X-Correlation-ID"
def _get_correlation_id(request: Request) -> str:
"""Resolve correlation ID from request state, headers, or generate a new one."""
# From middleware
correlation_id: Optional[str] = getattr(getattr(request, "state", object()), "correlation_id", None)
if correlation_id:
return correlation_id
# From incoming headers
correlation_id = (
request.headers.get("x-correlation-id")
or request.headers.get("x-request-id")
)
if correlation_id:
return correlation_id
# Generate a new one if not present
return str(uuid4())
def _build_error_response(
request: Request,
*,
status_code: int,
message: str,
code: str,
details: Any | None = None,
) -> JSONResponse:
correlation_id = _get_correlation_id(request)
body = {
"success": False,
"error": {
"status": status_code,
"code": code,
"message": message,
},
"correlation_id": correlation_id,
}
if details is not None:
body["error"]["details"] = details
response = JSONResponse(content=body, status_code=status_code)
response.headers[ERROR_HEADER_NAME] = correlation_id
return response
async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
"""Handle FastAPI HTTPException with envelope and correlation id."""
message = exc.detail if isinstance(exc.detail, str) else "HTTP error"
logger.warning(
"HTTPException raised",
status_code=exc.status_code,
detail=message,
path=request.url.path,
)
return _build_error_response(
request,
status_code=exc.status_code,
message=message,
code="http_error",
details=None,
)
async def validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse:
"""Handle validation errors from FastAPI/Pydantic."""
logger.info(
"Validation error",
path=request.url.path,
errors=exc.errors(),
)
return _build_error_response(
request,
status_code=http_status.HTTP_422_UNPROCESSABLE_ENTITY,
message="Validation error",
code="validation_error",
details=exc.errors(),
)
async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONResponse:
"""Catch-all handler for unexpected exceptions (500)."""
# Log full exception for diagnostics without leaking internals to clients
try:
logger.exception("Unhandled exception", path=request.url.path)
except Exception:
pass
return _build_error_response(
request,
status_code=http_status.HTTP_500_INTERNAL_SERVER_ERROR,
message="Internal Server Error",
code="internal_error",
details=None,
)
def register_exception_handlers(app: FastAPI) -> None:
"""Register global exception handlers on the provided FastAPI app."""
app.add_exception_handler(HTTPException, http_exception_handler)
app.add_exception_handler(RequestValidationError, validation_exception_handler)
app.add_exception_handler(Exception, unhandled_exception_handler)

View File

@@ -4,6 +4,7 @@ Request/Response Logging Middleware
import time
import json
from typing import Callable
from uuid import uuid4
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import StreamingResponse
@@ -21,10 +22,19 @@ class LoggingMiddleware(BaseHTTPMiddleware):
self.log_responses = log_responses
async def dispatch(self, request: Request, call_next: Callable) -> Response:
# Skip logging for static files and health checks
# Correlation ID: use incoming header or generate a new one
correlation_id = request.headers.get("x-correlation-id") or request.headers.get("x-request-id") or str(uuid4())
request.state.correlation_id = correlation_id
# Skip logging for static files and health checks (still attach correlation id)
skip_paths = ["/static/", "/uploads/", "/health", "/favicon.ico"]
if any(request.url.path.startswith(path) for path in skip_paths):
return await call_next(request)
response = await call_next(request)
try:
response.headers["X-Correlation-ID"] = correlation_id
except Exception:
pass
return response
# Record start time
start_time = time.time()
@@ -41,7 +51,8 @@ class LoggingMiddleware(BaseHTTPMiddleware):
path=request.url.path,
query_params=str(request.query_params) if request.query_params else None,
client_ip=client_ip,
user_agent=user_agent
user_agent=user_agent,
correlation_id=correlation_id,
)
# Process request
@@ -56,7 +67,8 @@ class LoggingMiddleware(BaseHTTPMiddleware):
path=request.url.path,
duration_ms=duration_ms,
error=str(e),
client_ip=client_ip
client_ip=client_ip,
correlation_id=correlation_id,
)
raise
@@ -74,7 +86,8 @@ class LoggingMiddleware(BaseHTTPMiddleware):
path=request.url.path,
status_code=response.status_code,
duration_ms=duration_ms,
user_id=user_id
user_id=user_id,
correlation_id=correlation_id,
)
# Log response details if enabled
@@ -84,7 +97,8 @@ class LoggingMiddleware(BaseHTTPMiddleware):
status_code=response.status_code,
headers=dict(response.headers),
size_bytes=response.headers.get("content-length"),
content_type=response.headers.get("content-type")
content_type=response.headers.get("content-type"),
correlation_id=correlation_id,
)
# Log slow requests as warnings
@@ -94,7 +108,8 @@ class LoggingMiddleware(BaseHTTPMiddleware):
method=request.method,
path=request.url.path,
duration_ms=duration_ms,
status_code=response.status_code
status_code=response.status_code,
correlation_id=correlation_id,
)
# Log authentication-related requests to auth log
@@ -106,9 +121,16 @@ class LoggingMiddleware(BaseHTTPMiddleware):
status_code=response.status_code,
duration_ms=duration_ms,
client_ip=client_ip,
user_agent=user_agent
user_agent=user_agent,
correlation_id=correlation_id,
)
# Attach correlation id header to all responses
try:
response.headers["X-Correlation-ID"] = correlation_id
except Exception:
pass
return response
def get_client_ip(self, request: Request) -> str:

View File

@@ -8,6 +8,7 @@ from .files import File
from .ledger import Ledger
from .qdro import QDRO
from .audit import AuditLog, LoginAttempt
from .auth import RefreshToken
from .additional import Deposit, Payment, FileNote, FormVariable, ReportVariable, Document
from .support import SupportTicket, TicketResponse, TicketStatus, TicketPriority, TicketCategory
from .pensions import (
@@ -22,7 +23,7 @@ from .lookups import (
__all__ = [
"BaseModel", "User", "Rolodex", "Phone", "File", "Ledger", "QDRO",
"AuditLog", "LoginAttempt",
"AuditLog", "LoginAttempt", "RefreshToken",
"Deposit", "Payment", "FileNote", "FormVariable", "ReportVariable", "Document",
"SupportTicket", "TicketResponse", "TicketStatus", "TicketPriority", "TicketCategory",
"Pension", "PensionSchedule", "MarriageHistory", "DeathBenefit",

34
app/models/auth.py Normal file
View File

@@ -0,0 +1,34 @@
"""
Authentication-related persistence models
"""
from datetime import datetime
from typing import Optional
from sqlalchemy import Column, Integer, String, DateTime, Boolean, ForeignKey, UniqueConstraint
from sqlalchemy.orm import relationship
from app.models.base import BaseModel
class RefreshToken(BaseModel):
"""Persisted refresh tokens for revocation and auditing."""
__tablename__ = "refresh_tokens"
id = Column(Integer, primary_key=True, autoincrement=True, index=True)
user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True)
jti = Column(String(64), nullable=False, unique=True, index=True)
user_agent = Column(String(255), nullable=True)
ip_address = Column(String(45), nullable=True)
issued_at = Column(DateTime, default=datetime.utcnow, nullable=False)
expires_at = Column(DateTime, nullable=False, index=True)
revoked = Column(Boolean, default=False, nullable=False)
revoked_at = Column(DateTime, nullable=True)
# relationships
user = relationship("User")
__table_args__ = (
UniqueConstraint("jti", name="uq_refresh_tokens_jti"),
)