from typing import Optional import asyncio import json from pathlib import Path import structlog from fh_ob.book_store import OrderBookTop5 class SocketServer: def __init__(self, socket_path: Path) -> None: self._socket_path = socket_path self._log = structlog.get_logger().bind(component="socket_server") self._clients: set[asyncio.StreamWriter] = set() self._server: Optional[asyncio.Server] = None async def start(self) -> None: if self._socket_path.exists(): self._socket_path.unlink() self._server = await asyncio.start_unix_server( self._accept_client, path=str(self._socket_path), ) self._log.info("socket_server_started", path=str(self._socket_path)) async def stop(self) -> None: if self._server: self._server.close() await self._server.wait_closed() if self._socket_path.exists(): self._socket_path.unlink() self._log.info("socket_server_stopped") def client_count(self) -> int: return len(self._clients) async def _accept_client( self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter ) -> None: self._clients.add(writer) self._log.info("client_connected", addr=writer.get_extra_info("peername")) try: while True: try: line = await reader.readline() except (ConnectionResetError, BrokenPipeError, asyncio.CancelledError): break except Exception: break if not line: break except asyncio.CancelledError: pass except Exception: pass finally: self._clients.discard(writer) writer.close() try: await asyncio.wait_for(writer.wait_closed(), timeout=1.0) except (asyncio.CancelledError, Exception): pass self._log.info("client_disconnected") async def broadcast(self, book: OrderBookTop5) -> None: if not self._clients: return msg_bytes = json.dumps(book.to_dict(), separators=(",", ":")).encode() + b"\n" clients_snapshot = list(self._clients) bad = set() for w in clients_snapshot: try: w.write(msg_bytes) except Exception as e: self._log.warning("broadcast_write_failed", error=str(e)) bad.add(w) if not clients_snapshot: return drain_results = await asyncio.gather( *(w.drain() for w in clients_snapshot), return_exceptions=True, ) for w, res in zip(clients_snapshot, drain_results): if isinstance(res, Exception): self._log.warning("broadcast_drain_failed", error=str(res)) bad.add(w) self._clients -= bad