diff --git a/README.md b/README.md index a89ebe4..2125576 100644 --- a/README.md +++ b/README.md @@ -187,6 +187,12 @@ delphi-database/ - Password hashing with bcrypt - Token expiration and refresh +JWT details: + +- Access token: returned by `POST /api/auth/login`, use in `Authorization: Bearer` header +- Refresh token: also returned on login; use `POST /api/auth/refresh` with body `{ "refresh_token": "..." }` to obtain a new access token. On refresh, the provided refresh token is revoked and a new one is issued. +- Legacy compatibility: `POST /api/auth/refresh` called without a body (but with Authorization header) will issue a new access token only. + ## 🗄️ Data Management - CSV import/export functionality - Database backup and restore @@ -194,14 +200,17 @@ delphi-database/ - Automatic financial calculations (matching legacy system) ## ⚙️ Configuration -Environment variables (create `.env` file): +Environment variables (create `.env` file). Real environment variables override `.env` which override defaults: ```bash # Database DATABASE_URL=sqlite:///./delphi_database.db -# Security +# Security SECRET_KEY=your-secret-key-change-in-production -ACCESS_TOKEN_EXPIRE_MINUTES=30 +# Optional previous key to allow rotation +PREVIOUS_SECRET_KEY= +ACCESS_TOKEN_EXPIRE_MINUTES=240 +REFRESH_TOKEN_EXPIRE_MINUTES=43200 # Application DEBUG=False diff --git a/app/api/auth.py b/app/api/auth.py index 4f96a3b..d6194ae 100644 --- a/app/api/auth.py +++ b/app/api/auth.py @@ -10,18 +10,23 @@ from sqlalchemy.orm import Session from app.database.base import get_db from app.models.user import User from app.auth.security import ( - authenticate_user, - create_access_token, + authenticate_user, + create_access_token, + create_refresh_token, + decode_refresh_token, + is_refresh_token_revoked, + revoke_refresh_token, get_password_hash, get_current_user, - get_admin_user + get_admin_user, ) from app.auth.schemas import ( - Token, - UserCreate, - UserResponse, + Token, + UserCreate, + UserResponse, LoginRequest, - ThemePreferenceUpdate + ThemePreferenceUpdate, + RefreshRequest, ) from app.config import settings from app.core.logging import get_logger, log_auth_attempt @@ -71,6 +76,12 @@ async def login(login_data: LoginRequest, request: Request, db: Session = Depend access_token = create_access_token( data={"sub": user.username}, expires_delta=access_token_expires ) + refresh_token = create_refresh_token( + user=user, + user_agent=request.headers.get("user-agent", ""), + ip_address=request.client.host if request.client else None, + db=db, + ) log_auth_attempt( username=login_data.username, @@ -85,7 +96,7 @@ async def login(login_data: LoginRequest, request: Request, db: Session = Depend client_ip=client_ip ) - return {"access_token": access_token, "token_type": "bearer"} + return {"access_token": access_token, "token_type": "bearer", "refresh_token": refresh_token} @router.post("/register", response_model=UserResponse) @@ -130,22 +141,76 @@ async def read_users_me(current_user: User = Depends(get_current_user)): @router.post("/refresh", response_model=Token) -async def refresh_token( - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) +async def refresh_token_endpoint( + request: Request, + db: Session = Depends(get_db), + body: RefreshRequest | None = None, ): - """Refresh access token for current user""" - # Update last login timestamp - current_user.last_login = datetime.utcnow() + """Issue a new access token using a valid, non-revoked refresh token. + + For backwards compatibility with existing clients that may call this without a body, + consider falling back to Authorization header in the future if needed. + """ + # New flow: refresh token in body + if body and body.refresh_token: + payload = decode_refresh_token(body.refresh_token) + if not payload: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token") + + jti = payload.get("jti") + username = payload.get("sub") + if not jti or not username: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token payload") + + # Verify token not revoked/expired + if is_refresh_token_revoked(jti, db): + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Refresh token revoked or expired") + + # Load user + user = db.query(User).filter(User.username == username).first() + if not user or not user.is_active: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found or inactive") + + # Rotate refresh token on use + revoke_refresh_token(jti, db) + new_refresh_token = create_refresh_token( + user=user, + user_agent=request.headers.get("user-agent", ""), + ip_address=request.client.host if request.client else None, + db=db, + ) + + # Issue new access token + access_token_expires = timedelta(minutes=settings.access_token_expire_minutes) + access_token = create_access_token( + data={"sub": user.username}, expires_delta=access_token_expires + ) + + return {"access_token": access_token, "token_type": "bearer", "refresh_token": new_refresh_token} + + # Legacy flow: Authorization header-based refresh + auth_header = request.headers.get("authorization") or request.headers.get("Authorization") + if not auth_header or not auth_header.lower().startswith("bearer "): + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing credentials") + + token = auth_header.split(" ", 1)[1].strip() + from app.auth.security import verify_token # local import to avoid circular + username = verify_token(token) + if not username: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token") + + user = db.query(User).filter(User.username == username).first() + if not user or not user.is_active: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found or inactive") + + user.last_login = datetime.utcnow() db.commit() - - # Create new token with full expiration time + access_token_expires = timedelta(minutes=settings.access_token_expire_minutes) access_token = create_access_token( - data={"sub": current_user.username}, - expires_delta=access_token_expires + data={"sub": user.username}, expires_delta=access_token_expires ) - + return {"access_token": access_token, "token_type": "bearer"} @@ -159,6 +224,23 @@ async def list_users( return users +@router.post("/logout") +async def logout(body: RefreshRequest | None = None, db: Session = Depends(get_db)): + """Revoke the provided refresh token. Idempotent and safe to call multiple times. + + The client should send a JSON body: { "refresh_token": "..." }. + """ + try: + if body and body.refresh_token: + payload = decode_refresh_token(body.refresh_token) + if payload and payload.get("jti"): + revoke_refresh_token(payload["jti"], db) + except Exception: + # Don't leak details; logout should be best-effort + pass + return {"status": "ok"} + + @router.post("/theme-preference") async def update_theme_preference( theme_data: ThemePreferenceUpdate, diff --git a/app/api/documents.py b/app/api/documents.py index 7b7203c..481ac9f 100644 --- a/app/api/documents.py +++ b/app/api/documents.py @@ -2,7 +2,7 @@ Document Management API endpoints - QDROs, Templates, and General Documents """ from typing import List, Optional, Dict, Any -from fastapi import APIRouter, Depends, HTTPException, status, Query, UploadFile, File, Form +from fastapi import APIRouter, Depends, HTTPException, status, Query, UploadFile, File, Form, Request from sqlalchemy.orm import Session, joinedload from sqlalchemy import or_, func, and_, desc, asc, text from datetime import date, datetime @@ -18,6 +18,8 @@ from app.models.lookups import FormIndex, FormList, Footer, Employee from app.models.user import User from app.auth.security import get_current_user from app.models.additional import Document +from app.core.logging import get_logger +from app.services.audit import audit_service router = APIRouter() @@ -666,6 +668,78 @@ def _merge_template_variables(content: str, variables: Dict[str, Any]) -> str: return merged +# --- Client Error Logging (for Documents page) --- +class ClientErrorLog(BaseModel): + """Payload for client-side error logging""" + message: str + action: Optional[str] = None + stack: Optional[str] = None + url: Optional[str] = None + line: Optional[int] = None + column: Optional[int] = None + user_agent: Optional[str] = None + extra: Optional[Dict[str, Any]] = None + + +@router.post("/client-error") +async def log_client_error( + payload: ClientErrorLog, + request: Request, + db: Session = Depends(get_db), + current_user: Optional[User] = Depends(lambda: None) +): + """Accept client-side error logs from the Documents page. + + This endpoint is lightweight and safe to call; it records the error to the + application logs and best-effort to the audit log without interrupting the UI. + """ + logger = get_logger("client.documents") + client_ip = request.headers.get("x-forwarded-for") + if client_ip: + client_ip = client_ip.split(",")[0].strip() + else: + client_ip = request.client.host if request.client else None + + logger.error( + "Client error reported", + action=payload.action, + message=payload.message, + stack=payload.stack, + page="/documents", + url=payload.url or str(request.url), + line=payload.line, + column=payload.column, + user=getattr(current_user, "username", None), + user_id=getattr(current_user, "id", None), + user_agent=payload.user_agent or request.headers.get("user-agent"), + client_ip=client_ip, + extra=payload.extra, + ) + + # Best-effort audit log; do not raise on failure + try: + audit_service.log_action( + db=db, + action="CLIENT_ERROR", + resource_type="DOCUMENTS", + user=current_user, + resource_id=None, + details={ + "action": payload.action, + "message": payload.message, + "url": payload.url or str(request.url), + "line": payload.line, + "column": payload.column, + "extra": payload.extra, + }, + request=request, + ) + except Exception: + pass + + return {"status": "logged"} + + @router.post("/upload/{file_no}") async def upload_document( file_no: str, diff --git a/app/auth/schemas.py b/app/auth/schemas.py index 880d146..537c6ae 100644 --- a/app/auth/schemas.py +++ b/app/auth/schemas.py @@ -45,6 +45,7 @@ class Token(BaseModel): """Token response schema""" access_token: str token_type: str + refresh_token: str | None = None class TokenData(BaseModel): @@ -55,4 +56,9 @@ class TokenData(BaseModel): class LoginRequest(BaseModel): """Login request schema""" username: str - password: str \ No newline at end of file + password: str + + +class RefreshRequest(BaseModel): + """Refresh token submission""" + refresh_token: str \ No newline at end of file diff --git a/app/auth/security.py b/app/auth/security.py index cdc49f2..dc069db 100644 --- a/app/auth/security.py +++ b/app/auth/security.py @@ -2,7 +2,8 @@ Authentication and security utilities """ from datetime import datetime, timedelta -from typing import Optional, Union +from typing import Optional, Union, Tuple +from uuid import uuid4 from jose import JWTError, jwt from passlib.context import CryptContext from sqlalchemy.orm import Session @@ -12,6 +13,7 @@ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from app.config import settings from app.database.base import get_db from app.models.user import User +from app.models.auth import RefreshToken # Password hashing pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") @@ -30,39 +32,107 @@ def get_password_hash(password: str) -> str: return pwd_context.hash(password) +def _encode_with_rotation(payload: dict) -> str: + """Encode JWT with active secret and algorithm.""" + return jwt.encode(payload, settings.secret_key, algorithm=settings.algorithm) + + +def _decode_with_rotation(token: str) -> dict: + """Decode JWT trying current then previous secret if set.""" + try: + return jwt.decode(token, settings.secret_key, algorithms=[settings.algorithm]) + except JWTError: + # Try previous secret to allow seamless rotation + if settings.previous_secret_key: + try: + return jwt.decode(token, settings.previous_secret_key, algorithms=[settings.algorithm]) + except JWTError: + pass + raise + + def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str: """Create JWT access token""" to_encode = data.copy() - if expires_delta: - expire = datetime.utcnow() + expires_delta - else: - expire = datetime.utcnow() + timedelta(minutes=settings.access_token_expire_minutes) - + expire = datetime.utcnow() + ( + expires_delta if expires_delta else timedelta(minutes=settings.access_token_expire_minutes) + ) to_encode.update({ "exp": expire, "iat": datetime.utcnow(), + "type": "access", }) - encoded_jwt = jwt.encode(to_encode, settings.secret_key, algorithm=settings.algorithm) - return encoded_jwt + return _encode_with_rotation(to_encode) + + +def create_refresh_token(user: User, user_agent: Optional[str], ip_address: Optional[str], db: Session) -> str: + """Create refresh token, store its JTI in DB for revocation.""" + jti = uuid4().hex + expire = datetime.utcnow() + timedelta(minutes=settings.refresh_token_expire_minutes) + payload = { + "sub": user.username, + "uid": user.id, + "jti": jti, + "type": "refresh", + "exp": expire, + "iat": datetime.utcnow(), + } + token = _encode_with_rotation(payload) + + db_token = RefreshToken( + user_id=user.id, + jti=jti, + user_agent=user_agent, + ip_address=ip_address, + issued_at=datetime.utcnow(), + expires_at=expire, + revoked=False, + ) + db.add(db_token) + db.commit() + return token def verify_token(token: str) -> Optional[str]: """Verify JWT token and return username""" try: - payload = jwt.decode( - token, - settings.secret_key, - algorithms=[settings.algorithm], - leeway=30 # allow small clock skew - ) + payload = _decode_with_rotation(token) username: str = payload.get("sub") + token_type: str = payload.get("type") if username is None: return None + # Only accept access tokens for auth + if token_type and token_type != "access": + return None return username except JWTError: return None +def decode_refresh_token(token: str) -> Optional[dict]: + """Decode refresh token and return payload if valid and not revoked.""" + try: + payload = _decode_with_rotation(token) + if payload.get("type") != "refresh": + return None + return payload + except JWTError: + return None + + +def is_refresh_token_revoked(jti: str, db: Session) -> bool: + token_row = db.query(RefreshToken).filter(RefreshToken.jti == jti).first() + return not token_row or token_row.revoked or token_row.expires_at <= datetime.utcnow() + + +def revoke_refresh_token(jti: str, db: Session) -> None: + token_row = db.query(RefreshToken).filter(RefreshToken.jti == jti).first() + if token_row and not token_row.revoked: + token_row.revoked = True + token_row.revoked_at = datetime.utcnow() + db.commit() + + def authenticate_user(db: Session, username: str, password: str) -> Optional[User]: """Authenticate user credentials""" user = db.query(User).filter(User.username == username).first() diff --git a/app/config.py b/app/config.py index 0b833a8..06e3d92 100644 --- a/app/config.py +++ b/app/config.py @@ -1,53 +1,70 @@ """ Delphi Consulting Group Database System - Configuration """ -from pydantic_settings import BaseSettings from typing import Optional +from pydantic import Field +from pydantic_settings import BaseSettings, SettingsConfigDict + class Settings(BaseSettings): - """Application configuration""" - + """Application configuration (env and .env driven). + + Environment precedence: real environment variables take priority over .env, + which take priority over defaults. + """ + # Application app_name: str = "Delphi Consulting Group Database System" app_version: str = "1.0.0" debug: bool = False - + # Database database_url: str = "sqlite:///./data/delphi_database.db" - - # Authentication - secret_key: str = "your-secret-key-change-in-production" + + # Authentication / JWT + # Require SECRET_KEY to be provided via environment/.env (no insecure default) + secret_key: str = Field(..., min_length=32) + # Optional previous secret key to allow seamless rotation + previous_secret_key: Optional[str] = None algorithm: str = "HS256" access_token_expire_minutes: int = 240 # 4 hours - + # Long-lived refresh token expiration (default 30 days) + refresh_token_expire_minutes: int = 43200 + # Admin account settings admin_username: str = "admin" admin_password: str = "change-me" - + # File paths upload_dir: str = "./uploads" backup_dir: str = "./backups" - + # Pagination default_page_size: int = 50 max_page_size: int = 200 - + # Docker/deployment settings external_port: Optional[str] = None allowed_hosts: Optional[str] = None cors_origins: Optional[str] = None secure_cookies: bool = False compose_project_name: Optional[str] = None - + # Logging log_level: str = "INFO" log_to_file: bool = True log_rotation: str = "10 MB" log_retention: str = "30 days" - - class Config: - env_file = ".env" + + # pydantic-settings v2 configuration + model_config = SettingsConfigDict( + env_file=".env", + env_prefix="", + case_sensitive=False, + env_ignore_empty=True, + extra="ignore", + ) settings = Settings() \ No newline at end of file diff --git a/app/main.py b/app/main.py index 015b6fa..5f24436 100644 --- a/app/main.py +++ b/app/main.py @@ -14,6 +14,7 @@ from app.models.user import User from app.auth.security import get_admin_user from app.core.logging import setup_logging, get_logger from app.middleware.logging import LoggingMiddleware +from app.middleware.errors import register_exception_handlers # Initialize logging setup_logging() @@ -35,6 +36,10 @@ app = FastAPI( logger.info("Adding request logging middleware") app.add_middleware(LoggingMiddleware, log_requests=True, log_responses=settings.debug) +# Register global exception handlers +logger.info("Registering global exception handlers") +register_exception_handlers(app) + # Configure CORS logger.info("Configuring CORS middleware") app.add_middleware( diff --git a/app/middleware/errors.py b/app/middleware/errors.py new file mode 100644 index 0000000..f1575cb --- /dev/null +++ b/app/middleware/errors.py @@ -0,0 +1,132 @@ +""" +Global exception handlers that return a consistent JSON error envelope +and propagate the X-Correlation-ID header. +""" +from __future__ import annotations + +from typing import Any, Optional +from uuid import uuid4 + +from fastapi import FastAPI, Request +from fastapi.exceptions import RequestValidationError +from fastapi.responses import JSONResponse +from starlette import status as http_status +from starlette.requests import Request as StarletteRequest +from starlette.responses import JSONResponse as StarletteJSONResponse +from fastapi import HTTPException + +try: + # Prefer project logger if available + from app.core.logging import get_logger + logger = get_logger("middleware.errors") +except Exception: # pragma: no cover - fallback simple logger + import logging + logger = logging.getLogger("middleware.errors") + + +ERROR_HEADER_NAME = "X-Correlation-ID" + + +def _get_correlation_id(request: Request) -> str: + """Resolve correlation ID from request state, headers, or generate a new one.""" + # From middleware + correlation_id: Optional[str] = getattr(getattr(request, "state", object()), "correlation_id", None) + if correlation_id: + return correlation_id + + # From incoming headers + correlation_id = ( + request.headers.get("x-correlation-id") + or request.headers.get("x-request-id") + ) + if correlation_id: + return correlation_id + + # Generate a new one if not present + return str(uuid4()) + + +def _build_error_response( + request: Request, + *, + status_code: int, + message: str, + code: str, + details: Any | None = None, +) -> JSONResponse: + correlation_id = _get_correlation_id(request) + + body = { + "success": False, + "error": { + "status": status_code, + "code": code, + "message": message, + }, + "correlation_id": correlation_id, + } + if details is not None: + body["error"]["details"] = details + + response = JSONResponse(content=body, status_code=status_code) + response.headers[ERROR_HEADER_NAME] = correlation_id + return response + + +async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse: + """Handle FastAPI HTTPException with envelope and correlation id.""" + message = exc.detail if isinstance(exc.detail, str) else "HTTP error" + logger.warning( + "HTTPException raised", + status_code=exc.status_code, + detail=message, + path=request.url.path, + ) + return _build_error_response( + request, + status_code=exc.status_code, + message=message, + code="http_error", + details=None, + ) + + +async def validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse: + """Handle validation errors from FastAPI/Pydantic.""" + logger.info( + "Validation error", + path=request.url.path, + errors=exc.errors(), + ) + return _build_error_response( + request, + status_code=http_status.HTTP_422_UNPROCESSABLE_ENTITY, + message="Validation error", + code="validation_error", + details=exc.errors(), + ) + + +async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONResponse: + """Catch-all handler for unexpected exceptions (500).""" + # Log full exception for diagnostics without leaking internals to clients + try: + logger.exception("Unhandled exception", path=request.url.path) + except Exception: + pass + return _build_error_response( + request, + status_code=http_status.HTTP_500_INTERNAL_SERVER_ERROR, + message="Internal Server Error", + code="internal_error", + details=None, + ) + + +def register_exception_handlers(app: FastAPI) -> None: + """Register global exception handlers on the provided FastAPI app.""" + app.add_exception_handler(HTTPException, http_exception_handler) + app.add_exception_handler(RequestValidationError, validation_exception_handler) + app.add_exception_handler(Exception, unhandled_exception_handler) + + diff --git a/app/middleware/logging.py b/app/middleware/logging.py index 4ab4de3..c13c9c9 100644 --- a/app/middleware/logging.py +++ b/app/middleware/logging.py @@ -4,6 +4,7 @@ Request/Response Logging Middleware import time import json from typing import Callable +from uuid import uuid4 from fastapi import Request, Response from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import StreamingResponse @@ -21,10 +22,19 @@ class LoggingMiddleware(BaseHTTPMiddleware): self.log_responses = log_responses async def dispatch(self, request: Request, call_next: Callable) -> Response: - # Skip logging for static files and health checks + # Correlation ID: use incoming header or generate a new one + correlation_id = request.headers.get("x-correlation-id") or request.headers.get("x-request-id") or str(uuid4()) + request.state.correlation_id = correlation_id + + # Skip logging for static files and health checks (still attach correlation id) skip_paths = ["/static/", "/uploads/", "/health", "/favicon.ico"] if any(request.url.path.startswith(path) for path in skip_paths): - return await call_next(request) + response = await call_next(request) + try: + response.headers["X-Correlation-ID"] = correlation_id + except Exception: + pass + return response # Record start time start_time = time.time() @@ -41,7 +51,8 @@ class LoggingMiddleware(BaseHTTPMiddleware): path=request.url.path, query_params=str(request.query_params) if request.query_params else None, client_ip=client_ip, - user_agent=user_agent + user_agent=user_agent, + correlation_id=correlation_id, ) # Process request @@ -56,7 +67,8 @@ class LoggingMiddleware(BaseHTTPMiddleware): path=request.url.path, duration_ms=duration_ms, error=str(e), - client_ip=client_ip + client_ip=client_ip, + correlation_id=correlation_id, ) raise @@ -74,7 +86,8 @@ class LoggingMiddleware(BaseHTTPMiddleware): path=request.url.path, status_code=response.status_code, duration_ms=duration_ms, - user_id=user_id + user_id=user_id, + correlation_id=correlation_id, ) # Log response details if enabled @@ -84,7 +97,8 @@ class LoggingMiddleware(BaseHTTPMiddleware): status_code=response.status_code, headers=dict(response.headers), size_bytes=response.headers.get("content-length"), - content_type=response.headers.get("content-type") + content_type=response.headers.get("content-type"), + correlation_id=correlation_id, ) # Log slow requests as warnings @@ -94,7 +108,8 @@ class LoggingMiddleware(BaseHTTPMiddleware): method=request.method, path=request.url.path, duration_ms=duration_ms, - status_code=response.status_code + status_code=response.status_code, + correlation_id=correlation_id, ) # Log authentication-related requests to auth log @@ -106,9 +121,16 @@ class LoggingMiddleware(BaseHTTPMiddleware): status_code=response.status_code, duration_ms=duration_ms, client_ip=client_ip, - user_agent=user_agent + user_agent=user_agent, + correlation_id=correlation_id, ) - + + # Attach correlation id header to all responses + try: + response.headers["X-Correlation-ID"] = correlation_id + except Exception: + pass + return response def get_client_ip(self, request: Request) -> str: diff --git a/app/models/__init__.py b/app/models/__init__.py index 108063c..a02cfe0 100644 --- a/app/models/__init__.py +++ b/app/models/__init__.py @@ -8,6 +8,7 @@ from .files import File from .ledger import Ledger from .qdro import QDRO from .audit import AuditLog, LoginAttempt +from .auth import RefreshToken from .additional import Deposit, Payment, FileNote, FormVariable, ReportVariable, Document from .support import SupportTicket, TicketResponse, TicketStatus, TicketPriority, TicketCategory from .pensions import ( @@ -22,7 +23,7 @@ from .lookups import ( __all__ = [ "BaseModel", "User", "Rolodex", "Phone", "File", "Ledger", "QDRO", - "AuditLog", "LoginAttempt", + "AuditLog", "LoginAttempt", "RefreshToken", "Deposit", "Payment", "FileNote", "FormVariable", "ReportVariable", "Document", "SupportTicket", "TicketResponse", "TicketStatus", "TicketPriority", "TicketCategory", "Pension", "PensionSchedule", "MarriageHistory", "DeathBenefit", diff --git a/app/models/auth.py b/app/models/auth.py new file mode 100644 index 0000000..2fa1d26 --- /dev/null +++ b/app/models/auth.py @@ -0,0 +1,34 @@ +""" +Authentication-related persistence models +""" +from datetime import datetime +from typing import Optional + +from sqlalchemy import Column, Integer, String, DateTime, Boolean, ForeignKey, UniqueConstraint +from sqlalchemy.orm import relationship + +from app.models.base import BaseModel + + +class RefreshToken(BaseModel): + """Persisted refresh tokens for revocation and auditing.""" + __tablename__ = "refresh_tokens" + + id = Column(Integer, primary_key=True, autoincrement=True, index=True) + user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True) + jti = Column(String(64), nullable=False, unique=True, index=True) + user_agent = Column(String(255), nullable=True) + ip_address = Column(String(45), nullable=True) + issued_at = Column(DateTime, default=datetime.utcnow, nullable=False) + expires_at = Column(DateTime, nullable=False, index=True) + revoked = Column(Boolean, default=False, nullable=False) + revoked_at = Column(DateTime, nullable=True) + + # relationships + user = relationship("User") + + __table_args__ = ( + UniqueConstraint("jti", name="uq_refresh_tokens_jti"), + ) + + diff --git a/scripts/rotate-secret-key.py b/scripts/rotate-secret-key.py new file mode 100644 index 0000000..fddb4c7 --- /dev/null +++ b/scripts/rotate-secret-key.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python3 +""" +Rotate SECRET_KEY in .env with seamless fallback using PREVIOUS_SECRET_KEY. + +Usage: + python scripts/rotate-secret-key.py [--env-path .env] + +Behavior: +- Reads the .env file (default .env) +- Sets PREVIOUS_SECRET_KEY to current SECRET_KEY +- Generates a new SECRET_KEY +- Preserves other variables +- Writes back atomically and sets file mode 600 +""" +from __future__ import annotations + +import argparse +import os +import secrets +import tempfile +from pathlib import Path + + +def generate_secret_key(length: int = 32) -> str: + return secrets.token_urlsafe(length) + + +def parse_env(contents: str) -> dict[str, str]: + env: dict[str, str] = {} + for line in contents.splitlines(): + if not line or line.strip().startswith("#"): + continue + if "=" not in line: + continue + key, value = line.split("=", 1) + env[key.strip()] = value.strip() + return env + + +def render_env(original: str, updates: dict[str, str]) -> str: + lines = original.splitlines() + seen_keys: set[str] = set() + out_lines: list[str] = [] + for line in lines: + if not line or line.strip().startswith("#") or "=" not in line: + out_lines.append(line) + continue + key, _ = line.split("=", 1) + k = key.strip() + if k in updates: + out_lines.append(f"{k}={updates[k]}") + seen_keys.add(k) + else: + out_lines.append(line) + seen_keys.add(k) + # Append any new keys not present originally + for k, v in updates.items(): + if k not in seen_keys: + out_lines.append(f"{k}={v}") + return "\n".join(out_lines) + "\n" + + +def main() -> None: + parser = argparse.ArgumentParser(description="Rotate SECRET_KEY in .env") + parser.add_argument("--env-path", default=".env", help="Path to .env file") + args = parser.parse_args() + + env_path = Path(args.env_path) + if not env_path.exists(): + raise SystemExit(f".env file not found at {env_path}") + + original = env_path.read_text() + env = parse_env(original) + + current_secret = env.get("SECRET_KEY") + if not current_secret: + raise SystemExit("SECRET_KEY not found in .env") + + new_secret = generate_secret_key(32) + updates = { + "PREVIOUS_SECRET_KEY": current_secret, + "SECRET_KEY": new_secret, + } + + rendered = render_env(original, updates) + + # Atomic write + with tempfile.NamedTemporaryFile("w", delete=False, dir=str(env_path.parent)) as tmp: + tmp.write(rendered) + temp_name = tmp.name + os.replace(temp_name, env_path) + os.chmod(env_path, 0o600) + + print("✅ SECRET_KEY rotated successfully.") + print(" PREVIOUS_SECRET_KEY updated for seamless token validation.") + print(" Restart the application to apply the new key.") + + +if __name__ == "__main__": + main() + + diff --git a/scripts/setup-security.py b/scripts/setup-security.py index 9f47905..df4a8c8 100755 --- a/scripts/setup-security.py +++ b/scripts/setup-security.py @@ -63,7 +63,10 @@ DATABASE_URL=sqlite:///data/delphi_database.db # ===== SECURITY SETTINGS - GENERATED ===== SECRET_KEY={secret_key} -ACCESS_TOKEN_EXPIRE_MINUTES=30 +# Optional previous key for seamless rotation (leave blank initially) +PREVIOUS_SECRET_KEY= +ACCESS_TOKEN_EXPIRE_MINUTES=240 +REFRESH_TOKEN_EXPIRE_MINUTES=43200 ALGORITHM=HS256 # ===== ADMIN USER CREATION ===== diff --git a/static/js/financial.js b/static/js/financial.js index 237eb1a..acc3655 100644 --- a/static/js/financial.js +++ b/static/js/financial.js @@ -6,7 +6,7 @@ // ... add the JS content ... -// Modify modal showing/hiding to use classList.add/remove('hidden') instead of Bootstrap modal +// Modify modal showing/hiding to use classList.add/remove('hidden') // For example: function showQuickTimeModal() { diff --git a/static/js/main.js b/static/js/main.js index ef69b3e..7697840 100644 --- a/static/js/main.js +++ b/static/js/main.js @@ -5,6 +5,7 @@ // Global application state const app = { token: localStorage.getItem('auth_token'), + refreshToken: localStorage.getItem('refresh_token'), user: null, initialized: false, refreshTimerId: null, @@ -13,9 +14,98 @@ const app = { // Initialize application document.addEventListener('DOMContentLoaded', function() { + try { setupGlobalErrorHandlers(); } catch (_) {} initializeApp(); }); +// Theme Management (centralized) +function applyTheme(theme) { + const html = document.documentElement; + const isDark = theme === 'dark'; + html.classList.toggle('dark', isDark); + html.setAttribute('data-theme', isDark ? 'dark' : 'light'); +} + +function toggleTheme() { + const html = document.documentElement; + const nextTheme = html.classList.contains('dark') ? 'light' : 'dark'; + applyTheme(nextTheme); + try { localStorage.setItem('theme-preference', nextTheme); } catch (_) {} + saveThemePreference(nextTheme); +} + +function initializeTheme() { + // Check for saved theme preference + let savedTheme = null; + try { savedTheme = localStorage.getItem('theme-preference'); } catch (_) {} + const prefersDark = window.matchMedia && window.matchMedia('(prefers-color-scheme: dark)').matches; + + const theme = savedTheme || (prefersDark ? 'dark' : 'light'); + applyTheme(theme); + + // Load from server if available + loadUserThemePreference(); + + // Listen for OS theme changes if no explicit preference is set + attachSystemThemeListener(); +} + +async function saveThemePreference(theme) { + const token = localStorage.getItem('auth_token'); + if (!token || isLoginPage()) return; + try { + await fetch('/api/auth/theme-preference', { + method: 'POST', + headers: { + 'Authorization': `Bearer ${token}`, + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ theme_preference: theme }) + }); + } catch (error) { + console.log('Could not save theme preference to server:', error.message); + } +} + +function attachSystemThemeListener() { + if (!('matchMedia' in window)) return; + const media = window.matchMedia('(prefers-color-scheme: dark)'); + const handleChange = (e) => { + let savedTheme = null; + try { savedTheme = localStorage.getItem('theme-preference'); } catch (_) {} + if (!savedTheme || savedTheme === 'system') { + applyTheme(e.matches ? 'dark' : 'light'); + } + }; + if (typeof media.addEventListener === 'function') { + media.addEventListener('change', handleChange); + } else if (typeof media.addListener === 'function') { + media.addListener(handleChange); // Safari fallback + } +} + +async function loadUserThemePreference() { + const token = localStorage.getItem('auth_token'); + if (!token || isLoginPage()) return; + try { + const response = await fetch('/api/auth/me', { + headers: { 'Authorization': `Bearer ${token}` } + }); + if (response.ok) { + const user = await response.json(); + if (user.theme_preference) { + applyTheme(user.theme_preference); + try { localStorage.setItem('theme-preference', user.theme_preference); } catch (_) {} + } + } + } catch (error) { + console.log('Could not load theme preference from server:', error.message); + } +} + +// Apply theme immediately on script load +try { initializeTheme(); } catch (_) {} + async function initializeApp() { // Initialize keyboard shortcuts if (window.keyboardShortcuts) { @@ -29,6 +119,11 @@ async function initializeApp() { // Initialize API helpers setupAPIHelpers(); + + // Initialize authentication manager (centralized) + if (typeof initializeAuthManager === 'function') { + initializeAuthManager(); + } app.initialized = true; console.log('Delphi Database System initialized'); @@ -103,6 +198,14 @@ async function apiCall(url, options = {}) { try { let response = await fetch(url, config); + const updateCorrelationFromResponse = (resp) => { + try { + const cid = resp && resp.headers ? resp.headers.get('X-Correlation-ID') : null; + if (cid) { window.app.lastCorrelationId = cid; } + return cid; + } catch (_) { return null; } + }; + let lastCorrelationId = updateCorrelationFromResponse(response); if (response.status === 401 && app.token) { // Attempt one refresh then retry once @@ -113,6 +216,7 @@ async function apiCall(url, options = {}) { ...options }; response = await fetch(url, retryConfig); + lastCorrelationId = updateCorrelationFromResponse(response); } catch (_) { // fall through to logout below } @@ -124,7 +228,10 @@ async function apiCall(url, options = {}) { if (!response.ok) { const errorData = await response.json().catch(() => ({ detail: 'Request failed' })); - throw new Error(errorData.detail || `HTTP ${response.status}`); + const err = new Error(errorData.detail || `HTTP ${response.status}`); + err.status = response.status; + err.correlationId = lastCorrelationId || null; + throw err; } return await response.json(); @@ -158,16 +265,263 @@ async function apiDelete(url) { } // Authentication functions -function setAuthToken(token) { - app.token = token; - localStorage.setItem('auth_token', token); - window.apiHeaders['Authorization'] = `Bearer ${token}`; - +function setAuthTokens(accessToken, newRefreshToken = null) { + if (accessToken) { + app.token = accessToken; + localStorage.setItem('auth_token', accessToken); + window.apiHeaders['Authorization'] = `Bearer ${accessToken}`; + } + if (newRefreshToken) { + app.refreshToken = newRefreshToken; + localStorage.setItem('refresh_token', newRefreshToken); + } // Reschedule refresh on token update - scheduleTokenRefresh(); + if (accessToken) { + scheduleTokenRefresh(); + } } -function logout() { +function setAuthToken(token) { + // Backwards compatibility + setAuthTokens(token, null); +} + +// Page helpers +function isLoginPage() { + const path = window.location.pathname; + return path === '/login' || path === '/'; +} + +// Verify the current access token by hitting /api/auth/me +async function checkTokenValidity() { + const token = localStorage.getItem('auth_token'); + if (!token) return false; + try { + const response = await fetch('/api/auth/me', { + headers: { 'Authorization': `Bearer ${token}` } + }); + if (!response.ok) { + // Invalid token + return false; + } + // Cache user for later UI updates + try { app.user = await response.json(); } catch (_) {} + return true; + } catch (error) { + console.error('Error checking token validity:', error); + return false; + } +} + +// Try to refresh token if refresh token is present; fallback to validity check+logout +async function refreshTokenIfNeeded() { + const refreshTokenValue = localStorage.getItem('refresh_token'); + if (!refreshTokenValue) return; + app.refreshToken = refreshTokenValue; + try { + await refreshToken(); + console.log('Token refreshed successfully'); + } catch (error) { + const stillValid = await checkTokenValidity(); + if (!stillValid) { + await logout('Session expired or invalid token'); + } + } +} + +// Update UI elements that are permission/user dependent +async function checkUserPermissions() { + const token = localStorage.getItem('auth_token'); + if (!token || isLoginPage()) return; + try { + const response = await fetch('/api/auth/me', { + headers: { 'Authorization': `Bearer ${token}` } + }); + if (response.ok) { + const user = await response.json(); + app.user = user; + if (user.is_admin) { + const adminItem = document.getElementById('admin-menu-item'); + const adminDivider = document.getElementById('admin-menu-divider'); + if (adminItem) adminItem.classList.remove('hidden'); + if (adminDivider) adminDivider.classList.remove('hidden'); + } + const userDropdownName = document.querySelector('#userDropdown button span'); + if (user.full_name && userDropdownName) { + userDropdownName.textContent = user.full_name; + } + } + } catch (error) { + console.error('Error checking user permissions:', error); + } +} + +// Inactivity monitoring & session extension UI +async function getInactivityWarningMinutes() { + const token = localStorage.getItem('auth_token'); + if (!token) return 240; + try { + const resp = await fetch('/api/settings/inactivity_warning_minutes', { + headers: { 'Authorization': `Bearer ${token}` } + }); + if (!resp.ok) return 240; + const data = await resp.json(); + if (typeof data.minutes === 'number') return data.minutes; + const parsed = parseInt(data.setting_value || data.minutes, 10); + return Number.isFinite(parsed) ? parsed : 240; + } catch (_) { + return 240; + } +} + +function showSessionExtendedNotification() { + if (window.alerts && typeof window.alerts.success === 'function') { + window.alerts.success('Your session has been refreshed successfully.', { + title: 'Session Extended', + duration: 3000 + }); + return; + } + // Fallback + const notification = document.createElement('div'); + notification.className = 'fixed top-4 right-4 bg-green-100 border-l-4 border-green-500 text-green-700 p-4 rounded-lg shadow-lg z-50 max-w-sm'; + notification.innerHTML = ` +
Session Extended
+Your session has been refreshed successfully.
+