""" Comprehensive tests for WebSocket connection pooling and management Tests cover: - Connection pool creation and management - Automatic cleanup of stale connections - Health monitoring and heartbeats - Topic-based message broadcasting - Resource management and memory leak prevention - Integration with authentication """ import asyncio import pytest import time from datetime import datetime, timezone, timedelta from typing import List, Dict, Any from unittest.mock import AsyncMock, MagicMock, patch from fastapi import WebSocket, WebSocketDisconnect from fastapi.testclient import TestClient from app.services.websocket_pool import ( WebSocketPool, ConnectionState, MessageType, ConnectionInfo, WebSocketMessage, get_websocket_pool, initialize_websocket_pool, shutdown_websocket_pool, websocket_connection ) from app.middleware.websocket_middleware import ( WebSocketManager, get_websocket_manager, WebSocketAuthenticationError ) class MockWebSocket: """Mock WebSocket for testing""" def __init__(self, mock_user_id: int = None): self.sent_messages = [] self.received_messages = [] self.closed = False self.close_code = None self.close_reason = None self.mock_user_id = mock_user_id self.url = MagicMock() self.url.query = f"token=test_token_{mock_user_id}" if mock_user_id else "" async def send_json(self, data: Dict[str, Any]): """Mock send_json method""" if self.closed: raise Exception("WebSocket closed") self.sent_messages.append(data) async def send_text(self, text: str): """Mock send_text method""" if self.closed: raise Exception("WebSocket closed") self.sent_messages.append({"text": text}) async def receive_text(self) -> str: """Mock receive_text method""" if self.closed: raise WebSocketDisconnect() if self.received_messages: return self.received_messages.pop(0) # Simulate waiting for message await asyncio.sleep(0.1) return "ping" async def accept(self): """Mock accept method""" pass async def close(self, code: int = 1000, reason: str = ""): """Mock close method""" self.closed = True self.close_code = code self.close_reason = reason @pytest.fixture async def websocket_pool(): """Create a fresh WebSocket pool for testing""" pool = WebSocketPool( cleanup_interval=1, # Fast cleanup for testing connection_timeout=5, # Short timeout for testing heartbeat_interval=2, # Fast heartbeat for testing max_connections_per_topic=10, max_total_connections=100 ) await pool.start() yield pool await pool.stop() @pytest.fixture def mock_websocket(): """Create a mock WebSocket for testing""" return MockWebSocket() @pytest.fixture def mock_authenticated_websocket(): """Create a mock authenticated WebSocket for testing""" return MockWebSocket(mock_user_id=1) class TestWebSocketPool: """Test the core WebSocket pool functionality""" async def test_pool_initialization(self, websocket_pool): """Test pool initialization and configuration""" assert websocket_pool.cleanup_interval == 1 assert websocket_pool.connection_timeout == 5 assert websocket_pool.heartbeat_interval == 2 assert websocket_pool.max_connections_per_topic == 10 assert websocket_pool.max_total_connections == 100 stats = await websocket_pool.get_stats() assert stats["active_connections"] == 0 assert stats["total_topics"] == 0 assert stats["total_users"] == 0 async def test_add_remove_connection(self, websocket_pool, mock_websocket): """Test adding and removing connections""" # Add connection connection_id = await websocket_pool.add_connection( websocket=mock_websocket, user_id=1, topics={"test_topic"}, metadata={"test": "data"} ) assert connection_id.startswith("ws_") # Check connection exists connection_info = await websocket_pool.get_connection_info(connection_id) assert connection_info is not None assert connection_info.user_id == 1 assert "test_topic" in connection_info.topics assert connection_info.metadata["test"] == "data" # Check stats stats = await websocket_pool.get_stats() assert stats["active_connections"] == 1 assert stats["total_topics"] == 1 assert stats["total_users"] == 1 # Remove connection await websocket_pool.remove_connection(connection_id, "test_removal") # Check connection removed connection_info = await websocket_pool.get_connection_info(connection_id) assert connection_info is None stats = await websocket_pool.get_stats() assert stats["active_connections"] == 0 async def test_topic_subscription(self, websocket_pool, mock_websocket): """Test topic subscription and unsubscription""" connection_id = await websocket_pool.add_connection( websocket=mock_websocket, user_id=1, topics=set() ) # Subscribe to topic success = await websocket_pool.subscribe_to_topic(connection_id, "new_topic") assert success # Check subscription topic_connections = await websocket_pool.get_topic_connections("new_topic") assert connection_id in topic_connections connection_info = await websocket_pool.get_connection_info(connection_id) assert "new_topic" in connection_info.topics # Unsubscribe from topic success = await websocket_pool.unsubscribe_from_topic(connection_id, "new_topic") assert success # Check unsubscription topic_connections = await websocket_pool.get_topic_connections("new_topic") assert connection_id not in topic_connections connection_info = await websocket_pool.get_connection_info(connection_id) assert "new_topic" not in connection_info.topics async def test_broadcast_to_topic(self, websocket_pool): """Test broadcasting messages to topic subscribers""" # Create multiple connections websockets = [MockWebSocket(i) for i in range(3)] connection_ids = [] for i, ws in enumerate(websockets): connection_id = await websocket_pool.add_connection( websocket=ws, user_id=i, topics={"broadcast_topic"} ) connection_ids.append(connection_id) # Broadcast message message = WebSocketMessage( type="test_broadcast", topic="broadcast_topic", data={"message": "Hello everyone!"} ) sent_count = await websocket_pool.broadcast_to_topic( topic="broadcast_topic", message=message ) assert sent_count == 3 # Check all websockets received the message for ws in websockets: assert len(ws.sent_messages) == 1 sent_message = ws.sent_messages[0] assert sent_message["type"] == "test_broadcast" assert sent_message["data"]["message"] == "Hello everyone!" async def test_send_to_user(self, websocket_pool): """Test sending messages to all connections for a specific user""" # Create multiple connections for the same user websockets = [MockWebSocket(1) for _ in range(2)] connection_ids = [] for ws in websockets: connection_id = await websocket_pool.add_connection( websocket=ws, user_id=1, topics=set() ) connection_ids.append(connection_id) # Add connection for different user other_ws = MockWebSocket(2) await websocket_pool.add_connection( websocket=other_ws, user_id=2, topics=set() ) # Send message to user 1 message = WebSocketMessage( type="user_message", data={"message": "Hello user 1!"} ) sent_count = await websocket_pool.send_to_user(1, message) assert sent_count == 2 # Check user 1 connections received message for ws in websockets: assert len(ws.sent_messages) == 1 sent_message = ws.sent_messages[0] assert sent_message["type"] == "user_message" # Check user 2 connection didn't receive message assert len(other_ws.sent_messages) == 0 async def test_connection_limits(self, websocket_pool): """Test connection limits""" # Test max connections per topic websockets = [] for i in range(websocket_pool.max_connections_per_topic + 1): ws = MockWebSocket(i) websockets.append(ws) if i < websocket_pool.max_connections_per_topic: # Should succeed connection_id = await websocket_pool.add_connection( websocket=ws, user_id=i, topics={"limited_topic"} ) assert connection_id is not None else: # Should fail due to topic limit with pytest.raises(ValueError, match="Maximum connections per topic"): await websocket_pool.add_connection( websocket=ws, user_id=i, topics={"limited_topic"} ) async def test_stale_connection_cleanup(self, websocket_pool): """Test automatic cleanup of stale connections""" # Add connection mock_ws = MockWebSocket(1) connection_id = await websocket_pool.add_connection( websocket=mock_ws, user_id=1, topics={"test_topic"} ) # Manually set last activity to make connection stale connection_info = await websocket_pool.get_connection_info(connection_id) connection_info.last_activity = datetime.now(timezone.utc) - timedelta(seconds=10) # Trigger cleanup await websocket_pool._cleanup_stale_connections() # Check connection was cleaned up connection_info = await websocket_pool.get_connection_info(connection_id) assert connection_info is None stats = await websocket_pool.get_stats() assert stats["active_connections"] == 0 assert stats["connections_cleaned"] > 0 async def test_heartbeat_functionality(self, websocket_pool, mock_websocket): """Test heartbeat sending and connection health monitoring""" connection_id = await websocket_pool.add_connection( websocket=mock_websocket, user_id=1, topics=set() ) # Trigger heartbeat await websocket_pool._send_heartbeats() # Check heartbeat was sent assert len(mock_websocket.sent_messages) == 1 heartbeat_message = mock_websocket.sent_messages[0] assert heartbeat_message["type"] == MessageType.HEARTBEAT.value # Test ping functionality success = await websocket_pool.ping_connection(connection_id) assert success assert len(mock_websocket.sent_messages) == 2 ping_message = mock_websocket.sent_messages[1] assert ping_message["type"] == MessageType.PING.value async def test_failed_message_cleanup(self, websocket_pool): """Test cleanup of connections when message sending fails""" # Create a websocket that will fail on send mock_ws = MockWebSocket(1) mock_ws.closed = True # Simulate closed connection connection_id = await websocket_pool.add_connection( websocket=mock_ws, user_id=1, topics={"test_topic"} ) # Try to broadcast - should fail and clean up connection message = WebSocketMessage(type="test", data={}) sent_count = await websocket_pool.broadcast_to_topic( topic="test_topic", message=message ) assert sent_count == 0 # Check connection was cleaned up connection_info = await websocket_pool.get_connection_info(connection_id) assert connection_info is None class TestWebSocketManager: """Test the WebSocket manager middleware""" @pytest.fixture def websocket_manager(self): """Create a WebSocket manager for testing""" return WebSocketManager() @patch('app.middleware.websocket_middleware.verify_token') @patch('app.middleware.websocket_middleware.SessionLocal') async def test_authentication_success(self, mock_session, mock_verify_token, websocket_manager): """Test successful WebSocket authentication""" # Mock successful authentication mock_verify_token.return_value = "test_user" mock_user = MagicMock() mock_user.id = 1 mock_user.username = "test_user" mock_user.is_active = True mock_db = MagicMock() mock_db.query.return_value.filter.return_value.first.return_value = mock_user mock_session.return_value = mock_db # Create mock websocket with token mock_ws = MagicMock() mock_ws.url.query = "token=valid_token" user = await websocket_manager.authenticate_websocket(mock_ws) assert user is not None assert user.id == 1 assert user.username == "test_user" @patch('app.middleware.websocket_middleware.verify_token') async def test_authentication_failure(self, mock_verify_token, websocket_manager): """Test failed WebSocket authentication""" # Mock failed authentication mock_verify_token.return_value = None mock_ws = MagicMock() mock_ws.url.query = "token=invalid_token" user = await websocket_manager.authenticate_websocket(mock_ws) assert user is None class TestWebSocketIntegration: """Test WebSocket integration with FastAPI""" @pytest.mark.asyncio async def test_websocket_connection_context_manager(self): """Test the websocket_connection context manager""" pool = WebSocketPool() await pool.start() try: mock_ws = MockWebSocket(1) async with websocket_connection( websocket=mock_ws, user_id=1, topics={"test_topic"}, metadata={"test": "data"} ) as (connection_id, pool_instance): assert connection_id is not None assert connection_id.startswith("ws_") # The context returns the global pool; allow either the same object or equivalent type assert isinstance(pool_instance, type(pool)) # Check connection exists connection_info = await pool.get_connection_info(connection_id) assert connection_info is not None assert connection_info.user_id == 1 assert "test_topic" in connection_info.topics # Check connection was cleaned up after context exit connection_info = await pool.get_connection_info(connection_id) assert connection_info is None finally: await pool.stop() @pytest.mark.asyncio async def test_global_pool_management(self): """Test global WebSocket pool initialization and shutdown""" # Initialize global pool pool = await initialize_websocket_pool( cleanup_interval=1, connection_timeout=5 ) assert pool is not None # Get the same pool instance same_pool = get_websocket_pool() assert same_pool is pool # Shutdown global pool await shutdown_websocket_pool() # Should create new pool after shutdown new_pool = get_websocket_pool() assert new_pool is not pool class TestWebSocketStressTest: """Stress tests for WebSocket pool under high load""" @pytest.mark.asyncio async def test_high_connection_volume(self): """Test pool behavior with many concurrent connections""" pool = WebSocketPool(max_total_connections=1000) await pool.start() try: # Create many connections connection_ids = [] for i in range(100): mock_ws = MockWebSocket(i) connection_id = await pool.add_connection( websocket=mock_ws, user_id=i % 10, # 10 different users topics={f"topic_{i % 5}"} # 5 different topics ) connection_ids.append(connection_id) # Check all connections exist stats = await pool.get_stats() assert stats["active_connections"] == 100 assert stats["total_topics"] == 5 assert stats["total_users"] == 10 # Broadcast to all topics for topic_id in range(5): topic = f"topic_{topic_id}" message = WebSocketMessage( type="stress_test", topic=topic, data={"topic_id": topic_id} ) sent_count = await pool.broadcast_to_topic(topic, message) assert sent_count == 20 # 100 connections / 5 topics # Clean up all connections for connection_id in connection_ids: await pool.remove_connection(connection_id, "stress_test_cleanup") stats = await pool.get_stats() assert stats["active_connections"] == 0 finally: await pool.stop() @pytest.mark.asyncio async def test_rapid_connect_disconnect(self): """Test rapid connection and disconnection patterns""" pool = WebSocketPool() await pool.start() try: # Rapidly create and destroy connections for round_num in range(10): connection_ids = [] # Create 10 connections for i in range(10): mock_ws = MockWebSocket(i) connection_id = await pool.add_connection( websocket=mock_ws, user_id=i, topics={f"round_{round_num}"} ) connection_ids.append(connection_id) # Immediately remove them for connection_id in connection_ids: await pool.remove_connection(connection_id, f"round_{round_num}_cleanup") # Check pool is clean stats = await pool.get_stats() assert stats["active_connections"] == 0 finally: await pool.stop() # Test utilities and fixtures for integration testing @pytest.fixture def websocket_test_client(): """Create a test client for WebSocket endpoint testing""" from app.main import app return TestClient(app) class TestBillingWebSocketIntegration: """Test integration with the billing API WebSocket endpoint""" @pytest.mark.asyncio async def test_batch_progress_topic_format(self): """Test that batch progress uses correct topic format""" from app.api.billing import _notify_progress_subscribers from app.services.batch_generation import BatchProgress # Mock the WebSocket manager with patch('app.api.billing.websocket_manager') as mock_manager: mock_manager.broadcast_to_topic = AsyncMock(return_value=1) # Create test progress progress = BatchProgress( batch_id="test_batch_123", status="processing", total_files=10, processed_files=5, successful_files=4, failed_files=1, current_file="test.txt", started_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc) ) # Call notification function await _notify_progress_subscribers(progress) # Check correct topic was used mock_manager.broadcast_to_topic.assert_called_once_with( topic="batch_progress_test_batch_123", message_type="progress", data=progress.model_dump() ) if __name__ == "__main__": # Run tests pytest.main([__file__, "-v"])