bithuman 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bithuman/__init__.py +13 -0
- bithuman/_version.py +1 -0
- bithuman/api.py +164 -0
- bithuman/audio/__init__.py +19 -0
- bithuman/audio/audio.py +396 -0
- bithuman/audio/hparams.py +108 -0
- bithuman/audio/utils.py +255 -0
- bithuman/config.py +88 -0
- bithuman/engine/__init__.py +15 -0
- bithuman/engine/auth.py +335 -0
- bithuman/engine/compression.py +257 -0
- bithuman/engine/enums.py +16 -0
- bithuman/engine/image_ops.py +192 -0
- bithuman/engine/inference.py +108 -0
- bithuman/engine/knn.py +58 -0
- bithuman/engine/video_data.py +391 -0
- bithuman/engine/video_reader.py +168 -0
- bithuman/lib/__init__.py +1 -0
- bithuman/lib/audio_encoder.onnx +45631 -28
- bithuman/lib/generator.py +763 -0
- bithuman/lib/pth2h5.py +106 -0
- bithuman/plugins/__init__.py +0 -0
- bithuman/plugins/stt.py +185 -0
- bithuman/runtime.py +1004 -0
- bithuman/runtime_async.py +469 -0
- bithuman/service/__init__.py +9 -0
- bithuman/service/client.py +788 -0
- bithuman/service/messages.py +210 -0
- bithuman/service/server.py +759 -0
- bithuman/utils/__init__.py +43 -0
- bithuman/utils/agent.py +359 -0
- bithuman/utils/fps_controller.py +90 -0
- bithuman/utils/image.py +41 -0
- bithuman/utils/unzip.py +38 -0
- bithuman/video_graph/__init__.py +16 -0
- bithuman/video_graph/action_trigger.py +83 -0
- bithuman/video_graph/driver_video.py +482 -0
- bithuman/video_graph/navigator.py +736 -0
- bithuman/video_graph/trigger.py +90 -0
- bithuman/video_graph/video_script.py +344 -0
- bithuman-1.0.2.dist-info/METADATA +37 -0
- bithuman-1.0.2.dist-info/RECORD +44 -0
- bithuman-1.0.2.dist-info/WHEEL +5 -0
- bithuman-1.0.2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,759 @@
|
|
|
1
|
+
"""ZMQ server for bithuman runtime service."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import sys
|
|
6
|
+
import time
|
|
7
|
+
from abc import ABC, abstractmethod
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from queue import Empty
|
|
10
|
+
from threading import Lock, Semaphore, Thread
|
|
11
|
+
from typing import Dict, Optional
|
|
12
|
+
|
|
13
|
+
import msgpack
|
|
14
|
+
import numpy as np
|
|
15
|
+
from loguru import logger
|
|
16
|
+
|
|
17
|
+
try:
|
|
18
|
+
import zmq
|
|
19
|
+
except ImportError:
|
|
20
|
+
raise ImportError("zmq is required for bithuman runtime server")
|
|
21
|
+
|
|
22
|
+
from bithuman.api import VideoFrame
|
|
23
|
+
from bithuman.runtime import Bithuman
|
|
24
|
+
from bithuman.service.messages import (
|
|
25
|
+
AudioRequest,
|
|
26
|
+
CheckInitStatusRequest,
|
|
27
|
+
FrameMessage,
|
|
28
|
+
GetSettingRequest,
|
|
29
|
+
HeartbeatRequest,
|
|
30
|
+
InitRequest,
|
|
31
|
+
InterruptRequest,
|
|
32
|
+
ResponseStatus,
|
|
33
|
+
ServerResponse,
|
|
34
|
+
)
|
|
35
|
+
from bithuman.utils.fps_controller import FPSController
|
|
36
|
+
|
|
37
|
+
logger.remove()
|
|
38
|
+
logger.add(sys.stdout, level="INFO")
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class SessionWorker(ABC):
|
|
42
|
+
"""Abstract base class for session workers.
|
|
43
|
+
|
|
44
|
+
Handles frame processing and streaming for a single client session.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
def __init__(
|
|
48
|
+
self,
|
|
49
|
+
stream_socket: zmq.Socket,
|
|
50
|
+
init_request: InitRequest,
|
|
51
|
+
stream_socket_lock: Lock,
|
|
52
|
+
init_semaphore: Semaphore,
|
|
53
|
+
) -> None:
|
|
54
|
+
"""Initialize the worker.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
stream_socket: ZMQ socket for streaming frames
|
|
58
|
+
init_request: Initialization request
|
|
59
|
+
stream_socket_lock: Lock for stream socket
|
|
60
|
+
init_semaphore: Semaphore for initialization
|
|
61
|
+
"""
|
|
62
|
+
self.client_id = init_request.client_id
|
|
63
|
+
self.stream_socket = stream_socket
|
|
64
|
+
self.stream_socket_lock = stream_socket_lock # Add lock
|
|
65
|
+
self.init_semaphore = (
|
|
66
|
+
init_semaphore # Add semaphore for concurrent initialization
|
|
67
|
+
)
|
|
68
|
+
self.init_request = init_request
|
|
69
|
+
self.is_active = True
|
|
70
|
+
self.fps = 25
|
|
71
|
+
|
|
72
|
+
# FPS control
|
|
73
|
+
self.fps_controller = FPSController(target_fps=self.fps)
|
|
74
|
+
|
|
75
|
+
# Statistics
|
|
76
|
+
self.frame_count = 0
|
|
77
|
+
self.last_log_time = time.time()
|
|
78
|
+
|
|
79
|
+
# Initialize daemon with mutable Event
|
|
80
|
+
self.runtime = Bithuman()
|
|
81
|
+
self.runtime.set_model(
|
|
82
|
+
model_path=self.init_request.avatar_model_path, load_data=False
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
# Add last active timestamp
|
|
86
|
+
self.last_active_time = time.time()
|
|
87
|
+
self.max_inactive_time = 5.0 # 5 seconds timeout
|
|
88
|
+
|
|
89
|
+
self.interrupt_requested = False
|
|
90
|
+
|
|
91
|
+
self.cleanup_lock = Lock()
|
|
92
|
+
self.cleaned_up = False
|
|
93
|
+
|
|
94
|
+
logger.info(
|
|
95
|
+
f"Created worker for client {self.client_id}: {self.init_request.to_dict()}"
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
def get_first_frame(self) -> Optional[np.ndarray]:
|
|
99
|
+
"""Get the first frame of the video."""
|
|
100
|
+
return self.runtime.get_first_frame()
|
|
101
|
+
|
|
102
|
+
def interrupt(self) -> None:
|
|
103
|
+
"""Interrupt current processing by temporarily muting.
|
|
104
|
+
|
|
105
|
+
Clears the input queue and temporarily mutes audio processing
|
|
106
|
+
while maintaining video output.
|
|
107
|
+
"""
|
|
108
|
+
self.runtime.interrupt()
|
|
109
|
+
logger.info(f"Interrupted processing for client {self.client_id}")
|
|
110
|
+
|
|
111
|
+
def run(self) -> None:
|
|
112
|
+
"""Process frames for the client."""
|
|
113
|
+
logger.info(f"Starting frame processing for client {self.client_id}")
|
|
114
|
+
with self.init_semaphore:
|
|
115
|
+
logger.info(f"Loading model for client {self.client_id}")
|
|
116
|
+
self.runtime.load_data()
|
|
117
|
+
logger.info(f"Model loaded for client {self.client_id}")
|
|
118
|
+
|
|
119
|
+
while self.is_running:
|
|
120
|
+
try:
|
|
121
|
+
# Process frames
|
|
122
|
+
for frame in self.runtime.run():
|
|
123
|
+
if not self.is_running or not self.is_active:
|
|
124
|
+
logger.debug(
|
|
125
|
+
f"Worker for client {self.client_id} stopped or inactive"
|
|
126
|
+
)
|
|
127
|
+
break
|
|
128
|
+
|
|
129
|
+
# Wait for next frame time
|
|
130
|
+
self.fps_controller.wait_next_frame()
|
|
131
|
+
self.send_frame(frame, time.time())
|
|
132
|
+
# Update FPS controller
|
|
133
|
+
self.fps_controller.update()
|
|
134
|
+
|
|
135
|
+
except Empty:
|
|
136
|
+
time.sleep(0.001)
|
|
137
|
+
except Exception:
|
|
138
|
+
logger.exception(f"Error processing frames for client {self.client_id}")
|
|
139
|
+
|
|
140
|
+
def send_frame(self, frame: VideoFrame, current_time: float) -> None:
|
|
141
|
+
"""Send a frame to the client."""
|
|
142
|
+
try:
|
|
143
|
+
# Create FrameMessage first (outside lock)
|
|
144
|
+
frame_msg = FrameMessage.create(
|
|
145
|
+
client_id=self.client_id,
|
|
146
|
+
frame_image=frame.bgr_image,
|
|
147
|
+
frame_index=frame.frame_index,
|
|
148
|
+
source_message_id=frame.source_message_id,
|
|
149
|
+
end_of_speech=frame.end_of_speech,
|
|
150
|
+
audio_bytes=frame.audio_chunk.bytes if frame.audio_chunk else None,
|
|
151
|
+
sample_rate=(
|
|
152
|
+
frame.audio_chunk.sample_rate if frame.audio_chunk else None
|
|
153
|
+
),
|
|
154
|
+
# metadata
|
|
155
|
+
stream_fps=self.fps_controller.fps,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
# Serialize data outside lock
|
|
159
|
+
try:
|
|
160
|
+
msg_data = zmq.Frame(
|
|
161
|
+
msgpack.packb(frame_msg.to_dict(), use_bin_type=True)
|
|
162
|
+
)
|
|
163
|
+
topic = zmq.Frame(self.client_id.encode())
|
|
164
|
+
except Exception as e:
|
|
165
|
+
logger.error(
|
|
166
|
+
f"Failed to serialize frame data for client {self.client_id}: {e}"
|
|
167
|
+
)
|
|
168
|
+
return
|
|
169
|
+
|
|
170
|
+
# Only lock the actual send operation
|
|
171
|
+
try:
|
|
172
|
+
with self.stream_socket_lock:
|
|
173
|
+
self.stream_socket.send_multipart(
|
|
174
|
+
[topic, msg_data], flags=zmq.NOBLOCK, copy=False
|
|
175
|
+
)
|
|
176
|
+
except zmq.error.Again as e:
|
|
177
|
+
logger.warning(f"Client {self.client_id} is not receiving frames: {e}")
|
|
178
|
+
return
|
|
179
|
+
except zmq.ZMQError as e:
|
|
180
|
+
logger.error(f"Failed to send frame for client {self.client_id}: {e}")
|
|
181
|
+
return
|
|
182
|
+
|
|
183
|
+
# Update statistics (outside lock)
|
|
184
|
+
self.frame_count += 1
|
|
185
|
+
if current_time - self.last_log_time > 5:
|
|
186
|
+
fps = self.frame_count / (current_time - self.last_log_time)
|
|
187
|
+
logger.debug(
|
|
188
|
+
f"Client {self.client_id} streaming at {fps:.2f} FPS "
|
|
189
|
+
f"(target: {self.fps_controller.target_fps}, "
|
|
190
|
+
f"average: {self.fps_controller.fps:.2f})"
|
|
191
|
+
)
|
|
192
|
+
self.frame_count = 0
|
|
193
|
+
self.last_log_time = current_time
|
|
194
|
+
|
|
195
|
+
except Exception as e:
|
|
196
|
+
logger.error(f"Failed to send frame for client {self.client_id}: {e}")
|
|
197
|
+
|
|
198
|
+
@property
|
|
199
|
+
@abstractmethod
|
|
200
|
+
def is_running(self) -> bool:
|
|
201
|
+
"""Check if the worker is running.
|
|
202
|
+
|
|
203
|
+
Returns:
|
|
204
|
+
True if the worker is still running, False otherwise
|
|
205
|
+
"""
|
|
206
|
+
pass
|
|
207
|
+
|
|
208
|
+
def stop(self) -> None:
|
|
209
|
+
"""Stop the worker and cleanup resources with proper locking."""
|
|
210
|
+
with self.cleanup_lock:
|
|
211
|
+
if self.cleaned_up:
|
|
212
|
+
return
|
|
213
|
+
self.cleaned_up = True
|
|
214
|
+
self.is_active = False
|
|
215
|
+
|
|
216
|
+
# Ensure daemon cleanup
|
|
217
|
+
if self.runtime is not None:
|
|
218
|
+
self.runtime.cleanup()
|
|
219
|
+
self.runtime = None
|
|
220
|
+
|
|
221
|
+
logger.info(f"Cleaned up worker for client {self.client_id}")
|
|
222
|
+
|
|
223
|
+
def update_active_time(self) -> None:
|
|
224
|
+
"""Update last active timestamp to current time."""
|
|
225
|
+
self.last_active_time = time.time()
|
|
226
|
+
logger.debug(f"Updated active time for client {self.client_id}")
|
|
227
|
+
|
|
228
|
+
@property
|
|
229
|
+
def is_inactive_timeout(self) -> bool:
|
|
230
|
+
"""Check if worker has been inactive for too long.
|
|
231
|
+
|
|
232
|
+
Returns:
|
|
233
|
+
True if worker has exceeded max inactive time, False otherwise
|
|
234
|
+
"""
|
|
235
|
+
return time.time() - self.last_active_time > self.max_inactive_time
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
class ThreadedSessionWorker(SessionWorker):
|
|
239
|
+
"""Threaded implementation of session worker."""
|
|
240
|
+
|
|
241
|
+
def __init__(
|
|
242
|
+
self,
|
|
243
|
+
stream_socket: zmq.Socket,
|
|
244
|
+
init_request: InitRequest,
|
|
245
|
+
stream_socket_lock: Lock,
|
|
246
|
+
init_semaphore: Semaphore,
|
|
247
|
+
) -> None:
|
|
248
|
+
"""Initialize the threaded worker."""
|
|
249
|
+
super().__init__(
|
|
250
|
+
stream_socket, init_request, stream_socket_lock, init_semaphore
|
|
251
|
+
)
|
|
252
|
+
self.running = True
|
|
253
|
+
self.thread = Thread(target=self.run)
|
|
254
|
+
logger.info(f"Initialized threaded worker for client {self.client_id}")
|
|
255
|
+
|
|
256
|
+
@property
|
|
257
|
+
def is_running(self) -> bool:
|
|
258
|
+
"""Check if the worker is running."""
|
|
259
|
+
return self.running
|
|
260
|
+
|
|
261
|
+
def stop(self) -> None:
|
|
262
|
+
"""Stop the worker thread."""
|
|
263
|
+
logger.info(f"Stopping worker for client {self.client_id}")
|
|
264
|
+
self.running = False
|
|
265
|
+
self.thread.join()
|
|
266
|
+
super().stop()
|
|
267
|
+
logger.info(f"Stopped worker for client {self.client_id}")
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
class ZMQBithumanRuntimeServer:
|
|
271
|
+
"""ZMQ server for Bithuman Runtime.
|
|
272
|
+
|
|
273
|
+
Manages client connections, worker processes, and message routing.
|
|
274
|
+
Handles initialization, audio processing, and frame streaming for multiple clients.
|
|
275
|
+
"""
|
|
276
|
+
|
|
277
|
+
def __init__(
|
|
278
|
+
self,
|
|
279
|
+
host: str = "0.0.0.0",
|
|
280
|
+
control_port: int = 5555,
|
|
281
|
+
stream_port: int = 5556,
|
|
282
|
+
max_concurrent_inits: int = 1,
|
|
283
|
+
) -> None:
|
|
284
|
+
"""Initialize the server.
|
|
285
|
+
|
|
286
|
+
Args:
|
|
287
|
+
host: Host address to bind to
|
|
288
|
+
control_port: Port for control messages
|
|
289
|
+
stream_port: Port for frame streaming
|
|
290
|
+
max_concurrent_inits: Maximum number of concurrent initializations
|
|
291
|
+
"""
|
|
292
|
+
logger.info(
|
|
293
|
+
f"Initializing ZMQBithumanRuntimeServer on {host} "
|
|
294
|
+
f"(control_port={control_port}, stream_port={stream_port})"
|
|
295
|
+
)
|
|
296
|
+
self.context = zmq.Context()
|
|
297
|
+
|
|
298
|
+
# Control socket for commands (REQ/REP pattern)
|
|
299
|
+
self.control_socket = self.context.socket(zmq.REP)
|
|
300
|
+
self.control_socket.bind(f"tcp://{host}:{control_port}")
|
|
301
|
+
logger.debug(f"Control socket bound to tcp://{host}:{control_port}")
|
|
302
|
+
|
|
303
|
+
# Pub socket for streaming frames
|
|
304
|
+
self.stream_socket = self.context.socket(zmq.PUB)
|
|
305
|
+
self.stream_socket.bind(f"tcp://{host}:{stream_port}")
|
|
306
|
+
logger.debug(f"Stream socket bound to tcp://{host}:{stream_port}")
|
|
307
|
+
|
|
308
|
+
self.workers: Dict[str, SessionWorker] = {}
|
|
309
|
+
self.workers_lock = Lock()
|
|
310
|
+
self.running = True
|
|
311
|
+
|
|
312
|
+
# Add stream socket lock
|
|
313
|
+
self.stream_socket_lock = Lock()
|
|
314
|
+
# Replace Lock with Semaphore for concurrent initialization
|
|
315
|
+
self.init_semaphore = Semaphore(max_concurrent_inits)
|
|
316
|
+
|
|
317
|
+
# Start control message handler
|
|
318
|
+
self.control_thread = Thread(target=self._handle_control)
|
|
319
|
+
self.control_thread.start()
|
|
320
|
+
logger.info("Started control message handler thread")
|
|
321
|
+
|
|
322
|
+
# Start worker monitor thread
|
|
323
|
+
self.monitor_thread = Thread(target=self._monitor_workers)
|
|
324
|
+
self.monitor_thread.start()
|
|
325
|
+
logger.info("Started worker monitor thread")
|
|
326
|
+
|
|
327
|
+
self.initializing_workers: Dict[
|
|
328
|
+
str, Thread
|
|
329
|
+
] = {} # Track workers being initialized
|
|
330
|
+
self.init_errors: Dict[str, str] = {} # Track initialization errors
|
|
331
|
+
|
|
332
|
+
def _handle_control(self) -> None:
|
|
333
|
+
"""Handle incoming control messages."""
|
|
334
|
+
logger.info("Starting control message handler loop")
|
|
335
|
+
|
|
336
|
+
while self.running:
|
|
337
|
+
try:
|
|
338
|
+
try:
|
|
339
|
+
msg = msgpack.unpackb(
|
|
340
|
+
self.control_socket.recv(flags=zmq.NOBLOCK), raw=False
|
|
341
|
+
)
|
|
342
|
+
except zmq.error.Again:
|
|
343
|
+
# No messages waiting
|
|
344
|
+
time.sleep(0.001)
|
|
345
|
+
continue
|
|
346
|
+
except zmq.ZMQError as e:
|
|
347
|
+
if e.errno == zmq.ETERM:
|
|
348
|
+
# Context was terminated
|
|
349
|
+
break
|
|
350
|
+
logger.error(f"ZMQ error in control handler: {e}")
|
|
351
|
+
continue
|
|
352
|
+
|
|
353
|
+
self._process_control_message(msg)
|
|
354
|
+
|
|
355
|
+
except Exception as e:
|
|
356
|
+
logger.exception("Error handling control message")
|
|
357
|
+
error_response = ServerResponse(
|
|
358
|
+
status=ResponseStatus.ERROR,
|
|
359
|
+
message=str(e),
|
|
360
|
+
)
|
|
361
|
+
self._send_response(error_response, "unknown")
|
|
362
|
+
|
|
363
|
+
def _process_control_message(self, msg: dict) -> None:
|
|
364
|
+
"""Process a single control message."""
|
|
365
|
+
cmd = msg.get("command")
|
|
366
|
+
client_id = msg.get("client_id", "unknown")
|
|
367
|
+
|
|
368
|
+
logger.debug(f"Received '{cmd}' command from client {client_id}")
|
|
369
|
+
|
|
370
|
+
# Update worker active time and check connection
|
|
371
|
+
with self.workers_lock:
|
|
372
|
+
worker = self.workers.get(client_id)
|
|
373
|
+
if worker:
|
|
374
|
+
worker.update_active_time()
|
|
375
|
+
|
|
376
|
+
# Handle command
|
|
377
|
+
if cmd == "init":
|
|
378
|
+
response = self._handle_init(msg)
|
|
379
|
+
elif cmd == "audio":
|
|
380
|
+
response = self._handle_audio(msg)
|
|
381
|
+
elif cmd == "heartbeat":
|
|
382
|
+
response = self._handle_heartbeat(msg)
|
|
383
|
+
elif cmd == "interrupt":
|
|
384
|
+
response = self._handle_interrupt(msg)
|
|
385
|
+
elif cmd == "check_init_status":
|
|
386
|
+
response = self._handle_check_init_status(msg)
|
|
387
|
+
elif cmd == "get_setting":
|
|
388
|
+
response = self._handle_get_setting(msg)
|
|
389
|
+
else:
|
|
390
|
+
logger.warning(f"Unknown command received from client {client_id}: {cmd}")
|
|
391
|
+
response = ServerResponse(
|
|
392
|
+
status=ResponseStatus.ERROR,
|
|
393
|
+
message=f"Unknown command: {cmd}",
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
# Send response
|
|
397
|
+
self._send_response(response, client_id)
|
|
398
|
+
logger.debug(f"Sent response to client {client_id}: {response}")
|
|
399
|
+
|
|
400
|
+
def _send_response(self, response: ServerResponse, client_id: str) -> None:
|
|
401
|
+
"""Send response to client."""
|
|
402
|
+
try:
|
|
403
|
+
# Use blocking send for REQ/REP pattern
|
|
404
|
+
self.control_socket.send(
|
|
405
|
+
msgpack.packb(response.to_dict(), use_bin_type=True)
|
|
406
|
+
)
|
|
407
|
+
except zmq.ZMQError as e:
|
|
408
|
+
logger.error(f"Failed to send response to client {client_id}: {e}")
|
|
409
|
+
|
|
410
|
+
def _handle_init(self, msg: dict) -> ServerResponse:
|
|
411
|
+
"""Handle workspace initialization request."""
|
|
412
|
+
try:
|
|
413
|
+
request = InitRequest(**msg)
|
|
414
|
+
logger.info(f"Handling init request for client {request.client_id}")
|
|
415
|
+
|
|
416
|
+
if not Path(request.avatar_model_path).exists():
|
|
417
|
+
error_msg = f"Workspace not found: {request.avatar_model_path}"
|
|
418
|
+
logger.error(error_msg)
|
|
419
|
+
return ServerResponse(
|
|
420
|
+
status=ResponseStatus.ERROR,
|
|
421
|
+
message=error_msg,
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
# If already initialized, return success
|
|
425
|
+
if request.client_id in self.workers:
|
|
426
|
+
logger.info(f"Client {request.client_id} already initialized")
|
|
427
|
+
return ServerResponse(
|
|
428
|
+
status=ResponseStatus.SUCCESS,
|
|
429
|
+
message="Client already initialized",
|
|
430
|
+
)
|
|
431
|
+
|
|
432
|
+
# If already initializing, return loading status
|
|
433
|
+
if request.client_id in self.initializing_workers:
|
|
434
|
+
logger.info(f"Client {request.client_id} initialization in progress")
|
|
435
|
+
return ServerResponse(
|
|
436
|
+
status=ResponseStatus.LOADING,
|
|
437
|
+
message="Initialization in progress",
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
# Define initialization function
|
|
441
|
+
def initialize_worker() -> None:
|
|
442
|
+
logger.info(f"Initializing worker for client {request.client_id}")
|
|
443
|
+
try:
|
|
444
|
+
worker = ThreadedSessionWorker(
|
|
445
|
+
stream_socket=self.stream_socket,
|
|
446
|
+
init_request=request,
|
|
447
|
+
stream_socket_lock=self.stream_socket_lock,
|
|
448
|
+
init_semaphore=self.init_semaphore,
|
|
449
|
+
)
|
|
450
|
+
worker.thread.start()
|
|
451
|
+
logger.info(f"Started worker thread for client {request.client_id}")
|
|
452
|
+
|
|
453
|
+
with self.workers_lock:
|
|
454
|
+
self.workers[request.client_id] = worker
|
|
455
|
+
if request.client_id in self.initializing_workers:
|
|
456
|
+
del self.initializing_workers[request.client_id]
|
|
457
|
+
logger.info(f"Worker created for client {request.client_id}")
|
|
458
|
+
|
|
459
|
+
# Send first frame
|
|
460
|
+
first_frame = worker.get_first_frame()
|
|
461
|
+
if first_frame is not None:
|
|
462
|
+
frame_msg = VideoFrame(
|
|
463
|
+
bgr_image=first_frame, source_message_id="_init_frame"
|
|
464
|
+
)
|
|
465
|
+
worker.send_frame(frame_msg, time.time())
|
|
466
|
+
logger.info(f"Sent first frame to client {request.client_id}")
|
|
467
|
+
worker.update_active_time()
|
|
468
|
+
|
|
469
|
+
except Exception as e:
|
|
470
|
+
logger.exception(
|
|
471
|
+
f"Failed to initialize worker for client {request.client_id}"
|
|
472
|
+
)
|
|
473
|
+
with self.workers_lock:
|
|
474
|
+
self.init_errors[request.client_id] = str(e)
|
|
475
|
+
if request.client_id in self.initializing_workers:
|
|
476
|
+
del self.initializing_workers[request.client_id]
|
|
477
|
+
|
|
478
|
+
# Start initialization in background thread
|
|
479
|
+
init_thread = Thread(target=initialize_worker)
|
|
480
|
+
with self.workers_lock:
|
|
481
|
+
self.initializing_workers[request.client_id] = init_thread
|
|
482
|
+
init_thread.start()
|
|
483
|
+
logger.info(f"Started initialization thread for client {request.client_id}")
|
|
484
|
+
|
|
485
|
+
return ServerResponse(
|
|
486
|
+
status=ResponseStatus.LOADING,
|
|
487
|
+
message="Started initialization",
|
|
488
|
+
)
|
|
489
|
+
|
|
490
|
+
except Exception as e:
|
|
491
|
+
logger.exception("Failed to start initialization")
|
|
492
|
+
return ServerResponse(status=ResponseStatus.ERROR, message=str(e))
|
|
493
|
+
|
|
494
|
+
def _handle_audio(self, msg: dict) -> ServerResponse:
|
|
495
|
+
"""Handle audio processing request."""
|
|
496
|
+
try:
|
|
497
|
+
request = AudioRequest.from_dict(msg)
|
|
498
|
+
|
|
499
|
+
with self.workers_lock:
|
|
500
|
+
worker = self.workers.get(request.client_id)
|
|
501
|
+
if not worker:
|
|
502
|
+
error_msg = "Client not initialized"
|
|
503
|
+
logger.error(error_msg)
|
|
504
|
+
return ServerResponse(
|
|
505
|
+
status=ResponseStatus.ERROR, message=error_msg
|
|
506
|
+
)
|
|
507
|
+
if not request.data.audio:
|
|
508
|
+
logger.info(f"Received control message: {request}")
|
|
509
|
+
worker.runtime.push(request.data)
|
|
510
|
+
return ServerResponse(status=ResponseStatus.SUCCESS)
|
|
511
|
+
|
|
512
|
+
except Exception as e:
|
|
513
|
+
logger.exception(f"Error handling audio for client {msg.get('client_id')}")
|
|
514
|
+
return ServerResponse(status=ResponseStatus.ERROR, message=str(e))
|
|
515
|
+
|
|
516
|
+
def _handle_heartbeat(self, msg: dict) -> ServerResponse:
|
|
517
|
+
"""Handle heartbeat request.
|
|
518
|
+
|
|
519
|
+
Args:
|
|
520
|
+
msg: Message containing heartbeat parameters
|
|
521
|
+
|
|
522
|
+
Returns:
|
|
523
|
+
Response indicating success or failure
|
|
524
|
+
"""
|
|
525
|
+
try:
|
|
526
|
+
request = HeartbeatRequest(**msg)
|
|
527
|
+
client_id = request.client_id
|
|
528
|
+
logger.debug(f"Handling heartbeat for client {client_id}")
|
|
529
|
+
|
|
530
|
+
with self.workers_lock:
|
|
531
|
+
if (
|
|
532
|
+
client_id not in self.workers
|
|
533
|
+
and client_id not in self.initializing_workers
|
|
534
|
+
):
|
|
535
|
+
error_msg = "Client not initialized"
|
|
536
|
+
logger.error(error_msg)
|
|
537
|
+
return ServerResponse(
|
|
538
|
+
status=ResponseStatus.ERROR, message=error_msg
|
|
539
|
+
)
|
|
540
|
+
|
|
541
|
+
return ServerResponse(status=ResponseStatus.SUCCESS)
|
|
542
|
+
|
|
543
|
+
except Exception as e:
|
|
544
|
+
logger.exception(
|
|
545
|
+
f"Error handling heartbeat for client {msg.get('client_id')}"
|
|
546
|
+
)
|
|
547
|
+
return ServerResponse(status=ResponseStatus.ERROR, message=str(e))
|
|
548
|
+
|
|
549
|
+
def _handle_interrupt(self, msg: dict) -> ServerResponse:
|
|
550
|
+
"""Handle interrupt request.
|
|
551
|
+
|
|
552
|
+
Args:
|
|
553
|
+
msg: Message containing interrupt parameters
|
|
554
|
+
|
|
555
|
+
Returns:
|
|
556
|
+
Response indicating success or failure
|
|
557
|
+
"""
|
|
558
|
+
try:
|
|
559
|
+
request = InterruptRequest(**msg)
|
|
560
|
+
logger.info(f"Handling interrupt for client {request.client_id}")
|
|
561
|
+
|
|
562
|
+
with self.workers_lock:
|
|
563
|
+
worker = self.workers.get(request.client_id)
|
|
564
|
+
if not worker:
|
|
565
|
+
error_msg = "Client not initialized"
|
|
566
|
+
logger.error(error_msg)
|
|
567
|
+
return ServerResponse(
|
|
568
|
+
status=ResponseStatus.ERROR, message=error_msg
|
|
569
|
+
)
|
|
570
|
+
|
|
571
|
+
worker.interrupt()
|
|
572
|
+
logger.info(f"Interrupted audio for client {request.client_id}")
|
|
573
|
+
return ServerResponse(status=ResponseStatus.SUCCESS)
|
|
574
|
+
|
|
575
|
+
except Exception as e:
|
|
576
|
+
logger.exception(
|
|
577
|
+
f"Error handling interrupt for client {msg.get('client_id')}"
|
|
578
|
+
)
|
|
579
|
+
return ServerResponse(status=ResponseStatus.ERROR, message=str(e))
|
|
580
|
+
|
|
581
|
+
def _handle_check_init_status(self, msg: dict) -> ServerResponse:
|
|
582
|
+
"""Handle initialization status check request."""
|
|
583
|
+
try:
|
|
584
|
+
request = CheckInitStatusRequest(**msg)
|
|
585
|
+
client_id = request.client_id
|
|
586
|
+
logger.debug(f"Checking init status for client {client_id}")
|
|
587
|
+
|
|
588
|
+
# Check for initialization error
|
|
589
|
+
if client_id in self.init_errors:
|
|
590
|
+
error_msg = self.init_errors.pop(client_id)
|
|
591
|
+
logger.error(
|
|
592
|
+
f"Initialization error for client {client_id}: {error_msg}"
|
|
593
|
+
)
|
|
594
|
+
return ServerResponse(status=ResponseStatus.ERROR, message=error_msg)
|
|
595
|
+
|
|
596
|
+
# Check if initialization completed
|
|
597
|
+
if client_id in self.workers:
|
|
598
|
+
logger.info(f"Initialization complete for client {client_id}")
|
|
599
|
+
return ServerResponse(
|
|
600
|
+
status=ResponseStatus.SUCCESS,
|
|
601
|
+
message="Initialization complete",
|
|
602
|
+
)
|
|
603
|
+
|
|
604
|
+
# Still initializing
|
|
605
|
+
if client_id in self.initializing_workers:
|
|
606
|
+
logger.info(f"Initialization in progress for client {client_id}")
|
|
607
|
+
return ServerResponse(
|
|
608
|
+
status=ResponseStatus.LOADING,
|
|
609
|
+
message="Initialization in progress",
|
|
610
|
+
)
|
|
611
|
+
|
|
612
|
+
error_msg = "No initialization found for client"
|
|
613
|
+
logger.error(f"{error_msg}: {client_id}")
|
|
614
|
+
return ServerResponse(
|
|
615
|
+
status=ResponseStatus.ERROR,
|
|
616
|
+
message=error_msg,
|
|
617
|
+
)
|
|
618
|
+
|
|
619
|
+
except Exception as e:
|
|
620
|
+
logger.exception("Error checking initialization status")
|
|
621
|
+
return ServerResponse(status=ResponseStatus.ERROR, message=str(e))
|
|
622
|
+
|
|
623
|
+
def _handle_get_setting(self, msg: dict) -> ServerResponse:
|
|
624
|
+
"""Handle get setting request."""
|
|
625
|
+
try:
|
|
626
|
+
request = GetSettingRequest(**msg)
|
|
627
|
+
logger.debug(f"Handling get setting request for client {request.client_id}")
|
|
628
|
+
worker = self.workers.get(request.client_id)
|
|
629
|
+
if not worker:
|
|
630
|
+
error_msg = "Client not exist or not initialized"
|
|
631
|
+
logger.error(error_msg)
|
|
632
|
+
return ServerResponse(
|
|
633
|
+
status=ResponseStatus.ERROR,
|
|
634
|
+
message=error_msg,
|
|
635
|
+
)
|
|
636
|
+
if request.name == "video_script":
|
|
637
|
+
logger.info(f"Retrieved video script for client {request.client_id}")
|
|
638
|
+
return ServerResponse(
|
|
639
|
+
status=ResponseStatus.SUCCESS,
|
|
640
|
+
extra={"value": worker.runtime.video_graph.videos_script.to_dict()},
|
|
641
|
+
)
|
|
642
|
+
|
|
643
|
+
settings = worker.runtime.settings
|
|
644
|
+
if not hasattr(settings, request.name):
|
|
645
|
+
error_msg = f"Setting {request.name} not found"
|
|
646
|
+
logger.error(error_msg)
|
|
647
|
+
return ServerResponse(
|
|
648
|
+
status=ResponseStatus.ERROR,
|
|
649
|
+
message=error_msg,
|
|
650
|
+
)
|
|
651
|
+
value = getattr(worker.runtime.settings, request.name)
|
|
652
|
+
logger.info(
|
|
653
|
+
f"Retrieved setting '{request.name}' for client "
|
|
654
|
+
f"{request.client_id}: {value}"
|
|
655
|
+
)
|
|
656
|
+
return ServerResponse(status=ResponseStatus.SUCCESS, extra={"value": value})
|
|
657
|
+
except Exception as e:
|
|
658
|
+
logger.exception("Error getting setting")
|
|
659
|
+
return ServerResponse(status=ResponseStatus.ERROR, message=str(e))
|
|
660
|
+
|
|
661
|
+
def _monitor_workers(self) -> None:
|
|
662
|
+
"""Monitor workers and cleanup inactive ones."""
|
|
663
|
+
logger.info("Starting worker monitor thread")
|
|
664
|
+
|
|
665
|
+
while self.running:
|
|
666
|
+
try:
|
|
667
|
+
with self.workers_lock:
|
|
668
|
+
inactive_workers = [
|
|
669
|
+
client_id
|
|
670
|
+
for client_id, worker in self.workers.items()
|
|
671
|
+
if worker.is_inactive_timeout
|
|
672
|
+
]
|
|
673
|
+
|
|
674
|
+
for client_id in inactive_workers:
|
|
675
|
+
logger.info(f"Stopping inactive worker for client {client_id}")
|
|
676
|
+
worker = self.workers[client_id]
|
|
677
|
+
worker.stop()
|
|
678
|
+
del self.workers[client_id]
|
|
679
|
+
|
|
680
|
+
time.sleep(0.5) # Check every 500ms
|
|
681
|
+
|
|
682
|
+
except Exception:
|
|
683
|
+
logger.exception("Error in worker monitor")
|
|
684
|
+
|
|
685
|
+
def stop(self) -> None:
|
|
686
|
+
"""Stop the server and cleanup resources.
|
|
687
|
+
|
|
688
|
+
Stops all worker threads, closes sockets, and performs cleanup.
|
|
689
|
+
"""
|
|
690
|
+
logger.info("Shutting down server...")
|
|
691
|
+
self.running = False
|
|
692
|
+
|
|
693
|
+
# Stop all workers
|
|
694
|
+
with self.workers_lock:
|
|
695
|
+
for worker in self.workers.values():
|
|
696
|
+
worker.stop()
|
|
697
|
+
self.workers.clear()
|
|
698
|
+
logger.info("Stopped all workers")
|
|
699
|
+
|
|
700
|
+
# Stop threads
|
|
701
|
+
self.monitor_thread.join()
|
|
702
|
+
self.control_thread.join()
|
|
703
|
+
|
|
704
|
+
# Close sockets
|
|
705
|
+
self.control_socket.close()
|
|
706
|
+
logger.debug("Closed control socket")
|
|
707
|
+
self.stream_socket.close()
|
|
708
|
+
logger.debug("Closed stream socket")
|
|
709
|
+
self.context.term()
|
|
710
|
+
logger.info("Server shutdown complete")
|
|
711
|
+
|
|
712
|
+
|
|
713
|
+
def serve(
|
|
714
|
+
host: str = "0.0.0.0", control_port: int = 5555, stream_port: int = 5556
|
|
715
|
+
) -> None:
|
|
716
|
+
"""Start the server.
|
|
717
|
+
|
|
718
|
+
Args:
|
|
719
|
+
host: Host address to bind to
|
|
720
|
+
control_port: Port for control messages
|
|
721
|
+
stream_port: Port for frame streaming
|
|
722
|
+
"""
|
|
723
|
+
server = ZMQBithumanRuntimeServer(
|
|
724
|
+
host=host, control_port=control_port, stream_port=stream_port
|
|
725
|
+
)
|
|
726
|
+
try:
|
|
727
|
+
logger.info(
|
|
728
|
+
f"Server started on {host} "
|
|
729
|
+
f"(control_port={control_port}, stream_port={stream_port})"
|
|
730
|
+
)
|
|
731
|
+
while True:
|
|
732
|
+
time.sleep(1)
|
|
733
|
+
except KeyboardInterrupt:
|
|
734
|
+
logger.info("Received keyboard interrupt, shutting down server...")
|
|
735
|
+
server.stop()
|
|
736
|
+
|
|
737
|
+
|
|
738
|
+
if __name__ == "__main__":
|
|
739
|
+
import argparse
|
|
740
|
+
|
|
741
|
+
parser = argparse.ArgumentParser(description="ZMQ Bithuman Runtime Server")
|
|
742
|
+
parser.add_argument(
|
|
743
|
+
"--host", default="0.0.0.0", help="Host address to bind to (default: 0.0.0.0)"
|
|
744
|
+
)
|
|
745
|
+
parser.add_argument(
|
|
746
|
+
"--control-port",
|
|
747
|
+
type=int,
|
|
748
|
+
default=5555,
|
|
749
|
+
help="Port for control messages (default: 5555)",
|
|
750
|
+
)
|
|
751
|
+
parser.add_argument(
|
|
752
|
+
"--stream-port",
|
|
753
|
+
type=int,
|
|
754
|
+
default=5556,
|
|
755
|
+
help="Port for frame streaming (default: 5556)",
|
|
756
|
+
)
|
|
757
|
+
|
|
758
|
+
args = parser.parse_args()
|
|
759
|
+
serve(host=args.host, control_port=args.control_port, stream_port=args.stream_port)
|