bithuman 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. bithuman/__init__.py +13 -0
  2. bithuman/_version.py +1 -0
  3. bithuman/api.py +164 -0
  4. bithuman/audio/__init__.py +19 -0
  5. bithuman/audio/audio.py +396 -0
  6. bithuman/audio/hparams.py +108 -0
  7. bithuman/audio/utils.py +255 -0
  8. bithuman/config.py +88 -0
  9. bithuman/engine/__init__.py +15 -0
  10. bithuman/engine/auth.py +335 -0
  11. bithuman/engine/compression.py +257 -0
  12. bithuman/engine/enums.py +16 -0
  13. bithuman/engine/image_ops.py +192 -0
  14. bithuman/engine/inference.py +108 -0
  15. bithuman/engine/knn.py +58 -0
  16. bithuman/engine/video_data.py +391 -0
  17. bithuman/engine/video_reader.py +168 -0
  18. bithuman/lib/__init__.py +1 -0
  19. bithuman/lib/audio_encoder.onnx +45631 -28
  20. bithuman/lib/generator.py +763 -0
  21. bithuman/lib/pth2h5.py +106 -0
  22. bithuman/plugins/__init__.py +0 -0
  23. bithuman/plugins/stt.py +185 -0
  24. bithuman/runtime.py +1004 -0
  25. bithuman/runtime_async.py +469 -0
  26. bithuman/service/__init__.py +9 -0
  27. bithuman/service/client.py +788 -0
  28. bithuman/service/messages.py +210 -0
  29. bithuman/service/server.py +759 -0
  30. bithuman/utils/__init__.py +43 -0
  31. bithuman/utils/agent.py +359 -0
  32. bithuman/utils/fps_controller.py +90 -0
  33. bithuman/utils/image.py +41 -0
  34. bithuman/utils/unzip.py +38 -0
  35. bithuman/video_graph/__init__.py +16 -0
  36. bithuman/video_graph/action_trigger.py +83 -0
  37. bithuman/video_graph/driver_video.py +482 -0
  38. bithuman/video_graph/navigator.py +736 -0
  39. bithuman/video_graph/trigger.py +90 -0
  40. bithuman/video_graph/video_script.py +344 -0
  41. bithuman-1.0.2.dist-info/METADATA +37 -0
  42. bithuman-1.0.2.dist-info/RECORD +44 -0
  43. bithuman-1.0.2.dist-info/WHEEL +5 -0
  44. bithuman-1.0.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,788 @@
1
+ """ZMQ client for bithuman runtime service."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import asyncio
7
+ import os
8
+ import sys
9
+ import time
10
+ from collections import deque
11
+ from typing import Any, Awaitable, Callable, Dict, Optional
12
+
13
+ import cv2
14
+ import msgpack
15
+ import numpy as np
16
+ from loguru import logger
17
+
18
+ try:
19
+ import zmq
20
+ import zmq.asyncio
21
+ except ImportError:
22
+ raise ImportError("zmq is required for bithuman runtime client")
23
+
24
+ from bithuman.api import AudioChunk, VideoControl
25
+ from bithuman.audio import AudioStreamBatcher, float32_to_int16, load_audio
26
+ from bithuman.service.messages import (
27
+ AudioRequest,
28
+ CheckInitStatusRequest,
29
+ FrameMessage,
30
+ GetSettingRequest,
31
+ HeartbeatRequest,
32
+ InitRequest,
33
+ InterruptRequest,
34
+ ResponseStatus,
35
+ ServerResponse,
36
+ )
37
+ from bithuman.video_graph.video_script import VideoScript
38
+
39
+
40
+ class FPSMonitor:
41
+ """Monitor FPS with a sliding window.
42
+
43
+ Tracks frame timestamps and calculates current FPS based on
44
+ a moving window of recent frames.
45
+ """
46
+
47
+ def __init__(self, window_size: int = 30) -> None:
48
+ """Initialize FPS monitor."""
49
+ self.timestamps: deque = deque(maxlen=window_size)
50
+ self.last_fps_update: float = time.time()
51
+ self.current_fps: float = 0
52
+ logger.debug(f"Initialized FPSMonitor with window_size={window_size}")
53
+
54
+ def update(self) -> None:
55
+ """Add new frame timestamp and update FPS if needed."""
56
+ self.timestamps.append(time.time())
57
+
58
+ if time.time() - self.last_fps_update > 0.5:
59
+ self._calculate_fps()
60
+ self.last_fps_update = time.time()
61
+
62
+ def _calculate_fps(self) -> None:
63
+ """Calculate current FPS from timestamp differences."""
64
+ if len(self.timestamps) < 2:
65
+ self.current_fps = 0
66
+ return
67
+
68
+ time_diff = self.timestamps[-1] - self.timestamps[0]
69
+ if time_diff > 0:
70
+ self.current_fps = (len(self.timestamps) - 1) / time_diff
71
+
72
+ @property
73
+ def fps(self) -> float:
74
+ """Get current FPS value."""
75
+ return self.current_fps
76
+
77
+
78
+ class ZMQBithumanRuntimeClient:
79
+ """Async client for Bithuman Runtime using ZMQ.
80
+
81
+ Handles communication with server including:
82
+ - Workspace initialization
83
+ - Audio streaming
84
+ - Frame receiving
85
+ - Heartbeat monitoring
86
+ """
87
+
88
+ def __init__(
89
+ self,
90
+ client_id: str,
91
+ host: str = "127.0.0.1",
92
+ control_port: int = 5555,
93
+ stream_port: int = 5556,
94
+ request_timeout: float = 5,
95
+ max_consecutive_errors: int = 3,
96
+ ) -> None:
97
+ """Initialize client.
98
+
99
+ Args:
100
+ client_id: Unique client identifier
101
+ host: Host address
102
+ control_port: Port for control messages
103
+ stream_port: Port for frame streaming
104
+ request_timeout: Timeout for requests in seconds
105
+ max_consecutive_errors: Number of consecutive errors before disconnection
106
+ """
107
+ logger.info(
108
+ f"Initializing BithumanRuntimeClient {client_id} connecting to ports "
109
+ f"{control_port} (control) and {stream_port} (stream)"
110
+ )
111
+ self.client_id = client_id
112
+ self.context = zmq.asyncio.Context()
113
+ self.request_timeout = request_timeout
114
+ self.max_consecutive_errors = max_consecutive_errors
115
+
116
+ # Control socket (REQ/REP)
117
+ self.control_socket = self.context.socket(zmq.REQ)
118
+ # Set timeouts (in milliseconds)
119
+ if request_timeout and request_timeout > 0:
120
+ self.control_socket.setsockopt(zmq.RCVTIMEO, int(request_timeout * 1000))
121
+ self.control_socket.setsockopt(zmq.SNDTIMEO, int(request_timeout * 1000))
122
+ logger.debug(f"Set request timeouts to {request_timeout}s")
123
+ self.control_socket.connect(f"tcp://{host}:{control_port}")
124
+ logger.debug(f"Connected control socket to tcp://{host}:{control_port}")
125
+
126
+ # Subscribe socket for frames
127
+ self.stream_socket = self.context.socket(zmq.SUB)
128
+ self.stream_socket.connect(f"tcp://{host}:{stream_port}")
129
+ self.stream_socket.setsockopt_string(zmq.SUBSCRIBE, client_id)
130
+ logger.debug(f"Connected stream socket to tcp://{host}:{stream_port}")
131
+
132
+ # Frame callback and FPS monitoring
133
+ self.frame_callback: Optional[
134
+ Callable[[FrameMessage, Dict[str, Any]], Awaitable[None] | None]
135
+ ] = None
136
+ self.fps_monitor = FPSMonitor()
137
+ self.running = True
138
+ self.is_initialized = False
139
+
140
+ # Add stream batcher
141
+ self.stream_batcher = AudioStreamBatcher(fps=25, output_sample_rate=16000)
142
+
143
+ # Tasks
144
+ self._frame_task: Optional[asyncio.Task] = None
145
+ self._heartbeat_task: Optional[asyncio.Task] = None
146
+ self.init_frame: Optional[np.ndarray] = None
147
+ self.first_frame_received = False
148
+
149
+ # Add socket state tracking
150
+ self._socket_lock = asyncio.Lock()
151
+ self._is_closed = False
152
+
153
+ # Add connection state callbacks
154
+ self.on_disconnected: Optional[Callable[[], None]] = None
155
+ self.on_connection_error: Optional[Callable[[str], None]] = None
156
+
157
+ # Add video script
158
+ self._video_script: Optional[VideoScript] = None
159
+
160
+ def set_connection_callbacks(
161
+ self,
162
+ on_disconnected: Optional[Callable[[], None]] = None,
163
+ on_connection_error: Optional[Callable[[str], None]] = None,
164
+ ) -> None:
165
+ """Set callbacks for connection state changes."""
166
+ self.on_disconnected = on_disconnected
167
+ self.on_connection_error = on_connection_error
168
+
169
+ async def _handle_connection_error(self, error_msg: str) -> None:
170
+ """Handle connection errors."""
171
+ logger.error(f"Connection error: {error_msg}")
172
+ if self.on_connection_error:
173
+ self.on_connection_error(error_msg)
174
+
175
+ async def _handle_disconnection(self) -> None:
176
+ """Handle server disconnection."""
177
+ logger.warning("Server disconnected")
178
+ if self.on_disconnected:
179
+ self.on_disconnected()
180
+
181
+ async def start(self) -> None:
182
+ """Start client tasks."""
183
+ logger.info("Starting client tasks")
184
+ self._frame_task = asyncio.create_task(self._receive_frames())
185
+
186
+ async def wait_for_init_frame(self, timeout: Optional[float] = None) -> None:
187
+ """Wait for workspace initialization."""
188
+ logger.info("Waiting for initialization frame")
189
+ start_time = time.time()
190
+ while self.init_frame is None:
191
+ if timeout and time.time() - start_time > timeout:
192
+ logger.error("Timed out waiting for initialization frame")
193
+ raise TimeoutError("Timed out waiting for initialization")
194
+ await asyncio.sleep(0.1)
195
+ logger.debug("Initialization frame received")
196
+
197
+ async def wait_for_first_frame(self, timeout: Optional[float] = None) -> None:
198
+ """Wait for first frame."""
199
+ logger.info("Waiting for first frame")
200
+ start_time = time.time()
201
+ while not self.first_frame_received:
202
+ if timeout and time.time() - start_time > timeout:
203
+ logger.error("Timed out waiting for first frame")
204
+ raise TimeoutError("Timed out waiting for first frame")
205
+ await asyncio.sleep(0.1)
206
+ logger.debug("First frame received")
207
+
208
+ async def _send_and_receive(
209
+ self, request_data: bytes, client_id: str = None
210
+ ) -> ServerResponse:
211
+ """Send request and receive response with error handling and locking."""
212
+ client_id = client_id or self.client_id
213
+ logger.debug(f"Sending request to server (client_id={client_id})")
214
+ if self._is_closed:
215
+ error_msg = "Client is closed"
216
+ logger.error(f"{error_msg} (client_id={client_id})")
217
+ await self._handle_connection_error(error_msg)
218
+ return ServerResponse(status=ResponseStatus.ERROR, message=error_msg)
219
+
220
+ try:
221
+ async with self._socket_lock:
222
+ try:
223
+ await self.control_socket.send(request_data)
224
+ response_data = await self.control_socket.recv()
225
+ except zmq.Again:
226
+ error_msg = "Request timed out"
227
+ logger.error(f"{error_msg} (client_id={client_id})")
228
+ await self._handle_connection_error(error_msg)
229
+ return ServerResponse(
230
+ status=ResponseStatus.ERROR, message=error_msg
231
+ )
232
+
233
+ return ServerResponse.from_dict(msgpack.unpackb(response_data, raw=False))
234
+
235
+ except zmq.ZMQError as e:
236
+ error_msg = f"ZMQ error while sending request: {str(e)}"
237
+ logger.error(f"{error_msg} (client_id={client_id})")
238
+ await self._handle_connection_error(error_msg)
239
+ return ServerResponse(status=ResponseStatus.ERROR, message=error_msg)
240
+ except Exception as e:
241
+ error_msg = f"Failed to send request: {str(e)}"
242
+ logger.exception(f"{error_msg} (client_id={client_id})")
243
+ await self._handle_connection_error(error_msg)
244
+ return ServerResponse(status=ResponseStatus.ERROR, message=error_msg)
245
+
246
+ async def init_workspace(
247
+ self,
248
+ avatar_model_path: str,
249
+ video_file: Optional[str] = None,
250
+ inference_data_file: Optional[str] = None,
251
+ check_interval: float = 1.0,
252
+ max_retries: Optional[int] = None,
253
+ ) -> ServerResponse:
254
+ """Initialize avatar model with async status checking.
255
+
256
+ Args:
257
+ avatar_model_path: Avatar model path
258
+ video_file: Optional video file path
259
+ inference_data_file: Optional inference data file path
260
+ check_interval: Interval between status checks in seconds
261
+ max_retries: Maximum number of status checks before giving up
262
+
263
+ Returns:
264
+ Final initialization response
265
+ """
266
+ logger.info(
267
+ f"Initializing avatar model from '{avatar_model_path}' "
268
+ f"with video_file='{video_file}' and "
269
+ f"inference_data_file='{inference_data_file}'"
270
+ )
271
+ # Send initial request
272
+ request = InitRequest(
273
+ client_id=self.client_id,
274
+ avatar_model_path=avatar_model_path,
275
+ video_file=video_file,
276
+ inference_data_file=inference_data_file,
277
+ )
278
+ response = await self._send_and_receive(
279
+ msgpack.packb(request.to_dict(), use_bin_type=True)
280
+ )
281
+ logger.debug(f"Initial avatar model init response: {response}")
282
+
283
+ if response.status == ResponseStatus.ERROR:
284
+ return response
285
+
286
+ # Start heartbeat immediately after sending init request
287
+ if not self._heartbeat_task:
288
+ self._heartbeat_task = asyncio.create_task(self._send_heartbeat())
289
+ logger.info("Heartbeat task started")
290
+
291
+ # Keep checking status until complete or error
292
+ retries = 0
293
+ while max_retries is None or retries < max_retries:
294
+ if response.status == ResponseStatus.SUCCESS:
295
+ self.is_initialized = True
296
+ logger.info("Avatar model initialized successfully")
297
+ return response
298
+ elif response.status == ResponseStatus.ERROR:
299
+ logger.error(f"Initialization failed: {response}")
300
+ return response
301
+
302
+ await asyncio.sleep(check_interval)
303
+ logger.debug("Checking initialization status")
304
+
305
+ status_request = CheckInitStatusRequest(client_id=self.client_id)
306
+ response = await self._send_and_receive(
307
+ msgpack.packb(status_request.to_dict(), use_bin_type=True)
308
+ )
309
+ logger.debug(f"Avatar model init status response: {response}")
310
+
311
+ retries += 1
312
+
313
+ # If we timeout, don't stop the heartbeat - let the client handle cleanup
314
+ logger.error("Initialization timed out")
315
+ return ServerResponse(
316
+ status=ResponseStatus.ERROR,
317
+ message="Initialization timed out",
318
+ )
319
+
320
+ async def get_setting(self, name: str) -> Any:
321
+ """Get a setting from server."""
322
+ logger.debug(f"Getting setting '{name}' from server")
323
+ request = GetSettingRequest(client_id=self.client_id, name=name)
324
+ response = await self._send_and_receive(
325
+ msgpack.packb(request.to_dict(), use_bin_type=True)
326
+ )
327
+ if response.status != ResponseStatus.SUCCESS:
328
+ logger.error(f"Failed to get setting '{name}': {response}")
329
+ raise RuntimeError(f"Failed to get setting {name}: {response}")
330
+ logger.debug(f"Received setting '{name}': {response.extra['value']}")
331
+ return response.extra["value"]
332
+
333
+ async def get_video_script(self) -> VideoScript:
334
+ """Get the video script from server."""
335
+ if self._video_script is None:
336
+ data = await self.get_setting("video_script")
337
+ self._video_script = VideoScript.from_dict(data)
338
+ return self._video_script
339
+
340
+ async def send_video_control(
341
+ self, target_video: str | None, actions: list[str] | str | None = None
342
+ ) -> ServerResponse:
343
+ """Send video control commands to server."""
344
+ logger.info(f"Sending video control to server: '{target_video=}', '{actions=}'")
345
+ request = AudioRequest(
346
+ client_id=self.client_id,
347
+ data=VideoControl(target_video=target_video, action=actions),
348
+ )
349
+ return await self._send_audio(request)
350
+
351
+ async def send_answer_finished_sentinel(self) -> ServerResponse:
352
+ """Send answer finished sentinel to server to enable latter action triggers."""
353
+ logger.debug("Sending answer finished sentinel to server")
354
+ request = AudioRequest(
355
+ client_id=self.client_id,
356
+ data=VideoControl(end_of_speech=True),
357
+ )
358
+ return await self._send_audio(request)
359
+
360
+ async def send_audio(
361
+ self,
362
+ audio_bytes: bytes,
363
+ sample_rate: int,
364
+ is_last: bool = False,
365
+ **kwargs: dict[str, Any],
366
+ ) -> str | None:
367
+ """Stream audio data in chunks.
368
+
369
+ Returns:
370
+ Message ID for the streamed audio
371
+ """
372
+ logger.trace(
373
+ f"Streaming audio data: sample_rate={sample_rate}, "
374
+ f"duration={len(audio_bytes) / 2 / sample_rate:.2f}s, is_last={is_last}"
375
+ )
376
+
377
+ control = VideoControl(
378
+ audio=AudioChunk.from_bytes(audio_bytes, sample_rate, last_chunk=is_last),
379
+ **kwargs,
380
+ )
381
+ request = AudioRequest(client_id=self.client_id, data=control)
382
+ response = await self._send_audio(request)
383
+
384
+ if response.status != ResponseStatus.SUCCESS:
385
+ logger.error(f"Failed to send audio chunk: {response}")
386
+ return None
387
+
388
+ return control.message_id
389
+
390
+ async def interrupt(self) -> ServerResponse:
391
+ """Interrupt current audio processing."""
392
+ logger.info("Interrupting current audio processing")
393
+ self.reset_audio_stream()
394
+
395
+ request = InterruptRequest(client_id=self.client_id)
396
+ return await self._send_and_receive(
397
+ msgpack.packb(request.to_dict(), use_bin_type=True)
398
+ )
399
+
400
+ def set_frame_callback(
401
+ self, callback: Callable[[FrameMessage, Dict[str, Any]], Awaitable[None] | None]
402
+ ) -> None:
403
+ """Set callback for received frames. Can be sync or async function."""
404
+ self.frame_callback = callback
405
+
406
+ async def _send_audio(self, request: AudioRequest) -> ServerResponse:
407
+ """Send audio data to server."""
408
+ return await self._send_and_receive(
409
+ msgpack.packb(request.to_dict(), use_bin_type=True)
410
+ )
411
+
412
+ async def _receive_frames(self) -> None:
413
+ """Receive and process frames."""
414
+ logger.info(f"Starting frame receiver task (client_id={self.client_id})")
415
+
416
+ while self.running:
417
+ try:
418
+ parts = await self.stream_socket.recv_multipart()
419
+ if len(parts) < 2:
420
+ logger.warning(
421
+ f"Received incomplete message "
422
+ f"(client_id={self.client_id}, parts_count={len(parts)})"
423
+ )
424
+ continue
425
+
426
+ _, msg = parts
427
+
428
+ try:
429
+ frame_dict = msgpack.unpackb(msg, raw=False)
430
+
431
+ # Keep one client_id check for security
432
+ if frame_dict["client_id"] != self.client_id:
433
+ logger.warning(
434
+ f"Received message for wrong client "
435
+ f"(client_id={self.client_id}, "
436
+ f"received_client_id={frame_dict['client_id']})"
437
+ )
438
+ continue
439
+
440
+ frame_msg = FrameMessage(**frame_dict)
441
+
442
+ except Exception as e:
443
+ logger.error(
444
+ f"Error processing message "
445
+ f"(client_id={self.client_id}, error={str(e)})"
446
+ )
447
+ continue
448
+
449
+ if frame_msg.source_message_id == "_init_frame":
450
+ logger.info(
451
+ f"Received init frame (client_id={self.client_id}, "
452
+ f"shape={frame_msg.image.shape})"
453
+ )
454
+ self.init_frame = frame_msg.image
455
+ else:
456
+ self.first_frame_received = True
457
+ self.fps_monitor.update()
458
+
459
+ if self.frame_callback and self.first_frame_received:
460
+ result = self.frame_callback(
461
+ frame_msg, {"fps": self.fps_monitor.fps}
462
+ )
463
+ if result is not None and isinstance(result, Awaitable):
464
+ await result
465
+
466
+ except asyncio.CancelledError:
467
+ logger.info("Frame receiver task cancelled")
468
+ break
469
+ except Exception as e:
470
+ logger.exception(
471
+ f"Error receiving frame "
472
+ f"(client_id={self.client_id}, error={str(e)})"
473
+ )
474
+ await asyncio.sleep(0.001)
475
+
476
+ async def _send_heartbeat(self) -> None:
477
+ """Send periodic heartbeat to server."""
478
+ logger.info("Heartbeat task started")
479
+ consecutive_errors = 0
480
+ while self.running:
481
+ try:
482
+ if self._is_closed:
483
+ logger.debug("Client is closed, stopping heartbeat")
484
+ break
485
+
486
+ request = HeartbeatRequest(client_id=self.client_id)
487
+ logger.debug("Sending heartbeat to server")
488
+ response = await self._send_and_receive(
489
+ msgpack.packb(request.to_dict(), use_bin_type=True)
490
+ )
491
+ logger.debug(f"Heartbeat response: {response}")
492
+
493
+ if response.status != ResponseStatus.SUCCESS:
494
+ consecutive_errors += 1
495
+ logger.warning(f"Heartbeat failed: {response}")
496
+
497
+ # If we get multiple consecutive failures, assume disconnection
498
+ if consecutive_errors >= self.max_consecutive_errors:
499
+ logger.error(
500
+ f"Multiple heartbeat failures, "
501
+ f"assuming disconnected (client_id={self.client_id})"
502
+ )
503
+ await self._handle_disconnection()
504
+ break
505
+ else:
506
+ consecutive_errors = 0 # Reset counter on successful heartbeat
507
+
508
+ await asyncio.sleep(1)
509
+
510
+ except asyncio.CancelledError:
511
+ logger.info("Heartbeat task cancelled")
512
+ break
513
+ except Exception as e:
514
+ consecutive_errors += 1
515
+ logger.exception(f"Error sending heartbeat: {e}")
516
+
517
+ # Check for disconnection on consecutive errors
518
+ if consecutive_errors >= self.max_consecutive_errors:
519
+ logger.error(
520
+ f"Multiple heartbeat errors, "
521
+ f"assuming disconnected (client_id={self.client_id})"
522
+ )
523
+ await self._handle_disconnection()
524
+ break
525
+
526
+ await asyncio.sleep(0.2)
527
+
528
+ async def close(self) -> None:
529
+ """Close the client."""
530
+ logger.info("Closing client connection")
531
+ self._is_closed = True # Set closed flag first
532
+
533
+ async with self._socket_lock: # Ensure no ongoing operations
534
+ self.running = False
535
+
536
+ if self._frame_task:
537
+ self._frame_task.cancel()
538
+ try:
539
+ await self._frame_task
540
+ except asyncio.CancelledError:
541
+ logger.debug("Frame receiving task cancelled")
542
+
543
+ if self._heartbeat_task:
544
+ self._heartbeat_task.cancel()
545
+ try:
546
+ await self._heartbeat_task
547
+ except asyncio.CancelledError:
548
+ logger.debug("Heartbeat task cancelled")
549
+
550
+ # Close sockets
551
+ try:
552
+ self.control_socket.close(linger=0)
553
+ self.stream_socket.close(linger=0)
554
+ logger.debug("Sockets closed")
555
+ except zmq.ZMQError as e:
556
+ logger.error(f"Error closing sockets: {e}")
557
+
558
+ # Terminate context
559
+ try:
560
+ self.context.term()
561
+ logger.debug("ZMQ context terminated")
562
+ except zmq.ZMQError as e:
563
+ logger.error(f"Error terminating context: {e}")
564
+
565
+ logger.info("Client closed successfully")
566
+
567
+ def reset_audio_stream(self) -> None:
568
+ """Reset the audio stream batcher."""
569
+ logger.debug("Resetting audio stream batcher")
570
+ self.stream_batcher.reset()
571
+
572
+
573
+ async def run_example_client(
574
+ client_id: str,
575
+ avatar_model_path: str,
576
+ window_name: Optional[str] = None,
577
+ ) -> None:
578
+ """Run an example client with visualization."""
579
+
580
+ async def show_frame(frame: FrameMessage, frame_info: dict) -> None:
581
+ # Add FPS text to frame
582
+ image = frame.image
583
+ fps_text = f"FPS: {frame_info.get('fps', 0):.1f}"
584
+ cv2.putText(
585
+ image,
586
+ fps_text,
587
+ (10, 30),
588
+ cv2.FONT_HERSHEY_SIMPLEX,
589
+ 1,
590
+ (0, 255, 0),
591
+ 2,
592
+ cv2.LINE_AA,
593
+ )
594
+ logger.debug(f"Displaying frame {frame.source_message_id} with FPS {fps_text}")
595
+
596
+ window = window_name or f"Client {client_id}"
597
+ cv2.imshow(window, image)
598
+ key = chr(cv2.waitKey(1) & 0xFF).lower()
599
+ if key == "q":
600
+ logger.info("Quit key pressed, raising KeyboardInterrupt")
601
+ raise KeyboardInterrupt
602
+
603
+ # Example of async processing
604
+ await asyncio.sleep(0) # Allow other tasks to run
605
+
606
+ # Initialize client
607
+ client = ZMQBithumanRuntimeClient(client_id)
608
+ client.set_frame_callback(show_frame) # Now accepts async callback
609
+ logger.info("Starting example client")
610
+ await client.start()
611
+
612
+ try:
613
+ # Initialize workspace
614
+ response = await client.init_workspace(avatar_model_path)
615
+ if response.status != ResponseStatus.SUCCESS:
616
+ logger.error(f"Failed to initialize workspace: {response}")
617
+ return
618
+ logger.info("Workspace initialized in example client")
619
+
620
+ # Keep running until interrupted
621
+ while True:
622
+ await asyncio.sleep(0.1)
623
+
624
+ except KeyboardInterrupt:
625
+ logger.info("Shutting down client...")
626
+ finally:
627
+ await client.close()
628
+ cv2.destroyAllWindows()
629
+ logger.info("Example client closed")
630
+
631
+
632
+ async def play_audio_example(
633
+ client_ids: list[str],
634
+ audio_file: str,
635
+ text: Optional[str] = None,
636
+ control_port: int = 5555,
637
+ ) -> None:
638
+ """Play audio file to multiple clients.
639
+
640
+ Args:
641
+ client_ids: List of client IDs to send audio to
642
+ audio_file: Path to audio file
643
+ text: Optional text to be spoken
644
+ control_port: Server control port
645
+ """
646
+ for client_id in client_ids:
647
+ try:
648
+ # Create client
649
+ client = ZMQBithumanRuntimeClient(client_id, control_port=control_port)
650
+ await client.start()
651
+
652
+ # Load and prepare audio
653
+ logger.info(f"Loading audio file: {audio_file}")
654
+ audio_data, sample_rate = load_audio(audio_file)
655
+ audio_data = float32_to_int16(audio_data)
656
+
657
+ # Stream audio
658
+ logger.info(
659
+ f"Streaming audio to {client_id}: "
660
+ f"duration={len(audio_data) / sample_rate:.2f}s, "
661
+ f"sample_rate={sample_rate}Hz"
662
+ )
663
+ await client.send_audio(
664
+ audio_bytes=audio_data.tobytes(),
665
+ sample_rate=sample_rate,
666
+ is_last=False,
667
+ text=text,
668
+ )
669
+
670
+ # Cleanup
671
+ await client.close()
672
+
673
+ except Exception as e:
674
+ logger.exception(f"Error playing audio to client {client_id}: {e}")
675
+
676
+
677
+ async def interrupt_example(client_ids: list[str], control_port: int = 5555) -> None:
678
+ """Interrupt audio processing for multiple clients.
679
+
680
+ Args:
681
+ client_ids: List of client IDs to interrupt
682
+ control_port: Server control port
683
+ """
684
+ for client_id in client_ids:
685
+ try:
686
+ client = ZMQBithumanRuntimeClient(client_id, control_port=control_port)
687
+ await client.start()
688
+ await client.interrupt()
689
+ await client.close()
690
+ except Exception as e:
691
+ logger.exception(f"Error interrupting client {client_id}: {e}")
692
+
693
+
694
+ def main() -> None:
695
+ """Main entry point for the client CLI."""
696
+ parser = argparse.ArgumentParser(description="ZMQ Bithuman Runtime Client CLI")
697
+ subparsers = parser.add_subparsers(dest="command", help="Commands")
698
+
699
+ # Start client command
700
+ start_parser = subparsers.add_parser(
701
+ "start", help="Start a test client with visualization"
702
+ )
703
+ start_parser.add_argument("--client-id", type=str, required=True, help="Client ID")
704
+ start_parser.add_argument(
705
+ "--avatar-model-path", type=str, required=True, help="Avatar model path"
706
+ )
707
+ start_parser.add_argument("--window-name", type=str, help="Window name for display")
708
+
709
+ # Play audio command
710
+ play_parser = subparsers.add_parser(
711
+ "play-audio", help="Play audio to one or more clients"
712
+ )
713
+ play_parser.add_argument(
714
+ "--client-id",
715
+ type=str,
716
+ nargs="+",
717
+ required=True,
718
+ help="Client IDs to send audio to",
719
+ )
720
+ play_parser.add_argument(
721
+ "--audio-file", type=str, required=True, help="Path to audio file"
722
+ )
723
+ play_parser.add_argument("--text", type=str, help="Text to be spoken")
724
+ play_parser.add_argument(
725
+ "--control-port", type=int, default=5555, help="Server control port"
726
+ )
727
+
728
+ # Interrupt command
729
+ interrupt_parser = subparsers.add_parser(
730
+ "interrupt", help="Interrupt audio processing"
731
+ )
732
+ interrupt_parser.add_argument(
733
+ "--client-id",
734
+ type=str,
735
+ nargs="+",
736
+ required=True,
737
+ help="Client IDs to interrupt",
738
+ )
739
+ interrupt_parser.add_argument(
740
+ "--control-port", type=int, default=5555, help="Server control port"
741
+ )
742
+
743
+ args = parser.parse_args()
744
+
745
+ if not args.command:
746
+ parser.print_help()
747
+ return
748
+
749
+ logger.remove()
750
+ logger.add(sys.stderr, level="INFO")
751
+
752
+ try:
753
+ logger.debug(f"Executing command '{args.command}' with args: {args}")
754
+ if args.command == "start":
755
+ args.avatar_model_path = os.path.abspath(args.avatar_model_path)
756
+ asyncio.run(
757
+ run_example_client(
758
+ args.client_id, args.avatar_model_path, args.window_name
759
+ )
760
+ )
761
+
762
+ elif args.command == "play-audio":
763
+ asyncio.run(
764
+ play_audio_example(
765
+ client_ids=args.client_id,
766
+ audio_file=args.audio_file,
767
+ text=args.text,
768
+ control_port=args.control_port,
769
+ )
770
+ )
771
+
772
+ elif args.command == "interrupt":
773
+ asyncio.run(
774
+ interrupt_example(
775
+ client_ids=args.client_id,
776
+ control_port=args.control_port,
777
+ )
778
+ )
779
+
780
+ except KeyboardInterrupt:
781
+ logger.info("Operation interrupted by user")
782
+ except Exception as e:
783
+ logger.exception(f"Error executing command {args.command}: {e}")
784
+ sys.exit(1)
785
+
786
+
787
+ if __name__ == "__main__":
788
+ main()