This commit is contained in:
HotSwapp
2025-08-18 20:20:04 -05:00
parent 89b2bc0aa2
commit bac8cc4bd5
114 changed files with 30258 additions and 1341 deletions

292
tests/test_jobs_api.py Normal file
View File

@@ -0,0 +1,292 @@
import os
import uuid
from datetime import datetime, timezone, timedelta
import pytest
from fastapi.testclient import TestClient
# Ensure required env vars for app import/config
os.environ.setdefault("SECRET_KEY", "x" * 32)
os.environ.setdefault("DATABASE_URL", "sqlite:////tmp/delphi_test.sqlite")
from app.main import app # noqa: E402
from app.auth.security import get_current_user, get_admin_user # noqa: E402
from app.database.base import SessionLocal # noqa: E402
from app.models.jobs import JobRecord # noqa: E402
from app.models.audit import AuditLog # noqa: E402
class _User:
def __init__(self, is_admin: bool):
self.id = 1 if is_admin else 2
self.username = "admin" if is_admin else "user"
self.is_admin = is_admin
self.is_active = True
self.first_name = "Test"
self.last_name = "User"
@pytest.fixture()
def client_admin():
app.dependency_overrides[get_current_user] = lambda: _User(True)
app.dependency_overrides[get_admin_user] = lambda: _User(True)
try:
yield TestClient(app)
finally:
app.dependency_overrides.pop(get_current_user, None)
app.dependency_overrides.pop(get_admin_user, None)
@pytest.fixture()
def client_user():
app.dependency_overrides[get_current_user] = lambda: _User(False)
try:
yield TestClient(app)
finally:
app.dependency_overrides.pop(get_current_user, None)
@pytest.fixture()
def make_job():
"""Factory to create a JobRecord in the test database."""
created_job_ids = []
def _create(
*,
job_type: str,
status: str = "queued",
requested_by_username: str = "user",
started_at: datetime | None = None,
completed_at: datetime | None = None,
total_requested: int = 0,
total_success: int = 0,
total_failed: int = 0,
details: dict | None = None,
) -> JobRecord:
db = SessionLocal()
try:
j = JobRecord(
job_id=uuid.uuid4().hex,
job_type=job_type,
status=status,
requested_by_username=requested_by_username,
started_at=started_at or datetime.now(timezone.utc),
completed_at=completed_at,
total_requested=total_requested,
total_success=total_success,
total_failed=total_failed,
details=details or {},
)
db.add(j)
db.commit()
db.refresh(j)
created_job_ids.append(j.job_id)
return j
finally:
db.close()
yield _create
# Optional: cleanup created jobs to reduce cross-test noise
if created_job_ids:
db = SessionLocal()
try:
for jid in created_job_ids:
row = db.query(JobRecord).filter(JobRecord.job_id == jid).first()
if row:
db.delete(row)
db.commit()
finally:
db.close()
def test_jobs_list_and_filtering(client_admin: TestClient, client_user: TestClient, make_job):
jt = f"pytest_jobs_{uuid.uuid4().hex[:6]}"
# Create jobs: two for non-admin user, one for admin
j_user_running = make_job(job_type=jt, status="running", requested_by_username="user")
j_user_failed = make_job(job_type=jt, status="failed", requested_by_username="user")
j_admin_completed = make_job(job_type=jt, status="completed", requested_by_username="admin")
# Non-admin sees only their jobs by default (mine=True)
resp = client_user.get("/api/jobs/", params={"include_total": 1})
assert resp.status_code == 200
body = resp.json()
assert set(body.keys()) == {"items", "total"}
assert body["total"] >= 2
ids = [it["job_id"] for it in body["items"]]
assert j_user_running.job_id in ids and j_user_failed.job_id in ids
assert j_admin_completed.job_id not in ids
# Filter by status
resp = client_user.get("/api/jobs/", params={"include_total": 1, "status_filter": "failed"})
assert resp.status_code == 200
body = resp.json()
assert body["total"] >= 1
assert any(it["job_id"] == j_user_failed.job_id for it in body["items"])
# Filter by type
resp = client_user.get("/api/jobs/", params={"include_total": 1, "type_filter": jt})
assert resp.status_code == 200
body = resp.json()
assert body["total"] >= 2
# Search across job_type
resp = client_user.get("/api/jobs/", params={"search": jt})
assert resp.status_code == 200
ids = [it["job_id"] for it in resp.json()]
assert j_user_running.job_id in ids and j_user_failed.job_id in ids
# Admin can list all with mine=false and filter by requested_by
# Because fixtures share global dependency overrides, ensure admin override is active for this request
from app.main import app as _app
from app.auth.security import get_current_user as _get_current_user
_prev = _app.dependency_overrides.get(_get_current_user)
try:
_app.dependency_overrides[_get_current_user] = lambda: _User(True)
c = TestClient(_app)
resp = c.get("/api/jobs/", params={"mine": 0, "include_total": 1})
assert resp.status_code == 200
body = resp.json()
assert body["total"] >= 3
finally:
if _prev is not None:
_app.dependency_overrides[_get_current_user] = _prev
else:
_app.dependency_overrides.pop(_get_current_user, None)
resp = client_admin.get("/api/jobs/", params={"mine": 0, "include_total": 1, "requested_by": "user"})
assert resp.status_code == 200
body = resp.json()
assert body["total"] >= 2
def test_jobs_get_by_id_permissions(client_admin: TestClient, client_user: TestClient, make_job):
jt = f"pytest_jobs_{uuid.uuid4().hex[:6]}"
j_admin = make_job(job_type=jt, status="completed", requested_by_username="admin")
j_user = make_job(job_type=jt, status="running", requested_by_username="user")
# Non-admin cannot access someone else's job
resp = client_user.get(f"/api/jobs/{j_admin.job_id}")
assert resp.status_code in (403, 404)
if resp.status_code == 403:
assert "Not enough permissions" in resp.text
# Non-admin can access own job
resp = client_user.get(f"/api/jobs/{j_user.job_id}")
assert resp.status_code == 200
assert resp.json()["job_id"] == j_user.job_id
# Admin can access any job
resp = client_admin.get(f"/api/jobs/{j_user.job_id}")
assert resp.status_code == 200
assert resp.json()["job_id"] == j_user.job_id
def _audit_exists(job_id: str, action: str) -> bool:
db = SessionLocal()
try:
return (
db.query(AuditLog)
.filter(AuditLog.resource_type == "JOB", AuditLog.resource_id == job_id, AuditLog.action == action)
.count()
> 0
)
finally:
db.close()
def test_jobs_state_transitions_audit_and_metrics(client_admin: TestClient, make_job):
jt = f"pytest_jobs_{uuid.uuid4().hex[:6]}"
# Create a job and transition to running -> completed
j = make_job(job_type=jt, status="queued", requested_by_username="admin")
resp = client_admin.post(f"/api/jobs/{j.job_id}/mark-running")
assert resp.status_code == 200, resp.text
body = resp.json()
assert body["status"] == "running"
assert body["started_at"] is not None
assert body["completed_at"] is None
assert _audit_exists(j.job_id, "RUNNING")
resp = client_admin.post(
f"/api/jobs/{j.job_id}/mark-completed",
json={
"total_success": 5,
"total_failed": 0,
"result_storage_path": "exports/test_bundle.html",
"result_mime_type": "text/html",
"result_size": 123,
"details_update": {"note": "done"},
},
)
assert resp.status_code == 200, resp.text
body = resp.json()
assert body["status"] == "completed"
assert body["total_success"] == 5
assert body["total_failed"] == 0
assert body["has_result_bundle"] is True
assert _audit_exists(j.job_id, "COMPLETE")
# Create another job, transition to running -> failed
j2 = make_job(job_type=jt, status="queued", requested_by_username="admin")
resp = client_admin.post(f"/api/jobs/{j2.job_id}/mark-running")
assert resp.status_code == 200
resp = client_admin.post(
f"/api/jobs/{j2.job_id}/mark-failed",
json={"reason": "manual-error", "details_update": {"code": "E1"}},
)
assert resp.status_code == 200
body = resp.json()
assert body["status"] == "failed"
# Fetch job to verify details include last_error
resp = client_admin.get(f"/api/jobs/{j2.job_id}")
assert resp.status_code == 200
details = resp.json().get("details") or {}
assert details.get("last_error") == "manual-error"
assert _audit_exists(j2.job_id, "FAIL")
# Retry an existing job
note = "retry it"
resp = client_admin.post(f"/api/jobs/{j2.job_id}/retry", json={"note": note})
assert resp.status_code == 200
new_job_id = resp.json().get("job_id")
assert isinstance(new_job_id, str) and new_job_id
resp = client_admin.get(f"/api/jobs/{new_job_id}")
assert resp.status_code == 200
new_details = resp.json().get("details") or {}
assert new_details.get("retry_of") == j2.job_id
assert new_details.get("retry_note") == note
# Leave one job running for metrics
j3 = make_job(job_type=jt, status="queued", requested_by_username="admin")
resp = client_admin.post(f"/api/jobs/{j3.job_id}/mark-running")
assert resp.status_code == 200
# Metrics summary (admin-only)
resp = client_admin.get("/api/jobs/metrics/summary")
assert resp.status_code == 200
metrics = resp.json()
# Shape assertions
assert set(metrics.keys()) == {
"by_status",
"by_type",
"avg_duration_seconds",
"running_count",
"failed_last_24h",
"completed_last_24h",
}
assert isinstance(metrics["by_status"], dict)
assert isinstance(metrics["by_type"], dict)
# Our type appears with count at least the number we created in this test
# Created: j (completed), j2 (failed), retry (queued), j3 (running) => 4 with this type
assert metrics["by_type"].get(jt, 0) >= 4
# Running count should be at least 1 (j3)
assert metrics.get("running_count", 0) >= 1
# Ensure status buckets contain our transitions (may include others as well)
assert metrics["by_status"].get("completed", 0) >= 1
assert metrics["by_status"].get("failed", 0) >= 1

View File

@@ -0,0 +1,474 @@
"""
Test suite for P1 High Priority Security Features
Tests the security enhancements implemented for:
- Rate limiting
- Security headers
- Enhanced authentication (password complexity, account lockout)
- Database security (SQL injection prevention)
- CSRF protection and request size limits
"""
import pytest
import time
from datetime import datetime, timezone
from fastapi.testclient import TestClient
from sqlalchemy.orm import Session
from unittest.mock import patch, MagicMock
from app.main import app
from app.database.base import get_db
from app.models.user import User
from app.models.auth import LoginAttempt
from app.utils.enhanced_auth import (
PasswordValidator,
AccountLockoutManager,
SuspiciousActivityDetector,
)
from app.utils.database_security import (
SQLSecurityValidator,
SecureQueryBuilder,
execute_secure_query,
)
from app.middleware.rate_limiting import RateLimitStore
class TestRateLimiting:
"""Test rate limiting middleware functionality"""
def test_rate_limit_store_basic_functionality(self):
"""Test basic rate limit store operations"""
store = RateLimitStore()
# Test initial request - should be allowed
allowed, info = store.is_allowed("test_key", 5, 60)
assert allowed is True
assert info["remaining"] == 4
assert info["limit"] == 5
def test_rate_limit_exceeds_threshold(self):
"""Test rate limiting when threshold is exceeded"""
store = RateLimitStore()
# Make requests up to the limit
for i in range(5):
allowed, info = store.is_allowed("test_key", 5, 60)
assert allowed is True
# Next request should be rejected
allowed, info = store.is_allowed("test_key", 5, 60)
assert allowed is False
assert info["remaining"] == 0
def test_rate_limit_window_reset(self):
"""Test rate limit window reset functionality"""
store = RateLimitStore()
# Fill up the rate limit
for i in range(5):
store.is_allowed("test_key", 5, 1) # 1 second window
# Should be rejected
allowed, info = store.is_allowed("test_key", 5, 1)
assert allowed is False
# Wait for window to reset
time.sleep(1.1)
# Should be allowed again
allowed, info = store.is_allowed("test_key", 5, 1)
assert allowed is True
class TestSecurityHeaders:
"""Test security headers middleware"""
def test_security_headers_applied(self):
"""Test that security headers are applied to responses"""
client = TestClient(app)
response = client.get("/health")
# Check for key security headers
assert "X-Content-Type-Options" in response.headers
assert response.headers["X-Content-Type-Options"] == "nosniff"
assert "X-Frame-Options" in response.headers
assert response.headers["X-Frame-Options"] == "DENY"
assert "X-XSS-Protection" in response.headers
assert response.headers["X-XSS-Protection"] == "1; mode=block"
def test_csp_header_present(self):
"""Test Content Security Policy header"""
client = TestClient(app)
response = client.get("/health")
# CSP header should be present
if "Content-Security-Policy" in response.headers:
csp = response.headers["Content-Security-Policy"]
assert "default-src 'self'" in csp
assert "object-src 'none'" in csp
assert "frame-ancestors 'none'" in csp
def test_request_size_limit(self):
"""Test request size limiting"""
client = TestClient(app)
# Create a large payload (should be rejected if limit is enforced)
large_data = "x" * (101 * 1024 * 1024) # 101MB
response = client.post(
"/api/auth/login",
json={"username": "test", "password": large_data},
headers={"Content-Length": str(len(large_data))}
)
# Should be rejected for size
assert response.status_code == 413 # Request Entity Too Large
class TestPasswordValidation:
"""Test password strength validation"""
def test_weak_passwords_rejected(self):
"""Test that weak passwords are properly rejected"""
weak_passwords = [
"123456",
"password",
"qwerty",
"abc123",
"Password", # Missing special chars and numbers
"p@ssw0rd", # Too common
]
for password in weak_passwords:
is_valid, errors = PasswordValidator.validate_password_strength(password)
assert is_valid is False
assert len(errors) > 0
def test_strong_passwords_accepted(self):
"""Test that strong passwords are accepted"""
strong_passwords = [
"MyStr0ng!P@ssw0rd",
"C0mpl3x&S3cur3!",
"Adm1n!2024$ecure",
"Ungu3ss@bl3P@ssw0rd!",
]
for password in strong_passwords:
is_valid, errors = PasswordValidator.validate_password_strength(password)
assert is_valid is True
assert len(errors) == 0
def test_password_strength_scoring(self):
"""Test password strength scoring"""
# Weak password
weak_score = PasswordValidator.generate_password_strength_score("123456")
assert weak_score < 30
# Strong password
strong_score = PasswordValidator.generate_password_strength_score("MyStr0ng!P@ssw0rd")
assert strong_score > 70
def test_password_complexity_requirements(self):
"""Test individual password complexity requirements"""
# Test length requirement
is_valid, errors = PasswordValidator.validate_password_strength("Abc1!")
assert not is_valid
assert any("at least" in error for error in errors)
# Test uppercase requirement
is_valid, errors = PasswordValidator.validate_password_strength("abc123!@#")
assert not is_valid
assert any("uppercase" in error for error in errors)
# Test special character requirement
is_valid, errors = PasswordValidator.validate_password_strength("Abc123456")
assert not is_valid
assert any("special character" in error for error in errors)
class TestAccountLockout:
"""Test account lockout functionality"""
@pytest.fixture
def mock_db(self):
"""Mock database session"""
mock_db = MagicMock()
return mock_db
def test_lockout_after_failed_attempts(self, mock_db):
"""Test account lockout after multiple failed attempts"""
# Mock failed login attempts
mock_db.query.return_value.filter.return_value.scalar.return_value = 5
mock_db.query.return_value.filter.return_value.order_by.return_value.first.return_value = (
datetime.now(timezone.utc),
)
is_locked, unlock_time = AccountLockoutManager.is_account_locked(mock_db, "testuser")
assert is_locked is True
assert unlock_time is not None
def test_no_lockout_with_few_attempts(self, mock_db):
"""Test no lockout with fewer than threshold attempts"""
# Mock fewer failed attempts
mock_db.query.return_value.filter.return_value.scalar.return_value = 2
is_locked, unlock_time = AccountLockoutManager.is_account_locked(mock_db, "testuser")
assert is_locked is False
assert unlock_time is None
def test_lockout_info_retrieval(self, mock_db):
"""Test lockout information retrieval"""
# Mock database responses
mock_db.query.return_value.filter.return_value.order_by.return_value.all.return_value = []
mock_db.query.return_value.filter.return_value.scalar.return_value = 3
lockout_info = AccountLockoutManager.get_lockout_info(mock_db, "testuser")
assert "is_locked" in lockout_info
assert "failed_attempts" in lockout_info
assert "attempts_remaining" in lockout_info
assert "max_attempts" in lockout_info
class TestDatabaseSecurity:
"""Test database security utilities"""
def test_sql_injection_detection(self):
"""Test SQL injection pattern detection"""
# Test legitimate query
legitimate_query = "SELECT * FROM users WHERE username = :username"
issues = SQLSecurityValidator.validate_query_string(legitimate_query)
assert len(issues) == 0
# Test malicious queries
malicious_queries = [
"SELECT * FROM users WHERE id = 1; DROP TABLE users; --",
"SELECT * FROM users WHERE username = 'admin' OR 1=1 --",
"SELECT * FROM users UNION SELECT password FROM admin_table",
"SELECT * FROM users WHERE id = 1 AND (SELECT COUNT(*) FROM passwords) > 0",
]
for query in malicious_queries:
issues = SQLSecurityValidator.validate_query_string(query)
assert len(issues) > 0
def test_parameter_validation(self):
"""Test parameter value validation"""
# Test legitimate parameters
legitimate_params = {
"username": "john_doe",
"age": 25,
"email": "john@example.com",
}
for param_name, param_value in legitimate_params.items():
issues = SQLSecurityValidator.validate_parameter_value(param_name, param_value)
assert len(issues) == 0
# Test malicious parameters
malicious_params = {
"username": "'; DROP TABLE users; --",
"search": "admin' OR 1=1 --",
"id": "1 UNION SELECT password FROM users",
}
for param_name, param_value in malicious_params.items():
issues = SQLSecurityValidator.validate_parameter_value(param_name, param_value)
assert len(issues) > 0
def test_safe_query_building(self):
"""Test safe query building utilities"""
from sqlalchemy import Column, String
# Mock column for testing
mock_column = Column("name", String)
# Test safe LIKE clause building
like_clause = SecureQueryBuilder.build_like_clause(mock_column, "test_value")
assert like_clause is not None
# Test safe IN clause building
in_clause = SecureQueryBuilder.build_in_clause(mock_column, ["value1", "value2"])
assert in_clause is not None
# Test FTS query building
fts_query = SecureQueryBuilder.build_fts_query(["term1", "term2"])
assert "term1" in fts_query
assert "term2" in fts_query
def test_query_validation_with_params(self):
"""Test complete query validation with parameters"""
query = "SELECT * FROM users WHERE username = :username AND age > :min_age"
params = {
"username": "john_doe",
"min_age": 18,
}
issues = SQLSecurityValidator.validate_query_with_params(query, params)
assert len(issues) == 0
# Test with malicious parameters
malicious_params = {
"username": "'; DROP TABLE users; --",
"min_age": "18 OR 1=1",
}
issues = SQLSecurityValidator.validate_query_with_params(query, malicious_params)
assert len(issues) > 0
class TestSuspiciousActivityDetection:
"""Test suspicious activity detection"""
def test_suspicious_ip_detection(self):
"""Test detection of suspicious IP patterns"""
# This would typically require mock data or a test database
# For now, test the structure
mock_db = MagicMock()
mock_db.query.return_value.filter.return_value.all.return_value = []
alerts = SuspiciousActivityDetector.detect_suspicious_patterns(mock_db, 24)
assert isinstance(alerts, list)
def test_login_suspicion_check(self):
"""Test individual login suspicion checking"""
mock_db = MagicMock()
mock_db.query.return_value.filter.return_value.distinct.return_value.all.return_value = []
mock_db.query.return_value.filter.return_value.scalar.return_value = 0
is_suspicious, warnings = SuspiciousActivityDetector.is_login_suspicious(
mock_db, "testuser", "192.168.1.1", "Mozilla/5.0..."
)
assert isinstance(is_suspicious, bool)
assert isinstance(warnings, list)
class TestCSRFProtection:
"""Test CSRF protection middleware"""
def test_csrf_protection_on_post_requests(self):
"""Test CSRF protection for POST requests"""
client = TestClient(app)
# POST request without proper origin/referer should be blocked
response = client.post(
"/api/auth/login",
json={"username": "test", "password": "test"},
headers={
"Origin": "https://malicious-site.com",
"Host": "localhost:8000"
}
)
# Should be blocked by CSRF protection
# Note: This test might need adjustment based on actual CSRF implementation
assert response.status_code in [403, 401, 429] # Could be various security responses
def test_csrf_exemption_for_safe_methods(self):
"""Test that safe HTTP methods are exempt from CSRF protection"""
client = TestClient(app)
# GET requests should not be subject to CSRF protection
response = client.get("/health")
assert response.status_code == 200
class TestIntegratedSecurity:
"""Test integrated security features working together"""
def test_enhanced_login_endpoint(self):
"""Test the enhanced login endpoint with security features"""
client = TestClient(app)
# Test login with invalid credentials
response = client.post(
"/api/auth/login",
json={"username": "nonexistent", "password": "wrongpassword"}
)
assert response.status_code == 401
# Check for security headers in response
assert "X-RateLimit-Limit" in response.headers or "WWW-Authenticate" in response.headers
@patch('app.utils.enhanced_auth.validate_and_authenticate_user')
def test_account_lockout_headers(self, mock_auth):
"""Test that account lockout information is returned in headers"""
client = TestClient(app)
# Mock a locked account
mock_auth.return_value = (None, ["Account is locked"])
response = client.post(
"/api/auth/login",
json={"username": "locked_user", "password": "password"}
)
assert response.status_code == 401
def test_password_validation_endpoint(self):
"""Test the password validation endpoint"""
client = TestClient(app)
# Test weak password
response = client.post(
"/api/auth/validate-password",
json={"password": "123456"}
)
assert response.status_code == 200
data = response.json()
assert data["is_valid"] is False
assert len(data["errors"]) > 0
# Test strong password
response = client.post(
"/api/auth/validate-password",
json={"password": "MyStr0ng!P@ssw0rd"}
)
assert response.status_code == 200
data = response.json()
assert data["is_valid"] is True
assert len(data["errors"]) == 0
@pytest.fixture
def client():
"""Test client fixture"""
return TestClient(app)
def test_security_middleware_integration(client):
"""Test that all security middleware is properly integrated"""
response = client.get("/health")
# Should have rate limiting headers
rate_limit_headers = [h for h in response.headers.keys() if h.startswith("X-RateLimit")]
# Should have security headers
security_headers = ["X-Content-Type-Options", "X-Frame-Options", "X-XSS-Protection"]
for header in security_headers:
assert header in response.headers
# Should have correlation ID
assert "X-Correlation-ID" in response.headers
if __name__ == "__main__":
# Run specific test for quick validation
test_rate_limiting = TestRateLimiting()
test_rate_limiting.test_rate_limit_store_basic_functionality()
test_password = TestPasswordValidation()
test_password.test_weak_passwords_rejected()
test_password.test_strong_passwords_accepted()
test_db_security = TestDatabaseSecurity()
test_db_security.test_sql_injection_detection()
print("✅ All P1 security tests passed!")

View File

@@ -0,0 +1,555 @@
import os
import sys
from pathlib import Path
import math
import pytest
from fastapi.testclient import TestClient
# Ensure required env vars for app import/config
os.environ.setdefault("SECRET_KEY", "x" * 32)
os.environ.setdefault("DATABASE_URL", "sqlite:////tmp/delphi_test.sqlite")
# Ensure repository root on sys.path for direct test runs
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from app.main import app # noqa: E402
from app.auth.security import get_current_user # noqa: E402
from app.database.base import SessionLocal # noqa: E402
from app.models.pensions import NumberTable, LifeTable # noqa: E402
@pytest.fixture(scope="module")
def client():
# Override auth to bypass JWT for these tests
class _User:
def __init__(self):
self.id = "test"
self.username = "tester"
self.is_admin = True
self.is_active = True
app.dependency_overrides[get_current_user] = lambda: _User()
try:
yield TestClient(app)
finally:
app.dependency_overrides.pop(get_current_user, None)
def _seed_monthly_na_series():
db = SessionLocal()
try:
# Clear and seed NumberTable for months 0..4 for both sexes, race A
for m in range(0, 5):
db.query(NumberTable).filter(NumberTable.month == m).delete()
# NA series: [100000, 90000, 80000, 70000, 60000]
series = [100000.0, 90000.0, 80000.0, 70000.0, 60000.0]
for idx, na in enumerate(series):
row = NumberTable(
month=idx,
na_am=na, # male (all races)
na_af=na, # female (all races)
na_aa=na, # all (all races)
)
db.add(row)
# Provide an LE row in case fallback is used inadvertently
db.query(LifeTable).filter(LifeTable.age == 65).delete()
db.add(LifeTable(age=65, le_am=20.0, le_af=22.0, le_aa=21.0))
db.commit()
finally:
db.close()
def _monthly_rate(pct: float) -> float:
return pow(1.0 + pct / 100.0, 1.0 / 12.0) - 1.0
def test_single_life_no_discount_no_cola(client: TestClient):
_seed_monthly_na_series()
payload = {
"monthly_benefit": 1000.0,
"term_months": 5,
"start_age": 65,
"sex": "M",
"race": "A",
"discount_rate": 0.0,
"cola_rate": 0.0,
}
resp = client.post("/api/pensions/valuation/single-life", json=payload)
assert resp.status_code == 200, resp.text
# Survival probs: [1.0, 0.9, 0.8, 0.7, 0.6]
# PV = 1000 * sum = 4000
assert math.isclose(resp.json()["pv"], 4000.0, rel_tol=1e-6, abs_tol=0.01)
def test_joint_survivor_no_discount_no_cola(client: TestClient):
_seed_monthly_na_series()
payload = {
"monthly_benefit": 1000.0,
"term_months": 5,
"participant_age": 65,
"participant_sex": "M",
"participant_race": "A",
"spouse_age": 63,
"spouse_sex": "F",
"spouse_race": "A",
"survivor_percent": 50.0,
"discount_rate": 0.0,
"cola_rate": 0.0,
}
resp = client.post("/api/pensions/valuation/joint-survivor", json=payload)
assert resp.status_code == 200, resp.text
data = resp.json()
# p = [1, .9, .8, .7, .6]
# p_both = p^2 = [1, .81, .64, .49, .36]
# payment_t = 1000*p_both + 1000*0.5*(p - p_both) = 500*p + 500*p_both
expected = 500.0 * (1 + 0.9 + 0.8 + 0.7 + 0.6) + 500.0 * (1 + 0.81 + 0.64 + 0.49 + 0.36)
assert math.isclose(data["pv_total"], expected, rel_tol=1e-6, abs_tol=0.01)
# Components
expected_both = 1000.0 * (1 + 0.81 + 0.64 + 0.49 + 0.36)
expected_surv = expected - expected_both
assert math.isclose(data["pv_participant_component"], expected_both, rel_tol=1e-6, abs_tol=0.01)
assert math.isclose(data["pv_survivor_component"], expected_surv, rel_tol=1e-6, abs_tol=0.01)
def test_joint_survivor_last_survivor_basis(client: TestClient):
_seed_monthly_na_series()
payload = {
"monthly_benefit": 1000.0,
"term_months": 5,
"participant_age": 65,
"participant_sex": "M",
"participant_race": "A",
"spouse_age": 63,
"spouse_sex": "F",
"spouse_race": "A",
"survivor_percent": 50.0, # should be ignored in last_survivor for total
"discount_rate": 0.0,
"cola_rate": 0.0,
"survivor_basis": "last_survivor",
}
resp = client.post("/api/pensions/valuation/joint-survivor", json=payload)
assert resp.status_code == 200, resp.text
data = resp.json()
# p = [1, .9, .8, .7, .6]; p_both = p^2; p_either = p_part + p_sp - p_both = 2p - p^2
p = [1.0, 0.9, 0.8, 0.7, 0.6]
p_both = [x * x for x in p]
p_either = [2 * x - x * x for x in p]
expected_total = 1000.0 * sum(p_either)
assert math.isclose(data["pv_total"], expected_total, rel_tol=1e-6, abs_tol=0.01)
def test_interpolation_between_number_rows(client: TestClient):
# Seed months 0 and 2 only; expect linear interpolation for month 1
db = SessionLocal()
try:
from app.models.pensions import NumberTable, LifeTable
for m in range(0, 3):
db.query(NumberTable).filter(NumberTable.month == m).delete()
# only months 0 and 2
db.add(NumberTable(month=0, na_am=100000.0, na_af=100000.0, na_aa=100000.0))
db.add(NumberTable(month=2, na_am=80000.0, na_af=80000.0, na_aa=80000.0))
# Ensure LE exists but should not be used due to interpolation
db.query(LifeTable).filter(LifeTable.age == 65).delete()
db.add(LifeTable(age=65, le_am=30.0, le_af=30.0, le_aa=30.0))
db.commit()
finally:
db.close()
payload = {
"monthly_benefit": 1000.0,
"term_months": 3,
"start_age": 65,
"sex": "A",
"race": "A",
"discount_rate": 0.0,
"cola_rate": 0.0,
"interpolation_method": "linear",
}
resp = client.post("/api/pensions/valuation/single-life", json=payload)
assert resp.status_code == 200, resp.text
# survival at t: 0->1.0, 1->0.9 (interpolated), 2->0.8
expected = 1000.0 * (1.0 + 0.9 + 0.8)
assert math.isclose(resp.json()["pv"], expected, rel_tol=1e-6, abs_tol=0.01)
def test_interpolation_step_method(client: TestClient):
# Seed months 0 and 3 only; with step interpolation, months 1 and 2 carry month 0 value
db = SessionLocal()
try:
from app.models.pensions import NumberTable
for m in range(0, 4):
db.query(NumberTable).filter(NumberTable.month == m).delete()
db.add(NumberTable(month=0, na_aa=100000.0))
db.add(NumberTable(month=3, na_aa=70000.0))
db.commit()
finally:
db.close()
payload = {
"monthly_benefit": 1000.0,
"term_months": 4,
"start_age": 65,
"sex": "A",
"race": "A",
"discount_rate": 0.0,
"cola_rate": 0.0,
"interpolation_method": "step",
}
resp = client.post("/api/pensions/valuation/single-life", json=payload)
assert resp.status_code == 200, resp.text
# Series becomes NA: [100k, 100k, 100k, 70k] -> p: [1.0,1.0,1.0,0.7]
expected = 1000.0 * (1.0 + 1.0 + 1.0 + 0.7)
assert math.isclose(resp.json()["pv"], expected, rel_tol=1e-6, abs_tol=0.01)
def test_max_age_truncation(client: TestClient):
_seed_monthly_na_series()
# start_age 65, max_age 66 -> at most 12 months
payload = {
"monthly_benefit": 1000.0,
"term_months": 24,
"start_age": 65,
"sex": "A",
"race": "A",
"discount_rate": 0.0,
"cola_rate": 0.0,
"max_age": 66,
}
resp = client.post("/api/pensions/valuation/single-life", json=payload)
assert resp.status_code == 200
# With our seed p falls by 0.1 each month; but we only have 5 months seeded; extend to 12
db = SessionLocal()
try:
from app.models.pensions import NumberTable
for m in range(5, 12):
db.query(NumberTable).filter(NumberTable.month == m).delete()
db.add(NumberTable(month=m, na_aa=100000.0 - 10000.0 * m))
db.commit()
finally:
db.close()
# Recompute with extension
resp = client.post("/api/pensions/valuation/single-life", json=payload)
assert resp.status_code == 200
# p: 1.0,0.9,0.8,0.7,0.6,0.5,0.4,0.3,0.2,0.1,0.0 (clamped by interpolation), 0.0 -> sum first 10 effectively
expected = 1000.0 * (1.0 + 0.9 + 0.8 + 0.7 + 0.6 + 0.5 + 0.4 + 0.3 + 0.2 + 0.1)
assert math.isclose(resp.json()["pv"], expected, rel_tol=1e-6, abs_tol=0.1)
def test_single_life_with_discount_and_cola(client: TestClient):
_seed_monthly_na_series()
# discount 12% annual, COLA 12% annual
d_m = _monthly_rate(12.0)
g_m = _monthly_rate(12.0)
probs = [1.0, 0.9, 0.8, 0.7, 0.6]
monthly = 1000.0
# Expected PV: sum( monthly * p[t] * (1+g)^t / (1+i)^t )
expected = 0.0
g = 1.0
d = 1.0
for t, p in enumerate(probs):
if t == 0:
g = 1.0
d = 1.0
else:
g *= (1.0 + g_m)
d *= (1.0 + d_m)
expected += monthly * p * g / d
payload = {
"monthly_benefit": monthly,
"term_months": 5,
"start_age": 65,
"sex": "A",
"race": "A",
"discount_rate": 12.0,
"cola_rate": 12.0,
"cola_mode": "monthly",
}
resp = client.post("/api/pensions/valuation/single-life", json=payload)
assert resp.status_code == 200, resp.text
assert math.isclose(resp.json()["pv"], expected, rel_tol=1e-6, abs_tol=0.05)
def test_single_life_cola_annual_prorated_and_cap(client: TestClient):
_seed_monthly_na_series()
# Extend months 0..11 to allow a full year
db = SessionLocal()
try:
from app.models.pensions import NumberTable
for m in range(5, 12):
db.query(NumberTable).filter(NumberTable.month == m).delete()
db.add(NumberTable(month=m, na_aa=100000.0 - 10000.0 * m))
db.commit()
finally:
db.close()
# annual COLA 6% but cap at 4%; annual prorated mode
# Payments monthly for 6 months, p[t]=1,.9,.8,.7,.6,.5
# Growth at t uses 4% cap prorated: factors ~ 1.0, 1+0.04*(1/12), 1+0.04*(2/12), ...
monthly = 1000.0
probs = [1.0, 0.9, 0.8, 0.7, 0.6, 0.5]
a = 0.04
expected = 0.0
for t, p in enumerate(probs):
growth = (1.0 + a * (t / 12.0))
expected += monthly * p * growth
payload = {
"monthly_benefit": monthly,
"term_months": 6,
"start_age": 65,
"sex": "A",
"race": "A",
"discount_rate": 0.0,
"cola_rate": 6.0,
"cola_mode": "annual_prorated",
"cola_cap_percent": 4.0,
}
resp = client.post("/api/pensions/valuation/single-life", json=payload)
assert resp.status_code == 200, resp.text
assert math.isclose(resp.json()["pv"], expected, rel_tol=1e-6, abs_tol=0.05)
def test_single_life_quarterly_with_deferral(client: TestClient):
_seed_monthly_na_series()
# Quarterly payments (3 months per payment), starting after 2-month deferral
# Horizon 10 months, payments at t=2,5,8 (assuming 0-indexed months)
# Survival probs p[t] = [1, .9, .8, .7, .6, .5, .4, .3, .2, .1] for first 10 implied by seed extension
# But our seed has only 5 months; extend seed
db = SessionLocal()
try:
# extend NumberTable to months 0..9 with linear decrement 100k - 10k*m
from app.models.pensions import NumberTable
for m in range(5, 10):
db.query(NumberTable).filter(NumberTable.month == m).delete()
db.add(NumberTable(month=m, na_aa=100000.0 - 10000.0 * m))
db.commit()
finally:
db.close()
payload = {
"monthly_benefit": 1000.0,
"term_months": 10,
"start_age": 65,
"sex": "A",
"race": "A",
"discount_rate": 0.0,
"cola_rate": 0.0,
"defer_months": 2,
"payment_period_months": 3,
}
resp = client.post("/api/pensions/valuation/single-life", json=payload)
assert resp.status_code == 200, resp.text
# p = [1.0, .9, .8, .7, .6, .5, .4, .3, .2, .1]
# Payments at t=2,5,8: amount = monthly * 3 * p[t]
expected = 1000.0 * 3.0 * (0.8 + 0.5 + 0.2)
assert math.isclose(resp.json()["pv"], expected, rel_tol=1e-6, abs_tol=0.01)
def test_fractional_deferral_prorated_first_payment(client: TestClient):
_seed_monthly_na_series()
# Monthly payments, defer 0.5 month: first payment at month 1 with 50% of base
# With p = [1.0, 0.9, 0.8, 0.7, 0.6]
monthly = 1000.0
payload = {
"monthly_benefit": monthly,
"term_months": 5,
"start_age": 65,
"sex": "A",
"race": "A",
"discount_rate": 0.0,
"cola_rate": 0.0,
"defer_months": 0.5,
"payment_period_months": 1,
}
resp = client.post("/api/pensions/valuation/single-life", json=payload)
assert resp.status_code == 200, resp.text
# First payment at t=1: 0.5*1000*0.9 + subsequent: 1000*0.8 + 1000*0.7 + 1000*0.6
expected = 0.5 * monthly * 0.9 + monthly * (0.8 + 0.7 + 0.6)
assert math.isclose(resp.json()["pv"], expected, rel_tol=1e-6, abs_tol=0.01)
def test_joint_survivor_participant_only_commencement(client: TestClient):
_seed_monthly_na_series()
payload = {
"monthly_benefit": 1000.0,
"term_months": 5,
"participant_age": 65,
"participant_sex": "M",
"participant_race": "A",
"spouse_age": 63,
"spouse_sex": "F",
"spouse_race": "A",
"survivor_percent": 50.0,
"discount_rate": 0.0,
"cola_rate": 0.0,
"survivor_commence_participant_only": True,
}
resp = client.post("/api/pensions/valuation/joint-survivor", json=payload)
assert resp.status_code == 200, resp.text
data = resp.json()
# p = [1, .9, .8, .7, .6]; p_both = p^2; p_sp_only replaced by p_part for survivor component
p = [1.0, 0.9, 0.8, 0.7, 0.6]
p_both = [x * x for x in p]
surv_component = sum(1000.0 * 0.5 * x for x in p) # using participant survival
both_component = sum(1000.0 * x for x in p_both)
expected_total = 1000.0 + (both_component - 1000.0) + surv_component # first period total includes guarantee at t=0 (from previous tests)
# Given our service guarantees only if certain_months > 0, here it's 0, so no guarantee. Recompute expected total accordingly
expected_total = both_component + surv_component
assert math.isclose(data["pv_participant_component"], both_component, rel_tol=1e-6, abs_tol=0.01)
assert math.isclose(data["pv_survivor_component"], surv_component, rel_tol=1e-6, abs_tol=0.01)
assert math.isclose(data["pv_total"], expected_total, rel_tol=1e-6, abs_tol=0.01)
def test_certain_period_guarantee_then_mortality(client: TestClient):
_seed_monthly_na_series()
# Extend 0..9 again
db = SessionLocal()
try:
from app.models.pensions import NumberTable
for m in range(5, 10):
db.query(NumberTable).filter(NumberTable.month == m).delete()
db.add(NumberTable(month=m, na_aa=100000.0 - 10000.0 * m))
db.commit()
finally:
db.close()
# Monthly payments, no deferral, 3 months certain
payload = {
"monthly_benefit": 1000.0,
"term_months": 6,
"start_age": 65,
"sex": "A",
"race": "A",
"discount_rate": 0.0,
"cola_rate": 0.0,
"defer_months": 0,
"payment_period_months": 1,
"certain_months": 3,
}
resp = client.post("/api/pensions/valuation/single-life", json=payload)
assert resp.status_code == 200, resp.text
# p = [1.0, .9, .8, .7, .6, .5]
# guaranteed first 3: 1000, 1000, 1000; then mortality-weighted: 700, 600, 500
expected = 1000.0 * 3 + 1000.0 * (0.7 + 0.6 + 0.5)
assert math.isclose(resp.json()["pv"], expected, rel_tol=1e-6, abs_tol=0.01)
# Joint-survivor: same guarantee logic applies to total stream
payload_js = {
"monthly_benefit": 1000.0,
"term_months": 4,
"participant_age": 65,
"participant_sex": "M",
"participant_race": "A",
"spouse_age": 63,
"spouse_sex": "F",
"spouse_race": "A",
"survivor_percent": 50.0,
"discount_rate": 0.0,
"cola_rate": 0.0,
"defer_months": 0,
"payment_period_months": 1,
"certain_months": 2,
}
resp = client.post("/api/pensions/valuation/joint-survivor", json=payload_js)
assert resp.status_code == 200, resp.text
data = resp.json()
# p = [1.0, .9, .8, .7]
# total mortality stream (no guarantee): 1000*p_both + 500*(p - p_both)
p = [1.0, 0.9, 0.8, 0.7]
p_both = [x * x for x in p]
mort = [1000.0 * pb + 500.0 * (x - pb) for x, pb in zip(p, p_both)]
total_expected = 1000.0 + 1000.0 + mort[2] + mort[3]
assert math.isclose(data["pv_total"], total_expected, rel_tol=1e-6, abs_tol=0.01)
def test_batch_single_life_mixed_success_failure(client: TestClient):
_seed_monthly_na_series()
valid_item = {
"monthly_benefit": 1000.0,
"term_months": 5,
"start_age": 65,
"sex": "M",
"race": "A",
"discount_rate": 0.0,
"cola_rate": 0.0,
}
invalid_item = {
"monthly_benefit": -100.0, # invalid
"term_months": 5,
"start_age": 65,
"sex": "M",
"race": "A",
"discount_rate": 0.0,
"cola_rate": 0.0,
}
payload = {
"items": [valid_item, invalid_item, valid_item]
}
resp = client.post("/api/pensions/valuation/batch-single-life", json=payload)
assert resp.status_code == 200, resp.text
data = resp.json()["results"]
assert len(data) == 3
assert data[0]["success"] is True
assert "pv" in data[0]["result"]
assert math.isclose(data[0]["result"]["pv"], 4000.0, rel_tol=1e-6)
assert data[1]["success"] is False
assert "monthly_benefit must be non-negative" in (data[1]["error"] or "")
assert data[2]["success"] is True
assert math.isclose(data[2]["result"]["pv"], 4000.0, rel_tol=1e-6)
def test_batch_joint_survivor_mixed_success_failure(client: TestClient):
_seed_monthly_na_series()
valid_item = {
"monthly_benefit": 1000.0,
"term_months": 5,
"participant_age": 65,
"participant_sex": "M",
"participant_race": "A",
"spouse_age": 63,
"spouse_sex": "F",
"spouse_race": "A",
"survivor_percent": 50.0,
"discount_rate": 0.0,
"cola_rate": 0.0,
}
invalid_item = {
"monthly_benefit": 1000.0,
"term_months": 5,
"participant_age": 65,
"participant_sex": "M",
"participant_race": "A",
"spouse_age": 63,
"spouse_sex": "F",
"spouse_race": "A",
"survivor_percent": 150.0, # invalid >100
"discount_rate": 0.0,
"cola_rate": 0.0,
}
payload = {
"items": [valid_item, invalid_item]
}
resp = client.post("/api/pensions/valuation/batch-joint-survivor", json=payload)
assert resp.status_code == 200, resp.text
data = resp.json()["results"]
assert len(data) == 2
assert data[0]["success"] is True
assert "pv_total" in data[0]["result"]
p = [1.0, 0.9, 0.8, 0.7, 0.6]
p_both = [x * x for x in p]
expected = 500.0 * sum(p) + 500.0 * sum(p_both)
assert math.isclose(data[0]["result"]["pv_total"], expected, rel_tol=1e-6)
assert data[1]["success"] is False
assert "survivor_percent must be between 0 and 100" in (data[1]["error"] or "")

View File

@@ -0,0 +1,192 @@
import os
import uuid
import csv
import io
import pytest
from fastapi.testclient import TestClient
# Ensure required env vars for app import/config
os.environ.setdefault("SECRET_KEY", "x" * 32)
os.environ.setdefault("DATABASE_URL", "sqlite:////tmp/delphi_test.sqlite")
from app.main import app # noqa: E402
from app.auth.security import get_current_user # noqa: E402
@pytest.fixture(scope="module")
def client():
# Override auth to bypass JWT for these tests
class _User:
def __init__(self):
self.id = "test"
self.username = "tester"
self.is_admin = True
self.is_active = True
app.dependency_overrides[get_current_user] = lambda: _User()
try:
yield TestClient(app)
finally:
app.dependency_overrides.pop(get_current_user, None)
@pytest.fixture(scope="module")
def phone_book_data(client: TestClient):
gid1 = f"PBGRP1-{uuid.uuid4().hex[:6]}"
gid2 = f"PBGRP2-{uuid.uuid4().hex[:6]}"
def _create_customer(cid: str, last: str, first: str, group: str):
payload = {
"id": cid,
"last": last,
"first": first,
"group": group,
"email": f"{cid.lower()}@example.com",
}
resp = client.post("/api/customers/", json=payload)
assert resp.status_code == 200, resp.text
def _add_phone(cid: str, location: str, number: str):
resp = client.post(f"/api/customers/{cid}/phones", json={"location": location, "phone": number})
assert resp.status_code == 200, resp.text
# Create entries for letters A, B and non-alpha '#'
cid_a1 = f"PB-{uuid.uuid4().hex[:6]}-A1"
cid_a2 = f"PB-{uuid.uuid4().hex[:6]}-A2"
cid_b1 = f"PB-{uuid.uuid4().hex[:6]}-B1"
cid_hash = f"PB-{uuid.uuid4().hex[:6]}-H"
_create_customer(cid_a1, last="Alpha", first="Alice", group=gid1)
_add_phone(cid_a1, "Office", "111-111-1111")
_add_phone(cid_a1, "Mobile", "111-111-2222")
_create_customer(cid_a2, last="Able", first="Andy", group=gid1)
_add_phone(cid_a2, "Office", "222-222-2222")
_create_customer(cid_b1, last="Beta", first="Bob", group=gid2)
_add_phone(cid_b1, "Home", "333-333-3333")
_create_customer(cid_hash, last="123Company", first="NA", group=gid1)
_add_phone(cid_hash, "Main", "444-444-4444")
try:
yield {
"gid1": gid1,
"gid2": gid2,
"ids": [cid_a1, cid_a2, cid_b1, cid_hash],
}
finally:
# Cleanup
for cid in [cid_a1, cid_a2, cid_b1, cid_hash]:
client.delete(f"/api/customers/{cid}")
def _parse_csv(text: str):
reader = csv.reader(io.StringIO(text))
rows = list(reader)
return rows[0], rows[1:]
def test_phone_book_csv_letter_column_when_grouped_by_letter(client: TestClient, phone_book_data):
# Only include our test group gid1 to avoid interference
gid = phone_book_data["gid1"]
resp = client.get(
"/api/customers/phone-book",
params={
"format": "csv",
"mode": "numbers",
"grouping": "letter",
"groups": gid,
},
)
assert resp.status_code == 200
assert "text/csv" in resp.headers.get("content-type", "")
header, rows = _parse_csv(resp.text)
assert "Letter" in header
# Ensure letters include 'A' and '#'
letters = {r[header.index("Letter")] for r in rows}
assert "A" in letters and "#" in letters
def test_phone_book_html_sections_by_letter_with_page_break(client: TestClient, phone_book_data):
gid = phone_book_data["gid1"]
resp = client.get(
"/api/customers/phone-book",
params={
"format": "html",
"mode": "numbers",
"grouping": "letter",
"page_break": "1",
"groups": gid,
},
)
assert resp.status_code == 200
html = resp.text
# Snapshot-style checks: section headers and page-break class
assert "Letter: A" in html
assert "Letter: #" in html
assert "class=\"section-title\"" in html
assert "page-break" in html # later sections should have page-break class
def test_phone_book_html_group_then_letter_sections(client: TestClient, phone_book_data):
gid1 = phone_book_data["gid1"]
gid2 = phone_book_data["gid2"]
# Include both groups to verify group and nested letter sections
resp = client.get(
"/api/customers/phone-book",
params=[
("format", "html"),
("mode", "addresses"),
("grouping", "group_letter"),
("groups", gid1),
("groups", gid2),
],
)
assert resp.status_code == 200
html = resp.text
assert f"Group: {gid1}" in html
assert f"Group: {gid2}" in html
assert "Letter: A" in html or "Letter: #" in html
def test_phone_book_csv_no_letter_for_grouping_none(client: TestClient, phone_book_data):
gid = phone_book_data["gid1"]
resp = client.get(
"/api/customers/phone-book",
params={
"format": "csv",
"mode": "numbers",
"grouping": "none",
"groups": gid,
},
)
assert resp.status_code == 200
header, rows = _parse_csv(resp.text)
assert "Letter" not in header
# Basic sanity: names and phones are present
assert any("Alpha" in ",".join(row) or "Able" in ",".join(row) for row in rows)
def test_phone_book_respects_group_filter(client: TestClient, phone_book_data):
gid2 = phone_book_data["gid2"]
resp = client.get(
"/api/customers/phone-book",
params={
"format": "csv",
"mode": "numbers",
"grouping": "letter",
"groups": gid2,
},
)
assert resp.status_code == 200
header, rows = _parse_csv(resp.text)
# Only Beta/Bob (gid2) should be present
all_text = "\n".join([",".join(r) for r in rows])
assert "Beta" in all_text or "Bob" in all_text
# Ensure gid1 names are not present
assert "Alpha" not in all_text and "Able" not in all_text

View File

@@ -0,0 +1,199 @@
import os
import io
import uuid
import asyncio
from time import sleep
import sys
from pathlib import Path
import pytest
from fastapi.testclient import TestClient
# Ensure required env vars for app import/config
os.environ.setdefault("SECRET_KEY", "x" * 32)
os.environ.setdefault("DATABASE_URL", "sqlite:////tmp/delphi_test.sqlite")
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from app.main import app # noqa: E402
from app.auth.security import get_current_user # noqa: E402
from app.config import settings # noqa: E402
from app.services.template_search import TemplateSearchService # noqa: E402
def _dummy_docx_bytes():
try:
from docx import Document # type: ignore
except Exception:
return b"PK\x03\x04"
d = Document()
p = d.add_paragraph()
p.add_run("Cache Test ")
p.add_run("{{TOKEN}}")
buf = io.BytesIO()
d.save(buf)
return buf.getvalue()
def _upload_template(client: TestClient, name: str, category: str = "GENERAL") -> int:
files = {"file": (f"{name}.docx", _dummy_docx_bytes(), "application/vnd.openxmlformats-officedocument.wordprocessingml.document")}
resp = client.post(
"/api/templates/upload",
data={"name": name, "category": category, "semantic_version": "1.0.0"},
files=files,
)
assert resp.status_code == 200, resp.text
return int(resp.json()["id"])
@pytest.fixture()
def client_no_redis():
class _User:
def __init__(self):
self.id = "tmpl-cache-user"
self.username = "tester"
self.is_admin = True
self.is_active = True
app.dependency_overrides[get_current_user] = lambda: _User()
# Disable Redis to exercise in-memory fallback
settings.cache_enabled = False
settings.redis_url = ""
# Clear template caches
try:
asyncio.run(TemplateSearchService.invalidate_all())
except Exception:
pass
try:
yield TestClient(app)
finally:
app.dependency_overrides.pop(get_current_user, None)
def test_templates_search_caches_in_memory(client_no_redis: TestClient):
# Upload in unique category to avoid interference
category = f"CAT_{uuid.uuid4().hex[:8]}"
_upload_template(client_no_redis, f"T-{uuid.uuid4().hex[:6]}", category)
params = {"category": category}
r1 = client_no_redis.get("/api/templates/search", params=params)
assert r1.status_code == 200
d1 = r1.json()
# Second call should be served from cache and be identical
r2 = client_no_redis.get("/api/templates/search", params=params)
assert r2.status_code == 200
d2 = r2.json()
assert d1 == d2
def test_templates_search_invalidation_on_upload_in_memory(client_no_redis: TestClient):
# Use unique empty category; first query caches empty list
category = f"EMPTY_{uuid.uuid4().hex[:8]}"
r_empty = client_no_redis.get("/api/templates/search", params={"category": category})
assert r_empty.status_code == 200
assert r_empty.json() == []
# Upload a template into that category -> triggers invalidation
_upload_template(client_no_redis, f"New-{uuid.uuid4().hex[:6]}", category)
sleep(0.05)
# Query again should reflect the new item (cache invalidated)
r_after = client_no_redis.get("/api/templates/search", params={"category": category})
assert r_after.status_code == 200
ids = [it["id"] for it in r_after.json()]
assert len(ids) >= 1
def test_templates_search_invalidation_on_keyword_update_in_memory(client_no_redis: TestClient):
category = f"KW_{uuid.uuid4().hex[:8]}"
tid = _upload_template(client_no_redis, f"KW-{uuid.uuid4().hex[:6]}", category)
# Initial query has_keywords=true should be empty and cached
r1 = client_no_redis.get("/api/templates/search", params={"category": category, "has_keywords": True})
assert r1.status_code == 200
assert r1.json() == []
# Add a keyword to the template -> invalidates caches
r_add = client_no_redis.post(f"/api/templates/{tid}/keywords", json={"keywords": ["alpha"]})
assert r_add.status_code == 200
sleep(0.05)
# Now search with has_keywords=true should include our template
r2 = client_no_redis.get("/api/templates/search", params={"category": category, "has_keywords": True})
assert r2.status_code == 200
ids2 = {it["id"] for it in r2.json()}
assert tid in ids2
@pytest.fixture()
def client_with_redis():
if not settings.redis_url:
pytest.skip("Redis not configured for caching tests")
class _User:
def __init__(self):
self.id = "tmpl-cache-user-redis"
self.username = "tester"
self.is_admin = True
self.is_active = True
app.dependency_overrides[get_current_user] = lambda: _User()
settings.cache_enabled = True
# Clear template caches
try:
asyncio.run(TemplateSearchService.invalidate_all())
except Exception:
pass
try:
yield TestClient(app)
finally:
app.dependency_overrides.pop(get_current_user, None)
def test_templates_search_caches_with_redis(client_with_redis: TestClient):
category = f"RCAT_{uuid.uuid4().hex[:8]}"
_upload_template(client_with_redis, f"RT-{uuid.uuid4().hex[:6]}", category)
params = {"category": category}
r1 = client_with_redis.get("/api/templates/search", params=params)
assert r1.status_code == 200
d1 = r1.json()
r2 = client_with_redis.get("/api/templates/search", params=params)
assert r2.status_code == 200
d2 = r2.json()
assert d1 == d2
def test_templates_search_invalidation_on_upload_with_redis(client_with_redis: TestClient):
category = f"RADD_{uuid.uuid4().hex[:8]}"
r_empty = client_with_redis.get("/api/templates/search", params={"category": category})
assert r_empty.status_code == 200 and r_empty.json() == []
_upload_template(client_with_redis, f"RNew-{uuid.uuid4().hex[:6]}", category)
sleep(0.05)
r_after = client_with_redis.get("/api/templates/search", params={"category": category})
assert r_after.status_code == 200
assert len(r_after.json()) >= 1
def test_templates_search_invalidation_on_keyword_update_with_redis(client_with_redis: TestClient):
category = f"RKW_{uuid.uuid4().hex[:8]}"
tid = _upload_template(client_with_redis, f"RTKW-{uuid.uuid4().hex[:6]}", category)
r1 = client_with_redis.get("/api/templates/search", params={"category": category, "has_keywords": True})
assert r1.status_code == 200 and r1.json() == []
r_add = client_with_redis.post(f"/api/templates/{tid}/keywords", json={"keywords": ["beta"]})
assert r_add.status_code == 200
sleep(0.05)
r2 = client_with_redis.get("/api/templates/search", params={"category": category, "has_keywords": True})
ids = {it["id"] for it in r2.json()}
assert tid in ids

View File

@@ -0,0 +1,442 @@
"""
Tests for WebSocket management endpoints in the Admin API
Tests cover:
- WebSocket statistics endpoint
- Connection listing and filtering
- Connection management (disconnect, cleanup)
- Broadcasting functionality
- Admin-only access control
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from fastapi.testclient import TestClient
from fastapi import status
from app.main import app
from app.auth.security import get_current_user, get_admin_user
from app.models.user import User
@pytest.fixture
def test_client():
"""Create test client"""
return TestClient(app)
@pytest.fixture
def admin_user():
"""Create admin user for testing"""
user = User(
id=1,
username="admin",
email="admin@test.com",
is_admin=True,
is_active=True
)
return user
@pytest.fixture
def admin_client(admin_user):
"""Test client with admin dependency overrides applied."""
app.dependency_overrides[get_current_user] = lambda: admin_user
app.dependency_overrides[get_admin_user] = lambda: admin_user
try:
yield TestClient(app)
finally:
app.dependency_overrides.pop(get_current_user, None)
app.dependency_overrides.pop(get_admin_user, None)
@pytest.fixture
def regular_user():
"""Create regular user for testing"""
user = User(
id=2,
username="user",
email="user@test.com",
is_admin=False,
is_active=True
)
return user
class TestWebSocketStatsEndpoint:
"""Test WebSocket statistics endpoint"""
@patch('app.api.admin.get_admin_user')
@patch('app.api.admin.get_websocket_manager')
def test_get_websocket_stats_success(self, mock_get_manager, mock_get_admin, admin_client, admin_user):
"""Test successful retrieval of WebSocket statistics"""
# Mock admin authentication
mock_get_admin.return_value = admin_user
# Mock WebSocket manager
mock_manager = MagicMock()
mock_manager.get_stats = AsyncMock(return_value={
"total_connections": 10,
"active_connections": 8,
"total_topics": 5,
"total_users": 3,
"messages_sent": 100,
"messages_failed": 2,
"connections_cleaned": 5,
"last_cleanup": "2023-01-01T12:00:00Z",
"last_heartbeat": "2023-01-01T12:01:00Z",
"connections_by_state": {"connected": 8, "disconnected": 2},
"topic_distribution": {"topic1": 5, "topic2": 3}
})
mock_get_manager.return_value = mock_manager
# Make request
response = admin_client.get("/api/admin/websockets/stats")
# Check response
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["total_connections"] == 10
assert data["active_connections"] == 8
assert data["messages_sent"] == 100
assert "topic1" in data["topic_distribution"]
@patch('app.api.admin.get_admin_user')
def test_get_websocket_stats_unauthorized(self, mock_get_admin, test_client, regular_user):
"""Test unauthorized access to WebSocket statistics"""
# Mock non-admin user
mock_get_admin.side_effect = Exception("Admin required")
# Make request
response = test_client.get("/api/admin/websockets/stats")
# Should fail (forbidden)
assert response.status_code in (status.HTTP_403_FORBIDDEN, status.HTTP_401_UNAUTHORIZED, status.HTTP_500_INTERNAL_SERVER_ERROR)
class TestWebSocketConnectionsEndpoint:
"""Test WebSocket connections listing endpoint"""
@patch('app.api.admin.get_admin_user')
@patch('app.api.admin.get_websocket_manager')
@patch('app.api.admin.get_connection_tracker')
def test_get_connections_success(self, mock_get_tracker, mock_get_manager, mock_get_admin, admin_client, admin_user):
"""Test successful retrieval of WebSocket connections"""
# Mock admin authentication
mock_get_admin.return_value = admin_user
# Mock WebSocket manager and tracker
mock_manager = MagicMock()
mock_pool = MagicMock()
mock_pool._connections_lock = AsyncMock()
mock_pool._connections = {"conn1": MagicMock(), "conn2": MagicMock()}
mock_manager.pool = mock_pool
mock_get_manager.return_value = mock_manager
mock_tracker = MagicMock()
mock_tracker.get_connection_metrics = AsyncMock(side_effect=[
{
"connection_id": "conn1",
"user_id": 1,
"state": "connected",
"topics": ["topic1"],
"created_at": "2023-01-01T12:00:00Z",
"last_activity": "2023-01-01T12:01:00Z",
"age_seconds": 60,
"idle_seconds": 10,
"error_count": 0,
"last_ping": "2023-01-01T12:01:00Z",
"last_pong": "2023-01-01T12:01:00Z",
"metadata": {},
"is_alive": True,
"is_stale": False
},
{
"connection_id": "conn2",
"user_id": 2,
"state": "connected",
"topics": ["topic2"],
"created_at": "2023-01-01T12:00:00Z",
"last_activity": "2023-01-01T12:01:00Z",
"age_seconds": 60,
"idle_seconds": 10,
"error_count": 0,
"last_ping": "2023-01-01T12:01:00Z",
"last_pong": "2023-01-01T12:01:00Z",
"metadata": {},
"is_alive": True,
"is_stale": False
}
])
mock_get_tracker.return_value = mock_tracker
# Make request
response = admin_client.get("/api/admin/websockets/connections")
# Check response
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["total_count"] == 2
assert data["active_count"] == 2
assert data["stale_count"] == 0
assert len(data["connections"]) == 2
@patch('app.api.admin.get_admin_user')
@patch('app.api.admin.get_websocket_manager')
@patch('app.api.admin.get_connection_tracker')
def test_get_connections_with_filters(self, mock_get_tracker, mock_get_manager, mock_get_admin, admin_client, admin_user):
"""Test WebSocket connections listing with filters"""
# Mock admin authentication
mock_get_admin.return_value = admin_user
# Mock WebSocket manager and tracker
mock_manager = MagicMock()
mock_pool = MagicMock()
mock_pool._connections_lock = AsyncMock()
mock_pool._connections = {"conn1": MagicMock()}
mock_manager.pool = mock_pool
mock_get_manager.return_value = mock_manager
mock_tracker = MagicMock()
# Return connection for user 1 only (user_id filter)
mock_tracker.get_connection_metrics = AsyncMock(return_value={
"connection_id": "conn1",
"user_id": 1,
"state": "connected",
"topics": ["topic1"],
"created_at": "2023-01-01T12:00:00Z",
"last_activity": "2023-01-01T12:01:00Z",
"age_seconds": 60,
"idle_seconds": 10,
"error_count": 0,
"last_ping": "2023-01-01T12:01:00Z",
"last_pong": "2023-01-01T12:01:00Z",
"metadata": {},
"is_alive": True,
"is_stale": False
})
mock_get_tracker.return_value = mock_tracker
# Make request with user_id filter
response = admin_client.get("/api/admin/websockets/connections?user_id=1")
# Check response
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["total_count"] == 1
assert data["connections"][0]["user_id"] == 1
class TestWebSocketDisconnectEndpoint:
"""Test WebSocket disconnect endpoint"""
@patch('app.api.admin.get_admin_user')
@patch('app.api.admin.get_websocket_manager')
def test_disconnect_by_connection_ids(self, mock_get_manager, mock_get_admin, admin_client, admin_user):
"""Test disconnecting specific connections"""
# Mock admin authentication
mock_get_admin.return_value = admin_user
# Mock WebSocket manager
mock_manager = MagicMock()
mock_pool = MagicMock()
mock_pool.remove_connection = AsyncMock()
mock_manager.pool = mock_pool
mock_get_manager.return_value = mock_manager
# Make request
response = admin_client.post(
"/api/admin/websockets/disconnect",
json={
"connection_ids": ["conn1", "conn2"],
"reason": "admin_test"
}
)
# Check response
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["disconnected_count"] == 2
assert data["reason"] == "admin_test"
# Check mock calls
assert mock_pool.remove_connection.call_count == 2
@patch('app.api.admin.get_admin_user')
@patch('app.api.admin.get_websocket_manager')
def test_disconnect_by_user_id(self, mock_get_manager, mock_get_admin, admin_client, admin_user):
"""Test disconnecting all connections for a user"""
# Mock admin authentication
mock_get_admin.return_value = admin_user
# Mock WebSocket manager
mock_manager = MagicMock()
mock_pool = MagicMock()
mock_pool.get_user_connections = AsyncMock(return_value=["conn1", "conn2"])
mock_pool.remove_connection = AsyncMock()
mock_manager.pool = mock_pool
mock_get_manager.return_value = mock_manager
# Make request
response = admin_client.post(
"/api/admin/websockets/disconnect",
json={
"user_id": 1,
"reason": "user_maintenance"
}
)
# Check response
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["disconnected_count"] == 2
def test_disconnect_missing_criteria(self, admin_client):
"""Test disconnect request with missing criteria"""
# Make request without specifying what to disconnect
response = admin_client.post(
"/api/admin/websockets/disconnect",
json={"reason": "test"}
)
# Should fail
assert response.status_code == status.HTTP_400_BAD_REQUEST
class TestWebSocketCleanupEndpoint:
"""Test WebSocket cleanup endpoint"""
@patch('app.api.admin.get_admin_user')
@patch('app.api.admin.get_websocket_manager')
def test_cleanup_websockets(self, mock_get_manager, mock_get_admin, admin_client, admin_user):
"""Test manual WebSocket cleanup"""
# Mock admin authentication
mock_get_admin.return_value = admin_user
# Mock WebSocket manager
mock_manager = MagicMock()
mock_pool = MagicMock()
mock_pool.get_stats = AsyncMock(side_effect=[
{"active_connections": 10}, # Before cleanup
{"active_connections": 8} # After cleanup
])
mock_pool._cleanup_stale_connections = AsyncMock()
mock_manager.pool = mock_pool
mock_get_manager.return_value = mock_manager
# Make request
response = admin_client.post("/api/admin/websockets/cleanup")
# Check response
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["connections_before"] == 10
assert data["connections_after"] == 8
assert data["cleaned_count"] == 2
# Check cleanup was called
mock_pool._cleanup_stale_connections.assert_called_once()
class TestWebSocketBroadcastEndpoint:
"""Test WebSocket broadcast endpoint"""
@patch('app.api.admin.get_admin_user')
@patch('app.api.admin.get_websocket_manager')
def test_broadcast_message(self, mock_get_manager, mock_get_admin, admin_client, admin_user):
"""Test broadcasting a message to a topic"""
# Mock admin authentication
mock_get_admin.return_value = admin_user
# Mock WebSocket manager
mock_manager = MagicMock()
mock_manager.broadcast_to_topic = AsyncMock(return_value=5)
mock_get_manager.return_value = mock_manager
# Make request
response = admin_client.post(
"/api/admin/websockets/broadcast",
json={
"topic": "admin_announcement",
"message_type": "system_message",
"data": {"message": "System maintenance in 5 minutes"}
}
)
# Check response
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["sent_count"] == 5
assert data["topic"] == "admin_announcement"
assert data["message_type"] == "system_message"
# Check broadcast was called correctly
mock_manager.broadcast_to_topic.assert_called_once_with(
topic="admin_announcement",
message_type="system_message",
data={"message": "System maintenance in 5 minutes"}
)
class TestWebSocketConnectionDetailEndpoint:
"""Test individual WebSocket connection detail endpoint"""
@patch('app.api.admin.get_admin_user')
@patch('app.api.admin.get_connection_tracker')
def test_get_connection_detail_success(self, mock_get_tracker, mock_get_admin, admin_client, admin_user):
"""Test getting details for a specific connection"""
# Mock admin authentication
mock_get_admin.return_value = admin_user
# Mock connection tracker
mock_tracker = MagicMock()
mock_tracker.get_connection_metrics = AsyncMock(return_value={
"connection_id": "conn1",
"user_id": 1,
"state": "connected",
"topics": ["topic1", "topic2"],
"created_at": "2023-01-01T12:00:00Z",
"last_activity": "2023-01-01T12:01:00Z",
"age_seconds": 60,
"idle_seconds": 10,
"error_count": 0,
"last_ping": "2023-01-01T12:01:00Z",
"last_pong": "2023-01-01T12:01:00Z",
"metadata": {"endpoint": "batch_progress"},
"is_alive": True,
"is_stale": False
})
mock_get_tracker.return_value = mock_tracker
# Make request
response = admin_client.get("/api/admin/websockets/connections/conn1")
# Check response
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["connection_id"] == "conn1"
assert data["user_id"] == 1
assert data["state"] == "connected"
assert len(data["topics"]) == 2
assert data["metadata"]["endpoint"] == "batch_progress"
@patch('app.api.admin.get_admin_user')
@patch('app.api.admin.get_connection_tracker')
def test_get_connection_detail_not_found(self, mock_get_tracker, mock_get_admin, admin_client, admin_user):
"""Test getting details for non-existent connection"""
# Mock admin authentication
mock_get_admin.return_value = admin_user
# Mock connection tracker returning None
mock_tracker = MagicMock()
mock_tracker.get_connection_metrics = AsyncMock(return_value=None)
mock_get_tracker.return_value = mock_tracker
# Make request
response = admin_client.get("/api/admin/websockets/connections/nonexistent")
# Check response
assert response.status_code == status.HTTP_404_NOT_FOUND
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,607 @@
"""
Comprehensive tests for WebSocket connection pooling and management
Tests cover:
- Connection pool creation and management
- Automatic cleanup of stale connections
- Health monitoring and heartbeats
- Topic-based message broadcasting
- Resource management and memory leak prevention
- Integration with authentication
"""
import asyncio
import pytest
import time
from datetime import datetime, timezone, timedelta
from typing import List, Dict, Any
from unittest.mock import AsyncMock, MagicMock, patch
from fastapi import WebSocket, WebSocketDisconnect
from fastapi.testclient import TestClient
from app.services.websocket_pool import (
WebSocketPool,
ConnectionState,
MessageType,
ConnectionInfo,
WebSocketMessage,
get_websocket_pool,
initialize_websocket_pool,
shutdown_websocket_pool,
websocket_connection
)
from app.middleware.websocket_middleware import (
WebSocketManager,
get_websocket_manager,
WebSocketAuthenticationError
)
class MockWebSocket:
"""Mock WebSocket for testing"""
def __init__(self, mock_user_id: int = None):
self.sent_messages = []
self.received_messages = []
self.closed = False
self.close_code = None
self.close_reason = None
self.mock_user_id = mock_user_id
self.url = MagicMock()
self.url.query = f"token=test_token_{mock_user_id}" if mock_user_id else ""
async def send_json(self, data: Dict[str, Any]):
"""Mock send_json method"""
if self.closed:
raise Exception("WebSocket closed")
self.sent_messages.append(data)
async def send_text(self, text: str):
"""Mock send_text method"""
if self.closed:
raise Exception("WebSocket closed")
self.sent_messages.append({"text": text})
async def receive_text(self) -> str:
"""Mock receive_text method"""
if self.closed:
raise WebSocketDisconnect()
if self.received_messages:
return self.received_messages.pop(0)
# Simulate waiting for message
await asyncio.sleep(0.1)
return "ping"
async def accept(self):
"""Mock accept method"""
pass
async def close(self, code: int = 1000, reason: str = ""):
"""Mock close method"""
self.closed = True
self.close_code = code
self.close_reason = reason
@pytest.fixture
async def websocket_pool():
"""Create a fresh WebSocket pool for testing"""
pool = WebSocketPool(
cleanup_interval=1, # Fast cleanup for testing
connection_timeout=5, # Short timeout for testing
heartbeat_interval=2, # Fast heartbeat for testing
max_connections_per_topic=10,
max_total_connections=100
)
await pool.start()
yield pool
await pool.stop()
@pytest.fixture
def mock_websocket():
"""Create a mock WebSocket for testing"""
return MockWebSocket()
@pytest.fixture
def mock_authenticated_websocket():
"""Create a mock authenticated WebSocket for testing"""
return MockWebSocket(mock_user_id=1)
class TestWebSocketPool:
"""Test the core WebSocket pool functionality"""
async def test_pool_initialization(self, websocket_pool):
"""Test pool initialization and configuration"""
assert websocket_pool.cleanup_interval == 1
assert websocket_pool.connection_timeout == 5
assert websocket_pool.heartbeat_interval == 2
assert websocket_pool.max_connections_per_topic == 10
assert websocket_pool.max_total_connections == 100
stats = await websocket_pool.get_stats()
assert stats["active_connections"] == 0
assert stats["total_topics"] == 0
assert stats["total_users"] == 0
async def test_add_remove_connection(self, websocket_pool, mock_websocket):
"""Test adding and removing connections"""
# Add connection
connection_id = await websocket_pool.add_connection(
websocket=mock_websocket,
user_id=1,
topics={"test_topic"},
metadata={"test": "data"}
)
assert connection_id.startswith("ws_")
# Check connection exists
connection_info = await websocket_pool.get_connection_info(connection_id)
assert connection_info is not None
assert connection_info.user_id == 1
assert "test_topic" in connection_info.topics
assert connection_info.metadata["test"] == "data"
# Check stats
stats = await websocket_pool.get_stats()
assert stats["active_connections"] == 1
assert stats["total_topics"] == 1
assert stats["total_users"] == 1
# Remove connection
await websocket_pool.remove_connection(connection_id, "test_removal")
# Check connection removed
connection_info = await websocket_pool.get_connection_info(connection_id)
assert connection_info is None
stats = await websocket_pool.get_stats()
assert stats["active_connections"] == 0
async def test_topic_subscription(self, websocket_pool, mock_websocket):
"""Test topic subscription and unsubscription"""
connection_id = await websocket_pool.add_connection(
websocket=mock_websocket,
user_id=1,
topics=set()
)
# Subscribe to topic
success = await websocket_pool.subscribe_to_topic(connection_id, "new_topic")
assert success
# Check subscription
topic_connections = await websocket_pool.get_topic_connections("new_topic")
assert connection_id in topic_connections
connection_info = await websocket_pool.get_connection_info(connection_id)
assert "new_topic" in connection_info.topics
# Unsubscribe from topic
success = await websocket_pool.unsubscribe_from_topic(connection_id, "new_topic")
assert success
# Check unsubscription
topic_connections = await websocket_pool.get_topic_connections("new_topic")
assert connection_id not in topic_connections
connection_info = await websocket_pool.get_connection_info(connection_id)
assert "new_topic" not in connection_info.topics
async def test_broadcast_to_topic(self, websocket_pool):
"""Test broadcasting messages to topic subscribers"""
# Create multiple connections
websockets = [MockWebSocket(i) for i in range(3)]
connection_ids = []
for i, ws in enumerate(websockets):
connection_id = await websocket_pool.add_connection(
websocket=ws,
user_id=i,
topics={"broadcast_topic"}
)
connection_ids.append(connection_id)
# Broadcast message
message = WebSocketMessage(
type="test_broadcast",
topic="broadcast_topic",
data={"message": "Hello everyone!"}
)
sent_count = await websocket_pool.broadcast_to_topic(
topic="broadcast_topic",
message=message
)
assert sent_count == 3
# Check all websockets received the message
for ws in websockets:
assert len(ws.sent_messages) == 1
sent_message = ws.sent_messages[0]
assert sent_message["type"] == "test_broadcast"
assert sent_message["data"]["message"] == "Hello everyone!"
async def test_send_to_user(self, websocket_pool):
"""Test sending messages to all connections for a specific user"""
# Create multiple connections for the same user
websockets = [MockWebSocket(1) for _ in range(2)]
connection_ids = []
for ws in websockets:
connection_id = await websocket_pool.add_connection(
websocket=ws,
user_id=1,
topics=set()
)
connection_ids.append(connection_id)
# Add connection for different user
other_ws = MockWebSocket(2)
await websocket_pool.add_connection(
websocket=other_ws,
user_id=2,
topics=set()
)
# Send message to user 1
message = WebSocketMessage(
type="user_message",
data={"message": "Hello user 1!"}
)
sent_count = await websocket_pool.send_to_user(1, message)
assert sent_count == 2
# Check user 1 connections received message
for ws in websockets:
assert len(ws.sent_messages) == 1
sent_message = ws.sent_messages[0]
assert sent_message["type"] == "user_message"
# Check user 2 connection didn't receive message
assert len(other_ws.sent_messages) == 0
async def test_connection_limits(self, websocket_pool):
"""Test connection limits"""
# Test max connections per topic
websockets = []
for i in range(websocket_pool.max_connections_per_topic + 1):
ws = MockWebSocket(i)
websockets.append(ws)
if i < websocket_pool.max_connections_per_topic:
# Should succeed
connection_id = await websocket_pool.add_connection(
websocket=ws,
user_id=i,
topics={"limited_topic"}
)
assert connection_id is not None
else:
# Should fail due to topic limit
with pytest.raises(ValueError, match="Maximum connections per topic"):
await websocket_pool.add_connection(
websocket=ws,
user_id=i,
topics={"limited_topic"}
)
async def test_stale_connection_cleanup(self, websocket_pool):
"""Test automatic cleanup of stale connections"""
# Add connection
mock_ws = MockWebSocket(1)
connection_id = await websocket_pool.add_connection(
websocket=mock_ws,
user_id=1,
topics={"test_topic"}
)
# Manually set last activity to make connection stale
connection_info = await websocket_pool.get_connection_info(connection_id)
connection_info.last_activity = datetime.now(timezone.utc) - timedelta(seconds=10)
# Trigger cleanup
await websocket_pool._cleanup_stale_connections()
# Check connection was cleaned up
connection_info = await websocket_pool.get_connection_info(connection_id)
assert connection_info is None
stats = await websocket_pool.get_stats()
assert stats["active_connections"] == 0
assert stats["connections_cleaned"] > 0
async def test_heartbeat_functionality(self, websocket_pool, mock_websocket):
"""Test heartbeat sending and connection health monitoring"""
connection_id = await websocket_pool.add_connection(
websocket=mock_websocket,
user_id=1,
topics=set()
)
# Trigger heartbeat
await websocket_pool._send_heartbeats()
# Check heartbeat was sent
assert len(mock_websocket.sent_messages) == 1
heartbeat_message = mock_websocket.sent_messages[0]
assert heartbeat_message["type"] == MessageType.HEARTBEAT.value
# Test ping functionality
success = await websocket_pool.ping_connection(connection_id)
assert success
assert len(mock_websocket.sent_messages) == 2
ping_message = mock_websocket.sent_messages[1]
assert ping_message["type"] == MessageType.PING.value
async def test_failed_message_cleanup(self, websocket_pool):
"""Test cleanup of connections when message sending fails"""
# Create a websocket that will fail on send
mock_ws = MockWebSocket(1)
mock_ws.closed = True # Simulate closed connection
connection_id = await websocket_pool.add_connection(
websocket=mock_ws,
user_id=1,
topics={"test_topic"}
)
# Try to broadcast - should fail and clean up connection
message = WebSocketMessage(type="test", data={})
sent_count = await websocket_pool.broadcast_to_topic(
topic="test_topic",
message=message
)
assert sent_count == 0
# Check connection was cleaned up
connection_info = await websocket_pool.get_connection_info(connection_id)
assert connection_info is None
class TestWebSocketManager:
"""Test the WebSocket manager middleware"""
@pytest.fixture
def websocket_manager(self):
"""Create a WebSocket manager for testing"""
return WebSocketManager()
@patch('app.middleware.websocket_middleware.verify_token')
@patch('app.middleware.websocket_middleware.SessionLocal')
async def test_authentication_success(self, mock_session, mock_verify_token, websocket_manager):
"""Test successful WebSocket authentication"""
# Mock successful authentication
mock_verify_token.return_value = "test_user"
mock_user = MagicMock()
mock_user.id = 1
mock_user.username = "test_user"
mock_user.is_active = True
mock_db = MagicMock()
mock_db.query.return_value.filter.return_value.first.return_value = mock_user
mock_session.return_value = mock_db
# Create mock websocket with token
mock_ws = MagicMock()
mock_ws.url.query = "token=valid_token"
user = await websocket_manager.authenticate_websocket(mock_ws)
assert user is not None
assert user.id == 1
assert user.username == "test_user"
@patch('app.middleware.websocket_middleware.verify_token')
async def test_authentication_failure(self, mock_verify_token, websocket_manager):
"""Test failed WebSocket authentication"""
# Mock failed authentication
mock_verify_token.return_value = None
mock_ws = MagicMock()
mock_ws.url.query = "token=invalid_token"
user = await websocket_manager.authenticate_websocket(mock_ws)
assert user is None
class TestWebSocketIntegration:
"""Test WebSocket integration with FastAPI"""
@pytest.mark.asyncio
async def test_websocket_connection_context_manager(self):
"""Test the websocket_connection context manager"""
pool = WebSocketPool()
await pool.start()
try:
mock_ws = MockWebSocket(1)
async with websocket_connection(
websocket=mock_ws,
user_id=1,
topics={"test_topic"},
metadata={"test": "data"}
) as (connection_id, pool_instance):
assert connection_id is not None
assert connection_id.startswith("ws_")
# The context returns the global pool; allow either the same object or equivalent type
assert isinstance(pool_instance, type(pool))
# Check connection exists
connection_info = await pool.get_connection_info(connection_id)
assert connection_info is not None
assert connection_info.user_id == 1
assert "test_topic" in connection_info.topics
# Check connection was cleaned up after context exit
connection_info = await pool.get_connection_info(connection_id)
assert connection_info is None
finally:
await pool.stop()
@pytest.mark.asyncio
async def test_global_pool_management(self):
"""Test global WebSocket pool initialization and shutdown"""
# Initialize global pool
pool = await initialize_websocket_pool(
cleanup_interval=1,
connection_timeout=5
)
assert pool is not None
# Get the same pool instance
same_pool = get_websocket_pool()
assert same_pool is pool
# Shutdown global pool
await shutdown_websocket_pool()
# Should create new pool after shutdown
new_pool = get_websocket_pool()
assert new_pool is not pool
class TestWebSocketStressTest:
"""Stress tests for WebSocket pool under high load"""
@pytest.mark.asyncio
async def test_high_connection_volume(self):
"""Test pool behavior with many concurrent connections"""
pool = WebSocketPool(max_total_connections=1000)
await pool.start()
try:
# Create many connections
connection_ids = []
for i in range(100):
mock_ws = MockWebSocket(i)
connection_id = await pool.add_connection(
websocket=mock_ws,
user_id=i % 10, # 10 different users
topics={f"topic_{i % 5}"} # 5 different topics
)
connection_ids.append(connection_id)
# Check all connections exist
stats = await pool.get_stats()
assert stats["active_connections"] == 100
assert stats["total_topics"] == 5
assert stats["total_users"] == 10
# Broadcast to all topics
for topic_id in range(5):
topic = f"topic_{topic_id}"
message = WebSocketMessage(
type="stress_test",
topic=topic,
data={"topic_id": topic_id}
)
sent_count = await pool.broadcast_to_topic(topic, message)
assert sent_count == 20 # 100 connections / 5 topics
# Clean up all connections
for connection_id in connection_ids:
await pool.remove_connection(connection_id, "stress_test_cleanup")
stats = await pool.get_stats()
assert stats["active_connections"] == 0
finally:
await pool.stop()
@pytest.mark.asyncio
async def test_rapid_connect_disconnect(self):
"""Test rapid connection and disconnection patterns"""
pool = WebSocketPool()
await pool.start()
try:
# Rapidly create and destroy connections
for round_num in range(10):
connection_ids = []
# Create 10 connections
for i in range(10):
mock_ws = MockWebSocket(i)
connection_id = await pool.add_connection(
websocket=mock_ws,
user_id=i,
topics={f"round_{round_num}"}
)
connection_ids.append(connection_id)
# Immediately remove them
for connection_id in connection_ids:
await pool.remove_connection(connection_id, f"round_{round_num}_cleanup")
# Check pool is clean
stats = await pool.get_stats()
assert stats["active_connections"] == 0
finally:
await pool.stop()
# Test utilities and fixtures for integration testing
@pytest.fixture
def websocket_test_client():
"""Create a test client for WebSocket endpoint testing"""
from app.main import app
return TestClient(app)
class TestBillingWebSocketIntegration:
"""Test integration with the billing API WebSocket endpoint"""
@pytest.mark.asyncio
async def test_batch_progress_topic_format(self):
"""Test that batch progress uses correct topic format"""
from app.api.billing import _notify_progress_subscribers
from app.services.batch_generation import BatchProgress
# Mock the WebSocket manager
with patch('app.api.billing.websocket_manager') as mock_manager:
mock_manager.broadcast_to_topic = AsyncMock(return_value=1)
# Create test progress
progress = BatchProgress(
batch_id="test_batch_123",
status="processing",
total_files=10,
processed_files=5,
successful_files=4,
failed_files=1,
current_file="test.txt",
started_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc)
)
# Call notification function
await _notify_progress_subscribers(progress)
# Check correct topic was used
mock_manager.broadcast_to_topic.assert_called_once_with(
topic="batch_progress_test_batch_123",
message_type="progress",
data=progress.model_dump()
)
if __name__ == "__main__":
# Run tests
pytest.main([__file__, "-v"])