272 lines
10 KiB
Python
272 lines
10 KiB
Python
import asyncio
|
|
import json
|
|
import time
|
|
import uuid
|
|
from dataclasses import dataclass, field
|
|
from typing import Callable, Optional, Awaitable
|
|
|
|
import aiohttp
|
|
import structlog
|
|
import websockets
|
|
|
|
from common.config import FHobSettings
|
|
from fh_ob.book_store import BookStore, OrderBookTop5
|
|
|
|
|
|
@dataclass
|
|
class _WorkerState:
|
|
symbols: set[str] = field(default_factory=set)
|
|
command_queue: asyncio.Queue = field(default_factory=asyncio.Queue)
|
|
ws_id: int = 0
|
|
reconnect_count: int = 0
|
|
last_reconnect_ts_ms: int = 0
|
|
connection_active: bool = False
|
|
|
|
|
|
class KuCoinWSClient:
|
|
def __init__(
|
|
self,
|
|
settings: FHobSettings,
|
|
book_store: BookStore,
|
|
on_book_update: Optional[Callable[[OrderBookTop5], None | Awaitable[None]]] = None,
|
|
) -> None:
|
|
self._settings = settings
|
|
self._book_store = book_store
|
|
self._on_book_update_callback = on_book_update
|
|
self._log = structlog.get_logger().bind(component="ws_client")
|
|
self._running = False
|
|
self._reconnect_delay = settings.reconnect_base_delay
|
|
self._subscription_events: dict[str, asyncio.Event] = {}
|
|
self._workers: list[_WorkerState] = []
|
|
self._worker_tasks: list[asyncio.Task] = []
|
|
self._http_session: Optional[aiohttp.ClientSession] = None
|
|
|
|
async def start(self) -> None:
|
|
self._running = True
|
|
self._workers.clear()
|
|
self._worker_tasks.clear()
|
|
self._http_session = aiohttp.ClientSession()
|
|
symbol_list = list(self._settings.symbols)
|
|
for i in range(0, len(symbol_list) or 1, 400):
|
|
group = set(symbol_list[i : i + 400])
|
|
ws_id = len(self._workers) + 1
|
|
state = _WorkerState(symbols=group, ws_id=ws_id)
|
|
self._workers.append(state)
|
|
for state in self._workers:
|
|
task = asyncio.create_task(self._connection_worker(state))
|
|
self._worker_tasks.append(task)
|
|
try:
|
|
await asyncio.gather(*self._worker_tasks)
|
|
except asyncio.CancelledError:
|
|
pass
|
|
self._log.debug("all_workers_stopped")
|
|
|
|
async def stop(self) -> None:
|
|
self._running = False
|
|
for t in self._worker_tasks:
|
|
t.cancel()
|
|
if self._worker_tasks:
|
|
await asyncio.wait(self._worker_tasks, timeout=5)
|
|
if self._http_session and not self._http_session.closed:
|
|
await self._http_session.close()
|
|
self._log.debug("ws_client_stopped")
|
|
|
|
def is_connected(self) -> bool:
|
|
return any(w.connection_active for w in self._workers)
|
|
|
|
def subscribed_count(self) -> int:
|
|
return sum(len(w.symbols) for w in self._workers)
|
|
|
|
def reconnect_stats(self) -> tuple[int, int]:
|
|
"""Return (total_reconnects, timestamp_ms of last reconnect) across all workers."""
|
|
total = sum(w.reconnect_count for w in self._workers)
|
|
latest = max((w.last_reconnect_ts_ms for w in self._workers), default=0)
|
|
return total, latest
|
|
|
|
def get_symbols(self) -> list[str]:
|
|
result = []
|
|
for w in self._workers:
|
|
result.extend(w.symbols)
|
|
return result
|
|
|
|
def add_symbol(self, symbol: str) -> bool:
|
|
if not self._workers:
|
|
return False
|
|
if any(symbol in w.symbols for w in self._workers):
|
|
return False
|
|
self._settings.symbols.append(symbol)
|
|
eligible = [w for w in self._workers if len(w.symbols) < 400]
|
|
if not eligible:
|
|
self._log.warning("all_workers_full", symbol=symbol)
|
|
return False
|
|
worker = min(eligible, key=lambda w: len(w.symbols))
|
|
worker.symbols.add(symbol)
|
|
worker.command_queue.put_nowait(("subscribe", symbol))
|
|
return True
|
|
|
|
def remove_symbol(self, symbol: str) -> bool:
|
|
found = False
|
|
for worker in self._workers:
|
|
if symbol in worker.symbols:
|
|
worker.symbols.discard(symbol)
|
|
found = True
|
|
break
|
|
if not found:
|
|
return False
|
|
self._settings.symbols.remove(symbol)
|
|
return True
|
|
|
|
async def _connection_worker(self, state: _WorkerState) -> None:
|
|
while self._running:
|
|
try:
|
|
token, instance = await self._get_public_token()
|
|
self._ping_interval = instance.get("pingInterval", 18000) / 1000.0
|
|
ws = await websockets.connect(
|
|
instance["endpoint"] + f"?token={token}&connectId={uuid.uuid4()}-{state.ws_id}",
|
|
ping_interval=None,
|
|
)
|
|
self._log.debug("ws_connected", ws_id=state.ws_id)
|
|
self._reconnect_delay = self._settings.reconnect_base_delay
|
|
state.connection_active = True
|
|
|
|
ping_task = asyncio.create_task(self._ping_loop(ws, state.ws_id))
|
|
|
|
async def reader() -> None:
|
|
try:
|
|
async for msg in ws:
|
|
await self._handle_message(msg)
|
|
except websockets.ConnectionClosed as e:
|
|
self._log.warning("reader_connection_closed", ws_id=state.ws_id, code=e.code, reason=e.reason)
|
|
except asyncio.CancelledError:
|
|
raise
|
|
except Exception as e:
|
|
self._log.error("reader_unexpected_error", ws_id=state.ws_id, error=str(e))
|
|
|
|
reader_task = asyncio.create_task(reader())
|
|
|
|
try:
|
|
if state.symbols:
|
|
await self._send_subscribe(ws, list(state.symbols), state.ws_id)
|
|
|
|
while True:
|
|
cmd = await state.command_queue.get()
|
|
if cmd is None:
|
|
break
|
|
action, symbol = cmd
|
|
if action == "subscribe":
|
|
self._log.debug(
|
|
"subscribing_dynamic",
|
|
symbol=symbol,
|
|
ws_id=state.ws_id,
|
|
)
|
|
await self._send_subscribe(ws, [symbol], state.ws_id)
|
|
except asyncio.CancelledError:
|
|
raise
|
|
except websockets.ConnectionClosed as e:
|
|
self._log.warning("ws_disconnected", ws_id=state.ws_id, code=e.code, reason=e.reason)
|
|
except Exception as e:
|
|
self._log.error("command_loop_error", ws_id=state.ws_id, error=str(e))
|
|
finally:
|
|
state.connection_active = False
|
|
ping_task.cancel()
|
|
reader_task.cancel()
|
|
try:
|
|
await reader_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
except asyncio.CancelledError:
|
|
break
|
|
except Exception as e:
|
|
if not self._running:
|
|
break
|
|
state.connection_active = False
|
|
state.reconnect_count += 1
|
|
state.last_reconnect_ts_ms = int(time.time() * 1000)
|
|
self._log.warning(
|
|
"ws_reconnecting",
|
|
ws_id=state.ws_id,
|
|
reconnect_count=state.reconnect_count,
|
|
error=str(e),
|
|
)
|
|
await asyncio.sleep(self._reconnect_delay)
|
|
self._reconnect_delay = min(
|
|
self._reconnect_delay * 2,
|
|
self._settings.reconnect_max_delay,
|
|
)
|
|
|
|
self._log.debug("worker_exiting", ws_id=state.ws_id)
|
|
|
|
async def _get_public_token(self) -> tuple[str, dict]:
|
|
self._log.debug("fetching_public_token", url=self._settings.token_url)
|
|
async with self._http_session.post(self._settings.token_url) as resp:
|
|
data = await resp.json()
|
|
token = data["data"]["token"]
|
|
instance = data["data"]["instanceServers"][0]
|
|
self._log.debug("public_token_received", ping_interval_ms=instance.get("pingInterval"))
|
|
return token, instance
|
|
|
|
async def _send_subscribe(self, ws, symbols: list[str], ws_id: int) -> None:
|
|
for i in range(0, len(symbols), 100):
|
|
batch = symbols[i : i + 100]
|
|
topic = "/spotMarket/level2Depth5:" + ",".join(batch)
|
|
ack_id = str(uuid.uuid4())
|
|
evt = asyncio.Event()
|
|
self._subscription_events[ack_id] = evt
|
|
sub_msg = {
|
|
"id": ack_id,
|
|
"type": "subscribe",
|
|
"topic": topic,
|
|
"response": True,
|
|
}
|
|
self._log.debug("subscribing", topic=topic[:80], ws_id=ws_id)
|
|
await ws.send(json.dumps(sub_msg))
|
|
try:
|
|
await asyncio.wait_for(evt.wait(), timeout=self._reconnect_delay)
|
|
except asyncio.TimeoutError:
|
|
self._log.warning("subscription_ack_timeout", topic=topic[:80], ws_id=ws_id)
|
|
raise
|
|
|
|
async def _ping_loop(self, ws, ws_id: int) -> None:
|
|
while self._running:
|
|
await asyncio.sleep(self._ping_interval)
|
|
try:
|
|
await ws.ping()
|
|
except Exception:
|
|
self._log.warning("ping_failed", ws_id=ws_id)
|
|
break
|
|
|
|
async def _handle_message(self, msg: str) -> None:
|
|
try:
|
|
data = json.loads(msg)
|
|
except json.JSONDecodeError:
|
|
self._log.warning("invalid_json", msg=msg[:100])
|
|
return
|
|
|
|
msg_type = data.get("type")
|
|
|
|
if msg_type == "welcome":
|
|
self._log.debug("ws_welcome")
|
|
return
|
|
|
|
if msg_type == "pong":
|
|
return
|
|
|
|
if msg_type == "ack":
|
|
ack_id = data.get("id")
|
|
self._log.debug("subscription_ack", topic=data.get("topic"), ack_id=ack_id)
|
|
if ack_id in self._subscription_events:
|
|
self._subscription_events[ack_id].set()
|
|
del self._subscription_events[ack_id]
|
|
return
|
|
|
|
topic = data.get("topic", "")
|
|
|
|
if msg_type == "message" and "level2Depth5" in topic:
|
|
book = self._book_store.update(data)
|
|
if book and self._on_book_update_callback:
|
|
result = self._on_book_update_callback(book)
|
|
if asyncio.iscoroutine(result):
|
|
asyncio.create_task(result)
|
|
elif topic:
|
|
self._log.warning("ws_unexpected_message", type=msg_type, topic=topic)
|