95 lines
2.9 KiB
Python
95 lines
2.9 KiB
Python
from typing import Optional
|
|
|
|
import asyncio
|
|
import json
|
|
from pathlib import Path
|
|
|
|
import structlog
|
|
|
|
from fh_ob.book_store import OrderBookTop5
|
|
|
|
|
|
class SocketServer:
|
|
def __init__(self, socket_path: Path) -> None:
|
|
self._socket_path = socket_path
|
|
self._log = structlog.get_logger().bind(component="socket_server")
|
|
self._clients: set[asyncio.StreamWriter] = set()
|
|
self._server: Optional[asyncio.Server] = None
|
|
|
|
async def start(self) -> None:
|
|
if self._socket_path.exists():
|
|
self._socket_path.unlink()
|
|
|
|
self._server = await asyncio.start_unix_server(
|
|
self._accept_client,
|
|
path=str(self._socket_path),
|
|
)
|
|
self._log.info("socket_server_started", path=str(self._socket_path))
|
|
|
|
async def stop(self) -> None:
|
|
if self._server:
|
|
self._server.close()
|
|
await self._server.wait_closed()
|
|
if self._socket_path.exists():
|
|
self._socket_path.unlink()
|
|
self._log.info("socket_server_stopped")
|
|
|
|
def client_count(self) -> int:
|
|
return len(self._clients)
|
|
|
|
async def _accept_client(
|
|
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
|
|
) -> None:
|
|
self._clients.add(writer)
|
|
self._log.info("client_connected", addr=writer.get_extra_info("peername"))
|
|
try:
|
|
while True:
|
|
try:
|
|
line = await reader.readline()
|
|
except (ConnectionResetError, BrokenPipeError, asyncio.CancelledError):
|
|
break
|
|
except Exception:
|
|
break
|
|
if not line:
|
|
break
|
|
except asyncio.CancelledError:
|
|
pass
|
|
except Exception:
|
|
pass
|
|
finally:
|
|
self._clients.discard(writer)
|
|
writer.close()
|
|
try:
|
|
await asyncio.wait_for(writer.wait_closed(), timeout=1.0)
|
|
except (asyncio.CancelledError, Exception):
|
|
pass
|
|
self._log.info("client_disconnected")
|
|
|
|
async def broadcast(self, book: OrderBookTop5) -> None:
|
|
if not self._clients:
|
|
return
|
|
|
|
msg_bytes = json.dumps(book.to_dict(), separators=(",", ":")).encode() + b"\n"
|
|
|
|
clients_snapshot = list(self._clients)
|
|
bad = set()
|
|
for w in clients_snapshot:
|
|
try:
|
|
w.write(msg_bytes)
|
|
except Exception as e:
|
|
self._log.warning("broadcast_write_failed", error=str(e))
|
|
bad.add(w)
|
|
|
|
if not clients_snapshot:
|
|
return
|
|
|
|
drain_results = await asyncio.gather(
|
|
*(w.drain() for w in clients_snapshot),
|
|
return_exceptions=True,
|
|
)
|
|
for w, res in zip(clients_snapshot, drain_results):
|
|
if isinstance(res, Exception):
|
|
self._log.warning("broadcast_drain_failed", error=str(res))
|
|
bad.add(w)
|
|
|
|
self._clients -= bad |