This commit is contained in:
HotSwapp
2025-08-18 20:20:04 -05:00
parent 89b2bc0aa2
commit bac8cc4bd5
114 changed files with 30258 additions and 1341 deletions

View File

@@ -0,0 +1,607 @@
"""
Comprehensive tests for WebSocket connection pooling and management
Tests cover:
- Connection pool creation and management
- Automatic cleanup of stale connections
- Health monitoring and heartbeats
- Topic-based message broadcasting
- Resource management and memory leak prevention
- Integration with authentication
"""
import asyncio
import pytest
import time
from datetime import datetime, timezone, timedelta
from typing import List, Dict, Any
from unittest.mock import AsyncMock, MagicMock, patch
from fastapi import WebSocket, WebSocketDisconnect
from fastapi.testclient import TestClient
from app.services.websocket_pool import (
WebSocketPool,
ConnectionState,
MessageType,
ConnectionInfo,
WebSocketMessage,
get_websocket_pool,
initialize_websocket_pool,
shutdown_websocket_pool,
websocket_connection
)
from app.middleware.websocket_middleware import (
WebSocketManager,
get_websocket_manager,
WebSocketAuthenticationError
)
class MockWebSocket:
"""Mock WebSocket for testing"""
def __init__(self, mock_user_id: int = None):
self.sent_messages = []
self.received_messages = []
self.closed = False
self.close_code = None
self.close_reason = None
self.mock_user_id = mock_user_id
self.url = MagicMock()
self.url.query = f"token=test_token_{mock_user_id}" if mock_user_id else ""
async def send_json(self, data: Dict[str, Any]):
"""Mock send_json method"""
if self.closed:
raise Exception("WebSocket closed")
self.sent_messages.append(data)
async def send_text(self, text: str):
"""Mock send_text method"""
if self.closed:
raise Exception("WebSocket closed")
self.sent_messages.append({"text": text})
async def receive_text(self) -> str:
"""Mock receive_text method"""
if self.closed:
raise WebSocketDisconnect()
if self.received_messages:
return self.received_messages.pop(0)
# Simulate waiting for message
await asyncio.sleep(0.1)
return "ping"
async def accept(self):
"""Mock accept method"""
pass
async def close(self, code: int = 1000, reason: str = ""):
"""Mock close method"""
self.closed = True
self.close_code = code
self.close_reason = reason
@pytest.fixture
async def websocket_pool():
"""Create a fresh WebSocket pool for testing"""
pool = WebSocketPool(
cleanup_interval=1, # Fast cleanup for testing
connection_timeout=5, # Short timeout for testing
heartbeat_interval=2, # Fast heartbeat for testing
max_connections_per_topic=10,
max_total_connections=100
)
await pool.start()
yield pool
await pool.stop()
@pytest.fixture
def mock_websocket():
"""Create a mock WebSocket for testing"""
return MockWebSocket()
@pytest.fixture
def mock_authenticated_websocket():
"""Create a mock authenticated WebSocket for testing"""
return MockWebSocket(mock_user_id=1)
class TestWebSocketPool:
"""Test the core WebSocket pool functionality"""
async def test_pool_initialization(self, websocket_pool):
"""Test pool initialization and configuration"""
assert websocket_pool.cleanup_interval == 1
assert websocket_pool.connection_timeout == 5
assert websocket_pool.heartbeat_interval == 2
assert websocket_pool.max_connections_per_topic == 10
assert websocket_pool.max_total_connections == 100
stats = await websocket_pool.get_stats()
assert stats["active_connections"] == 0
assert stats["total_topics"] == 0
assert stats["total_users"] == 0
async def test_add_remove_connection(self, websocket_pool, mock_websocket):
"""Test adding and removing connections"""
# Add connection
connection_id = await websocket_pool.add_connection(
websocket=mock_websocket,
user_id=1,
topics={"test_topic"},
metadata={"test": "data"}
)
assert connection_id.startswith("ws_")
# Check connection exists
connection_info = await websocket_pool.get_connection_info(connection_id)
assert connection_info is not None
assert connection_info.user_id == 1
assert "test_topic" in connection_info.topics
assert connection_info.metadata["test"] == "data"
# Check stats
stats = await websocket_pool.get_stats()
assert stats["active_connections"] == 1
assert stats["total_topics"] == 1
assert stats["total_users"] == 1
# Remove connection
await websocket_pool.remove_connection(connection_id, "test_removal")
# Check connection removed
connection_info = await websocket_pool.get_connection_info(connection_id)
assert connection_info is None
stats = await websocket_pool.get_stats()
assert stats["active_connections"] == 0
async def test_topic_subscription(self, websocket_pool, mock_websocket):
"""Test topic subscription and unsubscription"""
connection_id = await websocket_pool.add_connection(
websocket=mock_websocket,
user_id=1,
topics=set()
)
# Subscribe to topic
success = await websocket_pool.subscribe_to_topic(connection_id, "new_topic")
assert success
# Check subscription
topic_connections = await websocket_pool.get_topic_connections("new_topic")
assert connection_id in topic_connections
connection_info = await websocket_pool.get_connection_info(connection_id)
assert "new_topic" in connection_info.topics
# Unsubscribe from topic
success = await websocket_pool.unsubscribe_from_topic(connection_id, "new_topic")
assert success
# Check unsubscription
topic_connections = await websocket_pool.get_topic_connections("new_topic")
assert connection_id not in topic_connections
connection_info = await websocket_pool.get_connection_info(connection_id)
assert "new_topic" not in connection_info.topics
async def test_broadcast_to_topic(self, websocket_pool):
"""Test broadcasting messages to topic subscribers"""
# Create multiple connections
websockets = [MockWebSocket(i) for i in range(3)]
connection_ids = []
for i, ws in enumerate(websockets):
connection_id = await websocket_pool.add_connection(
websocket=ws,
user_id=i,
topics={"broadcast_topic"}
)
connection_ids.append(connection_id)
# Broadcast message
message = WebSocketMessage(
type="test_broadcast",
topic="broadcast_topic",
data={"message": "Hello everyone!"}
)
sent_count = await websocket_pool.broadcast_to_topic(
topic="broadcast_topic",
message=message
)
assert sent_count == 3
# Check all websockets received the message
for ws in websockets:
assert len(ws.sent_messages) == 1
sent_message = ws.sent_messages[0]
assert sent_message["type"] == "test_broadcast"
assert sent_message["data"]["message"] == "Hello everyone!"
async def test_send_to_user(self, websocket_pool):
"""Test sending messages to all connections for a specific user"""
# Create multiple connections for the same user
websockets = [MockWebSocket(1) for _ in range(2)]
connection_ids = []
for ws in websockets:
connection_id = await websocket_pool.add_connection(
websocket=ws,
user_id=1,
topics=set()
)
connection_ids.append(connection_id)
# Add connection for different user
other_ws = MockWebSocket(2)
await websocket_pool.add_connection(
websocket=other_ws,
user_id=2,
topics=set()
)
# Send message to user 1
message = WebSocketMessage(
type="user_message",
data={"message": "Hello user 1!"}
)
sent_count = await websocket_pool.send_to_user(1, message)
assert sent_count == 2
# Check user 1 connections received message
for ws in websockets:
assert len(ws.sent_messages) == 1
sent_message = ws.sent_messages[0]
assert sent_message["type"] == "user_message"
# Check user 2 connection didn't receive message
assert len(other_ws.sent_messages) == 0
async def test_connection_limits(self, websocket_pool):
"""Test connection limits"""
# Test max connections per topic
websockets = []
for i in range(websocket_pool.max_connections_per_topic + 1):
ws = MockWebSocket(i)
websockets.append(ws)
if i < websocket_pool.max_connections_per_topic:
# Should succeed
connection_id = await websocket_pool.add_connection(
websocket=ws,
user_id=i,
topics={"limited_topic"}
)
assert connection_id is not None
else:
# Should fail due to topic limit
with pytest.raises(ValueError, match="Maximum connections per topic"):
await websocket_pool.add_connection(
websocket=ws,
user_id=i,
topics={"limited_topic"}
)
async def test_stale_connection_cleanup(self, websocket_pool):
"""Test automatic cleanup of stale connections"""
# Add connection
mock_ws = MockWebSocket(1)
connection_id = await websocket_pool.add_connection(
websocket=mock_ws,
user_id=1,
topics={"test_topic"}
)
# Manually set last activity to make connection stale
connection_info = await websocket_pool.get_connection_info(connection_id)
connection_info.last_activity = datetime.now(timezone.utc) - timedelta(seconds=10)
# Trigger cleanup
await websocket_pool._cleanup_stale_connections()
# Check connection was cleaned up
connection_info = await websocket_pool.get_connection_info(connection_id)
assert connection_info is None
stats = await websocket_pool.get_stats()
assert stats["active_connections"] == 0
assert stats["connections_cleaned"] > 0
async def test_heartbeat_functionality(self, websocket_pool, mock_websocket):
"""Test heartbeat sending and connection health monitoring"""
connection_id = await websocket_pool.add_connection(
websocket=mock_websocket,
user_id=1,
topics=set()
)
# Trigger heartbeat
await websocket_pool._send_heartbeats()
# Check heartbeat was sent
assert len(mock_websocket.sent_messages) == 1
heartbeat_message = mock_websocket.sent_messages[0]
assert heartbeat_message["type"] == MessageType.HEARTBEAT.value
# Test ping functionality
success = await websocket_pool.ping_connection(connection_id)
assert success
assert len(mock_websocket.sent_messages) == 2
ping_message = mock_websocket.sent_messages[1]
assert ping_message["type"] == MessageType.PING.value
async def test_failed_message_cleanup(self, websocket_pool):
"""Test cleanup of connections when message sending fails"""
# Create a websocket that will fail on send
mock_ws = MockWebSocket(1)
mock_ws.closed = True # Simulate closed connection
connection_id = await websocket_pool.add_connection(
websocket=mock_ws,
user_id=1,
topics={"test_topic"}
)
# Try to broadcast - should fail and clean up connection
message = WebSocketMessage(type="test", data={})
sent_count = await websocket_pool.broadcast_to_topic(
topic="test_topic",
message=message
)
assert sent_count == 0
# Check connection was cleaned up
connection_info = await websocket_pool.get_connection_info(connection_id)
assert connection_info is None
class TestWebSocketManager:
"""Test the WebSocket manager middleware"""
@pytest.fixture
def websocket_manager(self):
"""Create a WebSocket manager for testing"""
return WebSocketManager()
@patch('app.middleware.websocket_middleware.verify_token')
@patch('app.middleware.websocket_middleware.SessionLocal')
async def test_authentication_success(self, mock_session, mock_verify_token, websocket_manager):
"""Test successful WebSocket authentication"""
# Mock successful authentication
mock_verify_token.return_value = "test_user"
mock_user = MagicMock()
mock_user.id = 1
mock_user.username = "test_user"
mock_user.is_active = True
mock_db = MagicMock()
mock_db.query.return_value.filter.return_value.first.return_value = mock_user
mock_session.return_value = mock_db
# Create mock websocket with token
mock_ws = MagicMock()
mock_ws.url.query = "token=valid_token"
user = await websocket_manager.authenticate_websocket(mock_ws)
assert user is not None
assert user.id == 1
assert user.username == "test_user"
@patch('app.middleware.websocket_middleware.verify_token')
async def test_authentication_failure(self, mock_verify_token, websocket_manager):
"""Test failed WebSocket authentication"""
# Mock failed authentication
mock_verify_token.return_value = None
mock_ws = MagicMock()
mock_ws.url.query = "token=invalid_token"
user = await websocket_manager.authenticate_websocket(mock_ws)
assert user is None
class TestWebSocketIntegration:
"""Test WebSocket integration with FastAPI"""
@pytest.mark.asyncio
async def test_websocket_connection_context_manager(self):
"""Test the websocket_connection context manager"""
pool = WebSocketPool()
await pool.start()
try:
mock_ws = MockWebSocket(1)
async with websocket_connection(
websocket=mock_ws,
user_id=1,
topics={"test_topic"},
metadata={"test": "data"}
) as (connection_id, pool_instance):
assert connection_id is not None
assert connection_id.startswith("ws_")
# The context returns the global pool; allow either the same object or equivalent type
assert isinstance(pool_instance, type(pool))
# Check connection exists
connection_info = await pool.get_connection_info(connection_id)
assert connection_info is not None
assert connection_info.user_id == 1
assert "test_topic" in connection_info.topics
# Check connection was cleaned up after context exit
connection_info = await pool.get_connection_info(connection_id)
assert connection_info is None
finally:
await pool.stop()
@pytest.mark.asyncio
async def test_global_pool_management(self):
"""Test global WebSocket pool initialization and shutdown"""
# Initialize global pool
pool = await initialize_websocket_pool(
cleanup_interval=1,
connection_timeout=5
)
assert pool is not None
# Get the same pool instance
same_pool = get_websocket_pool()
assert same_pool is pool
# Shutdown global pool
await shutdown_websocket_pool()
# Should create new pool after shutdown
new_pool = get_websocket_pool()
assert new_pool is not pool
class TestWebSocketStressTest:
"""Stress tests for WebSocket pool under high load"""
@pytest.mark.asyncio
async def test_high_connection_volume(self):
"""Test pool behavior with many concurrent connections"""
pool = WebSocketPool(max_total_connections=1000)
await pool.start()
try:
# Create many connections
connection_ids = []
for i in range(100):
mock_ws = MockWebSocket(i)
connection_id = await pool.add_connection(
websocket=mock_ws,
user_id=i % 10, # 10 different users
topics={f"topic_{i % 5}"} # 5 different topics
)
connection_ids.append(connection_id)
# Check all connections exist
stats = await pool.get_stats()
assert stats["active_connections"] == 100
assert stats["total_topics"] == 5
assert stats["total_users"] == 10
# Broadcast to all topics
for topic_id in range(5):
topic = f"topic_{topic_id}"
message = WebSocketMessage(
type="stress_test",
topic=topic,
data={"topic_id": topic_id}
)
sent_count = await pool.broadcast_to_topic(topic, message)
assert sent_count == 20 # 100 connections / 5 topics
# Clean up all connections
for connection_id in connection_ids:
await pool.remove_connection(connection_id, "stress_test_cleanup")
stats = await pool.get_stats()
assert stats["active_connections"] == 0
finally:
await pool.stop()
@pytest.mark.asyncio
async def test_rapid_connect_disconnect(self):
"""Test rapid connection and disconnection patterns"""
pool = WebSocketPool()
await pool.start()
try:
# Rapidly create and destroy connections
for round_num in range(10):
connection_ids = []
# Create 10 connections
for i in range(10):
mock_ws = MockWebSocket(i)
connection_id = await pool.add_connection(
websocket=mock_ws,
user_id=i,
topics={f"round_{round_num}"}
)
connection_ids.append(connection_id)
# Immediately remove them
for connection_id in connection_ids:
await pool.remove_connection(connection_id, f"round_{round_num}_cleanup")
# Check pool is clean
stats = await pool.get_stats()
assert stats["active_connections"] == 0
finally:
await pool.stop()
# Test utilities and fixtures for integration testing
@pytest.fixture
def websocket_test_client():
"""Create a test client for WebSocket endpoint testing"""
from app.main import app
return TestClient(app)
class TestBillingWebSocketIntegration:
"""Test integration with the billing API WebSocket endpoint"""
@pytest.mark.asyncio
async def test_batch_progress_topic_format(self):
"""Test that batch progress uses correct topic format"""
from app.api.billing import _notify_progress_subscribers
from app.services.batch_generation import BatchProgress
# Mock the WebSocket manager
with patch('app.api.billing.websocket_manager') as mock_manager:
mock_manager.broadcast_to_topic = AsyncMock(return_value=1)
# Create test progress
progress = BatchProgress(
batch_id="test_batch_123",
status="processing",
total_files=10,
processed_files=5,
successful_files=4,
failed_files=1,
current_file="test.txt",
started_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc)
)
# Call notification function
await _notify_progress_subscribers(progress)
# Check correct topic was used
mock_manager.broadcast_to_topic.assert_called_once_with(
topic="batch_progress_test_batch_123",
message_type="progress",
data=progress.model_dump()
)
if __name__ == "__main__":
# Run tests
pytest.main([__file__, "-v"])