This commit is contained in:
HotSwapp
2025-08-18 20:20:04 -05:00
parent 89b2bc0aa2
commit bac8cc4bd5
114 changed files with 30258 additions and 1341 deletions

View File

@@ -29,6 +29,9 @@ from app.config import settings
from app.services.query_utils import apply_sorting, tokenized_ilike_filter, paginate_with_total
from app.utils.exceptions import handle_database_errors, safe_execute
from app.utils.logging import app_logger
from app.middleware.websocket_middleware import get_websocket_manager, get_connection_tracker, WebSocketMessage
from app.services.document_notifications import ADMIN_DOCUMENTS_TOPIC
from fastapi import WebSocket
router = APIRouter()
@@ -64,6 +67,55 @@ class HealthCheck(BaseModel):
cpu_usage: float
alerts: List[str]
class WebSocketStats(BaseModel):
"""WebSocket connection pool statistics"""
total_connections: int
active_connections: int
total_topics: int
total_users: int
messages_sent: int
messages_failed: int
connections_cleaned: int
last_cleanup: Optional[str]
last_heartbeat: Optional[str]
connections_by_state: Dict[str, int]
topic_distribution: Dict[str, int]
class ConnectionInfo(BaseModel):
"""Individual WebSocket connection information"""
connection_id: str
user_id: Optional[int]
state: str
topics: List[str]
created_at: str
last_activity: str
age_seconds: float
idle_seconds: float
error_count: int
last_ping: Optional[str]
last_pong: Optional[str]
metadata: Dict[str, Any]
is_alive: bool
is_stale: bool
class WebSocketConnectionsResponse(BaseModel):
"""Response for WebSocket connections listing"""
connections: List[ConnectionInfo]
total_count: int
active_count: int
stale_count: int
class DisconnectRequest(BaseModel):
"""Request to disconnect WebSocket connections"""
connection_ids: Optional[List[str]] = None
user_id: Optional[int] = None
topic: Optional[str] = None
reason: str = "admin_disconnect"
class UserCreate(BaseModel):
"""Create new user"""
username: str = Field(..., min_length=3, max_length=50)
@@ -551,6 +603,253 @@ async def system_statistics(
)
# WebSocket Management Endpoints
@router.get("/websockets/stats", response_model=WebSocketStats)
async def get_websocket_stats(
current_user: User = Depends(get_admin_user)
):
"""Get WebSocket connection pool statistics"""
websocket_manager = get_websocket_manager()
stats = await websocket_manager.get_stats()
return WebSocketStats(**stats)
@router.get("/websockets/connections", response_model=WebSocketConnectionsResponse)
async def get_websocket_connections(
user_id: Optional[int] = Query(None, description="Filter by user ID"),
topic: Optional[str] = Query(None, description="Filter by topic"),
state: Optional[str] = Query(None, description="Filter by connection state"),
current_user: User = Depends(get_admin_user)
):
"""Get list of active WebSocket connections with optional filtering"""
websocket_manager = get_websocket_manager()
connection_tracker = get_connection_tracker()
# Get all connection IDs
pool = websocket_manager.pool
async with pool._connections_lock:
all_connection_ids = list(pool._connections.keys())
connections = []
active_count = 0
stale_count = 0
for connection_id in all_connection_ids:
metrics = await connection_tracker.get_connection_metrics(connection_id)
if not metrics:
continue
# Apply filters
if user_id and metrics.get("user_id") != user_id:
continue
if topic and topic not in metrics.get("topics", []):
continue
if state and metrics.get("state") != state:
continue
connections.append(ConnectionInfo(**metrics))
if metrics.get("is_alive"):
active_count += 1
if metrics.get("is_stale"):
stale_count += 1
return WebSocketConnectionsResponse(
connections=connections,
total_count=len(connections),
active_count=active_count,
stale_count=stale_count
)
@router.get("/websockets/connections/{connection_id}", response_model=ConnectionInfo)
async def get_websocket_connection(
connection_id: str,
current_user: User = Depends(get_admin_user)
):
"""Get detailed information about a specific WebSocket connection"""
connection_tracker = get_connection_tracker()
metrics = await connection_tracker.get_connection_metrics(connection_id)
if not metrics:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"WebSocket connection {connection_id} not found"
)
return ConnectionInfo(**metrics)
@router.post("/websockets/disconnect")
async def disconnect_websockets(
request: DisconnectRequest,
current_user: User = Depends(get_admin_user)
):
"""Disconnect WebSocket connections based on criteria"""
websocket_manager = get_websocket_manager()
pool = websocket_manager.pool
disconnected_count = 0
if request.connection_ids:
# Disconnect specific connections
for connection_id in request.connection_ids:
await pool.remove_connection(connection_id, request.reason)
disconnected_count += 1
elif request.user_id:
# Disconnect all connections for a user
user_connections = await pool.get_user_connections(request.user_id)
for connection_id in user_connections:
await pool.remove_connection(connection_id, request.reason)
disconnected_count += 1
elif request.topic:
# Disconnect all connections for a topic
topic_connections = await pool.get_topic_connections(request.topic)
for connection_id in topic_connections:
await pool.remove_connection(connection_id, request.reason)
disconnected_count += 1
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Must specify connection_ids, user_id, or topic"
)
app_logger.info("Admin disconnected WebSocket connections",
admin_user=current_user.username,
disconnected_count=disconnected_count,
reason=request.reason)
return {
"message": f"Disconnected {disconnected_count} WebSocket connections",
"disconnected_count": disconnected_count,
"reason": request.reason
}
@router.post("/websockets/cleanup")
async def cleanup_websockets(
current_user: User = Depends(get_admin_user)
):
"""Manually trigger cleanup of stale WebSocket connections"""
websocket_manager = get_websocket_manager()
pool = websocket_manager.pool
# Get stats before cleanup
stats_before = await pool.get_stats()
connections_before = stats_before["active_connections"]
# Force cleanup
await pool._cleanup_stale_connections()
# Get stats after cleanup
stats_after = await pool.get_stats()
connections_after = stats_after["active_connections"]
cleaned_count = connections_before - connections_after
app_logger.info("Admin triggered WebSocket cleanup",
admin_user=current_user.username,
cleaned_count=cleaned_count)
return {
"message": f"Cleaned up {cleaned_count} stale WebSocket connections",
"connections_before": connections_before,
"connections_after": connections_after,
"cleaned_count": cleaned_count
}
@router.post("/websockets/broadcast")
async def broadcast_message(
topic: str = Body(..., description="Topic to broadcast to"),
message_type: str = Body(..., description="Message type"),
data: Optional[Dict[str, Any]] = Body(None, description="Message data"),
current_user: User = Depends(get_admin_user)
):
"""Broadcast a message to all connections subscribed to a topic"""
websocket_manager = get_websocket_manager()
sent_count = await websocket_manager.broadcast_to_topic(
topic=topic,
message_type=message_type,
data=data
)
app_logger.info("Admin broadcast message to topic",
admin_user=current_user.username,
topic=topic,
message_type=message_type,
sent_count=sent_count)
return {
"message": f"Broadcast message to {sent_count} connections",
"topic": topic,
"message_type": message_type,
"sent_count": sent_count
}
@router.websocket("/ws/documents")
async def ws_admin_documents(websocket: WebSocket):
"""
Admin WebSocket endpoint for monitoring all document processing events.
Receives real-time notifications about:
- Document generation started/completed/failed across all files
- Document uploads across all files
- Workflow executions that generate documents
Requires admin authentication via token query parameter.
"""
websocket_manager = get_websocket_manager()
# Custom message handler for admin document monitoring
async def handle_admin_document_message(connection_id: str, message: WebSocketMessage):
"""Handle custom messages for admin document monitoring"""
app_logger.debug("Received admin document message",
connection_id=connection_id,
message_type=message.type)
# Use the WebSocket manager to handle the connection
connection_id = await websocket_manager.handle_connection(
websocket=websocket,
topics={ADMIN_DOCUMENTS_TOPIC},
require_auth=True,
metadata={"endpoint": "admin_documents", "admin_monitoring": True},
message_handler=handle_admin_document_message
)
if connection_id:
# Send initial welcome message with admin monitoring confirmation
try:
pool = websocket_manager.pool
welcome_message = WebSocketMessage(
type="admin_monitoring_active",
topic=ADMIN_DOCUMENTS_TOPIC,
data={
"message": "Connected to admin document monitoring stream",
"events": [
"document_processing",
"document_completed",
"document_failed",
"document_upload"
]
}
)
await pool._send_to_connection(connection_id, welcome_message)
app_logger.info("Admin document monitoring connection established",
connection_id=connection_id)
except Exception as e:
app_logger.error("Failed to send admin monitoring welcome message",
connection_id=connection_id,
error=str(e))
@router.post("/import/csv")
async def import_csv(
table_name: str,
@@ -558,17 +857,34 @@ async def import_csv(
db: Session = Depends(get_db),
current_user: User = Depends(get_admin_user)
):
"""Import data from CSV file"""
"""Import data from CSV file with comprehensive security validation"""
from app.utils.file_security import file_validator, validate_csv_content
if not file.filename.endswith('.csv'):
# Comprehensive security validation for CSV uploads
content_bytes, safe_filename, file_ext, mime_type = await file_validator.validate_upload_file(
file, category='csv'
)
# Decode content with proper encoding handling
encodings = ['utf-8', 'utf-8-sig', 'windows-1252', 'iso-8859-1']
content_str = None
for encoding in encodings:
try:
content_str = content_bytes.decode(encoding)
break
except UnicodeDecodeError:
continue
if content_str is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="File must be a CSV"
status_code=400,
detail="Could not decode CSV file. Please ensure it's saved in UTF-8, Windows-1252, or ISO-8859-1 encoding."
)
# Read CSV content
content = await file.read()
csv_data = csv.DictReader(io.StringIO(content.decode('utf-8')))
# Additional CSV security validation
validate_csv_content(content_str)
csv_data = csv.DictReader(io.StringIO(content_str))
imported_count = 0
errors = []
@@ -1786,4 +2102,33 @@ async def get_audit_statistics(
{"username": username, "activity_count": count}
for username, count in most_active_users
]
}
}
@router.get("/cache-performance")
async def get_cache_performance(
current_user: User = Depends(get_admin_user)
):
"""Get adaptive cache performance statistics"""
try:
from app.services.adaptive_cache import get_cache_stats
stats = get_cache_stats()
return {
"status": "success",
"cache_statistics": stats,
"timestamp": datetime.now().isoformat(),
"summary": {
"total_cache_types": len(stats),
"avg_hit_rate": sum(s.get("hit_rate", 0) for s in stats.values()) / len(stats) if stats else 0,
"most_active": max(stats.items(), key=lambda x: x[1].get("total_queries", 0)) if stats else None,
"longest_ttl": max(stats.items(), key=lambda x: x[1].get("current_ttl", 0)) if stats else None,
"shortest_ttl": min(stats.items(), key=lambda x: x[1].get("current_ttl", float('inf'))) if stats else None
}
}
except Exception as e:
return {
"status": "error",
"error": str(e),
"cache_statistics": {}
}