From 1512b2d12a7c0a70b0d9969029439693bfd181e2 Mon Sep 17 00:00:00 2001 From: HotSwapp <47397945+HotSwapp@users.noreply.github.com> Date: Sun, 10 Aug 2025 21:34:11 -0500 Subject: [PATCH] all working --- README.md | 15 +- app/api/auth.py | 120 ++++++++-- app/api/documents.py | 76 +++++- app/auth/schemas.py | 8 +- app/auth/security.py | 98 ++++++-- app/config.py | 47 ++-- app/main.py | 5 + app/middleware/errors.py | 132 +++++++++++ app/middleware/logging.py | 40 +++- app/models/__init__.py | 3 +- app/models/auth.py | 34 +++ scripts/rotate-secret-key.py | 102 ++++++++ scripts/setup-security.py | 5 +- static/js/financial.js | 2 +- static/js/main.js | 447 ++++++++++++++++++++++++++++++++++- templates/admin.html | 18 +- templates/base.html | 369 +---------------------------- templates/documents.html | 227 +++++++++++++++--- templates/login.html | 7 +- test_customers.py | 18 ++ tests/test_auth.py | 87 +++++++ tests/test_error_handling.py | 82 +++++++ 22 files changed, 1453 insertions(+), 489 deletions(-) create mode 100644 app/middleware/errors.py create mode 100644 app/models/auth.py create mode 100644 scripts/rotate-secret-key.py create mode 100644 tests/test_auth.py create mode 100644 tests/test_error_handling.py diff --git a/README.md b/README.md index a89ebe4..2125576 100644 --- a/README.md +++ b/README.md @@ -187,6 +187,12 @@ delphi-database/ - Password hashing with bcrypt - Token expiration and refresh +JWT details: + +- Access token: returned by `POST /api/auth/login`, use in `Authorization: Bearer` header +- Refresh token: also returned on login; use `POST /api/auth/refresh` with body `{ "refresh_token": "..." }` to obtain a new access token. On refresh, the provided refresh token is revoked and a new one is issued. +- Legacy compatibility: `POST /api/auth/refresh` called without a body (but with Authorization header) will issue a new access token only. + ## 🗄️ Data Management - CSV import/export functionality - Database backup and restore @@ -194,14 +200,17 @@ delphi-database/ - Automatic financial calculations (matching legacy system) ## ⚙️ Configuration -Environment variables (create `.env` file): +Environment variables (create `.env` file). Real environment variables override `.env` which override defaults: ```bash # Database DATABASE_URL=sqlite:///./delphi_database.db -# Security +# Security SECRET_KEY=your-secret-key-change-in-production -ACCESS_TOKEN_EXPIRE_MINUTES=30 +# Optional previous key to allow rotation +PREVIOUS_SECRET_KEY= +ACCESS_TOKEN_EXPIRE_MINUTES=240 +REFRESH_TOKEN_EXPIRE_MINUTES=43200 # Application DEBUG=False diff --git a/app/api/auth.py b/app/api/auth.py index 4f96a3b..d6194ae 100644 --- a/app/api/auth.py +++ b/app/api/auth.py @@ -10,18 +10,23 @@ from sqlalchemy.orm import Session from app.database.base import get_db from app.models.user import User from app.auth.security import ( - authenticate_user, - create_access_token, + authenticate_user, + create_access_token, + create_refresh_token, + decode_refresh_token, + is_refresh_token_revoked, + revoke_refresh_token, get_password_hash, get_current_user, - get_admin_user + get_admin_user, ) from app.auth.schemas import ( - Token, - UserCreate, - UserResponse, + Token, + UserCreate, + UserResponse, LoginRequest, - ThemePreferenceUpdate + ThemePreferenceUpdate, + RefreshRequest, ) from app.config import settings from app.core.logging import get_logger, log_auth_attempt @@ -71,6 +76,12 @@ async def login(login_data: LoginRequest, request: Request, db: Session = Depend access_token = create_access_token( data={"sub": user.username}, expires_delta=access_token_expires ) + refresh_token = create_refresh_token( + user=user, + user_agent=request.headers.get("user-agent", ""), + ip_address=request.client.host if request.client else None, + db=db, + ) log_auth_attempt( username=login_data.username, @@ -85,7 +96,7 @@ async def login(login_data: LoginRequest, request: Request, db: Session = Depend client_ip=client_ip ) - return {"access_token": access_token, "token_type": "bearer"} + return {"access_token": access_token, "token_type": "bearer", "refresh_token": refresh_token} @router.post("/register", response_model=UserResponse) @@ -130,22 +141,76 @@ async def read_users_me(current_user: User = Depends(get_current_user)): @router.post("/refresh", response_model=Token) -async def refresh_token( - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) +async def refresh_token_endpoint( + request: Request, + db: Session = Depends(get_db), + body: RefreshRequest | None = None, ): - """Refresh access token for current user""" - # Update last login timestamp - current_user.last_login = datetime.utcnow() + """Issue a new access token using a valid, non-revoked refresh token. + + For backwards compatibility with existing clients that may call this without a body, + consider falling back to Authorization header in the future if needed. + """ + # New flow: refresh token in body + if body and body.refresh_token: + payload = decode_refresh_token(body.refresh_token) + if not payload: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token") + + jti = payload.get("jti") + username = payload.get("sub") + if not jti or not username: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token payload") + + # Verify token not revoked/expired + if is_refresh_token_revoked(jti, db): + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Refresh token revoked or expired") + + # Load user + user = db.query(User).filter(User.username == username).first() + if not user or not user.is_active: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found or inactive") + + # Rotate refresh token on use + revoke_refresh_token(jti, db) + new_refresh_token = create_refresh_token( + user=user, + user_agent=request.headers.get("user-agent", ""), + ip_address=request.client.host if request.client else None, + db=db, + ) + + # Issue new access token + access_token_expires = timedelta(minutes=settings.access_token_expire_minutes) + access_token = create_access_token( + data={"sub": user.username}, expires_delta=access_token_expires + ) + + return {"access_token": access_token, "token_type": "bearer", "refresh_token": new_refresh_token} + + # Legacy flow: Authorization header-based refresh + auth_header = request.headers.get("authorization") or request.headers.get("Authorization") + if not auth_header or not auth_header.lower().startswith("bearer "): + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing credentials") + + token = auth_header.split(" ", 1)[1].strip() + from app.auth.security import verify_token # local import to avoid circular + username = verify_token(token) + if not username: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token") + + user = db.query(User).filter(User.username == username).first() + if not user or not user.is_active: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found or inactive") + + user.last_login = datetime.utcnow() db.commit() - - # Create new token with full expiration time + access_token_expires = timedelta(minutes=settings.access_token_expire_minutes) access_token = create_access_token( - data={"sub": current_user.username}, - expires_delta=access_token_expires + data={"sub": user.username}, expires_delta=access_token_expires ) - + return {"access_token": access_token, "token_type": "bearer"} @@ -159,6 +224,23 @@ async def list_users( return users +@router.post("/logout") +async def logout(body: RefreshRequest | None = None, db: Session = Depends(get_db)): + """Revoke the provided refresh token. Idempotent and safe to call multiple times. + + The client should send a JSON body: { "refresh_token": "..." }. + """ + try: + if body and body.refresh_token: + payload = decode_refresh_token(body.refresh_token) + if payload and payload.get("jti"): + revoke_refresh_token(payload["jti"], db) + except Exception: + # Don't leak details; logout should be best-effort + pass + return {"status": "ok"} + + @router.post("/theme-preference") async def update_theme_preference( theme_data: ThemePreferenceUpdate, diff --git a/app/api/documents.py b/app/api/documents.py index 7b7203c..481ac9f 100644 --- a/app/api/documents.py +++ b/app/api/documents.py @@ -2,7 +2,7 @@ Document Management API endpoints - QDROs, Templates, and General Documents """ from typing import List, Optional, Dict, Any -from fastapi import APIRouter, Depends, HTTPException, status, Query, UploadFile, File, Form +from fastapi import APIRouter, Depends, HTTPException, status, Query, UploadFile, File, Form, Request from sqlalchemy.orm import Session, joinedload from sqlalchemy import or_, func, and_, desc, asc, text from datetime import date, datetime @@ -18,6 +18,8 @@ from app.models.lookups import FormIndex, FormList, Footer, Employee from app.models.user import User from app.auth.security import get_current_user from app.models.additional import Document +from app.core.logging import get_logger +from app.services.audit import audit_service router = APIRouter() @@ -666,6 +668,78 @@ def _merge_template_variables(content: str, variables: Dict[str, Any]) -> str: return merged +# --- Client Error Logging (for Documents page) --- +class ClientErrorLog(BaseModel): + """Payload for client-side error logging""" + message: str + action: Optional[str] = None + stack: Optional[str] = None + url: Optional[str] = None + line: Optional[int] = None + column: Optional[int] = None + user_agent: Optional[str] = None + extra: Optional[Dict[str, Any]] = None + + +@router.post("/client-error") +async def log_client_error( + payload: ClientErrorLog, + request: Request, + db: Session = Depends(get_db), + current_user: Optional[User] = Depends(lambda: None) +): + """Accept client-side error logs from the Documents page. + + This endpoint is lightweight and safe to call; it records the error to the + application logs and best-effort to the audit log without interrupting the UI. + """ + logger = get_logger("client.documents") + client_ip = request.headers.get("x-forwarded-for") + if client_ip: + client_ip = client_ip.split(",")[0].strip() + else: + client_ip = request.client.host if request.client else None + + logger.error( + "Client error reported", + action=payload.action, + message=payload.message, + stack=payload.stack, + page="/documents", + url=payload.url or str(request.url), + line=payload.line, + column=payload.column, + user=getattr(current_user, "username", None), + user_id=getattr(current_user, "id", None), + user_agent=payload.user_agent or request.headers.get("user-agent"), + client_ip=client_ip, + extra=payload.extra, + ) + + # Best-effort audit log; do not raise on failure + try: + audit_service.log_action( + db=db, + action="CLIENT_ERROR", + resource_type="DOCUMENTS", + user=current_user, + resource_id=None, + details={ + "action": payload.action, + "message": payload.message, + "url": payload.url or str(request.url), + "line": payload.line, + "column": payload.column, + "extra": payload.extra, + }, + request=request, + ) + except Exception: + pass + + return {"status": "logged"} + + @router.post("/upload/{file_no}") async def upload_document( file_no: str, diff --git a/app/auth/schemas.py b/app/auth/schemas.py index 880d146..537c6ae 100644 --- a/app/auth/schemas.py +++ b/app/auth/schemas.py @@ -45,6 +45,7 @@ class Token(BaseModel): """Token response schema""" access_token: str token_type: str + refresh_token: str | None = None class TokenData(BaseModel): @@ -55,4 +56,9 @@ class TokenData(BaseModel): class LoginRequest(BaseModel): """Login request schema""" username: str - password: str \ No newline at end of file + password: str + + +class RefreshRequest(BaseModel): + """Refresh token submission""" + refresh_token: str \ No newline at end of file diff --git a/app/auth/security.py b/app/auth/security.py index cdc49f2..dc069db 100644 --- a/app/auth/security.py +++ b/app/auth/security.py @@ -2,7 +2,8 @@ Authentication and security utilities """ from datetime import datetime, timedelta -from typing import Optional, Union +from typing import Optional, Union, Tuple +from uuid import uuid4 from jose import JWTError, jwt from passlib.context import CryptContext from sqlalchemy.orm import Session @@ -12,6 +13,7 @@ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from app.config import settings from app.database.base import get_db from app.models.user import User +from app.models.auth import RefreshToken # Password hashing pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") @@ -30,39 +32,107 @@ def get_password_hash(password: str) -> str: return pwd_context.hash(password) +def _encode_with_rotation(payload: dict) -> str: + """Encode JWT with active secret and algorithm.""" + return jwt.encode(payload, settings.secret_key, algorithm=settings.algorithm) + + +def _decode_with_rotation(token: str) -> dict: + """Decode JWT trying current then previous secret if set.""" + try: + return jwt.decode(token, settings.secret_key, algorithms=[settings.algorithm]) + except JWTError: + # Try previous secret to allow seamless rotation + if settings.previous_secret_key: + try: + return jwt.decode(token, settings.previous_secret_key, algorithms=[settings.algorithm]) + except JWTError: + pass + raise + + def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str: """Create JWT access token""" to_encode = data.copy() - if expires_delta: - expire = datetime.utcnow() + expires_delta - else: - expire = datetime.utcnow() + timedelta(minutes=settings.access_token_expire_minutes) - + expire = datetime.utcnow() + ( + expires_delta if expires_delta else timedelta(minutes=settings.access_token_expire_minutes) + ) to_encode.update({ "exp": expire, "iat": datetime.utcnow(), + "type": "access", }) - encoded_jwt = jwt.encode(to_encode, settings.secret_key, algorithm=settings.algorithm) - return encoded_jwt + return _encode_with_rotation(to_encode) + + +def create_refresh_token(user: User, user_agent: Optional[str], ip_address: Optional[str], db: Session) -> str: + """Create refresh token, store its JTI in DB for revocation.""" + jti = uuid4().hex + expire = datetime.utcnow() + timedelta(minutes=settings.refresh_token_expire_minutes) + payload = { + "sub": user.username, + "uid": user.id, + "jti": jti, + "type": "refresh", + "exp": expire, + "iat": datetime.utcnow(), + } + token = _encode_with_rotation(payload) + + db_token = RefreshToken( + user_id=user.id, + jti=jti, + user_agent=user_agent, + ip_address=ip_address, + issued_at=datetime.utcnow(), + expires_at=expire, + revoked=False, + ) + db.add(db_token) + db.commit() + return token def verify_token(token: str) -> Optional[str]: """Verify JWT token and return username""" try: - payload = jwt.decode( - token, - settings.secret_key, - algorithms=[settings.algorithm], - leeway=30 # allow small clock skew - ) + payload = _decode_with_rotation(token) username: str = payload.get("sub") + token_type: str = payload.get("type") if username is None: return None + # Only accept access tokens for auth + if token_type and token_type != "access": + return None return username except JWTError: return None +def decode_refresh_token(token: str) -> Optional[dict]: + """Decode refresh token and return payload if valid and not revoked.""" + try: + payload = _decode_with_rotation(token) + if payload.get("type") != "refresh": + return None + return payload + except JWTError: + return None + + +def is_refresh_token_revoked(jti: str, db: Session) -> bool: + token_row = db.query(RefreshToken).filter(RefreshToken.jti == jti).first() + return not token_row or token_row.revoked or token_row.expires_at <= datetime.utcnow() + + +def revoke_refresh_token(jti: str, db: Session) -> None: + token_row = db.query(RefreshToken).filter(RefreshToken.jti == jti).first() + if token_row and not token_row.revoked: + token_row.revoked = True + token_row.revoked_at = datetime.utcnow() + db.commit() + + def authenticate_user(db: Session, username: str, password: str) -> Optional[User]: """Authenticate user credentials""" user = db.query(User).filter(User.username == username).first() diff --git a/app/config.py b/app/config.py index 0b833a8..06e3d92 100644 --- a/app/config.py +++ b/app/config.py @@ -1,53 +1,70 @@ """ Delphi Consulting Group Database System - Configuration """ -from pydantic_settings import BaseSettings from typing import Optional +from pydantic import Field +from pydantic_settings import BaseSettings, SettingsConfigDict + class Settings(BaseSettings): - """Application configuration""" - + """Application configuration (env and .env driven). + + Environment precedence: real environment variables take priority over .env, + which take priority over defaults. + """ + # Application app_name: str = "Delphi Consulting Group Database System" app_version: str = "1.0.0" debug: bool = False - + # Database database_url: str = "sqlite:///./data/delphi_database.db" - - # Authentication - secret_key: str = "your-secret-key-change-in-production" + + # Authentication / JWT + # Require SECRET_KEY to be provided via environment/.env (no insecure default) + secret_key: str = Field(..., min_length=32) + # Optional previous secret key to allow seamless rotation + previous_secret_key: Optional[str] = None algorithm: str = "HS256" access_token_expire_minutes: int = 240 # 4 hours - + # Long-lived refresh token expiration (default 30 days) + refresh_token_expire_minutes: int = 43200 + # Admin account settings admin_username: str = "admin" admin_password: str = "change-me" - + # File paths upload_dir: str = "./uploads" backup_dir: str = "./backups" - + # Pagination default_page_size: int = 50 max_page_size: int = 200 - + # Docker/deployment settings external_port: Optional[str] = None allowed_hosts: Optional[str] = None cors_origins: Optional[str] = None secure_cookies: bool = False compose_project_name: Optional[str] = None - + # Logging log_level: str = "INFO" log_to_file: bool = True log_rotation: str = "10 MB" log_retention: str = "30 days" - - class Config: - env_file = ".env" + + # pydantic-settings v2 configuration + model_config = SettingsConfigDict( + env_file=".env", + env_prefix="", + case_sensitive=False, + env_ignore_empty=True, + extra="ignore", + ) settings = Settings() \ No newline at end of file diff --git a/app/main.py b/app/main.py index 015b6fa..5f24436 100644 --- a/app/main.py +++ b/app/main.py @@ -14,6 +14,7 @@ from app.models.user import User from app.auth.security import get_admin_user from app.core.logging import setup_logging, get_logger from app.middleware.logging import LoggingMiddleware +from app.middleware.errors import register_exception_handlers # Initialize logging setup_logging() @@ -35,6 +36,10 @@ app = FastAPI( logger.info("Adding request logging middleware") app.add_middleware(LoggingMiddleware, log_requests=True, log_responses=settings.debug) +# Register global exception handlers +logger.info("Registering global exception handlers") +register_exception_handlers(app) + # Configure CORS logger.info("Configuring CORS middleware") app.add_middleware( diff --git a/app/middleware/errors.py b/app/middleware/errors.py new file mode 100644 index 0000000..f1575cb --- /dev/null +++ b/app/middleware/errors.py @@ -0,0 +1,132 @@ +""" +Global exception handlers that return a consistent JSON error envelope +and propagate the X-Correlation-ID header. +""" +from __future__ import annotations + +from typing import Any, Optional +from uuid import uuid4 + +from fastapi import FastAPI, Request +from fastapi.exceptions import RequestValidationError +from fastapi.responses import JSONResponse +from starlette import status as http_status +from starlette.requests import Request as StarletteRequest +from starlette.responses import JSONResponse as StarletteJSONResponse +from fastapi import HTTPException + +try: + # Prefer project logger if available + from app.core.logging import get_logger + logger = get_logger("middleware.errors") +except Exception: # pragma: no cover - fallback simple logger + import logging + logger = logging.getLogger("middleware.errors") + + +ERROR_HEADER_NAME = "X-Correlation-ID" + + +def _get_correlation_id(request: Request) -> str: + """Resolve correlation ID from request state, headers, or generate a new one.""" + # From middleware + correlation_id: Optional[str] = getattr(getattr(request, "state", object()), "correlation_id", None) + if correlation_id: + return correlation_id + + # From incoming headers + correlation_id = ( + request.headers.get("x-correlation-id") + or request.headers.get("x-request-id") + ) + if correlation_id: + return correlation_id + + # Generate a new one if not present + return str(uuid4()) + + +def _build_error_response( + request: Request, + *, + status_code: int, + message: str, + code: str, + details: Any | None = None, +) -> JSONResponse: + correlation_id = _get_correlation_id(request) + + body = { + "success": False, + "error": { + "status": status_code, + "code": code, + "message": message, + }, + "correlation_id": correlation_id, + } + if details is not None: + body["error"]["details"] = details + + response = JSONResponse(content=body, status_code=status_code) + response.headers[ERROR_HEADER_NAME] = correlation_id + return response + + +async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse: + """Handle FastAPI HTTPException with envelope and correlation id.""" + message = exc.detail if isinstance(exc.detail, str) else "HTTP error" + logger.warning( + "HTTPException raised", + status_code=exc.status_code, + detail=message, + path=request.url.path, + ) + return _build_error_response( + request, + status_code=exc.status_code, + message=message, + code="http_error", + details=None, + ) + + +async def validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse: + """Handle validation errors from FastAPI/Pydantic.""" + logger.info( + "Validation error", + path=request.url.path, + errors=exc.errors(), + ) + return _build_error_response( + request, + status_code=http_status.HTTP_422_UNPROCESSABLE_ENTITY, + message="Validation error", + code="validation_error", + details=exc.errors(), + ) + + +async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONResponse: + """Catch-all handler for unexpected exceptions (500).""" + # Log full exception for diagnostics without leaking internals to clients + try: + logger.exception("Unhandled exception", path=request.url.path) + except Exception: + pass + return _build_error_response( + request, + status_code=http_status.HTTP_500_INTERNAL_SERVER_ERROR, + message="Internal Server Error", + code="internal_error", + details=None, + ) + + +def register_exception_handlers(app: FastAPI) -> None: + """Register global exception handlers on the provided FastAPI app.""" + app.add_exception_handler(HTTPException, http_exception_handler) + app.add_exception_handler(RequestValidationError, validation_exception_handler) + app.add_exception_handler(Exception, unhandled_exception_handler) + + diff --git a/app/middleware/logging.py b/app/middleware/logging.py index 4ab4de3..c13c9c9 100644 --- a/app/middleware/logging.py +++ b/app/middleware/logging.py @@ -4,6 +4,7 @@ Request/Response Logging Middleware import time import json from typing import Callable +from uuid import uuid4 from fastapi import Request, Response from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import StreamingResponse @@ -21,10 +22,19 @@ class LoggingMiddleware(BaseHTTPMiddleware): self.log_responses = log_responses async def dispatch(self, request: Request, call_next: Callable) -> Response: - # Skip logging for static files and health checks + # Correlation ID: use incoming header or generate a new one + correlation_id = request.headers.get("x-correlation-id") or request.headers.get("x-request-id") or str(uuid4()) + request.state.correlation_id = correlation_id + + # Skip logging for static files and health checks (still attach correlation id) skip_paths = ["/static/", "/uploads/", "/health", "/favicon.ico"] if any(request.url.path.startswith(path) for path in skip_paths): - return await call_next(request) + response = await call_next(request) + try: + response.headers["X-Correlation-ID"] = correlation_id + except Exception: + pass + return response # Record start time start_time = time.time() @@ -41,7 +51,8 @@ class LoggingMiddleware(BaseHTTPMiddleware): path=request.url.path, query_params=str(request.query_params) if request.query_params else None, client_ip=client_ip, - user_agent=user_agent + user_agent=user_agent, + correlation_id=correlation_id, ) # Process request @@ -56,7 +67,8 @@ class LoggingMiddleware(BaseHTTPMiddleware): path=request.url.path, duration_ms=duration_ms, error=str(e), - client_ip=client_ip + client_ip=client_ip, + correlation_id=correlation_id, ) raise @@ -74,7 +86,8 @@ class LoggingMiddleware(BaseHTTPMiddleware): path=request.url.path, status_code=response.status_code, duration_ms=duration_ms, - user_id=user_id + user_id=user_id, + correlation_id=correlation_id, ) # Log response details if enabled @@ -84,7 +97,8 @@ class LoggingMiddleware(BaseHTTPMiddleware): status_code=response.status_code, headers=dict(response.headers), size_bytes=response.headers.get("content-length"), - content_type=response.headers.get("content-type") + content_type=response.headers.get("content-type"), + correlation_id=correlation_id, ) # Log slow requests as warnings @@ -94,7 +108,8 @@ class LoggingMiddleware(BaseHTTPMiddleware): method=request.method, path=request.url.path, duration_ms=duration_ms, - status_code=response.status_code + status_code=response.status_code, + correlation_id=correlation_id, ) # Log authentication-related requests to auth log @@ -106,9 +121,16 @@ class LoggingMiddleware(BaseHTTPMiddleware): status_code=response.status_code, duration_ms=duration_ms, client_ip=client_ip, - user_agent=user_agent + user_agent=user_agent, + correlation_id=correlation_id, ) - + + # Attach correlation id header to all responses + try: + response.headers["X-Correlation-ID"] = correlation_id + except Exception: + pass + return response def get_client_ip(self, request: Request) -> str: diff --git a/app/models/__init__.py b/app/models/__init__.py index 108063c..a02cfe0 100644 --- a/app/models/__init__.py +++ b/app/models/__init__.py @@ -8,6 +8,7 @@ from .files import File from .ledger import Ledger from .qdro import QDRO from .audit import AuditLog, LoginAttempt +from .auth import RefreshToken from .additional import Deposit, Payment, FileNote, FormVariable, ReportVariable, Document from .support import SupportTicket, TicketResponse, TicketStatus, TicketPriority, TicketCategory from .pensions import ( @@ -22,7 +23,7 @@ from .lookups import ( __all__ = [ "BaseModel", "User", "Rolodex", "Phone", "File", "Ledger", "QDRO", - "AuditLog", "LoginAttempt", + "AuditLog", "LoginAttempt", "RefreshToken", "Deposit", "Payment", "FileNote", "FormVariable", "ReportVariable", "Document", "SupportTicket", "TicketResponse", "TicketStatus", "TicketPriority", "TicketCategory", "Pension", "PensionSchedule", "MarriageHistory", "DeathBenefit", diff --git a/app/models/auth.py b/app/models/auth.py new file mode 100644 index 0000000..2fa1d26 --- /dev/null +++ b/app/models/auth.py @@ -0,0 +1,34 @@ +""" +Authentication-related persistence models +""" +from datetime import datetime +from typing import Optional + +from sqlalchemy import Column, Integer, String, DateTime, Boolean, ForeignKey, UniqueConstraint +from sqlalchemy.orm import relationship + +from app.models.base import BaseModel + + +class RefreshToken(BaseModel): + """Persisted refresh tokens for revocation and auditing.""" + __tablename__ = "refresh_tokens" + + id = Column(Integer, primary_key=True, autoincrement=True, index=True) + user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True) + jti = Column(String(64), nullable=False, unique=True, index=True) + user_agent = Column(String(255), nullable=True) + ip_address = Column(String(45), nullable=True) + issued_at = Column(DateTime, default=datetime.utcnow, nullable=False) + expires_at = Column(DateTime, nullable=False, index=True) + revoked = Column(Boolean, default=False, nullable=False) + revoked_at = Column(DateTime, nullable=True) + + # relationships + user = relationship("User") + + __table_args__ = ( + UniqueConstraint("jti", name="uq_refresh_tokens_jti"), + ) + + diff --git a/scripts/rotate-secret-key.py b/scripts/rotate-secret-key.py new file mode 100644 index 0000000..fddb4c7 --- /dev/null +++ b/scripts/rotate-secret-key.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python3 +""" +Rotate SECRET_KEY in .env with seamless fallback using PREVIOUS_SECRET_KEY. + +Usage: + python scripts/rotate-secret-key.py [--env-path .env] + +Behavior: +- Reads the .env file (default .env) +- Sets PREVIOUS_SECRET_KEY to current SECRET_KEY +- Generates a new SECRET_KEY +- Preserves other variables +- Writes back atomically and sets file mode 600 +""" +from __future__ import annotations + +import argparse +import os +import secrets +import tempfile +from pathlib import Path + + +def generate_secret_key(length: int = 32) -> str: + return secrets.token_urlsafe(length) + + +def parse_env(contents: str) -> dict[str, str]: + env: dict[str, str] = {} + for line in contents.splitlines(): + if not line or line.strip().startswith("#"): + continue + if "=" not in line: + continue + key, value = line.split("=", 1) + env[key.strip()] = value.strip() + return env + + +def render_env(original: str, updates: dict[str, str]) -> str: + lines = original.splitlines() + seen_keys: set[str] = set() + out_lines: list[str] = [] + for line in lines: + if not line or line.strip().startswith("#") or "=" not in line: + out_lines.append(line) + continue + key, _ = line.split("=", 1) + k = key.strip() + if k in updates: + out_lines.append(f"{k}={updates[k]}") + seen_keys.add(k) + else: + out_lines.append(line) + seen_keys.add(k) + # Append any new keys not present originally + for k, v in updates.items(): + if k not in seen_keys: + out_lines.append(f"{k}={v}") + return "\n".join(out_lines) + "\n" + + +def main() -> None: + parser = argparse.ArgumentParser(description="Rotate SECRET_KEY in .env") + parser.add_argument("--env-path", default=".env", help="Path to .env file") + args = parser.parse_args() + + env_path = Path(args.env_path) + if not env_path.exists(): + raise SystemExit(f".env file not found at {env_path}") + + original = env_path.read_text() + env = parse_env(original) + + current_secret = env.get("SECRET_KEY") + if not current_secret: + raise SystemExit("SECRET_KEY not found in .env") + + new_secret = generate_secret_key(32) + updates = { + "PREVIOUS_SECRET_KEY": current_secret, + "SECRET_KEY": new_secret, + } + + rendered = render_env(original, updates) + + # Atomic write + with tempfile.NamedTemporaryFile("w", delete=False, dir=str(env_path.parent)) as tmp: + tmp.write(rendered) + temp_name = tmp.name + os.replace(temp_name, env_path) + os.chmod(env_path, 0o600) + + print("✅ SECRET_KEY rotated successfully.") + print(" PREVIOUS_SECRET_KEY updated for seamless token validation.") + print(" Restart the application to apply the new key.") + + +if __name__ == "__main__": + main() + + diff --git a/scripts/setup-security.py b/scripts/setup-security.py index 9f47905..df4a8c8 100755 --- a/scripts/setup-security.py +++ b/scripts/setup-security.py @@ -63,7 +63,10 @@ DATABASE_URL=sqlite:///data/delphi_database.db # ===== SECURITY SETTINGS - GENERATED ===== SECRET_KEY={secret_key} -ACCESS_TOKEN_EXPIRE_MINUTES=30 +# Optional previous key for seamless rotation (leave blank initially) +PREVIOUS_SECRET_KEY= +ACCESS_TOKEN_EXPIRE_MINUTES=240 +REFRESH_TOKEN_EXPIRE_MINUTES=43200 ALGORITHM=HS256 # ===== ADMIN USER CREATION ===== diff --git a/static/js/financial.js b/static/js/financial.js index 237eb1a..acc3655 100644 --- a/static/js/financial.js +++ b/static/js/financial.js @@ -6,7 +6,7 @@ // ... add the JS content ... -// Modify modal showing/hiding to use classList.add/remove('hidden') instead of Bootstrap modal +// Modify modal showing/hiding to use classList.add/remove('hidden') // For example: function showQuickTimeModal() { diff --git a/static/js/main.js b/static/js/main.js index ef69b3e..7697840 100644 --- a/static/js/main.js +++ b/static/js/main.js @@ -5,6 +5,7 @@ // Global application state const app = { token: localStorage.getItem('auth_token'), + refreshToken: localStorage.getItem('refresh_token'), user: null, initialized: false, refreshTimerId: null, @@ -13,9 +14,98 @@ const app = { // Initialize application document.addEventListener('DOMContentLoaded', function() { + try { setupGlobalErrorHandlers(); } catch (_) {} initializeApp(); }); +// Theme Management (centralized) +function applyTheme(theme) { + const html = document.documentElement; + const isDark = theme === 'dark'; + html.classList.toggle('dark', isDark); + html.setAttribute('data-theme', isDark ? 'dark' : 'light'); +} + +function toggleTheme() { + const html = document.documentElement; + const nextTheme = html.classList.contains('dark') ? 'light' : 'dark'; + applyTheme(nextTheme); + try { localStorage.setItem('theme-preference', nextTheme); } catch (_) {} + saveThemePreference(nextTheme); +} + +function initializeTheme() { + // Check for saved theme preference + let savedTheme = null; + try { savedTheme = localStorage.getItem('theme-preference'); } catch (_) {} + const prefersDark = window.matchMedia && window.matchMedia('(prefers-color-scheme: dark)').matches; + + const theme = savedTheme || (prefersDark ? 'dark' : 'light'); + applyTheme(theme); + + // Load from server if available + loadUserThemePreference(); + + // Listen for OS theme changes if no explicit preference is set + attachSystemThemeListener(); +} + +async function saveThemePreference(theme) { + const token = localStorage.getItem('auth_token'); + if (!token || isLoginPage()) return; + try { + await fetch('/api/auth/theme-preference', { + method: 'POST', + headers: { + 'Authorization': `Bearer ${token}`, + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ theme_preference: theme }) + }); + } catch (error) { + console.log('Could not save theme preference to server:', error.message); + } +} + +function attachSystemThemeListener() { + if (!('matchMedia' in window)) return; + const media = window.matchMedia('(prefers-color-scheme: dark)'); + const handleChange = (e) => { + let savedTheme = null; + try { savedTheme = localStorage.getItem('theme-preference'); } catch (_) {} + if (!savedTheme || savedTheme === 'system') { + applyTheme(e.matches ? 'dark' : 'light'); + } + }; + if (typeof media.addEventListener === 'function') { + media.addEventListener('change', handleChange); + } else if (typeof media.addListener === 'function') { + media.addListener(handleChange); // Safari fallback + } +} + +async function loadUserThemePreference() { + const token = localStorage.getItem('auth_token'); + if (!token || isLoginPage()) return; + try { + const response = await fetch('/api/auth/me', { + headers: { 'Authorization': `Bearer ${token}` } + }); + if (response.ok) { + const user = await response.json(); + if (user.theme_preference) { + applyTheme(user.theme_preference); + try { localStorage.setItem('theme-preference', user.theme_preference); } catch (_) {} + } + } + } catch (error) { + console.log('Could not load theme preference from server:', error.message); + } +} + +// Apply theme immediately on script load +try { initializeTheme(); } catch (_) {} + async function initializeApp() { // Initialize keyboard shortcuts if (window.keyboardShortcuts) { @@ -29,6 +119,11 @@ async function initializeApp() { // Initialize API helpers setupAPIHelpers(); + + // Initialize authentication manager (centralized) + if (typeof initializeAuthManager === 'function') { + initializeAuthManager(); + } app.initialized = true; console.log('Delphi Database System initialized'); @@ -103,6 +198,14 @@ async function apiCall(url, options = {}) { try { let response = await fetch(url, config); + const updateCorrelationFromResponse = (resp) => { + try { + const cid = resp && resp.headers ? resp.headers.get('X-Correlation-ID') : null; + if (cid) { window.app.lastCorrelationId = cid; } + return cid; + } catch (_) { return null; } + }; + let lastCorrelationId = updateCorrelationFromResponse(response); if (response.status === 401 && app.token) { // Attempt one refresh then retry once @@ -113,6 +216,7 @@ async function apiCall(url, options = {}) { ...options }; response = await fetch(url, retryConfig); + lastCorrelationId = updateCorrelationFromResponse(response); } catch (_) { // fall through to logout below } @@ -124,7 +228,10 @@ async function apiCall(url, options = {}) { if (!response.ok) { const errorData = await response.json().catch(() => ({ detail: 'Request failed' })); - throw new Error(errorData.detail || `HTTP ${response.status}`); + const err = new Error(errorData.detail || `HTTP ${response.status}`); + err.status = response.status; + err.correlationId = lastCorrelationId || null; + throw err; } return await response.json(); @@ -158,16 +265,263 @@ async function apiDelete(url) { } // Authentication functions -function setAuthToken(token) { - app.token = token; - localStorage.setItem('auth_token', token); - window.apiHeaders['Authorization'] = `Bearer ${token}`; - +function setAuthTokens(accessToken, newRefreshToken = null) { + if (accessToken) { + app.token = accessToken; + localStorage.setItem('auth_token', accessToken); + window.apiHeaders['Authorization'] = `Bearer ${accessToken}`; + } + if (newRefreshToken) { + app.refreshToken = newRefreshToken; + localStorage.setItem('refresh_token', newRefreshToken); + } // Reschedule refresh on token update - scheduleTokenRefresh(); + if (accessToken) { + scheduleTokenRefresh(); + } } -function logout() { +function setAuthToken(token) { + // Backwards compatibility + setAuthTokens(token, null); +} + +// Page helpers +function isLoginPage() { + const path = window.location.pathname; + return path === '/login' || path === '/'; +} + +// Verify the current access token by hitting /api/auth/me +async function checkTokenValidity() { + const token = localStorage.getItem('auth_token'); + if (!token) return false; + try { + const response = await fetch('/api/auth/me', { + headers: { 'Authorization': `Bearer ${token}` } + }); + if (!response.ok) { + // Invalid token + return false; + } + // Cache user for later UI updates + try { app.user = await response.json(); } catch (_) {} + return true; + } catch (error) { + console.error('Error checking token validity:', error); + return false; + } +} + +// Try to refresh token if refresh token is present; fallback to validity check+logout +async function refreshTokenIfNeeded() { + const refreshTokenValue = localStorage.getItem('refresh_token'); + if (!refreshTokenValue) return; + app.refreshToken = refreshTokenValue; + try { + await refreshToken(); + console.log('Token refreshed successfully'); + } catch (error) { + const stillValid = await checkTokenValidity(); + if (!stillValid) { + await logout('Session expired or invalid token'); + } + } +} + +// Update UI elements that are permission/user dependent +async function checkUserPermissions() { + const token = localStorage.getItem('auth_token'); + if (!token || isLoginPage()) return; + try { + const response = await fetch('/api/auth/me', { + headers: { 'Authorization': `Bearer ${token}` } + }); + if (response.ok) { + const user = await response.json(); + app.user = user; + if (user.is_admin) { + const adminItem = document.getElementById('admin-menu-item'); + const adminDivider = document.getElementById('admin-menu-divider'); + if (adminItem) adminItem.classList.remove('hidden'); + if (adminDivider) adminDivider.classList.remove('hidden'); + } + const userDropdownName = document.querySelector('#userDropdown button span'); + if (user.full_name && userDropdownName) { + userDropdownName.textContent = user.full_name; + } + } + } catch (error) { + console.error('Error checking user permissions:', error); + } +} + +// Inactivity monitoring & session extension UI +async function getInactivityWarningMinutes() { + const token = localStorage.getItem('auth_token'); + if (!token) return 240; + try { + const resp = await fetch('/api/settings/inactivity_warning_minutes', { + headers: { 'Authorization': `Bearer ${token}` } + }); + if (!resp.ok) return 240; + const data = await resp.json(); + if (typeof data.minutes === 'number') return data.minutes; + const parsed = parseInt(data.setting_value || data.minutes, 10); + return Number.isFinite(parsed) ? parsed : 240; + } catch (_) { + return 240; + } +} + +function showSessionExtendedNotification() { + if (window.alerts && typeof window.alerts.success === 'function') { + window.alerts.success('Your session has been refreshed successfully.', { + title: 'Session Extended', + duration: 3000 + }); + return; + } + // Fallback + const notification = document.createElement('div'); + notification.className = 'fixed top-4 right-4 bg-green-100 border-l-4 border-green-500 text-green-700 p-4 rounded-lg shadow-lg z-50 max-w-sm'; + notification.innerHTML = ` +
+
+ +
+
+

Session Extended

+

Your session has been refreshed successfully.

+
+
+ `; + document.body.appendChild(notification); + setTimeout(() => notification.remove(), 3000); +} + +function setupActivityMonitoring() { + let lastActivity = Date.now(); + let warningShown = false; + let inactivityWarningMinutes = 240; // default 4 hours + const inactivityGraceMinutes = 5; // auto-logout after warning + 5 minutes + + // Fetch setting (best effort) + getInactivityWarningMinutes().then(minutes => { + if (Number.isFinite(minutes) && minutes > 0) { + inactivityWarningMinutes = minutes; + } + }).catch(() => {}); + + function hideInactivityWarning() { + const el = document.getElementById('inactivity-warning'); + if (el && el.remove) el.remove(); + } + + function extendSession() { + refreshTokenIfNeeded(); + hideInactivityWarning(); + showSessionExtendedNotification(); + } + + function showInactivityWarning() { + hideInactivityWarning(); + const msg = `You've been inactive. Your session may expire due to inactivity.`; + if (window.alerts && typeof window.alerts.show === 'function') { + window.alerts.show(msg, 'warning', { + title: 'Session Warning', + html: false, + duration: 0, + dismissible: true, + id: 'inactivity-warning', + actions: [ + { + label: 'Stay Logged In', + classes: 'bg-warning-600 hover:bg-warning-700 text-white text-xs px-3 py-1 rounded', + onClick: () => extendSession(), + autoClose: true + }, + { + label: 'Dismiss', + classes: 'bg-neutral-200 hover:bg-neutral-300 text-neutral-800 dark:bg-neutral-700 dark:hover:bg-neutral-600 dark:text-neutral-200 text-xs px-3 py-1 rounded', + onClick: () => hideInactivityWarning(), + autoClose: true + } + ] + }); + } else { + alert('Session Warning: ' + msg); + } + // Auto-hide after 2 minutes + setTimeout(() => hideInactivityWarning(), 2 * 60 * 1000); + } + + // Track user activity + const activityEvents = ['mousedown', 'mousemove', 'keypress', 'scroll', 'touchstart']; + activityEvents.forEach(event => { + document.addEventListener(event, () => { + lastActivity = Date.now(); + warningShown = false; + const el = document.getElementById('inactivity-warning'); + if (el && el.remove) el.remove(); + }); + }); + + // Check every 5 minutes for inactivity + setInterval(() => { + const now = Date.now(); + const warningMs = inactivityWarningMinutes * 60 * 1000; + const logoutMs = (inactivityWarningMinutes + inactivityGraceMinutes) * 60 * 1000; + const timeSinceActivity = now - lastActivity; + if (timeSinceActivity > warningMs && !warningShown) { + showInactivityWarning(); + warningShown = true; + } + if (timeSinceActivity > logoutMs) { + logout('Session expired due to inactivity'); + } + }, 5 * 60 * 1000); +} + +// Central initializer for auth +async function initializeAuthManager() { + const token = localStorage.getItem('auth_token'); + // If on the login page, do nothing + if (isLoginPage()) return; + if (token) { + // Align in-memory/app state with stored tokens + app.token = token; + const storedRefresh = localStorage.getItem('refresh_token'); + if (storedRefresh) app.refreshToken = storedRefresh; + + // Verify token and schedule refresh + checkTokenValidity(); + scheduleTokenRefresh(); + // Start inactivity monitoring + setupActivityMonitoring(); + // Update UI according to user permissions + checkUserPermissions(); + } else { + // No token and not on login page - redirect to login + window.location.href = '/login'; + } +} + +async function logout(reason = null) { + // Best-effort revoke refresh token server-side + const rtoken = localStorage.getItem('refresh_token'); + try { + if (rtoken) { + await fetch('/api/auth/logout', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ refresh_token: rtoken }) + }); + } + } catch (_) { + // ignore + } + if (app.refreshTimerId) { clearTimeout(app.refreshTimerId); app.refreshTimerId = null; @@ -176,7 +530,12 @@ function logout() { app.token = null; app.user = null; localStorage.removeItem('auth_token'); + localStorage.removeItem('refresh_token'); delete window.apiHeaders['Authorization']; + + if (reason) { + try { sessionStorage.setItem('logout_reason', reason); } catch (_) {} + } window.location.href = '/login'; } @@ -404,6 +763,70 @@ window.apiPut = apiPut; window.apiDelete = apiDelete; window.formatCurrency = formatCurrency; window.formatDate = formatDate; +window.toggleTheme = toggleTheme; +window.initializeTheme = initializeTheme; +window.saveThemePreference = saveThemePreference; +window.loadUserThemePreference = loadUserThemePreference; + +// Global error handling +function setupGlobalErrorHandlers() { + // Handle unexpected runtime errors + window.addEventListener('error', function(event) { + try { + const payload = { + message: event && event.message ? String(event.message) : 'Unhandled error', + action: 'window.onerror', + stack: event && event.error && event.error.stack ? String(event.error.stack) : null, + url: (event && event.filename) ? String(event.filename) : String(window.location.href), + line: event && typeof event.lineno === 'number' ? event.lineno : null, + column: event && typeof event.colno === 'number' ? event.colno : null, + user_agent: navigator.userAgent, + extra: { + page: window.location.pathname, + lastCorrelationId: (window.app && window.app.lastCorrelationId) || null + } + }; + postClientError(payload); + } catch (_) {} + }); + + // Handle unhandled promise rejections + window.addEventListener('unhandledrejection', function(event) { + try { + const reason = event && event.reason ? event.reason : null; + const payload = { + message: reason && reason.message ? String(reason.message) : 'Unhandled promise rejection', + action: 'window.unhandledrejection', + stack: reason && reason.stack ? String(reason.stack) : null, + url: String(window.location.href), + user_agent: navigator.userAgent, + extra: { + page: window.location.pathname, + reasonType: reason ? (reason.name || typeof reason) : null, + status: reason && typeof reason.status === 'number' ? reason.status : null, + correlationId: reason && reason.correlationId ? reason.correlationId : ((window.app && window.app.lastCorrelationId) || null) + } + }; + postClientError(payload); + } catch (_) {} + }); +} + +async function postClientError(payload) { + try { + const headers = { 'Content-Type': 'application/json', 'Accept': 'application/json' }; + const token = (window.app && window.app.token) || localStorage.getItem('auth_token'); + if (token) headers['Authorization'] = `Bearer ${token}`; + // Fire-and-forget; do not block UI + fetch('/api/documents/client-error', { + method: 'POST', + headers, + body: JSON.stringify(payload) + }).catch(() => {}); + } catch (_) { + // no-op + } +} // JWT utilities and refresh handling function decodeJwt(token) { @@ -448,13 +871,14 @@ function scheduleTokenRefresh() { } async function refreshToken() { - if (!app.token) throw new Error('No token to refresh'); + if (!app.refreshToken) throw new Error('No refresh token available'); if (app.refreshInProgress) return; // Avoid parallel refreshes app.refreshInProgress = true; try { const response = await fetch('/api/auth/refresh', { method: 'POST', - headers: { ...window.apiHeaders } + headers: { 'Content-Type': 'application/json', 'Accept': 'application/json' }, + body: JSON.stringify({ refresh_token: app.refreshToken }) }); if (!response.ok) { throw new Error('Refresh failed'); @@ -463,7 +887,8 @@ async function refreshToken() { if (!data || !data.access_token) { throw new Error('Invalid refresh response'); } - setAuthToken(data.access_token); + // Handle refresh token rotation if provided + setAuthTokens(data.access_token, data.refresh_token || null); } finally { app.refreshInProgress = false; } diff --git a/templates/admin.html b/templates/admin.html index 4d18746..372cfb1 100644 --- a/templates/admin.html +++ b/templates/admin.html @@ -55,25 +55,25 @@ @@ -951,7 +951,7 @@ function initializeTabs() { const panes = document.querySelectorAll('#adminTabContent > div[role="tabpanel"]'); tabs.forEach(tab => { tab.addEventListener('click', () => { - const target = tab.getAttribute('data-target'); + const target = tab.getAttribute('data-tab-target'); if (!target) return; // deactivate tabs.forEach(t => t.classList.remove('active')); @@ -996,7 +996,7 @@ document.addEventListener('DOMContentLoaded', function() { loadSettings(); loadLookupTables(); loadBackups(); - // Tabs setup (no Bootstrap JS) + // Tabs setup initializeTabs(); // Auto-refresh every 5 minutes diff --git a/templates/base.html b/templates/base.html index e6a9c62..702b6f7 100644 --- a/templates/base.html +++ b/templates/base.html @@ -1,5 +1,5 @@ - + @@ -347,12 +347,7 @@ } }); - // Handle escape key for modal - document.addEventListener('keydown', function(event) { - if (event.key === 'Escape') { - closeShortcutsModal(); - } - }); + // Escape handling is centralized in keyboard-shortcuts.js + + - {% block extra_scripts %}{% endblock %} \ No newline at end of file diff --git a/templates/documents.html b/templates/documents.html index e886dcc..16533cd 100644 --- a/templates/documents.html +++ b/templates/documents.html @@ -435,25 +435,51 @@ {% endblock %} \ No newline at end of file diff --git a/templates/login.html b/templates/login.html index 8417cf6..7c01111 100644 --- a/templates/login.html +++ b/templates/login.html @@ -1,5 +1,5 @@ - + @@ -124,8 +124,11 @@ const data = await response.json(); console.log('Login successful, token:', data.access_token); - // Store token + // Store tokens localStorage.setItem('auth_token', data.access_token); + if (data.refresh_token) { + localStorage.setItem('refresh_token', data.refresh_token); + } // Show success message showAlert('Login successful! Redirecting...', 'success'); diff --git a/test_customers.py b/test_customers.py index 4c420c3..c49addd 100644 --- a/test_customers.py +++ b/test_customers.py @@ -3,11 +3,29 @@ Test script for the customers module """ import requests +import pytest import json from datetime import datetime BASE_URL = "http://localhost:6920" + +@pytest.fixture(scope="module") +def token(): + """Obtain an access token from the running server, or skip if unavailable.""" + try: + response = requests.post(f"{BASE_URL}/api/auth/login", json={ + "username": "admin", + "password": "admin123" + }, timeout=3) + if response.status_code == 200: + data = response.json() + if data and data.get("access_token"): + return data["access_token"] + except Exception: + pass + pytest.skip("Auth server not available; skipping integration tests") + def test_auth(): """Test authentication""" print("🔐 Testing authentication...") diff --git a/tests/test_auth.py b/tests/test_auth.py new file mode 100644 index 0000000..d4a41ac --- /dev/null +++ b/tests/test_auth.py @@ -0,0 +1,87 @@ +import os +import sys +from pathlib import Path +from datetime import datetime, timedelta + +import pytest +from jose import jwt +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +ROOT = Path(__file__).resolve().parents[1] +if str(ROOT) not in sys.path: + sys.path.insert(0, str(ROOT)) + +from app.config import Settings, settings +from app.database.base import Base +from app.models.user import User +from app.auth.security import ( + create_access_token, + create_refresh_token, + decode_refresh_token, + is_refresh_token_revoked, + revoke_refresh_token, +) + + +def test_settings_env_precedence(monkeypatch): + # Ensure env var overrides .env/default + monkeypatch.setenv("SECRET_KEY", "env_secret_value_12345678901234567890123456789012") + cfg = Settings() + assert cfg.secret_key == "env_secret_value_12345678901234567890123456789012" + + +def test_jwt_rotation_decode(monkeypatch): + # Simulate key rotation: token signed with previous key should validate + old_key = "old_secret_value_12345678901234567890123456789012" + new_key = "new_secret_value_12345678901234567890123456789012" + + # Patch runtime settings + settings.previous_secret_key = old_key + settings.secret_key = new_key + + # Sign token with old key + payload = { + "sub": "tester", + "exp": datetime.utcnow() + timedelta(minutes=5), + "iat": datetime.utcnow(), + "type": "access", + } + token = jwt.encode(payload, old_key, algorithm=settings.algorithm) + + # Verify using public API verify_token via access token creation roundtrip + # Using internal decode through create_access_token is indirect; ensure no exception + from app.auth.security import verify_token + + username = verify_token(token) + assert username == "tester" + + +def test_refresh_token_lifecycle(tmp_path): + # Build isolated in-memory database + engine = create_engine("sqlite:///:memory:", connect_args={"check_same_thread": False}) + TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + Base.metadata.create_all(bind=engine) + + db = TestingSessionLocal() + try: + # Create a user + user = User(username="alice", email="a@example.com", hashed_password="x", is_active=True) + db.add(user) + db.commit() + db.refresh(user) + + # Issue refresh token + rtoken = create_refresh_token(user=user, user_agent="pytest", ip_address="127.0.0.1", db=db) + payload = decode_refresh_token(rtoken) + assert payload is not None + jti = payload["jti"] + assert not is_refresh_token_revoked(jti, db) + + # Revoke and assert + revoke_refresh_token(jti, db) + assert is_refresh_token_revoked(jti, db) + finally: + db.close() + + diff --git a/tests/test_error_handling.py b/tests/test_error_handling.py new file mode 100644 index 0000000..c8ae320 --- /dev/null +++ b/tests/test_error_handling.py @@ -0,0 +1,82 @@ +import json +from typing import Optional + +from fastapi import FastAPI, HTTPException +from fastapi.testclient import TestClient +from pydantic import BaseModel, Field + +from app.middleware.logging import LoggingMiddleware +from app.middleware.errors import register_exception_handlers + + +class Item(BaseModel): + name: str = Field(..., min_length=3) + quantity: int = Field(..., ge=1) + + +def build_test_app() -> FastAPI: + app = FastAPI() + app.add_middleware(LoggingMiddleware, log_requests=False, log_responses=False) + register_exception_handlers(app) + + @app.get("/http-error") + async def http_error(): + raise HTTPException(status_code=403, detail="Forbidden action") + + @app.post("/validation") + async def validation_endpoint(item: Item): # noqa: F841 + return {"ok": True} + + @app.get("/crash") + async def crash(): + raise RuntimeError("Boom") + + return app + + +def assert_envelope(resp, status: int, code: str, has_details: bool, expected_cid: Optional[str] = None): + assert resp.status_code == status + data = resp.json() + assert data["success"] is False + assert data["error"]["status"] == status + assert data["error"]["code"] == code + assert isinstance(data["error"]["message"], str) and data["error"]["message"] + if has_details: + assert "details" in data["error"] + else: + assert "details" not in data["error"] + + # Correlation id in body and header + assert "correlation_id" in data and isinstance(data["correlation_id"], str) + header_cid = resp.headers.get("X-Correlation-ID") + assert header_cid == data["correlation_id"] + if expected_cid is not None: + assert header_cid == expected_cid + + +def test_http_exception_envelope_and_correlation_id_echo(): + app = build_test_app() + client = TestClient(app) + cid = "abc-12345-test" + resp = client.get("/http-error", headers={"X-Correlation-ID": cid}) + assert_envelope(resp, 403, "http_error", has_details=False, expected_cid=cid) + + +def test_validation_exception_envelope_and_correlation_id_echo(): + app = build_test_app() + client = TestClient(app) + cid = "cid-validation-67890" + # Missing fields to trigger 422 + resp = client.post("/validation", json={"name": "ab"}, headers={"X-Correlation-ID": cid}) + assert_envelope(resp, 422, "validation_error", has_details=True, expected_cid=cid) + + +def test_unhandled_exception_envelope_and_generated_correlation_id(): + app = build_test_app() + client = TestClient(app, raise_server_exceptions=False) + resp = client.get("/crash") + # Should have generated a correlation id and not echo None + assert_envelope(resp, 500, "internal_error", has_details=False) + assert resp.headers.get("X-Correlation-ID") + +