jaxcld 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxcld/__init__.py +47 -0
- jaxcld/models/__init__.py +11 -0
- jaxcld/models/asr_model.py +495 -0
- jaxcld/models/cvx_grelu_mlp.py +74 -0
- jaxcld/models/cvx_mlp.py +63 -0
- jaxcld/models/cvx_relu_mlp.py +120 -0
- jaxcld/models/get_model.py +26 -0
- jaxcld/models/grelu_mlp.py +44 -0
- jaxcld/models/lang_detect_head.py +152 -0
- jaxcld/models/relu_mlp.py +71 -0
- jaxcld/models/two_layer_mlp.py +11 -0
- jaxcld/optimizers/__init__.py +4 -0
- jaxcld/optimizers/adamW.py +39 -0
- jaxcld/optimizers/admm.py +103 -0
- jaxcld/optimizers/dadapt_adamW.py +38 -0
- jaxcld/optimizers/dist_shampoo/__init__.py +4 -0
- jaxcld/optimizers/dist_shampoo/distributed_shampoo.py +2831 -0
- jaxcld/optimizers/dist_shampoo/quantization_utils.py +115 -0
- jaxcld/optimizers/pcg.py +69 -0
- jaxcld/optimizers/sgd.py +38 -0
- jaxcld/optimizers/shampoo.py +37 -0
- jaxcld/optimizers/yogi.py +36 -0
- jaxcld/preconditioner/__init__.py +4 -0
- jaxcld/preconditioner/nystrom.py +102 -0
- jaxcld/training/__init__.py +8 -0
- jaxcld/training/train.py +164 -0
- jaxcld/training/train_no_jit.py +126 -0
- jaxcld/utils/__init__.py +4 -0
- jaxcld/utils/linops_utils.py +50 -0
- jaxcld/utils/load_data.py +459 -0
- jaxcld/utils/metric_utils.py +59 -0
- jaxcld/utils/model_utils.py +113 -0
- jaxcld/utils/opt_utils.py +31 -0
- jaxcld/utils/proximal_utils.py +22 -0
- jaxcld/utils/train_utils.py +7 -0
- jaxcld/utils/whisper_dataloader.py +142 -0
- jaxcld-0.1.0.dist-info/METADATA +80 -0
- jaxcld-0.1.0.dist-info/RECORD +40 -0
- jaxcld-0.1.0.dist-info/WHEEL +5 -0
- jaxcld-0.1.0.dist-info/top_level.txt +1 -0
jaxcld/__init__.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
"""
|
|
2
|
+
`jaxcld` package public API.
|
|
3
|
+
|
|
4
|
+
The goal is to support:
|
|
5
|
+
|
|
6
|
+
from jaxcld import ASRModel, CVXNNLangDetectHead
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
__version__ = "0.1.0"
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"ASRModel",
|
|
15
|
+
"CVXNNLangDetectHead",
|
|
16
|
+
"NNLangDetectHead",
|
|
17
|
+
"SVMLangDetectHead",
|
|
18
|
+
]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def __getattr__(name: str):
|
|
22
|
+
# Lazy imports so `import jaxcld` works even if optional heavy deps (torch, transformers)
|
|
23
|
+
# are not installed, while still supporting `from jaxcld import ASRModel, ...` when they are.
|
|
24
|
+
try:
|
|
25
|
+
if name == "ASRModel":
|
|
26
|
+
from .models.asr_model import ASRModel
|
|
27
|
+
|
|
28
|
+
return ASRModel
|
|
29
|
+
if name == "CVXNNLangDetectHead":
|
|
30
|
+
from .models.lang_detect_head import CVXNNLangDetectHead
|
|
31
|
+
|
|
32
|
+
return CVXNNLangDetectHead
|
|
33
|
+
if name == "NNLangDetectHead":
|
|
34
|
+
from .models.lang_detect_head import NNLangDetectHead
|
|
35
|
+
|
|
36
|
+
return NNLangDetectHead
|
|
37
|
+
if name == "SVMLangDetectHead":
|
|
38
|
+
from .models.lang_detect_head import SVMLangDetectHead
|
|
39
|
+
|
|
40
|
+
return SVMLangDetectHead
|
|
41
|
+
except ModuleNotFoundError as e:
|
|
42
|
+
raise ImportError(
|
|
43
|
+
"Missing optional dependency. Install jaxcld with its runtime dependencies, e.g. "
|
|
44
|
+
"`pip install -e .` (or `pip install .`) and ensure `torch`, `torchaudio`, and "
|
|
45
|
+
"`transformers` are available."
|
|
46
|
+
) from e
|
|
47
|
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
"""Model implementations and language detection heads.
|
|
2
|
+
|
|
3
|
+
Note: keep this module light (avoid importing torch/transformers at import time).
|
|
4
|
+
Import symbols from their defining modules directly, e.g.:
|
|
5
|
+
|
|
6
|
+
from jaxcld.models.asr_model import ASRModel
|
|
7
|
+
from jaxcld.models.lang_detect_head import CVXNNLangDetectHead
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
__all__ = []
|
|
11
|
+
|
|
@@ -0,0 +1,495 @@
|
|
|
1
|
+
import os, time, torch, types, pickle
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from safetensors.torch import load_file
|
|
4
|
+
from transformers import WhisperForConditionalGeneration, WhisperProcessor, Wav2Vec2ForCTC, AutoProcessor, AutoModelForAudioClassification
|
|
5
|
+
|
|
6
|
+
import jax.numpy as jnp
|
|
7
|
+
import numpy as np
|
|
8
|
+
from datasets import load_from_disk
|
|
9
|
+
import torch
|
|
10
|
+
import torchaudio
|
|
11
|
+
from typing import Tuple
|
|
12
|
+
from abc import ABC, abstractmethod
|
|
13
|
+
from collections import defaultdict
|
|
14
|
+
|
|
15
|
+
dtype = torch.float16
|
|
16
|
+
|
|
17
|
+
ISO2_TO_ISO3 = {
|
|
18
|
+
"en": "eng",
|
|
19
|
+
"zh": "zho",
|
|
20
|
+
"hi": "hin",
|
|
21
|
+
"id": "ind",
|
|
22
|
+
"ms": "zlm"
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
ISO3_TO_ISO2 = {
|
|
26
|
+
"cdo": "zh",
|
|
27
|
+
"cmn": "zh",
|
|
28
|
+
"cpx": "zh",
|
|
29
|
+
"czh": "zh",
|
|
30
|
+
"hak": "zh",
|
|
31
|
+
"hsn": "zh",
|
|
32
|
+
"mnp": "zh",
|
|
33
|
+
"nan": "zh",
|
|
34
|
+
"wuu": "zh",
|
|
35
|
+
"yue": "zh",
|
|
36
|
+
"eng": "en",
|
|
37
|
+
"hin": "hi",
|
|
38
|
+
"zlm": "ms",
|
|
39
|
+
"ind": "id"
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
# Used by MMS adapter selection to try multiple script-specific adapters.
|
|
43
|
+
POSSIBLE_SCRIPTS = ["", "Latn", "Cyrl", "Arab", "Deva", "Hans", "Hant"]
|
|
44
|
+
|
|
45
|
+
class ASRModel(ABC):
|
|
46
|
+
def __init__(self, model_name, config):
|
|
47
|
+
"""Load model here"""
|
|
48
|
+
self.model_name = model_name
|
|
49
|
+
self.config = config
|
|
50
|
+
|
|
51
|
+
@classmethod
|
|
52
|
+
def from_pretrained(self, model_name, config={}):
|
|
53
|
+
if model_name.startswith("openai/whisper"):
|
|
54
|
+
return Whisper(model_name, config)
|
|
55
|
+
elif model_name.startswith("facebook/mms"):
|
|
56
|
+
return MMS(model_name, config)
|
|
57
|
+
else:
|
|
58
|
+
raise ValueError(f"Unknown model name: {model_name}")
|
|
59
|
+
|
|
60
|
+
@abstractmethod
|
|
61
|
+
def load_data(self, dataset_path: str, caller_script: str = None, data_seed: int = 42, dataset_split: str = "train"):
|
|
62
|
+
pass
|
|
63
|
+
|
|
64
|
+
def load_data_jax(self, dataset_path: str, caller_script: str = None, data_seed: int = 42, dataset_split: str = "train"):
|
|
65
|
+
A, y = self.load_data(dataset_path, caller_script, data_seed, dataset_split)
|
|
66
|
+
A = jnp.array(A) # (n, 768)
|
|
67
|
+
y = jnp.array(y) # (n,)
|
|
68
|
+
return A, y
|
|
69
|
+
|
|
70
|
+
@abstractmethod
|
|
71
|
+
def set_lang_detect_head(self, lang_detect_head):
|
|
72
|
+
pass
|
|
73
|
+
|
|
74
|
+
@abstractmethod
|
|
75
|
+
def predict(self, audio):
|
|
76
|
+
"""Runs transcription on the audio, returns list of language tokens and transcriptions"""
|
|
77
|
+
pass
|
|
78
|
+
|
|
79
|
+
@abstractmethod
|
|
80
|
+
def get_dimensions(self):
|
|
81
|
+
pass
|
|
82
|
+
|
|
83
|
+
@abstractmethod
|
|
84
|
+
def get_device(self):
|
|
85
|
+
pass
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def whisper_custom_retrieve_init_tokens_creator(asr_model, languages, tokenizer):
|
|
90
|
+
def _custom_retrieve_init_tokens(self, input_features, batch_size, generation_config=None, **kwargs):
|
|
91
|
+
def lang_to_id(_, lang):
|
|
92
|
+
return self.generation_config.lang_to_id[f"<|{lang}|>"]
|
|
93
|
+
|
|
94
|
+
encoder_outputs = self.model.encoder(input_features, return_dict=True)
|
|
95
|
+
hidden = encoder_outputs.last_hidden_state
|
|
96
|
+
class_ids = asr_model.head.predict(hidden)
|
|
97
|
+
|
|
98
|
+
if not languages:
|
|
99
|
+
raise ValueError("config['languages'] must be provided (non-empty) when using a custom language detection head.")
|
|
100
|
+
|
|
101
|
+
# Head predicts indices into config['languages']
|
|
102
|
+
chosen_langs = []
|
|
103
|
+
for cid in class_ids:
|
|
104
|
+
try:
|
|
105
|
+
chosen_langs.append(languages[int(cid)])
|
|
106
|
+
except Exception:
|
|
107
|
+
chosen_langs.append(languages[0])
|
|
108
|
+
|
|
109
|
+
lang_tokens = [lang_to_id(self, lang) for lang in chosen_langs]
|
|
110
|
+
asr_model.lang_tokens.extend(chosen_langs)
|
|
111
|
+
|
|
112
|
+
sot_token_id = tokenizer.convert_tokens_to_ids("<|startoftranscript|>")
|
|
113
|
+
transcribe_token_id = tokenizer.convert_tokens_to_ids("<|transcribe|>")
|
|
114
|
+
init_tokens = [[sot_token_id, lang_token, transcribe_token_id] for lang_token in lang_tokens]
|
|
115
|
+
|
|
116
|
+
init_tokens_tensor = torch.tensor(init_tokens,
|
|
117
|
+
dtype=torch.long,
|
|
118
|
+
device=input_features.device)
|
|
119
|
+
|
|
120
|
+
return init_tokens_tensor
|
|
121
|
+
|
|
122
|
+
return _custom_retrieve_init_tokens
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class Whisper(ASRModel):
|
|
126
|
+
def __init__(self, model_name, config={}):
|
|
127
|
+
super().__init__(model_name, config)
|
|
128
|
+
self.model = WhisperForConditionalGeneration.from_pretrained(model_name, device_map="auto")
|
|
129
|
+
self.model.to(dtype=dtype)
|
|
130
|
+
self.model.config.forced_decoder_ids = None
|
|
131
|
+
self.processor = WhisperProcessor.from_pretrained(model_name)
|
|
132
|
+
self.head = None # default head
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def load_data(self, dataset_path: str, caller_script: str = None, data_seed: int = 42, dataset_split: str = "train", shuffle=True):
|
|
136
|
+
"""
|
|
137
|
+
Load HF dataset, extract pooled model hidden states, return train/test splits.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
dataset_path (str): Path to local HF dataset dir (splits: train, valid, test).
|
|
141
|
+
caller_script (str): 'defrun' for 90% data (convex training); else full.
|
|
142
|
+
data_seed (int): Seed for shuffle/split.
|
|
143
|
+
|
|
144
|
+
Returns:
|
|
145
|
+
Atr, ytr, Atst, ytst, ntr, ntst: JAX arrays for features/labels (pooled to 768 dim).
|
|
146
|
+
"""
|
|
147
|
+
np.random.seed(data_seed)
|
|
148
|
+
|
|
149
|
+
# Load train split (main data for training)
|
|
150
|
+
dataset = load_from_disk(dataset_path)
|
|
151
|
+
train_data = dataset[dataset_split]
|
|
152
|
+
print(f"Loaded {len(train_data)} train samples")
|
|
153
|
+
|
|
154
|
+
languages = self.config.get("languages")
|
|
155
|
+
if not languages:
|
|
156
|
+
# Infer languages from the dataset split and persist for downstream consumers.
|
|
157
|
+
languages = sorted({sample.get("lang") for sample in train_data if sample.get("lang") is not None})
|
|
158
|
+
self.config["languages"] = languages
|
|
159
|
+
|
|
160
|
+
lang_to_index = {lang: i for i, lang in enumerate(languages)}
|
|
161
|
+
|
|
162
|
+
# Load Whisper encoder
|
|
163
|
+
self.model.eval()
|
|
164
|
+
|
|
165
|
+
def extract_pooled_hidden(audio) -> np.ndarray:
|
|
166
|
+
"""Extract and pool last hidden states to (768,)."""
|
|
167
|
+
# Handle audio dict or path
|
|
168
|
+
if isinstance(audio, dict):
|
|
169
|
+
if audio.get('array') is not None:
|
|
170
|
+
audio_arr = audio['array']
|
|
171
|
+
sr = audio['sampling_rate']
|
|
172
|
+
else:
|
|
173
|
+
audio_path = audio['path']
|
|
174
|
+
if not os.path.exists(audio_path):
|
|
175
|
+
return None
|
|
176
|
+
waveform, sr = torchaudio.load(audio_path)
|
|
177
|
+
audio_arr = waveform.mean(0).numpy()
|
|
178
|
+
else:
|
|
179
|
+
# Assume path if not dict
|
|
180
|
+
if not os.path.exists(audio):
|
|
181
|
+
return None
|
|
182
|
+
waveform, sr = torchaudio.load(audio)
|
|
183
|
+
audio_arr = waveform.mean(0).numpy()
|
|
184
|
+
|
|
185
|
+
# Resample to 16kHz
|
|
186
|
+
if sr != 16000:
|
|
187
|
+
resampler = torchaudio.transforms.Resample(sr, 16000)
|
|
188
|
+
audio_arr = resampler(torch.tensor(audio_arr)).numpy()
|
|
189
|
+
|
|
190
|
+
# Process to input_features
|
|
191
|
+
inputs = self.processor(audio_arr, sampling_rate=16000, return_tensors='pt').to(self.get_device(), dtype=dtype)
|
|
192
|
+
|
|
193
|
+
# Encoder last hidden
|
|
194
|
+
with torch.no_grad():
|
|
195
|
+
encoder_outputs = self.model.model.encoder(inputs.input_features, output_hidden_states=True)
|
|
196
|
+
hidden = encoder_outputs.last_hidden_state.squeeze(0) # (seq_len, 768)
|
|
197
|
+
|
|
198
|
+
# Pool: Mean over seq_len
|
|
199
|
+
pooled = hidden.mean(0).cpu().numpy() # (768,)
|
|
200
|
+
return pooled
|
|
201
|
+
|
|
202
|
+
# Extract features and labels for all train samples
|
|
203
|
+
features = []
|
|
204
|
+
labels = []
|
|
205
|
+
valid_count = 0
|
|
206
|
+
for sample in train_data:
|
|
207
|
+
hidden = extract_pooled_hidden(sample['audio'])
|
|
208
|
+
if hidden is None:
|
|
209
|
+
continue # Skip invalid audio
|
|
210
|
+
|
|
211
|
+
label = lang_to_index.get(sample.get("lang"))
|
|
212
|
+
if label is None:
|
|
213
|
+
continue
|
|
214
|
+
features.append(hidden)
|
|
215
|
+
labels.append(label)
|
|
216
|
+
valid_count += 1
|
|
217
|
+
|
|
218
|
+
if valid_count == 0:
|
|
219
|
+
raise ValueError("No valid audio samples found")
|
|
220
|
+
print(f"Extracted {valid_count} valid samples across {len(languages)} language(s)")
|
|
221
|
+
|
|
222
|
+
# Convert to arrays
|
|
223
|
+
A = np.array(features)
|
|
224
|
+
y = np.array(labels, dtype=int)
|
|
225
|
+
|
|
226
|
+
# Shuffle
|
|
227
|
+
if shuffle:
|
|
228
|
+
perm = np.random.permutation(A.shape[0])
|
|
229
|
+
A = A[perm]
|
|
230
|
+
y = y[perm]
|
|
231
|
+
|
|
232
|
+
return A, y
|
|
233
|
+
|
|
234
|
+
def set_lang_detect_head(self, lang_detect_head):
|
|
235
|
+
self.head = lang_detect_head
|
|
236
|
+
if self.head:
|
|
237
|
+
languages = self.config.get("languages")
|
|
238
|
+
if not languages:
|
|
239
|
+
raise ValueError("config['languages'] must be provided when using a custom language detection head.")
|
|
240
|
+
self.model._retrieve_init_tokens = types.MethodType(
|
|
241
|
+
whisper_custom_retrieve_init_tokens_creator(self, languages, self.processor.tokenizer),
|
|
242
|
+
self.model,
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
def _detect_language_vanilla(self, input_features):
|
|
246
|
+
# 50258 is the token for transcribing
|
|
247
|
+
batch_size = input_features.shape[0]
|
|
248
|
+
device = input_features.device
|
|
249
|
+
decoder_input_ids = torch.full((batch_size, 1), 50258, dtype=torch.long, device=device)
|
|
250
|
+
model_output = self.model(input_features, decoder_input_ids=decoder_input_ids)
|
|
251
|
+
logits = model_output.logits[:, -1, :] # Shape: (batch_size, vocab_size)
|
|
252
|
+
|
|
253
|
+
# Language tokens in Whisper multilingual models are IDs 50263 to 50361 (99 languages)
|
|
254
|
+
# Compute probabilities and detect the most likely language per batch item
|
|
255
|
+
language_probs = torch.softmax(logits, dim=-1)
|
|
256
|
+
language_indices = torch.argmax(language_probs, dim=-1) # Shape: (batch_size,)
|
|
257
|
+
|
|
258
|
+
# Map indices to language codes (sorted list of Whisper's 99 supported languages)
|
|
259
|
+
detected_languages = [self.id_to_lang(x.item()) for x in language_indices]
|
|
260
|
+
|
|
261
|
+
# Return list of detected languages (one per batch item); also return probs if needed
|
|
262
|
+
return detected_languages # e.g., ['en'] for batch_size=1
|
|
263
|
+
|
|
264
|
+
def predict(self, audio):
|
|
265
|
+
input_features = self.processor(audio, sampling_rate=16000, return_tensors="pt").input_features
|
|
266
|
+
input_features = input_features.to(self.get_device(), dtype=dtype)
|
|
267
|
+
|
|
268
|
+
self.lang_tokens = []
|
|
269
|
+
predicted_ids = self.model.generate(input_features, repetition_penalty=1.1, temperature=(0.0, 0.2, 0.4, 0.6, 0.8, 1.0), compression_ratio_threshold=1.35)
|
|
270
|
+
transcription = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)
|
|
271
|
+
if(self.head is None or getattr(self.head, "SKIP", False)):
|
|
272
|
+
self.lang_tokens = self._detect_language_vanilla(input_features)
|
|
273
|
+
return self.lang_tokens, transcription
|
|
274
|
+
|
|
275
|
+
def get_dimensions(self):
|
|
276
|
+
return self.model.config.d_model
|
|
277
|
+
|
|
278
|
+
def get_device(self):
|
|
279
|
+
return next(self.model.model.encoder.layers[-1].parameters()).device
|
|
280
|
+
|
|
281
|
+
def lang_to_id(self, lang):
|
|
282
|
+
lang_code = f"<|{lang}|>"
|
|
283
|
+
return self.model.generation_config.lang_to_id[lang_code]
|
|
284
|
+
|
|
285
|
+
def id_to_lang(self, tid):
|
|
286
|
+
id_to_lang_mapping = dict(zip(self.model.generation_config.lang_to_id.values(), self.model.generation_config.lang_to_id.keys()))
|
|
287
|
+
return id_to_lang_mapping.get(tid, " ")[2:-2]
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
class MMS(ASRModel):
|
|
291
|
+
def __init__(self, model_name: str, config: dict = {}):
|
|
292
|
+
super().__init__(model_name, config)
|
|
293
|
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
294
|
+
print(model_name)
|
|
295
|
+
self.processor = AutoProcessor.from_pretrained(model_name)
|
|
296
|
+
self.model = Wav2Vec2ForCTC.from_pretrained(model_name).to(self.device, dtype=dtype)
|
|
297
|
+
self.lid_model = AutoModelForAudioClassification.from_pretrained("facebook/mms-lid-126").to(self.device, dtype=dtype)
|
|
298
|
+
self.head = None
|
|
299
|
+
self.current_adapter = None
|
|
300
|
+
|
|
301
|
+
self.iso2_to_iso3 = ISO2_TO_ISO3
|
|
302
|
+
self.languages = config.get("languages") or []
|
|
303
|
+
# Back-compat: older config used "class_names" for iso2 codes.
|
|
304
|
+
if not self.languages and config.get("class_names"):
|
|
305
|
+
self.languages = list(config.get("class_names"))
|
|
306
|
+
self.class_names = [self.iso2_to_iso3.get(cid, cid) for cid in self.languages]
|
|
307
|
+
|
|
308
|
+
def load_data(self, dataset_path: str, caller_script: str = None, data_seed: int = 42, dataset_split: str = "train", shuffle=True):
|
|
309
|
+
"""
|
|
310
|
+
Load HF dataset, extract pooled model hidden states, return train/test splits.
|
|
311
|
+
|
|
312
|
+
Args:
|
|
313
|
+
dataset_path (str): Path to local HF dataset dir (splits: train, valid, test).
|
|
314
|
+
caller_script (str): 'defrun' for 90% data (convex training); else full.
|
|
315
|
+
data_seed (int): Seed for shuffle/split.
|
|
316
|
+
|
|
317
|
+
Returns:
|
|
318
|
+
Atr, ytr, Atst, ytst, ntr, ntst: JAX arrays for features/labels (pooled to 768 dim).
|
|
319
|
+
"""
|
|
320
|
+
np.random.seed(data_seed)
|
|
321
|
+
|
|
322
|
+
# Load train split (main data for training)
|
|
323
|
+
dataset = load_from_disk(dataset_path)
|
|
324
|
+
train_data = dataset[dataset_split]
|
|
325
|
+
print(f"Loaded {len(train_data)} train samples")
|
|
326
|
+
|
|
327
|
+
languages = self.config.get("languages") or self.languages
|
|
328
|
+
if not languages:
|
|
329
|
+
languages = sorted({sample.get("lang") for sample in train_data if sample.get("lang") is not None})
|
|
330
|
+
self.config["languages"] = languages
|
|
331
|
+
self.languages = languages
|
|
332
|
+
self.class_names = [self.iso2_to_iso3.get(cid, cid) for cid in self.languages]
|
|
333
|
+
lang_to_index = {lang: i for i, lang in enumerate(languages)}
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
def extract_pooled_hidden(audio) -> np.ndarray:
|
|
337
|
+
"""Extract and pool last hidden states to (768,)."""
|
|
338
|
+
# Handle audio dict or path
|
|
339
|
+
if isinstance(audio, dict):
|
|
340
|
+
if audio.get('array') is not None:
|
|
341
|
+
audio_arr = audio['array']
|
|
342
|
+
sr = audio['sampling_rate']
|
|
343
|
+
else:
|
|
344
|
+
audio_path = audio['path']
|
|
345
|
+
if not os.path.exists(audio_path):
|
|
346
|
+
return None
|
|
347
|
+
waveform, sr = torchaudio.load(audio_path)
|
|
348
|
+
audio_arr = waveform.mean(0).numpy()
|
|
349
|
+
else:
|
|
350
|
+
# Assume path if not dict
|
|
351
|
+
if not os.path.exists(audio):
|
|
352
|
+
return None
|
|
353
|
+
waveform, sr = torchaudio.load(audio)
|
|
354
|
+
audio_arr = waveform.mean(0).numpy()
|
|
355
|
+
|
|
356
|
+
# Resample to 16kHz
|
|
357
|
+
if sr != 16000:
|
|
358
|
+
resampler = torchaudio.transforms.Resample(sr, 16000)
|
|
359
|
+
audio_arr = resampler(torch.tensor(audio_arr)).numpy()
|
|
360
|
+
|
|
361
|
+
# Process to input_features
|
|
362
|
+
inputs = self.processor(audio_arr, sampling_rate=16000, return_tensors='pt').to(self.get_device(), dtype=dtype)
|
|
363
|
+
|
|
364
|
+
# Encoder last hidden
|
|
365
|
+
with torch.no_grad():
|
|
366
|
+
encoder_outputs = self.model.wav2vec2(inputs.input_values, output_hidden_states=True)
|
|
367
|
+
hidden = encoder_outputs.last_hidden_state.squeeze(0) # (seq_len, 768)
|
|
368
|
+
|
|
369
|
+
# Pool: Mean over seq_len
|
|
370
|
+
pooled = hidden.mean(0).cpu().numpy() # (768,)
|
|
371
|
+
return pooled
|
|
372
|
+
|
|
373
|
+
# Extract features and labels for all train samples
|
|
374
|
+
features = []
|
|
375
|
+
labels = []
|
|
376
|
+
valid_count = 0
|
|
377
|
+
for sample in train_data:
|
|
378
|
+
hidden = extract_pooled_hidden(sample['audio'])
|
|
379
|
+
if hidden is None:
|
|
380
|
+
continue # Skip invalid audio
|
|
381
|
+
|
|
382
|
+
label = lang_to_index.get(sample.get("lang"))
|
|
383
|
+
if label is None:
|
|
384
|
+
continue
|
|
385
|
+
features.append(hidden)
|
|
386
|
+
labels.append(label)
|
|
387
|
+
valid_count += 1
|
|
388
|
+
|
|
389
|
+
if valid_count == 0:
|
|
390
|
+
raise ValueError("No valid audio samples found")
|
|
391
|
+
print(f"Extracted {valid_count} valid samples across {len(languages)} language(s)")
|
|
392
|
+
|
|
393
|
+
# Convert to arrays
|
|
394
|
+
A = np.array(features)
|
|
395
|
+
y = np.array(labels, dtype=int)
|
|
396
|
+
|
|
397
|
+
# Shuffle
|
|
398
|
+
if shuffle:
|
|
399
|
+
perm = np.random.permutation(A.shape[0])
|
|
400
|
+
A = A[perm]
|
|
401
|
+
y = y[perm]
|
|
402
|
+
|
|
403
|
+
return A, y, len(languages)
|
|
404
|
+
|
|
405
|
+
def set_lang_detect_head(self, lang_detect_head):
|
|
406
|
+
self.head = lang_detect_head
|
|
407
|
+
|
|
408
|
+
def _detect_language_vanilla(self, audio_list):
|
|
409
|
+
inputs = self.processor(audio_list, sampling_rate=16000, padding="longest", return_tensors="pt")
|
|
410
|
+
input_values = inputs.input_values.to(self.device, dtype=dtype)
|
|
411
|
+
with torch.no_grad():
|
|
412
|
+
logits = self.lid_model(input_values).logits
|
|
413
|
+
pred_ids = torch.argmax(logits, dim=-1).cpu().tolist()
|
|
414
|
+
return [self.lid_model.config.id2label[pid] for pid in pred_ids]
|
|
415
|
+
|
|
416
|
+
def predict(self, audio):
|
|
417
|
+
# Ensure audio is a list (single np.ndarray or list of them)
|
|
418
|
+
if not isinstance(audio, list):
|
|
419
|
+
audio = [audio]
|
|
420
|
+
|
|
421
|
+
batch_size = len(audio)
|
|
422
|
+
|
|
423
|
+
# Prepare batch once
|
|
424
|
+
inputs = self.processor(audio, sampling_rate=16000, padding="longest", return_tensors="pt")
|
|
425
|
+
input_values = inputs.input_values.to(self.device, dtype=dtype)
|
|
426
|
+
|
|
427
|
+
# 1. Detect language(s)
|
|
428
|
+
if self.head:
|
|
429
|
+
# Run frozen encoder to get hidden states for the head
|
|
430
|
+
with torch.no_grad():
|
|
431
|
+
encoder_out = self.model.wav2vec2(input_values, output_hidden_states=True)
|
|
432
|
+
hidden = encoder_out.last_hidden_state # (B, T, D)
|
|
433
|
+
pooled = hidden.mean(dim=1).cpu().numpy() # (B, D) → numpy for sklearn heads
|
|
434
|
+
class_ids = self.head.predict(pooled) # assume returns np.array of shape (B,)
|
|
435
|
+
detected_langs = [self.class_names[cid] for cid in class_ids]
|
|
436
|
+
else:
|
|
437
|
+
detected_langs = self._detect_language_vanilla(audio)
|
|
438
|
+
|
|
439
|
+
|
|
440
|
+
# 2. Transcribe – group by language to minimise adapter switching
|
|
441
|
+
transcriptions = [None] * batch_size
|
|
442
|
+
lang_to_indices = defaultdict(list)
|
|
443
|
+
for i, lang in enumerate(detected_langs):
|
|
444
|
+
lang_to_indices[lang].append(i)
|
|
445
|
+
|
|
446
|
+
for lang, indices in lang_to_indices.items():
|
|
447
|
+
batch_input = input_values[indices]
|
|
448
|
+
if self.current_adapter != lang:
|
|
449
|
+
self.set_adapter(lang)
|
|
450
|
+
|
|
451
|
+
with torch.no_grad():
|
|
452
|
+
logits = self.model(batch_input).logits
|
|
453
|
+
|
|
454
|
+
pred_ids = torch.argmax(logits, dim=-1)
|
|
455
|
+
trans = self.processor.batch_decode(pred_ids, skip_special_tokens=True)
|
|
456
|
+
|
|
457
|
+
for k, orig_idx in enumerate(indices):
|
|
458
|
+
transcriptions[orig_idx] = trans[k]
|
|
459
|
+
|
|
460
|
+
detected_langs = [ISO3_TO_ISO2[token] if token in ISO3_TO_ISO2 else token for token in detected_langs]
|
|
461
|
+
|
|
462
|
+
# Return single values if input was single audio, otherwise lists
|
|
463
|
+
if batch_size == 1:
|
|
464
|
+
return detected_langs[0], transcriptions[0]
|
|
465
|
+
|
|
466
|
+
|
|
467
|
+
return detected_langs, transcriptions
|
|
468
|
+
|
|
469
|
+
def get_dimensions(self):
|
|
470
|
+
return self.model.config.hidden_size
|
|
471
|
+
|
|
472
|
+
def get_device(self):
|
|
473
|
+
return self.device
|
|
474
|
+
|
|
475
|
+
def set_adapter(self, lang_id):
|
|
476
|
+
for script in POSSIBLE_SCRIPTS:
|
|
477
|
+
if script == "":
|
|
478
|
+
new_lang_id = lang_id
|
|
479
|
+
else:
|
|
480
|
+
new_lang_id = lang_id+"-script_"+script
|
|
481
|
+
|
|
482
|
+
try:
|
|
483
|
+
self.processor.tokenizer.set_target_lang(new_lang_id)
|
|
484
|
+
self.model.load_adapter(new_lang_id)
|
|
485
|
+
self.current_adapter = lang_id
|
|
486
|
+
return
|
|
487
|
+
except ValueError:
|
|
488
|
+
pass
|
|
489
|
+
|
|
490
|
+
# raise ValueError(f"No adapter found for {lang_id}")
|
|
491
|
+
new_lang_id = "eng"
|
|
492
|
+
self.processor.tokenizer.set_target_lang(new_lang_id)
|
|
493
|
+
self.model.load_adapter(new_lang_id)
|
|
494
|
+
self.current_adapter = lang_id
|
|
495
|
+
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
from .cvx_mlp import Convex_MLP
|
|
2
|
+
from ..utils.model_utils import get_grelu_patterns, grelu_optimal_weights_transform
|
|
3
|
+
import jax.numpy as jnp
|
|
4
|
+
from jax import jit, tree_util
|
|
5
|
+
|
|
6
|
+
class CVX_GReLU_MLP:
|
|
7
|
+
def __init__(self, X, y, P_S, seed, beta = None, d_diags = None, gates = None):
|
|
8
|
+
self.X = X
|
|
9
|
+
self.y = y
|
|
10
|
+
self.P_S = P_S
|
|
11
|
+
self.seed = seed
|
|
12
|
+
self.beta = beta
|
|
13
|
+
self.d_diags = d_diags
|
|
14
|
+
self.gates = gates
|
|
15
|
+
|
|
16
|
+
def init_model(self):
|
|
17
|
+
self.d_diags, self.gates, self.seed = get_grelu_patterns(self.X, self.P_S, self.seed)
|
|
18
|
+
|
|
19
|
+
@jit
|
|
20
|
+
def matvec_Fi(self, i, vec):
|
|
21
|
+
return self.d_diags[:,i] * (self.X @ vec)
|
|
22
|
+
|
|
23
|
+
@jit
|
|
24
|
+
def rmatvec_Fi(self, i, vec):
|
|
25
|
+
return self.X.T @ (self.d_diags[:,i] * vec)
|
|
26
|
+
|
|
27
|
+
@jit
|
|
28
|
+
def matvec_F(self, vec):
|
|
29
|
+
n = self.X.shape[0]
|
|
30
|
+
out = jnp.zeros((n, ))
|
|
31
|
+
for i in range(self.P_S):
|
|
32
|
+
out += self.matvec_Fi(i, vec[:, i])
|
|
33
|
+
return out
|
|
34
|
+
|
|
35
|
+
@jit
|
|
36
|
+
def rmatvec_F(self, vec):
|
|
37
|
+
n, d = self.X.shape
|
|
38
|
+
out = jnp.zeros((d, self.P_S))
|
|
39
|
+
for i in range(self.P_S):
|
|
40
|
+
rFi_v = self.rmatvec_Fi(i, vec)
|
|
41
|
+
out = out.at[:,i].set(rFi_v)
|
|
42
|
+
return out
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@jit
|
|
46
|
+
def matvec_A0(self, vec):
|
|
47
|
+
return self.rmatvec_F(self.matvec_F(vec)/self.X.shape[0])
|
|
48
|
+
|
|
49
|
+
@jit
|
|
50
|
+
def matvec_A(self, vec):
|
|
51
|
+
return self.rmatvec_F(self.matvec_F(vec)/self.X.shape[0])+self.beta*vec
|
|
52
|
+
|
|
53
|
+
def get_nvcx_weights(self, u):
|
|
54
|
+
return grelu_optimal_weights_transform(u, self.P_S, self.X.shape[1])
|
|
55
|
+
|
|
56
|
+
def predict(self, data, W1, w2):
|
|
57
|
+
d_g = (data@self.gates)>=0
|
|
58
|
+
return (d_g*(data@W1))@w2
|
|
59
|
+
|
|
60
|
+
def _tree_flatten(self):
|
|
61
|
+
children = (self.X, self.y, self.seed, self.d_diags, self.gates) # arrays / dynamic values
|
|
62
|
+
aux_data = {'P_S': self.P_S, 'beta': self.beta} # static values
|
|
63
|
+
return (children, aux_data)
|
|
64
|
+
|
|
65
|
+
@classmethod
|
|
66
|
+
def _tree_unflatten(cls, aux_data, children):
|
|
67
|
+
X, y, seed, d_diags, gates = children
|
|
68
|
+
P_S = aux_data['P_S']
|
|
69
|
+
beta = aux_data['beta']
|
|
70
|
+
return cls(X, y, P_S, seed, beta, d_diags, gates)
|
|
71
|
+
|
|
72
|
+
tree_util.register_pytree_node(CVX_GReLU_MLP,
|
|
73
|
+
CVX_GReLU_MLP._tree_flatten,
|
|
74
|
+
CVX_GReLU_MLP._tree_unflatten)
|
jaxcld/models/cvx_mlp.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
|
|
3
|
+
class Convex_MLP(ABC):
|
|
4
|
+
def __init__(self, X, y, P_S, beta, rho, seed):
|
|
5
|
+
self.X = X
|
|
6
|
+
self.y = y
|
|
7
|
+
self.P_S = P_S
|
|
8
|
+
self.beta = beta
|
|
9
|
+
self.rho = rho
|
|
10
|
+
self.seed = seed
|
|
11
|
+
self.d_diags = None
|
|
12
|
+
self.e_diags = None
|
|
13
|
+
self.Xtst = None
|
|
14
|
+
self.ytst = None
|
|
15
|
+
|
|
16
|
+
@abstractmethod
|
|
17
|
+
def init_model(self):
|
|
18
|
+
pass
|
|
19
|
+
|
|
20
|
+
@abstractmethod
|
|
21
|
+
def rmatvec_Fi(self, i, vec):
|
|
22
|
+
pass
|
|
23
|
+
|
|
24
|
+
@abstractmethod
|
|
25
|
+
def matvec_F(self, vec):
|
|
26
|
+
pass
|
|
27
|
+
|
|
28
|
+
@abstractmethod
|
|
29
|
+
def batch_matvec_F(self, vecs):
|
|
30
|
+
pass
|
|
31
|
+
|
|
32
|
+
@abstractmethod
|
|
33
|
+
def rmatvec_F(self, vec):
|
|
34
|
+
pass
|
|
35
|
+
|
|
36
|
+
@abstractmethod
|
|
37
|
+
def batch_rmatvec_F(self, vecs):
|
|
38
|
+
pass
|
|
39
|
+
|
|
40
|
+
@abstractmethod
|
|
41
|
+
def matvec_G(self, vec):
|
|
42
|
+
pass
|
|
43
|
+
|
|
44
|
+
@abstractmethod
|
|
45
|
+
def batch_matvec_G(self, vecs):
|
|
46
|
+
pass
|
|
47
|
+
|
|
48
|
+
@abstractmethod
|
|
49
|
+
def rmatvec_G(self, vec):
|
|
50
|
+
pass
|
|
51
|
+
|
|
52
|
+
@abstractmethod
|
|
53
|
+
def batch_rmatvec_G(self,vecs):
|
|
54
|
+
pass
|
|
55
|
+
|
|
56
|
+
@abstractmethod
|
|
57
|
+
def matvec_A(self, vec):
|
|
58
|
+
pass
|
|
59
|
+
|
|
60
|
+
@abstractmethod
|
|
61
|
+
def batch_matvec_A(self,vecs):
|
|
62
|
+
pass
|
|
63
|
+
|