484 lines
17 KiB
Python
484 lines
17 KiB
Python
"""
|
|
KuCoin private WebSocket client for fill event delivery.
|
|
|
|
Manages the private WebSocket connection, authenticates via bullet-private,
|
|
subscribes to /spotMarket/tradeOrdersV2, and dispatches fill events to
|
|
waiting executor coroutines.
|
|
|
|
KuCoin market orders placed with ``funds`` often complete with a tiny
|
|
unfilled remainder, emitting ``canceled`` + ``status=done`` rather than
|
|
``filled``. The message handler in ``_handle_message`` treats this
|
|
terminal combination as a successful fill when match events have been
|
|
accumulated.
|
|
"""
|
|
import asyncio
|
|
import json
|
|
import time
|
|
import uuid
|
|
from decimal import Decimal
|
|
from typing import Optional
|
|
|
|
import aiohttp
|
|
import structlog
|
|
import websockets
|
|
|
|
from executor.kucoin_api import KuCoinAPI
|
|
|
|
logger = structlog.get_logger().bind(component="executor-ws")
|
|
_D0 = Decimal("0")
|
|
|
|
|
|
class FillAccumulator:
|
|
"""Accumulate match events for a single order, compute aggregated totals."""
|
|
|
|
def __init__(self, client_oid: str, order_id: str = "") -> None:
|
|
self.client_oid = client_oid
|
|
self.order_id = order_id
|
|
self.total_size = _D0
|
|
self.total_funds = _D0
|
|
self.match_count = 0
|
|
self.side = ""
|
|
self.symbol = ""
|
|
self._done = False
|
|
|
|
def add_match(self, data: dict) -> None:
|
|
match_price = Decimal(str(data.get("matchPrice", "0")))
|
|
match_size = Decimal(str(data.get("matchSize", "0")))
|
|
self.total_size += match_size
|
|
self.total_funds += match_price * match_size
|
|
self.match_count += 1
|
|
if not self.side:
|
|
self.side = data.get("side", "")
|
|
if not self.symbol:
|
|
self.symbol = data.get("symbol", "")
|
|
if not self.order_id:
|
|
self.order_id = data.get("orderId", "")
|
|
|
|
@property
|
|
def weighted_avg_price(self) -> Decimal:
|
|
if self.total_size <= 0:
|
|
return _D0
|
|
return self.total_funds / self.total_size
|
|
|
|
def to_dict(self) -> dict:
|
|
return {
|
|
"total_size": self.total_size,
|
|
"total_funds": self.total_funds,
|
|
"weighted_avg_price": self.weighted_avg_price,
|
|
"match_count": self.match_count,
|
|
"order_id": self.order_id,
|
|
"side": self.side,
|
|
"symbol": self.symbol,
|
|
}
|
|
|
|
|
|
class KuCoinWSClient:
|
|
"""
|
|
Private WebSocket client for KuCoin execution events.
|
|
|
|
Subscribes to /spotMarket/tradeOrdersV2 (global topic) and dispatches
|
|
fill events to awaiting executor coroutines via await_fill().
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
kucoin_api: KuCoinAPI,
|
|
private_token_url: str = "https://api.kucoin.com/api/v1/bullet-private",
|
|
reconnect_base_delay: float = 1.0,
|
|
reconnect_max_delay: float = 30.0,
|
|
) -> None:
|
|
self._api = kucoin_api
|
|
self._private_token_url = private_token_url
|
|
self._reconnect_base_delay = reconnect_base_delay
|
|
self._reconnect_max_delay = reconnect_max_delay
|
|
self._reconnect_delay = reconnect_base_delay
|
|
self._log = logger
|
|
self._running = False
|
|
self._ws: Optional[websockets.WebSocketClientProtocol] = None
|
|
self._connected = False
|
|
self._ping_interval: float = 18.0
|
|
self._ping_timeout: float = 10.0
|
|
self._fill_futures: dict[str, asyncio.Future] = {}
|
|
self._accumulators: dict[str, FillAccumulator] = {}
|
|
self._balance_futures: dict[str, asyncio.Future] = {}
|
|
self._latest_balance: dict[str, Decimal] = {}
|
|
self._worker_task: Optional[asyncio.Task] = None
|
|
self._pending_acks: dict[str, asyncio.Event] = {}
|
|
|
|
@property
|
|
def is_connected(self) -> bool:
|
|
return self._connected
|
|
|
|
async def start(self) -> None:
|
|
"""Start the WebSocket connection worker with reconnection loop."""
|
|
self._running = True
|
|
self._worker_task = asyncio.create_task(self._connection_worker())
|
|
try:
|
|
await self._worker_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
async def stop(self) -> None:
|
|
"""Stop the WebSocket connection and resolve any pending futures."""
|
|
self._running = False
|
|
for future in self._fill_futures.values():
|
|
if not future.done():
|
|
future.set_result((False, {}))
|
|
self._fill_futures.clear()
|
|
self._accumulators.clear()
|
|
for future in self._balance_futures.values():
|
|
if not future.done():
|
|
future.set_result(False)
|
|
self._balance_futures.clear()
|
|
for evt in self._pending_acks.values():
|
|
evt.set()
|
|
self._pending_acks.clear()
|
|
if self._worker_task is not None and not self._worker_task.done():
|
|
self._worker_task.cancel()
|
|
if self._ws is not None:
|
|
try:
|
|
await self._ws.close()
|
|
except Exception:
|
|
pass
|
|
self._connected = False
|
|
self._log.debug("ws_client_stopped")
|
|
|
|
async def await_fill(
|
|
self, client_oid: str, timeout_ms: float
|
|
) -> tuple[bool, dict]:
|
|
"""
|
|
Wait for the order identified by client_oid to be fully filled.
|
|
|
|
Parameters
|
|
----------
|
|
client_oid : str
|
|
The client-order ID used when placing the order.
|
|
timeout_ms : float
|
|
Maximum wait time in milliseconds.
|
|
|
|
Returns
|
|
-------
|
|
tuple[bool, dict]
|
|
(success, aggregated_fill_data) on fill.
|
|
(False, {}) on timeout, failure, or disconnected.
|
|
"""
|
|
if not self._connected:
|
|
self._log.warning(
|
|
"await_fill_not_connected",
|
|
client_oid=client_oid,
|
|
)
|
|
return (False, {})
|
|
|
|
future: asyncio.Future = asyncio.get_event_loop().create_future()
|
|
self._fill_futures[client_oid] = future
|
|
|
|
if client_oid in self._accumulators and self._accumulators[client_oid]._done:
|
|
acc = self._accumulators[client_oid]
|
|
result = (True, acc.to_dict())
|
|
del self._accumulators[client_oid]
|
|
self._fill_futures.pop(client_oid, None)
|
|
return result
|
|
|
|
try:
|
|
await asyncio.wait_for(future, timeout=timeout_ms / 1000.0)
|
|
except asyncio.TimeoutError:
|
|
self._fill_futures.pop(client_oid, None)
|
|
self._log.warning(
|
|
"fill_timeout",
|
|
client_oid=client_oid,
|
|
timeout_ms=timeout_ms,
|
|
accumulator=(
|
|
self._accumulators.get(client_oid).to_dict()
|
|
if client_oid in self._accumulators
|
|
else None
|
|
),
|
|
)
|
|
return (False, {})
|
|
except asyncio.CancelledError:
|
|
self._fill_futures.pop(client_oid, None)
|
|
raise
|
|
|
|
result = future.result()
|
|
self._fill_futures.pop(client_oid, None)
|
|
if client_oid in self._accumulators:
|
|
del self._accumulators[client_oid]
|
|
return result
|
|
|
|
async def await_balance(
|
|
self, currency: str, min_available: Decimal, timeout_ms: float
|
|
) -> bool:
|
|
"""
|
|
Wait until the available balance for *currency* reaches at least *min_available*.
|
|
|
|
If the latest known balance already meets the threshold, returns immediately.
|
|
Otherwise registers a future and waits for a balance WS event.
|
|
|
|
Returns True once the threshold is met, False on timeout.
|
|
"""
|
|
key = currency.upper()
|
|
current = self._latest_balance.get(key, _D0)
|
|
if current >= min_available:
|
|
return True
|
|
|
|
future: asyncio.Future = asyncio.Future()
|
|
self._balance_futures[key] = future
|
|
try:
|
|
await asyncio.wait_for(future, timeout=timeout_ms / 1000.0)
|
|
return True
|
|
except asyncio.TimeoutError:
|
|
self._log.warning(
|
|
"await_balance_timeout",
|
|
currency=key,
|
|
min_available=str(min_available),
|
|
latest=str(self._latest_balance.get(key, _D0)),
|
|
)
|
|
return False
|
|
except asyncio.CancelledError:
|
|
raise
|
|
finally:
|
|
if self._balance_futures.get(key) is future:
|
|
del self._balance_futures[key]
|
|
|
|
def latest_balance(self, currency: str) -> Decimal:
|
|
"""Return the latest known available balance for *currency*, or 0."""
|
|
return self._latest_balance.get(currency.upper(), _D0)
|
|
|
|
async def _connection_worker(self) -> None:
|
|
"""Main connection loop with exponential backoff reconnection."""
|
|
while self._running:
|
|
try:
|
|
await self._connect_and_run()
|
|
except asyncio.CancelledError:
|
|
break
|
|
except Exception as e:
|
|
if not self._running:
|
|
break
|
|
self._connected = False
|
|
self._log.warning(
|
|
"ws_reconnecting",
|
|
error=str(e),
|
|
delay=self._reconnect_delay,
|
|
)
|
|
await asyncio.sleep(self._reconnect_delay)
|
|
self._reconnect_delay = min(
|
|
self._reconnect_delay * 2,
|
|
self._reconnect_max_delay,
|
|
)
|
|
|
|
async def _connect_and_run(self) -> None:
|
|
"""Authenticate, connect, subscribe, and run the message loop."""
|
|
async with aiohttp.ClientSession() as session:
|
|
token_data = await self._api.get_private_token(session)
|
|
if not token_data:
|
|
raise RuntimeError("Failed to obtain private token")
|
|
|
|
token = token_data.get("token", "")
|
|
servers = token_data.get("instanceServers", [])
|
|
if not servers:
|
|
raise RuntimeError("No instance servers in token response")
|
|
|
|
server = servers[0]
|
|
endpoint = server.get("endpoint", "")
|
|
self._ping_interval = server.get("pingInterval", 18000) / 1000.0
|
|
self._ping_timeout = server.get("pingTimeout", 10000) / 1000.0
|
|
|
|
ws_url = f"{endpoint}?token={token}&connectId={uuid.uuid4()}"
|
|
self._log.debug("ws_connecting", url=ws_url[:80])
|
|
|
|
self._ws = await websockets.connect(
|
|
ws_url,
|
|
ping_interval=self._ping_interval,
|
|
ping_timeout=self._ping_timeout,
|
|
)
|
|
self._connected = True
|
|
self._reconnect_delay = self._reconnect_base_delay
|
|
self._log.info("ws_connected")
|
|
|
|
sub_task = asyncio.create_task(self._subscribe())
|
|
|
|
try:
|
|
async for msg in self._ws:
|
|
await self._handle_message(msg)
|
|
except websockets.ConnectionClosed as e:
|
|
self._log.warning(
|
|
"ws_connection_closed",
|
|
code=e.code,
|
|
reason=e.reason,
|
|
)
|
|
except asyncio.CancelledError:
|
|
sub_task.cancel()
|
|
raise
|
|
except Exception as e:
|
|
self._log.error("ws_message_loop_error", error=str(e))
|
|
finally:
|
|
self._connected = False
|
|
|
|
async def _subscribe(self) -> None:
|
|
"""Subscribe to tradeOrdersV2 and balance channels."""
|
|
if self._ws is None:
|
|
return
|
|
|
|
ack_id1 = str(int(time.time() * 1000))
|
|
evt1 = asyncio.Event()
|
|
self._pending_acks[ack_id1] = evt1
|
|
sub_msg = {
|
|
"id": int(ack_id1),
|
|
"type": "subscribe",
|
|
"topic": "/spotMarket/tradeOrdersV2",
|
|
"response": True,
|
|
"privateChannel": "true",
|
|
}
|
|
await self._ws.send(json.dumps(sub_msg))
|
|
self._log.info("subscribe_sent", topic="/spotMarket/tradeOrdersV2")
|
|
|
|
ack_id2 = str(int(time.time() * 1000) + 1)
|
|
evt2 = asyncio.Event()
|
|
self._pending_acks[ack_id2] = evt2
|
|
bal_msg = {
|
|
"id": int(ack_id2),
|
|
"type": "subscribe",
|
|
"topic": "/account/balance",
|
|
"response": True,
|
|
"privateChannel": "true",
|
|
}
|
|
await self._ws.send(json.dumps(bal_msg))
|
|
self._log.info("bal_subscribe_sent", topic="/account/balance")
|
|
|
|
await asyncio.wait_for(evt2.wait(), timeout=5.0)
|
|
self._log.info("bal_subscribe_ack_received")
|
|
|
|
async def _handle_message(self, msg: str) -> None:
|
|
"""Parse incoming WS message and dispatch fill events."""
|
|
try:
|
|
data = json.loads(msg)
|
|
except json.JSONDecodeError:
|
|
self._log.warning("ws_raw_message_parse_error", raw=msg[:500])
|
|
return
|
|
|
|
msg_type = data.get("type")
|
|
|
|
if msg_type == "welcome":
|
|
return
|
|
|
|
if msg_type == "pong":
|
|
return
|
|
|
|
if msg_type == "ack":
|
|
ack_id = str(data.get("id", ""))
|
|
evt = self._pending_acks.pop(ack_id, None)
|
|
if evt is not None:
|
|
evt.set()
|
|
return
|
|
|
|
subject = data.get("subject", "")
|
|
|
|
if subject == "account.balance":
|
|
payload = data.get("data", {})
|
|
currency = (payload.get("currency", "")).upper()
|
|
available_raw = payload.get("available")
|
|
if currency and available_raw is not None:
|
|
available = Decimal(str(available_raw))
|
|
self._latest_balance[currency] = available
|
|
self._log.debug("balance_update", currency=currency, available=str(available))
|
|
future = self._balance_futures.get(currency)
|
|
if future is not None and not future.done():
|
|
future.set_result(True)
|
|
return
|
|
|
|
if subject != "orderChange":
|
|
return
|
|
|
|
payload = data.get("data", {})
|
|
event_type = payload.get("type", "")
|
|
client_oid = payload.get("clientOid", "")
|
|
order_id = payload.get("orderId", "")
|
|
status = payload.get("status", "")
|
|
|
|
if not client_oid:
|
|
return
|
|
|
|
if event_type == "match":
|
|
if client_oid not in self._accumulators:
|
|
self._accumulators[client_oid] = FillAccumulator(client_oid, order_id)
|
|
self._accumulators[client_oid].add_match(payload)
|
|
|
|
elif event_type == "filled" and status == "done":
|
|
if client_oid in self._accumulators:
|
|
acc = self._accumulators[client_oid]
|
|
acc.order_id = order_id or acc.order_id
|
|
fill_data = acc.to_dict()
|
|
self._log.debug(
|
|
"fill_received",
|
|
client_oid=client_oid,
|
|
order_id=order_id,
|
|
total_size=fill_data["total_size"],
|
|
total_funds=fill_data["total_funds"],
|
|
avg_price=fill_data["weighted_avg_price"],
|
|
match_count=fill_data["match_count"],
|
|
)
|
|
else:
|
|
# Shouldn't happen in normal flow, but handle defensively.
|
|
fill_data = {
|
|
"total_size": Decimal(str(payload.get("filledSize", "0"))),
|
|
"total_funds": _D0,
|
|
"weighted_avg_price": _D0,
|
|
"match_count": 0,
|
|
"order_id": order_id,
|
|
"side": payload.get("side", ""),
|
|
"symbol": payload.get("symbol", ""),
|
|
}
|
|
self._log.warning(
|
|
"filled_without_matches",
|
|
client_oid=client_oid,
|
|
order_id=order_id,
|
|
)
|
|
|
|
if client_oid in self._fill_futures:
|
|
future = self._fill_futures[client_oid]
|
|
if not future.done():
|
|
future.set_result((True, fill_data))
|
|
elif client_oid in self._accumulators:
|
|
self._accumulators[client_oid]._done = True
|
|
|
|
elif event_type in ("canceled", "failed"):
|
|
self._log.warning(
|
|
"ws_terminal_event_full_payload",
|
|
client_oid=client_oid,
|
|
event_type=event_type,
|
|
status=status,
|
|
full_payload=json.dumps(payload, indent=2),
|
|
)
|
|
# Market orders with `funds` send type="canceled" + status="done" when
|
|
# the order completes with a tiny remainder. If we have accumulated
|
|
# matches, this is actually a successful fill.
|
|
if (
|
|
event_type == "canceled"
|
|
and status == "done"
|
|
and client_oid in self._accumulators
|
|
):
|
|
acc = self._accumulators[client_oid]
|
|
if acc.match_count > 0:
|
|
acc.order_id = order_id or acc.order_id
|
|
fill_data = acc.to_dict()
|
|
self._log.info(
|
|
"fill_via_cancel_done",
|
|
client_oid=client_oid,
|
|
order_id=order_id,
|
|
total_size=fill_data["total_size"],
|
|
total_funds=fill_data["total_funds"],
|
|
avg_price=fill_data["weighted_avg_price"],
|
|
match_count=fill_data["match_count"],
|
|
)
|
|
if client_oid in self._fill_futures:
|
|
future = self._fill_futures[client_oid]
|
|
if not future.done():
|
|
future.set_result((True, fill_data))
|
|
elif client_oid in self._accumulators:
|
|
self._accumulators[client_oid]._done = True
|
|
return
|
|
|
|
if client_oid in self._fill_futures:
|
|
future = self._fill_futures[client_oid]
|
|
if not future.done():
|
|
future.set_result((False, {}))
|