86 lines
3.1 KiB
Python
Raw Normal View History

2025-12-10 12:02:17 +08:00
import asyncio
from collections import defaultdict
from typing import Dict, Optional, Set
from fastapi import WebSocket
from fastapi.encoders import jsonable_encoder
from backend.api.deps import AuthUser
from backend.modules.instances.models import Instance
from backend.modules.instances.schemas import InstanceOut
from backend.modules.users.models import RoleName
class InstanceEventManager:
"""Tracks websocket connections and dispatches instance change events."""
def __init__(self) -> None:
self._connections: Dict[int, Set[WebSocket]] = defaultdict(set)
self._admins: Set[WebSocket] = set()
self._ws_to_customer: Dict[WebSocket, int] = {}
self._lock = asyncio.Lock()
async def connect(self, websocket: WebSocket, auth_user: AuthUser) -> None:
await websocket.accept()
customer_key = auth_user.customer_id or 0
async with self._lock:
self._connections[customer_key].add(websocket)
self._ws_to_customer[websocket] = customer_key
if auth_user.role_name == RoleName.ADMIN.value:
self._admins.add(websocket)
async def disconnect(self, websocket: WebSocket) -> None:
async with self._lock:
customer_key = self._ws_to_customer.pop(websocket, None)
if customer_key is not None and customer_key in self._connections:
self._connections[customer_key].discard(websocket)
if not self._connections[customer_key]:
self._connections.pop(customer_key, None)
self._admins.discard(websocket)
async def broadcast(self, payload: dict, customer_id: Optional[int]) -> None:
if customer_id is None:
return
payload_jsonable = jsonable_encoder(payload)
async with self._lock:
targets = set(self._admins) | set(self._connections.get(customer_id, set()))
if not targets:
return
stale: list[WebSocket] = []
for ws in targets:
try:
await ws.send_json(payload_jsonable)
except Exception:
stale.append(ws)
for ws in stale:
await self.disconnect(ws)
instance_event_manager = InstanceEventManager()
def serialize_instance(instance: Instance) -> dict:
return InstanceOut.model_validate(instance).model_dump()
def build_removed_payload(instance: Instance) -> dict:
return {
"id": instance.id,
"instance_id": instance.instance_id,
"account_id": instance.account_id,
"region": instance.region,
"customer_id": instance.customer_id,
"credential_id": instance.credential_id,
"status": instance.status,
}
async def broadcast_instance_update(instance: Instance) -> None:
await instance_event_manager.broadcast(
{"type": "instance_update", "instance": serialize_instance(instance)}, instance.customer_id
)
async def broadcast_instance_removed(payload: dict, customer_id: Optional[int]) -> None:
await instance_event_manager.broadcast({"type": "instance_removed", "instance": payload}, customer_id)