bithuman 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. bithuman/__init__.py +13 -0
  2. bithuman/_version.py +1 -0
  3. bithuman/api.py +164 -0
  4. bithuman/audio/__init__.py +19 -0
  5. bithuman/audio/audio.py +396 -0
  6. bithuman/audio/hparams.py +108 -0
  7. bithuman/audio/utils.py +255 -0
  8. bithuman/config.py +88 -0
  9. bithuman/engine/__init__.py +15 -0
  10. bithuman/engine/auth.py +335 -0
  11. bithuman/engine/compression.py +257 -0
  12. bithuman/engine/enums.py +16 -0
  13. bithuman/engine/image_ops.py +192 -0
  14. bithuman/engine/inference.py +108 -0
  15. bithuman/engine/knn.py +58 -0
  16. bithuman/engine/video_data.py +391 -0
  17. bithuman/engine/video_reader.py +168 -0
  18. bithuman/lib/__init__.py +1 -0
  19. bithuman/lib/audio_encoder.onnx +45631 -28
  20. bithuman/lib/generator.py +763 -0
  21. bithuman/lib/pth2h5.py +106 -0
  22. bithuman/plugins/__init__.py +0 -0
  23. bithuman/plugins/stt.py +185 -0
  24. bithuman/runtime.py +1004 -0
  25. bithuman/runtime_async.py +469 -0
  26. bithuman/service/__init__.py +9 -0
  27. bithuman/service/client.py +788 -0
  28. bithuman/service/messages.py +210 -0
  29. bithuman/service/server.py +759 -0
  30. bithuman/utils/__init__.py +43 -0
  31. bithuman/utils/agent.py +359 -0
  32. bithuman/utils/fps_controller.py +90 -0
  33. bithuman/utils/image.py +41 -0
  34. bithuman/utils/unzip.py +38 -0
  35. bithuman/video_graph/__init__.py +16 -0
  36. bithuman/video_graph/action_trigger.py +83 -0
  37. bithuman/video_graph/driver_video.py +482 -0
  38. bithuman/video_graph/navigator.py +736 -0
  39. bithuman/video_graph/trigger.py +90 -0
  40. bithuman/video_graph/video_script.py +344 -0
  41. bithuman-1.0.2.dist-info/METADATA +37 -0
  42. bithuman-1.0.2.dist-info/RECORD +44 -0
  43. bithuman-1.0.2.dist-info/WHEEL +5 -0
  44. bithuman-1.0.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,763 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import threading
5
+ import time
6
+ import uuid
7
+ from typing import Dict, Optional, Union
8
+
9
+ import numpy as np
10
+ from loguru import logger
11
+
12
+ from ..engine import RUNTIME_VERSION, CompressionType, LoadingMode
13
+ from ..engine.auth import (
14
+ JWTValidator,
15
+ calculate_file_md5,
16
+ generate_hardware_fingerprint,
17
+ generate_uuid,
18
+ request_token as _http_request_token,
19
+ )
20
+ from ..engine.compression import decode_image, encode_image
21
+ from ..engine.image_ops import resize_image
22
+ from ..engine.inference import AudioEncoder
23
+ from ..engine.knn import AudioFeatureIndex
24
+ from ..engine.video_data import LipFormat, VideoData
25
+
26
+
27
+ def _parse_compression_type(compression_type: CompressionType | str) -> CompressionType:
28
+ if isinstance(compression_type, CompressionType):
29
+ return compression_type
30
+ maps = {
31
+ "NONE": CompressionType.NONE,
32
+ "JPEG": CompressionType.JPEG,
33
+ "LZ4": CompressionType.LZ4,
34
+ "TEMP_FILE": CompressionType.TEMP_FILE,
35
+ }
36
+ if isinstance(compression_type, str):
37
+ if compression_type not in maps:
38
+ raise ValueError(f"Invalid compression type: {compression_type}")
39
+ compression_type = maps[compression_type]
40
+ return compression_type
41
+
42
+
43
+ def _parse_loading_mode(loading_mode: LoadingMode | str) -> LoadingMode:
44
+ if isinstance(loading_mode, LoadingMode):
45
+ return loading_mode
46
+ maps = {
47
+ "SYNC": LoadingMode.SYNC,
48
+ "ASYNC": LoadingMode.ASYNC,
49
+ "ON_DEMAND": LoadingMode.ON_DEMAND,
50
+ }
51
+ return maps[loading_mode]
52
+
53
+
54
+ # ---------------------------------------------------------------------------
55
+ # _PythonBithumanRuntime — drop-in replacement for the C++ LibBithuman
56
+ # ---------------------------------------------------------------------------
57
+ class _PythonBithumanRuntime:
58
+ """Pure Python replacement for the C++ BithumanRuntime (pybind11 class).
59
+
60
+ Every public method matches the C++ pybind11 interface exactly so that
61
+ ``BithumanGenerator`` and ``runtime.py`` can use it without changes.
62
+ """
63
+
64
+ _runtime_version: str = RUNTIME_VERSION
65
+
66
+ def __init__(self, audio_encoder_path: str = "", output_size: int = -1) -> None:
67
+ # Audio encoder (ONNX)
68
+ self._audio_encoder = AudioEncoder(audio_encoder_path if audio_encoder_path else "")
69
+
70
+ # Audio features (KNN index)
71
+ self._audio_features = AudioFeatureIndex()
72
+
73
+ # Output size (-1 means use original)
74
+ self._output_size = output_size
75
+
76
+ # Video data map: name -> VideoData
77
+ self._video_data: Dict[str, VideoData] = {}
78
+
79
+ # JWT validator
80
+ self._jwt_validator = JWTValidator()
81
+
82
+ # Hardware fingerprint (generated once at init, matches C++)
83
+ self._fingerprint = generate_hardware_fingerprint()
84
+
85
+ # Model hash for token verification
86
+ self._model_hash = ""
87
+ self._encrypted_model_hash = ""
88
+
89
+ # Token refresh thread state
90
+ self._refresh_thread: Optional[threading.Thread] = None
91
+ self._refresh_stop_event = threading.Event()
92
+ self._refresh_running = False
93
+ self._account_status_error = False
94
+
95
+ # Transaction ID (generated in C++ layer to prevent user tampering)
96
+ self._transaction_id = ""
97
+ self._transaction_lock = threading.Lock()
98
+
99
+ # Instance ID from JWT
100
+ self._instance_id = ""
101
+
102
+ # ------------------------------------------------------------------
103
+ # Audio encoder
104
+ # ------------------------------------------------------------------
105
+ def set_audio_encoder(self, path: str) -> None:
106
+ self._audio_encoder.load(path)
107
+
108
+ # ------------------------------------------------------------------
109
+ # Audio features
110
+ # ------------------------------------------------------------------
111
+ def set_audio_feature(self, feature_source: Union[str, np.ndarray]) -> None:
112
+ """Set audio features from HDF5 path or numpy array."""
113
+ if isinstance(feature_source, str):
114
+ self._audio_features.load_from_h5(feature_source)
115
+ else:
116
+ self._audio_features.set_features(feature_source.astype(np.float32))
117
+
118
+ # ------------------------------------------------------------------
119
+ # Output size
120
+ # ------------------------------------------------------------------
121
+ def set_output_size(self, output_size: int) -> None:
122
+ self._output_size = output_size
123
+
124
+ # ------------------------------------------------------------------
125
+ # Video management
126
+ # ------------------------------------------------------------------
127
+ def add_video(
128
+ self,
129
+ video_name: str,
130
+ video_path: str,
131
+ video_data_path: str,
132
+ avatar_data_path: str,
133
+ compression_type: CompressionType = CompressionType.JPEG,
134
+ loading_mode: LoadingMode = LoadingMode.ASYNC,
135
+ thread_count: int = 0,
136
+ ) -> None:
137
+ """Add a video. Matches C++ addVideo (lines 385-407)."""
138
+ lip_format = LipFormat.detect(avatar_data_path) if avatar_data_path else LipFormat.NONE
139
+
140
+ vd = VideoData(
141
+ video_path=video_path,
142
+ video_data_path=video_data_path,
143
+ avatar_data_path=avatar_data_path,
144
+ lip_format=lip_format,
145
+ compression_type=compression_type,
146
+ loading_mode=loading_mode,
147
+ thread_count=thread_count,
148
+ )
149
+ self._video_data[video_name] = vd
150
+
151
+ # ------------------------------------------------------------------
152
+ # Core pipeline: process_audio
153
+ # ------------------------------------------------------------------
154
+ def process_audio(
155
+ self, mel_chunk: np.ndarray, video_name: str, frame_idx: int
156
+ ) -> np.ndarray:
157
+ """Process mel chunk → embedding → KNN → blended frame.
158
+
159
+ Matches generator.cpp processAudio (lines 321-359):
160
+ 1. checkJWTValidation()
161
+ 2. melChunkToAudioEmbedding(mel_chunk)
162
+ 3. find nearest cluster in audio_features_
163
+ 4. getBlendedFrame(video_name, frame_idx, cluster_idx, num_clusters)
164
+ 5. resize to output_size if needed
165
+ """
166
+ self._check_jwt_validation()
167
+
168
+ # 1. Mel → embedding
169
+ embedding = self._audio_encoder.encode(mel_chunk.astype(np.float32))
170
+
171
+ # 2. Find nearest cluster
172
+ cluster_idx = self._audio_features.find_nearest(embedding)
173
+ num_clusters = self._audio_features.num_clusters
174
+
175
+ # 3. Get blended frame
176
+ vd = self._video_data.get(video_name)
177
+ if vd is None:
178
+ raise RuntimeError(f"Video not found: {video_name}")
179
+
180
+ frame = vd.get_blended_frame(frame_idx, cluster_idx, num_clusters)
181
+
182
+ # 4. Resize if output_size is set
183
+ if self._output_size > 0 and frame.shape[1] != self._output_size:
184
+ scale = self._output_size / frame.shape[1]
185
+ new_h = round(frame.shape[0] * scale)
186
+ frame = resize_image(frame, self._output_size, new_h)
187
+
188
+ return frame
189
+
190
+ # ------------------------------------------------------------------
191
+ # Original frame access
192
+ # ------------------------------------------------------------------
193
+ def get_original_frame(self, video_name: str, frame_idx: int) -> np.ndarray:
194
+ """Get original (non-blended) frame. Matches C++ getOriginalFrame."""
195
+ self._check_jwt_validation()
196
+
197
+ vd = self._video_data.get(video_name)
198
+ if vd is None:
199
+ raise RuntimeError(f"Video not found: {video_name}")
200
+
201
+ frame = vd.get_original_frame(frame_idx)
202
+
203
+ if self._output_size > 0 and frame.shape[1] != self._output_size:
204
+ scale = self._output_size / frame.shape[1]
205
+ new_h = round(frame.shape[0] * scale)
206
+ frame = resize_image(frame, self._output_size, new_h)
207
+
208
+ return frame
209
+
210
+ # ------------------------------------------------------------------
211
+ # Frame count
212
+ # ------------------------------------------------------------------
213
+ def get_num_frames(self, video_name: str) -> int:
214
+ vd = self._video_data.get(video_name)
215
+ if vd is None:
216
+ return -1
217
+ return vd.num_frames
218
+
219
+ # ------------------------------------------------------------------
220
+ # JWT / Token validation
221
+ # ------------------------------------------------------------------
222
+ def validate_token(self, token: str, verbose: bool = True) -> bool:
223
+ """Validate and set a JWT token. Matches C++ validateToken."""
224
+ result = self._jwt_validator.validate_token(token, verbose=verbose)
225
+ if result:
226
+ self._instance_id = self._jwt_validator._instance_id
227
+ return result
228
+
229
+ def is_token_validated(self) -> bool:
230
+ return self._jwt_validator.is_validated()
231
+
232
+ def has_expired(self) -> bool:
233
+ return self._jwt_validator.has_expired()
234
+
235
+ def get_expiration_time(self) -> int:
236
+ """Return expiration time as int seconds, -1 if not validated."""
237
+ exp = self._jwt_validator.get_expiration_time()
238
+ if exp is None:
239
+ return -1
240
+ return int(exp)
241
+
242
+ def get_instance_id(self) -> str:
243
+ return self._instance_id
244
+
245
+ # ------------------------------------------------------------------
246
+ # Model hash
247
+ # ------------------------------------------------------------------
248
+ def set_model_hash_from_file(self, model_path: str) -> str:
249
+ """Calculate and store model hash from file. Matches C++ setModelHashFromFile."""
250
+ self._model_hash = calculate_file_md5(model_path)
251
+ self._encrypted_model_hash = self._jwt_validator.encrypt_model_hash(self._model_hash)
252
+ return self._model_hash
253
+
254
+ # ------------------------------------------------------------------
255
+ # Hardware fingerprint
256
+ # ------------------------------------------------------------------
257
+ def getFingerprint(self) -> str:
258
+ return self._fingerprint
259
+
260
+ # ------------------------------------------------------------------
261
+ # Token request
262
+ # ------------------------------------------------------------------
263
+ def request_token(
264
+ self,
265
+ api_url: str,
266
+ api_secret: str,
267
+ model_path: Optional[str] = None,
268
+ tags: Optional[str] = None,
269
+ transaction_id: Optional[str] = None,
270
+ insecure: bool = False,
271
+ timeout: float = 30.0,
272
+ ) -> str:
273
+ """Request a token from the auth API. Matches C++ requestToken."""
274
+ # Calculate runtime model hash if model_path is provided
275
+ runtime_model_hash = None
276
+ if model_path:
277
+ raw_hash = calculate_file_md5(model_path)
278
+ if raw_hash:
279
+ runtime_model_hash = self._jwt_validator.encrypt_model_hash(raw_hash)
280
+
281
+ return _http_request_token(
282
+ api_url=api_url,
283
+ api_secret=api_secret,
284
+ fingerprint=self._fingerprint,
285
+ runtime_model_hash=runtime_model_hash,
286
+ tags=tags,
287
+ transaction_id=transaction_id,
288
+ insecure=insecure,
289
+ timeout=timeout,
290
+ )
291
+
292
+ # ------------------------------------------------------------------
293
+ # Token refresh thread
294
+ # ------------------------------------------------------------------
295
+ def start_token_refresh(
296
+ self,
297
+ api_url: str,
298
+ api_secret: str,
299
+ model_path: Optional[str] = None,
300
+ tags: Optional[str] = None,
301
+ refresh_interval: int = 60,
302
+ insecure: bool = False,
303
+ timeout: float = 30.0,
304
+ ) -> bool:
305
+ """Start background token refresh. Matches C++ startTokenRefresh.
306
+
307
+ Does an initial synchronous token request, then spawns a daemon
308
+ thread that refreshes every ``refresh_interval`` seconds.
309
+ Returns True if started, False if already running.
310
+ """
311
+ if self._refresh_running:
312
+ return False
313
+
314
+ # Calculate runtime model hash once
315
+ runtime_model_hash = None
316
+ if model_path:
317
+ raw_hash = calculate_file_md5(model_path)
318
+ if raw_hash:
319
+ runtime_model_hash = self._jwt_validator.encrypt_model_hash(raw_hash)
320
+ self._model_hash = raw_hash
321
+ self._encrypted_model_hash = runtime_model_hash
322
+
323
+ # Synchronous initial token request (matches C++ behavior)
324
+ try:
325
+ with self._transaction_lock:
326
+ txn_id = self._transaction_id
327
+
328
+ token = _http_request_token(
329
+ api_url=api_url,
330
+ api_secret=api_secret,
331
+ fingerprint=self._fingerprint,
332
+ runtime_model_hash=runtime_model_hash,
333
+ tags=tags,
334
+ transaction_id=txn_id,
335
+ insecure=insecure,
336
+ timeout=timeout,
337
+ )
338
+ if not self._jwt_validator.validate_token(token, verbose=True):
339
+ logger.error("Initial token validation failed during start_token_refresh")
340
+ return False
341
+ self._instance_id = self._jwt_validator._instance_id
342
+ logger.info("Initial token acquired and validated")
343
+ except RuntimeError as e:
344
+ error_msg = str(e)
345
+ if "402" in error_msg or "403" in error_msg or "400" in error_msg:
346
+ self._account_status_error = True
347
+ logger.error(f"Account status error: {e}")
348
+ raise
349
+
350
+ # Start background refresh thread
351
+ self._refresh_stop_event.clear()
352
+ self._refresh_running = True
353
+
354
+ self._refresh_thread = threading.Thread(
355
+ target=self._token_refresh_worker,
356
+ args=(
357
+ api_url,
358
+ api_secret,
359
+ runtime_model_hash,
360
+ tags,
361
+ refresh_interval,
362
+ insecure,
363
+ timeout,
364
+ ),
365
+ daemon=True,
366
+ )
367
+ self._refresh_thread.start()
368
+ return True
369
+
370
+ def _token_refresh_worker(
371
+ self,
372
+ api_url: str,
373
+ api_secret: str,
374
+ runtime_model_hash: Optional[str],
375
+ tags: Optional[str],
376
+ refresh_interval: int,
377
+ insecure: bool,
378
+ timeout: float,
379
+ ) -> None:
380
+ """Background token refresh loop. Matches C++ tokenRefreshWorker."""
381
+ max_retries = 3
382
+
383
+ while not self._refresh_stop_event.is_set():
384
+ # Sleep in small increments so we can respond to stop quickly
385
+ for _ in range(refresh_interval * 10):
386
+ if self._refresh_stop_event.is_set():
387
+ break
388
+ time.sleep(0.1)
389
+
390
+ if self._refresh_stop_event.is_set():
391
+ break
392
+
393
+ # Attempt token refresh with retries
394
+ with self._transaction_lock:
395
+ txn_id = self._transaction_id
396
+
397
+ success = False
398
+ for attempt in range(max_retries):
399
+ if self._refresh_stop_event.is_set():
400
+ break
401
+ try:
402
+ token = _http_request_token(
403
+ api_url=api_url,
404
+ api_secret=api_secret,
405
+ fingerprint=self._fingerprint,
406
+ runtime_model_hash=runtime_model_hash,
407
+ tags=tags,
408
+ transaction_id=txn_id,
409
+ insecure=insecure,
410
+ timeout=timeout,
411
+ )
412
+ if self._jwt_validator.validate_token(token, verbose=True):
413
+ self._instance_id = self._jwt_validator._instance_id
414
+ logger.debug("Token refreshed successfully")
415
+ success = True
416
+ break
417
+ else:
418
+ logger.warning(
419
+ f"Token refresh validation failed (attempt {attempt + 1}/{max_retries})"
420
+ )
421
+ except RuntimeError as e:
422
+ error_msg = str(e)
423
+ # Permanent errors — stop retrying
424
+ if "402" in error_msg or "403" in error_msg or "400" in error_msg:
425
+ self._account_status_error = True
426
+ logger.error(f"Permanent token refresh error: {e}")
427
+ self._refresh_running = False
428
+ return
429
+ logger.warning(
430
+ f"Token refresh request failed (attempt {attempt + 1}/{max_retries}): {e}"
431
+ )
432
+ if attempt < max_retries - 1:
433
+ time.sleep(2 ** attempt) # Exponential backoff
434
+
435
+ if not success and not self._refresh_stop_event.is_set():
436
+ logger.error("Token refresh failed after all retries")
437
+
438
+ self._refresh_running = False
439
+
440
+ def stop_token_refresh(self) -> None:
441
+ """Stop the background token refresh thread."""
442
+ self._refresh_stop_event.set()
443
+ if self._refresh_thread and self._refresh_thread.is_alive():
444
+ self._refresh_thread.join(timeout=5.0)
445
+ self._refresh_running = False
446
+
447
+ def is_token_refresh_running(self) -> bool:
448
+ return self._refresh_running
449
+
450
+ def is_account_status_error(self) -> bool:
451
+ return self._account_status_error
452
+
453
+ # ------------------------------------------------------------------
454
+ # Transaction ID
455
+ # ------------------------------------------------------------------
456
+ def generate_transaction_id(self) -> str:
457
+ """Generate and store a new transaction ID. Matches C++ generateTransactionId."""
458
+ with self._transaction_lock:
459
+ self._transaction_id = generate_uuid()
460
+ return self._transaction_id
461
+
462
+ def get_transaction_id(self) -> str:
463
+ with self._transaction_lock:
464
+ return self._transaction_id
465
+
466
+ # ------------------------------------------------------------------
467
+ # Static methods
468
+ # ------------------------------------------------------------------
469
+ @staticmethod
470
+ def get_runtime_version() -> str:
471
+ return RUNTIME_VERSION
472
+
473
+ @staticmethod
474
+ def build_token_request_data(
475
+ fingerprint: str,
476
+ runtime_model_hash: Optional[str] = None,
477
+ tags: Optional[str] = None,
478
+ transaction_id: Optional[str] = None,
479
+ ) -> str:
480
+ """Build JSON token request body. Matches C++ buildTokenRequestData."""
481
+ import json
482
+
483
+ body: dict = {
484
+ "fingerprint": fingerprint,
485
+ "runtime_version": RUNTIME_VERSION,
486
+ }
487
+ if runtime_model_hash:
488
+ body["runtime_model_hash"] = runtime_model_hash
489
+ if tags:
490
+ body["tags"] = tags
491
+ if transaction_id:
492
+ body["transaction_id"] = transaction_id
493
+ return json.dumps(body)
494
+
495
+ # ------------------------------------------------------------------
496
+ # Internal JWT check (called before every frame operation)
497
+ # ------------------------------------------------------------------
498
+ def _check_jwt_validation(self) -> None:
499
+ """Check JWT validity before operations. Matches C++ checkJWTValidation.
500
+
501
+ Raises RuntimeError if:
502
+ - Account status error (402/403/400)
503
+ - Token not validated
504
+ - Token expired
505
+ - Model hash not allowed
506
+ - Video hash not allowed (not implemented here — future)
507
+ """
508
+ if self._account_status_error:
509
+ raise RuntimeError(
510
+ "Account status error: token refresh failed with permanent error. "
511
+ "Runtime terminated."
512
+ )
513
+
514
+ if not self._jwt_validator.is_validated():
515
+ raise RuntimeError("Token validation failed: token not validated")
516
+
517
+ if self._jwt_validator.has_expired():
518
+ raise RuntimeError("Token validation failed: token has expired")
519
+
520
+ # Model hash verification
521
+ if self._model_hash:
522
+ if not self._jwt_validator.is_model_hash_allowed(
523
+ self._model_hash, self._encrypted_model_hash
524
+ ):
525
+ raise RuntimeError(
526
+ "Token validation failed: model hash not allowed by token"
527
+ )
528
+
529
+
530
+ # ---------------------------------------------------------------------------
531
+ # BithumanGenerator — public wrapper (API unchanged)
532
+ # ---------------------------------------------------------------------------
533
+ class BithumanGenerator:
534
+ """High-level Python wrapper for Bithuman Runtime Generator."""
535
+
536
+ # Re-export CompressionType enum
537
+ CompressionType = CompressionType
538
+
539
+ def __init__(
540
+ self,
541
+ audio_encoder_path: Optional[str] = None,
542
+ output_size: int = -1,
543
+ api_secret: Optional[str] = None,
544
+ api_url: Optional[str] = None,
545
+ model_path: Optional[str] = None,
546
+ tags: Optional[str] = None,
547
+ insecure: bool = False,
548
+ ):
549
+ """Initialize the generator.
550
+
551
+ Args:
552
+ audio_encoder_path: Path to the ONNX audio encoder model
553
+ output_size: Output size for frames
554
+ api_secret: Optional API secret for automatic token refresh
555
+ api_url: Optional API endpoint URL for token requests
556
+ model_path: Optional model file path (for token refresh)
557
+ tags: Optional tags for token requests
558
+ insecure: Whether to disable SSL verification
559
+ """
560
+ if audio_encoder_path is not None:
561
+ audio_encoder_path = str(audio_encoder_path)
562
+
563
+ # Validate audio encoder model file
564
+ if audio_encoder_path:
565
+ if not os.path.isfile(audio_encoder_path):
566
+ raise FileNotFoundError(
567
+ f"Audio encoder model not found: {audio_encoder_path}"
568
+ )
569
+ file_size = os.path.getsize(audio_encoder_path)
570
+ if file_size == 0:
571
+ raise ValueError(
572
+ f"Audio encoder model is empty (0 bytes): {audio_encoder_path}"
573
+ )
574
+ logger.debug(
575
+ f"Audio encoder model validated: {audio_encoder_path} "
576
+ f"({file_size / 1024:.1f} KB)"
577
+ )
578
+
579
+ try:
580
+ self._generator = _PythonBithumanRuntime(
581
+ audio_encoder_path or "", output_size
582
+ )
583
+ except Exception as e:
584
+ logger.error(
585
+ f"Failed to initialize runtime with audio encoder "
586
+ f"'{audio_encoder_path}': {e}"
587
+ )
588
+ raise RuntimeError(
589
+ f"ONNX audio encoder initialization failed: {e}. "
590
+ f"This may indicate CPU incompatibility with the INT8 quantized "
591
+ f"model. Check that the deployment environment supports the "
592
+ f"required instruction set."
593
+ ) from e
594
+
595
+ # Store initialization parameters for security checks
596
+ self._initialized_api_secret = api_secret
597
+ self._initialized_api_url = api_url
598
+ self._initialized_model_path = model_path
599
+ self._initialized_tags = tags
600
+ self._initialized_insecure = insecure
601
+
602
+ def set_model_hash_from_file(self, model_path: str) -> str:
603
+ """Set the model hash for verification against the token from a file."""
604
+ return self._generator.set_model_hash_from_file(model_path)
605
+
606
+ def get_instance_id(self) -> str:
607
+ """Get the instance ID of this runtime."""
608
+ return self._generator.get_instance_id()
609
+
610
+ def is_token_refresh_running(self) -> bool:
611
+ """Check if token refresh thread is running."""
612
+ return self._generator.is_token_refresh_running()
613
+
614
+ def is_account_status_error(self) -> bool:
615
+ """Check if account status error occurred (402, 403, 400)."""
616
+ return self._generator.is_account_status_error()
617
+
618
+ def generate_transaction_id(self) -> str:
619
+ """Generate and set a new transaction ID."""
620
+ return self._generator.generate_transaction_id()
621
+
622
+ def get_transaction_id(self) -> str:
623
+ """Get current transaction ID."""
624
+ return self._generator.get_transaction_id()
625
+
626
+ def set_audio_encoder(self, audio_encoder_path: str) -> None:
627
+ """Set the audio encoder model path."""
628
+ self._generator.set_audio_encoder(str(audio_encoder_path))
629
+
630
+ def set_audio_feature(self, audio_feature: Union[str, np.ndarray]) -> None:
631
+ """Set the audio feature."""
632
+ if isinstance(audio_feature, str):
633
+ self._generator.set_audio_feature(audio_feature)
634
+ else:
635
+ self._generator.set_audio_feature(audio_feature.astype(np.float32))
636
+
637
+ def set_output_size(self, output_size: int) -> None:
638
+ """Set the output size."""
639
+ self._generator.set_output_size(output_size)
640
+
641
+ def add_video(
642
+ self,
643
+ video_name: str,
644
+ video_path: str,
645
+ video_data_path: str,
646
+ avatar_data_path: str,
647
+ compression_type: CompressionType | str = CompressionType.JPEG,
648
+ loading_mode: LoadingMode | str = LoadingMode.ASYNC,
649
+ thread_count: int = 0,
650
+ ) -> None:
651
+ """Add a video to the generator."""
652
+ compression_type = _parse_compression_type(compression_type)
653
+ loading_mode = _parse_loading_mode(loading_mode)
654
+ self._generator.add_video(
655
+ str(video_name),
656
+ str(video_path),
657
+ str(video_data_path),
658
+ str(avatar_data_path),
659
+ compression_type,
660
+ loading_mode,
661
+ thread_count,
662
+ )
663
+
664
+ def process_audio(
665
+ self, mel_chunk: np.ndarray, video_name: str, frame_idx: int
666
+ ) -> np.ndarray:
667
+ """Process audio chunk and return blended frame."""
668
+ return self._generator.process_audio(
669
+ mel_chunk.astype(np.float32), str(video_name), frame_idx
670
+ )
671
+
672
+ def get_original_frame(self, video_name: str, frame_idx: int) -> np.ndarray:
673
+ """Get the original frame."""
674
+ return self._generator.get_original_frame(str(video_name), frame_idx)
675
+
676
+ def get_num_frames(self, video_name: str) -> int:
677
+ """Get the number of frames in the video."""
678
+ return self._generator.get_num_frames(str(video_name))
679
+
680
+ def is_token_validated(self) -> bool:
681
+ """Check if the token is validated."""
682
+ return self._generator.is_token_validated()
683
+
684
+ def get_expiration_time(self) -> int:
685
+ """Get the expiration time of the token."""
686
+ return self._generator.get_expiration_time()
687
+
688
+ @staticmethod
689
+ def get_runtime_version() -> str:
690
+ """Get the runtime version."""
691
+ return _PythonBithumanRuntime.get_runtime_version()
692
+
693
+ def request_token(
694
+ self,
695
+ api_url: str,
696
+ api_secret: str,
697
+ model_path: Optional[str] = None,
698
+ tags: Optional[str] = None,
699
+ transaction_id: Optional[str] = None,
700
+ insecure: bool = False,
701
+ timeout: float = 30.0,
702
+ ) -> str:
703
+ """Request a token from the authentication server."""
704
+ return self._generator.request_token(
705
+ api_url,
706
+ api_secret,
707
+ model_path,
708
+ tags,
709
+ transaction_id,
710
+ insecure,
711
+ timeout,
712
+ )
713
+
714
+ def start_token_refresh(
715
+ self,
716
+ api_url: str,
717
+ api_secret: str,
718
+ model_path: Optional[str] = None,
719
+ tags: Optional[str] = None,
720
+ refresh_interval: int = 60,
721
+ insecure: bool = False,
722
+ timeout: float = 30.0,
723
+ ) -> bool:
724
+ """Start token refresh thread."""
725
+ # Security check: If token refresh is already running, prevent restart
726
+ if self.is_token_refresh_running():
727
+ logger.warning(
728
+ "Security: Token refresh is already running. "
729
+ "Cannot restart with different parameters. "
730
+ "This prevents security bypass attempts."
731
+ )
732
+ return False
733
+
734
+ # Security check: Verify parameters match initialization parameters if set
735
+ if hasattr(self, '_initialized_api_secret') and self._initialized_api_secret:
736
+ if api_secret != self._initialized_api_secret:
737
+ logger.error(
738
+ "Security violation: api_secret does not match initialization parameter. "
739
+ "This prevents security bypass attempts."
740
+ )
741
+ raise RuntimeError(
742
+ "Cannot change api_secret after initialization. "
743
+ "This is a security restriction."
744
+ )
745
+ if api_url != self._initialized_api_url:
746
+ logger.error(
747
+ "Security violation: api_url does not match initialization parameter. "
748
+ "This prevents security bypass attempts."
749
+ )
750
+ raise RuntimeError(
751
+ "Cannot change api_url after initialization. "
752
+ "This is a security restriction."
753
+ )
754
+
755
+ return self._generator.start_token_refresh(
756
+ api_url,
757
+ api_secret,
758
+ model_path,
759
+ tags,
760
+ refresh_interval,
761
+ insecure,
762
+ timeout,
763
+ )