bithuman 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bithuman might be problematic. Click here for more details.

Files changed (44) hide show
  1. bithuman/__init__.py +13 -0
  2. bithuman/_version.py +1 -0
  3. bithuman/api.py +164 -0
  4. bithuman/audio/__init__.py +19 -0
  5. bithuman/audio/audio.py +396 -0
  6. bithuman/audio/hparams.py +108 -0
  7. bithuman/audio/utils.py +255 -0
  8. bithuman/config.py +88 -0
  9. bithuman/engine/__init__.py +15 -0
  10. bithuman/engine/auth.py +335 -0
  11. bithuman/engine/compression.py +257 -0
  12. bithuman/engine/enums.py +16 -0
  13. bithuman/engine/image_ops.py +192 -0
  14. bithuman/engine/inference.py +108 -0
  15. bithuman/engine/knn.py +58 -0
  16. bithuman/engine/video_data.py +391 -0
  17. bithuman/engine/video_reader.py +168 -0
  18. bithuman/lib/__init__.py +1 -0
  19. bithuman/lib/audio_encoder.onnx +45631 -28
  20. bithuman/lib/generator.py +763 -0
  21. bithuman/lib/pth2h5.py +106 -0
  22. bithuman/plugins/__init__.py +0 -0
  23. bithuman/plugins/stt.py +185 -0
  24. bithuman/runtime.py +1004 -0
  25. bithuman/runtime_async.py +469 -0
  26. bithuman/service/__init__.py +9 -0
  27. bithuman/service/client.py +788 -0
  28. bithuman/service/messages.py +210 -0
  29. bithuman/service/server.py +759 -0
  30. bithuman/utils/__init__.py +43 -0
  31. bithuman/utils/agent.py +359 -0
  32. bithuman/utils/fps_controller.py +90 -0
  33. bithuman/utils/image.py +41 -0
  34. bithuman/utils/unzip.py +38 -0
  35. bithuman/video_graph/__init__.py +16 -0
  36. bithuman/video_graph/action_trigger.py +83 -0
  37. bithuman/video_graph/driver_video.py +482 -0
  38. bithuman/video_graph/navigator.py +736 -0
  39. bithuman/video_graph/trigger.py +90 -0
  40. bithuman/video_graph/video_script.py +344 -0
  41. bithuman-1.0.2.dist-info/METADATA +37 -0
  42. bithuman-1.0.2.dist-info/RECORD +44 -0
  43. bithuman-1.0.2.dist-info/WHEEL +5 -0
  44. bithuman-1.0.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,257 @@
1
+ """Frame compression/decompression — replaces image.cpp encode/decode."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ import struct
7
+ import tempfile
8
+ from typing import Optional
9
+ import numpy as np
10
+ from .enums import CompressionType
11
+
12
+ IMAGE_TYPE_BGR8 = 16 # Matches CV_8UC3
13
+
14
+ # ---------------------------------------------------------------------------
15
+ # TurboJPEG (preferred) with cv2 fallback
16
+ # ---------------------------------------------------------------------------
17
+ try:
18
+ from turbojpeg import TurboJPEG, TJPF_BGR, TJFLAG_FASTDCT, TJSAMP_420
19
+
20
+ _TURBOJPEG_LIB_PATHS = [
21
+ None, # auto-detect first
22
+ "/usr/lib/x86_64-linux-gnu/libturbojpeg.so",
23
+ "/usr/lib/libturbojpeg.so",
24
+ "/opt/libjpeg-turbo/lib64/libturbojpeg.so",
25
+ "/usr/lib/aarch64-linux-gnu/libturbojpeg.so",
26
+ ]
27
+ _tj = None
28
+ for _lib_path in _TURBOJPEG_LIB_PATHS:
29
+ try:
30
+ _tj = TurboJPEG(_lib_path) if _lib_path else TurboJPEG()
31
+ break
32
+ except (RuntimeError, OSError):
33
+ continue
34
+ _HAS_TURBOJPEG = _tj is not None
35
+ except ImportError:
36
+ _HAS_TURBOJPEG = False
37
+
38
+ # ---------------------------------------------------------------------------
39
+ # LZ4
40
+ # ---------------------------------------------------------------------------
41
+ try:
42
+ import lz4.block
43
+
44
+ _HAS_LZ4 = True
45
+ except ImportError:
46
+ _HAS_LZ4 = False
47
+
48
+
49
+ # ---------------------------------------------------------------------------
50
+ # JPEG
51
+ # ---------------------------------------------------------------------------
52
+ def encode_jpeg(image: np.ndarray, quality: int = 95) -> bytes:
53
+ """Encode BGR uint8 image to JPEG bytes.
54
+
55
+ Matches image.cpp encodeJPEG: TurboJPEG with TJPF_BGR, TJSAMP_420.
56
+ """
57
+ if _HAS_TURBOJPEG:
58
+ return _tj.encode(
59
+ image, quality=quality, pixel_format=TJPF_BGR, jpeg_subsample=TJSAMP_420
60
+ )
61
+ import cv2
62
+
63
+ _, buf = cv2.imencode(".jpg", image, [cv2.IMWRITE_JPEG_QUALITY, quality])
64
+ return buf.tobytes()
65
+
66
+
67
+ def decode_jpeg(data: bytes) -> np.ndarray:
68
+ """Decode JPEG bytes to BGR uint8 numpy array (H, W, 3).
69
+
70
+ Matches image.cpp decodeJPEG: TurboJPEG with TJPF_BGR.
71
+ """
72
+ if not data:
73
+ raise RuntimeError("Empty JPEG data")
74
+ if _HAS_TURBOJPEG:
75
+ return _tj.decode(data, pixel_format=TJPF_BGR)
76
+ import cv2
77
+
78
+ arr = np.frombuffer(data, dtype=np.uint8)
79
+ img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
80
+ if img is None:
81
+ raise RuntimeError("Failed to decompress JPEG")
82
+ return img
83
+
84
+
85
+ # ---------------------------------------------------------------------------
86
+ # LZ4
87
+ # ---------------------------------------------------------------------------
88
+ def encode_lz4(image: np.ndarray) -> bytes:
89
+ """Encode BGR image with 16-byte header + LZ4 compressed data.
90
+
91
+ Header format matches image.cpp encodeLZ4:
92
+ [uint32 width, uint32 height, uint32 type(16), uint32 original_size]
93
+ """
94
+ if not _HAS_LZ4:
95
+ raise RuntimeError("lz4 package not installed")
96
+ h, w = image.shape[:2]
97
+ data_size = w * h * 3
98
+ header = struct.pack("<IIII", w, h, IMAGE_TYPE_BGR8, data_size)
99
+ # lz4.block.compress accepts buffer-protocol objects; pass raw memory
100
+ # buffer directly when array is already C-contiguous to avoid .tobytes() copy.
101
+ if image.flags.c_contiguous:
102
+ compressed = lz4.block.compress(image.data, store_size=False)
103
+ else:
104
+ compressed = lz4.block.compress(
105
+ np.ascontiguousarray(image).tobytes(), store_size=False
106
+ )
107
+ return header + compressed
108
+
109
+
110
+ def decode_lz4(data: bytes) -> np.ndarray:
111
+ """Decode 16-byte header + LZ4 compressed data to BGR image.
112
+
113
+ Matches image.cpp decodeLZ4.
114
+ """
115
+ if not _HAS_LZ4:
116
+ raise RuntimeError("lz4 package not installed")
117
+ if len(data) < 16:
118
+ raise RuntimeError("Invalid LZ4 data: header too small")
119
+ w, h, img_type, orig_size = struct.unpack("<IIII", data[:16])
120
+ if img_type != IMAGE_TYPE_BGR8:
121
+ raise RuntimeError(f"Invalid image type in LZ4 data: {img_type}")
122
+ expected = w * h * 3
123
+ if orig_size != expected:
124
+ raise RuntimeError(
125
+ f"Invalid data size in LZ4 header: {orig_size} != {expected}"
126
+ )
127
+ decompressed = lz4.block.decompress(data[16:], uncompressed_size=orig_size)
128
+ return np.frombuffer(decompressed, dtype=np.uint8).reshape(h, w, 3).copy()
129
+
130
+
131
+ # ---------------------------------------------------------------------------
132
+ # NONE (raw with header)
133
+ # ---------------------------------------------------------------------------
134
+ def encode_none(image: np.ndarray) -> bytes:
135
+ """Encode with 16-byte header + raw BGR pixel data.
136
+
137
+ Matches image.cpp encodeNONE.
138
+ """
139
+ h, w = image.shape[:2]
140
+ data_size = w * h * 3
141
+ header = struct.pack("<IIII", w, h, IMAGE_TYPE_BGR8, data_size)
142
+ return header + np.ascontiguousarray(image).tobytes()
143
+
144
+
145
+ def decode_none(data: bytes) -> np.ndarray:
146
+ """Decode 16-byte header + raw BGR pixel data.
147
+
148
+ Matches image.cpp decodeNONE.
149
+ """
150
+ if len(data) < 16:
151
+ raise RuntimeError("Invalid raw data: header too small")
152
+ w, h, img_type, data_size = struct.unpack("<IIII", data[:16])
153
+ if img_type != IMAGE_TYPE_BGR8:
154
+ raise RuntimeError(f"Invalid image type in raw data: {img_type}")
155
+ expected = w * h * 3
156
+ if data_size != expected or len(data) != expected + 16:
157
+ raise RuntimeError("Invalid data size in raw header")
158
+ return np.frombuffer(data[16:], dtype=np.uint8).reshape(h, w, 3).copy()
159
+
160
+
161
+ # ---------------------------------------------------------------------------
162
+ # TEMP_FILE
163
+ # ---------------------------------------------------------------------------
164
+ _temp_counter = 0
165
+
166
+
167
+ def encode_temp_file(image: np.ndarray, temp_dir: str, quality: int = 95) -> bytes:
168
+ """Encode image to JPEG, write to temp file, return path as bytes.
169
+
170
+ Matches image.cpp encodeTempFile.
171
+ """
172
+ global _temp_counter
173
+ if not temp_dir:
174
+ raise RuntimeError("Temp dir is not set")
175
+ _temp_counter += 1
176
+ path = os.path.join(temp_dir, f"frame_{_temp_counter}.bin")
177
+ jpeg_data = encode_jpeg(image, quality)
178
+ with open(path, "wb") as f:
179
+ f.write(jpeg_data)
180
+ return path.encode("utf-8")
181
+
182
+
183
+ def decode_temp_file(path_data: bytes) -> np.ndarray:
184
+ """Read JPEG data from temp file, decode to BGR image.
185
+
186
+ Matches image.cpp decodeTempFile.
187
+ """
188
+ if not path_data:
189
+ raise RuntimeError("Empty path data")
190
+ path = path_data.decode("utf-8")
191
+ with open(path, "rb") as f:
192
+ jpeg_data = f.read()
193
+ return decode_jpeg(jpeg_data)
194
+
195
+
196
+ # ---------------------------------------------------------------------------
197
+ # Temp dir management
198
+ # ---------------------------------------------------------------------------
199
+ def create_temp_dir() -> str:
200
+ """Create a unique temp directory for frame storage.
201
+
202
+ Matches image.cpp createTempDir.
203
+ """
204
+ root = os.path.join(tempfile.gettempdir(), "bithuman_tmp")
205
+ os.makedirs(root, exist_ok=True)
206
+ return tempfile.mkdtemp(prefix="video_", dir=root)
207
+
208
+
209
+ def cleanup_temp_dir(path: str) -> None:
210
+ """Remove temp directory and all contents.
211
+
212
+ Matches image.cpp cleanupTempDir.
213
+ """
214
+ import shutil
215
+
216
+ if path and os.path.isdir(path):
217
+ shutil.rmtree(path, ignore_errors=True)
218
+
219
+
220
+ # ---------------------------------------------------------------------------
221
+ # Unified encode/decode API
222
+ # ---------------------------------------------------------------------------
223
+ def encode_image(
224
+ image: np.ndarray,
225
+ compression: CompressionType,
226
+ quality: int = 95,
227
+ temp_dir: str = "",
228
+ ) -> bytes:
229
+ """Encode image with specified compression type.
230
+
231
+ Matches image.cpp encodeImage.
232
+ """
233
+ if compression == CompressionType.JPEG:
234
+ return encode_jpeg(image, quality)
235
+ elif compression == CompressionType.LZ4:
236
+ return encode_lz4(image)
237
+ elif compression == CompressionType.NONE:
238
+ return encode_none(image)
239
+ elif compression == CompressionType.TEMP_FILE:
240
+ return encode_temp_file(image, temp_dir, quality)
241
+ raise ValueError(f"Unknown compression type: {compression}")
242
+
243
+
244
+ def decode_image(data: bytes, compression: CompressionType) -> np.ndarray:
245
+ """Decode image with specified compression type.
246
+
247
+ Matches image.cpp decodeImage.
248
+ """
249
+ if compression == CompressionType.JPEG:
250
+ return decode_jpeg(data)
251
+ elif compression == CompressionType.LZ4:
252
+ return decode_lz4(data)
253
+ elif compression == CompressionType.NONE:
254
+ return decode_none(data)
255
+ elif compression == CompressionType.TEMP_FILE:
256
+ return decode_temp_file(data)
257
+ raise ValueError(f"Unknown compression type: {compression}")
@@ -0,0 +1,16 @@
1
+ """Pure Python enums replacing pybind11-exported C++ enums."""
2
+
3
+ from enum import IntEnum
4
+
5
+
6
+ class CompressionType(IntEnum):
7
+ NONE = 0
8
+ JPEG = 1
9
+ LZ4 = 2
10
+ TEMP_FILE = 3
11
+
12
+
13
+ class LoadingMode(IntEnum):
14
+ SYNC = 0
15
+ ASYNC = 1
16
+ ON_DEMAND = 2
@@ -0,0 +1,192 @@
1
+ """Image operations — replaces image_ops.cpp (SSE2 blend + bilinear resize).
2
+
3
+ Provides two blend implementations:
4
+ 1. numba JIT (primary) — per-pixel with exact div255 formula, SIMD-comparable
5
+ 2. numpy fallback — vectorized uint16 arithmetic, ~3-5ms for typical face ROI
6
+
7
+ Resize delegates to cv2.resize (IPP/SIMD-optimized internally).
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import numpy as np
13
+ import cv2
14
+
15
+ try:
16
+ from numba import njit, prange
17
+
18
+ _HAS_NUMBA = True
19
+ except ImportError:
20
+ _HAS_NUMBA = False
21
+
22
+
23
+ # ---------------------------------------------------------------------------
24
+ # Blend: numba JIT (primary path)
25
+ # ---------------------------------------------------------------------------
26
+ if _HAS_NUMBA:
27
+
28
+ @njit(cache=True, parallel=True)
29
+ def _blend_3ch_numba(
30
+ frame: np.ndarray,
31
+ lip: np.ndarray,
32
+ mask: np.ndarray,
33
+ roi_x: int,
34
+ roi_y: int,
35
+ roi_h: int,
36
+ roi_w: int,
37
+ ) -> None:
38
+ """In-place alpha blend with 3-channel mask using exact div255."""
39
+ for y in prange(roi_h):
40
+ for x in range(roi_w):
41
+ for c in range(3):
42
+ m = np.uint16(mask[y, x, c])
43
+ im = np.uint16(255) - m
44
+ val = np.uint16(lip[y, x, c]) * m + np.uint16(
45
+ frame[roi_y + y, roi_x + x, c]
46
+ ) * im
47
+ # div255: exact integer division by 255
48
+ val = (val + np.uint16(1) + ((val + np.uint16(1)) >> np.uint16(8))) >> np.uint16(8)
49
+ frame[roi_y + y, roi_x + x, c] = np.uint8(val)
50
+
51
+ @njit(cache=True, parallel=True)
52
+ def _blend_1ch_numba(
53
+ frame: np.ndarray,
54
+ lip: np.ndarray,
55
+ mask: np.ndarray,
56
+ roi_x: int,
57
+ roi_y: int,
58
+ roi_h: int,
59
+ roi_w: int,
60
+ ) -> None:
61
+ """In-place alpha blend with 1-channel mask using exact div255."""
62
+ for y in prange(roi_h):
63
+ for x in range(roi_w):
64
+ m = np.uint16(mask[y, x, 0] if mask.ndim == 3 else mask[y, x])
65
+ im = np.uint16(255) - m
66
+ for c in range(3):
67
+ val = np.uint16(lip[y, x, c]) * m + np.uint16(
68
+ frame[roi_y + y, roi_x + x, c]
69
+ ) * im
70
+ val = (val + np.uint16(1) + ((val + np.uint16(1)) >> np.uint16(8))) >> np.uint16(8)
71
+ frame[roi_y + y, roi_x + x, c] = np.uint8(val)
72
+
73
+
74
+ # ---------------------------------------------------------------------------
75
+ # Blend: numpy fallback
76
+ # ---------------------------------------------------------------------------
77
+ def _blend_numpy(
78
+ frame: np.ndarray,
79
+ lip: np.ndarray,
80
+ mask: np.ndarray,
81
+ roi_x: int,
82
+ roi_y: int,
83
+ roi_h: int,
84
+ roi_w: int,
85
+ ) -> None:
86
+ """Vectorized numpy blend with exact div255 formula."""
87
+ roi = frame[roi_y : roi_y + roi_h, roi_x : roi_x + roi_w]
88
+
89
+ # Ensure mask is 3-channel
90
+ if mask.ndim == 2:
91
+ m = mask[:, :, np.newaxis].astype(np.uint16)
92
+ elif mask.ndim == 3 and mask.shape[2] == 1:
93
+ m = mask.astype(np.uint16)
94
+ else:
95
+ m = mask.astype(np.uint16)
96
+
97
+ # Broadcast to 3 channels if needed
98
+ if m.shape[2] == 1:
99
+ m = np.broadcast_to(m, roi.shape).copy()
100
+
101
+ im = np.uint16(255) - m
102
+ val = lip[:roi_h, :roi_w].astype(np.uint16) * m + roi.astype(np.uint16) * im
103
+ # div255: (val + 1 + ((val + 1) >> 8)) >> 8
104
+ val_p1 = val + np.uint16(1)
105
+ result = ((val_p1 + (val_p1 >> 8)) >> 8).astype(np.uint8)
106
+ frame[roi_y : roi_y + roi_h, roi_x : roi_x + roi_w] = result
107
+
108
+
109
+ # ---------------------------------------------------------------------------
110
+ # Public API
111
+ # ---------------------------------------------------------------------------
112
+ def blend_face_region(
113
+ frame: np.ndarray,
114
+ roi_x: int,
115
+ roi_y: int,
116
+ roi_w: int,
117
+ roi_h: int,
118
+ lip: np.ndarray,
119
+ mask: np.ndarray,
120
+ ) -> None:
121
+ """Alpha-blend lip overlay into face region of frame (in-place).
122
+
123
+ Matches image_ops.cpp blend_face_region (lines 429-503):
124
+ Formula per byte: out = div255(lip * mask + face * (255 - mask))
125
+ where div255(x) = (x + 1 + ((x + 1) >> 8)) >> 8
126
+
127
+ Args:
128
+ frame: (H, W, 3) uint8 BGR — modified in-place.
129
+ roi_x, roi_y: top-left corner of face region in frame.
130
+ roi_w, roi_h: dimensions of face region.
131
+ lip: (roi_h, roi_w, 3) uint8 BGR lip overlay.
132
+ mask: (roi_h, roi_w, 1) or (roi_h, roi_w, 3) uint8 blend mask.
133
+ """
134
+ is_3ch = mask.ndim == 3 and mask.shape[2] == 3
135
+
136
+ if _HAS_NUMBA:
137
+ if is_3ch:
138
+ _blend_3ch_numba(frame, lip, mask, roi_x, roi_y, roi_h, roi_w)
139
+ else:
140
+ _blend_1ch_numba(frame, lip, mask, roi_x, roi_y, roi_h, roi_w)
141
+ else:
142
+ _blend_numpy(frame, lip, mask, roi_x, roi_y, roi_h, roi_w)
143
+
144
+
145
+ def resize_image(image: np.ndarray, new_width: int, new_height: int) -> np.ndarray:
146
+ """Bilinear resize using OpenCV (IPP/SIMD optimized).
147
+
148
+ Replaces image_ops.cpp resize_image (lines 347-375).
149
+ """
150
+ if image.size == 0:
151
+ return image
152
+ h, w = image.shape[:2]
153
+ if w == new_width and h == new_height:
154
+ return image
155
+ return cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_LINEAR)
156
+
157
+
158
+ def resize_image_scale(image: np.ndarray, scale: float) -> np.ndarray:
159
+ """Resize by scale factor. Replaces resize_image_scale in C++."""
160
+ if scale == 1.0 or image.size == 0:
161
+ return image
162
+ new_w = round(image.shape[1] * scale)
163
+ new_h = round(image.shape[0] * scale)
164
+ return resize_image(image, new_w, new_h)
165
+
166
+
167
+ # ---------------------------------------------------------------------------
168
+ # Pre-warm numba JIT in background thread (avoids ~200ms spike on first blend)
169
+ # ---------------------------------------------------------------------------
170
+ def _warmup_numba_jit() -> None:
171
+ """Trigger JIT compilation of blend functions with tiny arrays.
172
+
173
+ With cache=True, the compiled code persists to __pycache__ for
174
+ subsequent process starts.
175
+ """
176
+ if not _HAS_NUMBA:
177
+ return
178
+ try:
179
+ _tiny_frame = np.zeros((2, 2, 3), dtype=np.uint8)
180
+ _tiny_lip = np.zeros((1, 1, 3), dtype=np.uint8)
181
+ _tiny_mask_3ch = np.zeros((1, 1, 3), dtype=np.uint8)
182
+ _tiny_mask_1ch = np.zeros((1, 1, 1), dtype=np.uint8)
183
+ _blend_3ch_numba(_tiny_frame, _tiny_lip, _tiny_mask_3ch, 0, 0, 1, 1)
184
+ _blend_1ch_numba(_tiny_frame, _tiny_lip, _tiny_mask_1ch, 0, 0, 1, 1)
185
+ except Exception:
186
+ pass # Warmup failure is non-fatal
187
+
188
+
189
+ import threading as _threading
190
+
191
+ _warmup_thread = _threading.Thread(target=_warmup_numba_jit, daemon=True)
192
+ _warmup_thread.start()
@@ -0,0 +1,108 @@
1
+ """ONNX audio encoder inference — replaces C++ melChunkToAudioEmbedding."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ import threading
7
+ from typing import Dict, Optional
8
+
9
+ import numpy as np
10
+ import onnxruntime as ort
11
+
12
+ # Module-level ONNX session cache keyed by absolute model path.
13
+ # ONNX Runtime internally shares model weights across sessions, but this
14
+ # avoids Python-side overhead (~8 KB per duplicate session object).
15
+ _onnx_session_cache: Dict[str, ort.InferenceSession] = {}
16
+ _onnx_session_lock = threading.Lock()
17
+
18
+
19
+ class AudioEncoder:
20
+ """Wraps an ONNX audio encoder session for mel-to-embedding inference.
21
+
22
+ Matches the C++ BithumanRuntime ONNX setup:
23
+ - intra_op_num_threads=1 (avoid nested parallelism)
24
+ - disable CPU memory arena (container compatibility)
25
+ - fallback from EXTENDED to BASIC optimization level
26
+ - pre-allocated input buffer reused across calls
27
+ """
28
+
29
+ def __init__(self, model_path: str = ""):
30
+ self._session: Optional[ort.InferenceSession] = None
31
+ self._input_name: Optional[str] = None
32
+ self._output_name: Optional[str] = None
33
+ # Pre-allocated input buffer matching C++ onnx_input_data_ (1*1*80*16)
34
+ self._input_buf: np.ndarray = np.zeros((1, 1, 80, 16), dtype=np.float32)
35
+ if model_path:
36
+ self.load(model_path)
37
+
38
+ def load(self, model_path: str) -> None:
39
+ """Load ONNX model, reusing cached session if available.
40
+
41
+ Mirrors generator.cpp setAudioEncoder (lines 439-502):
42
+ - Try EXTENDED first, fall back to BASIC on failure.
43
+ """
44
+ abs_path = os.path.abspath(model_path)
45
+
46
+ with _onnx_session_lock:
47
+ if abs_path in _onnx_session_cache:
48
+ self._session = _onnx_session_cache[abs_path]
49
+ self._input_name = self._session.get_inputs()[0].name
50
+ self._output_name = self._session.get_outputs()[0].name
51
+ return
52
+
53
+ # Create new session outside the lock (loading can be slow)
54
+ opts = ort.SessionOptions()
55
+ opts.intra_op_num_threads = 1
56
+ opts.enable_cpu_mem_arena = False
57
+
58
+ try:
59
+ opts.graph_optimization_level = (
60
+ ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
61
+ )
62
+ session = ort.InferenceSession(model_path, opts)
63
+ except Exception:
64
+ opts.graph_optimization_level = (
65
+ ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
66
+ )
67
+ try:
68
+ session = ort.InferenceSession(model_path, opts)
69
+ except Exception as e:
70
+ raise RuntimeError(
71
+ f"Failed to load audio encoder model '{model_path}': {e}"
72
+ ) from e
73
+
74
+ with _onnx_session_lock:
75
+ # Double-check: another thread may have cached it while we were loading
76
+ if abs_path in _onnx_session_cache:
77
+ session = _onnx_session_cache[abs_path]
78
+ else:
79
+ _onnx_session_cache[abs_path] = session
80
+
81
+ self._session = session
82
+ self._input_name = self._session.get_inputs()[0].name
83
+ self._output_name = self._session.get_outputs()[0].name
84
+
85
+ def encode(self, mel_chunk: np.ndarray) -> np.ndarray:
86
+ """Convert a mel spectrogram chunk to an audio embedding.
87
+
88
+ Args:
89
+ mel_chunk: float32 array of shape (80, 16).
90
+
91
+ Returns:
92
+ float32 array of shape (embedding_dim,).
93
+
94
+ Mirrors generator.cpp melChunkToAudioEmbedding (lines 505-533):
95
+ - Input shape: (1, 1, 80, 16), row-major (C-contiguous = numpy default).
96
+ - Output: squeezed to 1-D embedding vector.
97
+ """
98
+ if self._session is None:
99
+ raise RuntimeError("Audio encoder not initialized")
100
+
101
+ # Copy into pre-allocated buffer (matches C++ zero-copy pattern)
102
+ np.copyto(self._input_buf[0, 0], mel_chunk.astype(np.float32))
103
+
104
+ result = self._session.run(
105
+ [self._output_name],
106
+ {self._input_name: self._input_buf},
107
+ )
108
+ return result[0].squeeze()
bithuman/engine/knn.py ADDED
@@ -0,0 +1,58 @@
1
+ """Audio feature KNN search — replaces Eigen squaredNorm + minCoeff."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import numpy as np
6
+
7
+
8
+ class AudioFeatureIndex:
9
+ """Stores audio feature cluster centers and finds the nearest match.
10
+
11
+ Mirrors generator.cpp processAudio (lines 337-342):
12
+ ``(audio_features_.rowwise() - audio_embed.transpose())
13
+ .rowwise().squaredNorm().minCoeff(&min_index)``
14
+ """
15
+
16
+ def __init__(self) -> None:
17
+ self._features: np.ndarray = np.empty((0, 0), dtype=np.float32)
18
+
19
+ def set_features(self, features: np.ndarray) -> None:
20
+ """Set audio feature cluster centers.
21
+
22
+ Args:
23
+ features: float32 array shape (num_clusters, embedding_dim), row-major.
24
+ """
25
+ self._features = np.ascontiguousarray(features, dtype=np.float32)
26
+
27
+ def load_from_h5(self, h5_path: str) -> None:
28
+ """Load audio features from an HDF5 file.
29
+
30
+ Mirrors generator.cpp setAudioFeature (lines 409-432):
31
+ Reads dataset ``audio_feature`` as row-major float32 matrix.
32
+ """
33
+ import h5py
34
+
35
+ with h5py.File(h5_path, "r") as f:
36
+ data = f["audio_feature"][:]
37
+ self.set_features(data)
38
+
39
+ def find_nearest(self, embedding: np.ndarray) -> int:
40
+ """Return index of nearest cluster center by squared Euclidean distance.
41
+
42
+ Args:
43
+ embedding: float32 array shape (embedding_dim,).
44
+
45
+ Returns:
46
+ Index of the nearest cluster center.
47
+ """
48
+ diff = self._features - embedding
49
+ sq_dists = np.einsum("ij,ij->i", diff, diff)
50
+ return int(np.argmin(sq_dists))
51
+
52
+ @property
53
+ def num_clusters(self) -> int:
54
+ return self._features.shape[0]
55
+
56
+ @property
57
+ def embedding_dim(self) -> int:
58
+ return self._features.shape[1] if self._features.size > 0 else 0