bithuman 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bithuman/__init__.py +13 -0
- bithuman/_version.py +1 -0
- bithuman/api.py +164 -0
- bithuman/audio/__init__.py +19 -0
- bithuman/audio/audio.py +396 -0
- bithuman/audio/hparams.py +108 -0
- bithuman/audio/utils.py +255 -0
- bithuman/config.py +88 -0
- bithuman/engine/__init__.py +15 -0
- bithuman/engine/auth.py +335 -0
- bithuman/engine/compression.py +257 -0
- bithuman/engine/enums.py +16 -0
- bithuman/engine/image_ops.py +192 -0
- bithuman/engine/inference.py +108 -0
- bithuman/engine/knn.py +58 -0
- bithuman/engine/video_data.py +391 -0
- bithuman/engine/video_reader.py +168 -0
- bithuman/lib/__init__.py +1 -0
- bithuman/lib/audio_encoder.onnx +45631 -28
- bithuman/lib/generator.py +763 -0
- bithuman/lib/pth2h5.py +106 -0
- bithuman/plugins/__init__.py +0 -0
- bithuman/plugins/stt.py +185 -0
- bithuman/runtime.py +1004 -0
- bithuman/runtime_async.py +469 -0
- bithuman/service/__init__.py +9 -0
- bithuman/service/client.py +788 -0
- bithuman/service/messages.py +210 -0
- bithuman/service/server.py +759 -0
- bithuman/utils/__init__.py +43 -0
- bithuman/utils/agent.py +359 -0
- bithuman/utils/fps_controller.py +90 -0
- bithuman/utils/image.py +41 -0
- bithuman/utils/unzip.py +38 -0
- bithuman/video_graph/__init__.py +16 -0
- bithuman/video_graph/action_trigger.py +83 -0
- bithuman/video_graph/driver_video.py +482 -0
- bithuman/video_graph/navigator.py +736 -0
- bithuman/video_graph/trigger.py +90 -0
- bithuman/video_graph/video_script.py +344 -0
- bithuman-1.0.2.dist-info/METADATA +37 -0
- bithuman-1.0.2.dist-info/RECORD +44 -0
- bithuman-1.0.2.dist-info/WHEEL +5 -0
- bithuman-1.0.2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,763 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import threading
|
|
5
|
+
import time
|
|
6
|
+
import uuid
|
|
7
|
+
from typing import Dict, Optional, Union
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
from loguru import logger
|
|
11
|
+
|
|
12
|
+
from ..engine import RUNTIME_VERSION, CompressionType, LoadingMode
|
|
13
|
+
from ..engine.auth import (
|
|
14
|
+
JWTValidator,
|
|
15
|
+
calculate_file_md5,
|
|
16
|
+
generate_hardware_fingerprint,
|
|
17
|
+
generate_uuid,
|
|
18
|
+
request_token as _http_request_token,
|
|
19
|
+
)
|
|
20
|
+
from ..engine.compression import decode_image, encode_image
|
|
21
|
+
from ..engine.image_ops import resize_image
|
|
22
|
+
from ..engine.inference import AudioEncoder
|
|
23
|
+
from ..engine.knn import AudioFeatureIndex
|
|
24
|
+
from ..engine.video_data import LipFormat, VideoData
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _parse_compression_type(compression_type: CompressionType | str) -> CompressionType:
|
|
28
|
+
if isinstance(compression_type, CompressionType):
|
|
29
|
+
return compression_type
|
|
30
|
+
maps = {
|
|
31
|
+
"NONE": CompressionType.NONE,
|
|
32
|
+
"JPEG": CompressionType.JPEG,
|
|
33
|
+
"LZ4": CompressionType.LZ4,
|
|
34
|
+
"TEMP_FILE": CompressionType.TEMP_FILE,
|
|
35
|
+
}
|
|
36
|
+
if isinstance(compression_type, str):
|
|
37
|
+
if compression_type not in maps:
|
|
38
|
+
raise ValueError(f"Invalid compression type: {compression_type}")
|
|
39
|
+
compression_type = maps[compression_type]
|
|
40
|
+
return compression_type
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _parse_loading_mode(loading_mode: LoadingMode | str) -> LoadingMode:
|
|
44
|
+
if isinstance(loading_mode, LoadingMode):
|
|
45
|
+
return loading_mode
|
|
46
|
+
maps = {
|
|
47
|
+
"SYNC": LoadingMode.SYNC,
|
|
48
|
+
"ASYNC": LoadingMode.ASYNC,
|
|
49
|
+
"ON_DEMAND": LoadingMode.ON_DEMAND,
|
|
50
|
+
}
|
|
51
|
+
return maps[loading_mode]
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
# ---------------------------------------------------------------------------
|
|
55
|
+
# _PythonBithumanRuntime — drop-in replacement for the C++ LibBithuman
|
|
56
|
+
# ---------------------------------------------------------------------------
|
|
57
|
+
class _PythonBithumanRuntime:
|
|
58
|
+
"""Pure Python replacement for the C++ BithumanRuntime (pybind11 class).
|
|
59
|
+
|
|
60
|
+
Every public method matches the C++ pybind11 interface exactly so that
|
|
61
|
+
``BithumanGenerator`` and ``runtime.py`` can use it without changes.
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
_runtime_version: str = RUNTIME_VERSION
|
|
65
|
+
|
|
66
|
+
def __init__(self, audio_encoder_path: str = "", output_size: int = -1) -> None:
|
|
67
|
+
# Audio encoder (ONNX)
|
|
68
|
+
self._audio_encoder = AudioEncoder(audio_encoder_path if audio_encoder_path else "")
|
|
69
|
+
|
|
70
|
+
# Audio features (KNN index)
|
|
71
|
+
self._audio_features = AudioFeatureIndex()
|
|
72
|
+
|
|
73
|
+
# Output size (-1 means use original)
|
|
74
|
+
self._output_size = output_size
|
|
75
|
+
|
|
76
|
+
# Video data map: name -> VideoData
|
|
77
|
+
self._video_data: Dict[str, VideoData] = {}
|
|
78
|
+
|
|
79
|
+
# JWT validator
|
|
80
|
+
self._jwt_validator = JWTValidator()
|
|
81
|
+
|
|
82
|
+
# Hardware fingerprint (generated once at init, matches C++)
|
|
83
|
+
self._fingerprint = generate_hardware_fingerprint()
|
|
84
|
+
|
|
85
|
+
# Model hash for token verification
|
|
86
|
+
self._model_hash = ""
|
|
87
|
+
self._encrypted_model_hash = ""
|
|
88
|
+
|
|
89
|
+
# Token refresh thread state
|
|
90
|
+
self._refresh_thread: Optional[threading.Thread] = None
|
|
91
|
+
self._refresh_stop_event = threading.Event()
|
|
92
|
+
self._refresh_running = False
|
|
93
|
+
self._account_status_error = False
|
|
94
|
+
|
|
95
|
+
# Transaction ID (generated in C++ layer to prevent user tampering)
|
|
96
|
+
self._transaction_id = ""
|
|
97
|
+
self._transaction_lock = threading.Lock()
|
|
98
|
+
|
|
99
|
+
# Instance ID from JWT
|
|
100
|
+
self._instance_id = ""
|
|
101
|
+
|
|
102
|
+
# ------------------------------------------------------------------
|
|
103
|
+
# Audio encoder
|
|
104
|
+
# ------------------------------------------------------------------
|
|
105
|
+
def set_audio_encoder(self, path: str) -> None:
|
|
106
|
+
self._audio_encoder.load(path)
|
|
107
|
+
|
|
108
|
+
# ------------------------------------------------------------------
|
|
109
|
+
# Audio features
|
|
110
|
+
# ------------------------------------------------------------------
|
|
111
|
+
def set_audio_feature(self, feature_source: Union[str, np.ndarray]) -> None:
|
|
112
|
+
"""Set audio features from HDF5 path or numpy array."""
|
|
113
|
+
if isinstance(feature_source, str):
|
|
114
|
+
self._audio_features.load_from_h5(feature_source)
|
|
115
|
+
else:
|
|
116
|
+
self._audio_features.set_features(feature_source.astype(np.float32))
|
|
117
|
+
|
|
118
|
+
# ------------------------------------------------------------------
|
|
119
|
+
# Output size
|
|
120
|
+
# ------------------------------------------------------------------
|
|
121
|
+
def set_output_size(self, output_size: int) -> None:
|
|
122
|
+
self._output_size = output_size
|
|
123
|
+
|
|
124
|
+
# ------------------------------------------------------------------
|
|
125
|
+
# Video management
|
|
126
|
+
# ------------------------------------------------------------------
|
|
127
|
+
def add_video(
|
|
128
|
+
self,
|
|
129
|
+
video_name: str,
|
|
130
|
+
video_path: str,
|
|
131
|
+
video_data_path: str,
|
|
132
|
+
avatar_data_path: str,
|
|
133
|
+
compression_type: CompressionType = CompressionType.JPEG,
|
|
134
|
+
loading_mode: LoadingMode = LoadingMode.ASYNC,
|
|
135
|
+
thread_count: int = 0,
|
|
136
|
+
) -> None:
|
|
137
|
+
"""Add a video. Matches C++ addVideo (lines 385-407)."""
|
|
138
|
+
lip_format = LipFormat.detect(avatar_data_path) if avatar_data_path else LipFormat.NONE
|
|
139
|
+
|
|
140
|
+
vd = VideoData(
|
|
141
|
+
video_path=video_path,
|
|
142
|
+
video_data_path=video_data_path,
|
|
143
|
+
avatar_data_path=avatar_data_path,
|
|
144
|
+
lip_format=lip_format,
|
|
145
|
+
compression_type=compression_type,
|
|
146
|
+
loading_mode=loading_mode,
|
|
147
|
+
thread_count=thread_count,
|
|
148
|
+
)
|
|
149
|
+
self._video_data[video_name] = vd
|
|
150
|
+
|
|
151
|
+
# ------------------------------------------------------------------
|
|
152
|
+
# Core pipeline: process_audio
|
|
153
|
+
# ------------------------------------------------------------------
|
|
154
|
+
def process_audio(
|
|
155
|
+
self, mel_chunk: np.ndarray, video_name: str, frame_idx: int
|
|
156
|
+
) -> np.ndarray:
|
|
157
|
+
"""Process mel chunk → embedding → KNN → blended frame.
|
|
158
|
+
|
|
159
|
+
Matches generator.cpp processAudio (lines 321-359):
|
|
160
|
+
1. checkJWTValidation()
|
|
161
|
+
2. melChunkToAudioEmbedding(mel_chunk)
|
|
162
|
+
3. find nearest cluster in audio_features_
|
|
163
|
+
4. getBlendedFrame(video_name, frame_idx, cluster_idx, num_clusters)
|
|
164
|
+
5. resize to output_size if needed
|
|
165
|
+
"""
|
|
166
|
+
self._check_jwt_validation()
|
|
167
|
+
|
|
168
|
+
# 1. Mel → embedding
|
|
169
|
+
embedding = self._audio_encoder.encode(mel_chunk.astype(np.float32))
|
|
170
|
+
|
|
171
|
+
# 2. Find nearest cluster
|
|
172
|
+
cluster_idx = self._audio_features.find_nearest(embedding)
|
|
173
|
+
num_clusters = self._audio_features.num_clusters
|
|
174
|
+
|
|
175
|
+
# 3. Get blended frame
|
|
176
|
+
vd = self._video_data.get(video_name)
|
|
177
|
+
if vd is None:
|
|
178
|
+
raise RuntimeError(f"Video not found: {video_name}")
|
|
179
|
+
|
|
180
|
+
frame = vd.get_blended_frame(frame_idx, cluster_idx, num_clusters)
|
|
181
|
+
|
|
182
|
+
# 4. Resize if output_size is set
|
|
183
|
+
if self._output_size > 0 and frame.shape[1] != self._output_size:
|
|
184
|
+
scale = self._output_size / frame.shape[1]
|
|
185
|
+
new_h = round(frame.shape[0] * scale)
|
|
186
|
+
frame = resize_image(frame, self._output_size, new_h)
|
|
187
|
+
|
|
188
|
+
return frame
|
|
189
|
+
|
|
190
|
+
# ------------------------------------------------------------------
|
|
191
|
+
# Original frame access
|
|
192
|
+
# ------------------------------------------------------------------
|
|
193
|
+
def get_original_frame(self, video_name: str, frame_idx: int) -> np.ndarray:
|
|
194
|
+
"""Get original (non-blended) frame. Matches C++ getOriginalFrame."""
|
|
195
|
+
self._check_jwt_validation()
|
|
196
|
+
|
|
197
|
+
vd = self._video_data.get(video_name)
|
|
198
|
+
if vd is None:
|
|
199
|
+
raise RuntimeError(f"Video not found: {video_name}")
|
|
200
|
+
|
|
201
|
+
frame = vd.get_original_frame(frame_idx)
|
|
202
|
+
|
|
203
|
+
if self._output_size > 0 and frame.shape[1] != self._output_size:
|
|
204
|
+
scale = self._output_size / frame.shape[1]
|
|
205
|
+
new_h = round(frame.shape[0] * scale)
|
|
206
|
+
frame = resize_image(frame, self._output_size, new_h)
|
|
207
|
+
|
|
208
|
+
return frame
|
|
209
|
+
|
|
210
|
+
# ------------------------------------------------------------------
|
|
211
|
+
# Frame count
|
|
212
|
+
# ------------------------------------------------------------------
|
|
213
|
+
def get_num_frames(self, video_name: str) -> int:
|
|
214
|
+
vd = self._video_data.get(video_name)
|
|
215
|
+
if vd is None:
|
|
216
|
+
return -1
|
|
217
|
+
return vd.num_frames
|
|
218
|
+
|
|
219
|
+
# ------------------------------------------------------------------
|
|
220
|
+
# JWT / Token validation
|
|
221
|
+
# ------------------------------------------------------------------
|
|
222
|
+
def validate_token(self, token: str, verbose: bool = True) -> bool:
|
|
223
|
+
"""Validate and set a JWT token. Matches C++ validateToken."""
|
|
224
|
+
result = self._jwt_validator.validate_token(token, verbose=verbose)
|
|
225
|
+
if result:
|
|
226
|
+
self._instance_id = self._jwt_validator._instance_id
|
|
227
|
+
return result
|
|
228
|
+
|
|
229
|
+
def is_token_validated(self) -> bool:
|
|
230
|
+
return self._jwt_validator.is_validated()
|
|
231
|
+
|
|
232
|
+
def has_expired(self) -> bool:
|
|
233
|
+
return self._jwt_validator.has_expired()
|
|
234
|
+
|
|
235
|
+
def get_expiration_time(self) -> int:
|
|
236
|
+
"""Return expiration time as int seconds, -1 if not validated."""
|
|
237
|
+
exp = self._jwt_validator.get_expiration_time()
|
|
238
|
+
if exp is None:
|
|
239
|
+
return -1
|
|
240
|
+
return int(exp)
|
|
241
|
+
|
|
242
|
+
def get_instance_id(self) -> str:
|
|
243
|
+
return self._instance_id
|
|
244
|
+
|
|
245
|
+
# ------------------------------------------------------------------
|
|
246
|
+
# Model hash
|
|
247
|
+
# ------------------------------------------------------------------
|
|
248
|
+
def set_model_hash_from_file(self, model_path: str) -> str:
|
|
249
|
+
"""Calculate and store model hash from file. Matches C++ setModelHashFromFile."""
|
|
250
|
+
self._model_hash = calculate_file_md5(model_path)
|
|
251
|
+
self._encrypted_model_hash = self._jwt_validator.encrypt_model_hash(self._model_hash)
|
|
252
|
+
return self._model_hash
|
|
253
|
+
|
|
254
|
+
# ------------------------------------------------------------------
|
|
255
|
+
# Hardware fingerprint
|
|
256
|
+
# ------------------------------------------------------------------
|
|
257
|
+
def getFingerprint(self) -> str:
|
|
258
|
+
return self._fingerprint
|
|
259
|
+
|
|
260
|
+
# ------------------------------------------------------------------
|
|
261
|
+
# Token request
|
|
262
|
+
# ------------------------------------------------------------------
|
|
263
|
+
def request_token(
|
|
264
|
+
self,
|
|
265
|
+
api_url: str,
|
|
266
|
+
api_secret: str,
|
|
267
|
+
model_path: Optional[str] = None,
|
|
268
|
+
tags: Optional[str] = None,
|
|
269
|
+
transaction_id: Optional[str] = None,
|
|
270
|
+
insecure: bool = False,
|
|
271
|
+
timeout: float = 30.0,
|
|
272
|
+
) -> str:
|
|
273
|
+
"""Request a token from the auth API. Matches C++ requestToken."""
|
|
274
|
+
# Calculate runtime model hash if model_path is provided
|
|
275
|
+
runtime_model_hash = None
|
|
276
|
+
if model_path:
|
|
277
|
+
raw_hash = calculate_file_md5(model_path)
|
|
278
|
+
if raw_hash:
|
|
279
|
+
runtime_model_hash = self._jwt_validator.encrypt_model_hash(raw_hash)
|
|
280
|
+
|
|
281
|
+
return _http_request_token(
|
|
282
|
+
api_url=api_url,
|
|
283
|
+
api_secret=api_secret,
|
|
284
|
+
fingerprint=self._fingerprint,
|
|
285
|
+
runtime_model_hash=runtime_model_hash,
|
|
286
|
+
tags=tags,
|
|
287
|
+
transaction_id=transaction_id,
|
|
288
|
+
insecure=insecure,
|
|
289
|
+
timeout=timeout,
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
# ------------------------------------------------------------------
|
|
293
|
+
# Token refresh thread
|
|
294
|
+
# ------------------------------------------------------------------
|
|
295
|
+
def start_token_refresh(
|
|
296
|
+
self,
|
|
297
|
+
api_url: str,
|
|
298
|
+
api_secret: str,
|
|
299
|
+
model_path: Optional[str] = None,
|
|
300
|
+
tags: Optional[str] = None,
|
|
301
|
+
refresh_interval: int = 60,
|
|
302
|
+
insecure: bool = False,
|
|
303
|
+
timeout: float = 30.0,
|
|
304
|
+
) -> bool:
|
|
305
|
+
"""Start background token refresh. Matches C++ startTokenRefresh.
|
|
306
|
+
|
|
307
|
+
Does an initial synchronous token request, then spawns a daemon
|
|
308
|
+
thread that refreshes every ``refresh_interval`` seconds.
|
|
309
|
+
Returns True if started, False if already running.
|
|
310
|
+
"""
|
|
311
|
+
if self._refresh_running:
|
|
312
|
+
return False
|
|
313
|
+
|
|
314
|
+
# Calculate runtime model hash once
|
|
315
|
+
runtime_model_hash = None
|
|
316
|
+
if model_path:
|
|
317
|
+
raw_hash = calculate_file_md5(model_path)
|
|
318
|
+
if raw_hash:
|
|
319
|
+
runtime_model_hash = self._jwt_validator.encrypt_model_hash(raw_hash)
|
|
320
|
+
self._model_hash = raw_hash
|
|
321
|
+
self._encrypted_model_hash = runtime_model_hash
|
|
322
|
+
|
|
323
|
+
# Synchronous initial token request (matches C++ behavior)
|
|
324
|
+
try:
|
|
325
|
+
with self._transaction_lock:
|
|
326
|
+
txn_id = self._transaction_id
|
|
327
|
+
|
|
328
|
+
token = _http_request_token(
|
|
329
|
+
api_url=api_url,
|
|
330
|
+
api_secret=api_secret,
|
|
331
|
+
fingerprint=self._fingerprint,
|
|
332
|
+
runtime_model_hash=runtime_model_hash,
|
|
333
|
+
tags=tags,
|
|
334
|
+
transaction_id=txn_id,
|
|
335
|
+
insecure=insecure,
|
|
336
|
+
timeout=timeout,
|
|
337
|
+
)
|
|
338
|
+
if not self._jwt_validator.validate_token(token, verbose=True):
|
|
339
|
+
logger.error("Initial token validation failed during start_token_refresh")
|
|
340
|
+
return False
|
|
341
|
+
self._instance_id = self._jwt_validator._instance_id
|
|
342
|
+
logger.info("Initial token acquired and validated")
|
|
343
|
+
except RuntimeError as e:
|
|
344
|
+
error_msg = str(e)
|
|
345
|
+
if "402" in error_msg or "403" in error_msg or "400" in error_msg:
|
|
346
|
+
self._account_status_error = True
|
|
347
|
+
logger.error(f"Account status error: {e}")
|
|
348
|
+
raise
|
|
349
|
+
|
|
350
|
+
# Start background refresh thread
|
|
351
|
+
self._refresh_stop_event.clear()
|
|
352
|
+
self._refresh_running = True
|
|
353
|
+
|
|
354
|
+
self._refresh_thread = threading.Thread(
|
|
355
|
+
target=self._token_refresh_worker,
|
|
356
|
+
args=(
|
|
357
|
+
api_url,
|
|
358
|
+
api_secret,
|
|
359
|
+
runtime_model_hash,
|
|
360
|
+
tags,
|
|
361
|
+
refresh_interval,
|
|
362
|
+
insecure,
|
|
363
|
+
timeout,
|
|
364
|
+
),
|
|
365
|
+
daemon=True,
|
|
366
|
+
)
|
|
367
|
+
self._refresh_thread.start()
|
|
368
|
+
return True
|
|
369
|
+
|
|
370
|
+
def _token_refresh_worker(
|
|
371
|
+
self,
|
|
372
|
+
api_url: str,
|
|
373
|
+
api_secret: str,
|
|
374
|
+
runtime_model_hash: Optional[str],
|
|
375
|
+
tags: Optional[str],
|
|
376
|
+
refresh_interval: int,
|
|
377
|
+
insecure: bool,
|
|
378
|
+
timeout: float,
|
|
379
|
+
) -> None:
|
|
380
|
+
"""Background token refresh loop. Matches C++ tokenRefreshWorker."""
|
|
381
|
+
max_retries = 3
|
|
382
|
+
|
|
383
|
+
while not self._refresh_stop_event.is_set():
|
|
384
|
+
# Sleep in small increments so we can respond to stop quickly
|
|
385
|
+
for _ in range(refresh_interval * 10):
|
|
386
|
+
if self._refresh_stop_event.is_set():
|
|
387
|
+
break
|
|
388
|
+
time.sleep(0.1)
|
|
389
|
+
|
|
390
|
+
if self._refresh_stop_event.is_set():
|
|
391
|
+
break
|
|
392
|
+
|
|
393
|
+
# Attempt token refresh with retries
|
|
394
|
+
with self._transaction_lock:
|
|
395
|
+
txn_id = self._transaction_id
|
|
396
|
+
|
|
397
|
+
success = False
|
|
398
|
+
for attempt in range(max_retries):
|
|
399
|
+
if self._refresh_stop_event.is_set():
|
|
400
|
+
break
|
|
401
|
+
try:
|
|
402
|
+
token = _http_request_token(
|
|
403
|
+
api_url=api_url,
|
|
404
|
+
api_secret=api_secret,
|
|
405
|
+
fingerprint=self._fingerprint,
|
|
406
|
+
runtime_model_hash=runtime_model_hash,
|
|
407
|
+
tags=tags,
|
|
408
|
+
transaction_id=txn_id,
|
|
409
|
+
insecure=insecure,
|
|
410
|
+
timeout=timeout,
|
|
411
|
+
)
|
|
412
|
+
if self._jwt_validator.validate_token(token, verbose=True):
|
|
413
|
+
self._instance_id = self._jwt_validator._instance_id
|
|
414
|
+
logger.debug("Token refreshed successfully")
|
|
415
|
+
success = True
|
|
416
|
+
break
|
|
417
|
+
else:
|
|
418
|
+
logger.warning(
|
|
419
|
+
f"Token refresh validation failed (attempt {attempt + 1}/{max_retries})"
|
|
420
|
+
)
|
|
421
|
+
except RuntimeError as e:
|
|
422
|
+
error_msg = str(e)
|
|
423
|
+
# Permanent errors — stop retrying
|
|
424
|
+
if "402" in error_msg or "403" in error_msg or "400" in error_msg:
|
|
425
|
+
self._account_status_error = True
|
|
426
|
+
logger.error(f"Permanent token refresh error: {e}")
|
|
427
|
+
self._refresh_running = False
|
|
428
|
+
return
|
|
429
|
+
logger.warning(
|
|
430
|
+
f"Token refresh request failed (attempt {attempt + 1}/{max_retries}): {e}"
|
|
431
|
+
)
|
|
432
|
+
if attempt < max_retries - 1:
|
|
433
|
+
time.sleep(2 ** attempt) # Exponential backoff
|
|
434
|
+
|
|
435
|
+
if not success and not self._refresh_stop_event.is_set():
|
|
436
|
+
logger.error("Token refresh failed after all retries")
|
|
437
|
+
|
|
438
|
+
self._refresh_running = False
|
|
439
|
+
|
|
440
|
+
def stop_token_refresh(self) -> None:
|
|
441
|
+
"""Stop the background token refresh thread."""
|
|
442
|
+
self._refresh_stop_event.set()
|
|
443
|
+
if self._refresh_thread and self._refresh_thread.is_alive():
|
|
444
|
+
self._refresh_thread.join(timeout=5.0)
|
|
445
|
+
self._refresh_running = False
|
|
446
|
+
|
|
447
|
+
def is_token_refresh_running(self) -> bool:
|
|
448
|
+
return self._refresh_running
|
|
449
|
+
|
|
450
|
+
def is_account_status_error(self) -> bool:
|
|
451
|
+
return self._account_status_error
|
|
452
|
+
|
|
453
|
+
# ------------------------------------------------------------------
|
|
454
|
+
# Transaction ID
|
|
455
|
+
# ------------------------------------------------------------------
|
|
456
|
+
def generate_transaction_id(self) -> str:
|
|
457
|
+
"""Generate and store a new transaction ID. Matches C++ generateTransactionId."""
|
|
458
|
+
with self._transaction_lock:
|
|
459
|
+
self._transaction_id = generate_uuid()
|
|
460
|
+
return self._transaction_id
|
|
461
|
+
|
|
462
|
+
def get_transaction_id(self) -> str:
|
|
463
|
+
with self._transaction_lock:
|
|
464
|
+
return self._transaction_id
|
|
465
|
+
|
|
466
|
+
# ------------------------------------------------------------------
|
|
467
|
+
# Static methods
|
|
468
|
+
# ------------------------------------------------------------------
|
|
469
|
+
@staticmethod
|
|
470
|
+
def get_runtime_version() -> str:
|
|
471
|
+
return RUNTIME_VERSION
|
|
472
|
+
|
|
473
|
+
@staticmethod
|
|
474
|
+
def build_token_request_data(
|
|
475
|
+
fingerprint: str,
|
|
476
|
+
runtime_model_hash: Optional[str] = None,
|
|
477
|
+
tags: Optional[str] = None,
|
|
478
|
+
transaction_id: Optional[str] = None,
|
|
479
|
+
) -> str:
|
|
480
|
+
"""Build JSON token request body. Matches C++ buildTokenRequestData."""
|
|
481
|
+
import json
|
|
482
|
+
|
|
483
|
+
body: dict = {
|
|
484
|
+
"fingerprint": fingerprint,
|
|
485
|
+
"runtime_version": RUNTIME_VERSION,
|
|
486
|
+
}
|
|
487
|
+
if runtime_model_hash:
|
|
488
|
+
body["runtime_model_hash"] = runtime_model_hash
|
|
489
|
+
if tags:
|
|
490
|
+
body["tags"] = tags
|
|
491
|
+
if transaction_id:
|
|
492
|
+
body["transaction_id"] = transaction_id
|
|
493
|
+
return json.dumps(body)
|
|
494
|
+
|
|
495
|
+
# ------------------------------------------------------------------
|
|
496
|
+
# Internal JWT check (called before every frame operation)
|
|
497
|
+
# ------------------------------------------------------------------
|
|
498
|
+
def _check_jwt_validation(self) -> None:
|
|
499
|
+
"""Check JWT validity before operations. Matches C++ checkJWTValidation.
|
|
500
|
+
|
|
501
|
+
Raises RuntimeError if:
|
|
502
|
+
- Account status error (402/403/400)
|
|
503
|
+
- Token not validated
|
|
504
|
+
- Token expired
|
|
505
|
+
- Model hash not allowed
|
|
506
|
+
- Video hash not allowed (not implemented here — future)
|
|
507
|
+
"""
|
|
508
|
+
if self._account_status_error:
|
|
509
|
+
raise RuntimeError(
|
|
510
|
+
"Account status error: token refresh failed with permanent error. "
|
|
511
|
+
"Runtime terminated."
|
|
512
|
+
)
|
|
513
|
+
|
|
514
|
+
if not self._jwt_validator.is_validated():
|
|
515
|
+
raise RuntimeError("Token validation failed: token not validated")
|
|
516
|
+
|
|
517
|
+
if self._jwt_validator.has_expired():
|
|
518
|
+
raise RuntimeError("Token validation failed: token has expired")
|
|
519
|
+
|
|
520
|
+
# Model hash verification
|
|
521
|
+
if self._model_hash:
|
|
522
|
+
if not self._jwt_validator.is_model_hash_allowed(
|
|
523
|
+
self._model_hash, self._encrypted_model_hash
|
|
524
|
+
):
|
|
525
|
+
raise RuntimeError(
|
|
526
|
+
"Token validation failed: model hash not allowed by token"
|
|
527
|
+
)
|
|
528
|
+
|
|
529
|
+
|
|
530
|
+
# ---------------------------------------------------------------------------
|
|
531
|
+
# BithumanGenerator — public wrapper (API unchanged)
|
|
532
|
+
# ---------------------------------------------------------------------------
|
|
533
|
+
class BithumanGenerator:
|
|
534
|
+
"""High-level Python wrapper for Bithuman Runtime Generator."""
|
|
535
|
+
|
|
536
|
+
# Re-export CompressionType enum
|
|
537
|
+
CompressionType = CompressionType
|
|
538
|
+
|
|
539
|
+
def __init__(
|
|
540
|
+
self,
|
|
541
|
+
audio_encoder_path: Optional[str] = None,
|
|
542
|
+
output_size: int = -1,
|
|
543
|
+
api_secret: Optional[str] = None,
|
|
544
|
+
api_url: Optional[str] = None,
|
|
545
|
+
model_path: Optional[str] = None,
|
|
546
|
+
tags: Optional[str] = None,
|
|
547
|
+
insecure: bool = False,
|
|
548
|
+
):
|
|
549
|
+
"""Initialize the generator.
|
|
550
|
+
|
|
551
|
+
Args:
|
|
552
|
+
audio_encoder_path: Path to the ONNX audio encoder model
|
|
553
|
+
output_size: Output size for frames
|
|
554
|
+
api_secret: Optional API secret for automatic token refresh
|
|
555
|
+
api_url: Optional API endpoint URL for token requests
|
|
556
|
+
model_path: Optional model file path (for token refresh)
|
|
557
|
+
tags: Optional tags for token requests
|
|
558
|
+
insecure: Whether to disable SSL verification
|
|
559
|
+
"""
|
|
560
|
+
if audio_encoder_path is not None:
|
|
561
|
+
audio_encoder_path = str(audio_encoder_path)
|
|
562
|
+
|
|
563
|
+
# Validate audio encoder model file
|
|
564
|
+
if audio_encoder_path:
|
|
565
|
+
if not os.path.isfile(audio_encoder_path):
|
|
566
|
+
raise FileNotFoundError(
|
|
567
|
+
f"Audio encoder model not found: {audio_encoder_path}"
|
|
568
|
+
)
|
|
569
|
+
file_size = os.path.getsize(audio_encoder_path)
|
|
570
|
+
if file_size == 0:
|
|
571
|
+
raise ValueError(
|
|
572
|
+
f"Audio encoder model is empty (0 bytes): {audio_encoder_path}"
|
|
573
|
+
)
|
|
574
|
+
logger.debug(
|
|
575
|
+
f"Audio encoder model validated: {audio_encoder_path} "
|
|
576
|
+
f"({file_size / 1024:.1f} KB)"
|
|
577
|
+
)
|
|
578
|
+
|
|
579
|
+
try:
|
|
580
|
+
self._generator = _PythonBithumanRuntime(
|
|
581
|
+
audio_encoder_path or "", output_size
|
|
582
|
+
)
|
|
583
|
+
except Exception as e:
|
|
584
|
+
logger.error(
|
|
585
|
+
f"Failed to initialize runtime with audio encoder "
|
|
586
|
+
f"'{audio_encoder_path}': {e}"
|
|
587
|
+
)
|
|
588
|
+
raise RuntimeError(
|
|
589
|
+
f"ONNX audio encoder initialization failed: {e}. "
|
|
590
|
+
f"This may indicate CPU incompatibility with the INT8 quantized "
|
|
591
|
+
f"model. Check that the deployment environment supports the "
|
|
592
|
+
f"required instruction set."
|
|
593
|
+
) from e
|
|
594
|
+
|
|
595
|
+
# Store initialization parameters for security checks
|
|
596
|
+
self._initialized_api_secret = api_secret
|
|
597
|
+
self._initialized_api_url = api_url
|
|
598
|
+
self._initialized_model_path = model_path
|
|
599
|
+
self._initialized_tags = tags
|
|
600
|
+
self._initialized_insecure = insecure
|
|
601
|
+
|
|
602
|
+
def set_model_hash_from_file(self, model_path: str) -> str:
|
|
603
|
+
"""Set the model hash for verification against the token from a file."""
|
|
604
|
+
return self._generator.set_model_hash_from_file(model_path)
|
|
605
|
+
|
|
606
|
+
def get_instance_id(self) -> str:
|
|
607
|
+
"""Get the instance ID of this runtime."""
|
|
608
|
+
return self._generator.get_instance_id()
|
|
609
|
+
|
|
610
|
+
def is_token_refresh_running(self) -> bool:
|
|
611
|
+
"""Check if token refresh thread is running."""
|
|
612
|
+
return self._generator.is_token_refresh_running()
|
|
613
|
+
|
|
614
|
+
def is_account_status_error(self) -> bool:
|
|
615
|
+
"""Check if account status error occurred (402, 403, 400)."""
|
|
616
|
+
return self._generator.is_account_status_error()
|
|
617
|
+
|
|
618
|
+
def generate_transaction_id(self) -> str:
|
|
619
|
+
"""Generate and set a new transaction ID."""
|
|
620
|
+
return self._generator.generate_transaction_id()
|
|
621
|
+
|
|
622
|
+
def get_transaction_id(self) -> str:
|
|
623
|
+
"""Get current transaction ID."""
|
|
624
|
+
return self._generator.get_transaction_id()
|
|
625
|
+
|
|
626
|
+
def set_audio_encoder(self, audio_encoder_path: str) -> None:
|
|
627
|
+
"""Set the audio encoder model path."""
|
|
628
|
+
self._generator.set_audio_encoder(str(audio_encoder_path))
|
|
629
|
+
|
|
630
|
+
def set_audio_feature(self, audio_feature: Union[str, np.ndarray]) -> None:
|
|
631
|
+
"""Set the audio feature."""
|
|
632
|
+
if isinstance(audio_feature, str):
|
|
633
|
+
self._generator.set_audio_feature(audio_feature)
|
|
634
|
+
else:
|
|
635
|
+
self._generator.set_audio_feature(audio_feature.astype(np.float32))
|
|
636
|
+
|
|
637
|
+
def set_output_size(self, output_size: int) -> None:
|
|
638
|
+
"""Set the output size."""
|
|
639
|
+
self._generator.set_output_size(output_size)
|
|
640
|
+
|
|
641
|
+
def add_video(
|
|
642
|
+
self,
|
|
643
|
+
video_name: str,
|
|
644
|
+
video_path: str,
|
|
645
|
+
video_data_path: str,
|
|
646
|
+
avatar_data_path: str,
|
|
647
|
+
compression_type: CompressionType | str = CompressionType.JPEG,
|
|
648
|
+
loading_mode: LoadingMode | str = LoadingMode.ASYNC,
|
|
649
|
+
thread_count: int = 0,
|
|
650
|
+
) -> None:
|
|
651
|
+
"""Add a video to the generator."""
|
|
652
|
+
compression_type = _parse_compression_type(compression_type)
|
|
653
|
+
loading_mode = _parse_loading_mode(loading_mode)
|
|
654
|
+
self._generator.add_video(
|
|
655
|
+
str(video_name),
|
|
656
|
+
str(video_path),
|
|
657
|
+
str(video_data_path),
|
|
658
|
+
str(avatar_data_path),
|
|
659
|
+
compression_type,
|
|
660
|
+
loading_mode,
|
|
661
|
+
thread_count,
|
|
662
|
+
)
|
|
663
|
+
|
|
664
|
+
def process_audio(
|
|
665
|
+
self, mel_chunk: np.ndarray, video_name: str, frame_idx: int
|
|
666
|
+
) -> np.ndarray:
|
|
667
|
+
"""Process audio chunk and return blended frame."""
|
|
668
|
+
return self._generator.process_audio(
|
|
669
|
+
mel_chunk.astype(np.float32), str(video_name), frame_idx
|
|
670
|
+
)
|
|
671
|
+
|
|
672
|
+
def get_original_frame(self, video_name: str, frame_idx: int) -> np.ndarray:
|
|
673
|
+
"""Get the original frame."""
|
|
674
|
+
return self._generator.get_original_frame(str(video_name), frame_idx)
|
|
675
|
+
|
|
676
|
+
def get_num_frames(self, video_name: str) -> int:
|
|
677
|
+
"""Get the number of frames in the video."""
|
|
678
|
+
return self._generator.get_num_frames(str(video_name))
|
|
679
|
+
|
|
680
|
+
def is_token_validated(self) -> bool:
|
|
681
|
+
"""Check if the token is validated."""
|
|
682
|
+
return self._generator.is_token_validated()
|
|
683
|
+
|
|
684
|
+
def get_expiration_time(self) -> int:
|
|
685
|
+
"""Get the expiration time of the token."""
|
|
686
|
+
return self._generator.get_expiration_time()
|
|
687
|
+
|
|
688
|
+
@staticmethod
|
|
689
|
+
def get_runtime_version() -> str:
|
|
690
|
+
"""Get the runtime version."""
|
|
691
|
+
return _PythonBithumanRuntime.get_runtime_version()
|
|
692
|
+
|
|
693
|
+
def request_token(
|
|
694
|
+
self,
|
|
695
|
+
api_url: str,
|
|
696
|
+
api_secret: str,
|
|
697
|
+
model_path: Optional[str] = None,
|
|
698
|
+
tags: Optional[str] = None,
|
|
699
|
+
transaction_id: Optional[str] = None,
|
|
700
|
+
insecure: bool = False,
|
|
701
|
+
timeout: float = 30.0,
|
|
702
|
+
) -> str:
|
|
703
|
+
"""Request a token from the authentication server."""
|
|
704
|
+
return self._generator.request_token(
|
|
705
|
+
api_url,
|
|
706
|
+
api_secret,
|
|
707
|
+
model_path,
|
|
708
|
+
tags,
|
|
709
|
+
transaction_id,
|
|
710
|
+
insecure,
|
|
711
|
+
timeout,
|
|
712
|
+
)
|
|
713
|
+
|
|
714
|
+
def start_token_refresh(
|
|
715
|
+
self,
|
|
716
|
+
api_url: str,
|
|
717
|
+
api_secret: str,
|
|
718
|
+
model_path: Optional[str] = None,
|
|
719
|
+
tags: Optional[str] = None,
|
|
720
|
+
refresh_interval: int = 60,
|
|
721
|
+
insecure: bool = False,
|
|
722
|
+
timeout: float = 30.0,
|
|
723
|
+
) -> bool:
|
|
724
|
+
"""Start token refresh thread."""
|
|
725
|
+
# Security check: If token refresh is already running, prevent restart
|
|
726
|
+
if self.is_token_refresh_running():
|
|
727
|
+
logger.warning(
|
|
728
|
+
"Security: Token refresh is already running. "
|
|
729
|
+
"Cannot restart with different parameters. "
|
|
730
|
+
"This prevents security bypass attempts."
|
|
731
|
+
)
|
|
732
|
+
return False
|
|
733
|
+
|
|
734
|
+
# Security check: Verify parameters match initialization parameters if set
|
|
735
|
+
if hasattr(self, '_initialized_api_secret') and self._initialized_api_secret:
|
|
736
|
+
if api_secret != self._initialized_api_secret:
|
|
737
|
+
logger.error(
|
|
738
|
+
"Security violation: api_secret does not match initialization parameter. "
|
|
739
|
+
"This prevents security bypass attempts."
|
|
740
|
+
)
|
|
741
|
+
raise RuntimeError(
|
|
742
|
+
"Cannot change api_secret after initialization. "
|
|
743
|
+
"This is a security restriction."
|
|
744
|
+
)
|
|
745
|
+
if api_url != self._initialized_api_url:
|
|
746
|
+
logger.error(
|
|
747
|
+
"Security violation: api_url does not match initialization parameter. "
|
|
748
|
+
"This prevents security bypass attempts."
|
|
749
|
+
)
|
|
750
|
+
raise RuntimeError(
|
|
751
|
+
"Cannot change api_url after initialization. "
|
|
752
|
+
"This is a security restriction."
|
|
753
|
+
)
|
|
754
|
+
|
|
755
|
+
return self._generator.start_token_refresh(
|
|
756
|
+
api_url,
|
|
757
|
+
api_secret,
|
|
758
|
+
model_path,
|
|
759
|
+
tags,
|
|
760
|
+
refresh_interval,
|
|
761
|
+
insecure,
|
|
762
|
+
timeout,
|
|
763
|
+
)
|