all working
This commit is contained in:
120
app/api/auth.py
120
app/api/auth.py
@@ -10,18 +10,23 @@ from sqlalchemy.orm import Session
|
||||
from app.database.base import get_db
|
||||
from app.models.user import User
|
||||
from app.auth.security import (
|
||||
authenticate_user,
|
||||
create_access_token,
|
||||
authenticate_user,
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
decode_refresh_token,
|
||||
is_refresh_token_revoked,
|
||||
revoke_refresh_token,
|
||||
get_password_hash,
|
||||
get_current_user,
|
||||
get_admin_user
|
||||
get_admin_user,
|
||||
)
|
||||
from app.auth.schemas import (
|
||||
Token,
|
||||
UserCreate,
|
||||
UserResponse,
|
||||
Token,
|
||||
UserCreate,
|
||||
UserResponse,
|
||||
LoginRequest,
|
||||
ThemePreferenceUpdate
|
||||
ThemePreferenceUpdate,
|
||||
RefreshRequest,
|
||||
)
|
||||
from app.config import settings
|
||||
from app.core.logging import get_logger, log_auth_attempt
|
||||
@@ -71,6 +76,12 @@ async def login(login_data: LoginRequest, request: Request, db: Session = Depend
|
||||
access_token = create_access_token(
|
||||
data={"sub": user.username}, expires_delta=access_token_expires
|
||||
)
|
||||
refresh_token = create_refresh_token(
|
||||
user=user,
|
||||
user_agent=request.headers.get("user-agent", ""),
|
||||
ip_address=request.client.host if request.client else None,
|
||||
db=db,
|
||||
)
|
||||
|
||||
log_auth_attempt(
|
||||
username=login_data.username,
|
||||
@@ -85,7 +96,7 @@ async def login(login_data: LoginRequest, request: Request, db: Session = Depend
|
||||
client_ip=client_ip
|
||||
)
|
||||
|
||||
return {"access_token": access_token, "token_type": "bearer"}
|
||||
return {"access_token": access_token, "token_type": "bearer", "refresh_token": refresh_token}
|
||||
|
||||
|
||||
@router.post("/register", response_model=UserResponse)
|
||||
@@ -130,22 +141,76 @@ async def read_users_me(current_user: User = Depends(get_current_user)):
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=Token)
|
||||
async def refresh_token(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
async def refresh_token_endpoint(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
body: RefreshRequest | None = None,
|
||||
):
|
||||
"""Refresh access token for current user"""
|
||||
# Update last login timestamp
|
||||
current_user.last_login = datetime.utcnow()
|
||||
"""Issue a new access token using a valid, non-revoked refresh token.
|
||||
|
||||
For backwards compatibility with existing clients that may call this without a body,
|
||||
consider falling back to Authorization header in the future if needed.
|
||||
"""
|
||||
# New flow: refresh token in body
|
||||
if body and body.refresh_token:
|
||||
payload = decode_refresh_token(body.refresh_token)
|
||||
if not payload:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token")
|
||||
|
||||
jti = payload.get("jti")
|
||||
username = payload.get("sub")
|
||||
if not jti or not username:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token payload")
|
||||
|
||||
# Verify token not revoked/expired
|
||||
if is_refresh_token_revoked(jti, db):
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Refresh token revoked or expired")
|
||||
|
||||
# Load user
|
||||
user = db.query(User).filter(User.username == username).first()
|
||||
if not user or not user.is_active:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found or inactive")
|
||||
|
||||
# Rotate refresh token on use
|
||||
revoke_refresh_token(jti, db)
|
||||
new_refresh_token = create_refresh_token(
|
||||
user=user,
|
||||
user_agent=request.headers.get("user-agent", ""),
|
||||
ip_address=request.client.host if request.client else None,
|
||||
db=db,
|
||||
)
|
||||
|
||||
# Issue new access token
|
||||
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
|
||||
access_token = create_access_token(
|
||||
data={"sub": user.username}, expires_delta=access_token_expires
|
||||
)
|
||||
|
||||
return {"access_token": access_token, "token_type": "bearer", "refresh_token": new_refresh_token}
|
||||
|
||||
# Legacy flow: Authorization header-based refresh
|
||||
auth_header = request.headers.get("authorization") or request.headers.get("Authorization")
|
||||
if not auth_header or not auth_header.lower().startswith("bearer "):
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing credentials")
|
||||
|
||||
token = auth_header.split(" ", 1)[1].strip()
|
||||
from app.auth.security import verify_token # local import to avoid circular
|
||||
username = verify_token(token)
|
||||
if not username:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token")
|
||||
|
||||
user = db.query(User).filter(User.username == username).first()
|
||||
if not user or not user.is_active:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found or inactive")
|
||||
|
||||
user.last_login = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
# Create new token with full expiration time
|
||||
|
||||
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
|
||||
access_token = create_access_token(
|
||||
data={"sub": current_user.username},
|
||||
expires_delta=access_token_expires
|
||||
data={"sub": user.username}, expires_delta=access_token_expires
|
||||
)
|
||||
|
||||
|
||||
return {"access_token": access_token, "token_type": "bearer"}
|
||||
|
||||
|
||||
@@ -159,6 +224,23 @@ async def list_users(
|
||||
return users
|
||||
|
||||
|
||||
@router.post("/logout")
|
||||
async def logout(body: RefreshRequest | None = None, db: Session = Depends(get_db)):
|
||||
"""Revoke the provided refresh token. Idempotent and safe to call multiple times.
|
||||
|
||||
The client should send a JSON body: { "refresh_token": "..." }.
|
||||
"""
|
||||
try:
|
||||
if body and body.refresh_token:
|
||||
payload = decode_refresh_token(body.refresh_token)
|
||||
if payload and payload.get("jti"):
|
||||
revoke_refresh_token(payload["jti"], db)
|
||||
except Exception:
|
||||
# Don't leak details; logout should be best-effort
|
||||
pass
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@router.post("/theme-preference")
|
||||
async def update_theme_preference(
|
||||
theme_data: ThemePreferenceUpdate,
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
Document Management API endpoints - QDROs, Templates, and General Documents
|
||||
"""
|
||||
from typing import List, Optional, Dict, Any
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query, UploadFile, File, Form
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query, UploadFile, File, Form, Request
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
from sqlalchemy import or_, func, and_, desc, asc, text
|
||||
from datetime import date, datetime
|
||||
@@ -18,6 +18,8 @@ from app.models.lookups import FormIndex, FormList, Footer, Employee
|
||||
from app.models.user import User
|
||||
from app.auth.security import get_current_user
|
||||
from app.models.additional import Document
|
||||
from app.core.logging import get_logger
|
||||
from app.services.audit import audit_service
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -666,6 +668,78 @@ def _merge_template_variables(content: str, variables: Dict[str, Any]) -> str:
|
||||
return merged
|
||||
|
||||
|
||||
# --- Client Error Logging (for Documents page) ---
|
||||
class ClientErrorLog(BaseModel):
|
||||
"""Payload for client-side error logging"""
|
||||
message: str
|
||||
action: Optional[str] = None
|
||||
stack: Optional[str] = None
|
||||
url: Optional[str] = None
|
||||
line: Optional[int] = None
|
||||
column: Optional[int] = None
|
||||
user_agent: Optional[str] = None
|
||||
extra: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@router.post("/client-error")
|
||||
async def log_client_error(
|
||||
payload: ClientErrorLog,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Optional[User] = Depends(lambda: None)
|
||||
):
|
||||
"""Accept client-side error logs from the Documents page.
|
||||
|
||||
This endpoint is lightweight and safe to call; it records the error to the
|
||||
application logs and best-effort to the audit log without interrupting the UI.
|
||||
"""
|
||||
logger = get_logger("client.documents")
|
||||
client_ip = request.headers.get("x-forwarded-for")
|
||||
if client_ip:
|
||||
client_ip = client_ip.split(",")[0].strip()
|
||||
else:
|
||||
client_ip = request.client.host if request.client else None
|
||||
|
||||
logger.error(
|
||||
"Client error reported",
|
||||
action=payload.action,
|
||||
message=payload.message,
|
||||
stack=payload.stack,
|
||||
page="/documents",
|
||||
url=payload.url or str(request.url),
|
||||
line=payload.line,
|
||||
column=payload.column,
|
||||
user=getattr(current_user, "username", None),
|
||||
user_id=getattr(current_user, "id", None),
|
||||
user_agent=payload.user_agent or request.headers.get("user-agent"),
|
||||
client_ip=client_ip,
|
||||
extra=payload.extra,
|
||||
)
|
||||
|
||||
# Best-effort audit log; do not raise on failure
|
||||
try:
|
||||
audit_service.log_action(
|
||||
db=db,
|
||||
action="CLIENT_ERROR",
|
||||
resource_type="DOCUMENTS",
|
||||
user=current_user,
|
||||
resource_id=None,
|
||||
details={
|
||||
"action": payload.action,
|
||||
"message": payload.message,
|
||||
"url": payload.url or str(request.url),
|
||||
"line": payload.line,
|
||||
"column": payload.column,
|
||||
"extra": payload.extra,
|
||||
},
|
||||
request=request,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {"status": "logged"}
|
||||
|
||||
|
||||
@router.post("/upload/{file_no}")
|
||||
async def upload_document(
|
||||
file_no: str,
|
||||
|
||||
@@ -45,6 +45,7 @@ class Token(BaseModel):
|
||||
"""Token response schema"""
|
||||
access_token: str
|
||||
token_type: str
|
||||
refresh_token: str | None = None
|
||||
|
||||
|
||||
class TokenData(BaseModel):
|
||||
@@ -55,4 +56,9 @@ class TokenData(BaseModel):
|
||||
class LoginRequest(BaseModel):
|
||||
"""Login request schema"""
|
||||
username: str
|
||||
password: str
|
||||
password: str
|
||||
|
||||
|
||||
class RefreshRequest(BaseModel):
|
||||
"""Refresh token submission"""
|
||||
refresh_token: str
|
||||
@@ -2,7 +2,8 @@
|
||||
Authentication and security utilities
|
||||
"""
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Union
|
||||
from typing import Optional, Union, Tuple
|
||||
from uuid import uuid4
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -12,6 +13,7 @@ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from app.config import settings
|
||||
from app.database.base import get_db
|
||||
from app.models.user import User
|
||||
from app.models.auth import RefreshToken
|
||||
|
||||
# Password hashing
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
@@ -30,39 +32,107 @@ def get_password_hash(password: str) -> str:
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
def _encode_with_rotation(payload: dict) -> str:
|
||||
"""Encode JWT with active secret and algorithm."""
|
||||
return jwt.encode(payload, settings.secret_key, algorithm=settings.algorithm)
|
||||
|
||||
|
||||
def _decode_with_rotation(token: str) -> dict:
|
||||
"""Decode JWT trying current then previous secret if set."""
|
||||
try:
|
||||
return jwt.decode(token, settings.secret_key, algorithms=[settings.algorithm])
|
||||
except JWTError:
|
||||
# Try previous secret to allow seamless rotation
|
||||
if settings.previous_secret_key:
|
||||
try:
|
||||
return jwt.decode(token, settings.previous_secret_key, algorithms=[settings.algorithm])
|
||||
except JWTError:
|
||||
pass
|
||||
raise
|
||||
|
||||
|
||||
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
||||
"""Create JWT access token"""
|
||||
to_encode = data.copy()
|
||||
if expires_delta:
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
else:
|
||||
expire = datetime.utcnow() + timedelta(minutes=settings.access_token_expire_minutes)
|
||||
|
||||
expire = datetime.utcnow() + (
|
||||
expires_delta if expires_delta else timedelta(minutes=settings.access_token_expire_minutes)
|
||||
)
|
||||
to_encode.update({
|
||||
"exp": expire,
|
||||
"iat": datetime.utcnow(),
|
||||
"type": "access",
|
||||
})
|
||||
encoded_jwt = jwt.encode(to_encode, settings.secret_key, algorithm=settings.algorithm)
|
||||
return encoded_jwt
|
||||
return _encode_with_rotation(to_encode)
|
||||
|
||||
|
||||
def create_refresh_token(user: User, user_agent: Optional[str], ip_address: Optional[str], db: Session) -> str:
|
||||
"""Create refresh token, store its JTI in DB for revocation."""
|
||||
jti = uuid4().hex
|
||||
expire = datetime.utcnow() + timedelta(minutes=settings.refresh_token_expire_minutes)
|
||||
payload = {
|
||||
"sub": user.username,
|
||||
"uid": user.id,
|
||||
"jti": jti,
|
||||
"type": "refresh",
|
||||
"exp": expire,
|
||||
"iat": datetime.utcnow(),
|
||||
}
|
||||
token = _encode_with_rotation(payload)
|
||||
|
||||
db_token = RefreshToken(
|
||||
user_id=user.id,
|
||||
jti=jti,
|
||||
user_agent=user_agent,
|
||||
ip_address=ip_address,
|
||||
issued_at=datetime.utcnow(),
|
||||
expires_at=expire,
|
||||
revoked=False,
|
||||
)
|
||||
db.add(db_token)
|
||||
db.commit()
|
||||
return token
|
||||
|
||||
|
||||
def verify_token(token: str) -> Optional[str]:
|
||||
"""Verify JWT token and return username"""
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.secret_key,
|
||||
algorithms=[settings.algorithm],
|
||||
leeway=30 # allow small clock skew
|
||||
)
|
||||
payload = _decode_with_rotation(token)
|
||||
username: str = payload.get("sub")
|
||||
token_type: str = payload.get("type")
|
||||
if username is None:
|
||||
return None
|
||||
# Only accept access tokens for auth
|
||||
if token_type and token_type != "access":
|
||||
return None
|
||||
return username
|
||||
except JWTError:
|
||||
return None
|
||||
|
||||
|
||||
def decode_refresh_token(token: str) -> Optional[dict]:
|
||||
"""Decode refresh token and return payload if valid and not revoked."""
|
||||
try:
|
||||
payload = _decode_with_rotation(token)
|
||||
if payload.get("type") != "refresh":
|
||||
return None
|
||||
return payload
|
||||
except JWTError:
|
||||
return None
|
||||
|
||||
|
||||
def is_refresh_token_revoked(jti: str, db: Session) -> bool:
|
||||
token_row = db.query(RefreshToken).filter(RefreshToken.jti == jti).first()
|
||||
return not token_row or token_row.revoked or token_row.expires_at <= datetime.utcnow()
|
||||
|
||||
|
||||
def revoke_refresh_token(jti: str, db: Session) -> None:
|
||||
token_row = db.query(RefreshToken).filter(RefreshToken.jti == jti).first()
|
||||
if token_row and not token_row.revoked:
|
||||
token_row.revoked = True
|
||||
token_row.revoked_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
|
||||
def authenticate_user(db: Session, username: str, password: str) -> Optional[User]:
|
||||
"""Authenticate user credentials"""
|
||||
user = db.query(User).filter(User.username == username).first()
|
||||
|
||||
@@ -1,53 +1,70 @@
|
||||
"""
|
||||
Delphi Consulting Group Database System - Configuration
|
||||
"""
|
||||
from pydantic_settings import BaseSettings
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application configuration"""
|
||||
|
||||
"""Application configuration (env and .env driven).
|
||||
|
||||
Environment precedence: real environment variables take priority over .env,
|
||||
which take priority over defaults.
|
||||
"""
|
||||
|
||||
# Application
|
||||
app_name: str = "Delphi Consulting Group Database System"
|
||||
app_version: str = "1.0.0"
|
||||
debug: bool = False
|
||||
|
||||
|
||||
# Database
|
||||
database_url: str = "sqlite:///./data/delphi_database.db"
|
||||
|
||||
# Authentication
|
||||
secret_key: str = "your-secret-key-change-in-production"
|
||||
|
||||
# Authentication / JWT
|
||||
# Require SECRET_KEY to be provided via environment/.env (no insecure default)
|
||||
secret_key: str = Field(..., min_length=32)
|
||||
# Optional previous secret key to allow seamless rotation
|
||||
previous_secret_key: Optional[str] = None
|
||||
algorithm: str = "HS256"
|
||||
access_token_expire_minutes: int = 240 # 4 hours
|
||||
|
||||
# Long-lived refresh token expiration (default 30 days)
|
||||
refresh_token_expire_minutes: int = 43200
|
||||
|
||||
# Admin account settings
|
||||
admin_username: str = "admin"
|
||||
admin_password: str = "change-me"
|
||||
|
||||
|
||||
# File paths
|
||||
upload_dir: str = "./uploads"
|
||||
backup_dir: str = "./backups"
|
||||
|
||||
|
||||
# Pagination
|
||||
default_page_size: int = 50
|
||||
max_page_size: int = 200
|
||||
|
||||
|
||||
# Docker/deployment settings
|
||||
external_port: Optional[str] = None
|
||||
allowed_hosts: Optional[str] = None
|
||||
cors_origins: Optional[str] = None
|
||||
secure_cookies: bool = False
|
||||
compose_project_name: Optional[str] = None
|
||||
|
||||
|
||||
# Logging
|
||||
log_level: str = "INFO"
|
||||
log_to_file: bool = True
|
||||
log_rotation: str = "10 MB"
|
||||
log_retention: str = "30 days"
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
|
||||
# pydantic-settings v2 configuration
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env",
|
||||
env_prefix="",
|
||||
case_sensitive=False,
|
||||
env_ignore_empty=True,
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
|
||||
settings = Settings()
|
||||
@@ -14,6 +14,7 @@ from app.models.user import User
|
||||
from app.auth.security import get_admin_user
|
||||
from app.core.logging import setup_logging, get_logger
|
||||
from app.middleware.logging import LoggingMiddleware
|
||||
from app.middleware.errors import register_exception_handlers
|
||||
|
||||
# Initialize logging
|
||||
setup_logging()
|
||||
@@ -35,6 +36,10 @@ app = FastAPI(
|
||||
logger.info("Adding request logging middleware")
|
||||
app.add_middleware(LoggingMiddleware, log_requests=True, log_responses=settings.debug)
|
||||
|
||||
# Register global exception handlers
|
||||
logger.info("Registering global exception handlers")
|
||||
register_exception_handlers(app)
|
||||
|
||||
# Configure CORS
|
||||
logger.info("Configuring CORS middleware")
|
||||
app.add_middleware(
|
||||
|
||||
132
app/middleware/errors.py
Normal file
132
app/middleware/errors.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""
|
||||
Global exception handlers that return a consistent JSON error envelope
|
||||
and propagate the X-Correlation-ID header.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette import status as http_status
|
||||
from starlette.requests import Request as StarletteRequest
|
||||
from starlette.responses import JSONResponse as StarletteJSONResponse
|
||||
from fastapi import HTTPException
|
||||
|
||||
try:
|
||||
# Prefer project logger if available
|
||||
from app.core.logging import get_logger
|
||||
logger = get_logger("middleware.errors")
|
||||
except Exception: # pragma: no cover - fallback simple logger
|
||||
import logging
|
||||
logger = logging.getLogger("middleware.errors")
|
||||
|
||||
|
||||
ERROR_HEADER_NAME = "X-Correlation-ID"
|
||||
|
||||
|
||||
def _get_correlation_id(request: Request) -> str:
|
||||
"""Resolve correlation ID from request state, headers, or generate a new one."""
|
||||
# From middleware
|
||||
correlation_id: Optional[str] = getattr(getattr(request, "state", object()), "correlation_id", None)
|
||||
if correlation_id:
|
||||
return correlation_id
|
||||
|
||||
# From incoming headers
|
||||
correlation_id = (
|
||||
request.headers.get("x-correlation-id")
|
||||
or request.headers.get("x-request-id")
|
||||
)
|
||||
if correlation_id:
|
||||
return correlation_id
|
||||
|
||||
# Generate a new one if not present
|
||||
return str(uuid4())
|
||||
|
||||
|
||||
def _build_error_response(
|
||||
request: Request,
|
||||
*,
|
||||
status_code: int,
|
||||
message: str,
|
||||
code: str,
|
||||
details: Any | None = None,
|
||||
) -> JSONResponse:
|
||||
correlation_id = _get_correlation_id(request)
|
||||
|
||||
body = {
|
||||
"success": False,
|
||||
"error": {
|
||||
"status": status_code,
|
||||
"code": code,
|
||||
"message": message,
|
||||
},
|
||||
"correlation_id": correlation_id,
|
||||
}
|
||||
if details is not None:
|
||||
body["error"]["details"] = details
|
||||
|
||||
response = JSONResponse(content=body, status_code=status_code)
|
||||
response.headers[ERROR_HEADER_NAME] = correlation_id
|
||||
return response
|
||||
|
||||
|
||||
async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
|
||||
"""Handle FastAPI HTTPException with envelope and correlation id."""
|
||||
message = exc.detail if isinstance(exc.detail, str) else "HTTP error"
|
||||
logger.warning(
|
||||
"HTTPException raised",
|
||||
status_code=exc.status_code,
|
||||
detail=message,
|
||||
path=request.url.path,
|
||||
)
|
||||
return _build_error_response(
|
||||
request,
|
||||
status_code=exc.status_code,
|
||||
message=message,
|
||||
code="http_error",
|
||||
details=None,
|
||||
)
|
||||
|
||||
|
||||
async def validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse:
|
||||
"""Handle validation errors from FastAPI/Pydantic."""
|
||||
logger.info(
|
||||
"Validation error",
|
||||
path=request.url.path,
|
||||
errors=exc.errors(),
|
||||
)
|
||||
return _build_error_response(
|
||||
request,
|
||||
status_code=http_status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
message="Validation error",
|
||||
code="validation_error",
|
||||
details=exc.errors(),
|
||||
)
|
||||
|
||||
|
||||
async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONResponse:
|
||||
"""Catch-all handler for unexpected exceptions (500)."""
|
||||
# Log full exception for diagnostics without leaking internals to clients
|
||||
try:
|
||||
logger.exception("Unhandled exception", path=request.url.path)
|
||||
except Exception:
|
||||
pass
|
||||
return _build_error_response(
|
||||
request,
|
||||
status_code=http_status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
message="Internal Server Error",
|
||||
code="internal_error",
|
||||
details=None,
|
||||
)
|
||||
|
||||
|
||||
def register_exception_handlers(app: FastAPI) -> None:
|
||||
"""Register global exception handlers on the provided FastAPI app."""
|
||||
app.add_exception_handler(HTTPException, http_exception_handler)
|
||||
app.add_exception_handler(RequestValidationError, validation_exception_handler)
|
||||
app.add_exception_handler(Exception, unhandled_exception_handler)
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ Request/Response Logging Middleware
|
||||
import time
|
||||
import json
|
||||
from typing import Callable
|
||||
from uuid import uuid4
|
||||
from fastapi import Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import StreamingResponse
|
||||
@@ -21,10 +22,19 @@ class LoggingMiddleware(BaseHTTPMiddleware):
|
||||
self.log_responses = log_responses
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
# Skip logging for static files and health checks
|
||||
# Correlation ID: use incoming header or generate a new one
|
||||
correlation_id = request.headers.get("x-correlation-id") or request.headers.get("x-request-id") or str(uuid4())
|
||||
request.state.correlation_id = correlation_id
|
||||
|
||||
# Skip logging for static files and health checks (still attach correlation id)
|
||||
skip_paths = ["/static/", "/uploads/", "/health", "/favicon.ico"]
|
||||
if any(request.url.path.startswith(path) for path in skip_paths):
|
||||
return await call_next(request)
|
||||
response = await call_next(request)
|
||||
try:
|
||||
response.headers["X-Correlation-ID"] = correlation_id
|
||||
except Exception:
|
||||
pass
|
||||
return response
|
||||
|
||||
# Record start time
|
||||
start_time = time.time()
|
||||
@@ -41,7 +51,8 @@ class LoggingMiddleware(BaseHTTPMiddleware):
|
||||
path=request.url.path,
|
||||
query_params=str(request.query_params) if request.query_params else None,
|
||||
client_ip=client_ip,
|
||||
user_agent=user_agent
|
||||
user_agent=user_agent,
|
||||
correlation_id=correlation_id,
|
||||
)
|
||||
|
||||
# Process request
|
||||
@@ -56,7 +67,8 @@ class LoggingMiddleware(BaseHTTPMiddleware):
|
||||
path=request.url.path,
|
||||
duration_ms=duration_ms,
|
||||
error=str(e),
|
||||
client_ip=client_ip
|
||||
client_ip=client_ip,
|
||||
correlation_id=correlation_id,
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -74,7 +86,8 @@ class LoggingMiddleware(BaseHTTPMiddleware):
|
||||
path=request.url.path,
|
||||
status_code=response.status_code,
|
||||
duration_ms=duration_ms,
|
||||
user_id=user_id
|
||||
user_id=user_id,
|
||||
correlation_id=correlation_id,
|
||||
)
|
||||
|
||||
# Log response details if enabled
|
||||
@@ -84,7 +97,8 @@ class LoggingMiddleware(BaseHTTPMiddleware):
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers),
|
||||
size_bytes=response.headers.get("content-length"),
|
||||
content_type=response.headers.get("content-type")
|
||||
content_type=response.headers.get("content-type"),
|
||||
correlation_id=correlation_id,
|
||||
)
|
||||
|
||||
# Log slow requests as warnings
|
||||
@@ -94,7 +108,8 @@ class LoggingMiddleware(BaseHTTPMiddleware):
|
||||
method=request.method,
|
||||
path=request.url.path,
|
||||
duration_ms=duration_ms,
|
||||
status_code=response.status_code
|
||||
status_code=response.status_code,
|
||||
correlation_id=correlation_id,
|
||||
)
|
||||
|
||||
# Log authentication-related requests to auth log
|
||||
@@ -106,9 +121,16 @@ class LoggingMiddleware(BaseHTTPMiddleware):
|
||||
status_code=response.status_code,
|
||||
duration_ms=duration_ms,
|
||||
client_ip=client_ip,
|
||||
user_agent=user_agent
|
||||
user_agent=user_agent,
|
||||
correlation_id=correlation_id,
|
||||
)
|
||||
|
||||
|
||||
# Attach correlation id header to all responses
|
||||
try:
|
||||
response.headers["X-Correlation-ID"] = correlation_id
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return response
|
||||
|
||||
def get_client_ip(self, request: Request) -> str:
|
||||
|
||||
@@ -8,6 +8,7 @@ from .files import File
|
||||
from .ledger import Ledger
|
||||
from .qdro import QDRO
|
||||
from .audit import AuditLog, LoginAttempt
|
||||
from .auth import RefreshToken
|
||||
from .additional import Deposit, Payment, FileNote, FormVariable, ReportVariable, Document
|
||||
from .support import SupportTicket, TicketResponse, TicketStatus, TicketPriority, TicketCategory
|
||||
from .pensions import (
|
||||
@@ -22,7 +23,7 @@ from .lookups import (
|
||||
|
||||
__all__ = [
|
||||
"BaseModel", "User", "Rolodex", "Phone", "File", "Ledger", "QDRO",
|
||||
"AuditLog", "LoginAttempt",
|
||||
"AuditLog", "LoginAttempt", "RefreshToken",
|
||||
"Deposit", "Payment", "FileNote", "FormVariable", "ReportVariable", "Document",
|
||||
"SupportTicket", "TicketResponse", "TicketStatus", "TicketPriority", "TicketCategory",
|
||||
"Pension", "PensionSchedule", "MarriageHistory", "DeathBenefit",
|
||||
|
||||
34
app/models/auth.py
Normal file
34
app/models/auth.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""
|
||||
Authentication-related persistence models
|
||||
"""
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Column, Integer, String, DateTime, Boolean, ForeignKey, UniqueConstraint
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.models.base import BaseModel
|
||||
|
||||
|
||||
class RefreshToken(BaseModel):
|
||||
"""Persisted refresh tokens for revocation and auditing."""
|
||||
__tablename__ = "refresh_tokens"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True, index=True)
|
||||
user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True)
|
||||
jti = Column(String(64), nullable=False, unique=True, index=True)
|
||||
user_agent = Column(String(255), nullable=True)
|
||||
ip_address = Column(String(45), nullable=True)
|
||||
issued_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||
expires_at = Column(DateTime, nullable=False, index=True)
|
||||
revoked = Column(Boolean, default=False, nullable=False)
|
||||
revoked_at = Column(DateTime, nullable=True)
|
||||
|
||||
# relationships
|
||||
user = relationship("User")
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("jti", name="uq_refresh_tokens_jti"),
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user