changes
This commit is contained in:
607
tests/test_websocket_pool.py
Normal file
607
tests/test_websocket_pool.py
Normal file
@@ -0,0 +1,607 @@
|
||||
"""
|
||||
Comprehensive tests for WebSocket connection pooling and management
|
||||
|
||||
Tests cover:
|
||||
- Connection pool creation and management
|
||||
- Automatic cleanup of stale connections
|
||||
- Health monitoring and heartbeats
|
||||
- Topic-based message broadcasting
|
||||
- Resource management and memory leak prevention
|
||||
- Integration with authentication
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import pytest
|
||||
import time
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import List, Dict, Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from fastapi import WebSocket, WebSocketDisconnect
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.services.websocket_pool import (
|
||||
WebSocketPool,
|
||||
ConnectionState,
|
||||
MessageType,
|
||||
ConnectionInfo,
|
||||
WebSocketMessage,
|
||||
get_websocket_pool,
|
||||
initialize_websocket_pool,
|
||||
shutdown_websocket_pool,
|
||||
websocket_connection
|
||||
)
|
||||
from app.middleware.websocket_middleware import (
|
||||
WebSocketManager,
|
||||
get_websocket_manager,
|
||||
WebSocketAuthenticationError
|
||||
)
|
||||
|
||||
|
||||
class MockWebSocket:
|
||||
"""Mock WebSocket for testing"""
|
||||
|
||||
def __init__(self, mock_user_id: int = None):
|
||||
self.sent_messages = []
|
||||
self.received_messages = []
|
||||
self.closed = False
|
||||
self.close_code = None
|
||||
self.close_reason = None
|
||||
self.mock_user_id = mock_user_id
|
||||
self.url = MagicMock()
|
||||
self.url.query = f"token=test_token_{mock_user_id}" if mock_user_id else ""
|
||||
|
||||
async def send_json(self, data: Dict[str, Any]):
|
||||
"""Mock send_json method"""
|
||||
if self.closed:
|
||||
raise Exception("WebSocket closed")
|
||||
self.sent_messages.append(data)
|
||||
|
||||
async def send_text(self, text: str):
|
||||
"""Mock send_text method"""
|
||||
if self.closed:
|
||||
raise Exception("WebSocket closed")
|
||||
self.sent_messages.append({"text": text})
|
||||
|
||||
async def receive_text(self) -> str:
|
||||
"""Mock receive_text method"""
|
||||
if self.closed:
|
||||
raise WebSocketDisconnect()
|
||||
if self.received_messages:
|
||||
return self.received_messages.pop(0)
|
||||
# Simulate waiting for message
|
||||
await asyncio.sleep(0.1)
|
||||
return "ping"
|
||||
|
||||
async def accept(self):
|
||||
"""Mock accept method"""
|
||||
pass
|
||||
|
||||
async def close(self, code: int = 1000, reason: str = ""):
|
||||
"""Mock close method"""
|
||||
self.closed = True
|
||||
self.close_code = code
|
||||
self.close_reason = reason
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def websocket_pool():
|
||||
"""Create a fresh WebSocket pool for testing"""
|
||||
pool = WebSocketPool(
|
||||
cleanup_interval=1, # Fast cleanup for testing
|
||||
connection_timeout=5, # Short timeout for testing
|
||||
heartbeat_interval=2, # Fast heartbeat for testing
|
||||
max_connections_per_topic=10,
|
||||
max_total_connections=100
|
||||
)
|
||||
await pool.start()
|
||||
yield pool
|
||||
await pool.stop()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_websocket():
|
||||
"""Create a mock WebSocket for testing"""
|
||||
return MockWebSocket()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_authenticated_websocket():
|
||||
"""Create a mock authenticated WebSocket for testing"""
|
||||
return MockWebSocket(mock_user_id=1)
|
||||
|
||||
|
||||
class TestWebSocketPool:
|
||||
"""Test the core WebSocket pool functionality"""
|
||||
|
||||
async def test_pool_initialization(self, websocket_pool):
|
||||
"""Test pool initialization and configuration"""
|
||||
assert websocket_pool.cleanup_interval == 1
|
||||
assert websocket_pool.connection_timeout == 5
|
||||
assert websocket_pool.heartbeat_interval == 2
|
||||
assert websocket_pool.max_connections_per_topic == 10
|
||||
assert websocket_pool.max_total_connections == 100
|
||||
|
||||
stats = await websocket_pool.get_stats()
|
||||
assert stats["active_connections"] == 0
|
||||
assert stats["total_topics"] == 0
|
||||
assert stats["total_users"] == 0
|
||||
|
||||
async def test_add_remove_connection(self, websocket_pool, mock_websocket):
|
||||
"""Test adding and removing connections"""
|
||||
# Add connection
|
||||
connection_id = await websocket_pool.add_connection(
|
||||
websocket=mock_websocket,
|
||||
user_id=1,
|
||||
topics={"test_topic"},
|
||||
metadata={"test": "data"}
|
||||
)
|
||||
|
||||
assert connection_id.startswith("ws_")
|
||||
|
||||
# Check connection exists
|
||||
connection_info = await websocket_pool.get_connection_info(connection_id)
|
||||
assert connection_info is not None
|
||||
assert connection_info.user_id == 1
|
||||
assert "test_topic" in connection_info.topics
|
||||
assert connection_info.metadata["test"] == "data"
|
||||
|
||||
# Check stats
|
||||
stats = await websocket_pool.get_stats()
|
||||
assert stats["active_connections"] == 1
|
||||
assert stats["total_topics"] == 1
|
||||
assert stats["total_users"] == 1
|
||||
|
||||
# Remove connection
|
||||
await websocket_pool.remove_connection(connection_id, "test_removal")
|
||||
|
||||
# Check connection removed
|
||||
connection_info = await websocket_pool.get_connection_info(connection_id)
|
||||
assert connection_info is None
|
||||
|
||||
stats = await websocket_pool.get_stats()
|
||||
assert stats["active_connections"] == 0
|
||||
|
||||
async def test_topic_subscription(self, websocket_pool, mock_websocket):
|
||||
"""Test topic subscription and unsubscription"""
|
||||
connection_id = await websocket_pool.add_connection(
|
||||
websocket=mock_websocket,
|
||||
user_id=1,
|
||||
topics=set()
|
||||
)
|
||||
|
||||
# Subscribe to topic
|
||||
success = await websocket_pool.subscribe_to_topic(connection_id, "new_topic")
|
||||
assert success
|
||||
|
||||
# Check subscription
|
||||
topic_connections = await websocket_pool.get_topic_connections("new_topic")
|
||||
assert connection_id in topic_connections
|
||||
|
||||
connection_info = await websocket_pool.get_connection_info(connection_id)
|
||||
assert "new_topic" in connection_info.topics
|
||||
|
||||
# Unsubscribe from topic
|
||||
success = await websocket_pool.unsubscribe_from_topic(connection_id, "new_topic")
|
||||
assert success
|
||||
|
||||
# Check unsubscription
|
||||
topic_connections = await websocket_pool.get_topic_connections("new_topic")
|
||||
assert connection_id not in topic_connections
|
||||
|
||||
connection_info = await websocket_pool.get_connection_info(connection_id)
|
||||
assert "new_topic" not in connection_info.topics
|
||||
|
||||
async def test_broadcast_to_topic(self, websocket_pool):
|
||||
"""Test broadcasting messages to topic subscribers"""
|
||||
# Create multiple connections
|
||||
websockets = [MockWebSocket(i) for i in range(3)]
|
||||
connection_ids = []
|
||||
|
||||
for i, ws in enumerate(websockets):
|
||||
connection_id = await websocket_pool.add_connection(
|
||||
websocket=ws,
|
||||
user_id=i,
|
||||
topics={"broadcast_topic"}
|
||||
)
|
||||
connection_ids.append(connection_id)
|
||||
|
||||
# Broadcast message
|
||||
message = WebSocketMessage(
|
||||
type="test_broadcast",
|
||||
topic="broadcast_topic",
|
||||
data={"message": "Hello everyone!"}
|
||||
)
|
||||
|
||||
sent_count = await websocket_pool.broadcast_to_topic(
|
||||
topic="broadcast_topic",
|
||||
message=message
|
||||
)
|
||||
|
||||
assert sent_count == 3
|
||||
|
||||
# Check all websockets received the message
|
||||
for ws in websockets:
|
||||
assert len(ws.sent_messages) == 1
|
||||
sent_message = ws.sent_messages[0]
|
||||
assert sent_message["type"] == "test_broadcast"
|
||||
assert sent_message["data"]["message"] == "Hello everyone!"
|
||||
|
||||
async def test_send_to_user(self, websocket_pool):
|
||||
"""Test sending messages to all connections for a specific user"""
|
||||
# Create multiple connections for the same user
|
||||
websockets = [MockWebSocket(1) for _ in range(2)]
|
||||
connection_ids = []
|
||||
|
||||
for ws in websockets:
|
||||
connection_id = await websocket_pool.add_connection(
|
||||
websocket=ws,
|
||||
user_id=1,
|
||||
topics=set()
|
||||
)
|
||||
connection_ids.append(connection_id)
|
||||
|
||||
# Add connection for different user
|
||||
other_ws = MockWebSocket(2)
|
||||
await websocket_pool.add_connection(
|
||||
websocket=other_ws,
|
||||
user_id=2,
|
||||
topics=set()
|
||||
)
|
||||
|
||||
# Send message to user 1
|
||||
message = WebSocketMessage(
|
||||
type="user_message",
|
||||
data={"message": "Hello user 1!"}
|
||||
)
|
||||
|
||||
sent_count = await websocket_pool.send_to_user(1, message)
|
||||
|
||||
assert sent_count == 2
|
||||
|
||||
# Check user 1 connections received message
|
||||
for ws in websockets:
|
||||
assert len(ws.sent_messages) == 1
|
||||
sent_message = ws.sent_messages[0]
|
||||
assert sent_message["type"] == "user_message"
|
||||
|
||||
# Check user 2 connection didn't receive message
|
||||
assert len(other_ws.sent_messages) == 0
|
||||
|
||||
async def test_connection_limits(self, websocket_pool):
|
||||
"""Test connection limits"""
|
||||
# Test max connections per topic
|
||||
websockets = []
|
||||
for i in range(websocket_pool.max_connections_per_topic + 1):
|
||||
ws = MockWebSocket(i)
|
||||
websockets.append(ws)
|
||||
|
||||
if i < websocket_pool.max_connections_per_topic:
|
||||
# Should succeed
|
||||
connection_id = await websocket_pool.add_connection(
|
||||
websocket=ws,
|
||||
user_id=i,
|
||||
topics={"limited_topic"}
|
||||
)
|
||||
assert connection_id is not None
|
||||
else:
|
||||
# Should fail due to topic limit
|
||||
with pytest.raises(ValueError, match="Maximum connections per topic"):
|
||||
await websocket_pool.add_connection(
|
||||
websocket=ws,
|
||||
user_id=i,
|
||||
topics={"limited_topic"}
|
||||
)
|
||||
|
||||
async def test_stale_connection_cleanup(self, websocket_pool):
|
||||
"""Test automatic cleanup of stale connections"""
|
||||
# Add connection
|
||||
mock_ws = MockWebSocket(1)
|
||||
connection_id = await websocket_pool.add_connection(
|
||||
websocket=mock_ws,
|
||||
user_id=1,
|
||||
topics={"test_topic"}
|
||||
)
|
||||
|
||||
# Manually set last activity to make connection stale
|
||||
connection_info = await websocket_pool.get_connection_info(connection_id)
|
||||
connection_info.last_activity = datetime.now(timezone.utc) - timedelta(seconds=10)
|
||||
|
||||
# Trigger cleanup
|
||||
await websocket_pool._cleanup_stale_connections()
|
||||
|
||||
# Check connection was cleaned up
|
||||
connection_info = await websocket_pool.get_connection_info(connection_id)
|
||||
assert connection_info is None
|
||||
|
||||
stats = await websocket_pool.get_stats()
|
||||
assert stats["active_connections"] == 0
|
||||
assert stats["connections_cleaned"] > 0
|
||||
|
||||
async def test_heartbeat_functionality(self, websocket_pool, mock_websocket):
|
||||
"""Test heartbeat sending and connection health monitoring"""
|
||||
connection_id = await websocket_pool.add_connection(
|
||||
websocket=mock_websocket,
|
||||
user_id=1,
|
||||
topics=set()
|
||||
)
|
||||
|
||||
# Trigger heartbeat
|
||||
await websocket_pool._send_heartbeats()
|
||||
|
||||
# Check heartbeat was sent
|
||||
assert len(mock_websocket.sent_messages) == 1
|
||||
heartbeat_message = mock_websocket.sent_messages[0]
|
||||
assert heartbeat_message["type"] == MessageType.HEARTBEAT.value
|
||||
|
||||
# Test ping functionality
|
||||
success = await websocket_pool.ping_connection(connection_id)
|
||||
assert success
|
||||
|
||||
assert len(mock_websocket.sent_messages) == 2
|
||||
ping_message = mock_websocket.sent_messages[1]
|
||||
assert ping_message["type"] == MessageType.PING.value
|
||||
|
||||
async def test_failed_message_cleanup(self, websocket_pool):
|
||||
"""Test cleanup of connections when message sending fails"""
|
||||
# Create a websocket that will fail on send
|
||||
mock_ws = MockWebSocket(1)
|
||||
mock_ws.closed = True # Simulate closed connection
|
||||
|
||||
connection_id = await websocket_pool.add_connection(
|
||||
websocket=mock_ws,
|
||||
user_id=1,
|
||||
topics={"test_topic"}
|
||||
)
|
||||
|
||||
# Try to broadcast - should fail and clean up connection
|
||||
message = WebSocketMessage(type="test", data={})
|
||||
sent_count = await websocket_pool.broadcast_to_topic(
|
||||
topic="test_topic",
|
||||
message=message
|
||||
)
|
||||
|
||||
assert sent_count == 0
|
||||
|
||||
# Check connection was cleaned up
|
||||
connection_info = await websocket_pool.get_connection_info(connection_id)
|
||||
assert connection_info is None
|
||||
|
||||
|
||||
class TestWebSocketManager:
|
||||
"""Test the WebSocket manager middleware"""
|
||||
|
||||
@pytest.fixture
|
||||
def websocket_manager(self):
|
||||
"""Create a WebSocket manager for testing"""
|
||||
return WebSocketManager()
|
||||
|
||||
@patch('app.middleware.websocket_middleware.verify_token')
|
||||
@patch('app.middleware.websocket_middleware.SessionLocal')
|
||||
async def test_authentication_success(self, mock_session, mock_verify_token, websocket_manager):
|
||||
"""Test successful WebSocket authentication"""
|
||||
# Mock successful authentication
|
||||
mock_verify_token.return_value = "test_user"
|
||||
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 1
|
||||
mock_user.username = "test_user"
|
||||
mock_user.is_active = True
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.query.return_value.filter.return_value.first.return_value = mock_user
|
||||
mock_session.return_value = mock_db
|
||||
|
||||
# Create mock websocket with token
|
||||
mock_ws = MagicMock()
|
||||
mock_ws.url.query = "token=valid_token"
|
||||
|
||||
user = await websocket_manager.authenticate_websocket(mock_ws)
|
||||
|
||||
assert user is not None
|
||||
assert user.id == 1
|
||||
assert user.username == "test_user"
|
||||
|
||||
@patch('app.middleware.websocket_middleware.verify_token')
|
||||
async def test_authentication_failure(self, mock_verify_token, websocket_manager):
|
||||
"""Test failed WebSocket authentication"""
|
||||
# Mock failed authentication
|
||||
mock_verify_token.return_value = None
|
||||
|
||||
mock_ws = MagicMock()
|
||||
mock_ws.url.query = "token=invalid_token"
|
||||
|
||||
user = await websocket_manager.authenticate_websocket(mock_ws)
|
||||
|
||||
assert user is None
|
||||
|
||||
|
||||
class TestWebSocketIntegration:
|
||||
"""Test WebSocket integration with FastAPI"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_connection_context_manager(self):
|
||||
"""Test the websocket_connection context manager"""
|
||||
pool = WebSocketPool()
|
||||
await pool.start()
|
||||
|
||||
try:
|
||||
mock_ws = MockWebSocket(1)
|
||||
|
||||
async with websocket_connection(
|
||||
websocket=mock_ws,
|
||||
user_id=1,
|
||||
topics={"test_topic"},
|
||||
metadata={"test": "data"}
|
||||
) as (connection_id, pool_instance):
|
||||
|
||||
assert connection_id is not None
|
||||
assert connection_id.startswith("ws_")
|
||||
# The context returns the global pool; allow either the same object or equivalent type
|
||||
assert isinstance(pool_instance, type(pool))
|
||||
|
||||
# Check connection exists
|
||||
connection_info = await pool.get_connection_info(connection_id)
|
||||
assert connection_info is not None
|
||||
assert connection_info.user_id == 1
|
||||
assert "test_topic" in connection_info.topics
|
||||
|
||||
# Check connection was cleaned up after context exit
|
||||
connection_info = await pool.get_connection_info(connection_id)
|
||||
assert connection_info is None
|
||||
|
||||
finally:
|
||||
await pool.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_pool_management(self):
|
||||
"""Test global WebSocket pool initialization and shutdown"""
|
||||
# Initialize global pool
|
||||
pool = await initialize_websocket_pool(
|
||||
cleanup_interval=1,
|
||||
connection_timeout=5
|
||||
)
|
||||
|
||||
assert pool is not None
|
||||
|
||||
# Get the same pool instance
|
||||
same_pool = get_websocket_pool()
|
||||
assert same_pool is pool
|
||||
|
||||
# Shutdown global pool
|
||||
await shutdown_websocket_pool()
|
||||
|
||||
# Should create new pool after shutdown
|
||||
new_pool = get_websocket_pool()
|
||||
assert new_pool is not pool
|
||||
|
||||
|
||||
class TestWebSocketStressTest:
|
||||
"""Stress tests for WebSocket pool under high load"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_high_connection_volume(self):
|
||||
"""Test pool behavior with many concurrent connections"""
|
||||
pool = WebSocketPool(max_total_connections=1000)
|
||||
await pool.start()
|
||||
|
||||
try:
|
||||
# Create many connections
|
||||
connection_ids = []
|
||||
for i in range(100):
|
||||
mock_ws = MockWebSocket(i)
|
||||
connection_id = await pool.add_connection(
|
||||
websocket=mock_ws,
|
||||
user_id=i % 10, # 10 different users
|
||||
topics={f"topic_{i % 5}"} # 5 different topics
|
||||
)
|
||||
connection_ids.append(connection_id)
|
||||
|
||||
# Check all connections exist
|
||||
stats = await pool.get_stats()
|
||||
assert stats["active_connections"] == 100
|
||||
assert stats["total_topics"] == 5
|
||||
assert stats["total_users"] == 10
|
||||
|
||||
# Broadcast to all topics
|
||||
for topic_id in range(5):
|
||||
topic = f"topic_{topic_id}"
|
||||
message = WebSocketMessage(
|
||||
type="stress_test",
|
||||
topic=topic,
|
||||
data={"topic_id": topic_id}
|
||||
)
|
||||
sent_count = await pool.broadcast_to_topic(topic, message)
|
||||
assert sent_count == 20 # 100 connections / 5 topics
|
||||
|
||||
# Clean up all connections
|
||||
for connection_id in connection_ids:
|
||||
await pool.remove_connection(connection_id, "stress_test_cleanup")
|
||||
|
||||
stats = await pool.get_stats()
|
||||
assert stats["active_connections"] == 0
|
||||
|
||||
finally:
|
||||
await pool.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rapid_connect_disconnect(self):
|
||||
"""Test rapid connection and disconnection patterns"""
|
||||
pool = WebSocketPool()
|
||||
await pool.start()
|
||||
|
||||
try:
|
||||
# Rapidly create and destroy connections
|
||||
for round_num in range(10):
|
||||
connection_ids = []
|
||||
|
||||
# Create 10 connections
|
||||
for i in range(10):
|
||||
mock_ws = MockWebSocket(i)
|
||||
connection_id = await pool.add_connection(
|
||||
websocket=mock_ws,
|
||||
user_id=i,
|
||||
topics={f"round_{round_num}"}
|
||||
)
|
||||
connection_ids.append(connection_id)
|
||||
|
||||
# Immediately remove them
|
||||
for connection_id in connection_ids:
|
||||
await pool.remove_connection(connection_id, f"round_{round_num}_cleanup")
|
||||
|
||||
# Check pool is clean
|
||||
stats = await pool.get_stats()
|
||||
assert stats["active_connections"] == 0
|
||||
|
||||
finally:
|
||||
await pool.stop()
|
||||
|
||||
|
||||
# Test utilities and fixtures for integration testing
|
||||
@pytest.fixture
|
||||
def websocket_test_client():
|
||||
"""Create a test client for WebSocket endpoint testing"""
|
||||
from app.main import app
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
class TestBillingWebSocketIntegration:
|
||||
"""Test integration with the billing API WebSocket endpoint"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_progress_topic_format(self):
|
||||
"""Test that batch progress uses correct topic format"""
|
||||
from app.api.billing import _notify_progress_subscribers
|
||||
from app.services.batch_generation import BatchProgress
|
||||
|
||||
# Mock the WebSocket manager
|
||||
with patch('app.api.billing.websocket_manager') as mock_manager:
|
||||
mock_manager.broadcast_to_topic = AsyncMock(return_value=1)
|
||||
|
||||
# Create test progress
|
||||
progress = BatchProgress(
|
||||
batch_id="test_batch_123",
|
||||
status="processing",
|
||||
total_files=10,
|
||||
processed_files=5,
|
||||
successful_files=4,
|
||||
failed_files=1,
|
||||
current_file="test.txt",
|
||||
started_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
# Call notification function
|
||||
await _notify_progress_subscribers(progress)
|
||||
|
||||
# Check correct topic was used
|
||||
mock_manager.broadcast_to_topic.assert_called_once_with(
|
||||
topic="batch_progress_test_batch_123",
|
||||
message_type="progress",
|
||||
data=progress.model_dump()
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run tests
|
||||
pytest.main([__file__, "-v"])
|
||||
Reference in New Issue
Block a user