all working
This commit is contained in:
87
tests/test_auth.py
Normal file
87
tests/test_auth.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import pytest
|
||||
from jose import jwt
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
from app.config import Settings, settings
|
||||
from app.database.base import Base
|
||||
from app.models.user import User
|
||||
from app.auth.security import (
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
decode_refresh_token,
|
||||
is_refresh_token_revoked,
|
||||
revoke_refresh_token,
|
||||
)
|
||||
|
||||
|
||||
def test_settings_env_precedence(monkeypatch):
|
||||
# Ensure env var overrides .env/default
|
||||
monkeypatch.setenv("SECRET_KEY", "env_secret_value_12345678901234567890123456789012")
|
||||
cfg = Settings()
|
||||
assert cfg.secret_key == "env_secret_value_12345678901234567890123456789012"
|
||||
|
||||
|
||||
def test_jwt_rotation_decode(monkeypatch):
|
||||
# Simulate key rotation: token signed with previous key should validate
|
||||
old_key = "old_secret_value_12345678901234567890123456789012"
|
||||
new_key = "new_secret_value_12345678901234567890123456789012"
|
||||
|
||||
# Patch runtime settings
|
||||
settings.previous_secret_key = old_key
|
||||
settings.secret_key = new_key
|
||||
|
||||
# Sign token with old key
|
||||
payload = {
|
||||
"sub": "tester",
|
||||
"exp": datetime.utcnow() + timedelta(minutes=5),
|
||||
"iat": datetime.utcnow(),
|
||||
"type": "access",
|
||||
}
|
||||
token = jwt.encode(payload, old_key, algorithm=settings.algorithm)
|
||||
|
||||
# Verify using public API verify_token via access token creation roundtrip
|
||||
# Using internal decode through create_access_token is indirect; ensure no exception
|
||||
from app.auth.security import verify_token
|
||||
|
||||
username = verify_token(token)
|
||||
assert username == "tester"
|
||||
|
||||
|
||||
def test_refresh_token_lifecycle(tmp_path):
|
||||
# Build isolated in-memory database
|
||||
engine = create_engine("sqlite:///:memory:", connect_args={"check_same_thread": False})
|
||||
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
db = TestingSessionLocal()
|
||||
try:
|
||||
# Create a user
|
||||
user = User(username="alice", email="a@example.com", hashed_password="x", is_active=True)
|
||||
db.add(user)
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
|
||||
# Issue refresh token
|
||||
rtoken = create_refresh_token(user=user, user_agent="pytest", ip_address="127.0.0.1", db=db)
|
||||
payload = decode_refresh_token(rtoken)
|
||||
assert payload is not None
|
||||
jti = payload["jti"]
|
||||
assert not is_refresh_token_revoked(jti, db)
|
||||
|
||||
# Revoke and assert
|
||||
revoke_refresh_token(jti, db)
|
||||
assert is_refresh_token_revoked(jti, db)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
82
tests/test_error_handling.py
Normal file
82
tests/test_error_handling.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.testclient import TestClient
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.middleware.logging import LoggingMiddleware
|
||||
from app.middleware.errors import register_exception_handlers
|
||||
|
||||
|
||||
class Item(BaseModel):
|
||||
name: str = Field(..., min_length=3)
|
||||
quantity: int = Field(..., ge=1)
|
||||
|
||||
|
||||
def build_test_app() -> FastAPI:
|
||||
app = FastAPI()
|
||||
app.add_middleware(LoggingMiddleware, log_requests=False, log_responses=False)
|
||||
register_exception_handlers(app)
|
||||
|
||||
@app.get("/http-error")
|
||||
async def http_error():
|
||||
raise HTTPException(status_code=403, detail="Forbidden action")
|
||||
|
||||
@app.post("/validation")
|
||||
async def validation_endpoint(item: Item): # noqa: F841
|
||||
return {"ok": True}
|
||||
|
||||
@app.get("/crash")
|
||||
async def crash():
|
||||
raise RuntimeError("Boom")
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def assert_envelope(resp, status: int, code: str, has_details: bool, expected_cid: Optional[str] = None):
|
||||
assert resp.status_code == status
|
||||
data = resp.json()
|
||||
assert data["success"] is False
|
||||
assert data["error"]["status"] == status
|
||||
assert data["error"]["code"] == code
|
||||
assert isinstance(data["error"]["message"], str) and data["error"]["message"]
|
||||
if has_details:
|
||||
assert "details" in data["error"]
|
||||
else:
|
||||
assert "details" not in data["error"]
|
||||
|
||||
# Correlation id in body and header
|
||||
assert "correlation_id" in data and isinstance(data["correlation_id"], str)
|
||||
header_cid = resp.headers.get("X-Correlation-ID")
|
||||
assert header_cid == data["correlation_id"]
|
||||
if expected_cid is not None:
|
||||
assert header_cid == expected_cid
|
||||
|
||||
|
||||
def test_http_exception_envelope_and_correlation_id_echo():
|
||||
app = build_test_app()
|
||||
client = TestClient(app)
|
||||
cid = "abc-12345-test"
|
||||
resp = client.get("/http-error", headers={"X-Correlation-ID": cid})
|
||||
assert_envelope(resp, 403, "http_error", has_details=False, expected_cid=cid)
|
||||
|
||||
|
||||
def test_validation_exception_envelope_and_correlation_id_echo():
|
||||
app = build_test_app()
|
||||
client = TestClient(app)
|
||||
cid = "cid-validation-67890"
|
||||
# Missing fields to trigger 422
|
||||
resp = client.post("/validation", json={"name": "ab"}, headers={"X-Correlation-ID": cid})
|
||||
assert_envelope(resp, 422, "validation_error", has_details=True, expected_cid=cid)
|
||||
|
||||
|
||||
def test_unhandled_exception_envelope_and_generated_correlation_id():
|
||||
app = build_test_app()
|
||||
client = TestClient(app, raise_server_exceptions=False)
|
||||
resp = client.get("/crash")
|
||||
# Should have generated a correlation id and not echo None
|
||||
assert_envelope(resp, 500, "internal_error", has_details=False)
|
||||
assert resp.headers.get("X-Correlation-ID")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user