jaxcld 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. jaxcld/__init__.py +47 -0
  2. jaxcld/models/__init__.py +11 -0
  3. jaxcld/models/asr_model.py +495 -0
  4. jaxcld/models/cvx_grelu_mlp.py +74 -0
  5. jaxcld/models/cvx_mlp.py +63 -0
  6. jaxcld/models/cvx_relu_mlp.py +120 -0
  7. jaxcld/models/get_model.py +26 -0
  8. jaxcld/models/grelu_mlp.py +44 -0
  9. jaxcld/models/lang_detect_head.py +152 -0
  10. jaxcld/models/relu_mlp.py +71 -0
  11. jaxcld/models/two_layer_mlp.py +11 -0
  12. jaxcld/optimizers/__init__.py +4 -0
  13. jaxcld/optimizers/adamW.py +39 -0
  14. jaxcld/optimizers/admm.py +103 -0
  15. jaxcld/optimizers/dadapt_adamW.py +38 -0
  16. jaxcld/optimizers/dist_shampoo/__init__.py +4 -0
  17. jaxcld/optimizers/dist_shampoo/distributed_shampoo.py +2831 -0
  18. jaxcld/optimizers/dist_shampoo/quantization_utils.py +115 -0
  19. jaxcld/optimizers/pcg.py +69 -0
  20. jaxcld/optimizers/sgd.py +38 -0
  21. jaxcld/optimizers/shampoo.py +37 -0
  22. jaxcld/optimizers/yogi.py +36 -0
  23. jaxcld/preconditioner/__init__.py +4 -0
  24. jaxcld/preconditioner/nystrom.py +102 -0
  25. jaxcld/training/__init__.py +8 -0
  26. jaxcld/training/train.py +164 -0
  27. jaxcld/training/train_no_jit.py +126 -0
  28. jaxcld/utils/__init__.py +4 -0
  29. jaxcld/utils/linops_utils.py +50 -0
  30. jaxcld/utils/load_data.py +459 -0
  31. jaxcld/utils/metric_utils.py +59 -0
  32. jaxcld/utils/model_utils.py +113 -0
  33. jaxcld/utils/opt_utils.py +31 -0
  34. jaxcld/utils/proximal_utils.py +22 -0
  35. jaxcld/utils/train_utils.py +7 -0
  36. jaxcld/utils/whisper_dataloader.py +142 -0
  37. jaxcld-0.1.0.dist-info/METADATA +80 -0
  38. jaxcld-0.1.0.dist-info/RECORD +40 -0
  39. jaxcld-0.1.0.dist-info/WHEEL +5 -0
  40. jaxcld-0.1.0.dist-info/top_level.txt +1 -0
jaxcld/__init__.py ADDED
@@ -0,0 +1,47 @@
1
+ """
2
+ `jaxcld` package public API.
3
+
4
+ The goal is to support:
5
+
6
+ from jaxcld import ASRModel, CVXNNLangDetectHead
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ __version__ = "0.1.0"
12
+
13
+ __all__ = [
14
+ "ASRModel",
15
+ "CVXNNLangDetectHead",
16
+ "NNLangDetectHead",
17
+ "SVMLangDetectHead",
18
+ ]
19
+
20
+
21
+ def __getattr__(name: str):
22
+ # Lazy imports so `import jaxcld` works even if optional heavy deps (torch, transformers)
23
+ # are not installed, while still supporting `from jaxcld import ASRModel, ...` when they are.
24
+ try:
25
+ if name == "ASRModel":
26
+ from .models.asr_model import ASRModel
27
+
28
+ return ASRModel
29
+ if name == "CVXNNLangDetectHead":
30
+ from .models.lang_detect_head import CVXNNLangDetectHead
31
+
32
+ return CVXNNLangDetectHead
33
+ if name == "NNLangDetectHead":
34
+ from .models.lang_detect_head import NNLangDetectHead
35
+
36
+ return NNLangDetectHead
37
+ if name == "SVMLangDetectHead":
38
+ from .models.lang_detect_head import SVMLangDetectHead
39
+
40
+ return SVMLangDetectHead
41
+ except ModuleNotFoundError as e:
42
+ raise ImportError(
43
+ "Missing optional dependency. Install jaxcld with its runtime dependencies, e.g. "
44
+ "`pip install -e .` (or `pip install .`) and ensure `torch`, `torchaudio`, and "
45
+ "`transformers` are available."
46
+ ) from e
47
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
@@ -0,0 +1,11 @@
1
+ """Model implementations and language detection heads.
2
+
3
+ Note: keep this module light (avoid importing torch/transformers at import time).
4
+ Import symbols from their defining modules directly, e.g.:
5
+
6
+ from jaxcld.models.asr_model import ASRModel
7
+ from jaxcld.models.lang_detect_head import CVXNNLangDetectHead
8
+ """
9
+
10
+ __all__ = []
11
+
@@ -0,0 +1,495 @@
1
+ import os, time, torch, types, pickle
2
+ import torch.nn as nn
3
+ from safetensors.torch import load_file
4
+ from transformers import WhisperForConditionalGeneration, WhisperProcessor, Wav2Vec2ForCTC, AutoProcessor, AutoModelForAudioClassification
5
+
6
+ import jax.numpy as jnp
7
+ import numpy as np
8
+ from datasets import load_from_disk
9
+ import torch
10
+ import torchaudio
11
+ from typing import Tuple
12
+ from abc import ABC, abstractmethod
13
+ from collections import defaultdict
14
+
15
+ dtype = torch.float16
16
+
17
+ ISO2_TO_ISO3 = {
18
+ "en": "eng",
19
+ "zh": "zho",
20
+ "hi": "hin",
21
+ "id": "ind",
22
+ "ms": "zlm"
23
+ }
24
+
25
+ ISO3_TO_ISO2 = {
26
+ "cdo": "zh",
27
+ "cmn": "zh",
28
+ "cpx": "zh",
29
+ "czh": "zh",
30
+ "hak": "zh",
31
+ "hsn": "zh",
32
+ "mnp": "zh",
33
+ "nan": "zh",
34
+ "wuu": "zh",
35
+ "yue": "zh",
36
+ "eng": "en",
37
+ "hin": "hi",
38
+ "zlm": "ms",
39
+ "ind": "id"
40
+ }
41
+
42
+ # Used by MMS adapter selection to try multiple script-specific adapters.
43
+ POSSIBLE_SCRIPTS = ["", "Latn", "Cyrl", "Arab", "Deva", "Hans", "Hant"]
44
+
45
+ class ASRModel(ABC):
46
+ def __init__(self, model_name, config):
47
+ """Load model here"""
48
+ self.model_name = model_name
49
+ self.config = config
50
+
51
+ @classmethod
52
+ def from_pretrained(self, model_name, config={}):
53
+ if model_name.startswith("openai/whisper"):
54
+ return Whisper(model_name, config)
55
+ elif model_name.startswith("facebook/mms"):
56
+ return MMS(model_name, config)
57
+ else:
58
+ raise ValueError(f"Unknown model name: {model_name}")
59
+
60
+ @abstractmethod
61
+ def load_data(self, dataset_path: str, caller_script: str = None, data_seed: int = 42, dataset_split: str = "train"):
62
+ pass
63
+
64
+ def load_data_jax(self, dataset_path: str, caller_script: str = None, data_seed: int = 42, dataset_split: str = "train"):
65
+ A, y = self.load_data(dataset_path, caller_script, data_seed, dataset_split)
66
+ A = jnp.array(A) # (n, 768)
67
+ y = jnp.array(y) # (n,)
68
+ return A, y
69
+
70
+ @abstractmethod
71
+ def set_lang_detect_head(self, lang_detect_head):
72
+ pass
73
+
74
+ @abstractmethod
75
+ def predict(self, audio):
76
+ """Runs transcription on the audio, returns list of language tokens and transcriptions"""
77
+ pass
78
+
79
+ @abstractmethod
80
+ def get_dimensions(self):
81
+ pass
82
+
83
+ @abstractmethod
84
+ def get_device(self):
85
+ pass
86
+
87
+
88
+
89
+ def whisper_custom_retrieve_init_tokens_creator(asr_model, languages, tokenizer):
90
+ def _custom_retrieve_init_tokens(self, input_features, batch_size, generation_config=None, **kwargs):
91
+ def lang_to_id(_, lang):
92
+ return self.generation_config.lang_to_id[f"<|{lang}|>"]
93
+
94
+ encoder_outputs = self.model.encoder(input_features, return_dict=True)
95
+ hidden = encoder_outputs.last_hidden_state
96
+ class_ids = asr_model.head.predict(hidden)
97
+
98
+ if not languages:
99
+ raise ValueError("config['languages'] must be provided (non-empty) when using a custom language detection head.")
100
+
101
+ # Head predicts indices into config['languages']
102
+ chosen_langs = []
103
+ for cid in class_ids:
104
+ try:
105
+ chosen_langs.append(languages[int(cid)])
106
+ except Exception:
107
+ chosen_langs.append(languages[0])
108
+
109
+ lang_tokens = [lang_to_id(self, lang) for lang in chosen_langs]
110
+ asr_model.lang_tokens.extend(chosen_langs)
111
+
112
+ sot_token_id = tokenizer.convert_tokens_to_ids("<|startoftranscript|>")
113
+ transcribe_token_id = tokenizer.convert_tokens_to_ids("<|transcribe|>")
114
+ init_tokens = [[sot_token_id, lang_token, transcribe_token_id] for lang_token in lang_tokens]
115
+
116
+ init_tokens_tensor = torch.tensor(init_tokens,
117
+ dtype=torch.long,
118
+ device=input_features.device)
119
+
120
+ return init_tokens_tensor
121
+
122
+ return _custom_retrieve_init_tokens
123
+
124
+
125
+ class Whisper(ASRModel):
126
+ def __init__(self, model_name, config={}):
127
+ super().__init__(model_name, config)
128
+ self.model = WhisperForConditionalGeneration.from_pretrained(model_name, device_map="auto")
129
+ self.model.to(dtype=dtype)
130
+ self.model.config.forced_decoder_ids = None
131
+ self.processor = WhisperProcessor.from_pretrained(model_name)
132
+ self.head = None # default head
133
+
134
+
135
+ def load_data(self, dataset_path: str, caller_script: str = None, data_seed: int = 42, dataset_split: str = "train", shuffle=True):
136
+ """
137
+ Load HF dataset, extract pooled model hidden states, return train/test splits.
138
+
139
+ Args:
140
+ dataset_path (str): Path to local HF dataset dir (splits: train, valid, test).
141
+ caller_script (str): 'defrun' for 90% data (convex training); else full.
142
+ data_seed (int): Seed for shuffle/split.
143
+
144
+ Returns:
145
+ Atr, ytr, Atst, ytst, ntr, ntst: JAX arrays for features/labels (pooled to 768 dim).
146
+ """
147
+ np.random.seed(data_seed)
148
+
149
+ # Load train split (main data for training)
150
+ dataset = load_from_disk(dataset_path)
151
+ train_data = dataset[dataset_split]
152
+ print(f"Loaded {len(train_data)} train samples")
153
+
154
+ languages = self.config.get("languages")
155
+ if not languages:
156
+ # Infer languages from the dataset split and persist for downstream consumers.
157
+ languages = sorted({sample.get("lang") for sample in train_data if sample.get("lang") is not None})
158
+ self.config["languages"] = languages
159
+
160
+ lang_to_index = {lang: i for i, lang in enumerate(languages)}
161
+
162
+ # Load Whisper encoder
163
+ self.model.eval()
164
+
165
+ def extract_pooled_hidden(audio) -> np.ndarray:
166
+ """Extract and pool last hidden states to (768,)."""
167
+ # Handle audio dict or path
168
+ if isinstance(audio, dict):
169
+ if audio.get('array') is not None:
170
+ audio_arr = audio['array']
171
+ sr = audio['sampling_rate']
172
+ else:
173
+ audio_path = audio['path']
174
+ if not os.path.exists(audio_path):
175
+ return None
176
+ waveform, sr = torchaudio.load(audio_path)
177
+ audio_arr = waveform.mean(0).numpy()
178
+ else:
179
+ # Assume path if not dict
180
+ if not os.path.exists(audio):
181
+ return None
182
+ waveform, sr = torchaudio.load(audio)
183
+ audio_arr = waveform.mean(0).numpy()
184
+
185
+ # Resample to 16kHz
186
+ if sr != 16000:
187
+ resampler = torchaudio.transforms.Resample(sr, 16000)
188
+ audio_arr = resampler(torch.tensor(audio_arr)).numpy()
189
+
190
+ # Process to input_features
191
+ inputs = self.processor(audio_arr, sampling_rate=16000, return_tensors='pt').to(self.get_device(), dtype=dtype)
192
+
193
+ # Encoder last hidden
194
+ with torch.no_grad():
195
+ encoder_outputs = self.model.model.encoder(inputs.input_features, output_hidden_states=True)
196
+ hidden = encoder_outputs.last_hidden_state.squeeze(0) # (seq_len, 768)
197
+
198
+ # Pool: Mean over seq_len
199
+ pooled = hidden.mean(0).cpu().numpy() # (768,)
200
+ return pooled
201
+
202
+ # Extract features and labels for all train samples
203
+ features = []
204
+ labels = []
205
+ valid_count = 0
206
+ for sample in train_data:
207
+ hidden = extract_pooled_hidden(sample['audio'])
208
+ if hidden is None:
209
+ continue # Skip invalid audio
210
+
211
+ label = lang_to_index.get(sample.get("lang"))
212
+ if label is None:
213
+ continue
214
+ features.append(hidden)
215
+ labels.append(label)
216
+ valid_count += 1
217
+
218
+ if valid_count == 0:
219
+ raise ValueError("No valid audio samples found")
220
+ print(f"Extracted {valid_count} valid samples across {len(languages)} language(s)")
221
+
222
+ # Convert to arrays
223
+ A = np.array(features)
224
+ y = np.array(labels, dtype=int)
225
+
226
+ # Shuffle
227
+ if shuffle:
228
+ perm = np.random.permutation(A.shape[0])
229
+ A = A[perm]
230
+ y = y[perm]
231
+
232
+ return A, y
233
+
234
+ def set_lang_detect_head(self, lang_detect_head):
235
+ self.head = lang_detect_head
236
+ if self.head:
237
+ languages = self.config.get("languages")
238
+ if not languages:
239
+ raise ValueError("config['languages'] must be provided when using a custom language detection head.")
240
+ self.model._retrieve_init_tokens = types.MethodType(
241
+ whisper_custom_retrieve_init_tokens_creator(self, languages, self.processor.tokenizer),
242
+ self.model,
243
+ )
244
+
245
+ def _detect_language_vanilla(self, input_features):
246
+ # 50258 is the token for transcribing
247
+ batch_size = input_features.shape[0]
248
+ device = input_features.device
249
+ decoder_input_ids = torch.full((batch_size, 1), 50258, dtype=torch.long, device=device)
250
+ model_output = self.model(input_features, decoder_input_ids=decoder_input_ids)
251
+ logits = model_output.logits[:, -1, :] # Shape: (batch_size, vocab_size)
252
+
253
+ # Language tokens in Whisper multilingual models are IDs 50263 to 50361 (99 languages)
254
+ # Compute probabilities and detect the most likely language per batch item
255
+ language_probs = torch.softmax(logits, dim=-1)
256
+ language_indices = torch.argmax(language_probs, dim=-1) # Shape: (batch_size,)
257
+
258
+ # Map indices to language codes (sorted list of Whisper's 99 supported languages)
259
+ detected_languages = [self.id_to_lang(x.item()) for x in language_indices]
260
+
261
+ # Return list of detected languages (one per batch item); also return probs if needed
262
+ return detected_languages # e.g., ['en'] for batch_size=1
263
+
264
+ def predict(self, audio):
265
+ input_features = self.processor(audio, sampling_rate=16000, return_tensors="pt").input_features
266
+ input_features = input_features.to(self.get_device(), dtype=dtype)
267
+
268
+ self.lang_tokens = []
269
+ predicted_ids = self.model.generate(input_features, repetition_penalty=1.1, temperature=(0.0, 0.2, 0.4, 0.6, 0.8, 1.0), compression_ratio_threshold=1.35)
270
+ transcription = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)
271
+ if(self.head is None or getattr(self.head, "SKIP", False)):
272
+ self.lang_tokens = self._detect_language_vanilla(input_features)
273
+ return self.lang_tokens, transcription
274
+
275
+ def get_dimensions(self):
276
+ return self.model.config.d_model
277
+
278
+ def get_device(self):
279
+ return next(self.model.model.encoder.layers[-1].parameters()).device
280
+
281
+ def lang_to_id(self, lang):
282
+ lang_code = f"<|{lang}|>"
283
+ return self.model.generation_config.lang_to_id[lang_code]
284
+
285
+ def id_to_lang(self, tid):
286
+ id_to_lang_mapping = dict(zip(self.model.generation_config.lang_to_id.values(), self.model.generation_config.lang_to_id.keys()))
287
+ return id_to_lang_mapping.get(tid, " ")[2:-2]
288
+
289
+
290
+ class MMS(ASRModel):
291
+ def __init__(self, model_name: str, config: dict = {}):
292
+ super().__init__(model_name, config)
293
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
294
+ print(model_name)
295
+ self.processor = AutoProcessor.from_pretrained(model_name)
296
+ self.model = Wav2Vec2ForCTC.from_pretrained(model_name).to(self.device, dtype=dtype)
297
+ self.lid_model = AutoModelForAudioClassification.from_pretrained("facebook/mms-lid-126").to(self.device, dtype=dtype)
298
+ self.head = None
299
+ self.current_adapter = None
300
+
301
+ self.iso2_to_iso3 = ISO2_TO_ISO3
302
+ self.languages = config.get("languages") or []
303
+ # Back-compat: older config used "class_names" for iso2 codes.
304
+ if not self.languages and config.get("class_names"):
305
+ self.languages = list(config.get("class_names"))
306
+ self.class_names = [self.iso2_to_iso3.get(cid, cid) for cid in self.languages]
307
+
308
+ def load_data(self, dataset_path: str, caller_script: str = None, data_seed: int = 42, dataset_split: str = "train", shuffle=True):
309
+ """
310
+ Load HF dataset, extract pooled model hidden states, return train/test splits.
311
+
312
+ Args:
313
+ dataset_path (str): Path to local HF dataset dir (splits: train, valid, test).
314
+ caller_script (str): 'defrun' for 90% data (convex training); else full.
315
+ data_seed (int): Seed for shuffle/split.
316
+
317
+ Returns:
318
+ Atr, ytr, Atst, ytst, ntr, ntst: JAX arrays for features/labels (pooled to 768 dim).
319
+ """
320
+ np.random.seed(data_seed)
321
+
322
+ # Load train split (main data for training)
323
+ dataset = load_from_disk(dataset_path)
324
+ train_data = dataset[dataset_split]
325
+ print(f"Loaded {len(train_data)} train samples")
326
+
327
+ languages = self.config.get("languages") or self.languages
328
+ if not languages:
329
+ languages = sorted({sample.get("lang") for sample in train_data if sample.get("lang") is not None})
330
+ self.config["languages"] = languages
331
+ self.languages = languages
332
+ self.class_names = [self.iso2_to_iso3.get(cid, cid) for cid in self.languages]
333
+ lang_to_index = {lang: i for i, lang in enumerate(languages)}
334
+
335
+
336
+ def extract_pooled_hidden(audio) -> np.ndarray:
337
+ """Extract and pool last hidden states to (768,)."""
338
+ # Handle audio dict or path
339
+ if isinstance(audio, dict):
340
+ if audio.get('array') is not None:
341
+ audio_arr = audio['array']
342
+ sr = audio['sampling_rate']
343
+ else:
344
+ audio_path = audio['path']
345
+ if not os.path.exists(audio_path):
346
+ return None
347
+ waveform, sr = torchaudio.load(audio_path)
348
+ audio_arr = waveform.mean(0).numpy()
349
+ else:
350
+ # Assume path if not dict
351
+ if not os.path.exists(audio):
352
+ return None
353
+ waveform, sr = torchaudio.load(audio)
354
+ audio_arr = waveform.mean(0).numpy()
355
+
356
+ # Resample to 16kHz
357
+ if sr != 16000:
358
+ resampler = torchaudio.transforms.Resample(sr, 16000)
359
+ audio_arr = resampler(torch.tensor(audio_arr)).numpy()
360
+
361
+ # Process to input_features
362
+ inputs = self.processor(audio_arr, sampling_rate=16000, return_tensors='pt').to(self.get_device(), dtype=dtype)
363
+
364
+ # Encoder last hidden
365
+ with torch.no_grad():
366
+ encoder_outputs = self.model.wav2vec2(inputs.input_values, output_hidden_states=True)
367
+ hidden = encoder_outputs.last_hidden_state.squeeze(0) # (seq_len, 768)
368
+
369
+ # Pool: Mean over seq_len
370
+ pooled = hidden.mean(0).cpu().numpy() # (768,)
371
+ return pooled
372
+
373
+ # Extract features and labels for all train samples
374
+ features = []
375
+ labels = []
376
+ valid_count = 0
377
+ for sample in train_data:
378
+ hidden = extract_pooled_hidden(sample['audio'])
379
+ if hidden is None:
380
+ continue # Skip invalid audio
381
+
382
+ label = lang_to_index.get(sample.get("lang"))
383
+ if label is None:
384
+ continue
385
+ features.append(hidden)
386
+ labels.append(label)
387
+ valid_count += 1
388
+
389
+ if valid_count == 0:
390
+ raise ValueError("No valid audio samples found")
391
+ print(f"Extracted {valid_count} valid samples across {len(languages)} language(s)")
392
+
393
+ # Convert to arrays
394
+ A = np.array(features)
395
+ y = np.array(labels, dtype=int)
396
+
397
+ # Shuffle
398
+ if shuffle:
399
+ perm = np.random.permutation(A.shape[0])
400
+ A = A[perm]
401
+ y = y[perm]
402
+
403
+ return A, y, len(languages)
404
+
405
+ def set_lang_detect_head(self, lang_detect_head):
406
+ self.head = lang_detect_head
407
+
408
+ def _detect_language_vanilla(self, audio_list):
409
+ inputs = self.processor(audio_list, sampling_rate=16000, padding="longest", return_tensors="pt")
410
+ input_values = inputs.input_values.to(self.device, dtype=dtype)
411
+ with torch.no_grad():
412
+ logits = self.lid_model(input_values).logits
413
+ pred_ids = torch.argmax(logits, dim=-1).cpu().tolist()
414
+ return [self.lid_model.config.id2label[pid] for pid in pred_ids]
415
+
416
+ def predict(self, audio):
417
+ # Ensure audio is a list (single np.ndarray or list of them)
418
+ if not isinstance(audio, list):
419
+ audio = [audio]
420
+
421
+ batch_size = len(audio)
422
+
423
+ # Prepare batch once
424
+ inputs = self.processor(audio, sampling_rate=16000, padding="longest", return_tensors="pt")
425
+ input_values = inputs.input_values.to(self.device, dtype=dtype)
426
+
427
+ # 1. Detect language(s)
428
+ if self.head:
429
+ # Run frozen encoder to get hidden states for the head
430
+ with torch.no_grad():
431
+ encoder_out = self.model.wav2vec2(input_values, output_hidden_states=True)
432
+ hidden = encoder_out.last_hidden_state # (B, T, D)
433
+ pooled = hidden.mean(dim=1).cpu().numpy() # (B, D) → numpy for sklearn heads
434
+ class_ids = self.head.predict(pooled) # assume returns np.array of shape (B,)
435
+ detected_langs = [self.class_names[cid] for cid in class_ids]
436
+ else:
437
+ detected_langs = self._detect_language_vanilla(audio)
438
+
439
+
440
+ # 2. Transcribe – group by language to minimise adapter switching
441
+ transcriptions = [None] * batch_size
442
+ lang_to_indices = defaultdict(list)
443
+ for i, lang in enumerate(detected_langs):
444
+ lang_to_indices[lang].append(i)
445
+
446
+ for lang, indices in lang_to_indices.items():
447
+ batch_input = input_values[indices]
448
+ if self.current_adapter != lang:
449
+ self.set_adapter(lang)
450
+
451
+ with torch.no_grad():
452
+ logits = self.model(batch_input).logits
453
+
454
+ pred_ids = torch.argmax(logits, dim=-1)
455
+ trans = self.processor.batch_decode(pred_ids, skip_special_tokens=True)
456
+
457
+ for k, orig_idx in enumerate(indices):
458
+ transcriptions[orig_idx] = trans[k]
459
+
460
+ detected_langs = [ISO3_TO_ISO2[token] if token in ISO3_TO_ISO2 else token for token in detected_langs]
461
+
462
+ # Return single values if input was single audio, otherwise lists
463
+ if batch_size == 1:
464
+ return detected_langs[0], transcriptions[0]
465
+
466
+
467
+ return detected_langs, transcriptions
468
+
469
+ def get_dimensions(self):
470
+ return self.model.config.hidden_size
471
+
472
+ def get_device(self):
473
+ return self.device
474
+
475
+ def set_adapter(self, lang_id):
476
+ for script in POSSIBLE_SCRIPTS:
477
+ if script == "":
478
+ new_lang_id = lang_id
479
+ else:
480
+ new_lang_id = lang_id+"-script_"+script
481
+
482
+ try:
483
+ self.processor.tokenizer.set_target_lang(new_lang_id)
484
+ self.model.load_adapter(new_lang_id)
485
+ self.current_adapter = lang_id
486
+ return
487
+ except ValueError:
488
+ pass
489
+
490
+ # raise ValueError(f"No adapter found for {lang_id}")
491
+ new_lang_id = "eng"
492
+ self.processor.tokenizer.set_target_lang(new_lang_id)
493
+ self.model.load_adapter(new_lang_id)
494
+ self.current_adapter = lang_id
495
+
@@ -0,0 +1,74 @@
1
+ from .cvx_mlp import Convex_MLP
2
+ from ..utils.model_utils import get_grelu_patterns, grelu_optimal_weights_transform
3
+ import jax.numpy as jnp
4
+ from jax import jit, tree_util
5
+
6
+ class CVX_GReLU_MLP:
7
+ def __init__(self, X, y, P_S, seed, beta = None, d_diags = None, gates = None):
8
+ self.X = X
9
+ self.y = y
10
+ self.P_S = P_S
11
+ self.seed = seed
12
+ self.beta = beta
13
+ self.d_diags = d_diags
14
+ self.gates = gates
15
+
16
+ def init_model(self):
17
+ self.d_diags, self.gates, self.seed = get_grelu_patterns(self.X, self.P_S, self.seed)
18
+
19
+ @jit
20
+ def matvec_Fi(self, i, vec):
21
+ return self.d_diags[:,i] * (self.X @ vec)
22
+
23
+ @jit
24
+ def rmatvec_Fi(self, i, vec):
25
+ return self.X.T @ (self.d_diags[:,i] * vec)
26
+
27
+ @jit
28
+ def matvec_F(self, vec):
29
+ n = self.X.shape[0]
30
+ out = jnp.zeros((n, ))
31
+ for i in range(self.P_S):
32
+ out += self.matvec_Fi(i, vec[:, i])
33
+ return out
34
+
35
+ @jit
36
+ def rmatvec_F(self, vec):
37
+ n, d = self.X.shape
38
+ out = jnp.zeros((d, self.P_S))
39
+ for i in range(self.P_S):
40
+ rFi_v = self.rmatvec_Fi(i, vec)
41
+ out = out.at[:,i].set(rFi_v)
42
+ return out
43
+
44
+
45
+ @jit
46
+ def matvec_A0(self, vec):
47
+ return self.rmatvec_F(self.matvec_F(vec)/self.X.shape[0])
48
+
49
+ @jit
50
+ def matvec_A(self, vec):
51
+ return self.rmatvec_F(self.matvec_F(vec)/self.X.shape[0])+self.beta*vec
52
+
53
+ def get_nvcx_weights(self, u):
54
+ return grelu_optimal_weights_transform(u, self.P_S, self.X.shape[1])
55
+
56
+ def predict(self, data, W1, w2):
57
+ d_g = (data@self.gates)>=0
58
+ return (d_g*(data@W1))@w2
59
+
60
+ def _tree_flatten(self):
61
+ children = (self.X, self.y, self.seed, self.d_diags, self.gates) # arrays / dynamic values
62
+ aux_data = {'P_S': self.P_S, 'beta': self.beta} # static values
63
+ return (children, aux_data)
64
+
65
+ @classmethod
66
+ def _tree_unflatten(cls, aux_data, children):
67
+ X, y, seed, d_diags, gates = children
68
+ P_S = aux_data['P_S']
69
+ beta = aux_data['beta']
70
+ return cls(X, y, P_S, seed, beta, d_diags, gates)
71
+
72
+ tree_util.register_pytree_node(CVX_GReLU_MLP,
73
+ CVX_GReLU_MLP._tree_flatten,
74
+ CVX_GReLU_MLP._tree_unflatten)
@@ -0,0 +1,63 @@
1
+ from abc import ABC, abstractmethod
2
+
3
+ class Convex_MLP(ABC):
4
+ def __init__(self, X, y, P_S, beta, rho, seed):
5
+ self.X = X
6
+ self.y = y
7
+ self.P_S = P_S
8
+ self.beta = beta
9
+ self.rho = rho
10
+ self.seed = seed
11
+ self.d_diags = None
12
+ self.e_diags = None
13
+ self.Xtst = None
14
+ self.ytst = None
15
+
16
+ @abstractmethod
17
+ def init_model(self):
18
+ pass
19
+
20
+ @abstractmethod
21
+ def rmatvec_Fi(self, i, vec):
22
+ pass
23
+
24
+ @abstractmethod
25
+ def matvec_F(self, vec):
26
+ pass
27
+
28
+ @abstractmethod
29
+ def batch_matvec_F(self, vecs):
30
+ pass
31
+
32
+ @abstractmethod
33
+ def rmatvec_F(self, vec):
34
+ pass
35
+
36
+ @abstractmethod
37
+ def batch_rmatvec_F(self, vecs):
38
+ pass
39
+
40
+ @abstractmethod
41
+ def matvec_G(self, vec):
42
+ pass
43
+
44
+ @abstractmethod
45
+ def batch_matvec_G(self, vecs):
46
+ pass
47
+
48
+ @abstractmethod
49
+ def rmatvec_G(self, vec):
50
+ pass
51
+
52
+ @abstractmethod
53
+ def batch_rmatvec_G(self,vecs):
54
+ pass
55
+
56
+ @abstractmethod
57
+ def matvec_A(self, vec):
58
+ pass
59
+
60
+ @abstractmethod
61
+ def batch_matvec_A(self,vecs):
62
+ pass
63
+