bithuman 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bithuman/__init__.py +13 -0
- bithuman/_version.py +1 -0
- bithuman/api.py +164 -0
- bithuman/audio/__init__.py +19 -0
- bithuman/audio/audio.py +396 -0
- bithuman/audio/hparams.py +108 -0
- bithuman/audio/utils.py +255 -0
- bithuman/config.py +88 -0
- bithuman/engine/__init__.py +15 -0
- bithuman/engine/auth.py +335 -0
- bithuman/engine/compression.py +257 -0
- bithuman/engine/enums.py +16 -0
- bithuman/engine/image_ops.py +192 -0
- bithuman/engine/inference.py +108 -0
- bithuman/engine/knn.py +58 -0
- bithuman/engine/video_data.py +391 -0
- bithuman/engine/video_reader.py +168 -0
- bithuman/lib/__init__.py +1 -0
- bithuman/lib/audio_encoder.onnx +45631 -28
- bithuman/lib/generator.py +763 -0
- bithuman/lib/pth2h5.py +106 -0
- bithuman/plugins/__init__.py +0 -0
- bithuman/plugins/stt.py +185 -0
- bithuman/runtime.py +1004 -0
- bithuman/runtime_async.py +469 -0
- bithuman/service/__init__.py +9 -0
- bithuman/service/client.py +788 -0
- bithuman/service/messages.py +210 -0
- bithuman/service/server.py +759 -0
- bithuman/utils/__init__.py +43 -0
- bithuman/utils/agent.py +359 -0
- bithuman/utils/fps_controller.py +90 -0
- bithuman/utils/image.py +41 -0
- bithuman/utils/unzip.py +38 -0
- bithuman/video_graph/__init__.py +16 -0
- bithuman/video_graph/action_trigger.py +83 -0
- bithuman/video_graph/driver_video.py +482 -0
- bithuman/video_graph/navigator.py +736 -0
- bithuman/video_graph/trigger.py +90 -0
- bithuman/video_graph/video_script.py +344 -0
- bithuman-1.0.2.dist-info/METADATA +37 -0
- bithuman-1.0.2.dist-info/RECORD +44 -0
- bithuman-1.0.2.dist-info/WHEEL +5 -0
- bithuman-1.0.2.dist-info/top_level.txt +1 -0
bithuman/runtime.py
ADDED
|
@@ -0,0 +1,1004 @@
|
|
|
1
|
+
"""Bithuman Runtime."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import copy
|
|
7
|
+
import logging
|
|
8
|
+
import os
|
|
9
|
+
import threading
|
|
10
|
+
import time
|
|
11
|
+
from functools import cached_property
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from queue import Empty, Queue
|
|
14
|
+
from threading import Event
|
|
15
|
+
from typing import Callable, Generic, Iterable, Optional, Tuple, TypeVar
|
|
16
|
+
|
|
17
|
+
import numpy as np
|
|
18
|
+
from loguru import logger
|
|
19
|
+
|
|
20
|
+
from . import audio as audio_utils
|
|
21
|
+
from .api import AudioChunk, VideoControl, VideoFrame
|
|
22
|
+
from .config import load_settings
|
|
23
|
+
from .lib.generator import BithumanGenerator
|
|
24
|
+
from .utils import calculate_file_hash
|
|
25
|
+
from .video_graph import Frame as FrameMeta
|
|
26
|
+
from .video_graph import VideoGraphNavigator
|
|
27
|
+
|
|
28
|
+
logging.getLogger("numba").setLevel(logging.WARNING)
|
|
29
|
+
|
|
30
|
+
T = TypeVar("T")
|
|
31
|
+
|
|
32
|
+
BufferEmptyCallback = Callable[[], bool]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class _ActionDebouncer:
|
|
36
|
+
"""Prevent redundant action playback across consecutive frames."""
|
|
37
|
+
|
|
38
|
+
__slots__ = ("_lock", "_last_signature")
|
|
39
|
+
|
|
40
|
+
def __init__(self) -> None:
|
|
41
|
+
self._lock = threading.Lock()
|
|
42
|
+
self._last_signature: Optional[Tuple[str, str]] = None
|
|
43
|
+
|
|
44
|
+
def prepare(self, control: VideoControl) -> None:
|
|
45
|
+
if not control.action:
|
|
46
|
+
return
|
|
47
|
+
|
|
48
|
+
signature = (control.action, control.target_video or "")
|
|
49
|
+
with self._lock:
|
|
50
|
+
if not control.force_action and signature == self._last_signature:
|
|
51
|
+
logger.debug("Suppressing repeated action: %s", signature)
|
|
52
|
+
control.action = None
|
|
53
|
+
else:
|
|
54
|
+
self._last_signature = signature
|
|
55
|
+
|
|
56
|
+
def reset(self) -> None:
|
|
57
|
+
with self._lock:
|
|
58
|
+
self._last_signature = None
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class Bithuman:
|
|
62
|
+
"""Bithuman Runtime."""
|
|
63
|
+
|
|
64
|
+
def __init__(
|
|
65
|
+
self,
|
|
66
|
+
*,
|
|
67
|
+
input_buffer_size: int = 0,
|
|
68
|
+
token: Optional[str] = None,
|
|
69
|
+
model_path: Optional[str] = None,
|
|
70
|
+
api_secret: Optional[str] = None,
|
|
71
|
+
api_url: str = "https://auth.api.bithuman.ai/v1/runtime-tokens/request",
|
|
72
|
+
tags: Optional[str] = None,
|
|
73
|
+
insecure: bool = True,
|
|
74
|
+
num_threads: int = 0,
|
|
75
|
+
verbose: Optional[bool] = None,
|
|
76
|
+
) -> None:
|
|
77
|
+
"""Initialize the Bithuman Runtime.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
input_buffer_size: The size of the input buffer.
|
|
81
|
+
token: The token for the Bithuman Runtime. Either token or api_secret must be provided.
|
|
82
|
+
model_path: The path to the avatar model.
|
|
83
|
+
api_secret: API Secret for API authentication. Either token or api_secret must be provided.
|
|
84
|
+
api_url: API endpoint URL for token requests.
|
|
85
|
+
tags: Optional tags for token request.
|
|
86
|
+
insecure: Disable SSL certificate verification (not recommended for production use).
|
|
87
|
+
num_threads: Number of threads for processing, 0 = single-threaded, >0 = use specified number of threads, <0 = auto-detect optimal thread count
|
|
88
|
+
verbose: Enable verbose logging for token validation. If None, reads from BITHUMAN_VERBOSE environment variable.
|
|
89
|
+
"""
|
|
90
|
+
# Set verbose from parameter or environment variable
|
|
91
|
+
if verbose is None:
|
|
92
|
+
verbose = os.getenv("BITHUMAN_VERBOSE", "false").lower() in ("true", "1", "yes", "on")
|
|
93
|
+
self._verbose = verbose
|
|
94
|
+
self._num_threads = num_threads
|
|
95
|
+
|
|
96
|
+
# Transaction ID will be generated in C++ layer at start() to prevent user tampering
|
|
97
|
+
self.transaction_id = ""
|
|
98
|
+
|
|
99
|
+
logger.debug(
|
|
100
|
+
f"Initializing Bithuman runtime with: model_path={model_path}, token={token is not None}, api_secret={api_secret is not None}, verbose={verbose}"
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
# Log environment variables for debugging
|
|
104
|
+
logger.debug(f"BITHUMAN_VERBOSE env var: {os.getenv('BITHUMAN_VERBOSE', 'not set')}")
|
|
105
|
+
logger.debug(f"LOADING_MODE env var: {os.getenv('LOADING_MODE', 'not set')}")
|
|
106
|
+
|
|
107
|
+
# Mask sensitive information in logs
|
|
108
|
+
if api_secret:
|
|
109
|
+
masked_secret = f"{api_secret[:5]}...{api_secret[-5:] if len(api_secret) > 10 else '***'}"
|
|
110
|
+
logger.debug(f"API secret provided: {masked_secret}")
|
|
111
|
+
if token:
|
|
112
|
+
masked_token = f"{token[:10]}...{token[-10:] if len(token) > 20 else '***'}"
|
|
113
|
+
logger.debug(f"Token provided: {masked_token}")
|
|
114
|
+
|
|
115
|
+
if not token and not api_secret:
|
|
116
|
+
logger.error("Neither token nor api_secret provided")
|
|
117
|
+
raise ValueError("Either token or api_secret must be provided")
|
|
118
|
+
|
|
119
|
+
self.settings = copy.deepcopy(load_settings())
|
|
120
|
+
|
|
121
|
+
try:
|
|
122
|
+
# Initialize generator with token refresh parameters if provided
|
|
123
|
+
# Token refresh will start automatically
|
|
124
|
+
self.generator = BithumanGenerator(
|
|
125
|
+
audio_encoder_path=str(self.settings.AUDIO_ENCODER_PATH),
|
|
126
|
+
api_secret=api_secret if api_secret else None,
|
|
127
|
+
api_url=api_url if api_secret else None,
|
|
128
|
+
model_path=model_path if api_secret else None,
|
|
129
|
+
tags=tags if api_secret else None,
|
|
130
|
+
insecure=insecure if api_secret else False,
|
|
131
|
+
)
|
|
132
|
+
except Exception as e:
|
|
133
|
+
logger.error(f"Failed to initialize BithumanGenerator: {e}")
|
|
134
|
+
raise
|
|
135
|
+
|
|
136
|
+
self.video_graph: Optional[VideoGraphNavigator] = None
|
|
137
|
+
|
|
138
|
+
# Store token request parameters
|
|
139
|
+
# Note: These are stored as private attributes to prevent direct modification
|
|
140
|
+
# after initialization. Token refresh parameters are locked once refresh starts.
|
|
141
|
+
self._model_path = model_path
|
|
142
|
+
self._api_secret = api_secret
|
|
143
|
+
self._api_url = api_url
|
|
144
|
+
self._tags = tags
|
|
145
|
+
self._insecure = insecure
|
|
146
|
+
self._token = token
|
|
147
|
+
# Track if token refresh has been started to prevent parameter changes
|
|
148
|
+
self._token_refresh_started = False
|
|
149
|
+
|
|
150
|
+
# Token refresh state
|
|
151
|
+
self._action_debouncer = _ActionDebouncer()
|
|
152
|
+
|
|
153
|
+
# Account status flag - set to True when account has issues (402, 403, etc.)
|
|
154
|
+
self._account_status_error = threading.Event()
|
|
155
|
+
|
|
156
|
+
try:
|
|
157
|
+
self._warmup()
|
|
158
|
+
except Exception as e:
|
|
159
|
+
logger.error(f"Warmup failed: {e}")
|
|
160
|
+
raise
|
|
161
|
+
|
|
162
|
+
# Ignore audios when muted
|
|
163
|
+
self.muted = Event()
|
|
164
|
+
self.interrupt_event = Event()
|
|
165
|
+
self._input_buffer = ThreadSafeAsyncQueue[VideoControl](
|
|
166
|
+
maxsize=input_buffer_size
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
# Video
|
|
170
|
+
self.audio_batcher = audio_utils.AudioStreamBatcher(
|
|
171
|
+
output_sample_rate=self.settings.INPUT_SAMPLE_RATE
|
|
172
|
+
)
|
|
173
|
+
self._video_loaded = False
|
|
174
|
+
self._sample_per_video_frame = (
|
|
175
|
+
self.settings.INPUT_SAMPLE_RATE / self.settings.FPS
|
|
176
|
+
)
|
|
177
|
+
self._idle_timeout: float = 0.001
|
|
178
|
+
|
|
179
|
+
self._model_hash = None
|
|
180
|
+
if self._model_path:
|
|
181
|
+
# load model if provided
|
|
182
|
+
self._initialize_token()
|
|
183
|
+
self.set_model(self._model_path)
|
|
184
|
+
|
|
185
|
+
def set_idle_timeout(self, idle_timeout: float) -> None:
|
|
186
|
+
"""Set the idle timeout for the Bithuman Runtime.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
idle_timeout: The idle timeout in seconds.
|
|
190
|
+
"""
|
|
191
|
+
self._idle_timeout = idle_timeout
|
|
192
|
+
|
|
193
|
+
def _regenerate_transaction_id(self) -> None:
|
|
194
|
+
"""Generate transaction ID for new runtime sessions.
|
|
195
|
+
|
|
196
|
+
This method is called when starting the runtime to ensure each session
|
|
197
|
+
has a unique transaction identifier. Once token refresh starts, transaction ID
|
|
198
|
+
is locked and cannot be regenerated to prevent billing bypass.
|
|
199
|
+
"""
|
|
200
|
+
old_id = self.transaction_id
|
|
201
|
+
self.transaction_id = self.generator.generate_transaction_id()
|
|
202
|
+
logger.debug(f"Generated transaction ID: {old_id} -> {self.transaction_id}")
|
|
203
|
+
|
|
204
|
+
def set_token(self, token: str, verbose: Optional[bool] = None) -> bool:
|
|
205
|
+
"""Set and validate the token for the Bithuman Runtime.
|
|
206
|
+
|
|
207
|
+
This method validates the provided token and sets it for subsequent operations if valid.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
token: The token to validate and set.
|
|
211
|
+
verbose: Enable verbose logging for token validation. If None, uses instance default.
|
|
212
|
+
|
|
213
|
+
Returns:
|
|
214
|
+
bool: True if token is valid and set successfully, False otherwise.
|
|
215
|
+
|
|
216
|
+
Raises:
|
|
217
|
+
ValueError: If the token is invalid.
|
|
218
|
+
"""
|
|
219
|
+
if verbose is None:
|
|
220
|
+
verbose = self._verbose
|
|
221
|
+
|
|
222
|
+
logger.debug(f"Attempting to set token: {token[:10]}...{token[-10:] if len(token) > 20 else '***'}")
|
|
223
|
+
|
|
224
|
+
is_valid = self.generator._generator.validate_token(token, verbose)
|
|
225
|
+
if not is_valid:
|
|
226
|
+
logger.error("Token validation failed - token is invalid")
|
|
227
|
+
raise ValueError("Invalid token")
|
|
228
|
+
|
|
229
|
+
logger.debug("Token validated and set successfully")
|
|
230
|
+
return True
|
|
231
|
+
|
|
232
|
+
def is_token_validated(self) -> bool:
|
|
233
|
+
"""Check if the token is validated."""
|
|
234
|
+
return self.generator.is_token_validated()
|
|
235
|
+
|
|
236
|
+
def get_expiration_time(self) -> int:
|
|
237
|
+
"""Get the expiration time of the token."""
|
|
238
|
+
return self.generator.get_expiration_time()
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def _handle_token_validation_error(self, e: RuntimeError, context: str = "") -> RuntimeError:
|
|
242
|
+
"""Handle token validation errors.
|
|
243
|
+
|
|
244
|
+
This method checks if a RuntimeError is related to token validation failure
|
|
245
|
+
and converts it to a standardized exception that can be caught by callers.
|
|
246
|
+
|
|
247
|
+
Args:
|
|
248
|
+
e: The RuntimeError exception to check
|
|
249
|
+
context: Additional context about where the error occurred
|
|
250
|
+
|
|
251
|
+
Returns:
|
|
252
|
+
RuntimeError: A standardized token validation error, or re-raises the original exception
|
|
253
|
+
"""
|
|
254
|
+
error_msg = str(e).lower()
|
|
255
|
+
token_expired_indicators = [
|
|
256
|
+
"token has expired",
|
|
257
|
+
"validation failed: token has expired",
|
|
258
|
+
"validation failed: token not validated",
|
|
259
|
+
]
|
|
260
|
+
|
|
261
|
+
for indicator in token_expired_indicators:
|
|
262
|
+
if indicator in error_msg:
|
|
263
|
+
logger.error(
|
|
264
|
+
f"Token validation failed{(' during ' + context) if context else ''}: {str(e)}"
|
|
265
|
+
)
|
|
266
|
+
# Return a standardized exception that can be caught by callers
|
|
267
|
+
return RuntimeError("Token validation failed: token has expired")
|
|
268
|
+
|
|
269
|
+
# Not a token expiration error, re-raise the original exception
|
|
270
|
+
raise
|
|
271
|
+
|
|
272
|
+
def _initialize_token(self) -> None:
|
|
273
|
+
"""Initialize token if provided by user.
|
|
274
|
+
|
|
275
|
+
If user provided a token, validate and set it.
|
|
276
|
+
If user provided api_secret, token refresh is handled automatically.
|
|
277
|
+
"""
|
|
278
|
+
if self._token:
|
|
279
|
+
logger.debug("Token provided, validating...")
|
|
280
|
+
try:
|
|
281
|
+
is_valid = self.generator._generator.validate_token(self._token, self._verbose)
|
|
282
|
+
if not is_valid:
|
|
283
|
+
raise ValueError("Token validation failed")
|
|
284
|
+
logger.debug("Token validated and set successfully")
|
|
285
|
+
except Exception as e:
|
|
286
|
+
logger.warning(f"Token validation failed: {e}")
|
|
287
|
+
raise
|
|
288
|
+
# If api_secret is provided, token refresh is handled automatically
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
def set_model(self, model_path: str) -> "Bithuman":
|
|
292
|
+
"""Set the video file or workspace directory.
|
|
293
|
+
|
|
294
|
+
Args:
|
|
295
|
+
model_path: The workspace directory.
|
|
296
|
+
"""
|
|
297
|
+
if not model_path:
|
|
298
|
+
logger.error("No model path provided to set_model()")
|
|
299
|
+
raise ValueError("Model path cannot be empty")
|
|
300
|
+
|
|
301
|
+
if model_path == self._model_path and self._video_loaded:
|
|
302
|
+
logger.debug("Model path is the same as the current model path, skipping")
|
|
303
|
+
return
|
|
304
|
+
|
|
305
|
+
if not Path(model_path).exists():
|
|
306
|
+
raise FileNotFoundError(f"Model path {model_path} does not exist")
|
|
307
|
+
|
|
308
|
+
# Security check: If token refresh is already running, verify model_path matches
|
|
309
|
+
# This prevents users from changing model_path after token refresh has started
|
|
310
|
+
if self._api_secret and self._api_url:
|
|
311
|
+
if self.generator.is_token_refresh_running():
|
|
312
|
+
# Token refresh is already running - verify model_path hasn't changed
|
|
313
|
+
if self._model_path and model_path != self._model_path:
|
|
314
|
+
logger.error(
|
|
315
|
+
f"Security violation: Cannot change model_path from '{self._model_path}' "
|
|
316
|
+
f"to '{model_path}' after token refresh has started. "
|
|
317
|
+
"Token refresh parameters are locked for security."
|
|
318
|
+
)
|
|
319
|
+
raise RuntimeError(
|
|
320
|
+
"Cannot change model_path after token refresh has started. "
|
|
321
|
+
"This is a security restriction."
|
|
322
|
+
)
|
|
323
|
+
else:
|
|
324
|
+
# Token refresh not started yet - request initial token for validation (not refresh)
|
|
325
|
+
# This is needed for load_data() validation, but doesn't start background refresh thread
|
|
326
|
+
try:
|
|
327
|
+
# Generate transaction ID before requesting token
|
|
328
|
+
if not self.transaction_id:
|
|
329
|
+
self._regenerate_transaction_id()
|
|
330
|
+
|
|
331
|
+
# Request initial token (single request, no background refresh thread)
|
|
332
|
+
initial_token = self.generator.request_token(
|
|
333
|
+
api_url=self._api_url,
|
|
334
|
+
api_secret=self._api_secret,
|
|
335
|
+
model_path=model_path,
|
|
336
|
+
tags=self._tags,
|
|
337
|
+
transaction_id=self.transaction_id,
|
|
338
|
+
insecure=self._insecure,
|
|
339
|
+
timeout=30.0
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
# Validate and set the token
|
|
343
|
+
is_valid = self.generator._generator.validate_token(initial_token, self._verbose)
|
|
344
|
+
if not is_valid:
|
|
345
|
+
logger.error("Initial token validation failed in set_model()")
|
|
346
|
+
raise RuntimeError("Initial token validation failed")
|
|
347
|
+
|
|
348
|
+
logger.debug("Initial token requested and validated in set_model()")
|
|
349
|
+
except Exception as e:
|
|
350
|
+
logger.error(f"Failed to request initial token in set_model(): {e}")
|
|
351
|
+
raise
|
|
352
|
+
|
|
353
|
+
# Store the model path for token requests
|
|
354
|
+
self._model_path = model_path
|
|
355
|
+
|
|
356
|
+
if Path(model_path).is_file():
|
|
357
|
+
# Set model hash for JWT validation (used by add_video)
|
|
358
|
+
try:
|
|
359
|
+
model_hash = self.generator.set_model_hash_from_file(model_path)
|
|
360
|
+
self._model_hash = model_hash
|
|
361
|
+
except Exception as e:
|
|
362
|
+
logger.error(f"Failed to calculate model hash: {e}")
|
|
363
|
+
raise
|
|
364
|
+
else:
|
|
365
|
+
logger.info(
|
|
366
|
+
"Skip model hash verification for non-file avatar model, "
|
|
367
|
+
"make sure the token is valid for kind of usage."
|
|
368
|
+
)
|
|
369
|
+
self._model_hash = None
|
|
370
|
+
|
|
371
|
+
try:
|
|
372
|
+
self.video_graph = VideoGraphNavigator.from_workspace(
|
|
373
|
+
model_path, extract_to_local=self.settings.EXTRACT_WORKSPACE_TO_LOCAL
|
|
374
|
+
).load_workspace()
|
|
375
|
+
except Exception as e:
|
|
376
|
+
logger.error(f"Failed to create VideoGraphNavigator or load workspace: {e}")
|
|
377
|
+
raise
|
|
378
|
+
|
|
379
|
+
try:
|
|
380
|
+
self.video_graph.update_runtime_configs(self.settings)
|
|
381
|
+
except Exception as e:
|
|
382
|
+
logger.error(f"Failed to update runtime configs: {e}")
|
|
383
|
+
raise
|
|
384
|
+
|
|
385
|
+
try:
|
|
386
|
+
self.generator.set_output_size(self.settings.OUTPUT_WIDTH)
|
|
387
|
+
except Exception as e:
|
|
388
|
+
logger.error(f"Failed to set output size: {e}")
|
|
389
|
+
raise
|
|
390
|
+
|
|
391
|
+
self._video_loaded = False
|
|
392
|
+
|
|
393
|
+
try:
|
|
394
|
+
self.load_data()
|
|
395
|
+
except Exception as e:
|
|
396
|
+
logger.error(f"load_data() failed: {e}")
|
|
397
|
+
raise
|
|
398
|
+
|
|
399
|
+
return self
|
|
400
|
+
|
|
401
|
+
@property
|
|
402
|
+
def model_hash(self) -> Optional[str]:
|
|
403
|
+
"""Get the model hash (read-only).
|
|
404
|
+
|
|
405
|
+
Returns the unique model hash that was generated during model loading.
|
|
406
|
+
This property is read-only and cannot be modified after initialization
|
|
407
|
+
after initialization.
|
|
408
|
+
|
|
409
|
+
Returns:
|
|
410
|
+
Optional[str]: The model hash if a file model was loaded, None otherwise.
|
|
411
|
+
"""
|
|
412
|
+
return self._model_hash
|
|
413
|
+
|
|
414
|
+
def load_data(self) -> None:
|
|
415
|
+
"""Load the workspace and set up related components."""
|
|
416
|
+
if self._video_loaded:
|
|
417
|
+
return
|
|
418
|
+
if self.video_graph is None:
|
|
419
|
+
logger.error("Video graph is None. Model may not be set properly.")
|
|
420
|
+
raise ValueError("Video graph is not set. Call set_model() first.")
|
|
421
|
+
|
|
422
|
+
models_path = Path(self.video_graph.avatar_model_path)
|
|
423
|
+
|
|
424
|
+
def find_avatar_data_file(video_path: str) -> Optional[str]:
|
|
425
|
+
video_name = Path(video_path).stem
|
|
426
|
+
for type in ["feature-first", "time-first"]:
|
|
427
|
+
files = list(models_path.glob(f"*/{video_name}.{type}.*"))
|
|
428
|
+
if files:
|
|
429
|
+
return str(files[0])
|
|
430
|
+
return None
|
|
431
|
+
|
|
432
|
+
try:
|
|
433
|
+
audio_feature_files = list(models_path.glob("*/feature_centers.npy"))
|
|
434
|
+
audio_feature_file = audio_feature_files[0]
|
|
435
|
+
except IndexError:
|
|
436
|
+
logger.error(f"Audio features file not found in {models_path}")
|
|
437
|
+
raise FileNotFoundError(f"Audio features file not found in {models_path}")
|
|
438
|
+
|
|
439
|
+
try:
|
|
440
|
+
audio_features = np.load(audio_feature_file)
|
|
441
|
+
except Exception as e:
|
|
442
|
+
logger.error(f"Failed to load audio features: {e}")
|
|
443
|
+
raise
|
|
444
|
+
|
|
445
|
+
try:
|
|
446
|
+
self.generator.set_audio_feature(audio_features)
|
|
447
|
+
except Exception as e:
|
|
448
|
+
logger.error(f"Failed to set audio feature in generator: {e}")
|
|
449
|
+
raise
|
|
450
|
+
|
|
451
|
+
videos = list(self.video_graph.videos.items())
|
|
452
|
+
filler_videos = list(self.video_graph.filler_videos.items())
|
|
453
|
+
logger.info(
|
|
454
|
+
f"Loading model data: {len(videos)} models and {len(filler_videos)} fillers"
|
|
455
|
+
)
|
|
456
|
+
|
|
457
|
+
for name, video in videos + filler_videos:
|
|
458
|
+
video_data_path = video.video_data_path
|
|
459
|
+
avatar_data_path = find_avatar_data_file(video.video_path)
|
|
460
|
+
|
|
461
|
+
if video.lip_sync_required:
|
|
462
|
+
if not (video_data_path and avatar_data_path):
|
|
463
|
+
logger.error(f"Model data not found for video {name}")
|
|
464
|
+
raise ValueError(f"Model data not found for video {name}")
|
|
465
|
+
else:
|
|
466
|
+
video_data_path, avatar_data_path = "", ""
|
|
467
|
+
|
|
468
|
+
# Process the video data file if needed
|
|
469
|
+
try:
|
|
470
|
+
video_data_path = self._process_video_data_file(video_data_path)
|
|
471
|
+
except Exception as e:
|
|
472
|
+
logger.error(f"Failed to process video data file for {name}: {e}")
|
|
473
|
+
raise
|
|
474
|
+
|
|
475
|
+
try:
|
|
476
|
+
self.generator.add_video(
|
|
477
|
+
name,
|
|
478
|
+
video_path=video.video_path,
|
|
479
|
+
video_data_path=video_data_path,
|
|
480
|
+
avatar_data_path=avatar_data_path,
|
|
481
|
+
compression_type=self.settings.COMPRESS_METHOD,
|
|
482
|
+
loading_mode=self.settings.LOADING_MODE,
|
|
483
|
+
thread_count=self._num_threads,
|
|
484
|
+
)
|
|
485
|
+
except Exception as e:
|
|
486
|
+
logger.error(f"Failed to add video {name} to generator: {e}")
|
|
487
|
+
raise
|
|
488
|
+
|
|
489
|
+
logger.info("Model data loaded successfully")
|
|
490
|
+
self._video_loaded = True
|
|
491
|
+
|
|
492
|
+
def get_first_frame(self) -> Optional[np.ndarray]:
|
|
493
|
+
"""Get the first frame of the video."""
|
|
494
|
+
if not self.video_graph:
|
|
495
|
+
logger.error("Model is not set. Call set_model() first.")
|
|
496
|
+
return None
|
|
497
|
+
try:
|
|
498
|
+
frame = self.video_graph.get_first_frame(self.settings.OUTPUT_WIDTH)
|
|
499
|
+
return frame
|
|
500
|
+
except Exception as e:
|
|
501
|
+
logger.error(f"Failed to get the first frame: {e}")
|
|
502
|
+
return None
|
|
503
|
+
|
|
504
|
+
def get_frame_size(self) -> tuple[int, int]:
|
|
505
|
+
"""Get the frame size in width and height."""
|
|
506
|
+
image = self.get_first_frame()
|
|
507
|
+
if image is None:
|
|
508
|
+
logger.error("Failed to get the first frame")
|
|
509
|
+
raise ValueError("Failed to get the first frame")
|
|
510
|
+
size = (image.shape[1], image.shape[0])
|
|
511
|
+
return size
|
|
512
|
+
|
|
513
|
+
def interrupt(self) -> None:
|
|
514
|
+
"""Interrupt the daemon."""
|
|
515
|
+
# clear the input buffer
|
|
516
|
+
while not self._input_buffer.empty():
|
|
517
|
+
try:
|
|
518
|
+
self._input_buffer.get_nowait()
|
|
519
|
+
except Empty:
|
|
520
|
+
break
|
|
521
|
+
self.audio_batcher.reset()
|
|
522
|
+
self.interrupt_event.set()
|
|
523
|
+
|
|
524
|
+
def set_muted(self, mute: bool) -> None:
|
|
525
|
+
"""Set the muted state."""
|
|
526
|
+
if mute:
|
|
527
|
+
self.muted.set()
|
|
528
|
+
else:
|
|
529
|
+
self.muted.clear()
|
|
530
|
+
|
|
531
|
+
def push_audio(
|
|
532
|
+
self, data: bytes, sample_rate: int, last_chunk: bool = True
|
|
533
|
+
) -> None:
|
|
534
|
+
"""Push the audio to the input buffer."""
|
|
535
|
+
self._input_buffer.put(VideoControl.from_audio(data, sample_rate, last_chunk))
|
|
536
|
+
|
|
537
|
+
def flush(self) -> None:
|
|
538
|
+
"""Flush the input buffer."""
|
|
539
|
+
self._input_buffer.put(VideoControl(end_of_speech=True))
|
|
540
|
+
|
|
541
|
+
def push(self, control: VideoControl) -> None:
|
|
542
|
+
"""Push the control (with audio, text, action, etc.) to the input buffer."""
|
|
543
|
+
self._input_buffer.put(control)
|
|
544
|
+
|
|
545
|
+
def run(
|
|
546
|
+
self,
|
|
547
|
+
out_buffer_empty: Optional[BufferEmptyCallback] = None,
|
|
548
|
+
*,
|
|
549
|
+
idle_timeout: float | None = None,
|
|
550
|
+
) -> Iterable[VideoFrame]:
|
|
551
|
+
# Start token refresh if api_secret is provided and refresh is not already running
|
|
552
|
+
# This ensures token is available before runtime starts processing
|
|
553
|
+
if self._api_secret and self._api_url and self._model_path:
|
|
554
|
+
if not self.generator.is_token_refresh_running():
|
|
555
|
+
try:
|
|
556
|
+
# Generate transaction ID before starting token refresh
|
|
557
|
+
if not self.transaction_id:
|
|
558
|
+
self._regenerate_transaction_id()
|
|
559
|
+
|
|
560
|
+
success = self.generator.start_token_refresh(
|
|
561
|
+
api_url=self._api_url,
|
|
562
|
+
api_secret=self._api_secret,
|
|
563
|
+
model_path=self._model_path,
|
|
564
|
+
tags=self._tags,
|
|
565
|
+
refresh_interval=60,
|
|
566
|
+
insecure=self._insecure,
|
|
567
|
+
timeout=30.0
|
|
568
|
+
)
|
|
569
|
+
if success:
|
|
570
|
+
logger.debug("Token refresh started in run()")
|
|
571
|
+
self._token_refresh_started = True
|
|
572
|
+
# startTokenRefresh does synchronous initial token request, so token should be ready
|
|
573
|
+
import time
|
|
574
|
+
time.sleep(0.2) # Give a moment for token validation state to be set
|
|
575
|
+
else:
|
|
576
|
+
logger.error("Failed to start token refresh in run()")
|
|
577
|
+
raise RuntimeError("Failed to start token refresh")
|
|
578
|
+
except Exception as e:
|
|
579
|
+
logger.error(f"Failed to start token refresh in run(): {e}")
|
|
580
|
+
raise
|
|
581
|
+
else:
|
|
582
|
+
# Token refresh already running - just ensure transaction ID is set
|
|
583
|
+
if not self.transaction_id:
|
|
584
|
+
self._regenerate_transaction_id()
|
|
585
|
+
else:
|
|
586
|
+
# Generate transaction ID only if not already set (prevents bypassing billing)
|
|
587
|
+
if not self.transaction_id:
|
|
588
|
+
self._regenerate_transaction_id()
|
|
589
|
+
|
|
590
|
+
# Current frame index, reset for every new audio
|
|
591
|
+
if self.video_graph is None:
|
|
592
|
+
raise ValueError("Model is not set. Call set_model() first.")
|
|
593
|
+
|
|
594
|
+
curr_frame_index = 0
|
|
595
|
+
action_played = False # Whether the action is played in this speech
|
|
596
|
+
token_expired = False # Flag to track token expiration
|
|
597
|
+
|
|
598
|
+
while True:
|
|
599
|
+
# Check if token has expired - if so, stop immediately
|
|
600
|
+
if token_expired:
|
|
601
|
+
logger.error("Token has expired, stopping runtime")
|
|
602
|
+
# Stop token refresh if running
|
|
603
|
+
if self.generator.is_token_refresh_running():
|
|
604
|
+
try:
|
|
605
|
+
self.generator._generator.stop_token_refresh()
|
|
606
|
+
except Exception as stop_error:
|
|
607
|
+
logger.warning(f"Error stopping token refresh: {stop_error}")
|
|
608
|
+
break
|
|
609
|
+
|
|
610
|
+
try:
|
|
611
|
+
if self.interrupt_event.is_set():
|
|
612
|
+
# Clear the interrupt event for the next loop
|
|
613
|
+
self.interrupt_event.clear()
|
|
614
|
+
action_played = False
|
|
615
|
+
control = self._input_buffer.get(
|
|
616
|
+
timeout=idle_timeout or self._idle_timeout
|
|
617
|
+
)
|
|
618
|
+
if control.action:
|
|
619
|
+
logger.debug(f"Action: {control.action}")
|
|
620
|
+
if self.muted.is_set():
|
|
621
|
+
# Consume and skip the audio when muted
|
|
622
|
+
control = VideoControl(message_id="MUTED")
|
|
623
|
+
action_played = False # Reset the action played flag
|
|
624
|
+
except Empty:
|
|
625
|
+
if out_buffer_empty and not out_buffer_empty():
|
|
626
|
+
continue
|
|
627
|
+
control = VideoControl(message_id="IDLE") # idle
|
|
628
|
+
|
|
629
|
+
if self.video_graph is None:
|
|
630
|
+
# cleanup is called
|
|
631
|
+
logger.debug("Stopping runtime after cleanup")
|
|
632
|
+
break
|
|
633
|
+
|
|
634
|
+
# Edit the video based on script if the input is None
|
|
635
|
+
if not control.target_video and not control.action:
|
|
636
|
+
control.target_video, control.action, reset_action = (
|
|
637
|
+
self.video_graph.videos_script.get_video_and_actions(
|
|
638
|
+
curr_frame_index,
|
|
639
|
+
control.emotion_preds,
|
|
640
|
+
text=control.text,
|
|
641
|
+
is_idle=control.is_idle,
|
|
642
|
+
settings=self.settings,
|
|
643
|
+
)
|
|
644
|
+
)
|
|
645
|
+
if reset_action:
|
|
646
|
+
action_played = False
|
|
647
|
+
|
|
648
|
+
if not control.is_idle:
|
|
649
|
+
# Avoid playing the action multiple times in a conversation
|
|
650
|
+
if action_played and not control.force_action:
|
|
651
|
+
control.action = None
|
|
652
|
+
elif control.action:
|
|
653
|
+
action_played = True
|
|
654
|
+
|
|
655
|
+
try:
|
|
656
|
+
frames_yielded = False
|
|
657
|
+
for frame in self.process(control):
|
|
658
|
+
yield frame
|
|
659
|
+
curr_frame_index += 1
|
|
660
|
+
frames_yielded = True
|
|
661
|
+
|
|
662
|
+
except RuntimeError as e:
|
|
663
|
+
# Catch token validation errors
|
|
664
|
+
error_msg = str(e).lower()
|
|
665
|
+
if ("token has expired" in error_msg or
|
|
666
|
+
"token validation failed" in error_msg or
|
|
667
|
+
"validation failed" in error_msg):
|
|
668
|
+
logger.error(f"Token validation failed in run() loop: {str(e)}, stopping video stream")
|
|
669
|
+
token_expired = True
|
|
670
|
+
# Stop token refresh if running
|
|
671
|
+
if self.generator.is_token_refresh_running():
|
|
672
|
+
try:
|
|
673
|
+
self.generator._generator.stop_token_refresh()
|
|
674
|
+
except Exception as stop_error:
|
|
675
|
+
logger.warning(f"Error stopping token refresh: {stop_error}")
|
|
676
|
+
break
|
|
677
|
+
# Re-raise other RuntimeErrors
|
|
678
|
+
raise
|
|
679
|
+
|
|
680
|
+
if control.end_of_speech:
|
|
681
|
+
self.audio_batcher.reset()
|
|
682
|
+
# Passthrough the end flag of the speech
|
|
683
|
+
yield VideoFrame(
|
|
684
|
+
source_message_id=control.message_id,
|
|
685
|
+
end_of_speech=control.end_of_speech,
|
|
686
|
+
)
|
|
687
|
+
|
|
688
|
+
# Reset the action played flag
|
|
689
|
+
action_played = False
|
|
690
|
+
curr_frame_index = 0
|
|
691
|
+
self.video_graph.videos_script.last_nonidle_frame = 0
|
|
692
|
+
self._action_debouncer.reset()
|
|
693
|
+
|
|
694
|
+
# Reset the video graph if needed
|
|
695
|
+
self.video_graph.next_n_frames(num_frames=0, on_user_speech=True)
|
|
696
|
+
|
|
697
|
+
def process(self, control: VideoControl) -> Iterable[VideoFrame]:
|
|
698
|
+
"""Process the audio or control data."""
|
|
699
|
+
|
|
700
|
+
def _get_next_frame() -> FrameMeta:
|
|
701
|
+
if control.action or control.target_video:
|
|
702
|
+
self._action_debouncer.prepare(control)
|
|
703
|
+
if control.action:
|
|
704
|
+
logger.debug(f"Getting next frame for control: {control.target_video} {control.action}")
|
|
705
|
+
|
|
706
|
+
return self.video_graph.next_n_frames(
|
|
707
|
+
num_frames=1,
|
|
708
|
+
target_video_name=control.target_video,
|
|
709
|
+
actions_name=control.action,
|
|
710
|
+
on_agent_speech=control.is_speaking,
|
|
711
|
+
stop_on_user_speech_override=control.stop_on_user_speech,
|
|
712
|
+
stop_on_agent_speech_override=control.stop_on_agent_speech,
|
|
713
|
+
)[0]
|
|
714
|
+
|
|
715
|
+
frame_index = 0
|
|
716
|
+
for padded_chunk in self.audio_batcher.push(control.audio):
|
|
717
|
+
audio_array = padded_chunk.array
|
|
718
|
+
|
|
719
|
+
# get the mel chunks on padded audio
|
|
720
|
+
mel_chunks = audio_utils.get_mel_chunks(
|
|
721
|
+
audio_utils.int16_to_float32(audio_array), fps=self.settings.FPS
|
|
722
|
+
)
|
|
723
|
+
# unpad the audio and mel chunks
|
|
724
|
+
audio_array = self.audio_batcher.unpad(audio_array)
|
|
725
|
+
start = self.audio_batcher.pre_pad_video_frames
|
|
726
|
+
valid_frames = int(len(audio_array) / self._sample_per_video_frame)
|
|
727
|
+
mel_chunks = mel_chunks[start : start + valid_frames]
|
|
728
|
+
|
|
729
|
+
num_frames = len(mel_chunks)
|
|
730
|
+
samples_per_frame = len(audio_array) // max(num_frames, 1)
|
|
731
|
+
for i, mel_chunk in enumerate(mel_chunks):
|
|
732
|
+
if self.muted.is_set():
|
|
733
|
+
return
|
|
734
|
+
if self.interrupt_event.is_set():
|
|
735
|
+
self.interrupt_event.clear()
|
|
736
|
+
return
|
|
737
|
+
|
|
738
|
+
try:
|
|
739
|
+
frame_meta = _get_next_frame()
|
|
740
|
+
frame = self._process_talking_frame(frame_meta, mel_chunk)
|
|
741
|
+
except RuntimeError as e:
|
|
742
|
+
# Catch token validation errors
|
|
743
|
+
error_msg = str(e).lower()
|
|
744
|
+
if ("token has expired" in error_msg or
|
|
745
|
+
"token validation failed" in error_msg or
|
|
746
|
+
"validation failed" in error_msg):
|
|
747
|
+
logger.error(f"Token validation failed during frame processing: {str(e)}, stopping video stream")
|
|
748
|
+
# Re-raise to stop the generator and runtime
|
|
749
|
+
raise
|
|
750
|
+
# Re-raise other RuntimeErrors
|
|
751
|
+
raise
|
|
752
|
+
|
|
753
|
+
audio_start = i * samples_per_frame
|
|
754
|
+
audio_end = (
|
|
755
|
+
audio_start + samples_per_frame
|
|
756
|
+
if i < num_frames - 1
|
|
757
|
+
else len(audio_array)
|
|
758
|
+
)
|
|
759
|
+
yield VideoFrame(
|
|
760
|
+
bgr_image=frame,
|
|
761
|
+
audio_chunk=AudioChunk(
|
|
762
|
+
data=audio_array[audio_start:audio_end],
|
|
763
|
+
sample_rate=padded_chunk.sample_rate,
|
|
764
|
+
last_chunk=i == num_frames - 1,
|
|
765
|
+
),
|
|
766
|
+
frame_index=frame_index,
|
|
767
|
+
source_message_id=control.message_id,
|
|
768
|
+
)
|
|
769
|
+
frame_index += 1
|
|
770
|
+
|
|
771
|
+
if frame_index == 0 and not control.audio:
|
|
772
|
+
# generate idle frame if no frame is generated
|
|
773
|
+
try:
|
|
774
|
+
frame_meta = _get_next_frame()
|
|
775
|
+
frame = self._process_idle_frame(frame_meta)
|
|
776
|
+
except RuntimeError as e:
|
|
777
|
+
# Catch token validation errors
|
|
778
|
+
error_msg = str(e).lower()
|
|
779
|
+
if ("token has expired" in error_msg or
|
|
780
|
+
"token validation failed" in error_msg or
|
|
781
|
+
"validation failed" in error_msg):
|
|
782
|
+
logger.error("Token validation failed during idle frame processing, stopping video stream")
|
|
783
|
+
# Re-raise to stop the generator and runtime
|
|
784
|
+
raise
|
|
785
|
+
# Re-raise other RuntimeErrors
|
|
786
|
+
raise
|
|
787
|
+
|
|
788
|
+
yield VideoFrame(
|
|
789
|
+
bgr_image=frame,
|
|
790
|
+
frame_index=frame_index,
|
|
791
|
+
source_message_id=control.message_id,
|
|
792
|
+
)
|
|
793
|
+
|
|
794
|
+
def _process_talking_frame(
|
|
795
|
+
self, frame: FrameMeta, mel_chunk: np.ndarray
|
|
796
|
+
) -> np.ndarray:
|
|
797
|
+
"""Process a talking frame with audio-driven lip sync.
|
|
798
|
+
|
|
799
|
+
This method processes audio and generates a frame. Token validation is checked
|
|
800
|
+
internally. If token validation fails, RuntimeError will be raised.
|
|
801
|
+
|
|
802
|
+
Args:
|
|
803
|
+
frame: Frame metadata
|
|
804
|
+
mel_chunk: Mel spectrogram chunk
|
|
805
|
+
|
|
806
|
+
Returns:
|
|
807
|
+
Processed frame as numpy array
|
|
808
|
+
|
|
809
|
+
Raises:
|
|
810
|
+
RuntimeError: If token validation fails
|
|
811
|
+
"""
|
|
812
|
+
try:
|
|
813
|
+
frame_np = self.generator.process_audio(
|
|
814
|
+
mel_chunk, frame.video_name, frame.frame_index
|
|
815
|
+
)
|
|
816
|
+
return frame_np
|
|
817
|
+
except RuntimeError as e:
|
|
818
|
+
# Handle token validation errors
|
|
819
|
+
raise self._handle_token_validation_error(e, "talking frame processing") from e
|
|
820
|
+
|
|
821
|
+
def _process_idle_frame(self, frame: FrameMeta) -> np.ndarray:
|
|
822
|
+
"""Get the idle frame with cache.
|
|
823
|
+
|
|
824
|
+
This method gets an idle frame. Token validation is checked automatically.
|
|
825
|
+
|
|
826
|
+
Args:
|
|
827
|
+
frame: Frame metadata
|
|
828
|
+
|
|
829
|
+
Returns:
|
|
830
|
+
Processed frame as numpy array
|
|
831
|
+
|
|
832
|
+
Raises:
|
|
833
|
+
RuntimeError: If token validation fails
|
|
834
|
+
"""
|
|
835
|
+
try:
|
|
836
|
+
if not self.settings.PROCESS_IDLE_VIDEO:
|
|
837
|
+
frame_np = self.generator.get_original_frame(
|
|
838
|
+
frame.video_name, frame.frame_index
|
|
839
|
+
)
|
|
840
|
+
else:
|
|
841
|
+
frame_np = self.generator.process_audio(
|
|
842
|
+
self.silent_mel_chunk, frame.video_name, frame.frame_index
|
|
843
|
+
)
|
|
844
|
+
|
|
845
|
+
return frame_np
|
|
846
|
+
except RuntimeError as e:
|
|
847
|
+
# Handle token validation errors
|
|
848
|
+
raise self._handle_token_validation_error(e, "idle frame processing") from e
|
|
849
|
+
|
|
850
|
+
@cached_property
|
|
851
|
+
def silent_mel_chunk(self) -> np.ndarray:
|
|
852
|
+
"""The mel chunk for silent audio."""
|
|
853
|
+
audio_np = np.zeros(self.settings.INPUT_SAMPLE_RATE * 1, dtype=np.float32)
|
|
854
|
+
return audio_utils.get_mel_chunks(audio_np, fps=self.settings.FPS)[0]
|
|
855
|
+
|
|
856
|
+
def _process_video_data_file(self, video_data_path: str) -> str:
|
|
857
|
+
"""Process the video data file."""
|
|
858
|
+
if not video_data_path:
|
|
859
|
+
return video_data_path
|
|
860
|
+
|
|
861
|
+
if video_data_path.endswith(".pth"):
|
|
862
|
+
logger.debug(f"Converting pth to h5, torch is required: {video_data_path}")
|
|
863
|
+
from .lib.pth2h5 import convert_pth_to_h5
|
|
864
|
+
|
|
865
|
+
return convert_pth_to_h5(video_data_path)
|
|
866
|
+
return video_data_path
|
|
867
|
+
|
|
868
|
+
def _warmup(self) -> None:
|
|
869
|
+
"""Warm up the audio processing."""
|
|
870
|
+
audio_utils.get_mel_chunks(
|
|
871
|
+
np.zeros(16000, dtype=np.float32), fps=self.settings.FPS
|
|
872
|
+
)
|
|
873
|
+
|
|
874
|
+
|
|
875
|
+
def cleanup(self) -> None:
|
|
876
|
+
"""Clean up the video graph."""
|
|
877
|
+
if self.video_graph:
|
|
878
|
+
self.video_graph.cleanup()
|
|
879
|
+
self.video_graph = None
|
|
880
|
+
|
|
881
|
+
def __del__(self) -> None:
|
|
882
|
+
"""Clean up the video graph."""
|
|
883
|
+
self.cleanup()
|
|
884
|
+
|
|
885
|
+
@classmethod
|
|
886
|
+
def create(
|
|
887
|
+
cls,
|
|
888
|
+
*,
|
|
889
|
+
model_path: Optional[str] = None,
|
|
890
|
+
token: Optional[str] = None,
|
|
891
|
+
api_secret: Optional[str] = None,
|
|
892
|
+
api_url: str = "https://auth.api.bithuman.ai/v1/runtime-tokens/request",
|
|
893
|
+
tags: Optional[str] = None,
|
|
894
|
+
insecure: bool = True,
|
|
895
|
+
input_buffer_size: int = 0,
|
|
896
|
+
verbose: Optional[bool] = None,
|
|
897
|
+
) -> "Bithuman":
|
|
898
|
+
"""Create a fully initialized Bithuman instance.
|
|
899
|
+
|
|
900
|
+
Token validation and refresh are handled automatically:
|
|
901
|
+
- If token is provided, it will be validated
|
|
902
|
+
- If api_secret and model_path are provided, token refresh will start when run() is called
|
|
903
|
+
"""
|
|
904
|
+
# Create instance - token refresh will start lazily when run() is called
|
|
905
|
+
instance = cls(
|
|
906
|
+
input_buffer_size=input_buffer_size,
|
|
907
|
+
token=token,
|
|
908
|
+
model_path=model_path,
|
|
909
|
+
api_secret=api_secret,
|
|
910
|
+
api_url=api_url,
|
|
911
|
+
tags=tags,
|
|
912
|
+
insecure=insecure,
|
|
913
|
+
verbose=verbose,
|
|
914
|
+
)
|
|
915
|
+
|
|
916
|
+
# Validate token if provided (validation happens in _initialize_token during start/set_model)
|
|
917
|
+
# Token refresh will start lazily when run() is called if api_secret is provided
|
|
918
|
+
|
|
919
|
+
# Set model if provided
|
|
920
|
+
if model_path:
|
|
921
|
+
try:
|
|
922
|
+
instance.set_model(model_path)
|
|
923
|
+
except Exception as e:
|
|
924
|
+
logger.error(f"Failed to set model: {e}")
|
|
925
|
+
raise
|
|
926
|
+
else:
|
|
927
|
+
logger.warning("No model path provided to factory method")
|
|
928
|
+
|
|
929
|
+
# Verify initialization success
|
|
930
|
+
try:
|
|
931
|
+
if instance.video_graph is None:
|
|
932
|
+
raise ValueError("Video graph not initialized")
|
|
933
|
+
except Exception as e:
|
|
934
|
+
logger.error(f"Initialization verification failed: {e}")
|
|
935
|
+
raise
|
|
936
|
+
|
|
937
|
+
return instance
|
|
938
|
+
|
|
939
|
+
|
|
940
|
+
|
|
941
|
+
class ThreadSafeAsyncQueue(Generic[T]):
|
|
942
|
+
"""A thread-safe queue that can be used from both async and sync contexts.
|
|
943
|
+
|
|
944
|
+
This queue uses a standard threading.Queue internally for thread safety,
|
|
945
|
+
but provides async methods for use in async contexts.
|
|
946
|
+
"""
|
|
947
|
+
|
|
948
|
+
def __init__(
|
|
949
|
+
self, maxsize: int = 0, event_loop: Optional[asyncio.AbstractEventLoop] = None
|
|
950
|
+
):
|
|
951
|
+
"""Initialize the queue.
|
|
952
|
+
|
|
953
|
+
Args:
|
|
954
|
+
maxsize: Maximum size of the queue. 0 means unlimited.
|
|
955
|
+
event_loop: The event loop to use.
|
|
956
|
+
"""
|
|
957
|
+
self._queue = Queue[T](maxsize=maxsize)
|
|
958
|
+
self._loop = event_loop
|
|
959
|
+
|
|
960
|
+
def put_nowait(self, item: T) -> None:
|
|
961
|
+
"""Put an item into the queue without blocking."""
|
|
962
|
+
self._queue.put_nowait(item)
|
|
963
|
+
|
|
964
|
+
async def aput(self, item: T, *args, **kwargs) -> None:
|
|
965
|
+
"""Put an item into the queue asynchronously."""
|
|
966
|
+
# Use run_in_executor to avoid blocking the event loop
|
|
967
|
+
if not self._loop:
|
|
968
|
+
self._loop = asyncio.get_event_loop()
|
|
969
|
+
await self._loop.run_in_executor(None, self._queue.put, item, *args, **kwargs)
|
|
970
|
+
|
|
971
|
+
def put(self, item: T, *args, **kwargs) -> None:
|
|
972
|
+
"""Put an item into the queue."""
|
|
973
|
+
self._queue.put(item, *args, **kwargs)
|
|
974
|
+
|
|
975
|
+
def get_nowait(self) -> T:
|
|
976
|
+
"""Get an item from the queue without blocking."""
|
|
977
|
+
return self._queue.get_nowait()
|
|
978
|
+
|
|
979
|
+
async def aget(self, *args, **kwargs) -> T:
|
|
980
|
+
"""Get an item from the queue asynchronously."""
|
|
981
|
+
# Use run_in_executor to avoid blocking the event loop
|
|
982
|
+
if not self._loop:
|
|
983
|
+
self._loop = asyncio.get_event_loop()
|
|
984
|
+
return await self._loop.run_in_executor(None, self._queue.get, *args, **kwargs)
|
|
985
|
+
|
|
986
|
+
def get(self, *args, **kwargs) -> T:
|
|
987
|
+
"""Get an item from the queue."""
|
|
988
|
+
return self._queue.get(*args, **kwargs)
|
|
989
|
+
|
|
990
|
+
def task_done(self) -> None:
|
|
991
|
+
"""Mark a task as done."""
|
|
992
|
+
self._queue.task_done()
|
|
993
|
+
|
|
994
|
+
def empty(self) -> bool:
|
|
995
|
+
"""Check if the queue is empty."""
|
|
996
|
+
return self._queue.empty()
|
|
997
|
+
|
|
998
|
+
def qsize(self) -> int:
|
|
999
|
+
"""Get the size of the queue."""
|
|
1000
|
+
return self._queue.qsize()
|
|
1001
|
+
|
|
1002
|
+
def set_loop(self, loop: asyncio.AbstractEventLoop) -> None:
|
|
1003
|
+
"""Set the event loop."""
|
|
1004
|
+
self._loop = loop
|