""" Tests for WebSocket management endpoints in the Admin API Tests cover: - WebSocket statistics endpoint - Connection listing and filtering - Connection management (disconnect, cleanup) - Broadcasting functionality - Admin-only access control """ import pytest from unittest.mock import AsyncMock, MagicMock, patch from fastapi.testclient import TestClient from fastapi import status from app.main import app from app.auth.security import get_current_user, get_admin_user from app.models.user import User @pytest.fixture def test_client(): """Create test client""" return TestClient(app) @pytest.fixture def admin_user(): """Create admin user for testing""" user = User( id=1, username="admin", email="admin@test.com", is_admin=True, is_active=True ) return user @pytest.fixture def admin_client(admin_user): """Test client with admin dependency overrides applied.""" app.dependency_overrides[get_current_user] = lambda: admin_user app.dependency_overrides[get_admin_user] = lambda: admin_user try: yield TestClient(app) finally: app.dependency_overrides.pop(get_current_user, None) app.dependency_overrides.pop(get_admin_user, None) @pytest.fixture def regular_user(): """Create regular user for testing""" user = User( id=2, username="user", email="user@test.com", is_admin=False, is_active=True ) return user class TestWebSocketStatsEndpoint: """Test WebSocket statistics endpoint""" @patch('app.api.admin.get_admin_user') @patch('app.api.admin.get_websocket_manager') def test_get_websocket_stats_success(self, mock_get_manager, mock_get_admin, admin_client, admin_user): """Test successful retrieval of WebSocket statistics""" # Mock admin authentication mock_get_admin.return_value = admin_user # Mock WebSocket manager mock_manager = MagicMock() mock_manager.get_stats = AsyncMock(return_value={ "total_connections": 10, "active_connections": 8, "total_topics": 5, "total_users": 3, "messages_sent": 100, "messages_failed": 2, "connections_cleaned": 5, "last_cleanup": "2023-01-01T12:00:00Z", "last_heartbeat": "2023-01-01T12:01:00Z", "connections_by_state": {"connected": 8, "disconnected": 2}, "topic_distribution": {"topic1": 5, "topic2": 3} }) mock_get_manager.return_value = mock_manager # Make request response = admin_client.get("/api/admin/websockets/stats") # Check response assert response.status_code == status.HTTP_200_OK data = response.json() assert data["total_connections"] == 10 assert data["active_connections"] == 8 assert data["messages_sent"] == 100 assert "topic1" in data["topic_distribution"] @patch('app.api.admin.get_admin_user') def test_get_websocket_stats_unauthorized(self, mock_get_admin, test_client, regular_user): """Test unauthorized access to WebSocket statistics""" # Mock non-admin user mock_get_admin.side_effect = Exception("Admin required") # Make request response = test_client.get("/api/admin/websockets/stats") # Should fail (forbidden) assert response.status_code in (status.HTTP_403_FORBIDDEN, status.HTTP_401_UNAUTHORIZED, status.HTTP_500_INTERNAL_SERVER_ERROR) class TestWebSocketConnectionsEndpoint: """Test WebSocket connections listing endpoint""" @patch('app.api.admin.get_admin_user') @patch('app.api.admin.get_websocket_manager') @patch('app.api.admin.get_connection_tracker') def test_get_connections_success(self, mock_get_tracker, mock_get_manager, mock_get_admin, admin_client, admin_user): """Test successful retrieval of WebSocket connections""" # Mock admin authentication mock_get_admin.return_value = admin_user # Mock WebSocket manager and tracker mock_manager = MagicMock() mock_pool = MagicMock() mock_pool._connections_lock = AsyncMock() mock_pool._connections = {"conn1": MagicMock(), "conn2": MagicMock()} mock_manager.pool = mock_pool mock_get_manager.return_value = mock_manager mock_tracker = MagicMock() mock_tracker.get_connection_metrics = AsyncMock(side_effect=[ { "connection_id": "conn1", "user_id": 1, "state": "connected", "topics": ["topic1"], "created_at": "2023-01-01T12:00:00Z", "last_activity": "2023-01-01T12:01:00Z", "age_seconds": 60, "idle_seconds": 10, "error_count": 0, "last_ping": "2023-01-01T12:01:00Z", "last_pong": "2023-01-01T12:01:00Z", "metadata": {}, "is_alive": True, "is_stale": False }, { "connection_id": "conn2", "user_id": 2, "state": "connected", "topics": ["topic2"], "created_at": "2023-01-01T12:00:00Z", "last_activity": "2023-01-01T12:01:00Z", "age_seconds": 60, "idle_seconds": 10, "error_count": 0, "last_ping": "2023-01-01T12:01:00Z", "last_pong": "2023-01-01T12:01:00Z", "metadata": {}, "is_alive": True, "is_stale": False } ]) mock_get_tracker.return_value = mock_tracker # Make request response = admin_client.get("/api/admin/websockets/connections") # Check response assert response.status_code == status.HTTP_200_OK data = response.json() assert data["total_count"] == 2 assert data["active_count"] == 2 assert data["stale_count"] == 0 assert len(data["connections"]) == 2 @patch('app.api.admin.get_admin_user') @patch('app.api.admin.get_websocket_manager') @patch('app.api.admin.get_connection_tracker') def test_get_connections_with_filters(self, mock_get_tracker, mock_get_manager, mock_get_admin, admin_client, admin_user): """Test WebSocket connections listing with filters""" # Mock admin authentication mock_get_admin.return_value = admin_user # Mock WebSocket manager and tracker mock_manager = MagicMock() mock_pool = MagicMock() mock_pool._connections_lock = AsyncMock() mock_pool._connections = {"conn1": MagicMock()} mock_manager.pool = mock_pool mock_get_manager.return_value = mock_manager mock_tracker = MagicMock() # Return connection for user 1 only (user_id filter) mock_tracker.get_connection_metrics = AsyncMock(return_value={ "connection_id": "conn1", "user_id": 1, "state": "connected", "topics": ["topic1"], "created_at": "2023-01-01T12:00:00Z", "last_activity": "2023-01-01T12:01:00Z", "age_seconds": 60, "idle_seconds": 10, "error_count": 0, "last_ping": "2023-01-01T12:01:00Z", "last_pong": "2023-01-01T12:01:00Z", "metadata": {}, "is_alive": True, "is_stale": False }) mock_get_tracker.return_value = mock_tracker # Make request with user_id filter response = admin_client.get("/api/admin/websockets/connections?user_id=1") # Check response assert response.status_code == status.HTTP_200_OK data = response.json() assert data["total_count"] == 1 assert data["connections"][0]["user_id"] == 1 class TestWebSocketDisconnectEndpoint: """Test WebSocket disconnect endpoint""" @patch('app.api.admin.get_admin_user') @patch('app.api.admin.get_websocket_manager') def test_disconnect_by_connection_ids(self, mock_get_manager, mock_get_admin, admin_client, admin_user): """Test disconnecting specific connections""" # Mock admin authentication mock_get_admin.return_value = admin_user # Mock WebSocket manager mock_manager = MagicMock() mock_pool = MagicMock() mock_pool.remove_connection = AsyncMock() mock_manager.pool = mock_pool mock_get_manager.return_value = mock_manager # Make request response = admin_client.post( "/api/admin/websockets/disconnect", json={ "connection_ids": ["conn1", "conn2"], "reason": "admin_test" } ) # Check response assert response.status_code == status.HTTP_200_OK data = response.json() assert data["disconnected_count"] == 2 assert data["reason"] == "admin_test" # Check mock calls assert mock_pool.remove_connection.call_count == 2 @patch('app.api.admin.get_admin_user') @patch('app.api.admin.get_websocket_manager') def test_disconnect_by_user_id(self, mock_get_manager, mock_get_admin, admin_client, admin_user): """Test disconnecting all connections for a user""" # Mock admin authentication mock_get_admin.return_value = admin_user # Mock WebSocket manager mock_manager = MagicMock() mock_pool = MagicMock() mock_pool.get_user_connections = AsyncMock(return_value=["conn1", "conn2"]) mock_pool.remove_connection = AsyncMock() mock_manager.pool = mock_pool mock_get_manager.return_value = mock_manager # Make request response = admin_client.post( "/api/admin/websockets/disconnect", json={ "user_id": 1, "reason": "user_maintenance" } ) # Check response assert response.status_code == status.HTTP_200_OK data = response.json() assert data["disconnected_count"] == 2 def test_disconnect_missing_criteria(self, admin_client): """Test disconnect request with missing criteria""" # Make request without specifying what to disconnect response = admin_client.post( "/api/admin/websockets/disconnect", json={"reason": "test"} ) # Should fail assert response.status_code == status.HTTP_400_BAD_REQUEST class TestWebSocketCleanupEndpoint: """Test WebSocket cleanup endpoint""" @patch('app.api.admin.get_admin_user') @patch('app.api.admin.get_websocket_manager') def test_cleanup_websockets(self, mock_get_manager, mock_get_admin, admin_client, admin_user): """Test manual WebSocket cleanup""" # Mock admin authentication mock_get_admin.return_value = admin_user # Mock WebSocket manager mock_manager = MagicMock() mock_pool = MagicMock() mock_pool.get_stats = AsyncMock(side_effect=[ {"active_connections": 10}, # Before cleanup {"active_connections": 8} # After cleanup ]) mock_pool._cleanup_stale_connections = AsyncMock() mock_manager.pool = mock_pool mock_get_manager.return_value = mock_manager # Make request response = admin_client.post("/api/admin/websockets/cleanup") # Check response assert response.status_code == status.HTTP_200_OK data = response.json() assert data["connections_before"] == 10 assert data["connections_after"] == 8 assert data["cleaned_count"] == 2 # Check cleanup was called mock_pool._cleanup_stale_connections.assert_called_once() class TestWebSocketBroadcastEndpoint: """Test WebSocket broadcast endpoint""" @patch('app.api.admin.get_admin_user') @patch('app.api.admin.get_websocket_manager') def test_broadcast_message(self, mock_get_manager, mock_get_admin, admin_client, admin_user): """Test broadcasting a message to a topic""" # Mock admin authentication mock_get_admin.return_value = admin_user # Mock WebSocket manager mock_manager = MagicMock() mock_manager.broadcast_to_topic = AsyncMock(return_value=5) mock_get_manager.return_value = mock_manager # Make request response = admin_client.post( "/api/admin/websockets/broadcast", json={ "topic": "admin_announcement", "message_type": "system_message", "data": {"message": "System maintenance in 5 minutes"} } ) # Check response assert response.status_code == status.HTTP_200_OK data = response.json() assert data["sent_count"] == 5 assert data["topic"] == "admin_announcement" assert data["message_type"] == "system_message" # Check broadcast was called correctly mock_manager.broadcast_to_topic.assert_called_once_with( topic="admin_announcement", message_type="system_message", data={"message": "System maintenance in 5 minutes"} ) class TestWebSocketConnectionDetailEndpoint: """Test individual WebSocket connection detail endpoint""" @patch('app.api.admin.get_admin_user') @patch('app.api.admin.get_connection_tracker') def test_get_connection_detail_success(self, mock_get_tracker, mock_get_admin, admin_client, admin_user): """Test getting details for a specific connection""" # Mock admin authentication mock_get_admin.return_value = admin_user # Mock connection tracker mock_tracker = MagicMock() mock_tracker.get_connection_metrics = AsyncMock(return_value={ "connection_id": "conn1", "user_id": 1, "state": "connected", "topics": ["topic1", "topic2"], "created_at": "2023-01-01T12:00:00Z", "last_activity": "2023-01-01T12:01:00Z", "age_seconds": 60, "idle_seconds": 10, "error_count": 0, "last_ping": "2023-01-01T12:01:00Z", "last_pong": "2023-01-01T12:01:00Z", "metadata": {"endpoint": "batch_progress"}, "is_alive": True, "is_stale": False }) mock_get_tracker.return_value = mock_tracker # Make request response = admin_client.get("/api/admin/websockets/connections/conn1") # Check response assert response.status_code == status.HTTP_200_OK data = response.json() assert data["connection_id"] == "conn1" assert data["user_id"] == 1 assert data["state"] == "connected" assert len(data["topics"]) == 2 assert data["metadata"]["endpoint"] == "batch_progress" @patch('app.api.admin.get_admin_user') @patch('app.api.admin.get_connection_tracker') def test_get_connection_detail_not_found(self, mock_get_tracker, mock_get_admin, admin_client, admin_user): """Test getting details for non-existent connection""" # Mock admin authentication mock_get_admin.return_value = admin_user # Mock connection tracker returning None mock_tracker = MagicMock() mock_tracker.get_connection_metrics = AsyncMock(return_value=None) mock_get_tracker.return_value = mock_tracker # Make request response = admin_client.get("/api/admin/websockets/connections/nonexistent") # Check response assert response.status_code == status.HTTP_404_NOT_FOUND if __name__ == "__main__": pytest.main([__file__, "-v"])