vox-parakeet 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: vox-parakeet
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: NVIDIA Parakeet STT adapter for Vox
|
|
5
|
+
Requires-Python: >=3.11
|
|
6
|
+
Requires-Dist: numpy<2.4,>=1.26.0
|
|
7
|
+
Requires-Dist: onnx-asr[gpu,hub]>=0.10.0; platform_machine == 'x86_64'
|
|
8
|
+
Requires-Dist: onnx-asr[hub]>=0.10.0; platform_machine != 'x86_64'
|
|
9
|
+
Requires-Dist: soundfile>=0.13.1
|
|
10
|
+
Requires-Dist: vox>=0.1.0
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "vox-parakeet"
|
|
3
|
+
version = "0.1.0"
|
|
4
|
+
description = "NVIDIA Parakeet STT adapter for Vox"
|
|
5
|
+
requires-python = ">=3.11"
|
|
6
|
+
dependencies = [
|
|
7
|
+
"vox>=0.1.0",
|
|
8
|
+
"onnx-asr[gpu,hub]>=0.10.0; platform_machine == 'x86_64'",
|
|
9
|
+
"onnx-asr[hub]>=0.10.0; platform_machine != 'x86_64'",
|
|
10
|
+
"numpy>=1.26.0,<2.4",
|
|
11
|
+
"soundfile>=0.13.1",
|
|
12
|
+
]
|
|
13
|
+
|
|
14
|
+
[project.entry-points."vox.adapters"]
|
|
15
|
+
parakeet = "vox_parakeet.adapter:ParakeetAdapter"
|
|
16
|
+
|
|
17
|
+
[build-system]
|
|
18
|
+
requires = ["hatchling"]
|
|
19
|
+
build-backend = "hatchling.build"
|
|
20
|
+
|
|
21
|
+
[tool.hatch.build.targets.wheel]
|
|
22
|
+
packages = ["src/vox_parakeet"]
|
|
@@ -0,0 +1,288 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import tempfile
|
|
5
|
+
import time
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
import onnx_asr
|
|
12
|
+
import soundfile as sf
|
|
13
|
+
from numpy.typing import NDArray
|
|
14
|
+
|
|
15
|
+
from vox.core.adapter import STTAdapter
|
|
16
|
+
from vox.core.types import (
|
|
17
|
+
AdapterInfo,
|
|
18
|
+
ModelFormat,
|
|
19
|
+
ModelType,
|
|
20
|
+
TranscribeResult,
|
|
21
|
+
TranscriptSegment,
|
|
22
|
+
WordTimestamp,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
PARAKEET_SAMPLE_RATE = 16_000
|
|
28
|
+
DEFAULT_MODEL_ID = "nemo-parakeet-tdt-0.6b-v3"
|
|
29
|
+
|
|
30
|
+
# Rough VRAM estimates by model size token in the model ID.
|
|
31
|
+
_VRAM_ESTIMATES: dict[str, int] = {
|
|
32
|
+
"0.6b": 1_300_000_000, # ~1.3 GB
|
|
33
|
+
"1.1b": 2_500_000_000, # ~2.5 GB
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _normalize_model_id(model_id: str) -> str:
|
|
38
|
+
"""Convert a HuggingFace repo ID (e.g. ``nvidia/parakeet-tdt-0.6b-v3``)
|
|
39
|
+
to the ``nemo-`` prefixed form that ``onnx-asr`` expects
|
|
40
|
+
(e.g. ``nemo-parakeet-tdt-0.6b-v3``).
|
|
41
|
+
|
|
42
|
+
If the string already starts with ``nemo-`` or has no known prefix it is
|
|
43
|
+
returned unchanged.
|
|
44
|
+
"""
|
|
45
|
+
if "/" in model_id:
|
|
46
|
+
# Take the repo name after the slash and add the nemo- prefix.
|
|
47
|
+
_, repo_name = model_id.split("/", 1)
|
|
48
|
+
return f"nemo-{repo_name}"
|
|
49
|
+
return model_id
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _resolve_model_id(model_path: str, source: str | None) -> str:
|
|
53
|
+
"""Resolve the ONNX-ASR model identifier for a local path or catalog source."""
|
|
54
|
+
if source:
|
|
55
|
+
return _normalize_model_id(source)
|
|
56
|
+
|
|
57
|
+
path = Path(model_path)
|
|
58
|
+
if path.exists():
|
|
59
|
+
return str(path)
|
|
60
|
+
|
|
61
|
+
return _normalize_model_id(model_path)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _get_providers(device: str) -> list[str]:
|
|
65
|
+
"""Return ONNX Runtime execution providers for *device*."""
|
|
66
|
+
if device == "cpu":
|
|
67
|
+
return ["CPUExecutionProvider"]
|
|
68
|
+
|
|
69
|
+
if device == "cuda":
|
|
70
|
+
try:
|
|
71
|
+
import onnxruntime as ort
|
|
72
|
+
|
|
73
|
+
available = ort.get_available_providers()
|
|
74
|
+
except ImportError:
|
|
75
|
+
raise RuntimeError("Parakeet requires onnxruntime to be installed") from None
|
|
76
|
+
|
|
77
|
+
if "CUDAExecutionProvider" in available:
|
|
78
|
+
return ["CUDAExecutionProvider", "CPUExecutionProvider"]
|
|
79
|
+
|
|
80
|
+
raise RuntimeError(
|
|
81
|
+
"Parakeet requires CUDAExecutionProvider for non-CPU devices; "
|
|
82
|
+
"CPU fallback is disabled"
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
return ["CPUExecutionProvider"]
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
@dataclass
|
|
89
|
+
class _Word:
|
|
90
|
+
word: str
|
|
91
|
+
start: float
|
|
92
|
+
end: float
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def _tokens_to_words(tokens: list[str], timestamps: list[float]) -> list[_Word]:
|
|
96
|
+
"""Merge sub-word tokens into whole words using leading-space heuristics."""
|
|
97
|
+
if not tokens or not timestamps:
|
|
98
|
+
return []
|
|
99
|
+
|
|
100
|
+
if len(tokens) != len(timestamps):
|
|
101
|
+
logger.warning(
|
|
102
|
+
"Token/timestamp length mismatch: %d tokens, %d timestamps",
|
|
103
|
+
len(tokens),
|
|
104
|
+
len(timestamps),
|
|
105
|
+
)
|
|
106
|
+
min_len = min(len(tokens), len(timestamps))
|
|
107
|
+
tokens = tokens[:min_len]
|
|
108
|
+
timestamps = timestamps[:min_len]
|
|
109
|
+
|
|
110
|
+
words: list[_Word] = []
|
|
111
|
+
current_word = ""
|
|
112
|
+
current_start: float | None = None
|
|
113
|
+
|
|
114
|
+
for token, ts in zip(tokens, timestamps, strict=False):
|
|
115
|
+
token_stripped = token.strip()
|
|
116
|
+
if not token_stripped:
|
|
117
|
+
continue
|
|
118
|
+
|
|
119
|
+
is_punctuation = len(token_stripped) == 1 and not token_stripped.isalnum()
|
|
120
|
+
|
|
121
|
+
if token.startswith(" ") or current_start is None:
|
|
122
|
+
if current_word and current_start is not None:
|
|
123
|
+
words.append(_Word(word=current_word, start=current_start, end=ts))
|
|
124
|
+
current_word = token_stripped
|
|
125
|
+
current_start = ts
|
|
126
|
+
elif is_punctuation:
|
|
127
|
+
current_word += token_stripped
|
|
128
|
+
else:
|
|
129
|
+
current_word += token_stripped
|
|
130
|
+
|
|
131
|
+
if current_word and current_start is not None:
|
|
132
|
+
end_time = timestamps[-1] if timestamps else current_start
|
|
133
|
+
words.append(_Word(word=current_word, start=current_start, end=end_time))
|
|
134
|
+
|
|
135
|
+
return words
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def _estimate_vram(model_id: str) -> int:
|
|
139
|
+
"""Return a rough VRAM estimate in bytes based on the model ID."""
|
|
140
|
+
lower = model_id.lower()
|
|
141
|
+
for size_key, vram in _VRAM_ESTIMATES.items():
|
|
142
|
+
if size_key in lower:
|
|
143
|
+
return vram
|
|
144
|
+
# Conservative fallback for unknown sizes.
|
|
145
|
+
return _VRAM_ESTIMATES["0.6b"]
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
class ParakeetAdapter(STTAdapter):
|
|
149
|
+
"""NVIDIA Parakeet STT adapter backed by ``onnx-asr``."""
|
|
150
|
+
|
|
151
|
+
def __init__(self) -> None:
|
|
152
|
+
self._model: onnx_asr.adapters.TextResultsAsrAdapter | None = None
|
|
153
|
+
self._model_with_ts: onnx_asr.adapters.TimestampedResultsAsrAdapter | None = None
|
|
154
|
+
self._loaded = False
|
|
155
|
+
self._model_id: str = DEFAULT_MODEL_ID
|
|
156
|
+
self._device: str = "cpu"
|
|
157
|
+
|
|
158
|
+
# ------------------------------------------------------------------
|
|
159
|
+
# STTAdapter interface
|
|
160
|
+
# ------------------------------------------------------------------
|
|
161
|
+
|
|
162
|
+
def info(self) -> AdapterInfo:
|
|
163
|
+
return AdapterInfo(
|
|
164
|
+
name="parakeet",
|
|
165
|
+
type=ModelType.STT,
|
|
166
|
+
architectures=("parakeet", "parakeet-tdt", "parakeet-ctc"),
|
|
167
|
+
default_sample_rate=PARAKEET_SAMPLE_RATE,
|
|
168
|
+
supported_formats=(ModelFormat.ONNX,),
|
|
169
|
+
supports_streaming=False,
|
|
170
|
+
supports_word_timestamps=True,
|
|
171
|
+
supports_language_detection=False,
|
|
172
|
+
supported_languages=("en",),
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
def load(self, model_path: str, device: str, **kwargs: Any) -> None:
|
|
176
|
+
if self._loaded:
|
|
177
|
+
return
|
|
178
|
+
|
|
179
|
+
self._device = device
|
|
180
|
+
# Prefer the catalog source (HuggingFace repo ID) passed via _source;
|
|
181
|
+
# fall back to model_path for backward compatibility.
|
|
182
|
+
source = kwargs.pop("_source", None)
|
|
183
|
+
self._model_id = _resolve_model_id(model_path, source)
|
|
184
|
+
|
|
185
|
+
logger.info("Loading Parakeet ONNX model: %s", self._model_id)
|
|
186
|
+
start = time.perf_counter()
|
|
187
|
+
|
|
188
|
+
providers = _get_providers(device)
|
|
189
|
+
self._model = onnx_asr.load_model(self._model_id, providers=providers)
|
|
190
|
+
self._model_with_ts = self._model.with_timestamps()
|
|
191
|
+
|
|
192
|
+
elapsed = time.perf_counter() - start
|
|
193
|
+
logger.info("Parakeet model loaded in %.2fs", elapsed)
|
|
194
|
+
self._loaded = True
|
|
195
|
+
|
|
196
|
+
def unload(self) -> None:
|
|
197
|
+
self._model = None
|
|
198
|
+
self._model_with_ts = None
|
|
199
|
+
self._loaded = False
|
|
200
|
+
logger.info("Parakeet adapter unloaded")
|
|
201
|
+
|
|
202
|
+
@property
|
|
203
|
+
def is_loaded(self) -> bool:
|
|
204
|
+
return self._loaded
|
|
205
|
+
|
|
206
|
+
def transcribe(
|
|
207
|
+
self,
|
|
208
|
+
audio: NDArray[np.float32],
|
|
209
|
+
*,
|
|
210
|
+
language: str | None = None,
|
|
211
|
+
word_timestamps: bool = False,
|
|
212
|
+
initial_prompt: str | None = None,
|
|
213
|
+
temperature: float = 0.0,
|
|
214
|
+
) -> TranscribeResult:
|
|
215
|
+
if not self._loaded or self._model is None:
|
|
216
|
+
raise RuntimeError("Parakeet model is not loaded — call load() first")
|
|
217
|
+
|
|
218
|
+
if language and language not in ("en", "english"):
|
|
219
|
+
logger.warning("Parakeet only supports English, ignoring language=%s", language)
|
|
220
|
+
|
|
221
|
+
if len(audio) == 0:
|
|
222
|
+
logger.warning("Empty audio buffer, returning empty result")
|
|
223
|
+
return TranscribeResult(text="", language="en", duration_ms=0, model=self._model_id)
|
|
224
|
+
|
|
225
|
+
audio_duration_ms = int(len(audio) / PARAKEET_SAMPLE_RATE * 1000)
|
|
226
|
+
|
|
227
|
+
# onnx-asr requires a file path — write audio to a temporary WAV.
|
|
228
|
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
|
|
229
|
+
sf.write(tmp.name, audio, PARAKEET_SAMPLE_RATE)
|
|
230
|
+
temp_path = Path(tmp.name)
|
|
231
|
+
|
|
232
|
+
try:
|
|
233
|
+
if word_timestamps:
|
|
234
|
+
result = self._model_with_ts.recognize(str(temp_path))
|
|
235
|
+
text = result.text
|
|
236
|
+
words = _tokens_to_words(result.tokens, result.timestamps)
|
|
237
|
+
word_ts = tuple(
|
|
238
|
+
WordTimestamp(
|
|
239
|
+
word=w.word,
|
|
240
|
+
start_ms=int(w.start * 1000),
|
|
241
|
+
end_ms=int(w.end * 1000),
|
|
242
|
+
)
|
|
243
|
+
for w in words
|
|
244
|
+
)
|
|
245
|
+
segments = (
|
|
246
|
+
(TranscriptSegment(
|
|
247
|
+
text=text,
|
|
248
|
+
start_ms=0,
|
|
249
|
+
end_ms=audio_duration_ms,
|
|
250
|
+
words=word_ts,
|
|
251
|
+
language="en",
|
|
252
|
+
),)
|
|
253
|
+
if text
|
|
254
|
+
else ()
|
|
255
|
+
)
|
|
256
|
+
else:
|
|
257
|
+
text = self._model.recognize(str(temp_path))
|
|
258
|
+
segments = (
|
|
259
|
+
(TranscriptSegment(
|
|
260
|
+
text=text,
|
|
261
|
+
start_ms=0,
|
|
262
|
+
end_ms=audio_duration_ms,
|
|
263
|
+
language="en",
|
|
264
|
+
),)
|
|
265
|
+
if text
|
|
266
|
+
else ()
|
|
267
|
+
)
|
|
268
|
+
finally:
|
|
269
|
+
temp_path.unlink(missing_ok=True)
|
|
270
|
+
|
|
271
|
+
text = text.strip() if text else ""
|
|
272
|
+
|
|
273
|
+
if not text:
|
|
274
|
+
logger.warning("Empty transcription for %dms audio", audio_duration_ms)
|
|
275
|
+
else:
|
|
276
|
+
logger.info("Transcribed %dms audio: %s", audio_duration_ms, text[:80])
|
|
277
|
+
|
|
278
|
+
return TranscribeResult(
|
|
279
|
+
text=text,
|
|
280
|
+
segments=segments,
|
|
281
|
+
language="en",
|
|
282
|
+
duration_ms=audio_duration_ms,
|
|
283
|
+
model=self._model_id,
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
def estimate_vram_bytes(self, **kwargs: Any) -> int:
|
|
287
|
+
model_id = kwargs.get("_source") or kwargs.get("model_id") or self._model_id or DEFAULT_MODEL_ID
|
|
288
|
+
return _estimate_vram(_normalize_model_id(str(model_id)))
|