jaxcld 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. jaxcld-0.1.0/PKG-INFO +80 -0
  2. jaxcld-0.1.0/README.md +314 -0
  3. jaxcld-0.1.0/README_PIP.md +44 -0
  4. jaxcld-0.1.0/jaxcld/__init__.py +47 -0
  5. jaxcld-0.1.0/jaxcld/models/__init__.py +11 -0
  6. jaxcld-0.1.0/jaxcld/models/asr_model.py +495 -0
  7. jaxcld-0.1.0/jaxcld/models/cvx_grelu_mlp.py +74 -0
  8. jaxcld-0.1.0/jaxcld/models/cvx_mlp.py +63 -0
  9. jaxcld-0.1.0/jaxcld/models/cvx_relu_mlp.py +120 -0
  10. jaxcld-0.1.0/jaxcld/models/get_model.py +26 -0
  11. jaxcld-0.1.0/jaxcld/models/grelu_mlp.py +44 -0
  12. jaxcld-0.1.0/jaxcld/models/lang_detect_head.py +152 -0
  13. jaxcld-0.1.0/jaxcld/models/relu_mlp.py +71 -0
  14. jaxcld-0.1.0/jaxcld/models/two_layer_mlp.py +11 -0
  15. jaxcld-0.1.0/jaxcld/optimizers/__init__.py +4 -0
  16. jaxcld-0.1.0/jaxcld/optimizers/adamW.py +39 -0
  17. jaxcld-0.1.0/jaxcld/optimizers/admm.py +103 -0
  18. jaxcld-0.1.0/jaxcld/optimizers/dadapt_adamW.py +38 -0
  19. jaxcld-0.1.0/jaxcld/optimizers/dist_shampoo/__init__.py +4 -0
  20. jaxcld-0.1.0/jaxcld/optimizers/dist_shampoo/distributed_shampoo.py +2831 -0
  21. jaxcld-0.1.0/jaxcld/optimizers/dist_shampoo/quantization_utils.py +115 -0
  22. jaxcld-0.1.0/jaxcld/optimizers/pcg.py +69 -0
  23. jaxcld-0.1.0/jaxcld/optimizers/sgd.py +38 -0
  24. jaxcld-0.1.0/jaxcld/optimizers/shampoo.py +37 -0
  25. jaxcld-0.1.0/jaxcld/optimizers/yogi.py +36 -0
  26. jaxcld-0.1.0/jaxcld/preconditioner/__init__.py +4 -0
  27. jaxcld-0.1.0/jaxcld/preconditioner/nystrom.py +102 -0
  28. jaxcld-0.1.0/jaxcld/training/__init__.py +8 -0
  29. jaxcld-0.1.0/jaxcld/training/train.py +164 -0
  30. jaxcld-0.1.0/jaxcld/training/train_no_jit.py +126 -0
  31. jaxcld-0.1.0/jaxcld/utils/__init__.py +4 -0
  32. jaxcld-0.1.0/jaxcld/utils/linops_utils.py +50 -0
  33. jaxcld-0.1.0/jaxcld/utils/load_data.py +459 -0
  34. jaxcld-0.1.0/jaxcld/utils/metric_utils.py +59 -0
  35. jaxcld-0.1.0/jaxcld/utils/model_utils.py +113 -0
  36. jaxcld-0.1.0/jaxcld/utils/opt_utils.py +31 -0
  37. jaxcld-0.1.0/jaxcld/utils/proximal_utils.py +22 -0
  38. jaxcld-0.1.0/jaxcld/utils/train_utils.py +7 -0
  39. jaxcld-0.1.0/jaxcld/utils/whisper_dataloader.py +142 -0
  40. jaxcld-0.1.0/jaxcld.egg-info/PKG-INFO +80 -0
  41. jaxcld-0.1.0/jaxcld.egg-info/SOURCES.txt +45 -0
  42. jaxcld-0.1.0/jaxcld.egg-info/dependency_links.txt +1 -0
  43. jaxcld-0.1.0/jaxcld.egg-info/requires.txt +28 -0
  44. jaxcld-0.1.0/jaxcld.egg-info/top_level.txt +1 -0
  45. jaxcld-0.1.0/pyproject.toml +55 -0
  46. jaxcld-0.1.0/setup.cfg +4 -0
  47. jaxcld-0.1.0/tests/test_final_dry_asr_and_heads.py +251 -0
jaxcld-0.1.0/PKG-INFO ADDED
@@ -0,0 +1,80 @@
1
+ Metadata-Version: 2.4
2
+ Name: jaxcld
3
+ Version: 0.1.0
4
+ Summary: CLD: language detection heads for ASR models
5
+ Author: CLD contributors
6
+ License: MIT
7
+ Requires-Python: >=3.9
8
+ Description-Content-Type: text/markdown
9
+ Requires-Dist: numpy>=1.24
10
+ Requires-Dist: torch>=2.0.0
11
+ Requires-Dist: torchaudio>=2.0.0
12
+ Requires-Dist: transformers==4.56.2
13
+ Requires-Dist: scikit-learn>=1.3.0
14
+ Provides-Extra: train
15
+ Requires-Dist: datasets[audio]==3.6.0; extra == "train"
16
+ Requires-Dist: soundfile>=0.12.1; extra == "train"
17
+ Requires-Dist: scipy>=1.10; extra == "train"
18
+ Requires-Dist: tqdm>=4.66; extra == "train"
19
+ Requires-Dist: pandas>=1.5.0; extra == "train"
20
+ Requires-Dist: librosa>=0.10.1; extra == "train"
21
+ Requires-Dist: noisereduce>=3.0.0; extra == "train"
22
+ Requires-Dist: pydub>=0.25.1; extra == "train"
23
+ Requires-Dist: accelerate>=0.20.0; extra == "train"
24
+ Requires-Dist: evaluate>=0.4.0; extra == "train"
25
+ Requires-Dist: jiwer>=3.0.0; extra == "train"
26
+ Requires-Dist: torchcodec==0.10.0; extra == "train"
27
+ Requires-Dist: wandb>=0.15.0; extra == "train"
28
+ Requires-Dist: tensorboard>=2.13.0; extra == "train"
29
+ Requires-Dist: huggingface_hub>=0.17.0; extra == "train"
30
+ Requires-Dist: gradio>=3.0.0; extra == "train"
31
+ Requires-Dist: audiomentations==0.43.1; extra == "train"
32
+ Requires-Dist: jax==0.7.2; extra == "train"
33
+ Requires-Dist: optax==0.2.6; extra == "train"
34
+ Requires-Dist: flax==0.11.2; extra == "train"
35
+ Requires-Dist: python-dotenv==1.1.1; extra == "train"
36
+
37
+ ## jaxcld
38
+
39
+ `jaxcld` is a lightweight language-detection module for multilingual ASR models (Whisper / MMS). It provides an `ASRModel` wrapper plus pluggable language detection heads you can attach at inference time.
40
+
41
+ ## Install
42
+
43
+ ```bash
44
+ pip install jaxcld
45
+ ```
46
+
47
+ If you are developing from source:
48
+
49
+ ```bash
50
+ pip install -e .
51
+ ```
52
+
53
+ ## Using the package (minimal inference example)
54
+
55
+ ```python
56
+ import numpy as np
57
+
58
+ from jaxcld import ASRModel, CVXNNLangDetectHead, NNLangDetectHead, SVMLangDetectHead
59
+
60
+ # 1) Load the base ASR model
61
+ languages = ["en", "hi", "id", "ms", "zh"]
62
+ asr = ASRModel.from_pretrained("openai/whisper-small", config={"languages": languages})
63
+
64
+ # 2) Load a language detection head artifact (choose ONE)
65
+ # head = CVXNNLangDetectHead.load("path/to/whisper-small_trained_cvx_mlp.pkl", asr)
66
+ # head = NNLangDetectHead.load("path/to/openai_whisper-small_nn_head.pkl", asr)
67
+ # head = SVMLangDetectHead.load("path/to/openai_whisper-small_linear_svm.pkl", asr)
68
+
69
+ # 3) Attach head and run inference
70
+ asr.set_lang_detect_head(head)
71
+
72
+ audio_16k_mono: np.ndarray = ... # shape (T,), sampling rate 16kHz
73
+ pred_langs, pred_texts = asr.predict(audio_16k_mono)
74
+ print(pred_langs[0], pred_texts[0])
75
+ ```
76
+
77
+ ## Notes
78
+
79
+ - Head artifacts (`*.pkl`) are produced by training scripts in the source repository; this pip README intentionally focuses only on **package usage**.
80
+
jaxcld-0.1.0/README.md ADDED
@@ -0,0 +1,314 @@
1
+ ## Convex Low-resource Accent-Robust Language Detection in Speech Recognition
2
+
3
+ This repository provides the official implementation of **CLD**, a lightweight language-detection module for multilingual ASR. This codebase contains our pip-installable Python package (`jaxcld/`) including our training/benchmark scripts implemented in JAX and optimized via ADMM for high performance in low-resource settings. Simply, the package attaches a small language detection head (Convex NN / small NN / linear SVM) to ASR encoder representations, and use it to select the language token (Whisper) or adapter (MMS) before decoding.
4
+
5
+ ![Approach overview](assets/fig_1_2.png)
6
+
7
+ ## Highlights
8
+
9
+ - High Accuracy: Excels in binary and multiclass language detection (Table 2).
10
+ - Low-Resource Robustness: Effective with limited data (Figures 1 & 2).
11
+ - Efficient: 13x training speedup from traditional NNs due to ADMM optimization and JAX.
12
+
13
+ <!--
14
+ ## What’s in this repo
15
+
16
+ - **`jaxcld/`**: package with `ASRModel` adapters (Whisper + MMS) and language detection heads
17
+ - **Training scripts**
18
+ - **Whisper fine-tuning**: `train_whisper.py`
19
+ - **Convex head (CVXNN, JAX + ADMM/CRONOS)**: `train_cvxnn.py`
20
+ - **Small NN head (PyTorch)**: `train_nn.py`
21
+ - **Linear SVM head (sklearn)**: `train_linear_svm.py`
22
+ - **Evaluation**: `benchmark_cld.py` (language detection metrics + WER/CER, with optional per-accent breakdown)
23
+ - **Tests**: `tests/` (smoke-tests for loading heads and running inference end-to-end) -->
24
+
25
+ ## Requirements
26
+
27
+ This repo supports two common setups:
28
+
29
+ - **Package-only install** (inference usage):
30
+
31
+ ```bash
32
+ pip install -e .
33
+ ```
34
+
35
+ - **Full training/benchmark environment** (recommended if you run the scripts in this repo):
36
+
37
+ ```bash
38
+ pip install -e ".[train]"
39
+ ```
40
+
41
+ If you prefer installing from the pinned dependency list instead:
42
+
43
+ ```bash
44
+ pip install -r requirements.txt
45
+ ```
46
+
47
+ ## Pip README (package-only)
48
+
49
+ For the pip/PyPI page we use a separate, minimal README focused only on **using** the `jaxcld` package:
50
+
51
+ - `README_PIP.md`
52
+
53
+ ## Using the package
54
+
55
+ ### Minimal inference example (Whisper)
56
+
57
+ ```python
58
+ import numpy as np
59
+
60
+ from jaxcld import ASRModel, CVXNNLangDetectHead, NNLangDetectHead, SVMLangDetectHead
61
+
62
+ # 1) Load the base ASR model
63
+ languages = ["en", "hi", "id", "ms", "zh"]
64
+ asr = ASRModel.from_pretrained("openai/whisper-small", config={"languages": languages})
65
+
66
+ # 2) Load a language detection head artifact (choose ONE)
67
+ # head = CVXNNLangDetectHead.load("path/to/whisper-small_trained_cvx_mlp.pkl", asr)
68
+ # head = NNLangDetectHead.load("path/to/openai_whisper-small_nn_head.pkl", asr)
69
+ # head = SVMLangDetectHead.load("path/to/openai_whisper-small_linear_svm.pkl", asr)
70
+
71
+ # 3) Attach head and run inference
72
+ asr.set_lang_detect_head(head)
73
+
74
+ audio_16k_mono: np.ndarray = ... # shape (T,), sampling rate 16kHz
75
+ pred_langs, pred_texts = asr.predict(audio_16k_mono)
76
+ print(pred_langs[0], pred_texts[0])
77
+ ```
78
+
79
+ ## Training
80
+
81
+ ## Data format
82
+
83
+ All training/evaluation scripts expect a **Hugging Face `DatasetDict` saved to disk** (loaded via `datasets.load_from_disk(...)`) with splits like `train`, `valid`, `test`. Use our `data_ingestion.py` script to prepare your data.
84
+
85
+ ```bash
86
+ python data_ingestion.py \
87
+ --config configs/en_hi_config.json \
88
+ --out data/en_hi \
89
+ --common-voice-dir /absolute/path/to/CommonVoice \
90
+ --augment
91
+ ```
92
+
93
+ - Required: `--config` JSON (see example below), `--out` save directory.
94
+ - Optional: `--augment` enables audiomentations; `--musan-dir` for background noise; `--common-voice-dir` for local Common Voice.
95
+ - Output: a saved `DatasetDict` at `data/en_hi` with columns: `audio`, `text`, `lang`, `accent`.
96
+
97
+ Minimal config example (see more in `configs/`):
98
+ ```json
99
+ {
100
+ "name": "English-Hindi example",
101
+ "languages": {
102
+ "en": {
103
+ "accents": [
104
+ { "code": "us", "column_name": "United States English", "dataset": "common_voice" }
105
+ ]
106
+ },
107
+ "hi": {
108
+ "accents": [
109
+ { "code": "hi", "column_name": "", "dataset": "common_voice" }
110
+ ]
111
+ }
112
+ },
113
+ "params": {
114
+ "samples_per_class": 1000,
115
+ "split": { "train": 0.8, "val": 0.1, "test": 0.1 }
116
+ }
117
+ }
118
+ ```
119
+
120
+ Notes:
121
+ - Common Voice selection uses `column_name` against `accents` in `validated.tsv`. Use `override_code` to point to alternative folders (see `configs/final_config.json`).
122
+ - Lahaja examples match by `native_language` (e.g., `"Telugu"`, `"Konkani"`).
123
+
124
+ ### Train language detection heads
125
+
126
+ All heads are trained on **pooled encoder embeddings** extracted by `ASRModel.load_data(...)` from a dataset on disk.
127
+
128
+ #### CVXNN (convex head, JAX + ADMM/CRONOS)
129
+
130
+ ```bash
131
+ python train_cvxnn.py \
132
+ --model_name openai/whisper-small \
133
+ --dataset_path data/multiclass \
134
+ --languages en,hi,id,ms,zh \
135
+ --output_dir models/lang_heads \
136
+ --neuron 64 \
137
+ --beta 0.001 \
138
+ --rho 0.1 \
139
+ --admm_iters 6
140
+ ```
141
+
142
+ This produces a pickled artifact like:
143
+ - `models/lang_heads/openai/whisper-small/openai_whisper-small_trained_cvx_mlp.pkl`
144
+
145
+ #### NN head (PyTorch)
146
+
147
+ ```bash
148
+ python train_nn.py \
149
+ --dataset_path data/multiclass \
150
+ --model_name openai/whisper-small \
151
+ --languages en,hi,id,ms,zh \
152
+ --output_dir models/lang_heads \
153
+ --num_train_epochs 10 \
154
+ --learning_rate 1e-3 \
155
+ --per_device_train_batch_size 256
156
+ ```
157
+
158
+ This produces a pickled artifact like:
159
+ - `models/lang_heads/openai/whisper-small/openai_whisper-small_nn_head.pkl`
160
+
161
+ #### Linear SVM head (sklearn)
162
+
163
+ ```bash
164
+ python train_linear_svm.py \
165
+ --model_name openai/whisper-small \
166
+ --data_dir data/multiclass \
167
+ --languages en,hi,id,ms,zh \
168
+ --output_dir models/lang_heads \
169
+ --C 1.0 \
170
+ --max_iter 5000
171
+ ```
172
+
173
+ This produces a pickled artifact like:
174
+ - `models/lang_heads/openai/whisper-small/openai_whisper-small_linear_svm.pkl`
175
+
176
+
177
+ #### Fine-tune Whisper
178
+
179
+ Use `train_whisper.py` to fine-tune a Whisper checkpoint on a preprocessed dataset directory:
180
+
181
+ ```bash
182
+ python train_whisper.py \
183
+ --data_dir data/multiclass \
184
+ --model_id openai/whisper-small \
185
+ --output_dir models/whisper-small-finetuned \
186
+ --num_train_epochs 3 \
187
+ --learning_rate 1e-5 \
188
+ --per_device_train_batch_size 8 \
189
+ --per_device_eval_batch_size 8 \
190
+ --gradient_accumulation_steps 1 \
191
+ --eval_strategy steps \
192
+ --eval_steps 1000 \
193
+ --save_steps 1000
194
+ ```
195
+
196
+ Optional logging:
197
+
198
+ ```bash
199
+ python train_whisper.py ... \
200
+ --wandb_project CLD \
201
+ --run_name whisper-small-finetune-final_dry
202
+ ```
203
+
204
+ ## Evaluation
205
+
206
+ Use `benchmark_cld.py` to evaluate **language detection** and **transcription quality** (WER/CER) on the `test` split.
207
+
208
+ ### Whisper + CVXNN head
209
+
210
+ ```bash
211
+ python benchmark_cld.py \
212
+ --dataset_path data/multiclass \
213
+ --model_name openai/whisper-small \
214
+ --cld_type cvx \
215
+ --cld_path models/lang_heads/openai/whisper-small/openai_whisper-small_trained_cvx_mlp.pkl \
216
+ --languages en,hi,id,ms,zh \
217
+ --batch_size 32 \
218
+ --no_wandb
219
+ ```
220
+
221
+ ### Whisper + NN head
222
+
223
+ ```bash
224
+ python benchmark_cld.py \
225
+ --dataset_path data/multiclass \
226
+ --model_name openai/whisper-small \
227
+ --cld_type nn \
228
+ --cld_path models/lang_heads/openai/whisper-small/openai_whisper-small_nn_head.pkl \
229
+ --languages en,hi,id,ms,zh \
230
+ --batch_size 32 \
231
+ --no_wandb
232
+ ```
233
+
234
+ ### Whisper + linear SVM head
235
+
236
+ ```bash
237
+ python benchmark_cld.py \
238
+ --dataset_path data/multiclass \
239
+ --model_name openai/whisper-small \
240
+ --cld_type linear_svm \
241
+ --cld_path models/lang_heads/openai/whisper-small/openai_whisper-small_linear_svm.pkl \
242
+ --languages en,hi,id,ms,zh \
243
+ --batch_size 32 \
244
+ --no_wandb
245
+ ```
246
+
247
+ ### Whisper vanilla language ID (no head)
248
+
249
+ ```bash
250
+ python benchmark_cld.py \
251
+ --dataset_path data/multiclass \
252
+ --model_name openai/whisper-small \
253
+ --cld_type vanilla \
254
+ --languages en,hi,id,ms,zh \
255
+ --batch_size 32 \
256
+ --no_wandb
257
+ ```
258
+
259
+ <!-- ## Pre-trained models
260
+
261
+ _TBD._ This repo supports loading three head types:
262
+
263
+ | Head type | Artifact | Loader |
264
+ | --- | --- | --- |
265
+ | CVXNN | `*_trained_cvx_mlp.pkl` | `CVXNNLangDetectHead.load(...)` |
266
+ | NN | `*_nn_head.pkl` | `NNLangDetectHead.load(...)` |
267
+ | Linear SVM | `*_linear_svm.pkl` | `SVMLangDetectHead.load(...)` | -->
268
+
269
+ ## Results
270
+
271
+ Paper results (Table 5):
272
+
273
+ ![Table 5](assets/table_5.png)
274
+
275
+ To reproduce the evaluation numbers for a given head, run `benchmark_cld.py` as shown in the Evaluation section.
276
+ <!--
277
+ ## Tests
278
+
279
+ ```bash
280
+ pytest -q
281
+ ```
282
+
283
+ Note: tests are designed to **skip** if the local dataset at `data/test/final_dry/` is missing or if large model weights are unavailable.
284
+
285
+ ## Contributing
286
+
287
+ - **Bugs / features**: please open an issue with a minimal reproduction.
288
+ - **Pull requests**: keep changes focused, add/update tests when behavior changes, and document new scripts/flags in `README.md`.
289
+
290
+ ## License
291
+
292
+ MIT (see `pyproject.toml`). -->
293
+
294
+ <!-- ## Citation
295
+
296
+ If you use this code in your work, please cite the paper:
297
+
298
+ ```bibtex
299
+ @article{cld2026,
300
+ title = {CLD: Convex Language Detection Heads for Accent-Robust Multilingual ASR},
301
+ author = {TBD},
302
+ journal = {TBD},
303
+ year = {2026}
304
+ }
305
+ ``` -->
306
+
307
+ <!-- ## Questions / missing info (to finalize this README)
308
+
309
+ - **Paper metadata**: what is the final paper title, author list, venue, and arXiv/camera-ready link?
310
+ - **Dataset recipe**: how should users reproduce `data/test/final_dry/` from raw sources (which datasets, filtering, splits, and preprocessing)?
311
+ - **Accent labels**: what is the definition/source of the `accent` field (taxonomy + how it’s derived)?
312
+ - **Default language set**: is `en,hi,id,ms,zh` the canonical set, or just the example from your experiments?
313
+ - **Pretrained artifacts**: where should the pretrained Whisper checkpoints and head artifacts be hosted (HF Hub / Google Drive / release assets), and what are the exact filenames?
314
+ - **Reproduction commands**: which exact `train_*` commands correspond to Table 5 (hyperparameters + seeds + compute setup)? -->
@@ -0,0 +1,44 @@
1
+ ## jaxcld
2
+
3
+ `jaxcld` is a lightweight language-detection module for multilingual ASR models (Whisper / MMS). It provides an `ASRModel` wrapper plus pluggable language detection heads you can attach at inference time.
4
+
5
+ ## Install
6
+
7
+ ```bash
8
+ pip install jaxcld
9
+ ```
10
+
11
+ If you are developing from source:
12
+
13
+ ```bash
14
+ pip install -e .
15
+ ```
16
+
17
+ ## Using the package (minimal inference example)
18
+
19
+ ```python
20
+ import numpy as np
21
+
22
+ from jaxcld import ASRModel, CVXNNLangDetectHead, NNLangDetectHead, SVMLangDetectHead
23
+
24
+ # 1) Load the base ASR model
25
+ languages = ["en", "hi", "id", "ms", "zh"]
26
+ asr = ASRModel.from_pretrained("openai/whisper-small", config={"languages": languages})
27
+
28
+ # 2) Load a language detection head artifact (choose ONE)
29
+ # head = CVXNNLangDetectHead.load("path/to/whisper-small_trained_cvx_mlp.pkl", asr)
30
+ # head = NNLangDetectHead.load("path/to/openai_whisper-small_nn_head.pkl", asr)
31
+ # head = SVMLangDetectHead.load("path/to/openai_whisper-small_linear_svm.pkl", asr)
32
+
33
+ # 3) Attach head and run inference
34
+ asr.set_lang_detect_head(head)
35
+
36
+ audio_16k_mono: np.ndarray = ... # shape (T,), sampling rate 16kHz
37
+ pred_langs, pred_texts = asr.predict(audio_16k_mono)
38
+ print(pred_langs[0], pred_texts[0])
39
+ ```
40
+
41
+ ## Notes
42
+
43
+ - Head artifacts (`*.pkl`) are produced by training scripts in the source repository; this pip README intentionally focuses only on **package usage**.
44
+
@@ -0,0 +1,47 @@
1
+ """
2
+ `jaxcld` package public API.
3
+
4
+ The goal is to support:
5
+
6
+ from jaxcld import ASRModel, CVXNNLangDetectHead
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ __version__ = "0.1.0"
12
+
13
+ __all__ = [
14
+ "ASRModel",
15
+ "CVXNNLangDetectHead",
16
+ "NNLangDetectHead",
17
+ "SVMLangDetectHead",
18
+ ]
19
+
20
+
21
+ def __getattr__(name: str):
22
+ # Lazy imports so `import jaxcld` works even if optional heavy deps (torch, transformers)
23
+ # are not installed, while still supporting `from jaxcld import ASRModel, ...` when they are.
24
+ try:
25
+ if name == "ASRModel":
26
+ from .models.asr_model import ASRModel
27
+
28
+ return ASRModel
29
+ if name == "CVXNNLangDetectHead":
30
+ from .models.lang_detect_head import CVXNNLangDetectHead
31
+
32
+ return CVXNNLangDetectHead
33
+ if name == "NNLangDetectHead":
34
+ from .models.lang_detect_head import NNLangDetectHead
35
+
36
+ return NNLangDetectHead
37
+ if name == "SVMLangDetectHead":
38
+ from .models.lang_detect_head import SVMLangDetectHead
39
+
40
+ return SVMLangDetectHead
41
+ except ModuleNotFoundError as e:
42
+ raise ImportError(
43
+ "Missing optional dependency. Install jaxcld with its runtime dependencies, e.g. "
44
+ "`pip install -e .` (or `pip install .`) and ensure `torch`, `torchaudio`, and "
45
+ "`transformers` are available."
46
+ ) from e
47
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
@@ -0,0 +1,11 @@
1
+ """Model implementations and language detection heads.
2
+
3
+ Note: keep this module light (avoid importing torch/transformers at import time).
4
+ Import symbols from their defining modules directly, e.g.:
5
+
6
+ from jaxcld.models.asr_model import ASRModel
7
+ from jaxcld.models.lang_detect_head import CVXNNLangDetectHead
8
+ """
9
+
10
+ __all__ = []
11
+