changes
This commit is contained in:
292
tests/test_jobs_api.py
Normal file
292
tests/test_jobs_api.py
Normal file
@@ -0,0 +1,292 @@
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timezone, timedelta
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
# Ensure required env vars for app import/config
|
||||
os.environ.setdefault("SECRET_KEY", "x" * 32)
|
||||
os.environ.setdefault("DATABASE_URL", "sqlite:////tmp/delphi_test.sqlite")
|
||||
|
||||
from app.main import app # noqa: E402
|
||||
from app.auth.security import get_current_user, get_admin_user # noqa: E402
|
||||
from app.database.base import SessionLocal # noqa: E402
|
||||
from app.models.jobs import JobRecord # noqa: E402
|
||||
from app.models.audit import AuditLog # noqa: E402
|
||||
|
||||
|
||||
class _User:
|
||||
def __init__(self, is_admin: bool):
|
||||
self.id = 1 if is_admin else 2
|
||||
self.username = "admin" if is_admin else "user"
|
||||
self.is_admin = is_admin
|
||||
self.is_active = True
|
||||
self.first_name = "Test"
|
||||
self.last_name = "User"
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def client_admin():
|
||||
app.dependency_overrides[get_current_user] = lambda: _User(True)
|
||||
app.dependency_overrides[get_admin_user] = lambda: _User(True)
|
||||
try:
|
||||
yield TestClient(app)
|
||||
finally:
|
||||
app.dependency_overrides.pop(get_current_user, None)
|
||||
app.dependency_overrides.pop(get_admin_user, None)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def client_user():
|
||||
app.dependency_overrides[get_current_user] = lambda: _User(False)
|
||||
try:
|
||||
yield TestClient(app)
|
||||
finally:
|
||||
app.dependency_overrides.pop(get_current_user, None)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def make_job():
|
||||
"""Factory to create a JobRecord in the test database."""
|
||||
created_job_ids = []
|
||||
|
||||
def _create(
|
||||
*,
|
||||
job_type: str,
|
||||
status: str = "queued",
|
||||
requested_by_username: str = "user",
|
||||
started_at: datetime | None = None,
|
||||
completed_at: datetime | None = None,
|
||||
total_requested: int = 0,
|
||||
total_success: int = 0,
|
||||
total_failed: int = 0,
|
||||
details: dict | None = None,
|
||||
) -> JobRecord:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
j = JobRecord(
|
||||
job_id=uuid.uuid4().hex,
|
||||
job_type=job_type,
|
||||
status=status,
|
||||
requested_by_username=requested_by_username,
|
||||
started_at=started_at or datetime.now(timezone.utc),
|
||||
completed_at=completed_at,
|
||||
total_requested=total_requested,
|
||||
total_success=total_success,
|
||||
total_failed=total_failed,
|
||||
details=details or {},
|
||||
)
|
||||
db.add(j)
|
||||
db.commit()
|
||||
db.refresh(j)
|
||||
created_job_ids.append(j.job_id)
|
||||
return j
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
yield _create
|
||||
|
||||
# Optional: cleanup created jobs to reduce cross-test noise
|
||||
if created_job_ids:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
for jid in created_job_ids:
|
||||
row = db.query(JobRecord).filter(JobRecord.job_id == jid).first()
|
||||
if row:
|
||||
db.delete(row)
|
||||
db.commit()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def test_jobs_list_and_filtering(client_admin: TestClient, client_user: TestClient, make_job):
|
||||
jt = f"pytest_jobs_{uuid.uuid4().hex[:6]}"
|
||||
|
||||
# Create jobs: two for non-admin user, one for admin
|
||||
j_user_running = make_job(job_type=jt, status="running", requested_by_username="user")
|
||||
j_user_failed = make_job(job_type=jt, status="failed", requested_by_username="user")
|
||||
j_admin_completed = make_job(job_type=jt, status="completed", requested_by_username="admin")
|
||||
|
||||
# Non-admin sees only their jobs by default (mine=True)
|
||||
resp = client_user.get("/api/jobs/", params={"include_total": 1})
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert set(body.keys()) == {"items", "total"}
|
||||
assert body["total"] >= 2
|
||||
ids = [it["job_id"] for it in body["items"]]
|
||||
assert j_user_running.job_id in ids and j_user_failed.job_id in ids
|
||||
assert j_admin_completed.job_id not in ids
|
||||
|
||||
# Filter by status
|
||||
resp = client_user.get("/api/jobs/", params={"include_total": 1, "status_filter": "failed"})
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["total"] >= 1
|
||||
assert any(it["job_id"] == j_user_failed.job_id for it in body["items"])
|
||||
|
||||
# Filter by type
|
||||
resp = client_user.get("/api/jobs/", params={"include_total": 1, "type_filter": jt})
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["total"] >= 2
|
||||
|
||||
# Search across job_type
|
||||
resp = client_user.get("/api/jobs/", params={"search": jt})
|
||||
assert resp.status_code == 200
|
||||
ids = [it["job_id"] for it in resp.json()]
|
||||
assert j_user_running.job_id in ids and j_user_failed.job_id in ids
|
||||
|
||||
# Admin can list all with mine=false and filter by requested_by
|
||||
# Because fixtures share global dependency overrides, ensure admin override is active for this request
|
||||
from app.main import app as _app
|
||||
from app.auth.security import get_current_user as _get_current_user
|
||||
_prev = _app.dependency_overrides.get(_get_current_user)
|
||||
try:
|
||||
_app.dependency_overrides[_get_current_user] = lambda: _User(True)
|
||||
c = TestClient(_app)
|
||||
resp = c.get("/api/jobs/", params={"mine": 0, "include_total": 1})
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["total"] >= 3
|
||||
finally:
|
||||
if _prev is not None:
|
||||
_app.dependency_overrides[_get_current_user] = _prev
|
||||
else:
|
||||
_app.dependency_overrides.pop(_get_current_user, None)
|
||||
|
||||
resp = client_admin.get("/api/jobs/", params={"mine": 0, "include_total": 1, "requested_by": "user"})
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["total"] >= 2
|
||||
|
||||
|
||||
def test_jobs_get_by_id_permissions(client_admin: TestClient, client_user: TestClient, make_job):
|
||||
jt = f"pytest_jobs_{uuid.uuid4().hex[:6]}"
|
||||
j_admin = make_job(job_type=jt, status="completed", requested_by_username="admin")
|
||||
j_user = make_job(job_type=jt, status="running", requested_by_username="user")
|
||||
|
||||
# Non-admin cannot access someone else's job
|
||||
resp = client_user.get(f"/api/jobs/{j_admin.job_id}")
|
||||
assert resp.status_code in (403, 404)
|
||||
if resp.status_code == 403:
|
||||
assert "Not enough permissions" in resp.text
|
||||
|
||||
# Non-admin can access own job
|
||||
resp = client_user.get(f"/api/jobs/{j_user.job_id}")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["job_id"] == j_user.job_id
|
||||
|
||||
# Admin can access any job
|
||||
resp = client_admin.get(f"/api/jobs/{j_user.job_id}")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["job_id"] == j_user.job_id
|
||||
|
||||
|
||||
def _audit_exists(job_id: str, action: str) -> bool:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
return (
|
||||
db.query(AuditLog)
|
||||
.filter(AuditLog.resource_type == "JOB", AuditLog.resource_id == job_id, AuditLog.action == action)
|
||||
.count()
|
||||
> 0
|
||||
)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def test_jobs_state_transitions_audit_and_metrics(client_admin: TestClient, make_job):
|
||||
jt = f"pytest_jobs_{uuid.uuid4().hex[:6]}"
|
||||
|
||||
# Create a job and transition to running -> completed
|
||||
j = make_job(job_type=jt, status="queued", requested_by_username="admin")
|
||||
|
||||
resp = client_admin.post(f"/api/jobs/{j.job_id}/mark-running")
|
||||
assert resp.status_code == 200, resp.text
|
||||
body = resp.json()
|
||||
assert body["status"] == "running"
|
||||
assert body["started_at"] is not None
|
||||
assert body["completed_at"] is None
|
||||
assert _audit_exists(j.job_id, "RUNNING")
|
||||
|
||||
resp = client_admin.post(
|
||||
f"/api/jobs/{j.job_id}/mark-completed",
|
||||
json={
|
||||
"total_success": 5,
|
||||
"total_failed": 0,
|
||||
"result_storage_path": "exports/test_bundle.html",
|
||||
"result_mime_type": "text/html",
|
||||
"result_size": 123,
|
||||
"details_update": {"note": "done"},
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200, resp.text
|
||||
body = resp.json()
|
||||
assert body["status"] == "completed"
|
||||
assert body["total_success"] == 5
|
||||
assert body["total_failed"] == 0
|
||||
assert body["has_result_bundle"] is True
|
||||
assert _audit_exists(j.job_id, "COMPLETE")
|
||||
|
||||
# Create another job, transition to running -> failed
|
||||
j2 = make_job(job_type=jt, status="queued", requested_by_username="admin")
|
||||
resp = client_admin.post(f"/api/jobs/{j2.job_id}/mark-running")
|
||||
assert resp.status_code == 200
|
||||
|
||||
resp = client_admin.post(
|
||||
f"/api/jobs/{j2.job_id}/mark-failed",
|
||||
json={"reason": "manual-error", "details_update": {"code": "E1"}},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["status"] == "failed"
|
||||
# Fetch job to verify details include last_error
|
||||
resp = client_admin.get(f"/api/jobs/{j2.job_id}")
|
||||
assert resp.status_code == 200
|
||||
details = resp.json().get("details") or {}
|
||||
assert details.get("last_error") == "manual-error"
|
||||
assert _audit_exists(j2.job_id, "FAIL")
|
||||
|
||||
# Retry an existing job
|
||||
note = "retry it"
|
||||
resp = client_admin.post(f"/api/jobs/{j2.job_id}/retry", json={"note": note})
|
||||
assert resp.status_code == 200
|
||||
new_job_id = resp.json().get("job_id")
|
||||
assert isinstance(new_job_id, str) and new_job_id
|
||||
resp = client_admin.get(f"/api/jobs/{new_job_id}")
|
||||
assert resp.status_code == 200
|
||||
new_details = resp.json().get("details") or {}
|
||||
assert new_details.get("retry_of") == j2.job_id
|
||||
assert new_details.get("retry_note") == note
|
||||
|
||||
# Leave one job running for metrics
|
||||
j3 = make_job(job_type=jt, status="queued", requested_by_username="admin")
|
||||
resp = client_admin.post(f"/api/jobs/{j3.job_id}/mark-running")
|
||||
assert resp.status_code == 200
|
||||
|
||||
# Metrics summary (admin-only)
|
||||
resp = client_admin.get("/api/jobs/metrics/summary")
|
||||
assert resp.status_code == 200
|
||||
metrics = resp.json()
|
||||
# Shape assertions
|
||||
assert set(metrics.keys()) == {
|
||||
"by_status",
|
||||
"by_type",
|
||||
"avg_duration_seconds",
|
||||
"running_count",
|
||||
"failed_last_24h",
|
||||
"completed_last_24h",
|
||||
}
|
||||
assert isinstance(metrics["by_status"], dict)
|
||||
assert isinstance(metrics["by_type"], dict)
|
||||
# Our type appears with count at least the number we created in this test
|
||||
# Created: j (completed), j2 (failed), retry (queued), j3 (running) => 4 with this type
|
||||
assert metrics["by_type"].get(jt, 0) >= 4
|
||||
# Running count should be at least 1 (j3)
|
||||
assert metrics.get("running_count", 0) >= 1
|
||||
# Ensure status buckets contain our transitions (may include others as well)
|
||||
assert metrics["by_status"].get("completed", 0) >= 1
|
||||
assert metrics["by_status"].get("failed", 0) >= 1
|
||||
|
||||
|
||||
474
tests/test_p1_security_features.py
Normal file
474
tests/test_p1_security_features.py
Normal file
@@ -0,0 +1,474 @@
|
||||
"""
|
||||
Test suite for P1 High Priority Security Features
|
||||
|
||||
Tests the security enhancements implemented for:
|
||||
- Rate limiting
|
||||
- Security headers
|
||||
- Enhanced authentication (password complexity, account lockout)
|
||||
- Database security (SQL injection prevention)
|
||||
- CSRF protection and request size limits
|
||||
"""
|
||||
import pytest
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy.orm import Session
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from app.main import app
|
||||
from app.database.base import get_db
|
||||
from app.models.user import User
|
||||
from app.models.auth import LoginAttempt
|
||||
from app.utils.enhanced_auth import (
|
||||
PasswordValidator,
|
||||
AccountLockoutManager,
|
||||
SuspiciousActivityDetector,
|
||||
)
|
||||
from app.utils.database_security import (
|
||||
SQLSecurityValidator,
|
||||
SecureQueryBuilder,
|
||||
execute_secure_query,
|
||||
)
|
||||
from app.middleware.rate_limiting import RateLimitStore
|
||||
|
||||
|
||||
class TestRateLimiting:
|
||||
"""Test rate limiting middleware functionality"""
|
||||
|
||||
def test_rate_limit_store_basic_functionality(self):
|
||||
"""Test basic rate limit store operations"""
|
||||
store = RateLimitStore()
|
||||
|
||||
# Test initial request - should be allowed
|
||||
allowed, info = store.is_allowed("test_key", 5, 60)
|
||||
assert allowed is True
|
||||
assert info["remaining"] == 4
|
||||
assert info["limit"] == 5
|
||||
|
||||
def test_rate_limit_exceeds_threshold(self):
|
||||
"""Test rate limiting when threshold is exceeded"""
|
||||
store = RateLimitStore()
|
||||
|
||||
# Make requests up to the limit
|
||||
for i in range(5):
|
||||
allowed, info = store.is_allowed("test_key", 5, 60)
|
||||
assert allowed is True
|
||||
|
||||
# Next request should be rejected
|
||||
allowed, info = store.is_allowed("test_key", 5, 60)
|
||||
assert allowed is False
|
||||
assert info["remaining"] == 0
|
||||
|
||||
def test_rate_limit_window_reset(self):
|
||||
"""Test rate limit window reset functionality"""
|
||||
store = RateLimitStore()
|
||||
|
||||
# Fill up the rate limit
|
||||
for i in range(5):
|
||||
store.is_allowed("test_key", 5, 1) # 1 second window
|
||||
|
||||
# Should be rejected
|
||||
allowed, info = store.is_allowed("test_key", 5, 1)
|
||||
assert allowed is False
|
||||
|
||||
# Wait for window to reset
|
||||
time.sleep(1.1)
|
||||
|
||||
# Should be allowed again
|
||||
allowed, info = store.is_allowed("test_key", 5, 1)
|
||||
assert allowed is True
|
||||
|
||||
|
||||
class TestSecurityHeaders:
|
||||
"""Test security headers middleware"""
|
||||
|
||||
def test_security_headers_applied(self):
|
||||
"""Test that security headers are applied to responses"""
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get("/health")
|
||||
|
||||
# Check for key security headers
|
||||
assert "X-Content-Type-Options" in response.headers
|
||||
assert response.headers["X-Content-Type-Options"] == "nosniff"
|
||||
|
||||
assert "X-Frame-Options" in response.headers
|
||||
assert response.headers["X-Frame-Options"] == "DENY"
|
||||
|
||||
assert "X-XSS-Protection" in response.headers
|
||||
assert response.headers["X-XSS-Protection"] == "1; mode=block"
|
||||
|
||||
def test_csp_header_present(self):
|
||||
"""Test Content Security Policy header"""
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get("/health")
|
||||
|
||||
# CSP header should be present
|
||||
if "Content-Security-Policy" in response.headers:
|
||||
csp = response.headers["Content-Security-Policy"]
|
||||
assert "default-src 'self'" in csp
|
||||
assert "object-src 'none'" in csp
|
||||
assert "frame-ancestors 'none'" in csp
|
||||
|
||||
def test_request_size_limit(self):
|
||||
"""Test request size limiting"""
|
||||
client = TestClient(app)
|
||||
|
||||
# Create a large payload (should be rejected if limit is enforced)
|
||||
large_data = "x" * (101 * 1024 * 1024) # 101MB
|
||||
|
||||
response = client.post(
|
||||
"/api/auth/login",
|
||||
json={"username": "test", "password": large_data},
|
||||
headers={"Content-Length": str(len(large_data))}
|
||||
)
|
||||
|
||||
# Should be rejected for size
|
||||
assert response.status_code == 413 # Request Entity Too Large
|
||||
|
||||
|
||||
class TestPasswordValidation:
|
||||
"""Test password strength validation"""
|
||||
|
||||
def test_weak_passwords_rejected(self):
|
||||
"""Test that weak passwords are properly rejected"""
|
||||
weak_passwords = [
|
||||
"123456",
|
||||
"password",
|
||||
"qwerty",
|
||||
"abc123",
|
||||
"Password", # Missing special chars and numbers
|
||||
"p@ssw0rd", # Too common
|
||||
]
|
||||
|
||||
for password in weak_passwords:
|
||||
is_valid, errors = PasswordValidator.validate_password_strength(password)
|
||||
assert is_valid is False
|
||||
assert len(errors) > 0
|
||||
|
||||
def test_strong_passwords_accepted(self):
|
||||
"""Test that strong passwords are accepted"""
|
||||
strong_passwords = [
|
||||
"MyStr0ng!P@ssw0rd",
|
||||
"C0mpl3x&S3cur3!",
|
||||
"Adm1n!2024$ecure",
|
||||
"Ungu3ss@bl3P@ssw0rd!",
|
||||
]
|
||||
|
||||
for password in strong_passwords:
|
||||
is_valid, errors = PasswordValidator.validate_password_strength(password)
|
||||
assert is_valid is True
|
||||
assert len(errors) == 0
|
||||
|
||||
def test_password_strength_scoring(self):
|
||||
"""Test password strength scoring"""
|
||||
# Weak password
|
||||
weak_score = PasswordValidator.generate_password_strength_score("123456")
|
||||
assert weak_score < 30
|
||||
|
||||
# Strong password
|
||||
strong_score = PasswordValidator.generate_password_strength_score("MyStr0ng!P@ssw0rd")
|
||||
assert strong_score > 70
|
||||
|
||||
def test_password_complexity_requirements(self):
|
||||
"""Test individual password complexity requirements"""
|
||||
# Test length requirement
|
||||
is_valid, errors = PasswordValidator.validate_password_strength("Abc1!")
|
||||
assert not is_valid
|
||||
assert any("at least" in error for error in errors)
|
||||
|
||||
# Test uppercase requirement
|
||||
is_valid, errors = PasswordValidator.validate_password_strength("abc123!@#")
|
||||
assert not is_valid
|
||||
assert any("uppercase" in error for error in errors)
|
||||
|
||||
# Test special character requirement
|
||||
is_valid, errors = PasswordValidator.validate_password_strength("Abc123456")
|
||||
assert not is_valid
|
||||
assert any("special character" in error for error in errors)
|
||||
|
||||
|
||||
class TestAccountLockout:
|
||||
"""Test account lockout functionality"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db(self):
|
||||
"""Mock database session"""
|
||||
mock_db = MagicMock()
|
||||
return mock_db
|
||||
|
||||
def test_lockout_after_failed_attempts(self, mock_db):
|
||||
"""Test account lockout after multiple failed attempts"""
|
||||
# Mock failed login attempts
|
||||
mock_db.query.return_value.filter.return_value.scalar.return_value = 5
|
||||
mock_db.query.return_value.filter.return_value.order_by.return_value.first.return_value = (
|
||||
datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
is_locked, unlock_time = AccountLockoutManager.is_account_locked(mock_db, "testuser")
|
||||
assert is_locked is True
|
||||
assert unlock_time is not None
|
||||
|
||||
def test_no_lockout_with_few_attempts(self, mock_db):
|
||||
"""Test no lockout with fewer than threshold attempts"""
|
||||
# Mock fewer failed attempts
|
||||
mock_db.query.return_value.filter.return_value.scalar.return_value = 2
|
||||
|
||||
is_locked, unlock_time = AccountLockoutManager.is_account_locked(mock_db, "testuser")
|
||||
assert is_locked is False
|
||||
assert unlock_time is None
|
||||
|
||||
def test_lockout_info_retrieval(self, mock_db):
|
||||
"""Test lockout information retrieval"""
|
||||
# Mock database responses
|
||||
mock_db.query.return_value.filter.return_value.order_by.return_value.all.return_value = []
|
||||
mock_db.query.return_value.filter.return_value.scalar.return_value = 3
|
||||
|
||||
lockout_info = AccountLockoutManager.get_lockout_info(mock_db, "testuser")
|
||||
|
||||
assert "is_locked" in lockout_info
|
||||
assert "failed_attempts" in lockout_info
|
||||
assert "attempts_remaining" in lockout_info
|
||||
assert "max_attempts" in lockout_info
|
||||
|
||||
|
||||
class TestDatabaseSecurity:
|
||||
"""Test database security utilities"""
|
||||
|
||||
def test_sql_injection_detection(self):
|
||||
"""Test SQL injection pattern detection"""
|
||||
# Test legitimate query
|
||||
legitimate_query = "SELECT * FROM users WHERE username = :username"
|
||||
issues = SQLSecurityValidator.validate_query_string(legitimate_query)
|
||||
assert len(issues) == 0
|
||||
|
||||
# Test malicious queries
|
||||
malicious_queries = [
|
||||
"SELECT * FROM users WHERE id = 1; DROP TABLE users; --",
|
||||
"SELECT * FROM users WHERE username = 'admin' OR 1=1 --",
|
||||
"SELECT * FROM users UNION SELECT password FROM admin_table",
|
||||
"SELECT * FROM users WHERE id = 1 AND (SELECT COUNT(*) FROM passwords) > 0",
|
||||
]
|
||||
|
||||
for query in malicious_queries:
|
||||
issues = SQLSecurityValidator.validate_query_string(query)
|
||||
assert len(issues) > 0
|
||||
|
||||
def test_parameter_validation(self):
|
||||
"""Test parameter value validation"""
|
||||
# Test legitimate parameters
|
||||
legitimate_params = {
|
||||
"username": "john_doe",
|
||||
"age": 25,
|
||||
"email": "john@example.com",
|
||||
}
|
||||
|
||||
for param_name, param_value in legitimate_params.items():
|
||||
issues = SQLSecurityValidator.validate_parameter_value(param_name, param_value)
|
||||
assert len(issues) == 0
|
||||
|
||||
# Test malicious parameters
|
||||
malicious_params = {
|
||||
"username": "'; DROP TABLE users; --",
|
||||
"search": "admin' OR 1=1 --",
|
||||
"id": "1 UNION SELECT password FROM users",
|
||||
}
|
||||
|
||||
for param_name, param_value in malicious_params.items():
|
||||
issues = SQLSecurityValidator.validate_parameter_value(param_name, param_value)
|
||||
assert len(issues) > 0
|
||||
|
||||
def test_safe_query_building(self):
|
||||
"""Test safe query building utilities"""
|
||||
from sqlalchemy import Column, String
|
||||
|
||||
# Mock column for testing
|
||||
mock_column = Column("name", String)
|
||||
|
||||
# Test safe LIKE clause building
|
||||
like_clause = SecureQueryBuilder.build_like_clause(mock_column, "test_value")
|
||||
assert like_clause is not None
|
||||
|
||||
# Test safe IN clause building
|
||||
in_clause = SecureQueryBuilder.build_in_clause(mock_column, ["value1", "value2"])
|
||||
assert in_clause is not None
|
||||
|
||||
# Test FTS query building
|
||||
fts_query = SecureQueryBuilder.build_fts_query(["term1", "term2"])
|
||||
assert "term1" in fts_query
|
||||
assert "term2" in fts_query
|
||||
|
||||
def test_query_validation_with_params(self):
|
||||
"""Test complete query validation with parameters"""
|
||||
query = "SELECT * FROM users WHERE username = :username AND age > :min_age"
|
||||
params = {
|
||||
"username": "john_doe",
|
||||
"min_age": 18,
|
||||
}
|
||||
|
||||
issues = SQLSecurityValidator.validate_query_with_params(query, params)
|
||||
assert len(issues) == 0
|
||||
|
||||
# Test with malicious parameters
|
||||
malicious_params = {
|
||||
"username": "'; DROP TABLE users; --",
|
||||
"min_age": "18 OR 1=1",
|
||||
}
|
||||
|
||||
issues = SQLSecurityValidator.validate_query_with_params(query, malicious_params)
|
||||
assert len(issues) > 0
|
||||
|
||||
|
||||
class TestSuspiciousActivityDetection:
|
||||
"""Test suspicious activity detection"""
|
||||
|
||||
def test_suspicious_ip_detection(self):
|
||||
"""Test detection of suspicious IP patterns"""
|
||||
# This would typically require mock data or a test database
|
||||
# For now, test the structure
|
||||
mock_db = MagicMock()
|
||||
mock_db.query.return_value.filter.return_value.all.return_value = []
|
||||
|
||||
alerts = SuspiciousActivityDetector.detect_suspicious_patterns(mock_db, 24)
|
||||
assert isinstance(alerts, list)
|
||||
|
||||
def test_login_suspicion_check(self):
|
||||
"""Test individual login suspicion checking"""
|
||||
mock_db = MagicMock()
|
||||
mock_db.query.return_value.filter.return_value.distinct.return_value.all.return_value = []
|
||||
mock_db.query.return_value.filter.return_value.scalar.return_value = 0
|
||||
|
||||
is_suspicious, warnings = SuspiciousActivityDetector.is_login_suspicious(
|
||||
mock_db, "testuser", "192.168.1.1", "Mozilla/5.0..."
|
||||
)
|
||||
|
||||
assert isinstance(is_suspicious, bool)
|
||||
assert isinstance(warnings, list)
|
||||
|
||||
|
||||
class TestCSRFProtection:
|
||||
"""Test CSRF protection middleware"""
|
||||
|
||||
def test_csrf_protection_on_post_requests(self):
|
||||
"""Test CSRF protection for POST requests"""
|
||||
client = TestClient(app)
|
||||
|
||||
# POST request without proper origin/referer should be blocked
|
||||
response = client.post(
|
||||
"/api/auth/login",
|
||||
json={"username": "test", "password": "test"},
|
||||
headers={
|
||||
"Origin": "https://malicious-site.com",
|
||||
"Host": "localhost:8000"
|
||||
}
|
||||
)
|
||||
|
||||
# Should be blocked by CSRF protection
|
||||
# Note: This test might need adjustment based on actual CSRF implementation
|
||||
assert response.status_code in [403, 401, 429] # Could be various security responses
|
||||
|
||||
def test_csrf_exemption_for_safe_methods(self):
|
||||
"""Test that safe HTTP methods are exempt from CSRF protection"""
|
||||
client = TestClient(app)
|
||||
|
||||
# GET requests should not be subject to CSRF protection
|
||||
response = client.get("/health")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
class TestIntegratedSecurity:
|
||||
"""Test integrated security features working together"""
|
||||
|
||||
def test_enhanced_login_endpoint(self):
|
||||
"""Test the enhanced login endpoint with security features"""
|
||||
client = TestClient(app)
|
||||
|
||||
# Test login with invalid credentials
|
||||
response = client.post(
|
||||
"/api/auth/login",
|
||||
json={"username": "nonexistent", "password": "wrongpassword"}
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
# Check for security headers in response
|
||||
assert "X-RateLimit-Limit" in response.headers or "WWW-Authenticate" in response.headers
|
||||
|
||||
@patch('app.utils.enhanced_auth.validate_and_authenticate_user')
|
||||
def test_account_lockout_headers(self, mock_auth):
|
||||
"""Test that account lockout information is returned in headers"""
|
||||
client = TestClient(app)
|
||||
|
||||
# Mock a locked account
|
||||
mock_auth.return_value = (None, ["Account is locked"])
|
||||
|
||||
response = client.post(
|
||||
"/api/auth/login",
|
||||
json={"username": "locked_user", "password": "password"}
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_password_validation_endpoint(self):
|
||||
"""Test the password validation endpoint"""
|
||||
client = TestClient(app)
|
||||
|
||||
# Test weak password
|
||||
response = client.post(
|
||||
"/api/auth/validate-password",
|
||||
json={"password": "123456"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["is_valid"] is False
|
||||
assert len(data["errors"]) > 0
|
||||
|
||||
# Test strong password
|
||||
response = client.post(
|
||||
"/api/auth/validate-password",
|
||||
json={"password": "MyStr0ng!P@ssw0rd"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["is_valid"] is True
|
||||
assert len(data["errors"]) == 0
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
"""Test client fixture"""
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def test_security_middleware_integration(client):
|
||||
"""Test that all security middleware is properly integrated"""
|
||||
response = client.get("/health")
|
||||
|
||||
# Should have rate limiting headers
|
||||
rate_limit_headers = [h for h in response.headers.keys() if h.startswith("X-RateLimit")]
|
||||
|
||||
# Should have security headers
|
||||
security_headers = ["X-Content-Type-Options", "X-Frame-Options", "X-XSS-Protection"]
|
||||
for header in security_headers:
|
||||
assert header in response.headers
|
||||
|
||||
# Should have correlation ID
|
||||
assert "X-Correlation-ID" in response.headers
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run specific test for quick validation
|
||||
test_rate_limiting = TestRateLimiting()
|
||||
test_rate_limiting.test_rate_limit_store_basic_functionality()
|
||||
|
||||
test_password = TestPasswordValidation()
|
||||
test_password.test_weak_passwords_rejected()
|
||||
test_password.test_strong_passwords_accepted()
|
||||
|
||||
test_db_security = TestDatabaseSecurity()
|
||||
test_db_security.test_sql_injection_detection()
|
||||
|
||||
print("✅ All P1 security tests passed!")
|
||||
555
tests/test_pension_valuation.py
Normal file
555
tests/test_pension_valuation.py
Normal file
@@ -0,0 +1,555 @@
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import math
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
# Ensure required env vars for app import/config
|
||||
os.environ.setdefault("SECRET_KEY", "x" * 32)
|
||||
os.environ.setdefault("DATABASE_URL", "sqlite:////tmp/delphi_test.sqlite")
|
||||
|
||||
# Ensure repository root on sys.path for direct test runs
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
from app.main import app # noqa: E402
|
||||
from app.auth.security import get_current_user # noqa: E402
|
||||
from app.database.base import SessionLocal # noqa: E402
|
||||
from app.models.pensions import NumberTable, LifeTable # noqa: E402
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def client():
|
||||
# Override auth to bypass JWT for these tests
|
||||
class _User:
|
||||
def __init__(self):
|
||||
self.id = "test"
|
||||
self.username = "tester"
|
||||
self.is_admin = True
|
||||
self.is_active = True
|
||||
|
||||
app.dependency_overrides[get_current_user] = lambda: _User()
|
||||
|
||||
try:
|
||||
yield TestClient(app)
|
||||
finally:
|
||||
app.dependency_overrides.pop(get_current_user, None)
|
||||
|
||||
|
||||
def _seed_monthly_na_series():
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# Clear and seed NumberTable for months 0..4 for both sexes, race A
|
||||
for m in range(0, 5):
|
||||
db.query(NumberTable).filter(NumberTable.month == m).delete()
|
||||
# NA series: [100000, 90000, 80000, 70000, 60000]
|
||||
series = [100000.0, 90000.0, 80000.0, 70000.0, 60000.0]
|
||||
for idx, na in enumerate(series):
|
||||
row = NumberTable(
|
||||
month=idx,
|
||||
na_am=na, # male (all races)
|
||||
na_af=na, # female (all races)
|
||||
na_aa=na, # all (all races)
|
||||
)
|
||||
db.add(row)
|
||||
|
||||
# Provide an LE row in case fallback is used inadvertently
|
||||
db.query(LifeTable).filter(LifeTable.age == 65).delete()
|
||||
db.add(LifeTable(age=65, le_am=20.0, le_af=22.0, le_aa=21.0))
|
||||
db.commit()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def _monthly_rate(pct: float) -> float:
|
||||
return pow(1.0 + pct / 100.0, 1.0 / 12.0) - 1.0
|
||||
|
||||
|
||||
def test_single_life_no_discount_no_cola(client: TestClient):
|
||||
_seed_monthly_na_series()
|
||||
|
||||
payload = {
|
||||
"monthly_benefit": 1000.0,
|
||||
"term_months": 5,
|
||||
"start_age": 65,
|
||||
"sex": "M",
|
||||
"race": "A",
|
||||
"discount_rate": 0.0,
|
||||
"cola_rate": 0.0,
|
||||
}
|
||||
resp = client.post("/api/pensions/valuation/single-life", json=payload)
|
||||
assert resp.status_code == 200, resp.text
|
||||
# Survival probs: [1.0, 0.9, 0.8, 0.7, 0.6]
|
||||
# PV = 1000 * sum = 4000
|
||||
assert math.isclose(resp.json()["pv"], 4000.0, rel_tol=1e-6, abs_tol=0.01)
|
||||
|
||||
|
||||
def test_joint_survivor_no_discount_no_cola(client: TestClient):
|
||||
_seed_monthly_na_series()
|
||||
|
||||
payload = {
|
||||
"monthly_benefit": 1000.0,
|
||||
"term_months": 5,
|
||||
"participant_age": 65,
|
||||
"participant_sex": "M",
|
||||
"participant_race": "A",
|
||||
"spouse_age": 63,
|
||||
"spouse_sex": "F",
|
||||
"spouse_race": "A",
|
||||
"survivor_percent": 50.0,
|
||||
"discount_rate": 0.0,
|
||||
"cola_rate": 0.0,
|
||||
}
|
||||
resp = client.post("/api/pensions/valuation/joint-survivor", json=payload)
|
||||
assert resp.status_code == 200, resp.text
|
||||
data = resp.json()
|
||||
# p = [1, .9, .8, .7, .6]
|
||||
# p_both = p^2 = [1, .81, .64, .49, .36]
|
||||
# payment_t = 1000*p_both + 1000*0.5*(p - p_both) = 500*p + 500*p_both
|
||||
expected = 500.0 * (1 + 0.9 + 0.8 + 0.7 + 0.6) + 500.0 * (1 + 0.81 + 0.64 + 0.49 + 0.36)
|
||||
assert math.isclose(data["pv_total"], expected, rel_tol=1e-6, abs_tol=0.01)
|
||||
# Components
|
||||
expected_both = 1000.0 * (1 + 0.81 + 0.64 + 0.49 + 0.36)
|
||||
expected_surv = expected - expected_both
|
||||
assert math.isclose(data["pv_participant_component"], expected_both, rel_tol=1e-6, abs_tol=0.01)
|
||||
assert math.isclose(data["pv_survivor_component"], expected_surv, rel_tol=1e-6, abs_tol=0.01)
|
||||
|
||||
|
||||
def test_joint_survivor_last_survivor_basis(client: TestClient):
|
||||
_seed_monthly_na_series()
|
||||
|
||||
payload = {
|
||||
"monthly_benefit": 1000.0,
|
||||
"term_months": 5,
|
||||
"participant_age": 65,
|
||||
"participant_sex": "M",
|
||||
"participant_race": "A",
|
||||
"spouse_age": 63,
|
||||
"spouse_sex": "F",
|
||||
"spouse_race": "A",
|
||||
"survivor_percent": 50.0, # should be ignored in last_survivor for total
|
||||
"discount_rate": 0.0,
|
||||
"cola_rate": 0.0,
|
||||
"survivor_basis": "last_survivor",
|
||||
}
|
||||
resp = client.post("/api/pensions/valuation/joint-survivor", json=payload)
|
||||
assert resp.status_code == 200, resp.text
|
||||
data = resp.json()
|
||||
# p = [1, .9, .8, .7, .6]; p_both = p^2; p_either = p_part + p_sp - p_both = 2p - p^2
|
||||
p = [1.0, 0.9, 0.8, 0.7, 0.6]
|
||||
p_both = [x * x for x in p]
|
||||
p_either = [2 * x - x * x for x in p]
|
||||
expected_total = 1000.0 * sum(p_either)
|
||||
assert math.isclose(data["pv_total"], expected_total, rel_tol=1e-6, abs_tol=0.01)
|
||||
|
||||
|
||||
def test_interpolation_between_number_rows(client: TestClient):
|
||||
# Seed months 0 and 2 only; expect linear interpolation for month 1
|
||||
db = SessionLocal()
|
||||
try:
|
||||
from app.models.pensions import NumberTable, LifeTable
|
||||
for m in range(0, 3):
|
||||
db.query(NumberTable).filter(NumberTable.month == m).delete()
|
||||
# only months 0 and 2
|
||||
db.add(NumberTable(month=0, na_am=100000.0, na_af=100000.0, na_aa=100000.0))
|
||||
db.add(NumberTable(month=2, na_am=80000.0, na_af=80000.0, na_aa=80000.0))
|
||||
# Ensure LE exists but should not be used due to interpolation
|
||||
db.query(LifeTable).filter(LifeTable.age == 65).delete()
|
||||
db.add(LifeTable(age=65, le_am=30.0, le_af=30.0, le_aa=30.0))
|
||||
db.commit()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
payload = {
|
||||
"monthly_benefit": 1000.0,
|
||||
"term_months": 3,
|
||||
"start_age": 65,
|
||||
"sex": "A",
|
||||
"race": "A",
|
||||
"discount_rate": 0.0,
|
||||
"cola_rate": 0.0,
|
||||
"interpolation_method": "linear",
|
||||
}
|
||||
resp = client.post("/api/pensions/valuation/single-life", json=payload)
|
||||
assert resp.status_code == 200, resp.text
|
||||
# survival at t: 0->1.0, 1->0.9 (interpolated), 2->0.8
|
||||
expected = 1000.0 * (1.0 + 0.9 + 0.8)
|
||||
assert math.isclose(resp.json()["pv"], expected, rel_tol=1e-6, abs_tol=0.01)
|
||||
|
||||
|
||||
def test_interpolation_step_method(client: TestClient):
|
||||
# Seed months 0 and 3 only; with step interpolation, months 1 and 2 carry month 0 value
|
||||
db = SessionLocal()
|
||||
try:
|
||||
from app.models.pensions import NumberTable
|
||||
for m in range(0, 4):
|
||||
db.query(NumberTable).filter(NumberTable.month == m).delete()
|
||||
db.add(NumberTable(month=0, na_aa=100000.0))
|
||||
db.add(NumberTable(month=3, na_aa=70000.0))
|
||||
db.commit()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
payload = {
|
||||
"monthly_benefit": 1000.0,
|
||||
"term_months": 4,
|
||||
"start_age": 65,
|
||||
"sex": "A",
|
||||
"race": "A",
|
||||
"discount_rate": 0.0,
|
||||
"cola_rate": 0.0,
|
||||
"interpolation_method": "step",
|
||||
}
|
||||
resp = client.post("/api/pensions/valuation/single-life", json=payload)
|
||||
assert resp.status_code == 200, resp.text
|
||||
# Series becomes NA: [100k, 100k, 100k, 70k] -> p: [1.0,1.0,1.0,0.7]
|
||||
expected = 1000.0 * (1.0 + 1.0 + 1.0 + 0.7)
|
||||
assert math.isclose(resp.json()["pv"], expected, rel_tol=1e-6, abs_tol=0.01)
|
||||
|
||||
|
||||
def test_max_age_truncation(client: TestClient):
|
||||
_seed_monthly_na_series()
|
||||
# start_age 65, max_age 66 -> at most 12 months
|
||||
payload = {
|
||||
"monthly_benefit": 1000.0,
|
||||
"term_months": 24,
|
||||
"start_age": 65,
|
||||
"sex": "A",
|
||||
"race": "A",
|
||||
"discount_rate": 0.0,
|
||||
"cola_rate": 0.0,
|
||||
"max_age": 66,
|
||||
}
|
||||
resp = client.post("/api/pensions/valuation/single-life", json=payload)
|
||||
assert resp.status_code == 200
|
||||
# With our seed p falls by 0.1 each month; but we only have 5 months seeded; extend to 12
|
||||
db = SessionLocal()
|
||||
try:
|
||||
from app.models.pensions import NumberTable
|
||||
for m in range(5, 12):
|
||||
db.query(NumberTable).filter(NumberTable.month == m).delete()
|
||||
db.add(NumberTable(month=m, na_aa=100000.0 - 10000.0 * m))
|
||||
db.commit()
|
||||
finally:
|
||||
db.close()
|
||||
# Recompute with extension
|
||||
resp = client.post("/api/pensions/valuation/single-life", json=payload)
|
||||
assert resp.status_code == 200
|
||||
# p: 1.0,0.9,0.8,0.7,0.6,0.5,0.4,0.3,0.2,0.1,0.0 (clamped by interpolation), 0.0 -> sum first 10 effectively
|
||||
expected = 1000.0 * (1.0 + 0.9 + 0.8 + 0.7 + 0.6 + 0.5 + 0.4 + 0.3 + 0.2 + 0.1)
|
||||
assert math.isclose(resp.json()["pv"], expected, rel_tol=1e-6, abs_tol=0.1)
|
||||
|
||||
|
||||
def test_single_life_with_discount_and_cola(client: TestClient):
|
||||
_seed_monthly_na_series()
|
||||
# discount 12% annual, COLA 12% annual
|
||||
d_m = _monthly_rate(12.0)
|
||||
g_m = _monthly_rate(12.0)
|
||||
probs = [1.0, 0.9, 0.8, 0.7, 0.6]
|
||||
monthly = 1000.0
|
||||
|
||||
# Expected PV: sum( monthly * p[t] * (1+g)^t / (1+i)^t )
|
||||
expected = 0.0
|
||||
g = 1.0
|
||||
d = 1.0
|
||||
for t, p in enumerate(probs):
|
||||
if t == 0:
|
||||
g = 1.0
|
||||
d = 1.0
|
||||
else:
|
||||
g *= (1.0 + g_m)
|
||||
d *= (1.0 + d_m)
|
||||
expected += monthly * p * g / d
|
||||
|
||||
payload = {
|
||||
"monthly_benefit": monthly,
|
||||
"term_months": 5,
|
||||
"start_age": 65,
|
||||
"sex": "A",
|
||||
"race": "A",
|
||||
"discount_rate": 12.0,
|
||||
"cola_rate": 12.0,
|
||||
"cola_mode": "monthly",
|
||||
}
|
||||
resp = client.post("/api/pensions/valuation/single-life", json=payload)
|
||||
assert resp.status_code == 200, resp.text
|
||||
assert math.isclose(resp.json()["pv"], expected, rel_tol=1e-6, abs_tol=0.05)
|
||||
|
||||
|
||||
def test_single_life_cola_annual_prorated_and_cap(client: TestClient):
|
||||
_seed_monthly_na_series()
|
||||
# Extend months 0..11 to allow a full year
|
||||
db = SessionLocal()
|
||||
try:
|
||||
from app.models.pensions import NumberTable
|
||||
for m in range(5, 12):
|
||||
db.query(NumberTable).filter(NumberTable.month == m).delete()
|
||||
db.add(NumberTable(month=m, na_aa=100000.0 - 10000.0 * m))
|
||||
db.commit()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# annual COLA 6% but cap at 4%; annual prorated mode
|
||||
# Payments monthly for 6 months, p[t]=1,.9,.8,.7,.6,.5
|
||||
# Growth at t uses 4% cap prorated: factors ~ 1.0, 1+0.04*(1/12), 1+0.04*(2/12), ...
|
||||
monthly = 1000.0
|
||||
probs = [1.0, 0.9, 0.8, 0.7, 0.6, 0.5]
|
||||
a = 0.04
|
||||
expected = 0.0
|
||||
for t, p in enumerate(probs):
|
||||
growth = (1.0 + a * (t / 12.0))
|
||||
expected += monthly * p * growth
|
||||
|
||||
payload = {
|
||||
"monthly_benefit": monthly,
|
||||
"term_months": 6,
|
||||
"start_age": 65,
|
||||
"sex": "A",
|
||||
"race": "A",
|
||||
"discount_rate": 0.0,
|
||||
"cola_rate": 6.0,
|
||||
"cola_mode": "annual_prorated",
|
||||
"cola_cap_percent": 4.0,
|
||||
}
|
||||
resp = client.post("/api/pensions/valuation/single-life", json=payload)
|
||||
assert resp.status_code == 200, resp.text
|
||||
assert math.isclose(resp.json()["pv"], expected, rel_tol=1e-6, abs_tol=0.05)
|
||||
|
||||
|
||||
def test_single_life_quarterly_with_deferral(client: TestClient):
|
||||
_seed_monthly_na_series()
|
||||
# Quarterly payments (3 months per payment), starting after 2-month deferral
|
||||
# Horizon 10 months, payments at t=2,5,8 (assuming 0-indexed months)
|
||||
# Survival probs p[t] = [1, .9, .8, .7, .6, .5, .4, .3, .2, .1] for first 10 implied by seed extension
|
||||
# But our seed has only 5 months; extend seed
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# extend NumberTable to months 0..9 with linear decrement 100k - 10k*m
|
||||
from app.models.pensions import NumberTable
|
||||
for m in range(5, 10):
|
||||
db.query(NumberTable).filter(NumberTable.month == m).delete()
|
||||
db.add(NumberTable(month=m, na_aa=100000.0 - 10000.0 * m))
|
||||
db.commit()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
payload = {
|
||||
"monthly_benefit": 1000.0,
|
||||
"term_months": 10,
|
||||
"start_age": 65,
|
||||
"sex": "A",
|
||||
"race": "A",
|
||||
"discount_rate": 0.0,
|
||||
"cola_rate": 0.0,
|
||||
"defer_months": 2,
|
||||
"payment_period_months": 3,
|
||||
}
|
||||
resp = client.post("/api/pensions/valuation/single-life", json=payload)
|
||||
assert resp.status_code == 200, resp.text
|
||||
# p = [1.0, .9, .8, .7, .6, .5, .4, .3, .2, .1]
|
||||
# Payments at t=2,5,8: amount = monthly * 3 * p[t]
|
||||
expected = 1000.0 * 3.0 * (0.8 + 0.5 + 0.2)
|
||||
assert math.isclose(resp.json()["pv"], expected, rel_tol=1e-6, abs_tol=0.01)
|
||||
|
||||
|
||||
def test_fractional_deferral_prorated_first_payment(client: TestClient):
|
||||
_seed_monthly_na_series()
|
||||
# Monthly payments, defer 0.5 month: first payment at month 1 with 50% of base
|
||||
# With p = [1.0, 0.9, 0.8, 0.7, 0.6]
|
||||
monthly = 1000.0
|
||||
payload = {
|
||||
"monthly_benefit": monthly,
|
||||
"term_months": 5,
|
||||
"start_age": 65,
|
||||
"sex": "A",
|
||||
"race": "A",
|
||||
"discount_rate": 0.0,
|
||||
"cola_rate": 0.0,
|
||||
"defer_months": 0.5,
|
||||
"payment_period_months": 1,
|
||||
}
|
||||
resp = client.post("/api/pensions/valuation/single-life", json=payload)
|
||||
assert resp.status_code == 200, resp.text
|
||||
# First payment at t=1: 0.5*1000*0.9 + subsequent: 1000*0.8 + 1000*0.7 + 1000*0.6
|
||||
expected = 0.5 * monthly * 0.9 + monthly * (0.8 + 0.7 + 0.6)
|
||||
assert math.isclose(resp.json()["pv"], expected, rel_tol=1e-6, abs_tol=0.01)
|
||||
|
||||
|
||||
def test_joint_survivor_participant_only_commencement(client: TestClient):
|
||||
_seed_monthly_na_series()
|
||||
payload = {
|
||||
"monthly_benefit": 1000.0,
|
||||
"term_months": 5,
|
||||
"participant_age": 65,
|
||||
"participant_sex": "M",
|
||||
"participant_race": "A",
|
||||
"spouse_age": 63,
|
||||
"spouse_sex": "F",
|
||||
"spouse_race": "A",
|
||||
"survivor_percent": 50.0,
|
||||
"discount_rate": 0.0,
|
||||
"cola_rate": 0.0,
|
||||
"survivor_commence_participant_only": True,
|
||||
}
|
||||
resp = client.post("/api/pensions/valuation/joint-survivor", json=payload)
|
||||
assert resp.status_code == 200, resp.text
|
||||
data = resp.json()
|
||||
# p = [1, .9, .8, .7, .6]; p_both = p^2; p_sp_only replaced by p_part for survivor component
|
||||
p = [1.0, 0.9, 0.8, 0.7, 0.6]
|
||||
p_both = [x * x for x in p]
|
||||
surv_component = sum(1000.0 * 0.5 * x for x in p) # using participant survival
|
||||
both_component = sum(1000.0 * x for x in p_both)
|
||||
expected_total = 1000.0 + (both_component - 1000.0) + surv_component # first period total includes guarantee at t=0 (from previous tests)
|
||||
# Given our service guarantees only if certain_months > 0, here it's 0, so no guarantee. Recompute expected total accordingly
|
||||
expected_total = both_component + surv_component
|
||||
assert math.isclose(data["pv_participant_component"], both_component, rel_tol=1e-6, abs_tol=0.01)
|
||||
assert math.isclose(data["pv_survivor_component"], surv_component, rel_tol=1e-6, abs_tol=0.01)
|
||||
assert math.isclose(data["pv_total"], expected_total, rel_tol=1e-6, abs_tol=0.01)
|
||||
|
||||
|
||||
def test_certain_period_guarantee_then_mortality(client: TestClient):
|
||||
_seed_monthly_na_series()
|
||||
# Extend 0..9 again
|
||||
db = SessionLocal()
|
||||
try:
|
||||
from app.models.pensions import NumberTable
|
||||
for m in range(5, 10):
|
||||
db.query(NumberTable).filter(NumberTable.month == m).delete()
|
||||
db.add(NumberTable(month=m, na_aa=100000.0 - 10000.0 * m))
|
||||
db.commit()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# Monthly payments, no deferral, 3 months certain
|
||||
payload = {
|
||||
"monthly_benefit": 1000.0,
|
||||
"term_months": 6,
|
||||
"start_age": 65,
|
||||
"sex": "A",
|
||||
"race": "A",
|
||||
"discount_rate": 0.0,
|
||||
"cola_rate": 0.0,
|
||||
"defer_months": 0,
|
||||
"payment_period_months": 1,
|
||||
"certain_months": 3,
|
||||
}
|
||||
resp = client.post("/api/pensions/valuation/single-life", json=payload)
|
||||
assert resp.status_code == 200, resp.text
|
||||
# p = [1.0, .9, .8, .7, .6, .5]
|
||||
# guaranteed first 3: 1000, 1000, 1000; then mortality-weighted: 700, 600, 500
|
||||
expected = 1000.0 * 3 + 1000.0 * (0.7 + 0.6 + 0.5)
|
||||
assert math.isclose(resp.json()["pv"], expected, rel_tol=1e-6, abs_tol=0.01)
|
||||
|
||||
# Joint-survivor: same guarantee logic applies to total stream
|
||||
payload_js = {
|
||||
"monthly_benefit": 1000.0,
|
||||
"term_months": 4,
|
||||
"participant_age": 65,
|
||||
"participant_sex": "M",
|
||||
"participant_race": "A",
|
||||
"spouse_age": 63,
|
||||
"spouse_sex": "F",
|
||||
"spouse_race": "A",
|
||||
"survivor_percent": 50.0,
|
||||
"discount_rate": 0.0,
|
||||
"cola_rate": 0.0,
|
||||
"defer_months": 0,
|
||||
"payment_period_months": 1,
|
||||
"certain_months": 2,
|
||||
}
|
||||
resp = client.post("/api/pensions/valuation/joint-survivor", json=payload_js)
|
||||
assert resp.status_code == 200, resp.text
|
||||
data = resp.json()
|
||||
# p = [1.0, .9, .8, .7]
|
||||
# total mortality stream (no guarantee): 1000*p_both + 500*(p - p_both)
|
||||
p = [1.0, 0.9, 0.8, 0.7]
|
||||
p_both = [x * x for x in p]
|
||||
mort = [1000.0 * pb + 500.0 * (x - pb) for x, pb in zip(p, p_both)]
|
||||
total_expected = 1000.0 + 1000.0 + mort[2] + mort[3]
|
||||
assert math.isclose(data["pv_total"], total_expected, rel_tol=1e-6, abs_tol=0.01)
|
||||
|
||||
|
||||
def test_batch_single_life_mixed_success_failure(client: TestClient):
|
||||
_seed_monthly_na_series()
|
||||
|
||||
valid_item = {
|
||||
"monthly_benefit": 1000.0,
|
||||
"term_months": 5,
|
||||
"start_age": 65,
|
||||
"sex": "M",
|
||||
"race": "A",
|
||||
"discount_rate": 0.0,
|
||||
"cola_rate": 0.0,
|
||||
}
|
||||
invalid_item = {
|
||||
"monthly_benefit": -100.0, # invalid
|
||||
"term_months": 5,
|
||||
"start_age": 65,
|
||||
"sex": "M",
|
||||
"race": "A",
|
||||
"discount_rate": 0.0,
|
||||
"cola_rate": 0.0,
|
||||
}
|
||||
payload = {
|
||||
"items": [valid_item, invalid_item, valid_item]
|
||||
}
|
||||
resp = client.post("/api/pensions/valuation/batch-single-life", json=payload)
|
||||
assert resp.status_code == 200, resp.text
|
||||
data = resp.json()["results"]
|
||||
assert len(data) == 3
|
||||
assert data[0]["success"] is True
|
||||
assert "pv" in data[0]["result"]
|
||||
assert math.isclose(data[0]["result"]["pv"], 4000.0, rel_tol=1e-6)
|
||||
assert data[1]["success"] is False
|
||||
assert "monthly_benefit must be non-negative" in (data[1]["error"] or "")
|
||||
assert data[2]["success"] is True
|
||||
assert math.isclose(data[2]["result"]["pv"], 4000.0, rel_tol=1e-6)
|
||||
|
||||
def test_batch_joint_survivor_mixed_success_failure(client: TestClient):
|
||||
_seed_monthly_na_series()
|
||||
|
||||
valid_item = {
|
||||
"monthly_benefit": 1000.0,
|
||||
"term_months": 5,
|
||||
"participant_age": 65,
|
||||
"participant_sex": "M",
|
||||
"participant_race": "A",
|
||||
"spouse_age": 63,
|
||||
"spouse_sex": "F",
|
||||
"spouse_race": "A",
|
||||
"survivor_percent": 50.0,
|
||||
"discount_rate": 0.0,
|
||||
"cola_rate": 0.0,
|
||||
}
|
||||
invalid_item = {
|
||||
"monthly_benefit": 1000.0,
|
||||
"term_months": 5,
|
||||
"participant_age": 65,
|
||||
"participant_sex": "M",
|
||||
"participant_race": "A",
|
||||
"spouse_age": 63,
|
||||
"spouse_sex": "F",
|
||||
"spouse_race": "A",
|
||||
"survivor_percent": 150.0, # invalid >100
|
||||
"discount_rate": 0.0,
|
||||
"cola_rate": 0.0,
|
||||
}
|
||||
payload = {
|
||||
"items": [valid_item, invalid_item]
|
||||
}
|
||||
resp = client.post("/api/pensions/valuation/batch-joint-survivor", json=payload)
|
||||
assert resp.status_code == 200, resp.text
|
||||
data = resp.json()["results"]
|
||||
assert len(data) == 2
|
||||
assert data[0]["success"] is True
|
||||
assert "pv_total" in data[0]["result"]
|
||||
p = [1.0, 0.9, 0.8, 0.7, 0.6]
|
||||
p_both = [x * x for x in p]
|
||||
expected = 500.0 * sum(p) + 500.0 * sum(p_both)
|
||||
assert math.isclose(data[0]["result"]["pv_total"], expected, rel_tol=1e-6)
|
||||
assert data[1]["success"] is False
|
||||
assert "survivor_percent must be between 0 and 100" in (data[1]["error"] or "")
|
||||
|
||||
|
||||
192
tests/test_phone_book_api.py
Normal file
192
tests/test_phone_book_api.py
Normal file
@@ -0,0 +1,192 @@
|
||||
import os
|
||||
import uuid
|
||||
import csv
|
||||
import io
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
# Ensure required env vars for app import/config
|
||||
os.environ.setdefault("SECRET_KEY", "x" * 32)
|
||||
os.environ.setdefault("DATABASE_URL", "sqlite:////tmp/delphi_test.sqlite")
|
||||
|
||||
from app.main import app # noqa: E402
|
||||
from app.auth.security import get_current_user # noqa: E402
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def client():
|
||||
# Override auth to bypass JWT for these tests
|
||||
class _User:
|
||||
def __init__(self):
|
||||
self.id = "test"
|
||||
self.username = "tester"
|
||||
self.is_admin = True
|
||||
self.is_active = True
|
||||
|
||||
app.dependency_overrides[get_current_user] = lambda: _User()
|
||||
|
||||
try:
|
||||
yield TestClient(app)
|
||||
finally:
|
||||
app.dependency_overrides.pop(get_current_user, None)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def phone_book_data(client: TestClient):
|
||||
gid1 = f"PBGRP1-{uuid.uuid4().hex[:6]}"
|
||||
gid2 = f"PBGRP2-{uuid.uuid4().hex[:6]}"
|
||||
|
||||
def _create_customer(cid: str, last: str, first: str, group: str):
|
||||
payload = {
|
||||
"id": cid,
|
||||
"last": last,
|
||||
"first": first,
|
||||
"group": group,
|
||||
"email": f"{cid.lower()}@example.com",
|
||||
}
|
||||
resp = client.post("/api/customers/", json=payload)
|
||||
assert resp.status_code == 200, resp.text
|
||||
|
||||
def _add_phone(cid: str, location: str, number: str):
|
||||
resp = client.post(f"/api/customers/{cid}/phones", json={"location": location, "phone": number})
|
||||
assert resp.status_code == 200, resp.text
|
||||
|
||||
# Create entries for letters A, B and non-alpha '#'
|
||||
cid_a1 = f"PB-{uuid.uuid4().hex[:6]}-A1"
|
||||
cid_a2 = f"PB-{uuid.uuid4().hex[:6]}-A2"
|
||||
cid_b1 = f"PB-{uuid.uuid4().hex[:6]}-B1"
|
||||
cid_hash = f"PB-{uuid.uuid4().hex[:6]}-H"
|
||||
|
||||
_create_customer(cid_a1, last="Alpha", first="Alice", group=gid1)
|
||||
_add_phone(cid_a1, "Office", "111-111-1111")
|
||||
_add_phone(cid_a1, "Mobile", "111-111-2222")
|
||||
|
||||
_create_customer(cid_a2, last="Able", first="Andy", group=gid1)
|
||||
_add_phone(cid_a2, "Office", "222-222-2222")
|
||||
|
||||
_create_customer(cid_b1, last="Beta", first="Bob", group=gid2)
|
||||
_add_phone(cid_b1, "Home", "333-333-3333")
|
||||
|
||||
_create_customer(cid_hash, last="123Company", first="NA", group=gid1)
|
||||
_add_phone(cid_hash, "Main", "444-444-4444")
|
||||
|
||||
try:
|
||||
yield {
|
||||
"gid1": gid1,
|
||||
"gid2": gid2,
|
||||
"ids": [cid_a1, cid_a2, cid_b1, cid_hash],
|
||||
}
|
||||
finally:
|
||||
# Cleanup
|
||||
for cid in [cid_a1, cid_a2, cid_b1, cid_hash]:
|
||||
client.delete(f"/api/customers/{cid}")
|
||||
|
||||
|
||||
def _parse_csv(text: str):
|
||||
reader = csv.reader(io.StringIO(text))
|
||||
rows = list(reader)
|
||||
return rows[0], rows[1:]
|
||||
|
||||
|
||||
def test_phone_book_csv_letter_column_when_grouped_by_letter(client: TestClient, phone_book_data):
|
||||
# Only include our test group gid1 to avoid interference
|
||||
gid = phone_book_data["gid1"]
|
||||
resp = client.get(
|
||||
"/api/customers/phone-book",
|
||||
params={
|
||||
"format": "csv",
|
||||
"mode": "numbers",
|
||||
"grouping": "letter",
|
||||
"groups": gid,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert "text/csv" in resp.headers.get("content-type", "")
|
||||
header, rows = _parse_csv(resp.text)
|
||||
assert "Letter" in header
|
||||
# Ensure letters include 'A' and '#'
|
||||
letters = {r[header.index("Letter")] for r in rows}
|
||||
assert "A" in letters and "#" in letters
|
||||
|
||||
|
||||
def test_phone_book_html_sections_by_letter_with_page_break(client: TestClient, phone_book_data):
|
||||
gid = phone_book_data["gid1"]
|
||||
resp = client.get(
|
||||
"/api/customers/phone-book",
|
||||
params={
|
||||
"format": "html",
|
||||
"mode": "numbers",
|
||||
"grouping": "letter",
|
||||
"page_break": "1",
|
||||
"groups": gid,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
html = resp.text
|
||||
# Snapshot-style checks: section headers and page-break class
|
||||
assert "Letter: A" in html
|
||||
assert "Letter: #" in html
|
||||
assert "class=\"section-title\"" in html
|
||||
assert "page-break" in html # later sections should have page-break class
|
||||
|
||||
|
||||
def test_phone_book_html_group_then_letter_sections(client: TestClient, phone_book_data):
|
||||
gid1 = phone_book_data["gid1"]
|
||||
gid2 = phone_book_data["gid2"]
|
||||
# Include both groups to verify group and nested letter sections
|
||||
resp = client.get(
|
||||
"/api/customers/phone-book",
|
||||
params=[
|
||||
("format", "html"),
|
||||
("mode", "addresses"),
|
||||
("grouping", "group_letter"),
|
||||
("groups", gid1),
|
||||
("groups", gid2),
|
||||
],
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
html = resp.text
|
||||
assert f"Group: {gid1}" in html
|
||||
assert f"Group: {gid2}" in html
|
||||
assert "Letter: A" in html or "Letter: #" in html
|
||||
|
||||
|
||||
def test_phone_book_csv_no_letter_for_grouping_none(client: TestClient, phone_book_data):
|
||||
gid = phone_book_data["gid1"]
|
||||
resp = client.get(
|
||||
"/api/customers/phone-book",
|
||||
params={
|
||||
"format": "csv",
|
||||
"mode": "numbers",
|
||||
"grouping": "none",
|
||||
"groups": gid,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
header, rows = _parse_csv(resp.text)
|
||||
assert "Letter" not in header
|
||||
# Basic sanity: names and phones are present
|
||||
assert any("Alpha" in ",".join(row) or "Able" in ",".join(row) for row in rows)
|
||||
|
||||
|
||||
def test_phone_book_respects_group_filter(client: TestClient, phone_book_data):
|
||||
gid2 = phone_book_data["gid2"]
|
||||
resp = client.get(
|
||||
"/api/customers/phone-book",
|
||||
params={
|
||||
"format": "csv",
|
||||
"mode": "numbers",
|
||||
"grouping": "letter",
|
||||
"groups": gid2,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
header, rows = _parse_csv(resp.text)
|
||||
# Only Beta/Bob (gid2) should be present
|
||||
all_text = "\n".join([",".join(r) for r in rows])
|
||||
assert "Beta" in all_text or "Bob" in all_text
|
||||
# Ensure gid1 names are not present
|
||||
assert "Alpha" not in all_text and "Able" not in all_text
|
||||
|
||||
|
||||
199
tests/test_templates_search_cache.py
Normal file
199
tests/test_templates_search_cache.py
Normal file
@@ -0,0 +1,199 @@
|
||||
import os
|
||||
import io
|
||||
import uuid
|
||||
import asyncio
|
||||
from time import sleep
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
# Ensure required env vars for app import/config
|
||||
os.environ.setdefault("SECRET_KEY", "x" * 32)
|
||||
os.environ.setdefault("DATABASE_URL", "sqlite:////tmp/delphi_test.sqlite")
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
from app.main import app # noqa: E402
|
||||
from app.auth.security import get_current_user # noqa: E402
|
||||
from app.config import settings # noqa: E402
|
||||
from app.services.template_search import TemplateSearchService # noqa: E402
|
||||
|
||||
|
||||
def _dummy_docx_bytes():
|
||||
try:
|
||||
from docx import Document # type: ignore
|
||||
except Exception:
|
||||
return b"PK\x03\x04"
|
||||
d = Document()
|
||||
p = d.add_paragraph()
|
||||
p.add_run("Cache Test ")
|
||||
p.add_run("{{TOKEN}}")
|
||||
buf = io.BytesIO()
|
||||
d.save(buf)
|
||||
return buf.getvalue()
|
||||
|
||||
|
||||
def _upload_template(client: TestClient, name: str, category: str = "GENERAL") -> int:
|
||||
files = {"file": (f"{name}.docx", _dummy_docx_bytes(), "application/vnd.openxmlformats-officedocument.wordprocessingml.document")}
|
||||
resp = client.post(
|
||||
"/api/templates/upload",
|
||||
data={"name": name, "category": category, "semantic_version": "1.0.0"},
|
||||
files=files,
|
||||
)
|
||||
assert resp.status_code == 200, resp.text
|
||||
return int(resp.json()["id"])
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def client_no_redis():
|
||||
class _User:
|
||||
def __init__(self):
|
||||
self.id = "tmpl-cache-user"
|
||||
self.username = "tester"
|
||||
self.is_admin = True
|
||||
self.is_active = True
|
||||
|
||||
app.dependency_overrides[get_current_user] = lambda: _User()
|
||||
# Disable Redis to exercise in-memory fallback
|
||||
settings.cache_enabled = False
|
||||
settings.redis_url = ""
|
||||
# Clear template caches
|
||||
try:
|
||||
asyncio.run(TemplateSearchService.invalidate_all())
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
yield TestClient(app)
|
||||
finally:
|
||||
app.dependency_overrides.pop(get_current_user, None)
|
||||
|
||||
|
||||
def test_templates_search_caches_in_memory(client_no_redis: TestClient):
|
||||
# Upload in unique category to avoid interference
|
||||
category = f"CAT_{uuid.uuid4().hex[:8]}"
|
||||
_upload_template(client_no_redis, f"T-{uuid.uuid4().hex[:6]}", category)
|
||||
|
||||
params = {"category": category}
|
||||
r1 = client_no_redis.get("/api/templates/search", params=params)
|
||||
assert r1.status_code == 200
|
||||
d1 = r1.json()
|
||||
|
||||
# Second call should be served from cache and be identical
|
||||
r2 = client_no_redis.get("/api/templates/search", params=params)
|
||||
assert r2.status_code == 200
|
||||
d2 = r2.json()
|
||||
assert d1 == d2
|
||||
|
||||
|
||||
def test_templates_search_invalidation_on_upload_in_memory(client_no_redis: TestClient):
|
||||
# Use unique empty category; first query caches empty list
|
||||
category = f"EMPTY_{uuid.uuid4().hex[:8]}"
|
||||
r_empty = client_no_redis.get("/api/templates/search", params={"category": category})
|
||||
assert r_empty.status_code == 200
|
||||
assert r_empty.json() == []
|
||||
|
||||
# Upload a template into that category -> triggers invalidation
|
||||
_upload_template(client_no_redis, f"New-{uuid.uuid4().hex[:6]}", category)
|
||||
sleep(0.05)
|
||||
|
||||
# Query again should reflect the new item (cache invalidated)
|
||||
r_after = client_no_redis.get("/api/templates/search", params={"category": category})
|
||||
assert r_after.status_code == 200
|
||||
ids = [it["id"] for it in r_after.json()]
|
||||
assert len(ids) >= 1
|
||||
|
||||
|
||||
def test_templates_search_invalidation_on_keyword_update_in_memory(client_no_redis: TestClient):
|
||||
category = f"KW_{uuid.uuid4().hex[:8]}"
|
||||
tid = _upload_template(client_no_redis, f"KW-{uuid.uuid4().hex[:6]}", category)
|
||||
|
||||
# Initial query has_keywords=true should be empty and cached
|
||||
r1 = client_no_redis.get("/api/templates/search", params={"category": category, "has_keywords": True})
|
||||
assert r1.status_code == 200
|
||||
assert r1.json() == []
|
||||
|
||||
# Add a keyword to the template -> invalidates caches
|
||||
r_add = client_no_redis.post(f"/api/templates/{tid}/keywords", json={"keywords": ["alpha"]})
|
||||
assert r_add.status_code == 200
|
||||
sleep(0.05)
|
||||
|
||||
# Now search with has_keywords=true should include our template
|
||||
r2 = client_no_redis.get("/api/templates/search", params={"category": category, "has_keywords": True})
|
||||
assert r2.status_code == 200
|
||||
ids2 = {it["id"] for it in r2.json()}
|
||||
assert tid in ids2
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def client_with_redis():
|
||||
if not settings.redis_url:
|
||||
pytest.skip("Redis not configured for caching tests")
|
||||
|
||||
class _User:
|
||||
def __init__(self):
|
||||
self.id = "tmpl-cache-user-redis"
|
||||
self.username = "tester"
|
||||
self.is_admin = True
|
||||
self.is_active = True
|
||||
|
||||
app.dependency_overrides[get_current_user] = lambda: _User()
|
||||
settings.cache_enabled = True
|
||||
# Clear template caches
|
||||
try:
|
||||
asyncio.run(TemplateSearchService.invalidate_all())
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
yield TestClient(app)
|
||||
finally:
|
||||
app.dependency_overrides.pop(get_current_user, None)
|
||||
|
||||
|
||||
def test_templates_search_caches_with_redis(client_with_redis: TestClient):
|
||||
category = f"RCAT_{uuid.uuid4().hex[:8]}"
|
||||
_upload_template(client_with_redis, f"RT-{uuid.uuid4().hex[:6]}", category)
|
||||
|
||||
params = {"category": category}
|
||||
r1 = client_with_redis.get("/api/templates/search", params=params)
|
||||
assert r1.status_code == 200
|
||||
d1 = r1.json()
|
||||
|
||||
r2 = client_with_redis.get("/api/templates/search", params=params)
|
||||
assert r2.status_code == 200
|
||||
d2 = r2.json()
|
||||
assert d1 == d2
|
||||
|
||||
|
||||
def test_templates_search_invalidation_on_upload_with_redis(client_with_redis: TestClient):
|
||||
category = f"RADD_{uuid.uuid4().hex[:8]}"
|
||||
r_empty = client_with_redis.get("/api/templates/search", params={"category": category})
|
||||
assert r_empty.status_code == 200 and r_empty.json() == []
|
||||
|
||||
_upload_template(client_with_redis, f"RNew-{uuid.uuid4().hex[:6]}", category)
|
||||
sleep(0.05)
|
||||
|
||||
r_after = client_with_redis.get("/api/templates/search", params={"category": category})
|
||||
assert r_after.status_code == 200
|
||||
assert len(r_after.json()) >= 1
|
||||
|
||||
|
||||
def test_templates_search_invalidation_on_keyword_update_with_redis(client_with_redis: TestClient):
|
||||
category = f"RKW_{uuid.uuid4().hex[:8]}"
|
||||
tid = _upload_template(client_with_redis, f"RTKW-{uuid.uuid4().hex[:6]}", category)
|
||||
|
||||
r1 = client_with_redis.get("/api/templates/search", params={"category": category, "has_keywords": True})
|
||||
assert r1.status_code == 200 and r1.json() == []
|
||||
|
||||
r_add = client_with_redis.post(f"/api/templates/{tid}/keywords", json={"keywords": ["beta"]})
|
||||
assert r_add.status_code == 200
|
||||
sleep(0.05)
|
||||
|
||||
r2 = client_with_redis.get("/api/templates/search", params={"category": category, "has_keywords": True})
|
||||
ids = {it["id"] for it in r2.json()}
|
||||
assert tid in ids
|
||||
|
||||
|
||||
442
tests/test_websocket_admin_api.py
Normal file
442
tests/test_websocket_admin_api.py
Normal file
@@ -0,0 +1,442 @@
|
||||
"""
|
||||
Tests for WebSocket management endpoints in the Admin API
|
||||
|
||||
Tests cover:
|
||||
- WebSocket statistics endpoint
|
||||
- Connection listing and filtering
|
||||
- Connection management (disconnect, cleanup)
|
||||
- Broadcasting functionality
|
||||
- Admin-only access control
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from fastapi.testclient import TestClient
|
||||
from fastapi import status
|
||||
|
||||
from app.main import app
|
||||
from app.auth.security import get_current_user, get_admin_user
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_client():
|
||||
"""Create test client"""
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_user():
|
||||
"""Create admin user for testing"""
|
||||
user = User(
|
||||
id=1,
|
||||
username="admin",
|
||||
email="admin@test.com",
|
||||
is_admin=True,
|
||||
is_active=True
|
||||
)
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_client(admin_user):
|
||||
"""Test client with admin dependency overrides applied."""
|
||||
app.dependency_overrides[get_current_user] = lambda: admin_user
|
||||
app.dependency_overrides[get_admin_user] = lambda: admin_user
|
||||
try:
|
||||
yield TestClient(app)
|
||||
finally:
|
||||
app.dependency_overrides.pop(get_current_user, None)
|
||||
app.dependency_overrides.pop(get_admin_user, None)
|
||||
|
||||
@pytest.fixture
|
||||
def regular_user():
|
||||
"""Create regular user for testing"""
|
||||
user = User(
|
||||
id=2,
|
||||
username="user",
|
||||
email="user@test.com",
|
||||
is_admin=False,
|
||||
is_active=True
|
||||
)
|
||||
return user
|
||||
|
||||
|
||||
class TestWebSocketStatsEndpoint:
|
||||
"""Test WebSocket statistics endpoint"""
|
||||
|
||||
@patch('app.api.admin.get_admin_user')
|
||||
@patch('app.api.admin.get_websocket_manager')
|
||||
def test_get_websocket_stats_success(self, mock_get_manager, mock_get_admin, admin_client, admin_user):
|
||||
"""Test successful retrieval of WebSocket statistics"""
|
||||
# Mock admin authentication
|
||||
mock_get_admin.return_value = admin_user
|
||||
|
||||
# Mock WebSocket manager
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.get_stats = AsyncMock(return_value={
|
||||
"total_connections": 10,
|
||||
"active_connections": 8,
|
||||
"total_topics": 5,
|
||||
"total_users": 3,
|
||||
"messages_sent": 100,
|
||||
"messages_failed": 2,
|
||||
"connections_cleaned": 5,
|
||||
"last_cleanup": "2023-01-01T12:00:00Z",
|
||||
"last_heartbeat": "2023-01-01T12:01:00Z",
|
||||
"connections_by_state": {"connected": 8, "disconnected": 2},
|
||||
"topic_distribution": {"topic1": 5, "topic2": 3}
|
||||
})
|
||||
mock_get_manager.return_value = mock_manager
|
||||
|
||||
# Make request
|
||||
response = admin_client.get("/api/admin/websockets/stats")
|
||||
|
||||
# Check response
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["total_connections"] == 10
|
||||
assert data["active_connections"] == 8
|
||||
assert data["messages_sent"] == 100
|
||||
assert "topic1" in data["topic_distribution"]
|
||||
|
||||
@patch('app.api.admin.get_admin_user')
|
||||
def test_get_websocket_stats_unauthorized(self, mock_get_admin, test_client, regular_user):
|
||||
"""Test unauthorized access to WebSocket statistics"""
|
||||
# Mock non-admin user
|
||||
mock_get_admin.side_effect = Exception("Admin required")
|
||||
|
||||
# Make request
|
||||
response = test_client.get("/api/admin/websockets/stats")
|
||||
|
||||
# Should fail (forbidden)
|
||||
assert response.status_code in (status.HTTP_403_FORBIDDEN, status.HTTP_401_UNAUTHORIZED, status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||
|
||||
|
||||
class TestWebSocketConnectionsEndpoint:
|
||||
"""Test WebSocket connections listing endpoint"""
|
||||
|
||||
@patch('app.api.admin.get_admin_user')
|
||||
@patch('app.api.admin.get_websocket_manager')
|
||||
@patch('app.api.admin.get_connection_tracker')
|
||||
def test_get_connections_success(self, mock_get_tracker, mock_get_manager, mock_get_admin, admin_client, admin_user):
|
||||
"""Test successful retrieval of WebSocket connections"""
|
||||
# Mock admin authentication
|
||||
mock_get_admin.return_value = admin_user
|
||||
|
||||
# Mock WebSocket manager and tracker
|
||||
mock_manager = MagicMock()
|
||||
mock_pool = MagicMock()
|
||||
mock_pool._connections_lock = AsyncMock()
|
||||
mock_pool._connections = {"conn1": MagicMock(), "conn2": MagicMock()}
|
||||
mock_manager.pool = mock_pool
|
||||
mock_get_manager.return_value = mock_manager
|
||||
|
||||
mock_tracker = MagicMock()
|
||||
mock_tracker.get_connection_metrics = AsyncMock(side_effect=[
|
||||
{
|
||||
"connection_id": "conn1",
|
||||
"user_id": 1,
|
||||
"state": "connected",
|
||||
"topics": ["topic1"],
|
||||
"created_at": "2023-01-01T12:00:00Z",
|
||||
"last_activity": "2023-01-01T12:01:00Z",
|
||||
"age_seconds": 60,
|
||||
"idle_seconds": 10,
|
||||
"error_count": 0,
|
||||
"last_ping": "2023-01-01T12:01:00Z",
|
||||
"last_pong": "2023-01-01T12:01:00Z",
|
||||
"metadata": {},
|
||||
"is_alive": True,
|
||||
"is_stale": False
|
||||
},
|
||||
{
|
||||
"connection_id": "conn2",
|
||||
"user_id": 2,
|
||||
"state": "connected",
|
||||
"topics": ["topic2"],
|
||||
"created_at": "2023-01-01T12:00:00Z",
|
||||
"last_activity": "2023-01-01T12:01:00Z",
|
||||
"age_seconds": 60,
|
||||
"idle_seconds": 10,
|
||||
"error_count": 0,
|
||||
"last_ping": "2023-01-01T12:01:00Z",
|
||||
"last_pong": "2023-01-01T12:01:00Z",
|
||||
"metadata": {},
|
||||
"is_alive": True,
|
||||
"is_stale": False
|
||||
}
|
||||
])
|
||||
mock_get_tracker.return_value = mock_tracker
|
||||
|
||||
# Make request
|
||||
response = admin_client.get("/api/admin/websockets/connections")
|
||||
|
||||
# Check response
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["total_count"] == 2
|
||||
assert data["active_count"] == 2
|
||||
assert data["stale_count"] == 0
|
||||
assert len(data["connections"]) == 2
|
||||
|
||||
@patch('app.api.admin.get_admin_user')
|
||||
@patch('app.api.admin.get_websocket_manager')
|
||||
@patch('app.api.admin.get_connection_tracker')
|
||||
def test_get_connections_with_filters(self, mock_get_tracker, mock_get_manager, mock_get_admin, admin_client, admin_user):
|
||||
"""Test WebSocket connections listing with filters"""
|
||||
# Mock admin authentication
|
||||
mock_get_admin.return_value = admin_user
|
||||
|
||||
# Mock WebSocket manager and tracker
|
||||
mock_manager = MagicMock()
|
||||
mock_pool = MagicMock()
|
||||
mock_pool._connections_lock = AsyncMock()
|
||||
mock_pool._connections = {"conn1": MagicMock()}
|
||||
mock_manager.pool = mock_pool
|
||||
mock_get_manager.return_value = mock_manager
|
||||
|
||||
mock_tracker = MagicMock()
|
||||
# Return connection for user 1 only (user_id filter)
|
||||
mock_tracker.get_connection_metrics = AsyncMock(return_value={
|
||||
"connection_id": "conn1",
|
||||
"user_id": 1,
|
||||
"state": "connected",
|
||||
"topics": ["topic1"],
|
||||
"created_at": "2023-01-01T12:00:00Z",
|
||||
"last_activity": "2023-01-01T12:01:00Z",
|
||||
"age_seconds": 60,
|
||||
"idle_seconds": 10,
|
||||
"error_count": 0,
|
||||
"last_ping": "2023-01-01T12:01:00Z",
|
||||
"last_pong": "2023-01-01T12:01:00Z",
|
||||
"metadata": {},
|
||||
"is_alive": True,
|
||||
"is_stale": False
|
||||
})
|
||||
mock_get_tracker.return_value = mock_tracker
|
||||
|
||||
# Make request with user_id filter
|
||||
response = admin_client.get("/api/admin/websockets/connections?user_id=1")
|
||||
|
||||
# Check response
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["total_count"] == 1
|
||||
assert data["connections"][0]["user_id"] == 1
|
||||
|
||||
|
||||
class TestWebSocketDisconnectEndpoint:
|
||||
"""Test WebSocket disconnect endpoint"""
|
||||
|
||||
@patch('app.api.admin.get_admin_user')
|
||||
@patch('app.api.admin.get_websocket_manager')
|
||||
def test_disconnect_by_connection_ids(self, mock_get_manager, mock_get_admin, admin_client, admin_user):
|
||||
"""Test disconnecting specific connections"""
|
||||
# Mock admin authentication
|
||||
mock_get_admin.return_value = admin_user
|
||||
|
||||
# Mock WebSocket manager
|
||||
mock_manager = MagicMock()
|
||||
mock_pool = MagicMock()
|
||||
mock_pool.remove_connection = AsyncMock()
|
||||
mock_manager.pool = mock_pool
|
||||
mock_get_manager.return_value = mock_manager
|
||||
|
||||
# Make request
|
||||
response = admin_client.post(
|
||||
"/api/admin/websockets/disconnect",
|
||||
json={
|
||||
"connection_ids": ["conn1", "conn2"],
|
||||
"reason": "admin_test"
|
||||
}
|
||||
)
|
||||
|
||||
# Check response
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["disconnected_count"] == 2
|
||||
assert data["reason"] == "admin_test"
|
||||
|
||||
# Check mock calls
|
||||
assert mock_pool.remove_connection.call_count == 2
|
||||
|
||||
@patch('app.api.admin.get_admin_user')
|
||||
@patch('app.api.admin.get_websocket_manager')
|
||||
def test_disconnect_by_user_id(self, mock_get_manager, mock_get_admin, admin_client, admin_user):
|
||||
"""Test disconnecting all connections for a user"""
|
||||
# Mock admin authentication
|
||||
mock_get_admin.return_value = admin_user
|
||||
|
||||
# Mock WebSocket manager
|
||||
mock_manager = MagicMock()
|
||||
mock_pool = MagicMock()
|
||||
mock_pool.get_user_connections = AsyncMock(return_value=["conn1", "conn2"])
|
||||
mock_pool.remove_connection = AsyncMock()
|
||||
mock_manager.pool = mock_pool
|
||||
mock_get_manager.return_value = mock_manager
|
||||
|
||||
# Make request
|
||||
response = admin_client.post(
|
||||
"/api/admin/websockets/disconnect",
|
||||
json={
|
||||
"user_id": 1,
|
||||
"reason": "user_maintenance"
|
||||
}
|
||||
)
|
||||
|
||||
# Check response
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["disconnected_count"] == 2
|
||||
|
||||
def test_disconnect_missing_criteria(self, admin_client):
|
||||
"""Test disconnect request with missing criteria"""
|
||||
# Make request without specifying what to disconnect
|
||||
response = admin_client.post(
|
||||
"/api/admin/websockets/disconnect",
|
||||
json={"reason": "test"}
|
||||
)
|
||||
|
||||
# Should fail
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
|
||||
|
||||
class TestWebSocketCleanupEndpoint:
|
||||
"""Test WebSocket cleanup endpoint"""
|
||||
|
||||
@patch('app.api.admin.get_admin_user')
|
||||
@patch('app.api.admin.get_websocket_manager')
|
||||
def test_cleanup_websockets(self, mock_get_manager, mock_get_admin, admin_client, admin_user):
|
||||
"""Test manual WebSocket cleanup"""
|
||||
# Mock admin authentication
|
||||
mock_get_admin.return_value = admin_user
|
||||
|
||||
# Mock WebSocket manager
|
||||
mock_manager = MagicMock()
|
||||
mock_pool = MagicMock()
|
||||
mock_pool.get_stats = AsyncMock(side_effect=[
|
||||
{"active_connections": 10}, # Before cleanup
|
||||
{"active_connections": 8} # After cleanup
|
||||
])
|
||||
mock_pool._cleanup_stale_connections = AsyncMock()
|
||||
mock_manager.pool = mock_pool
|
||||
mock_get_manager.return_value = mock_manager
|
||||
|
||||
# Make request
|
||||
response = admin_client.post("/api/admin/websockets/cleanup")
|
||||
|
||||
# Check response
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["connections_before"] == 10
|
||||
assert data["connections_after"] == 8
|
||||
assert data["cleaned_count"] == 2
|
||||
|
||||
# Check cleanup was called
|
||||
mock_pool._cleanup_stale_connections.assert_called_once()
|
||||
|
||||
|
||||
class TestWebSocketBroadcastEndpoint:
|
||||
"""Test WebSocket broadcast endpoint"""
|
||||
|
||||
@patch('app.api.admin.get_admin_user')
|
||||
@patch('app.api.admin.get_websocket_manager')
|
||||
def test_broadcast_message(self, mock_get_manager, mock_get_admin, admin_client, admin_user):
|
||||
"""Test broadcasting a message to a topic"""
|
||||
# Mock admin authentication
|
||||
mock_get_admin.return_value = admin_user
|
||||
|
||||
# Mock WebSocket manager
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.broadcast_to_topic = AsyncMock(return_value=5)
|
||||
mock_get_manager.return_value = mock_manager
|
||||
|
||||
# Make request
|
||||
response = admin_client.post(
|
||||
"/api/admin/websockets/broadcast",
|
||||
json={
|
||||
"topic": "admin_announcement",
|
||||
"message_type": "system_message",
|
||||
"data": {"message": "System maintenance in 5 minutes"}
|
||||
}
|
||||
)
|
||||
|
||||
# Check response
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["sent_count"] == 5
|
||||
assert data["topic"] == "admin_announcement"
|
||||
assert data["message_type"] == "system_message"
|
||||
|
||||
# Check broadcast was called correctly
|
||||
mock_manager.broadcast_to_topic.assert_called_once_with(
|
||||
topic="admin_announcement",
|
||||
message_type="system_message",
|
||||
data={"message": "System maintenance in 5 minutes"}
|
||||
)
|
||||
|
||||
|
||||
class TestWebSocketConnectionDetailEndpoint:
|
||||
"""Test individual WebSocket connection detail endpoint"""
|
||||
|
||||
@patch('app.api.admin.get_admin_user')
|
||||
@patch('app.api.admin.get_connection_tracker')
|
||||
def test_get_connection_detail_success(self, mock_get_tracker, mock_get_admin, admin_client, admin_user):
|
||||
"""Test getting details for a specific connection"""
|
||||
# Mock admin authentication
|
||||
mock_get_admin.return_value = admin_user
|
||||
|
||||
# Mock connection tracker
|
||||
mock_tracker = MagicMock()
|
||||
mock_tracker.get_connection_metrics = AsyncMock(return_value={
|
||||
"connection_id": "conn1",
|
||||
"user_id": 1,
|
||||
"state": "connected",
|
||||
"topics": ["topic1", "topic2"],
|
||||
"created_at": "2023-01-01T12:00:00Z",
|
||||
"last_activity": "2023-01-01T12:01:00Z",
|
||||
"age_seconds": 60,
|
||||
"idle_seconds": 10,
|
||||
"error_count": 0,
|
||||
"last_ping": "2023-01-01T12:01:00Z",
|
||||
"last_pong": "2023-01-01T12:01:00Z",
|
||||
"metadata": {"endpoint": "batch_progress"},
|
||||
"is_alive": True,
|
||||
"is_stale": False
|
||||
})
|
||||
mock_get_tracker.return_value = mock_tracker
|
||||
|
||||
# Make request
|
||||
response = admin_client.get("/api/admin/websockets/connections/conn1")
|
||||
|
||||
# Check response
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["connection_id"] == "conn1"
|
||||
assert data["user_id"] == 1
|
||||
assert data["state"] == "connected"
|
||||
assert len(data["topics"]) == 2
|
||||
assert data["metadata"]["endpoint"] == "batch_progress"
|
||||
|
||||
@patch('app.api.admin.get_admin_user')
|
||||
@patch('app.api.admin.get_connection_tracker')
|
||||
def test_get_connection_detail_not_found(self, mock_get_tracker, mock_get_admin, admin_client, admin_user):
|
||||
"""Test getting details for non-existent connection"""
|
||||
# Mock admin authentication
|
||||
mock_get_admin.return_value = admin_user
|
||||
|
||||
# Mock connection tracker returning None
|
||||
mock_tracker = MagicMock()
|
||||
mock_tracker.get_connection_metrics = AsyncMock(return_value=None)
|
||||
mock_get_tracker.return_value = mock_tracker
|
||||
|
||||
# Make request
|
||||
response = admin_client.get("/api/admin/websockets/connections/nonexistent")
|
||||
|
||||
# Check response
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
607
tests/test_websocket_pool.py
Normal file
607
tests/test_websocket_pool.py
Normal file
@@ -0,0 +1,607 @@
|
||||
"""
|
||||
Comprehensive tests for WebSocket connection pooling and management
|
||||
|
||||
Tests cover:
|
||||
- Connection pool creation and management
|
||||
- Automatic cleanup of stale connections
|
||||
- Health monitoring and heartbeats
|
||||
- Topic-based message broadcasting
|
||||
- Resource management and memory leak prevention
|
||||
- Integration with authentication
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import pytest
|
||||
import time
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import List, Dict, Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from fastapi import WebSocket, WebSocketDisconnect
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.services.websocket_pool import (
|
||||
WebSocketPool,
|
||||
ConnectionState,
|
||||
MessageType,
|
||||
ConnectionInfo,
|
||||
WebSocketMessage,
|
||||
get_websocket_pool,
|
||||
initialize_websocket_pool,
|
||||
shutdown_websocket_pool,
|
||||
websocket_connection
|
||||
)
|
||||
from app.middleware.websocket_middleware import (
|
||||
WebSocketManager,
|
||||
get_websocket_manager,
|
||||
WebSocketAuthenticationError
|
||||
)
|
||||
|
||||
|
||||
class MockWebSocket:
|
||||
"""Mock WebSocket for testing"""
|
||||
|
||||
def __init__(self, mock_user_id: int = None):
|
||||
self.sent_messages = []
|
||||
self.received_messages = []
|
||||
self.closed = False
|
||||
self.close_code = None
|
||||
self.close_reason = None
|
||||
self.mock_user_id = mock_user_id
|
||||
self.url = MagicMock()
|
||||
self.url.query = f"token=test_token_{mock_user_id}" if mock_user_id else ""
|
||||
|
||||
async def send_json(self, data: Dict[str, Any]):
|
||||
"""Mock send_json method"""
|
||||
if self.closed:
|
||||
raise Exception("WebSocket closed")
|
||||
self.sent_messages.append(data)
|
||||
|
||||
async def send_text(self, text: str):
|
||||
"""Mock send_text method"""
|
||||
if self.closed:
|
||||
raise Exception("WebSocket closed")
|
||||
self.sent_messages.append({"text": text})
|
||||
|
||||
async def receive_text(self) -> str:
|
||||
"""Mock receive_text method"""
|
||||
if self.closed:
|
||||
raise WebSocketDisconnect()
|
||||
if self.received_messages:
|
||||
return self.received_messages.pop(0)
|
||||
# Simulate waiting for message
|
||||
await asyncio.sleep(0.1)
|
||||
return "ping"
|
||||
|
||||
async def accept(self):
|
||||
"""Mock accept method"""
|
||||
pass
|
||||
|
||||
async def close(self, code: int = 1000, reason: str = ""):
|
||||
"""Mock close method"""
|
||||
self.closed = True
|
||||
self.close_code = code
|
||||
self.close_reason = reason
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def websocket_pool():
|
||||
"""Create a fresh WebSocket pool for testing"""
|
||||
pool = WebSocketPool(
|
||||
cleanup_interval=1, # Fast cleanup for testing
|
||||
connection_timeout=5, # Short timeout for testing
|
||||
heartbeat_interval=2, # Fast heartbeat for testing
|
||||
max_connections_per_topic=10,
|
||||
max_total_connections=100
|
||||
)
|
||||
await pool.start()
|
||||
yield pool
|
||||
await pool.stop()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_websocket():
|
||||
"""Create a mock WebSocket for testing"""
|
||||
return MockWebSocket()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_authenticated_websocket():
|
||||
"""Create a mock authenticated WebSocket for testing"""
|
||||
return MockWebSocket(mock_user_id=1)
|
||||
|
||||
|
||||
class TestWebSocketPool:
|
||||
"""Test the core WebSocket pool functionality"""
|
||||
|
||||
async def test_pool_initialization(self, websocket_pool):
|
||||
"""Test pool initialization and configuration"""
|
||||
assert websocket_pool.cleanup_interval == 1
|
||||
assert websocket_pool.connection_timeout == 5
|
||||
assert websocket_pool.heartbeat_interval == 2
|
||||
assert websocket_pool.max_connections_per_topic == 10
|
||||
assert websocket_pool.max_total_connections == 100
|
||||
|
||||
stats = await websocket_pool.get_stats()
|
||||
assert stats["active_connections"] == 0
|
||||
assert stats["total_topics"] == 0
|
||||
assert stats["total_users"] == 0
|
||||
|
||||
async def test_add_remove_connection(self, websocket_pool, mock_websocket):
|
||||
"""Test adding and removing connections"""
|
||||
# Add connection
|
||||
connection_id = await websocket_pool.add_connection(
|
||||
websocket=mock_websocket,
|
||||
user_id=1,
|
||||
topics={"test_topic"},
|
||||
metadata={"test": "data"}
|
||||
)
|
||||
|
||||
assert connection_id.startswith("ws_")
|
||||
|
||||
# Check connection exists
|
||||
connection_info = await websocket_pool.get_connection_info(connection_id)
|
||||
assert connection_info is not None
|
||||
assert connection_info.user_id == 1
|
||||
assert "test_topic" in connection_info.topics
|
||||
assert connection_info.metadata["test"] == "data"
|
||||
|
||||
# Check stats
|
||||
stats = await websocket_pool.get_stats()
|
||||
assert stats["active_connections"] == 1
|
||||
assert stats["total_topics"] == 1
|
||||
assert stats["total_users"] == 1
|
||||
|
||||
# Remove connection
|
||||
await websocket_pool.remove_connection(connection_id, "test_removal")
|
||||
|
||||
# Check connection removed
|
||||
connection_info = await websocket_pool.get_connection_info(connection_id)
|
||||
assert connection_info is None
|
||||
|
||||
stats = await websocket_pool.get_stats()
|
||||
assert stats["active_connections"] == 0
|
||||
|
||||
async def test_topic_subscription(self, websocket_pool, mock_websocket):
|
||||
"""Test topic subscription and unsubscription"""
|
||||
connection_id = await websocket_pool.add_connection(
|
||||
websocket=mock_websocket,
|
||||
user_id=1,
|
||||
topics=set()
|
||||
)
|
||||
|
||||
# Subscribe to topic
|
||||
success = await websocket_pool.subscribe_to_topic(connection_id, "new_topic")
|
||||
assert success
|
||||
|
||||
# Check subscription
|
||||
topic_connections = await websocket_pool.get_topic_connections("new_topic")
|
||||
assert connection_id in topic_connections
|
||||
|
||||
connection_info = await websocket_pool.get_connection_info(connection_id)
|
||||
assert "new_topic" in connection_info.topics
|
||||
|
||||
# Unsubscribe from topic
|
||||
success = await websocket_pool.unsubscribe_from_topic(connection_id, "new_topic")
|
||||
assert success
|
||||
|
||||
# Check unsubscription
|
||||
topic_connections = await websocket_pool.get_topic_connections("new_topic")
|
||||
assert connection_id not in topic_connections
|
||||
|
||||
connection_info = await websocket_pool.get_connection_info(connection_id)
|
||||
assert "new_topic" not in connection_info.topics
|
||||
|
||||
async def test_broadcast_to_topic(self, websocket_pool):
|
||||
"""Test broadcasting messages to topic subscribers"""
|
||||
# Create multiple connections
|
||||
websockets = [MockWebSocket(i) for i in range(3)]
|
||||
connection_ids = []
|
||||
|
||||
for i, ws in enumerate(websockets):
|
||||
connection_id = await websocket_pool.add_connection(
|
||||
websocket=ws,
|
||||
user_id=i,
|
||||
topics={"broadcast_topic"}
|
||||
)
|
||||
connection_ids.append(connection_id)
|
||||
|
||||
# Broadcast message
|
||||
message = WebSocketMessage(
|
||||
type="test_broadcast",
|
||||
topic="broadcast_topic",
|
||||
data={"message": "Hello everyone!"}
|
||||
)
|
||||
|
||||
sent_count = await websocket_pool.broadcast_to_topic(
|
||||
topic="broadcast_topic",
|
||||
message=message
|
||||
)
|
||||
|
||||
assert sent_count == 3
|
||||
|
||||
# Check all websockets received the message
|
||||
for ws in websockets:
|
||||
assert len(ws.sent_messages) == 1
|
||||
sent_message = ws.sent_messages[0]
|
||||
assert sent_message["type"] == "test_broadcast"
|
||||
assert sent_message["data"]["message"] == "Hello everyone!"
|
||||
|
||||
async def test_send_to_user(self, websocket_pool):
|
||||
"""Test sending messages to all connections for a specific user"""
|
||||
# Create multiple connections for the same user
|
||||
websockets = [MockWebSocket(1) for _ in range(2)]
|
||||
connection_ids = []
|
||||
|
||||
for ws in websockets:
|
||||
connection_id = await websocket_pool.add_connection(
|
||||
websocket=ws,
|
||||
user_id=1,
|
||||
topics=set()
|
||||
)
|
||||
connection_ids.append(connection_id)
|
||||
|
||||
# Add connection for different user
|
||||
other_ws = MockWebSocket(2)
|
||||
await websocket_pool.add_connection(
|
||||
websocket=other_ws,
|
||||
user_id=2,
|
||||
topics=set()
|
||||
)
|
||||
|
||||
# Send message to user 1
|
||||
message = WebSocketMessage(
|
||||
type="user_message",
|
||||
data={"message": "Hello user 1!"}
|
||||
)
|
||||
|
||||
sent_count = await websocket_pool.send_to_user(1, message)
|
||||
|
||||
assert sent_count == 2
|
||||
|
||||
# Check user 1 connections received message
|
||||
for ws in websockets:
|
||||
assert len(ws.sent_messages) == 1
|
||||
sent_message = ws.sent_messages[0]
|
||||
assert sent_message["type"] == "user_message"
|
||||
|
||||
# Check user 2 connection didn't receive message
|
||||
assert len(other_ws.sent_messages) == 0
|
||||
|
||||
async def test_connection_limits(self, websocket_pool):
|
||||
"""Test connection limits"""
|
||||
# Test max connections per topic
|
||||
websockets = []
|
||||
for i in range(websocket_pool.max_connections_per_topic + 1):
|
||||
ws = MockWebSocket(i)
|
||||
websockets.append(ws)
|
||||
|
||||
if i < websocket_pool.max_connections_per_topic:
|
||||
# Should succeed
|
||||
connection_id = await websocket_pool.add_connection(
|
||||
websocket=ws,
|
||||
user_id=i,
|
||||
topics={"limited_topic"}
|
||||
)
|
||||
assert connection_id is not None
|
||||
else:
|
||||
# Should fail due to topic limit
|
||||
with pytest.raises(ValueError, match="Maximum connections per topic"):
|
||||
await websocket_pool.add_connection(
|
||||
websocket=ws,
|
||||
user_id=i,
|
||||
topics={"limited_topic"}
|
||||
)
|
||||
|
||||
async def test_stale_connection_cleanup(self, websocket_pool):
|
||||
"""Test automatic cleanup of stale connections"""
|
||||
# Add connection
|
||||
mock_ws = MockWebSocket(1)
|
||||
connection_id = await websocket_pool.add_connection(
|
||||
websocket=mock_ws,
|
||||
user_id=1,
|
||||
topics={"test_topic"}
|
||||
)
|
||||
|
||||
# Manually set last activity to make connection stale
|
||||
connection_info = await websocket_pool.get_connection_info(connection_id)
|
||||
connection_info.last_activity = datetime.now(timezone.utc) - timedelta(seconds=10)
|
||||
|
||||
# Trigger cleanup
|
||||
await websocket_pool._cleanup_stale_connections()
|
||||
|
||||
# Check connection was cleaned up
|
||||
connection_info = await websocket_pool.get_connection_info(connection_id)
|
||||
assert connection_info is None
|
||||
|
||||
stats = await websocket_pool.get_stats()
|
||||
assert stats["active_connections"] == 0
|
||||
assert stats["connections_cleaned"] > 0
|
||||
|
||||
async def test_heartbeat_functionality(self, websocket_pool, mock_websocket):
|
||||
"""Test heartbeat sending and connection health monitoring"""
|
||||
connection_id = await websocket_pool.add_connection(
|
||||
websocket=mock_websocket,
|
||||
user_id=1,
|
||||
topics=set()
|
||||
)
|
||||
|
||||
# Trigger heartbeat
|
||||
await websocket_pool._send_heartbeats()
|
||||
|
||||
# Check heartbeat was sent
|
||||
assert len(mock_websocket.sent_messages) == 1
|
||||
heartbeat_message = mock_websocket.sent_messages[0]
|
||||
assert heartbeat_message["type"] == MessageType.HEARTBEAT.value
|
||||
|
||||
# Test ping functionality
|
||||
success = await websocket_pool.ping_connection(connection_id)
|
||||
assert success
|
||||
|
||||
assert len(mock_websocket.sent_messages) == 2
|
||||
ping_message = mock_websocket.sent_messages[1]
|
||||
assert ping_message["type"] == MessageType.PING.value
|
||||
|
||||
async def test_failed_message_cleanup(self, websocket_pool):
|
||||
"""Test cleanup of connections when message sending fails"""
|
||||
# Create a websocket that will fail on send
|
||||
mock_ws = MockWebSocket(1)
|
||||
mock_ws.closed = True # Simulate closed connection
|
||||
|
||||
connection_id = await websocket_pool.add_connection(
|
||||
websocket=mock_ws,
|
||||
user_id=1,
|
||||
topics={"test_topic"}
|
||||
)
|
||||
|
||||
# Try to broadcast - should fail and clean up connection
|
||||
message = WebSocketMessage(type="test", data={})
|
||||
sent_count = await websocket_pool.broadcast_to_topic(
|
||||
topic="test_topic",
|
||||
message=message
|
||||
)
|
||||
|
||||
assert sent_count == 0
|
||||
|
||||
# Check connection was cleaned up
|
||||
connection_info = await websocket_pool.get_connection_info(connection_id)
|
||||
assert connection_info is None
|
||||
|
||||
|
||||
class TestWebSocketManager:
|
||||
"""Test the WebSocket manager middleware"""
|
||||
|
||||
@pytest.fixture
|
||||
def websocket_manager(self):
|
||||
"""Create a WebSocket manager for testing"""
|
||||
return WebSocketManager()
|
||||
|
||||
@patch('app.middleware.websocket_middleware.verify_token')
|
||||
@patch('app.middleware.websocket_middleware.SessionLocal')
|
||||
async def test_authentication_success(self, mock_session, mock_verify_token, websocket_manager):
|
||||
"""Test successful WebSocket authentication"""
|
||||
# Mock successful authentication
|
||||
mock_verify_token.return_value = "test_user"
|
||||
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 1
|
||||
mock_user.username = "test_user"
|
||||
mock_user.is_active = True
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.query.return_value.filter.return_value.first.return_value = mock_user
|
||||
mock_session.return_value = mock_db
|
||||
|
||||
# Create mock websocket with token
|
||||
mock_ws = MagicMock()
|
||||
mock_ws.url.query = "token=valid_token"
|
||||
|
||||
user = await websocket_manager.authenticate_websocket(mock_ws)
|
||||
|
||||
assert user is not None
|
||||
assert user.id == 1
|
||||
assert user.username == "test_user"
|
||||
|
||||
@patch('app.middleware.websocket_middleware.verify_token')
|
||||
async def test_authentication_failure(self, mock_verify_token, websocket_manager):
|
||||
"""Test failed WebSocket authentication"""
|
||||
# Mock failed authentication
|
||||
mock_verify_token.return_value = None
|
||||
|
||||
mock_ws = MagicMock()
|
||||
mock_ws.url.query = "token=invalid_token"
|
||||
|
||||
user = await websocket_manager.authenticate_websocket(mock_ws)
|
||||
|
||||
assert user is None
|
||||
|
||||
|
||||
class TestWebSocketIntegration:
|
||||
"""Test WebSocket integration with FastAPI"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_connection_context_manager(self):
|
||||
"""Test the websocket_connection context manager"""
|
||||
pool = WebSocketPool()
|
||||
await pool.start()
|
||||
|
||||
try:
|
||||
mock_ws = MockWebSocket(1)
|
||||
|
||||
async with websocket_connection(
|
||||
websocket=mock_ws,
|
||||
user_id=1,
|
||||
topics={"test_topic"},
|
||||
metadata={"test": "data"}
|
||||
) as (connection_id, pool_instance):
|
||||
|
||||
assert connection_id is not None
|
||||
assert connection_id.startswith("ws_")
|
||||
# The context returns the global pool; allow either the same object or equivalent type
|
||||
assert isinstance(pool_instance, type(pool))
|
||||
|
||||
# Check connection exists
|
||||
connection_info = await pool.get_connection_info(connection_id)
|
||||
assert connection_info is not None
|
||||
assert connection_info.user_id == 1
|
||||
assert "test_topic" in connection_info.topics
|
||||
|
||||
# Check connection was cleaned up after context exit
|
||||
connection_info = await pool.get_connection_info(connection_id)
|
||||
assert connection_info is None
|
||||
|
||||
finally:
|
||||
await pool.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_pool_management(self):
|
||||
"""Test global WebSocket pool initialization and shutdown"""
|
||||
# Initialize global pool
|
||||
pool = await initialize_websocket_pool(
|
||||
cleanup_interval=1,
|
||||
connection_timeout=5
|
||||
)
|
||||
|
||||
assert pool is not None
|
||||
|
||||
# Get the same pool instance
|
||||
same_pool = get_websocket_pool()
|
||||
assert same_pool is pool
|
||||
|
||||
# Shutdown global pool
|
||||
await shutdown_websocket_pool()
|
||||
|
||||
# Should create new pool after shutdown
|
||||
new_pool = get_websocket_pool()
|
||||
assert new_pool is not pool
|
||||
|
||||
|
||||
class TestWebSocketStressTest:
|
||||
"""Stress tests for WebSocket pool under high load"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_high_connection_volume(self):
|
||||
"""Test pool behavior with many concurrent connections"""
|
||||
pool = WebSocketPool(max_total_connections=1000)
|
||||
await pool.start()
|
||||
|
||||
try:
|
||||
# Create many connections
|
||||
connection_ids = []
|
||||
for i in range(100):
|
||||
mock_ws = MockWebSocket(i)
|
||||
connection_id = await pool.add_connection(
|
||||
websocket=mock_ws,
|
||||
user_id=i % 10, # 10 different users
|
||||
topics={f"topic_{i % 5}"} # 5 different topics
|
||||
)
|
||||
connection_ids.append(connection_id)
|
||||
|
||||
# Check all connections exist
|
||||
stats = await pool.get_stats()
|
||||
assert stats["active_connections"] == 100
|
||||
assert stats["total_topics"] == 5
|
||||
assert stats["total_users"] == 10
|
||||
|
||||
# Broadcast to all topics
|
||||
for topic_id in range(5):
|
||||
topic = f"topic_{topic_id}"
|
||||
message = WebSocketMessage(
|
||||
type="stress_test",
|
||||
topic=topic,
|
||||
data={"topic_id": topic_id}
|
||||
)
|
||||
sent_count = await pool.broadcast_to_topic(topic, message)
|
||||
assert sent_count == 20 # 100 connections / 5 topics
|
||||
|
||||
# Clean up all connections
|
||||
for connection_id in connection_ids:
|
||||
await pool.remove_connection(connection_id, "stress_test_cleanup")
|
||||
|
||||
stats = await pool.get_stats()
|
||||
assert stats["active_connections"] == 0
|
||||
|
||||
finally:
|
||||
await pool.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rapid_connect_disconnect(self):
|
||||
"""Test rapid connection and disconnection patterns"""
|
||||
pool = WebSocketPool()
|
||||
await pool.start()
|
||||
|
||||
try:
|
||||
# Rapidly create and destroy connections
|
||||
for round_num in range(10):
|
||||
connection_ids = []
|
||||
|
||||
# Create 10 connections
|
||||
for i in range(10):
|
||||
mock_ws = MockWebSocket(i)
|
||||
connection_id = await pool.add_connection(
|
||||
websocket=mock_ws,
|
||||
user_id=i,
|
||||
topics={f"round_{round_num}"}
|
||||
)
|
||||
connection_ids.append(connection_id)
|
||||
|
||||
# Immediately remove them
|
||||
for connection_id in connection_ids:
|
||||
await pool.remove_connection(connection_id, f"round_{round_num}_cleanup")
|
||||
|
||||
# Check pool is clean
|
||||
stats = await pool.get_stats()
|
||||
assert stats["active_connections"] == 0
|
||||
|
||||
finally:
|
||||
await pool.stop()
|
||||
|
||||
|
||||
# Test utilities and fixtures for integration testing
|
||||
@pytest.fixture
|
||||
def websocket_test_client():
|
||||
"""Create a test client for WebSocket endpoint testing"""
|
||||
from app.main import app
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
class TestBillingWebSocketIntegration:
|
||||
"""Test integration with the billing API WebSocket endpoint"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_progress_topic_format(self):
|
||||
"""Test that batch progress uses correct topic format"""
|
||||
from app.api.billing import _notify_progress_subscribers
|
||||
from app.services.batch_generation import BatchProgress
|
||||
|
||||
# Mock the WebSocket manager
|
||||
with patch('app.api.billing.websocket_manager') as mock_manager:
|
||||
mock_manager.broadcast_to_topic = AsyncMock(return_value=1)
|
||||
|
||||
# Create test progress
|
||||
progress = BatchProgress(
|
||||
batch_id="test_batch_123",
|
||||
status="processing",
|
||||
total_files=10,
|
||||
processed_files=5,
|
||||
successful_files=4,
|
||||
failed_files=1,
|
||||
current_file="test.txt",
|
||||
started_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
# Call notification function
|
||||
await _notify_progress_subscribers(progress)
|
||||
|
||||
# Check correct topic was used
|
||||
mock_manager.broadcast_to_topic.assert_called_once_with(
|
||||
topic="batch_progress_test_batch_123",
|
||||
message_type="progress",
|
||||
data=progress.model_dump()
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run tests
|
||||
pytest.main([__file__, "-v"])
|
||||
Reference in New Issue
Block a user