410 lines
14 KiB
Python
410 lines
14 KiB
Python
"""
|
|
WebSocket Connection Pool Usage Examples
|
|
|
|
This file demonstrates how to use the WebSocket connection pooling system
|
|
in the Delphi Database application.
|
|
|
|
Examples include:
|
|
- Basic WebSocket endpoint with pooling
|
|
- Custom message handling
|
|
- Topic-based broadcasting
|
|
- Connection monitoring
|
|
- Admin management integration
|
|
"""
|
|
|
|
import asyncio
|
|
from datetime import datetime, timezone
|
|
from typing import Set, Optional, Dict, Any
|
|
|
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Depends
|
|
from fastapi.responses import HTMLResponse
|
|
|
|
from app.middleware.websocket_middleware import (
|
|
get_websocket_manager,
|
|
websocket_endpoint,
|
|
websocket_auth_dependency,
|
|
WebSocketManager
|
|
)
|
|
from app.services.websocket_pool import (
|
|
WebSocketMessage,
|
|
MessageType,
|
|
websocket_connection
|
|
)
|
|
from app.models.user import User
|
|
|
|
|
|
# Example 1: Basic WebSocket endpoint with automatic pooling
|
|
@websocket_endpoint(topics={"notifications"}, require_auth=True)
|
|
async def basic_websocket_handler(
|
|
websocket: WebSocket,
|
|
connection_id: str,
|
|
manager: WebSocketManager,
|
|
message: WebSocketMessage
|
|
):
|
|
"""
|
|
Basic WebSocket handler using the pooling decorator.
|
|
Automatically handles connection management, authentication, and cleanup.
|
|
"""
|
|
if message.type == "user_action":
|
|
# Handle user actions
|
|
response = WebSocketMessage(
|
|
type="action_response",
|
|
data={"status": "received", "action": message.data.get("action")}
|
|
)
|
|
await manager.pool._send_to_connection(connection_id, response)
|
|
|
|
|
|
# Example 2: Manual WebSocket management with custom logic
|
|
async def manual_websocket_handler(websocket: WebSocket, topic: str):
|
|
"""
|
|
Manual WebSocket handling with direct pool management.
|
|
Provides more control over connection lifecycle and message handling.
|
|
"""
|
|
manager = get_websocket_manager()
|
|
|
|
# Custom message handler
|
|
async def handle_custom_message(connection_id: str, message: WebSocketMessage):
|
|
if message.type == "chat_message":
|
|
# Broadcast chat message to all subscribers
|
|
await manager.broadcast_to_topic(
|
|
topic=topic,
|
|
message_type="chat_broadcast",
|
|
data={
|
|
"user": message.data.get("user", "Anonymous"),
|
|
"message": message.data.get("message", ""),
|
|
"timestamp": datetime.now(timezone.utc).isoformat()
|
|
},
|
|
exclude_connection_id=connection_id
|
|
)
|
|
elif message.type == "typing":
|
|
# Broadcast typing indicator
|
|
await manager.broadcast_to_topic(
|
|
topic=topic,
|
|
message_type="user_typing",
|
|
data={
|
|
"user": message.data.get("user", "Anonymous"),
|
|
"typing": message.data.get("typing", False)
|
|
},
|
|
exclude_connection_id=connection_id
|
|
)
|
|
|
|
# Handle the connection
|
|
await manager.handle_connection(
|
|
websocket=websocket,
|
|
topics={topic},
|
|
require_auth=True,
|
|
metadata={"chat_room": topic},
|
|
message_handler=handle_custom_message
|
|
)
|
|
|
|
|
|
# Example 3: Low-level pool usage with context manager
|
|
async def low_level_websocket_example(websocket: WebSocket, user_id: int):
|
|
"""
|
|
Low-level WebSocket handling using the connection context manager directly.
|
|
Provides maximum control over the connection lifecycle.
|
|
"""
|
|
await websocket.accept()
|
|
|
|
async with websocket_connection(
|
|
websocket=websocket,
|
|
user_id=user_id,
|
|
topics={"user_updates"},
|
|
metadata={"example": "low_level"}
|
|
) as (connection_id, pool):
|
|
|
|
# Send welcome message
|
|
welcome = WebSocketMessage(
|
|
type="welcome",
|
|
data={
|
|
"connection_id": connection_id,
|
|
"message": "Connected to low-level example"
|
|
}
|
|
)
|
|
await pool._send_to_connection(connection_id, welcome)
|
|
|
|
# Handle messages manually
|
|
try:
|
|
while True:
|
|
try:
|
|
data = await websocket.receive_text()
|
|
|
|
# Parse and handle message
|
|
import json
|
|
message_dict = json.loads(data)
|
|
message = WebSocketMessage(**message_dict)
|
|
|
|
if message.type == "ping":
|
|
pong = WebSocketMessage(type="pong", data={"timestamp": message.timestamp})
|
|
await pool._send_to_connection(connection_id, pong)
|
|
|
|
elif message.type == "echo":
|
|
echo = WebSocketMessage(
|
|
type="echo_response",
|
|
data={"original": message.data, "echoed_at": datetime.now(timezone.utc).isoformat()}
|
|
)
|
|
await pool._send_to_connection(connection_id, echo)
|
|
|
|
except WebSocketDisconnect:
|
|
break
|
|
except Exception as e:
|
|
print(f"Error handling message: {e}")
|
|
break
|
|
|
|
except Exception as e:
|
|
print(f"Connection error: {e}")
|
|
|
|
|
|
# Example 4: Broadcasting service
|
|
class NotificationBroadcaster:
|
|
"""
|
|
Service for broadcasting notifications to different user groups.
|
|
Demonstrates how to use the pool for system-wide notifications.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.manager = get_websocket_manager()
|
|
|
|
async def broadcast_system_announcement(self, message: str, priority: str = "info"):
|
|
"""Broadcast system announcement to all connected users"""
|
|
sent_count = await self.manager.broadcast_to_topic(
|
|
topic="system_announcements",
|
|
message_type="system_announcement",
|
|
data={
|
|
"message": message,
|
|
"priority": priority,
|
|
"timestamp": datetime.now(timezone.utc).isoformat()
|
|
}
|
|
)
|
|
return sent_count
|
|
|
|
async def notify_user_group(self, group: str, notification_type: str, data: Dict[str, Any]):
|
|
"""Send notification to a specific user group"""
|
|
topic = f"group_{group}"
|
|
sent_count = await self.manager.broadcast_to_topic(
|
|
topic=topic,
|
|
message_type=notification_type,
|
|
data=data
|
|
)
|
|
return sent_count
|
|
|
|
async def send_personal_notification(self, user_id: int, notification_type: str, data: Dict[str, Any]):
|
|
"""Send personal notification to a specific user"""
|
|
sent_count = await self.manager.send_to_user(
|
|
user_id=user_id,
|
|
message_type=notification_type,
|
|
data=data
|
|
)
|
|
return sent_count
|
|
|
|
|
|
# Example 5: Connection monitoring and health checks
|
|
class ConnectionMonitor:
|
|
"""
|
|
Service for monitoring WebSocket connections and health.
|
|
Demonstrates how to use the pool for system monitoring.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.manager = get_websocket_manager()
|
|
|
|
async def get_connection_stats(self) -> Dict[str, Any]:
|
|
"""Get comprehensive connection statistics"""
|
|
return await self.manager.get_stats()
|
|
|
|
async def health_check_all_connections(self) -> Dict[str, Any]:
|
|
"""Perform health check on all connections"""
|
|
pool = self.manager.pool
|
|
|
|
async with pool._connections_lock:
|
|
connection_ids = list(pool._connections.keys())
|
|
|
|
healthy = 0
|
|
stale = 0
|
|
total = len(connection_ids)
|
|
|
|
for connection_id in connection_ids:
|
|
connection_info = await pool.get_connection_info(connection_id)
|
|
if connection_info:
|
|
if connection_info.is_alive():
|
|
healthy += 1
|
|
if connection_info.is_stale():
|
|
stale += 1
|
|
|
|
return {
|
|
"total_connections": total,
|
|
"healthy_connections": healthy,
|
|
"stale_connections": stale,
|
|
"health_percentage": (healthy / total * 100) if total > 0 else 100
|
|
}
|
|
|
|
async def cleanup_stale_connections(self) -> int:
|
|
"""Manually cleanup stale connections"""
|
|
pool = self.manager.pool
|
|
stats_before = await pool.get_stats()
|
|
await pool._cleanup_stale_connections()
|
|
stats_after = await pool.get_stats()
|
|
return stats_before["active_connections"] - stats_after["active_connections"]
|
|
|
|
|
|
# Example 6: Real-time data streaming
|
|
class RealTimeDataStreamer:
|
|
"""
|
|
Service for streaming real-time data to WebSocket clients.
|
|
Demonstrates how to use the pool for continuous data updates.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.manager = get_websocket_manager()
|
|
self._streaming_tasks: Dict[str, asyncio.Task] = {}
|
|
|
|
async def start_data_stream(self, topic: str, data_source: callable, interval: float = 1.0):
|
|
"""Start streaming data to a topic"""
|
|
if topic in self._streaming_tasks:
|
|
return False # Already streaming
|
|
|
|
async def stream_data():
|
|
while True:
|
|
try:
|
|
# Get data from source
|
|
data = await data_source() if asyncio.iscoroutinefunction(data_source) else data_source()
|
|
|
|
# Broadcast to subscribers
|
|
await self.manager.broadcast_to_topic(
|
|
topic=topic,
|
|
message_type="data_update",
|
|
data={
|
|
"data": data,
|
|
"timestamp": datetime.now(timezone.utc).isoformat()
|
|
}
|
|
)
|
|
|
|
await asyncio.sleep(interval)
|
|
|
|
except asyncio.CancelledError:
|
|
break
|
|
except Exception as e:
|
|
print(f"Error in data stream {topic}: {e}")
|
|
await asyncio.sleep(interval * 2) # Back off on error
|
|
|
|
task = asyncio.create_task(stream_data())
|
|
self._streaming_tasks[topic] = task
|
|
return True
|
|
|
|
async def stop_data_stream(self, topic: str):
|
|
"""Stop streaming data to a topic"""
|
|
if topic in self._streaming_tasks:
|
|
self._streaming_tasks[topic].cancel()
|
|
del self._streaming_tasks[topic]
|
|
return True
|
|
return False
|
|
|
|
async def stop_all_streams(self):
|
|
"""Stop all data streams"""
|
|
for task in self._streaming_tasks.values():
|
|
task.cancel()
|
|
self._streaming_tasks.clear()
|
|
|
|
|
|
# Example FastAPI application demonstrating usage
|
|
def create_example_app() -> FastAPI:
|
|
"""Create example FastAPI application with WebSocket endpoints"""
|
|
app = FastAPI(title="WebSocket Pool Example")
|
|
|
|
# Initialize services
|
|
broadcaster = NotificationBroadcaster()
|
|
monitor = ConnectionMonitor()
|
|
streamer = RealTimeDataStreamer()
|
|
|
|
@app.websocket("/ws/basic")
|
|
async def basic_endpoint(websocket: WebSocket):
|
|
"""Basic WebSocket endpoint with automatic pooling"""
|
|
await basic_websocket_handler(websocket, "basic", get_websocket_manager(), None)
|
|
|
|
@app.websocket("/ws/chat/{room}")
|
|
async def chat_endpoint(websocket: WebSocket, room: str):
|
|
"""Chat room WebSocket endpoint"""
|
|
await manual_websocket_handler(websocket, f"chat_{room}")
|
|
|
|
@app.websocket("/ws/user/{user_id}")
|
|
async def user_endpoint(websocket: WebSocket, user_id: int):
|
|
"""User-specific WebSocket endpoint"""
|
|
await low_level_websocket_example(websocket, user_id)
|
|
|
|
@app.get("/")
|
|
async def index():
|
|
"""Simple HTML page for testing WebSocket connections"""
|
|
return HTMLResponse("""
|
|
<!DOCTYPE html>
|
|
<html>
|
|
<head>
|
|
<title>WebSocket Pool Example</title>
|
|
</head>
|
|
<body>
|
|
<h1>WebSocket Pool Example</h1>
|
|
<div id="messages"></div>
|
|
<input type="text" id="messageInput" placeholder="Type a message...">
|
|
<button onclick="sendMessage()">Send</button>
|
|
|
|
<script>
|
|
const ws = new WebSocket('ws://localhost:8000/ws/basic');
|
|
const messages = document.getElementById('messages');
|
|
|
|
ws.onmessage = function(event) {
|
|
const message = JSON.parse(event.data);
|
|
const div = document.createElement('div');
|
|
div.textContent = JSON.stringify(message, null, 2);
|
|
messages.appendChild(div);
|
|
};
|
|
|
|
function sendMessage() {
|
|
const input = document.getElementById('messageInput');
|
|
const message = {
|
|
type: 'user_action',
|
|
data: { action: input.value }
|
|
};
|
|
ws.send(JSON.stringify(message));
|
|
input.value = '';
|
|
}
|
|
|
|
// Send ping every 30 seconds
|
|
setInterval(() => {
|
|
ws.send(JSON.stringify({ type: 'ping' }));
|
|
}, 30000);
|
|
</script>
|
|
</body>
|
|
</html>
|
|
""")
|
|
|
|
@app.post("/api/broadcast/system")
|
|
async def broadcast_system(message: str, priority: str = "info"):
|
|
"""Broadcast system message to all users"""
|
|
sent_count = await broadcaster.broadcast_system_announcement(message, priority)
|
|
return {"sent_count": sent_count}
|
|
|
|
@app.get("/api/monitor/stats")
|
|
async def get_monitor_stats():
|
|
"""Get connection monitoring statistics"""
|
|
return await monitor.get_connection_stats()
|
|
|
|
@app.get("/api/monitor/health")
|
|
async def get_health_status():
|
|
"""Get connection health status"""
|
|
return await monitor.health_check_all_connections()
|
|
|
|
@app.post("/api/monitor/cleanup")
|
|
async def cleanup_connections():
|
|
"""Manually cleanup stale connections"""
|
|
cleaned = await monitor.cleanup_stale_connections()
|
|
return {"cleaned_connections": cleaned}
|
|
|
|
return app
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Run the example application
|
|
import uvicorn
|
|
|
|
app = create_example_app()
|
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|