triangular_arbitrage_bot/fh_ob/ws_client.py

272 lines
10 KiB
Python

import asyncio
import json
import time
import uuid
from dataclasses import dataclass, field
from typing import Callable, Optional, Awaitable
import aiohttp
import structlog
import websockets
from common.config import FHobSettings
from fh_ob.book_store import BookStore, OrderBookTop5
@dataclass
class _WorkerState:
symbols: set[str] = field(default_factory=set)
command_queue: asyncio.Queue = field(default_factory=asyncio.Queue)
ws_id: int = 0
reconnect_count: int = 0
last_reconnect_ts_ms: int = 0
connection_active: bool = False
class KuCoinWSClient:
def __init__(
self,
settings: FHobSettings,
book_store: BookStore,
on_book_update: Optional[Callable[[OrderBookTop5], None | Awaitable[None]]] = None,
) -> None:
self._settings = settings
self._book_store = book_store
self._on_book_update_callback = on_book_update
self._log = structlog.get_logger().bind(component="ws_client")
self._running = False
self._reconnect_delay = settings.reconnect_base_delay
self._subscription_events: dict[str, asyncio.Event] = {}
self._workers: list[_WorkerState] = []
self._worker_tasks: list[asyncio.Task] = []
self._http_session: Optional[aiohttp.ClientSession] = None
async def start(self) -> None:
self._running = True
self._workers.clear()
self._worker_tasks.clear()
self._http_session = aiohttp.ClientSession()
symbol_list = list(self._settings.symbols)
for i in range(0, len(symbol_list) or 1, 400):
group = set(symbol_list[i : i + 400])
ws_id = len(self._workers) + 1
state = _WorkerState(symbols=group, ws_id=ws_id)
self._workers.append(state)
for state in self._workers:
task = asyncio.create_task(self._connection_worker(state))
self._worker_tasks.append(task)
try:
await asyncio.gather(*self._worker_tasks)
except asyncio.CancelledError:
pass
self._log.debug("all_workers_stopped")
async def stop(self) -> None:
self._running = False
for t in self._worker_tasks:
t.cancel()
if self._worker_tasks:
await asyncio.wait(self._worker_tasks, timeout=5)
if self._http_session and not self._http_session.closed:
await self._http_session.close()
self._log.debug("ws_client_stopped")
def is_connected(self) -> bool:
return any(w.connection_active for w in self._workers)
def subscribed_count(self) -> int:
return sum(len(w.symbols) for w in self._workers)
def reconnect_stats(self) -> tuple[int, int]:
"""Return (total_reconnects, timestamp_ms of last reconnect) across all workers."""
total = sum(w.reconnect_count for w in self._workers)
latest = max((w.last_reconnect_ts_ms for w in self._workers), default=0)
return total, latest
def get_symbols(self) -> list[str]:
result = []
for w in self._workers:
result.extend(w.symbols)
return result
def add_symbol(self, symbol: str) -> bool:
if not self._workers:
return False
if any(symbol in w.symbols for w in self._workers):
return False
self._settings.symbols.append(symbol)
eligible = [w for w in self._workers if len(w.symbols) < 400]
if not eligible:
self._log.warning("all_workers_full", symbol=symbol)
return False
worker = min(eligible, key=lambda w: len(w.symbols))
worker.symbols.add(symbol)
worker.command_queue.put_nowait(("subscribe", symbol))
return True
def remove_symbol(self, symbol: str) -> bool:
found = False
for worker in self._workers:
if symbol in worker.symbols:
worker.symbols.discard(symbol)
found = True
break
if not found:
return False
self._settings.symbols.remove(symbol)
return True
async def _connection_worker(self, state: _WorkerState) -> None:
while self._running:
try:
token, instance = await self._get_public_token()
self._ping_interval = instance.get("pingInterval", 18000) / 1000.0
ws = await websockets.connect(
instance["endpoint"] + f"?token={token}&connectId={uuid.uuid4()}-{state.ws_id}",
ping_interval=None,
)
self._log.debug("ws_connected", ws_id=state.ws_id)
self._reconnect_delay = self._settings.reconnect_base_delay
state.connection_active = True
ping_task = asyncio.create_task(self._ping_loop(ws, state.ws_id))
async def reader() -> None:
try:
async for msg in ws:
await self._handle_message(msg)
except websockets.ConnectionClosed as e:
self._log.warning("reader_connection_closed", ws_id=state.ws_id, code=e.code, reason=e.reason)
except asyncio.CancelledError:
raise
except Exception as e:
self._log.error("reader_unexpected_error", ws_id=state.ws_id, error=str(e))
reader_task = asyncio.create_task(reader())
try:
if state.symbols:
await self._send_subscribe(ws, list(state.symbols), state.ws_id)
while True:
cmd = await state.command_queue.get()
if cmd is None:
break
action, symbol = cmd
if action == "subscribe":
self._log.debug(
"subscribing_dynamic",
symbol=symbol,
ws_id=state.ws_id,
)
await self._send_subscribe(ws, [symbol], state.ws_id)
except asyncio.CancelledError:
raise
except websockets.ConnectionClosed as e:
self._log.warning("ws_disconnected", ws_id=state.ws_id, code=e.code, reason=e.reason)
except Exception as e:
self._log.error("command_loop_error", ws_id=state.ws_id, error=str(e))
finally:
state.connection_active = False
ping_task.cancel()
reader_task.cancel()
try:
await reader_task
except asyncio.CancelledError:
pass
except asyncio.CancelledError:
break
except Exception as e:
if not self._running:
break
state.connection_active = False
state.reconnect_count += 1
state.last_reconnect_ts_ms = int(time.time() * 1000)
self._log.warning(
"ws_reconnecting",
ws_id=state.ws_id,
reconnect_count=state.reconnect_count,
error=str(e),
)
await asyncio.sleep(self._reconnect_delay)
self._reconnect_delay = min(
self._reconnect_delay * 2,
self._settings.reconnect_max_delay,
)
self._log.debug("worker_exiting", ws_id=state.ws_id)
async def _get_public_token(self) -> tuple[str, dict]:
self._log.debug("fetching_public_token", url=self._settings.token_url)
async with self._http_session.post(self._settings.token_url) as resp:
data = await resp.json()
token = data["data"]["token"]
instance = data["data"]["instanceServers"][0]
self._log.debug("public_token_received", ping_interval_ms=instance.get("pingInterval"))
return token, instance
async def _send_subscribe(self, ws, symbols: list[str], ws_id: int) -> None:
for i in range(0, len(symbols), 100):
batch = symbols[i : i + 100]
topic = "/spotMarket/level2Depth5:" + ",".join(batch)
ack_id = str(uuid.uuid4())
evt = asyncio.Event()
self._subscription_events[ack_id] = evt
sub_msg = {
"id": ack_id,
"type": "subscribe",
"topic": topic,
"response": True,
}
self._log.debug("subscribing", topic=topic[:80], ws_id=ws_id)
await ws.send(json.dumps(sub_msg))
try:
await asyncio.wait_for(evt.wait(), timeout=self._reconnect_delay)
except asyncio.TimeoutError:
self._log.warning("subscription_ack_timeout", topic=topic[:80], ws_id=ws_id)
raise
async def _ping_loop(self, ws, ws_id: int) -> None:
while self._running:
await asyncio.sleep(self._ping_interval)
try:
await ws.ping()
except Exception:
self._log.warning("ping_failed", ws_id=ws_id)
break
async def _handle_message(self, msg: str) -> None:
try:
data = json.loads(msg)
except json.JSONDecodeError:
self._log.warning("invalid_json", msg=msg[:100])
return
msg_type = data.get("type")
if msg_type == "welcome":
self._log.debug("ws_welcome")
return
if msg_type == "pong":
return
if msg_type == "ack":
ack_id = data.get("id")
self._log.debug("subscription_ack", topic=data.get("topic"), ack_id=ack_id)
if ack_id in self._subscription_events:
self._subscription_events[ack_id].set()
del self._subscription_events[ack_id]
return
topic = data.get("topic", "")
if msg_type == "message" and "level2Depth5" in topic:
book = self._book_store.update(data)
if book and self._on_book_update_callback:
result = self._on_book_update_callback(book)
if asyncio.iscoroutine(result):
asyncio.create_task(result)
elif topic:
self._log.warning("ws_unexpected_message", type=msg_type, topic=topic)