triangular_arbitrage_bot/fh_ob/socket_server.py

95 lines
2.9 KiB
Python

from typing import Optional
import asyncio
import json
from pathlib import Path
import structlog
from fh_ob.book_store import OrderBookTop5
class SocketServer:
def __init__(self, socket_path: Path) -> None:
self._socket_path = socket_path
self._log = structlog.get_logger().bind(component="socket_server")
self._clients: set[asyncio.StreamWriter] = set()
self._server: Optional[asyncio.Server] = None
async def start(self) -> None:
if self._socket_path.exists():
self._socket_path.unlink()
self._server = await asyncio.start_unix_server(
self._accept_client,
path=str(self._socket_path),
)
self._log.info("socket_server_started", path=str(self._socket_path))
async def stop(self) -> None:
if self._server:
self._server.close()
await self._server.wait_closed()
if self._socket_path.exists():
self._socket_path.unlink()
self._log.info("socket_server_stopped")
def client_count(self) -> int:
return len(self._clients)
async def _accept_client(
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
) -> None:
self._clients.add(writer)
self._log.info("client_connected", addr=writer.get_extra_info("peername"))
try:
while True:
try:
line = await reader.readline()
except (ConnectionResetError, BrokenPipeError, asyncio.CancelledError):
break
except Exception:
break
if not line:
break
except asyncio.CancelledError:
pass
except Exception:
pass
finally:
self._clients.discard(writer)
writer.close()
try:
await asyncio.wait_for(writer.wait_closed(), timeout=1.0)
except (asyncio.CancelledError, Exception):
pass
self._log.info("client_disconnected")
async def broadcast(self, book: OrderBookTop5) -> None:
if not self._clients:
return
msg_bytes = json.dumps(book.to_dict(), separators=(",", ":")).encode() + b"\n"
clients_snapshot = list(self._clients)
bad = set()
for w in clients_snapshot:
try:
w.write(msg_bytes)
except Exception as e:
self._log.warning("broadcast_write_failed", error=str(e))
bad.add(w)
if not clients_snapshot:
return
drain_results = await asyncio.gather(
*(w.drain() for w in clients_snapshot),
return_exceptions=True,
)
for w, res in zip(clients_snapshot, drain_results):
if isinstance(res, Exception):
self._log.warning("broadcast_drain_failed", error=str(res))
bad.add(w)
self._clients -= bad