""" Authentication and security utilities """ from datetime import datetime, timedelta, timezone from typing import Optional, Union, Tuple from uuid import uuid4 from jose import JWTError, jwt from passlib.context import CryptContext from sqlalchemy.orm import Session from fastapi import HTTPException, status, Depends from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from app.config import settings from app.database.base import get_db from app.models.user import User from app.models.auth import RefreshToken # Password hashing pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") # JWT Security security = HTTPBearer() def verify_password(plain_password: str, hashed_password: str) -> bool: """Verify a plain password against its hash""" return pwd_context.verify(plain_password, hashed_password) def get_password_hash(password: str) -> str: """Generate password hash""" return pwd_context.hash(password) def _encode_with_rotation(payload: dict) -> str: """Encode JWT with active secret and algorithm.""" return jwt.encode(payload, settings.secret_key, algorithm=settings.algorithm) def _decode_with_rotation(token: str) -> dict: """Decode JWT trying current then previous secret if set.""" try: return jwt.decode(token, settings.secret_key, algorithms=[settings.algorithm]) except JWTError: # Try previous secret to allow seamless rotation if settings.previous_secret_key: try: return jwt.decode(token, settings.previous_secret_key, algorithms=[settings.algorithm]) except JWTError: pass raise def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str: """Create JWT access token""" to_encode = data.copy() expire = datetime.now(timezone.utc) + ( expires_delta if expires_delta else timedelta(minutes=settings.access_token_expire_minutes) ) to_encode.update({ "exp": expire, "iat": datetime.now(timezone.utc), "type": "access", }) return _encode_with_rotation(to_encode) def create_refresh_token(user: User, user_agent: Optional[str], ip_address: Optional[str], db: Session) -> str: """Create refresh token, store its JTI in DB for revocation.""" jti = uuid4().hex expire = datetime.now(timezone.utc) + timedelta(minutes=settings.refresh_token_expire_minutes) payload = { "sub": user.username, "uid": user.id, "jti": jti, "type": "refresh", "exp": expire, "iat": datetime.now(timezone.utc), } token = _encode_with_rotation(payload) db_token = RefreshToken( user_id=user.id, jti=jti, user_agent=user_agent, ip_address=ip_address, issued_at=datetime.now(timezone.utc), expires_at=expire, revoked=False, ) db.add(db_token) db.commit() return token def _to_utc_aware(dt: Optional[datetime]) -> Optional[datetime]: """Convert a datetime to UTC-aware. If naive, assume it's already UTC and attach tzinfo.""" if dt is None: return None if dt.tzinfo is None: return dt.replace(tzinfo=timezone.utc) return dt.astimezone(timezone.utc) def verify_token(token: str) -> Optional[str]: """Verify JWT token and return username""" try: payload = _decode_with_rotation(token) username: str = payload.get("sub") token_type: str = payload.get("type") if username is None: return None # Only accept access tokens for auth if token_type and token_type != "access": return None return username except JWTError: return None def decode_refresh_token(token: str) -> Optional[dict]: """Decode refresh token and return payload if valid and not revoked.""" try: payload = _decode_with_rotation(token) if payload.get("type") != "refresh": return None return payload except JWTError: return None def is_refresh_token_revoked(jti: str, db: Session) -> bool: token_row = db.query(RefreshToken).filter(RefreshToken.jti == jti).first() if not token_row: return True if token_row.revoked: return True expires_at_utc = _to_utc_aware(token_row.expires_at) now_utc = datetime.now(timezone.utc) return expires_at_utc is not None and expires_at_utc <= now_utc def revoke_refresh_token(jti: str, db: Session) -> None: token_row = db.query(RefreshToken).filter(RefreshToken.jti == jti).first() if token_row and not token_row.revoked: token_row.revoked = True token_row.revoked_at = datetime.now(timezone.utc) db.commit() def authenticate_user(db: Session, username: str, password: str) -> Optional[User]: """Authenticate user credentials""" user = db.query(User).filter(User.username == username).first() if not user: return None if not verify_password(password, user.hashed_password): return None return user def get_current_user( credentials: HTTPAuthorizationCredentials = Depends(security), db: Session = Depends(get_db) ) -> User: """Get current authenticated user""" token = credentials.credentials username = verify_token(token) if username is None: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", headers={"WWW-Authenticate": "Bearer"}, ) user = db.query(User).filter(User.username == username).first() if user is None: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found", headers={"WWW-Authenticate": "Bearer"}, ) if not user.is_active: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Inactive user" ) return user def get_admin_user(current_user: User = Depends(get_current_user)) -> User: """Require admin privileges""" if not current_user.is_admin: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Not enough permissions" ) return current_user