vox-parakeet 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,17 @@
1
+ __pycache__/
2
+ *.py[cod]
3
+ *$py.class
4
+ *.egg-info/
5
+ dist/
6
+ build/
7
+ .eggs/
8
+ *.egg
9
+ .venv/
10
+ venv/
11
+ .env
12
+ .pytest_cache/
13
+ .ruff_cache/
14
+ .mypy_cache/
15
+ *.so
16
+ *.dylib
17
+ .coverage
@@ -0,0 +1,10 @@
1
+ Metadata-Version: 2.4
2
+ Name: vox-parakeet
3
+ Version: 0.1.0
4
+ Summary: NVIDIA Parakeet STT adapter for Vox
5
+ Requires-Python: >=3.11
6
+ Requires-Dist: numpy<2.4,>=1.26.0
7
+ Requires-Dist: onnx-asr[gpu,hub]>=0.10.0; platform_machine == 'x86_64'
8
+ Requires-Dist: onnx-asr[hub]>=0.10.0; platform_machine != 'x86_64'
9
+ Requires-Dist: soundfile>=0.13.1
10
+ Requires-Dist: vox>=0.1.0
@@ -0,0 +1,22 @@
1
+ [project]
2
+ name = "vox-parakeet"
3
+ version = "0.1.0"
4
+ description = "NVIDIA Parakeet STT adapter for Vox"
5
+ requires-python = ">=3.11"
6
+ dependencies = [
7
+ "vox>=0.1.0",
8
+ "onnx-asr[gpu,hub]>=0.10.0; platform_machine == 'x86_64'",
9
+ "onnx-asr[hub]>=0.10.0; platform_machine != 'x86_64'",
10
+ "numpy>=1.26.0,<2.4",
11
+ "soundfile>=0.13.1",
12
+ ]
13
+
14
+ [project.entry-points."vox.adapters"]
15
+ parakeet = "vox_parakeet.adapter:ParakeetAdapter"
16
+
17
+ [build-system]
18
+ requires = ["hatchling"]
19
+ build-backend = "hatchling.build"
20
+
21
+ [tool.hatch.build.targets.wheel]
22
+ packages = ["src/vox_parakeet"]
@@ -0,0 +1,3 @@
1
+ from vox_parakeet.adapter import ParakeetAdapter
2
+
3
+ __all__ = ["ParakeetAdapter"]
@@ -0,0 +1,288 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import tempfile
5
+ import time
6
+ from dataclasses import dataclass
7
+ from pathlib import Path
8
+ from typing import Any
9
+
10
+ import numpy as np
11
+ import onnx_asr
12
+ import soundfile as sf
13
+ from numpy.typing import NDArray
14
+
15
+ from vox.core.adapter import STTAdapter
16
+ from vox.core.types import (
17
+ AdapterInfo,
18
+ ModelFormat,
19
+ ModelType,
20
+ TranscribeResult,
21
+ TranscriptSegment,
22
+ WordTimestamp,
23
+ )
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+ PARAKEET_SAMPLE_RATE = 16_000
28
+ DEFAULT_MODEL_ID = "nemo-parakeet-tdt-0.6b-v3"
29
+
30
+ # Rough VRAM estimates by model size token in the model ID.
31
+ _VRAM_ESTIMATES: dict[str, int] = {
32
+ "0.6b": 1_300_000_000, # ~1.3 GB
33
+ "1.1b": 2_500_000_000, # ~2.5 GB
34
+ }
35
+
36
+
37
+ def _normalize_model_id(model_id: str) -> str:
38
+ """Convert a HuggingFace repo ID (e.g. ``nvidia/parakeet-tdt-0.6b-v3``)
39
+ to the ``nemo-`` prefixed form that ``onnx-asr`` expects
40
+ (e.g. ``nemo-parakeet-tdt-0.6b-v3``).
41
+
42
+ If the string already starts with ``nemo-`` or has no known prefix it is
43
+ returned unchanged.
44
+ """
45
+ if "/" in model_id:
46
+ # Take the repo name after the slash and add the nemo- prefix.
47
+ _, repo_name = model_id.split("/", 1)
48
+ return f"nemo-{repo_name}"
49
+ return model_id
50
+
51
+
52
+ def _resolve_model_id(model_path: str, source: str | None) -> str:
53
+ """Resolve the ONNX-ASR model identifier for a local path or catalog source."""
54
+ if source:
55
+ return _normalize_model_id(source)
56
+
57
+ path = Path(model_path)
58
+ if path.exists():
59
+ return str(path)
60
+
61
+ return _normalize_model_id(model_path)
62
+
63
+
64
+ def _get_providers(device: str) -> list[str]:
65
+ """Return ONNX Runtime execution providers for *device*."""
66
+ if device == "cpu":
67
+ return ["CPUExecutionProvider"]
68
+
69
+ if device == "cuda":
70
+ try:
71
+ import onnxruntime as ort
72
+
73
+ available = ort.get_available_providers()
74
+ except ImportError:
75
+ raise RuntimeError("Parakeet requires onnxruntime to be installed") from None
76
+
77
+ if "CUDAExecutionProvider" in available:
78
+ return ["CUDAExecutionProvider", "CPUExecutionProvider"]
79
+
80
+ raise RuntimeError(
81
+ "Parakeet requires CUDAExecutionProvider for non-CPU devices; "
82
+ "CPU fallback is disabled"
83
+ )
84
+
85
+ return ["CPUExecutionProvider"]
86
+
87
+
88
+ @dataclass
89
+ class _Word:
90
+ word: str
91
+ start: float
92
+ end: float
93
+
94
+
95
+ def _tokens_to_words(tokens: list[str], timestamps: list[float]) -> list[_Word]:
96
+ """Merge sub-word tokens into whole words using leading-space heuristics."""
97
+ if not tokens or not timestamps:
98
+ return []
99
+
100
+ if len(tokens) != len(timestamps):
101
+ logger.warning(
102
+ "Token/timestamp length mismatch: %d tokens, %d timestamps",
103
+ len(tokens),
104
+ len(timestamps),
105
+ )
106
+ min_len = min(len(tokens), len(timestamps))
107
+ tokens = tokens[:min_len]
108
+ timestamps = timestamps[:min_len]
109
+
110
+ words: list[_Word] = []
111
+ current_word = ""
112
+ current_start: float | None = None
113
+
114
+ for token, ts in zip(tokens, timestamps, strict=False):
115
+ token_stripped = token.strip()
116
+ if not token_stripped:
117
+ continue
118
+
119
+ is_punctuation = len(token_stripped) == 1 and not token_stripped.isalnum()
120
+
121
+ if token.startswith(" ") or current_start is None:
122
+ if current_word and current_start is not None:
123
+ words.append(_Word(word=current_word, start=current_start, end=ts))
124
+ current_word = token_stripped
125
+ current_start = ts
126
+ elif is_punctuation:
127
+ current_word += token_stripped
128
+ else:
129
+ current_word += token_stripped
130
+
131
+ if current_word and current_start is not None:
132
+ end_time = timestamps[-1] if timestamps else current_start
133
+ words.append(_Word(word=current_word, start=current_start, end=end_time))
134
+
135
+ return words
136
+
137
+
138
+ def _estimate_vram(model_id: str) -> int:
139
+ """Return a rough VRAM estimate in bytes based on the model ID."""
140
+ lower = model_id.lower()
141
+ for size_key, vram in _VRAM_ESTIMATES.items():
142
+ if size_key in lower:
143
+ return vram
144
+ # Conservative fallback for unknown sizes.
145
+ return _VRAM_ESTIMATES["0.6b"]
146
+
147
+
148
+ class ParakeetAdapter(STTAdapter):
149
+ """NVIDIA Parakeet STT adapter backed by ``onnx-asr``."""
150
+
151
+ def __init__(self) -> None:
152
+ self._model: onnx_asr.adapters.TextResultsAsrAdapter | None = None
153
+ self._model_with_ts: onnx_asr.adapters.TimestampedResultsAsrAdapter | None = None
154
+ self._loaded = False
155
+ self._model_id: str = DEFAULT_MODEL_ID
156
+ self._device: str = "cpu"
157
+
158
+ # ------------------------------------------------------------------
159
+ # STTAdapter interface
160
+ # ------------------------------------------------------------------
161
+
162
+ def info(self) -> AdapterInfo:
163
+ return AdapterInfo(
164
+ name="parakeet",
165
+ type=ModelType.STT,
166
+ architectures=("parakeet", "parakeet-tdt", "parakeet-ctc"),
167
+ default_sample_rate=PARAKEET_SAMPLE_RATE,
168
+ supported_formats=(ModelFormat.ONNX,),
169
+ supports_streaming=False,
170
+ supports_word_timestamps=True,
171
+ supports_language_detection=False,
172
+ supported_languages=("en",),
173
+ )
174
+
175
+ def load(self, model_path: str, device: str, **kwargs: Any) -> None:
176
+ if self._loaded:
177
+ return
178
+
179
+ self._device = device
180
+ # Prefer the catalog source (HuggingFace repo ID) passed via _source;
181
+ # fall back to model_path for backward compatibility.
182
+ source = kwargs.pop("_source", None)
183
+ self._model_id = _resolve_model_id(model_path, source)
184
+
185
+ logger.info("Loading Parakeet ONNX model: %s", self._model_id)
186
+ start = time.perf_counter()
187
+
188
+ providers = _get_providers(device)
189
+ self._model = onnx_asr.load_model(self._model_id, providers=providers)
190
+ self._model_with_ts = self._model.with_timestamps()
191
+
192
+ elapsed = time.perf_counter() - start
193
+ logger.info("Parakeet model loaded in %.2fs", elapsed)
194
+ self._loaded = True
195
+
196
+ def unload(self) -> None:
197
+ self._model = None
198
+ self._model_with_ts = None
199
+ self._loaded = False
200
+ logger.info("Parakeet adapter unloaded")
201
+
202
+ @property
203
+ def is_loaded(self) -> bool:
204
+ return self._loaded
205
+
206
+ def transcribe(
207
+ self,
208
+ audio: NDArray[np.float32],
209
+ *,
210
+ language: str | None = None,
211
+ word_timestamps: bool = False,
212
+ initial_prompt: str | None = None,
213
+ temperature: float = 0.0,
214
+ ) -> TranscribeResult:
215
+ if not self._loaded or self._model is None:
216
+ raise RuntimeError("Parakeet model is not loaded — call load() first")
217
+
218
+ if language and language not in ("en", "english"):
219
+ logger.warning("Parakeet only supports English, ignoring language=%s", language)
220
+
221
+ if len(audio) == 0:
222
+ logger.warning("Empty audio buffer, returning empty result")
223
+ return TranscribeResult(text="", language="en", duration_ms=0, model=self._model_id)
224
+
225
+ audio_duration_ms = int(len(audio) / PARAKEET_SAMPLE_RATE * 1000)
226
+
227
+ # onnx-asr requires a file path — write audio to a temporary WAV.
228
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
229
+ sf.write(tmp.name, audio, PARAKEET_SAMPLE_RATE)
230
+ temp_path = Path(tmp.name)
231
+
232
+ try:
233
+ if word_timestamps:
234
+ result = self._model_with_ts.recognize(str(temp_path))
235
+ text = result.text
236
+ words = _tokens_to_words(result.tokens, result.timestamps)
237
+ word_ts = tuple(
238
+ WordTimestamp(
239
+ word=w.word,
240
+ start_ms=int(w.start * 1000),
241
+ end_ms=int(w.end * 1000),
242
+ )
243
+ for w in words
244
+ )
245
+ segments = (
246
+ (TranscriptSegment(
247
+ text=text,
248
+ start_ms=0,
249
+ end_ms=audio_duration_ms,
250
+ words=word_ts,
251
+ language="en",
252
+ ),)
253
+ if text
254
+ else ()
255
+ )
256
+ else:
257
+ text = self._model.recognize(str(temp_path))
258
+ segments = (
259
+ (TranscriptSegment(
260
+ text=text,
261
+ start_ms=0,
262
+ end_ms=audio_duration_ms,
263
+ language="en",
264
+ ),)
265
+ if text
266
+ else ()
267
+ )
268
+ finally:
269
+ temp_path.unlink(missing_ok=True)
270
+
271
+ text = text.strip() if text else ""
272
+
273
+ if not text:
274
+ logger.warning("Empty transcription for %dms audio", audio_duration_ms)
275
+ else:
276
+ logger.info("Transcribed %dms audio: %s", audio_duration_ms, text[:80])
277
+
278
+ return TranscribeResult(
279
+ text=text,
280
+ segments=segments,
281
+ language="en",
282
+ duration_ms=audio_duration_ms,
283
+ model=self._model_id,
284
+ )
285
+
286
+ def estimate_vram_bytes(self, **kwargs: Any) -> int:
287
+ model_id = kwargs.get("_source") or kwargs.get("model_id") or self._model_id or DEFAULT_MODEL_ID
288
+ return _estimate_vram(_normalize_model_id(str(model_id)))