bithuman 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. bithuman/__init__.py +13 -0
  2. bithuman/_version.py +1 -0
  3. bithuman/api.py +164 -0
  4. bithuman/audio/__init__.py +19 -0
  5. bithuman/audio/audio.py +396 -0
  6. bithuman/audio/hparams.py +108 -0
  7. bithuman/audio/utils.py +255 -0
  8. bithuman/config.py +88 -0
  9. bithuman/engine/__init__.py +15 -0
  10. bithuman/engine/auth.py +335 -0
  11. bithuman/engine/compression.py +257 -0
  12. bithuman/engine/enums.py +16 -0
  13. bithuman/engine/image_ops.py +192 -0
  14. bithuman/engine/inference.py +108 -0
  15. bithuman/engine/knn.py +58 -0
  16. bithuman/engine/video_data.py +391 -0
  17. bithuman/engine/video_reader.py +168 -0
  18. bithuman/lib/__init__.py +1 -0
  19. bithuman/lib/audio_encoder.onnx +45631 -28
  20. bithuman/lib/generator.py +763 -0
  21. bithuman/lib/pth2h5.py +106 -0
  22. bithuman/plugins/__init__.py +0 -0
  23. bithuman/plugins/stt.py +185 -0
  24. bithuman/runtime.py +1004 -0
  25. bithuman/runtime_async.py +469 -0
  26. bithuman/service/__init__.py +9 -0
  27. bithuman/service/client.py +788 -0
  28. bithuman/service/messages.py +210 -0
  29. bithuman/service/server.py +759 -0
  30. bithuman/utils/__init__.py +43 -0
  31. bithuman/utils/agent.py +359 -0
  32. bithuman/utils/fps_controller.py +90 -0
  33. bithuman/utils/image.py +41 -0
  34. bithuman/utils/unzip.py +38 -0
  35. bithuman/video_graph/__init__.py +16 -0
  36. bithuman/video_graph/action_trigger.py +83 -0
  37. bithuman/video_graph/driver_video.py +482 -0
  38. bithuman/video_graph/navigator.py +736 -0
  39. bithuman/video_graph/trigger.py +90 -0
  40. bithuman/video_graph/video_script.py +344 -0
  41. bithuman-1.0.2.dist-info/METADATA +37 -0
  42. bithuman-1.0.2.dist-info/RECORD +44 -0
  43. bithuman-1.0.2.dist-info/WHEEL +5 -0
  44. bithuman-1.0.2.dist-info/top_level.txt +1 -0
bithuman/lib/pth2h5.py ADDED
@@ -0,0 +1,106 @@
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ from pathlib import Path
5
+ from typing import Union
6
+
7
+ import cv2
8
+ import h5py
9
+ import numpy as np
10
+
11
+
12
+ def encode_mask_as_jpg(mask: np.ndarray, quality: int = 85) -> bytes:
13
+ """Encode a mask array as JPG bytes.
14
+
15
+ Args:
16
+ mask: Numpy array of mask (single channel)
17
+ quality: JPG compression quality (1-100)
18
+
19
+ Returns:
20
+ JPG encoded bytes
21
+ """
22
+ # Convert to BGR for OpenCV
23
+ mask = (mask * 255).astype(np.uint8)
24
+ mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
25
+
26
+ # Encode as JPG
27
+ _, encoded = cv2.imencode(".jpg", mask, [int(cv2.IMWRITE_JPEG_QUALITY), quality])
28
+ return encoded.tobytes()
29
+
30
+
31
+ def convert_pth_to_h5(
32
+ pth_path: Union[str, Path], h5_path: Union[str, Path] = None
33
+ ) -> str:
34
+ """Convert a PyTorch model file (.pth) to a HDF5 file (.h5).
35
+
36
+ Args:
37
+ pth_path: Path to the PyTorch model file
38
+ h5_path: Path to save the HDF5 file. If None, saves to same location as pth_path
39
+ """
40
+ try:
41
+ import torch
42
+ except ImportError:
43
+ raise ImportError(
44
+ "PyTorch is not installed. Please install it using 'pip install torch'."
45
+ )
46
+
47
+ pth_path = Path(pth_path)
48
+ if h5_path is None:
49
+ h5_path = pth_path.with_suffix(".h5")
50
+
51
+ # Load the PyTorch model
52
+ data = torch.load(str(pth_path))
53
+
54
+ # Extract data
55
+ face_masks: list[bytes] = []
56
+ face_coords: list[np.ndarray] = []
57
+ frame_wh = data[0]["frame_wh"].numpy().astype(np.int32)
58
+
59
+ for item in data:
60
+ padded_crop_coords = item["padded_crop_coords"].numpy()
61
+ face_xyxy = item["face_coords"].numpy()
62
+ face_mask = item["face_mask"]
63
+
64
+ # Encode face mask if needed
65
+ if not isinstance(face_mask, bytes):
66
+ face_mask = encode_mask_as_jpg(face_mask.numpy())
67
+
68
+ # Adjust coordinates
69
+ shift_x, shift_y = padded_crop_coords[:2].astype(np.int32)
70
+ face_xyxy = face_xyxy[:4].astype(np.int32)
71
+ face_xyxy[0::2] += shift_x
72
+ face_xyxy[1::2] += shift_y
73
+
74
+ face_masks.append(face_mask)
75
+ face_coords.append(face_xyxy)
76
+
77
+ # Save to H5 file
78
+ with h5py.File(h5_path, "w") as f:
79
+ f.create_dataset("face_coords", data=face_coords)
80
+
81
+ dt = h5py.special_dtype(vlen=np.dtype("uint8"))
82
+ masks_dataset = f.create_dataset(
83
+ "face_masks", shape=(len(face_masks),), dtype=dt
84
+ )
85
+ for i, mask in enumerate(face_masks):
86
+ masks_dataset[i] = np.frombuffer(mask, dtype=np.uint8)
87
+ f.attrs["frame_wh"] = frame_wh
88
+
89
+ return str(h5_path)
90
+
91
+
92
+ def main():
93
+ parser = argparse.ArgumentParser(
94
+ description="Convert PyTorch .pth files to HDF5 format"
95
+ )
96
+ parser.add_argument("pth_path", type=str, help="Path to input .pth file")
97
+ parser.add_argument(
98
+ "--output", "-o", type=str, help="Path to output .h5 file (optional)"
99
+ )
100
+
101
+ args = parser.parse_args()
102
+ convert_pth_to_h5(args.pth_path, args.output)
103
+
104
+
105
+ if __name__ == "__main__":
106
+ main()
File without changes
@@ -0,0 +1,185 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Dict, Union
5
+
6
+ import aiohttp
7
+ import numpy as np
8
+ from loguru import logger
9
+
10
+ from bithuman.audio import resample
11
+
12
+ try:
13
+ from bithuman_local_voice import BithumanLocalSTT as BithumanSTTImpl
14
+ except ImportError:
15
+ BithumanSTTImpl = None
16
+
17
+ try:
18
+ from livekit.agents import utils
19
+ from livekit.agents.stt import (
20
+ STT,
21
+ SpeechData,
22
+ SpeechEvent,
23
+ SpeechEventType,
24
+ STTCapabilities,
25
+ )
26
+ from livekit.agents.types import NOT_GIVEN, APIConnectOptions, NotGivenOr
27
+ except ImportError:
28
+ raise ImportError(
29
+ "livekit is required, please install it with `pip install livekit-agents`"
30
+ )
31
+
32
+
33
+ @dataclass
34
+ class _STTOptions:
35
+ locale: str = "en-US"
36
+ on_device: bool = True
37
+ punctuation: bool = True
38
+ debug: bool = False
39
+
40
+
41
+ class BithumanSTTError(Exception):
42
+ pass
43
+
44
+
45
+ class BithumanLocalSTT(STT):
46
+ _SAMPLE_RATE = 16000
47
+
48
+ def __init__(
49
+ self,
50
+ *,
51
+ locale="en-US",
52
+ server_url=None,
53
+ on_device=True,
54
+ punctuation=True,
55
+ debug=False,
56
+ ):
57
+ capabilities = STTCapabilities(streaming=False, interim_results=False)
58
+ super().__init__(capabilities=capabilities)
59
+ self._opts = _STTOptions(
60
+ locale=locale, on_device=on_device, punctuation=punctuation, debug=debug
61
+ )
62
+ self._server_url: str | None = None
63
+ self._session: aiohttp.ClientSession | None = None
64
+ self._stt_impl = None
65
+
66
+ if server_url:
67
+ self._server_url = server_url
68
+ self._session = aiohttp.ClientSession()
69
+ else:
70
+ if BithumanSTTImpl is None:
71
+ raise ImportError(
72
+ "bithuman_local_voice is required if server_url is not provided, "
73
+ "please install it with `pip install bithuman_local_voice`"
74
+ )
75
+ self._stt_impl = BithumanSTTImpl(
76
+ locale=locale,
77
+ on_device=on_device,
78
+ punctuation=punctuation,
79
+ debug=debug,
80
+ )
81
+
82
+ async def _recognize_impl(
83
+ self,
84
+ buffer: utils.audio.AudioBuffer,
85
+ *,
86
+ language: NotGivenOr[str] = NOT_GIVEN,
87
+ conn_options: APIConnectOptions,
88
+ ):
89
+ if utils.is_given(language) and language != self._opts.locale:
90
+ try:
91
+ await self._set_locale(language)
92
+ except Exception as e:
93
+ logger.error(f"Failed to set locale: {e}")
94
+
95
+ frame = utils.audio.combine_frames(buffer)
96
+ audio_data = np.frombuffer(frame.data, dtype=np.int16)
97
+ if frame.sample_rate != self._SAMPLE_RATE:
98
+ audio_data = resample(audio_data, frame.sample_rate, self._SAMPLE_RATE)
99
+
100
+ try:
101
+ result = await self._recognize_audio(
102
+ audio_data, sample_rate=self._SAMPLE_RATE
103
+ )
104
+ except Exception as e:
105
+ logger.warning(f"Failed to recognize audio with error: {e}")
106
+ result = {"text": "", "confidence": 0.0}
107
+
108
+ return SpeechEvent(
109
+ type=SpeechEventType.FINAL_TRANSCRIPT,
110
+ request_id="",
111
+ alternatives=[
112
+ SpeechData(
113
+ language=self._opts.locale,
114
+ text=result["text"],
115
+ confidence=result["confidence"],
116
+ )
117
+ ],
118
+ )
119
+
120
+ async def _recognize_audio(
121
+ self, audio_data: Union[bytes, np.ndarray], sample_rate: int = 16000
122
+ ) -> Dict:
123
+ """Recognize speech in the provided audio data.
124
+
125
+ Args:
126
+ audio_data: Audio data as bytes or numpy array
127
+ sample_rate: Sample rate of the audio data
128
+
129
+ Returns:
130
+ Dict containing transcription and confidence score
131
+ """
132
+ if self._stt_impl is not None:
133
+ return await self._stt_impl.recognize(audio_data, sample_rate)
134
+
135
+ # Convert numpy array to bytes if needed
136
+ if isinstance(audio_data, np.ndarray):
137
+ # Convert to float in [-1, 1] range if not already
138
+ if audio_data.dtype != np.int16:
139
+ audio_data = (audio_data * 32767).astype(np.int16)
140
+ audio_bytes = audio_data.tobytes()
141
+ else:
142
+ audio_bytes = audio_data
143
+
144
+ async with self._session.post(
145
+ f"{self._server_url}/transcribe",
146
+ data=audio_bytes,
147
+ headers={"Content-Type": "application/octet-stream"},
148
+ timeout=5,
149
+ ) as response:
150
+ response.raise_for_status()
151
+ result: dict = await response.json()
152
+ return {
153
+ "text": result.get("text", ""),
154
+ "confidence": result.get("confidence", 0.0),
155
+ }
156
+
157
+ async def _set_locale(self, locale: str):
158
+ """Set the recognition locale.
159
+
160
+ Args:
161
+ locale: Locale identifier (e.g. en-US, fr-FR)
162
+ """
163
+ if self._stt_impl is not None:
164
+ await self._stt_impl.set_locale(locale)
165
+ self._opts.locale = locale
166
+ return
167
+
168
+ assert self._server_url is not None
169
+ async with self._session.post(
170
+ f"{self._server_url}/setLocale",
171
+ json={"locale": locale},
172
+ headers={"Content-Type": "application/json"},
173
+ timeout=5,
174
+ ) as response:
175
+ response.raise_for_status()
176
+ result: dict = await response.json()
177
+ if result.get("success"):
178
+ self._opts.locale = locale
179
+ raise BithumanSTTError(result.get("message", "Unknown error"))
180
+
181
+ async def aclose(self):
182
+ if self._stt_impl is not None:
183
+ await self._stt_impl.stop()
184
+ if self._session is not None:
185
+ await self._session.close()