import asyncio import json import time import uuid from dataclasses import dataclass, field from typing import Callable, Optional, Awaitable import aiohttp import structlog import websockets from common.config import FHobSettings from fh_ob.book_store import BookStore, OrderBookTop5 @dataclass class _WorkerState: symbols: set[str] = field(default_factory=set) command_queue: asyncio.Queue = field(default_factory=asyncio.Queue) ws_id: int = 0 reconnect_count: int = 0 last_reconnect_ts_ms: int = 0 connection_active: bool = False class KuCoinWSClient: def __init__( self, settings: FHobSettings, book_store: BookStore, on_book_update: Optional[Callable[[OrderBookTop5], None | Awaitable[None]]] = None, ) -> None: self._settings = settings self._book_store = book_store self._on_book_update_callback = on_book_update self._log = structlog.get_logger().bind(component="ws_client") self._running = False self._reconnect_delay = settings.reconnect_base_delay self._subscription_events: dict[str, asyncio.Event] = {} self._workers: list[_WorkerState] = [] self._worker_tasks: list[asyncio.Task] = [] self._http_session: Optional[aiohttp.ClientSession] = None async def start(self) -> None: self._running = True self._workers.clear() self._worker_tasks.clear() self._http_session = aiohttp.ClientSession() symbol_list = list(self._settings.symbols) for i in range(0, len(symbol_list) or 1, 400): group = set(symbol_list[i : i + 400]) ws_id = len(self._workers) + 1 state = _WorkerState(symbols=group, ws_id=ws_id) self._workers.append(state) for state in self._workers: task = asyncio.create_task(self._connection_worker(state)) self._worker_tasks.append(task) try: await asyncio.gather(*self._worker_tasks) except asyncio.CancelledError: pass self._log.debug("all_workers_stopped") async def stop(self) -> None: self._running = False for t in self._worker_tasks: t.cancel() if self._worker_tasks: await asyncio.wait(self._worker_tasks, timeout=5) if self._http_session and not self._http_session.closed: await self._http_session.close() self._log.debug("ws_client_stopped") def is_connected(self) -> bool: return any(w.connection_active for w in self._workers) def subscribed_count(self) -> int: return sum(len(w.symbols) for w in self._workers) def reconnect_stats(self) -> tuple[int, int]: """Return (total_reconnects, timestamp_ms of last reconnect) across all workers.""" total = sum(w.reconnect_count for w in self._workers) latest = max((w.last_reconnect_ts_ms for w in self._workers), default=0) return total, latest def get_symbols(self) -> list[str]: result = [] for w in self._workers: result.extend(w.symbols) return result def add_symbol(self, symbol: str) -> bool: if not self._workers: return False if any(symbol in w.symbols for w in self._workers): return False self._settings.symbols.append(symbol) eligible = [w for w in self._workers if len(w.symbols) < 400] if not eligible: self._log.warning("all_workers_full", symbol=symbol) return False worker = min(eligible, key=lambda w: len(w.symbols)) worker.symbols.add(symbol) worker.command_queue.put_nowait(("subscribe", symbol)) return True def remove_symbol(self, symbol: str) -> bool: found = False for worker in self._workers: if symbol in worker.symbols: worker.symbols.discard(symbol) found = True break if not found: return False self._settings.symbols.remove(symbol) return True async def _connection_worker(self, state: _WorkerState) -> None: while self._running: try: token, instance = await self._get_public_token() self._ping_interval = instance.get("pingInterval", 18000) / 1000.0 ws = await websockets.connect( instance["endpoint"] + f"?token={token}&connectId={uuid.uuid4()}-{state.ws_id}", ping_interval=None, ) self._log.debug("ws_connected", ws_id=state.ws_id) self._reconnect_delay = self._settings.reconnect_base_delay state.connection_active = True ping_task = asyncio.create_task(self._ping_loop(ws, state.ws_id)) async def reader() -> None: try: async for msg in ws: await self._handle_message(msg) except websockets.ConnectionClosed as e: self._log.warning("reader_connection_closed", ws_id=state.ws_id, code=e.code, reason=e.reason) except asyncio.CancelledError: raise except Exception as e: self._log.error("reader_unexpected_error", ws_id=state.ws_id, error=str(e)) reader_task = asyncio.create_task(reader()) try: if state.symbols: await self._send_subscribe(ws, list(state.symbols), state.ws_id) while True: cmd = await state.command_queue.get() if cmd is None: break action, symbol = cmd if action == "subscribe": self._log.debug( "subscribing_dynamic", symbol=symbol, ws_id=state.ws_id, ) await self._send_subscribe(ws, [symbol], state.ws_id) except asyncio.CancelledError: raise except websockets.ConnectionClosed as e: self._log.warning("ws_disconnected", ws_id=state.ws_id, code=e.code, reason=e.reason) except Exception as e: self._log.error("command_loop_error", ws_id=state.ws_id, error=str(e)) finally: state.connection_active = False ping_task.cancel() reader_task.cancel() try: await reader_task except asyncio.CancelledError: pass except asyncio.CancelledError: break except Exception as e: if not self._running: break state.connection_active = False state.reconnect_count += 1 state.last_reconnect_ts_ms = int(time.time() * 1000) self._log.warning( "ws_reconnecting", ws_id=state.ws_id, reconnect_count=state.reconnect_count, error=str(e), ) await asyncio.sleep(self._reconnect_delay) self._reconnect_delay = min( self._reconnect_delay * 2, self._settings.reconnect_max_delay, ) self._log.debug("worker_exiting", ws_id=state.ws_id) async def _get_public_token(self) -> tuple[str, dict]: self._log.debug("fetching_public_token", url=self._settings.token_url) async with self._http_session.post(self._settings.token_url) as resp: data = await resp.json() token = data["data"]["token"] instance = data["data"]["instanceServers"][0] self._log.debug("public_token_received", ping_interval_ms=instance.get("pingInterval")) return token, instance async def _send_subscribe(self, ws, symbols: list[str], ws_id: int) -> None: for i in range(0, len(symbols), 100): batch = symbols[i : i + 100] topic = "/spotMarket/level2Depth5:" + ",".join(batch) ack_id = str(uuid.uuid4()) evt = asyncio.Event() self._subscription_events[ack_id] = evt sub_msg = { "id": ack_id, "type": "subscribe", "topic": topic, "response": True, } self._log.debug("subscribing", topic=topic[:80], ws_id=ws_id) await ws.send(json.dumps(sub_msg)) try: await asyncio.wait_for(evt.wait(), timeout=self._reconnect_delay) except asyncio.TimeoutError: self._log.warning("subscription_ack_timeout", topic=topic[:80], ws_id=ws_id) raise async def _ping_loop(self, ws, ws_id: int) -> None: while self._running: await asyncio.sleep(self._ping_interval) try: await ws.ping() except Exception: self._log.warning("ping_failed", ws_id=ws_id) break async def _handle_message(self, msg: str) -> None: try: data = json.loads(msg) except json.JSONDecodeError: self._log.warning("invalid_json", msg=msg[:100]) return msg_type = data.get("type") if msg_type == "welcome": self._log.debug("ws_welcome") return if msg_type == "pong": return if msg_type == "ack": ack_id = data.get("id") self._log.debug("subscription_ack", topic=data.get("topic"), ack_id=ack_id) if ack_id in self._subscription_events: self._subscription_events[ack_id].set() del self._subscription_events[ack_id] return topic = data.get("topic", "") if msg_type == "message" and "level2Depth5" in topic: book = self._book_store.update(data) if book and self._on_book_update_callback: result = self._on_book_update_callback(book) if asyncio.iscoroutine(result): asyncio.create_task(result) elif topic: self._log.warning("ws_unexpected_message", type=msg_type, topic=topic)