triangular_arbitrage_bot/executor/ws_client.py

484 lines
17 KiB
Python

"""
KuCoin private WebSocket client for fill event delivery.
Manages the private WebSocket connection, authenticates via bullet-private,
subscribes to /spotMarket/tradeOrdersV2, and dispatches fill events to
waiting executor coroutines.
KuCoin market orders placed with ``funds`` often complete with a tiny
unfilled remainder, emitting ``canceled`` + ``status=done`` rather than
``filled``. The message handler in ``_handle_message`` treats this
terminal combination as a successful fill when match events have been
accumulated.
"""
import asyncio
import json
import time
import uuid
from decimal import Decimal
from typing import Optional
import aiohttp
import structlog
import websockets
from executor.kucoin_api import KuCoinAPI
logger = structlog.get_logger().bind(component="executor-ws")
_D0 = Decimal("0")
class FillAccumulator:
"""Accumulate match events for a single order, compute aggregated totals."""
def __init__(self, client_oid: str, order_id: str = "") -> None:
self.client_oid = client_oid
self.order_id = order_id
self.total_size = _D0
self.total_funds = _D0
self.match_count = 0
self.side = ""
self.symbol = ""
self._done = False
def add_match(self, data: dict) -> None:
match_price = Decimal(str(data.get("matchPrice", "0")))
match_size = Decimal(str(data.get("matchSize", "0")))
self.total_size += match_size
self.total_funds += match_price * match_size
self.match_count += 1
if not self.side:
self.side = data.get("side", "")
if not self.symbol:
self.symbol = data.get("symbol", "")
if not self.order_id:
self.order_id = data.get("orderId", "")
@property
def weighted_avg_price(self) -> Decimal:
if self.total_size <= 0:
return _D0
return self.total_funds / self.total_size
def to_dict(self) -> dict:
return {
"total_size": self.total_size,
"total_funds": self.total_funds,
"weighted_avg_price": self.weighted_avg_price,
"match_count": self.match_count,
"order_id": self.order_id,
"side": self.side,
"symbol": self.symbol,
}
class KuCoinWSClient:
"""
Private WebSocket client for KuCoin execution events.
Subscribes to /spotMarket/tradeOrdersV2 (global topic) and dispatches
fill events to awaiting executor coroutines via await_fill().
"""
def __init__(
self,
kucoin_api: KuCoinAPI,
private_token_url: str = "https://api.kucoin.com/api/v1/bullet-private",
reconnect_base_delay: float = 1.0,
reconnect_max_delay: float = 30.0,
) -> None:
self._api = kucoin_api
self._private_token_url = private_token_url
self._reconnect_base_delay = reconnect_base_delay
self._reconnect_max_delay = reconnect_max_delay
self._reconnect_delay = reconnect_base_delay
self._log = logger
self._running = False
self._ws: Optional[websockets.WebSocketClientProtocol] = None
self._connected = False
self._ping_interval: float = 18.0
self._ping_timeout: float = 10.0
self._fill_futures: dict[str, asyncio.Future] = {}
self._accumulators: dict[str, FillAccumulator] = {}
self._balance_futures: dict[str, asyncio.Future] = {}
self._latest_balance: dict[str, Decimal] = {}
self._worker_task: Optional[asyncio.Task] = None
self._pending_acks: dict[str, asyncio.Event] = {}
@property
def is_connected(self) -> bool:
return self._connected
async def start(self) -> None:
"""Start the WebSocket connection worker with reconnection loop."""
self._running = True
self._worker_task = asyncio.create_task(self._connection_worker())
try:
await self._worker_task
except asyncio.CancelledError:
pass
async def stop(self) -> None:
"""Stop the WebSocket connection and resolve any pending futures."""
self._running = False
for future in self._fill_futures.values():
if not future.done():
future.set_result((False, {}))
self._fill_futures.clear()
self._accumulators.clear()
for future in self._balance_futures.values():
if not future.done():
future.set_result(False)
self._balance_futures.clear()
for evt in self._pending_acks.values():
evt.set()
self._pending_acks.clear()
if self._worker_task is not None and not self._worker_task.done():
self._worker_task.cancel()
if self._ws is not None:
try:
await self._ws.close()
except Exception:
pass
self._connected = False
self._log.debug("ws_client_stopped")
async def await_fill(
self, client_oid: str, timeout_ms: float
) -> tuple[bool, dict]:
"""
Wait for the order identified by client_oid to be fully filled.
Parameters
----------
client_oid : str
The client-order ID used when placing the order.
timeout_ms : float
Maximum wait time in milliseconds.
Returns
-------
tuple[bool, dict]
(success, aggregated_fill_data) on fill.
(False, {}) on timeout, failure, or disconnected.
"""
if not self._connected:
self._log.warning(
"await_fill_not_connected",
client_oid=client_oid,
)
return (False, {})
future: asyncio.Future = asyncio.get_event_loop().create_future()
self._fill_futures[client_oid] = future
if client_oid in self._accumulators and self._accumulators[client_oid]._done:
acc = self._accumulators[client_oid]
result = (True, acc.to_dict())
del self._accumulators[client_oid]
self._fill_futures.pop(client_oid, None)
return result
try:
await asyncio.wait_for(future, timeout=timeout_ms / 1000.0)
except asyncio.TimeoutError:
self._fill_futures.pop(client_oid, None)
self._log.warning(
"fill_timeout",
client_oid=client_oid,
timeout_ms=timeout_ms,
accumulator=(
self._accumulators.get(client_oid).to_dict()
if client_oid in self._accumulators
else None
),
)
return (False, {})
except asyncio.CancelledError:
self._fill_futures.pop(client_oid, None)
raise
result = future.result()
self._fill_futures.pop(client_oid, None)
if client_oid in self._accumulators:
del self._accumulators[client_oid]
return result
async def await_balance(
self, currency: str, min_available: Decimal, timeout_ms: float
) -> bool:
"""
Wait until the available balance for *currency* reaches at least *min_available*.
If the latest known balance already meets the threshold, returns immediately.
Otherwise registers a future and waits for a balance WS event.
Returns True once the threshold is met, False on timeout.
"""
key = currency.upper()
current = self._latest_balance.get(key, _D0)
if current >= min_available:
return True
future: asyncio.Future = asyncio.Future()
self._balance_futures[key] = future
try:
await asyncio.wait_for(future, timeout=timeout_ms / 1000.0)
return True
except asyncio.TimeoutError:
self._log.warning(
"await_balance_timeout",
currency=key,
min_available=str(min_available),
latest=str(self._latest_balance.get(key, _D0)),
)
return False
except asyncio.CancelledError:
raise
finally:
if self._balance_futures.get(key) is future:
del self._balance_futures[key]
def latest_balance(self, currency: str) -> Decimal:
"""Return the latest known available balance for *currency*, or 0."""
return self._latest_balance.get(currency.upper(), _D0)
async def _connection_worker(self) -> None:
"""Main connection loop with exponential backoff reconnection."""
while self._running:
try:
await self._connect_and_run()
except asyncio.CancelledError:
break
except Exception as e:
if not self._running:
break
self._connected = False
self._log.warning(
"ws_reconnecting",
error=str(e),
delay=self._reconnect_delay,
)
await asyncio.sleep(self._reconnect_delay)
self._reconnect_delay = min(
self._reconnect_delay * 2,
self._reconnect_max_delay,
)
async def _connect_and_run(self) -> None:
"""Authenticate, connect, subscribe, and run the message loop."""
async with aiohttp.ClientSession() as session:
token_data = await self._api.get_private_token(session)
if not token_data:
raise RuntimeError("Failed to obtain private token")
token = token_data.get("token", "")
servers = token_data.get("instanceServers", [])
if not servers:
raise RuntimeError("No instance servers in token response")
server = servers[0]
endpoint = server.get("endpoint", "")
self._ping_interval = server.get("pingInterval", 18000) / 1000.0
self._ping_timeout = server.get("pingTimeout", 10000) / 1000.0
ws_url = f"{endpoint}?token={token}&connectId={uuid.uuid4()}"
self._log.debug("ws_connecting", url=ws_url[:80])
self._ws = await websockets.connect(
ws_url,
ping_interval=self._ping_interval,
ping_timeout=self._ping_timeout,
)
self._connected = True
self._reconnect_delay = self._reconnect_base_delay
self._log.info("ws_connected")
sub_task = asyncio.create_task(self._subscribe())
try:
async for msg in self._ws:
await self._handle_message(msg)
except websockets.ConnectionClosed as e:
self._log.warning(
"ws_connection_closed",
code=e.code,
reason=e.reason,
)
except asyncio.CancelledError:
sub_task.cancel()
raise
except Exception as e:
self._log.error("ws_message_loop_error", error=str(e))
finally:
self._connected = False
async def _subscribe(self) -> None:
"""Subscribe to tradeOrdersV2 and balance channels."""
if self._ws is None:
return
ack_id1 = str(int(time.time() * 1000))
evt1 = asyncio.Event()
self._pending_acks[ack_id1] = evt1
sub_msg = {
"id": int(ack_id1),
"type": "subscribe",
"topic": "/spotMarket/tradeOrdersV2",
"response": True,
"privateChannel": "true",
}
await self._ws.send(json.dumps(sub_msg))
self._log.info("subscribe_sent", topic="/spotMarket/tradeOrdersV2")
ack_id2 = str(int(time.time() * 1000) + 1)
evt2 = asyncio.Event()
self._pending_acks[ack_id2] = evt2
bal_msg = {
"id": int(ack_id2),
"type": "subscribe",
"topic": "/account/balance",
"response": True,
"privateChannel": "true",
}
await self._ws.send(json.dumps(bal_msg))
self._log.info("bal_subscribe_sent", topic="/account/balance")
await asyncio.wait_for(evt2.wait(), timeout=5.0)
self._log.info("bal_subscribe_ack_received")
async def _handle_message(self, msg: str) -> None:
"""Parse incoming WS message and dispatch fill events."""
try:
data = json.loads(msg)
except json.JSONDecodeError:
self._log.warning("ws_raw_message_parse_error", raw=msg[:500])
return
msg_type = data.get("type")
if msg_type == "welcome":
return
if msg_type == "pong":
return
if msg_type == "ack":
ack_id = str(data.get("id", ""))
evt = self._pending_acks.pop(ack_id, None)
if evt is not None:
evt.set()
return
subject = data.get("subject", "")
if subject == "account.balance":
payload = data.get("data", {})
currency = (payload.get("currency", "")).upper()
available_raw = payload.get("available")
if currency and available_raw is not None:
available = Decimal(str(available_raw))
self._latest_balance[currency] = available
self._log.debug("balance_update", currency=currency, available=str(available))
future = self._balance_futures.get(currency)
if future is not None and not future.done():
future.set_result(True)
return
if subject != "orderChange":
return
payload = data.get("data", {})
event_type = payload.get("type", "")
client_oid = payload.get("clientOid", "")
order_id = payload.get("orderId", "")
status = payload.get("status", "")
if not client_oid:
return
if event_type == "match":
if client_oid not in self._accumulators:
self._accumulators[client_oid] = FillAccumulator(client_oid, order_id)
self._accumulators[client_oid].add_match(payload)
elif event_type == "filled" and status == "done":
if client_oid in self._accumulators:
acc = self._accumulators[client_oid]
acc.order_id = order_id or acc.order_id
fill_data = acc.to_dict()
self._log.debug(
"fill_received",
client_oid=client_oid,
order_id=order_id,
total_size=fill_data["total_size"],
total_funds=fill_data["total_funds"],
avg_price=fill_data["weighted_avg_price"],
match_count=fill_data["match_count"],
)
else:
# Shouldn't happen in normal flow, but handle defensively.
fill_data = {
"total_size": Decimal(str(payload.get("filledSize", "0"))),
"total_funds": _D0,
"weighted_avg_price": _D0,
"match_count": 0,
"order_id": order_id,
"side": payload.get("side", ""),
"symbol": payload.get("symbol", ""),
}
self._log.warning(
"filled_without_matches",
client_oid=client_oid,
order_id=order_id,
)
if client_oid in self._fill_futures:
future = self._fill_futures[client_oid]
if not future.done():
future.set_result((True, fill_data))
elif client_oid in self._accumulators:
self._accumulators[client_oid]._done = True
elif event_type in ("canceled", "failed"):
self._log.warning(
"ws_terminal_event_full_payload",
client_oid=client_oid,
event_type=event_type,
status=status,
full_payload=json.dumps(payload, indent=2),
)
# Market orders with `funds` send type="canceled" + status="done" when
# the order completes with a tiny remainder. If we have accumulated
# matches, this is actually a successful fill.
if (
event_type == "canceled"
and status == "done"
and client_oid in self._accumulators
):
acc = self._accumulators[client_oid]
if acc.match_count > 0:
acc.order_id = order_id or acc.order_id
fill_data = acc.to_dict()
self._log.info(
"fill_via_cancel_done",
client_oid=client_oid,
order_id=order_id,
total_size=fill_data["total_size"],
total_funds=fill_data["total_funds"],
avg_price=fill_data["weighted_avg_price"],
match_count=fill_data["match_count"],
)
if client_oid in self._fill_futures:
future = self._fill_futures[client_oid]
if not future.done():
future.set_result((True, fill_data))
elif client_oid in self._accumulators:
self._accumulators[client_oid]._done = True
return
if client_oid in self._fill_futures:
future = self._fill_futures[client_oid]
if not future.done():
future.set_result((False, {}))