bithuman 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bithuman/__init__.py +13 -0
- bithuman/_version.py +1 -0
- bithuman/api.py +164 -0
- bithuman/audio/__init__.py +19 -0
- bithuman/audio/audio.py +396 -0
- bithuman/audio/hparams.py +108 -0
- bithuman/audio/utils.py +255 -0
- bithuman/config.py +88 -0
- bithuman/engine/__init__.py +15 -0
- bithuman/engine/auth.py +335 -0
- bithuman/engine/compression.py +257 -0
- bithuman/engine/enums.py +16 -0
- bithuman/engine/image_ops.py +192 -0
- bithuman/engine/inference.py +108 -0
- bithuman/engine/knn.py +58 -0
- bithuman/engine/video_data.py +391 -0
- bithuman/engine/video_reader.py +168 -0
- bithuman/lib/__init__.py +1 -0
- bithuman/lib/audio_encoder.onnx +45631 -28
- bithuman/lib/generator.py +763 -0
- bithuman/lib/pth2h5.py +106 -0
- bithuman/plugins/__init__.py +0 -0
- bithuman/plugins/stt.py +185 -0
- bithuman/runtime.py +1004 -0
- bithuman/runtime_async.py +469 -0
- bithuman/service/__init__.py +9 -0
- bithuman/service/client.py +788 -0
- bithuman/service/messages.py +210 -0
- bithuman/service/server.py +759 -0
- bithuman/utils/__init__.py +43 -0
- bithuman/utils/agent.py +359 -0
- bithuman/utils/fps_controller.py +90 -0
- bithuman/utils/image.py +41 -0
- bithuman/utils/unzip.py +38 -0
- bithuman/video_graph/__init__.py +16 -0
- bithuman/video_graph/action_trigger.py +83 -0
- bithuman/video_graph/driver_video.py +482 -0
- bithuman/video_graph/navigator.py +736 -0
- bithuman/video_graph/trigger.py +90 -0
- bithuman/video_graph/video_script.py +344 -0
- bithuman-1.0.2.dist-info/METADATA +37 -0
- bithuman-1.0.2.dist-info/RECORD +44 -0
- bithuman-1.0.2.dist-info/WHEEL +5 -0
- bithuman-1.0.2.dist-info/top_level.txt +1 -0
bithuman/lib/pth2h5.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import argparse
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Union
|
|
6
|
+
|
|
7
|
+
import cv2
|
|
8
|
+
import h5py
|
|
9
|
+
import numpy as np
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def encode_mask_as_jpg(mask: np.ndarray, quality: int = 85) -> bytes:
|
|
13
|
+
"""Encode a mask array as JPG bytes.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
mask: Numpy array of mask (single channel)
|
|
17
|
+
quality: JPG compression quality (1-100)
|
|
18
|
+
|
|
19
|
+
Returns:
|
|
20
|
+
JPG encoded bytes
|
|
21
|
+
"""
|
|
22
|
+
# Convert to BGR for OpenCV
|
|
23
|
+
mask = (mask * 255).astype(np.uint8)
|
|
24
|
+
mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
|
|
25
|
+
|
|
26
|
+
# Encode as JPG
|
|
27
|
+
_, encoded = cv2.imencode(".jpg", mask, [int(cv2.IMWRITE_JPEG_QUALITY), quality])
|
|
28
|
+
return encoded.tobytes()
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def convert_pth_to_h5(
|
|
32
|
+
pth_path: Union[str, Path], h5_path: Union[str, Path] = None
|
|
33
|
+
) -> str:
|
|
34
|
+
"""Convert a PyTorch model file (.pth) to a HDF5 file (.h5).
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
pth_path: Path to the PyTorch model file
|
|
38
|
+
h5_path: Path to save the HDF5 file. If None, saves to same location as pth_path
|
|
39
|
+
"""
|
|
40
|
+
try:
|
|
41
|
+
import torch
|
|
42
|
+
except ImportError:
|
|
43
|
+
raise ImportError(
|
|
44
|
+
"PyTorch is not installed. Please install it using 'pip install torch'."
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
pth_path = Path(pth_path)
|
|
48
|
+
if h5_path is None:
|
|
49
|
+
h5_path = pth_path.with_suffix(".h5")
|
|
50
|
+
|
|
51
|
+
# Load the PyTorch model
|
|
52
|
+
data = torch.load(str(pth_path))
|
|
53
|
+
|
|
54
|
+
# Extract data
|
|
55
|
+
face_masks: list[bytes] = []
|
|
56
|
+
face_coords: list[np.ndarray] = []
|
|
57
|
+
frame_wh = data[0]["frame_wh"].numpy().astype(np.int32)
|
|
58
|
+
|
|
59
|
+
for item in data:
|
|
60
|
+
padded_crop_coords = item["padded_crop_coords"].numpy()
|
|
61
|
+
face_xyxy = item["face_coords"].numpy()
|
|
62
|
+
face_mask = item["face_mask"]
|
|
63
|
+
|
|
64
|
+
# Encode face mask if needed
|
|
65
|
+
if not isinstance(face_mask, bytes):
|
|
66
|
+
face_mask = encode_mask_as_jpg(face_mask.numpy())
|
|
67
|
+
|
|
68
|
+
# Adjust coordinates
|
|
69
|
+
shift_x, shift_y = padded_crop_coords[:2].astype(np.int32)
|
|
70
|
+
face_xyxy = face_xyxy[:4].astype(np.int32)
|
|
71
|
+
face_xyxy[0::2] += shift_x
|
|
72
|
+
face_xyxy[1::2] += shift_y
|
|
73
|
+
|
|
74
|
+
face_masks.append(face_mask)
|
|
75
|
+
face_coords.append(face_xyxy)
|
|
76
|
+
|
|
77
|
+
# Save to H5 file
|
|
78
|
+
with h5py.File(h5_path, "w") as f:
|
|
79
|
+
f.create_dataset("face_coords", data=face_coords)
|
|
80
|
+
|
|
81
|
+
dt = h5py.special_dtype(vlen=np.dtype("uint8"))
|
|
82
|
+
masks_dataset = f.create_dataset(
|
|
83
|
+
"face_masks", shape=(len(face_masks),), dtype=dt
|
|
84
|
+
)
|
|
85
|
+
for i, mask in enumerate(face_masks):
|
|
86
|
+
masks_dataset[i] = np.frombuffer(mask, dtype=np.uint8)
|
|
87
|
+
f.attrs["frame_wh"] = frame_wh
|
|
88
|
+
|
|
89
|
+
return str(h5_path)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def main():
|
|
93
|
+
parser = argparse.ArgumentParser(
|
|
94
|
+
description="Convert PyTorch .pth files to HDF5 format"
|
|
95
|
+
)
|
|
96
|
+
parser.add_argument("pth_path", type=str, help="Path to input .pth file")
|
|
97
|
+
parser.add_argument(
|
|
98
|
+
"--output", "-o", type=str, help="Path to output .h5 file (optional)"
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
args = parser.parse_args()
|
|
102
|
+
convert_pth_to_h5(args.pth_path, args.output)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
if __name__ == "__main__":
|
|
106
|
+
main()
|
|
File without changes
|
bithuman/plugins/stt.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Dict, Union
|
|
5
|
+
|
|
6
|
+
import aiohttp
|
|
7
|
+
import numpy as np
|
|
8
|
+
from loguru import logger
|
|
9
|
+
|
|
10
|
+
from bithuman.audio import resample
|
|
11
|
+
|
|
12
|
+
try:
|
|
13
|
+
from bithuman_local_voice import BithumanLocalSTT as BithumanSTTImpl
|
|
14
|
+
except ImportError:
|
|
15
|
+
BithumanSTTImpl = None
|
|
16
|
+
|
|
17
|
+
try:
|
|
18
|
+
from livekit.agents import utils
|
|
19
|
+
from livekit.agents.stt import (
|
|
20
|
+
STT,
|
|
21
|
+
SpeechData,
|
|
22
|
+
SpeechEvent,
|
|
23
|
+
SpeechEventType,
|
|
24
|
+
STTCapabilities,
|
|
25
|
+
)
|
|
26
|
+
from livekit.agents.types import NOT_GIVEN, APIConnectOptions, NotGivenOr
|
|
27
|
+
except ImportError:
|
|
28
|
+
raise ImportError(
|
|
29
|
+
"livekit is required, please install it with `pip install livekit-agents`"
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass
|
|
34
|
+
class _STTOptions:
|
|
35
|
+
locale: str = "en-US"
|
|
36
|
+
on_device: bool = True
|
|
37
|
+
punctuation: bool = True
|
|
38
|
+
debug: bool = False
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class BithumanSTTError(Exception):
|
|
42
|
+
pass
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class BithumanLocalSTT(STT):
|
|
46
|
+
_SAMPLE_RATE = 16000
|
|
47
|
+
|
|
48
|
+
def __init__(
|
|
49
|
+
self,
|
|
50
|
+
*,
|
|
51
|
+
locale="en-US",
|
|
52
|
+
server_url=None,
|
|
53
|
+
on_device=True,
|
|
54
|
+
punctuation=True,
|
|
55
|
+
debug=False,
|
|
56
|
+
):
|
|
57
|
+
capabilities = STTCapabilities(streaming=False, interim_results=False)
|
|
58
|
+
super().__init__(capabilities=capabilities)
|
|
59
|
+
self._opts = _STTOptions(
|
|
60
|
+
locale=locale, on_device=on_device, punctuation=punctuation, debug=debug
|
|
61
|
+
)
|
|
62
|
+
self._server_url: str | None = None
|
|
63
|
+
self._session: aiohttp.ClientSession | None = None
|
|
64
|
+
self._stt_impl = None
|
|
65
|
+
|
|
66
|
+
if server_url:
|
|
67
|
+
self._server_url = server_url
|
|
68
|
+
self._session = aiohttp.ClientSession()
|
|
69
|
+
else:
|
|
70
|
+
if BithumanSTTImpl is None:
|
|
71
|
+
raise ImportError(
|
|
72
|
+
"bithuman_local_voice is required if server_url is not provided, "
|
|
73
|
+
"please install it with `pip install bithuman_local_voice`"
|
|
74
|
+
)
|
|
75
|
+
self._stt_impl = BithumanSTTImpl(
|
|
76
|
+
locale=locale,
|
|
77
|
+
on_device=on_device,
|
|
78
|
+
punctuation=punctuation,
|
|
79
|
+
debug=debug,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
async def _recognize_impl(
|
|
83
|
+
self,
|
|
84
|
+
buffer: utils.audio.AudioBuffer,
|
|
85
|
+
*,
|
|
86
|
+
language: NotGivenOr[str] = NOT_GIVEN,
|
|
87
|
+
conn_options: APIConnectOptions,
|
|
88
|
+
):
|
|
89
|
+
if utils.is_given(language) and language != self._opts.locale:
|
|
90
|
+
try:
|
|
91
|
+
await self._set_locale(language)
|
|
92
|
+
except Exception as e:
|
|
93
|
+
logger.error(f"Failed to set locale: {e}")
|
|
94
|
+
|
|
95
|
+
frame = utils.audio.combine_frames(buffer)
|
|
96
|
+
audio_data = np.frombuffer(frame.data, dtype=np.int16)
|
|
97
|
+
if frame.sample_rate != self._SAMPLE_RATE:
|
|
98
|
+
audio_data = resample(audio_data, frame.sample_rate, self._SAMPLE_RATE)
|
|
99
|
+
|
|
100
|
+
try:
|
|
101
|
+
result = await self._recognize_audio(
|
|
102
|
+
audio_data, sample_rate=self._SAMPLE_RATE
|
|
103
|
+
)
|
|
104
|
+
except Exception as e:
|
|
105
|
+
logger.warning(f"Failed to recognize audio with error: {e}")
|
|
106
|
+
result = {"text": "", "confidence": 0.0}
|
|
107
|
+
|
|
108
|
+
return SpeechEvent(
|
|
109
|
+
type=SpeechEventType.FINAL_TRANSCRIPT,
|
|
110
|
+
request_id="",
|
|
111
|
+
alternatives=[
|
|
112
|
+
SpeechData(
|
|
113
|
+
language=self._opts.locale,
|
|
114
|
+
text=result["text"],
|
|
115
|
+
confidence=result["confidence"],
|
|
116
|
+
)
|
|
117
|
+
],
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
async def _recognize_audio(
|
|
121
|
+
self, audio_data: Union[bytes, np.ndarray], sample_rate: int = 16000
|
|
122
|
+
) -> Dict:
|
|
123
|
+
"""Recognize speech in the provided audio data.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
audio_data: Audio data as bytes or numpy array
|
|
127
|
+
sample_rate: Sample rate of the audio data
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
Dict containing transcription and confidence score
|
|
131
|
+
"""
|
|
132
|
+
if self._stt_impl is not None:
|
|
133
|
+
return await self._stt_impl.recognize(audio_data, sample_rate)
|
|
134
|
+
|
|
135
|
+
# Convert numpy array to bytes if needed
|
|
136
|
+
if isinstance(audio_data, np.ndarray):
|
|
137
|
+
# Convert to float in [-1, 1] range if not already
|
|
138
|
+
if audio_data.dtype != np.int16:
|
|
139
|
+
audio_data = (audio_data * 32767).astype(np.int16)
|
|
140
|
+
audio_bytes = audio_data.tobytes()
|
|
141
|
+
else:
|
|
142
|
+
audio_bytes = audio_data
|
|
143
|
+
|
|
144
|
+
async with self._session.post(
|
|
145
|
+
f"{self._server_url}/transcribe",
|
|
146
|
+
data=audio_bytes,
|
|
147
|
+
headers={"Content-Type": "application/octet-stream"},
|
|
148
|
+
timeout=5,
|
|
149
|
+
) as response:
|
|
150
|
+
response.raise_for_status()
|
|
151
|
+
result: dict = await response.json()
|
|
152
|
+
return {
|
|
153
|
+
"text": result.get("text", ""),
|
|
154
|
+
"confidence": result.get("confidence", 0.0),
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
async def _set_locale(self, locale: str):
|
|
158
|
+
"""Set the recognition locale.
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
locale: Locale identifier (e.g. en-US, fr-FR)
|
|
162
|
+
"""
|
|
163
|
+
if self._stt_impl is not None:
|
|
164
|
+
await self._stt_impl.set_locale(locale)
|
|
165
|
+
self._opts.locale = locale
|
|
166
|
+
return
|
|
167
|
+
|
|
168
|
+
assert self._server_url is not None
|
|
169
|
+
async with self._session.post(
|
|
170
|
+
f"{self._server_url}/setLocale",
|
|
171
|
+
json={"locale": locale},
|
|
172
|
+
headers={"Content-Type": "application/json"},
|
|
173
|
+
timeout=5,
|
|
174
|
+
) as response:
|
|
175
|
+
response.raise_for_status()
|
|
176
|
+
result: dict = await response.json()
|
|
177
|
+
if result.get("success"):
|
|
178
|
+
self._opts.locale = locale
|
|
179
|
+
raise BithumanSTTError(result.get("message", "Unknown error"))
|
|
180
|
+
|
|
181
|
+
async def aclose(self):
|
|
182
|
+
if self._stt_impl is not None:
|
|
183
|
+
await self._stt_impl.stop()
|
|
184
|
+
if self._session is not None:
|
|
185
|
+
await self._session.close()
|