bithuman 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bithuman/__init__.py +13 -0
- bithuman/_version.py +1 -0
- bithuman/api.py +164 -0
- bithuman/audio/__init__.py +19 -0
- bithuman/audio/audio.py +396 -0
- bithuman/audio/hparams.py +108 -0
- bithuman/audio/utils.py +255 -0
- bithuman/config.py +88 -0
- bithuman/engine/__init__.py +15 -0
- bithuman/engine/auth.py +335 -0
- bithuman/engine/compression.py +257 -0
- bithuman/engine/enums.py +16 -0
- bithuman/engine/image_ops.py +192 -0
- bithuman/engine/inference.py +108 -0
- bithuman/engine/knn.py +58 -0
- bithuman/engine/video_data.py +391 -0
- bithuman/engine/video_reader.py +168 -0
- bithuman/lib/__init__.py +1 -0
- bithuman/lib/audio_encoder.onnx +45631 -28
- bithuman/lib/generator.py +763 -0
- bithuman/lib/pth2h5.py +106 -0
- bithuman/plugins/__init__.py +0 -0
- bithuman/plugins/stt.py +185 -0
- bithuman/runtime.py +1004 -0
- bithuman/runtime_async.py +469 -0
- bithuman/service/__init__.py +9 -0
- bithuman/service/client.py +788 -0
- bithuman/service/messages.py +210 -0
- bithuman/service/server.py +759 -0
- bithuman/utils/__init__.py +43 -0
- bithuman/utils/agent.py +359 -0
- bithuman/utils/fps_controller.py +90 -0
- bithuman/utils/image.py +41 -0
- bithuman/utils/unzip.py +38 -0
- bithuman/video_graph/__init__.py +16 -0
- bithuman/video_graph/action_trigger.py +83 -0
- bithuman/video_graph/driver_video.py +482 -0
- bithuman/video_graph/navigator.py +736 -0
- bithuman/video_graph/trigger.py +90 -0
- bithuman/video_graph/video_script.py +344 -0
- bithuman-1.0.2.dist-info/METADATA +37 -0
- bithuman-1.0.2.dist-info/RECORD +44 -0
- bithuman-1.0.2.dist-info/WHEEL +5 -0
- bithuman-1.0.2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,788 @@
|
|
|
1
|
+
"""ZMQ client for bithuman runtime service."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import argparse
|
|
6
|
+
import asyncio
|
|
7
|
+
import os
|
|
8
|
+
import sys
|
|
9
|
+
import time
|
|
10
|
+
from collections import deque
|
|
11
|
+
from typing import Any, Awaitable, Callable, Dict, Optional
|
|
12
|
+
|
|
13
|
+
import cv2
|
|
14
|
+
import msgpack
|
|
15
|
+
import numpy as np
|
|
16
|
+
from loguru import logger
|
|
17
|
+
|
|
18
|
+
try:
|
|
19
|
+
import zmq
|
|
20
|
+
import zmq.asyncio
|
|
21
|
+
except ImportError:
|
|
22
|
+
raise ImportError("zmq is required for bithuman runtime client")
|
|
23
|
+
|
|
24
|
+
from bithuman.api import AudioChunk, VideoControl
|
|
25
|
+
from bithuman.audio import AudioStreamBatcher, float32_to_int16, load_audio
|
|
26
|
+
from bithuman.service.messages import (
|
|
27
|
+
AudioRequest,
|
|
28
|
+
CheckInitStatusRequest,
|
|
29
|
+
FrameMessage,
|
|
30
|
+
GetSettingRequest,
|
|
31
|
+
HeartbeatRequest,
|
|
32
|
+
InitRequest,
|
|
33
|
+
InterruptRequest,
|
|
34
|
+
ResponseStatus,
|
|
35
|
+
ServerResponse,
|
|
36
|
+
)
|
|
37
|
+
from bithuman.video_graph.video_script import VideoScript
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class FPSMonitor:
|
|
41
|
+
"""Monitor FPS with a sliding window.
|
|
42
|
+
|
|
43
|
+
Tracks frame timestamps and calculates current FPS based on
|
|
44
|
+
a moving window of recent frames.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
def __init__(self, window_size: int = 30) -> None:
|
|
48
|
+
"""Initialize FPS monitor."""
|
|
49
|
+
self.timestamps: deque = deque(maxlen=window_size)
|
|
50
|
+
self.last_fps_update: float = time.time()
|
|
51
|
+
self.current_fps: float = 0
|
|
52
|
+
logger.debug(f"Initialized FPSMonitor with window_size={window_size}")
|
|
53
|
+
|
|
54
|
+
def update(self) -> None:
|
|
55
|
+
"""Add new frame timestamp and update FPS if needed."""
|
|
56
|
+
self.timestamps.append(time.time())
|
|
57
|
+
|
|
58
|
+
if time.time() - self.last_fps_update > 0.5:
|
|
59
|
+
self._calculate_fps()
|
|
60
|
+
self.last_fps_update = time.time()
|
|
61
|
+
|
|
62
|
+
def _calculate_fps(self) -> None:
|
|
63
|
+
"""Calculate current FPS from timestamp differences."""
|
|
64
|
+
if len(self.timestamps) < 2:
|
|
65
|
+
self.current_fps = 0
|
|
66
|
+
return
|
|
67
|
+
|
|
68
|
+
time_diff = self.timestamps[-1] - self.timestamps[0]
|
|
69
|
+
if time_diff > 0:
|
|
70
|
+
self.current_fps = (len(self.timestamps) - 1) / time_diff
|
|
71
|
+
|
|
72
|
+
@property
|
|
73
|
+
def fps(self) -> float:
|
|
74
|
+
"""Get current FPS value."""
|
|
75
|
+
return self.current_fps
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class ZMQBithumanRuntimeClient:
|
|
79
|
+
"""Async client for Bithuman Runtime using ZMQ.
|
|
80
|
+
|
|
81
|
+
Handles communication with server including:
|
|
82
|
+
- Workspace initialization
|
|
83
|
+
- Audio streaming
|
|
84
|
+
- Frame receiving
|
|
85
|
+
- Heartbeat monitoring
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
def __init__(
|
|
89
|
+
self,
|
|
90
|
+
client_id: str,
|
|
91
|
+
host: str = "127.0.0.1",
|
|
92
|
+
control_port: int = 5555,
|
|
93
|
+
stream_port: int = 5556,
|
|
94
|
+
request_timeout: float = 5,
|
|
95
|
+
max_consecutive_errors: int = 3,
|
|
96
|
+
) -> None:
|
|
97
|
+
"""Initialize client.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
client_id: Unique client identifier
|
|
101
|
+
host: Host address
|
|
102
|
+
control_port: Port for control messages
|
|
103
|
+
stream_port: Port for frame streaming
|
|
104
|
+
request_timeout: Timeout for requests in seconds
|
|
105
|
+
max_consecutive_errors: Number of consecutive errors before disconnection
|
|
106
|
+
"""
|
|
107
|
+
logger.info(
|
|
108
|
+
f"Initializing BithumanRuntimeClient {client_id} connecting to ports "
|
|
109
|
+
f"{control_port} (control) and {stream_port} (stream)"
|
|
110
|
+
)
|
|
111
|
+
self.client_id = client_id
|
|
112
|
+
self.context = zmq.asyncio.Context()
|
|
113
|
+
self.request_timeout = request_timeout
|
|
114
|
+
self.max_consecutive_errors = max_consecutive_errors
|
|
115
|
+
|
|
116
|
+
# Control socket (REQ/REP)
|
|
117
|
+
self.control_socket = self.context.socket(zmq.REQ)
|
|
118
|
+
# Set timeouts (in milliseconds)
|
|
119
|
+
if request_timeout and request_timeout > 0:
|
|
120
|
+
self.control_socket.setsockopt(zmq.RCVTIMEO, int(request_timeout * 1000))
|
|
121
|
+
self.control_socket.setsockopt(zmq.SNDTIMEO, int(request_timeout * 1000))
|
|
122
|
+
logger.debug(f"Set request timeouts to {request_timeout}s")
|
|
123
|
+
self.control_socket.connect(f"tcp://{host}:{control_port}")
|
|
124
|
+
logger.debug(f"Connected control socket to tcp://{host}:{control_port}")
|
|
125
|
+
|
|
126
|
+
# Subscribe socket for frames
|
|
127
|
+
self.stream_socket = self.context.socket(zmq.SUB)
|
|
128
|
+
self.stream_socket.connect(f"tcp://{host}:{stream_port}")
|
|
129
|
+
self.stream_socket.setsockopt_string(zmq.SUBSCRIBE, client_id)
|
|
130
|
+
logger.debug(f"Connected stream socket to tcp://{host}:{stream_port}")
|
|
131
|
+
|
|
132
|
+
# Frame callback and FPS monitoring
|
|
133
|
+
self.frame_callback: Optional[
|
|
134
|
+
Callable[[FrameMessage, Dict[str, Any]], Awaitable[None] | None]
|
|
135
|
+
] = None
|
|
136
|
+
self.fps_monitor = FPSMonitor()
|
|
137
|
+
self.running = True
|
|
138
|
+
self.is_initialized = False
|
|
139
|
+
|
|
140
|
+
# Add stream batcher
|
|
141
|
+
self.stream_batcher = AudioStreamBatcher(fps=25, output_sample_rate=16000)
|
|
142
|
+
|
|
143
|
+
# Tasks
|
|
144
|
+
self._frame_task: Optional[asyncio.Task] = None
|
|
145
|
+
self._heartbeat_task: Optional[asyncio.Task] = None
|
|
146
|
+
self.init_frame: Optional[np.ndarray] = None
|
|
147
|
+
self.first_frame_received = False
|
|
148
|
+
|
|
149
|
+
# Add socket state tracking
|
|
150
|
+
self._socket_lock = asyncio.Lock()
|
|
151
|
+
self._is_closed = False
|
|
152
|
+
|
|
153
|
+
# Add connection state callbacks
|
|
154
|
+
self.on_disconnected: Optional[Callable[[], None]] = None
|
|
155
|
+
self.on_connection_error: Optional[Callable[[str], None]] = None
|
|
156
|
+
|
|
157
|
+
# Add video script
|
|
158
|
+
self._video_script: Optional[VideoScript] = None
|
|
159
|
+
|
|
160
|
+
def set_connection_callbacks(
|
|
161
|
+
self,
|
|
162
|
+
on_disconnected: Optional[Callable[[], None]] = None,
|
|
163
|
+
on_connection_error: Optional[Callable[[str], None]] = None,
|
|
164
|
+
) -> None:
|
|
165
|
+
"""Set callbacks for connection state changes."""
|
|
166
|
+
self.on_disconnected = on_disconnected
|
|
167
|
+
self.on_connection_error = on_connection_error
|
|
168
|
+
|
|
169
|
+
async def _handle_connection_error(self, error_msg: str) -> None:
|
|
170
|
+
"""Handle connection errors."""
|
|
171
|
+
logger.error(f"Connection error: {error_msg}")
|
|
172
|
+
if self.on_connection_error:
|
|
173
|
+
self.on_connection_error(error_msg)
|
|
174
|
+
|
|
175
|
+
async def _handle_disconnection(self) -> None:
|
|
176
|
+
"""Handle server disconnection."""
|
|
177
|
+
logger.warning("Server disconnected")
|
|
178
|
+
if self.on_disconnected:
|
|
179
|
+
self.on_disconnected()
|
|
180
|
+
|
|
181
|
+
async def start(self) -> None:
|
|
182
|
+
"""Start client tasks."""
|
|
183
|
+
logger.info("Starting client tasks")
|
|
184
|
+
self._frame_task = asyncio.create_task(self._receive_frames())
|
|
185
|
+
|
|
186
|
+
async def wait_for_init_frame(self, timeout: Optional[float] = None) -> None:
|
|
187
|
+
"""Wait for workspace initialization."""
|
|
188
|
+
logger.info("Waiting for initialization frame")
|
|
189
|
+
start_time = time.time()
|
|
190
|
+
while self.init_frame is None:
|
|
191
|
+
if timeout and time.time() - start_time > timeout:
|
|
192
|
+
logger.error("Timed out waiting for initialization frame")
|
|
193
|
+
raise TimeoutError("Timed out waiting for initialization")
|
|
194
|
+
await asyncio.sleep(0.1)
|
|
195
|
+
logger.debug("Initialization frame received")
|
|
196
|
+
|
|
197
|
+
async def wait_for_first_frame(self, timeout: Optional[float] = None) -> None:
|
|
198
|
+
"""Wait for first frame."""
|
|
199
|
+
logger.info("Waiting for first frame")
|
|
200
|
+
start_time = time.time()
|
|
201
|
+
while not self.first_frame_received:
|
|
202
|
+
if timeout and time.time() - start_time > timeout:
|
|
203
|
+
logger.error("Timed out waiting for first frame")
|
|
204
|
+
raise TimeoutError("Timed out waiting for first frame")
|
|
205
|
+
await asyncio.sleep(0.1)
|
|
206
|
+
logger.debug("First frame received")
|
|
207
|
+
|
|
208
|
+
async def _send_and_receive(
|
|
209
|
+
self, request_data: bytes, client_id: str = None
|
|
210
|
+
) -> ServerResponse:
|
|
211
|
+
"""Send request and receive response with error handling and locking."""
|
|
212
|
+
client_id = client_id or self.client_id
|
|
213
|
+
logger.debug(f"Sending request to server (client_id={client_id})")
|
|
214
|
+
if self._is_closed:
|
|
215
|
+
error_msg = "Client is closed"
|
|
216
|
+
logger.error(f"{error_msg} (client_id={client_id})")
|
|
217
|
+
await self._handle_connection_error(error_msg)
|
|
218
|
+
return ServerResponse(status=ResponseStatus.ERROR, message=error_msg)
|
|
219
|
+
|
|
220
|
+
try:
|
|
221
|
+
async with self._socket_lock:
|
|
222
|
+
try:
|
|
223
|
+
await self.control_socket.send(request_data)
|
|
224
|
+
response_data = await self.control_socket.recv()
|
|
225
|
+
except zmq.Again:
|
|
226
|
+
error_msg = "Request timed out"
|
|
227
|
+
logger.error(f"{error_msg} (client_id={client_id})")
|
|
228
|
+
await self._handle_connection_error(error_msg)
|
|
229
|
+
return ServerResponse(
|
|
230
|
+
status=ResponseStatus.ERROR, message=error_msg
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
return ServerResponse.from_dict(msgpack.unpackb(response_data, raw=False))
|
|
234
|
+
|
|
235
|
+
except zmq.ZMQError as e:
|
|
236
|
+
error_msg = f"ZMQ error while sending request: {str(e)}"
|
|
237
|
+
logger.error(f"{error_msg} (client_id={client_id})")
|
|
238
|
+
await self._handle_connection_error(error_msg)
|
|
239
|
+
return ServerResponse(status=ResponseStatus.ERROR, message=error_msg)
|
|
240
|
+
except Exception as e:
|
|
241
|
+
error_msg = f"Failed to send request: {str(e)}"
|
|
242
|
+
logger.exception(f"{error_msg} (client_id={client_id})")
|
|
243
|
+
await self._handle_connection_error(error_msg)
|
|
244
|
+
return ServerResponse(status=ResponseStatus.ERROR, message=error_msg)
|
|
245
|
+
|
|
246
|
+
async def init_workspace(
|
|
247
|
+
self,
|
|
248
|
+
avatar_model_path: str,
|
|
249
|
+
video_file: Optional[str] = None,
|
|
250
|
+
inference_data_file: Optional[str] = None,
|
|
251
|
+
check_interval: float = 1.0,
|
|
252
|
+
max_retries: Optional[int] = None,
|
|
253
|
+
) -> ServerResponse:
|
|
254
|
+
"""Initialize avatar model with async status checking.
|
|
255
|
+
|
|
256
|
+
Args:
|
|
257
|
+
avatar_model_path: Avatar model path
|
|
258
|
+
video_file: Optional video file path
|
|
259
|
+
inference_data_file: Optional inference data file path
|
|
260
|
+
check_interval: Interval between status checks in seconds
|
|
261
|
+
max_retries: Maximum number of status checks before giving up
|
|
262
|
+
|
|
263
|
+
Returns:
|
|
264
|
+
Final initialization response
|
|
265
|
+
"""
|
|
266
|
+
logger.info(
|
|
267
|
+
f"Initializing avatar model from '{avatar_model_path}' "
|
|
268
|
+
f"with video_file='{video_file}' and "
|
|
269
|
+
f"inference_data_file='{inference_data_file}'"
|
|
270
|
+
)
|
|
271
|
+
# Send initial request
|
|
272
|
+
request = InitRequest(
|
|
273
|
+
client_id=self.client_id,
|
|
274
|
+
avatar_model_path=avatar_model_path,
|
|
275
|
+
video_file=video_file,
|
|
276
|
+
inference_data_file=inference_data_file,
|
|
277
|
+
)
|
|
278
|
+
response = await self._send_and_receive(
|
|
279
|
+
msgpack.packb(request.to_dict(), use_bin_type=True)
|
|
280
|
+
)
|
|
281
|
+
logger.debug(f"Initial avatar model init response: {response}")
|
|
282
|
+
|
|
283
|
+
if response.status == ResponseStatus.ERROR:
|
|
284
|
+
return response
|
|
285
|
+
|
|
286
|
+
# Start heartbeat immediately after sending init request
|
|
287
|
+
if not self._heartbeat_task:
|
|
288
|
+
self._heartbeat_task = asyncio.create_task(self._send_heartbeat())
|
|
289
|
+
logger.info("Heartbeat task started")
|
|
290
|
+
|
|
291
|
+
# Keep checking status until complete or error
|
|
292
|
+
retries = 0
|
|
293
|
+
while max_retries is None or retries < max_retries:
|
|
294
|
+
if response.status == ResponseStatus.SUCCESS:
|
|
295
|
+
self.is_initialized = True
|
|
296
|
+
logger.info("Avatar model initialized successfully")
|
|
297
|
+
return response
|
|
298
|
+
elif response.status == ResponseStatus.ERROR:
|
|
299
|
+
logger.error(f"Initialization failed: {response}")
|
|
300
|
+
return response
|
|
301
|
+
|
|
302
|
+
await asyncio.sleep(check_interval)
|
|
303
|
+
logger.debug("Checking initialization status")
|
|
304
|
+
|
|
305
|
+
status_request = CheckInitStatusRequest(client_id=self.client_id)
|
|
306
|
+
response = await self._send_and_receive(
|
|
307
|
+
msgpack.packb(status_request.to_dict(), use_bin_type=True)
|
|
308
|
+
)
|
|
309
|
+
logger.debug(f"Avatar model init status response: {response}")
|
|
310
|
+
|
|
311
|
+
retries += 1
|
|
312
|
+
|
|
313
|
+
# If we timeout, don't stop the heartbeat - let the client handle cleanup
|
|
314
|
+
logger.error("Initialization timed out")
|
|
315
|
+
return ServerResponse(
|
|
316
|
+
status=ResponseStatus.ERROR,
|
|
317
|
+
message="Initialization timed out",
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
async def get_setting(self, name: str) -> Any:
|
|
321
|
+
"""Get a setting from server."""
|
|
322
|
+
logger.debug(f"Getting setting '{name}' from server")
|
|
323
|
+
request = GetSettingRequest(client_id=self.client_id, name=name)
|
|
324
|
+
response = await self._send_and_receive(
|
|
325
|
+
msgpack.packb(request.to_dict(), use_bin_type=True)
|
|
326
|
+
)
|
|
327
|
+
if response.status != ResponseStatus.SUCCESS:
|
|
328
|
+
logger.error(f"Failed to get setting '{name}': {response}")
|
|
329
|
+
raise RuntimeError(f"Failed to get setting {name}: {response}")
|
|
330
|
+
logger.debug(f"Received setting '{name}': {response.extra['value']}")
|
|
331
|
+
return response.extra["value"]
|
|
332
|
+
|
|
333
|
+
async def get_video_script(self) -> VideoScript:
|
|
334
|
+
"""Get the video script from server."""
|
|
335
|
+
if self._video_script is None:
|
|
336
|
+
data = await self.get_setting("video_script")
|
|
337
|
+
self._video_script = VideoScript.from_dict(data)
|
|
338
|
+
return self._video_script
|
|
339
|
+
|
|
340
|
+
async def send_video_control(
|
|
341
|
+
self, target_video: str | None, actions: list[str] | str | None = None
|
|
342
|
+
) -> ServerResponse:
|
|
343
|
+
"""Send video control commands to server."""
|
|
344
|
+
logger.info(f"Sending video control to server: '{target_video=}', '{actions=}'")
|
|
345
|
+
request = AudioRequest(
|
|
346
|
+
client_id=self.client_id,
|
|
347
|
+
data=VideoControl(target_video=target_video, action=actions),
|
|
348
|
+
)
|
|
349
|
+
return await self._send_audio(request)
|
|
350
|
+
|
|
351
|
+
async def send_answer_finished_sentinel(self) -> ServerResponse:
|
|
352
|
+
"""Send answer finished sentinel to server to enable latter action triggers."""
|
|
353
|
+
logger.debug("Sending answer finished sentinel to server")
|
|
354
|
+
request = AudioRequest(
|
|
355
|
+
client_id=self.client_id,
|
|
356
|
+
data=VideoControl(end_of_speech=True),
|
|
357
|
+
)
|
|
358
|
+
return await self._send_audio(request)
|
|
359
|
+
|
|
360
|
+
async def send_audio(
|
|
361
|
+
self,
|
|
362
|
+
audio_bytes: bytes,
|
|
363
|
+
sample_rate: int,
|
|
364
|
+
is_last: bool = False,
|
|
365
|
+
**kwargs: dict[str, Any],
|
|
366
|
+
) -> str | None:
|
|
367
|
+
"""Stream audio data in chunks.
|
|
368
|
+
|
|
369
|
+
Returns:
|
|
370
|
+
Message ID for the streamed audio
|
|
371
|
+
"""
|
|
372
|
+
logger.trace(
|
|
373
|
+
f"Streaming audio data: sample_rate={sample_rate}, "
|
|
374
|
+
f"duration={len(audio_bytes) / 2 / sample_rate:.2f}s, is_last={is_last}"
|
|
375
|
+
)
|
|
376
|
+
|
|
377
|
+
control = VideoControl(
|
|
378
|
+
audio=AudioChunk.from_bytes(audio_bytes, sample_rate, last_chunk=is_last),
|
|
379
|
+
**kwargs,
|
|
380
|
+
)
|
|
381
|
+
request = AudioRequest(client_id=self.client_id, data=control)
|
|
382
|
+
response = await self._send_audio(request)
|
|
383
|
+
|
|
384
|
+
if response.status != ResponseStatus.SUCCESS:
|
|
385
|
+
logger.error(f"Failed to send audio chunk: {response}")
|
|
386
|
+
return None
|
|
387
|
+
|
|
388
|
+
return control.message_id
|
|
389
|
+
|
|
390
|
+
async def interrupt(self) -> ServerResponse:
|
|
391
|
+
"""Interrupt current audio processing."""
|
|
392
|
+
logger.info("Interrupting current audio processing")
|
|
393
|
+
self.reset_audio_stream()
|
|
394
|
+
|
|
395
|
+
request = InterruptRequest(client_id=self.client_id)
|
|
396
|
+
return await self._send_and_receive(
|
|
397
|
+
msgpack.packb(request.to_dict(), use_bin_type=True)
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
def set_frame_callback(
|
|
401
|
+
self, callback: Callable[[FrameMessage, Dict[str, Any]], Awaitable[None] | None]
|
|
402
|
+
) -> None:
|
|
403
|
+
"""Set callback for received frames. Can be sync or async function."""
|
|
404
|
+
self.frame_callback = callback
|
|
405
|
+
|
|
406
|
+
async def _send_audio(self, request: AudioRequest) -> ServerResponse:
|
|
407
|
+
"""Send audio data to server."""
|
|
408
|
+
return await self._send_and_receive(
|
|
409
|
+
msgpack.packb(request.to_dict(), use_bin_type=True)
|
|
410
|
+
)
|
|
411
|
+
|
|
412
|
+
async def _receive_frames(self) -> None:
|
|
413
|
+
"""Receive and process frames."""
|
|
414
|
+
logger.info(f"Starting frame receiver task (client_id={self.client_id})")
|
|
415
|
+
|
|
416
|
+
while self.running:
|
|
417
|
+
try:
|
|
418
|
+
parts = await self.stream_socket.recv_multipart()
|
|
419
|
+
if len(parts) < 2:
|
|
420
|
+
logger.warning(
|
|
421
|
+
f"Received incomplete message "
|
|
422
|
+
f"(client_id={self.client_id}, parts_count={len(parts)})"
|
|
423
|
+
)
|
|
424
|
+
continue
|
|
425
|
+
|
|
426
|
+
_, msg = parts
|
|
427
|
+
|
|
428
|
+
try:
|
|
429
|
+
frame_dict = msgpack.unpackb(msg, raw=False)
|
|
430
|
+
|
|
431
|
+
# Keep one client_id check for security
|
|
432
|
+
if frame_dict["client_id"] != self.client_id:
|
|
433
|
+
logger.warning(
|
|
434
|
+
f"Received message for wrong client "
|
|
435
|
+
f"(client_id={self.client_id}, "
|
|
436
|
+
f"received_client_id={frame_dict['client_id']})"
|
|
437
|
+
)
|
|
438
|
+
continue
|
|
439
|
+
|
|
440
|
+
frame_msg = FrameMessage(**frame_dict)
|
|
441
|
+
|
|
442
|
+
except Exception as e:
|
|
443
|
+
logger.error(
|
|
444
|
+
f"Error processing message "
|
|
445
|
+
f"(client_id={self.client_id}, error={str(e)})"
|
|
446
|
+
)
|
|
447
|
+
continue
|
|
448
|
+
|
|
449
|
+
if frame_msg.source_message_id == "_init_frame":
|
|
450
|
+
logger.info(
|
|
451
|
+
f"Received init frame (client_id={self.client_id}, "
|
|
452
|
+
f"shape={frame_msg.image.shape})"
|
|
453
|
+
)
|
|
454
|
+
self.init_frame = frame_msg.image
|
|
455
|
+
else:
|
|
456
|
+
self.first_frame_received = True
|
|
457
|
+
self.fps_monitor.update()
|
|
458
|
+
|
|
459
|
+
if self.frame_callback and self.first_frame_received:
|
|
460
|
+
result = self.frame_callback(
|
|
461
|
+
frame_msg, {"fps": self.fps_monitor.fps}
|
|
462
|
+
)
|
|
463
|
+
if result is not None and isinstance(result, Awaitable):
|
|
464
|
+
await result
|
|
465
|
+
|
|
466
|
+
except asyncio.CancelledError:
|
|
467
|
+
logger.info("Frame receiver task cancelled")
|
|
468
|
+
break
|
|
469
|
+
except Exception as e:
|
|
470
|
+
logger.exception(
|
|
471
|
+
f"Error receiving frame "
|
|
472
|
+
f"(client_id={self.client_id}, error={str(e)})"
|
|
473
|
+
)
|
|
474
|
+
await asyncio.sleep(0.001)
|
|
475
|
+
|
|
476
|
+
async def _send_heartbeat(self) -> None:
|
|
477
|
+
"""Send periodic heartbeat to server."""
|
|
478
|
+
logger.info("Heartbeat task started")
|
|
479
|
+
consecutive_errors = 0
|
|
480
|
+
while self.running:
|
|
481
|
+
try:
|
|
482
|
+
if self._is_closed:
|
|
483
|
+
logger.debug("Client is closed, stopping heartbeat")
|
|
484
|
+
break
|
|
485
|
+
|
|
486
|
+
request = HeartbeatRequest(client_id=self.client_id)
|
|
487
|
+
logger.debug("Sending heartbeat to server")
|
|
488
|
+
response = await self._send_and_receive(
|
|
489
|
+
msgpack.packb(request.to_dict(), use_bin_type=True)
|
|
490
|
+
)
|
|
491
|
+
logger.debug(f"Heartbeat response: {response}")
|
|
492
|
+
|
|
493
|
+
if response.status != ResponseStatus.SUCCESS:
|
|
494
|
+
consecutive_errors += 1
|
|
495
|
+
logger.warning(f"Heartbeat failed: {response}")
|
|
496
|
+
|
|
497
|
+
# If we get multiple consecutive failures, assume disconnection
|
|
498
|
+
if consecutive_errors >= self.max_consecutive_errors:
|
|
499
|
+
logger.error(
|
|
500
|
+
f"Multiple heartbeat failures, "
|
|
501
|
+
f"assuming disconnected (client_id={self.client_id})"
|
|
502
|
+
)
|
|
503
|
+
await self._handle_disconnection()
|
|
504
|
+
break
|
|
505
|
+
else:
|
|
506
|
+
consecutive_errors = 0 # Reset counter on successful heartbeat
|
|
507
|
+
|
|
508
|
+
await asyncio.sleep(1)
|
|
509
|
+
|
|
510
|
+
except asyncio.CancelledError:
|
|
511
|
+
logger.info("Heartbeat task cancelled")
|
|
512
|
+
break
|
|
513
|
+
except Exception as e:
|
|
514
|
+
consecutive_errors += 1
|
|
515
|
+
logger.exception(f"Error sending heartbeat: {e}")
|
|
516
|
+
|
|
517
|
+
# Check for disconnection on consecutive errors
|
|
518
|
+
if consecutive_errors >= self.max_consecutive_errors:
|
|
519
|
+
logger.error(
|
|
520
|
+
f"Multiple heartbeat errors, "
|
|
521
|
+
f"assuming disconnected (client_id={self.client_id})"
|
|
522
|
+
)
|
|
523
|
+
await self._handle_disconnection()
|
|
524
|
+
break
|
|
525
|
+
|
|
526
|
+
await asyncio.sleep(0.2)
|
|
527
|
+
|
|
528
|
+
async def close(self) -> None:
|
|
529
|
+
"""Close the client."""
|
|
530
|
+
logger.info("Closing client connection")
|
|
531
|
+
self._is_closed = True # Set closed flag first
|
|
532
|
+
|
|
533
|
+
async with self._socket_lock: # Ensure no ongoing operations
|
|
534
|
+
self.running = False
|
|
535
|
+
|
|
536
|
+
if self._frame_task:
|
|
537
|
+
self._frame_task.cancel()
|
|
538
|
+
try:
|
|
539
|
+
await self._frame_task
|
|
540
|
+
except asyncio.CancelledError:
|
|
541
|
+
logger.debug("Frame receiving task cancelled")
|
|
542
|
+
|
|
543
|
+
if self._heartbeat_task:
|
|
544
|
+
self._heartbeat_task.cancel()
|
|
545
|
+
try:
|
|
546
|
+
await self._heartbeat_task
|
|
547
|
+
except asyncio.CancelledError:
|
|
548
|
+
logger.debug("Heartbeat task cancelled")
|
|
549
|
+
|
|
550
|
+
# Close sockets
|
|
551
|
+
try:
|
|
552
|
+
self.control_socket.close(linger=0)
|
|
553
|
+
self.stream_socket.close(linger=0)
|
|
554
|
+
logger.debug("Sockets closed")
|
|
555
|
+
except zmq.ZMQError as e:
|
|
556
|
+
logger.error(f"Error closing sockets: {e}")
|
|
557
|
+
|
|
558
|
+
# Terminate context
|
|
559
|
+
try:
|
|
560
|
+
self.context.term()
|
|
561
|
+
logger.debug("ZMQ context terminated")
|
|
562
|
+
except zmq.ZMQError as e:
|
|
563
|
+
logger.error(f"Error terminating context: {e}")
|
|
564
|
+
|
|
565
|
+
logger.info("Client closed successfully")
|
|
566
|
+
|
|
567
|
+
def reset_audio_stream(self) -> None:
|
|
568
|
+
"""Reset the audio stream batcher."""
|
|
569
|
+
logger.debug("Resetting audio stream batcher")
|
|
570
|
+
self.stream_batcher.reset()
|
|
571
|
+
|
|
572
|
+
|
|
573
|
+
async def run_example_client(
|
|
574
|
+
client_id: str,
|
|
575
|
+
avatar_model_path: str,
|
|
576
|
+
window_name: Optional[str] = None,
|
|
577
|
+
) -> None:
|
|
578
|
+
"""Run an example client with visualization."""
|
|
579
|
+
|
|
580
|
+
async def show_frame(frame: FrameMessage, frame_info: dict) -> None:
|
|
581
|
+
# Add FPS text to frame
|
|
582
|
+
image = frame.image
|
|
583
|
+
fps_text = f"FPS: {frame_info.get('fps', 0):.1f}"
|
|
584
|
+
cv2.putText(
|
|
585
|
+
image,
|
|
586
|
+
fps_text,
|
|
587
|
+
(10, 30),
|
|
588
|
+
cv2.FONT_HERSHEY_SIMPLEX,
|
|
589
|
+
1,
|
|
590
|
+
(0, 255, 0),
|
|
591
|
+
2,
|
|
592
|
+
cv2.LINE_AA,
|
|
593
|
+
)
|
|
594
|
+
logger.debug(f"Displaying frame {frame.source_message_id} with FPS {fps_text}")
|
|
595
|
+
|
|
596
|
+
window = window_name or f"Client {client_id}"
|
|
597
|
+
cv2.imshow(window, image)
|
|
598
|
+
key = chr(cv2.waitKey(1) & 0xFF).lower()
|
|
599
|
+
if key == "q":
|
|
600
|
+
logger.info("Quit key pressed, raising KeyboardInterrupt")
|
|
601
|
+
raise KeyboardInterrupt
|
|
602
|
+
|
|
603
|
+
# Example of async processing
|
|
604
|
+
await asyncio.sleep(0) # Allow other tasks to run
|
|
605
|
+
|
|
606
|
+
# Initialize client
|
|
607
|
+
client = ZMQBithumanRuntimeClient(client_id)
|
|
608
|
+
client.set_frame_callback(show_frame) # Now accepts async callback
|
|
609
|
+
logger.info("Starting example client")
|
|
610
|
+
await client.start()
|
|
611
|
+
|
|
612
|
+
try:
|
|
613
|
+
# Initialize workspace
|
|
614
|
+
response = await client.init_workspace(avatar_model_path)
|
|
615
|
+
if response.status != ResponseStatus.SUCCESS:
|
|
616
|
+
logger.error(f"Failed to initialize workspace: {response}")
|
|
617
|
+
return
|
|
618
|
+
logger.info("Workspace initialized in example client")
|
|
619
|
+
|
|
620
|
+
# Keep running until interrupted
|
|
621
|
+
while True:
|
|
622
|
+
await asyncio.sleep(0.1)
|
|
623
|
+
|
|
624
|
+
except KeyboardInterrupt:
|
|
625
|
+
logger.info("Shutting down client...")
|
|
626
|
+
finally:
|
|
627
|
+
await client.close()
|
|
628
|
+
cv2.destroyAllWindows()
|
|
629
|
+
logger.info("Example client closed")
|
|
630
|
+
|
|
631
|
+
|
|
632
|
+
async def play_audio_example(
|
|
633
|
+
client_ids: list[str],
|
|
634
|
+
audio_file: str,
|
|
635
|
+
text: Optional[str] = None,
|
|
636
|
+
control_port: int = 5555,
|
|
637
|
+
) -> None:
|
|
638
|
+
"""Play audio file to multiple clients.
|
|
639
|
+
|
|
640
|
+
Args:
|
|
641
|
+
client_ids: List of client IDs to send audio to
|
|
642
|
+
audio_file: Path to audio file
|
|
643
|
+
text: Optional text to be spoken
|
|
644
|
+
control_port: Server control port
|
|
645
|
+
"""
|
|
646
|
+
for client_id in client_ids:
|
|
647
|
+
try:
|
|
648
|
+
# Create client
|
|
649
|
+
client = ZMQBithumanRuntimeClient(client_id, control_port=control_port)
|
|
650
|
+
await client.start()
|
|
651
|
+
|
|
652
|
+
# Load and prepare audio
|
|
653
|
+
logger.info(f"Loading audio file: {audio_file}")
|
|
654
|
+
audio_data, sample_rate = load_audio(audio_file)
|
|
655
|
+
audio_data = float32_to_int16(audio_data)
|
|
656
|
+
|
|
657
|
+
# Stream audio
|
|
658
|
+
logger.info(
|
|
659
|
+
f"Streaming audio to {client_id}: "
|
|
660
|
+
f"duration={len(audio_data) / sample_rate:.2f}s, "
|
|
661
|
+
f"sample_rate={sample_rate}Hz"
|
|
662
|
+
)
|
|
663
|
+
await client.send_audio(
|
|
664
|
+
audio_bytes=audio_data.tobytes(),
|
|
665
|
+
sample_rate=sample_rate,
|
|
666
|
+
is_last=False,
|
|
667
|
+
text=text,
|
|
668
|
+
)
|
|
669
|
+
|
|
670
|
+
# Cleanup
|
|
671
|
+
await client.close()
|
|
672
|
+
|
|
673
|
+
except Exception as e:
|
|
674
|
+
logger.exception(f"Error playing audio to client {client_id}: {e}")
|
|
675
|
+
|
|
676
|
+
|
|
677
|
+
async def interrupt_example(client_ids: list[str], control_port: int = 5555) -> None:
|
|
678
|
+
"""Interrupt audio processing for multiple clients.
|
|
679
|
+
|
|
680
|
+
Args:
|
|
681
|
+
client_ids: List of client IDs to interrupt
|
|
682
|
+
control_port: Server control port
|
|
683
|
+
"""
|
|
684
|
+
for client_id in client_ids:
|
|
685
|
+
try:
|
|
686
|
+
client = ZMQBithumanRuntimeClient(client_id, control_port=control_port)
|
|
687
|
+
await client.start()
|
|
688
|
+
await client.interrupt()
|
|
689
|
+
await client.close()
|
|
690
|
+
except Exception as e:
|
|
691
|
+
logger.exception(f"Error interrupting client {client_id}: {e}")
|
|
692
|
+
|
|
693
|
+
|
|
694
|
+
def main() -> None:
|
|
695
|
+
"""Main entry point for the client CLI."""
|
|
696
|
+
parser = argparse.ArgumentParser(description="ZMQ Bithuman Runtime Client CLI")
|
|
697
|
+
subparsers = parser.add_subparsers(dest="command", help="Commands")
|
|
698
|
+
|
|
699
|
+
# Start client command
|
|
700
|
+
start_parser = subparsers.add_parser(
|
|
701
|
+
"start", help="Start a test client with visualization"
|
|
702
|
+
)
|
|
703
|
+
start_parser.add_argument("--client-id", type=str, required=True, help="Client ID")
|
|
704
|
+
start_parser.add_argument(
|
|
705
|
+
"--avatar-model-path", type=str, required=True, help="Avatar model path"
|
|
706
|
+
)
|
|
707
|
+
start_parser.add_argument("--window-name", type=str, help="Window name for display")
|
|
708
|
+
|
|
709
|
+
# Play audio command
|
|
710
|
+
play_parser = subparsers.add_parser(
|
|
711
|
+
"play-audio", help="Play audio to one or more clients"
|
|
712
|
+
)
|
|
713
|
+
play_parser.add_argument(
|
|
714
|
+
"--client-id",
|
|
715
|
+
type=str,
|
|
716
|
+
nargs="+",
|
|
717
|
+
required=True,
|
|
718
|
+
help="Client IDs to send audio to",
|
|
719
|
+
)
|
|
720
|
+
play_parser.add_argument(
|
|
721
|
+
"--audio-file", type=str, required=True, help="Path to audio file"
|
|
722
|
+
)
|
|
723
|
+
play_parser.add_argument("--text", type=str, help="Text to be spoken")
|
|
724
|
+
play_parser.add_argument(
|
|
725
|
+
"--control-port", type=int, default=5555, help="Server control port"
|
|
726
|
+
)
|
|
727
|
+
|
|
728
|
+
# Interrupt command
|
|
729
|
+
interrupt_parser = subparsers.add_parser(
|
|
730
|
+
"interrupt", help="Interrupt audio processing"
|
|
731
|
+
)
|
|
732
|
+
interrupt_parser.add_argument(
|
|
733
|
+
"--client-id",
|
|
734
|
+
type=str,
|
|
735
|
+
nargs="+",
|
|
736
|
+
required=True,
|
|
737
|
+
help="Client IDs to interrupt",
|
|
738
|
+
)
|
|
739
|
+
interrupt_parser.add_argument(
|
|
740
|
+
"--control-port", type=int, default=5555, help="Server control port"
|
|
741
|
+
)
|
|
742
|
+
|
|
743
|
+
args = parser.parse_args()
|
|
744
|
+
|
|
745
|
+
if not args.command:
|
|
746
|
+
parser.print_help()
|
|
747
|
+
return
|
|
748
|
+
|
|
749
|
+
logger.remove()
|
|
750
|
+
logger.add(sys.stderr, level="INFO")
|
|
751
|
+
|
|
752
|
+
try:
|
|
753
|
+
logger.debug(f"Executing command '{args.command}' with args: {args}")
|
|
754
|
+
if args.command == "start":
|
|
755
|
+
args.avatar_model_path = os.path.abspath(args.avatar_model_path)
|
|
756
|
+
asyncio.run(
|
|
757
|
+
run_example_client(
|
|
758
|
+
args.client_id, args.avatar_model_path, args.window_name
|
|
759
|
+
)
|
|
760
|
+
)
|
|
761
|
+
|
|
762
|
+
elif args.command == "play-audio":
|
|
763
|
+
asyncio.run(
|
|
764
|
+
play_audio_example(
|
|
765
|
+
client_ids=args.client_id,
|
|
766
|
+
audio_file=args.audio_file,
|
|
767
|
+
text=args.text,
|
|
768
|
+
control_port=args.control_port,
|
|
769
|
+
)
|
|
770
|
+
)
|
|
771
|
+
|
|
772
|
+
elif args.command == "interrupt":
|
|
773
|
+
asyncio.run(
|
|
774
|
+
interrupt_example(
|
|
775
|
+
client_ids=args.client_id,
|
|
776
|
+
control_port=args.control_port,
|
|
777
|
+
)
|
|
778
|
+
)
|
|
779
|
+
|
|
780
|
+
except KeyboardInterrupt:
|
|
781
|
+
logger.info("Operation interrupted by user")
|
|
782
|
+
except Exception as e:
|
|
783
|
+
logger.exception(f"Error executing command {args.command}: {e}")
|
|
784
|
+
sys.exit(1)
|
|
785
|
+
|
|
786
|
+
|
|
787
|
+
if __name__ == "__main__":
|
|
788
|
+
main()
|