lattifai 1.0.5__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
lattifai/__init__.py CHANGED
@@ -52,28 +52,27 @@ except Exception:
52
52
  __version__ = "0.1.0" # fallback version
53
53
 
54
54
 
55
- # Check and auto-install k2 if not present
56
- def _check_and_install_k2():
57
- """Check if k2 is installed and attempt to install it if not."""
55
+ # Check and auto-install k2py if not present
56
+ def _check_and_install_k2py():
57
+ """Check if k2py is installed and attempt to install it if not."""
58
58
  try:
59
- import k2
59
+ import k2py
60
60
  except ImportError:
61
61
  import subprocess
62
62
 
63
- print("k2 is not installed. Attempting to install k2...")
63
+ print("k2py is not installed. Attempting to install k2py...")
64
64
  try:
65
- subprocess.check_call([sys.executable, "-m", "pip", "install", "install-k2"])
66
- subprocess.check_call([sys.executable, "-m", "install_k2"])
67
- import k2 # Try importing again after installation
65
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "k2py"])
66
+ import k2py # Try importing again after installation
68
67
 
69
- print("k2 installed successfully.")
68
+ print("k2py installed successfully.")
70
69
  except Exception as e:
71
- warnings.warn(f"Failed to install k2 automatically. Please install it manually. Error: {e}")
70
+ warnings.warn(f"Failed to install k2py automatically. Please install it manually. Error: {e}")
72
71
  return True
73
72
 
74
73
 
75
- # Auto-install k2 on first import
76
- _check_and_install_k2()
74
+ # Auto-install k2py on first import
75
+ _check_and_install_k2py()
77
76
 
78
77
 
79
78
  __all__ = [
@@ -4,7 +4,6 @@ from typing import Any, List, Optional, Tuple
4
4
 
5
5
  import colorful
6
6
  import numpy as np
7
- import torch
8
7
 
9
8
  from lattifai.audio2 import AudioData
10
9
  from lattifai.caption import Supervision
@@ -35,31 +34,55 @@ class Lattice1Aligner(object):
35
34
  raise ValueError("AlignmentConfig.client_wrapper is not set. It must be initialized by the client.")
36
35
 
37
36
  client_wrapper = config.client_wrapper
38
- model_path = _resolve_model_path(config.model_name)
37
+ # Resolve model path using configured model hub
38
+ model_path = _resolve_model_path(config.model_name, getattr(config, "model_hub", "huggingface"))
39
39
 
40
- self.tokenizer = _load_tokenizer(client_wrapper, model_path, config.model_name, config.device)
40
+ self.tokenizer = _load_tokenizer(
41
+ client_wrapper, model_path, config.model_name, config.device, model_hub=config.model_hub
42
+ )
41
43
  self.worker = _load_worker(model_path, config.device, config)
42
44
 
43
45
  self.frame_shift = self.worker.frame_shift
44
46
 
45
- def emission(self, ndarray: np.ndarray) -> torch.Tensor:
47
+ def emission(self, ndarray: np.ndarray) -> np.ndarray:
46
48
  """Generate emission probabilities from audio ndarray.
47
49
 
48
50
  Args:
49
51
  ndarray: Audio data as numpy array of shape (1, T) or (C, T)
50
52
 
51
53
  Returns:
52
- Emission tensor of shape (1, T, vocab_size)
54
+ Emission numpy array of shape (1, T, vocab_size)
53
55
  """
54
56
  return self.worker.emission(ndarray)
55
57
 
58
+ def separate(self, audio: np.ndarray) -> np.ndarray:
59
+ """Separate audio using separator model.
60
+
61
+ Args:
62
+ audio: np.ndarray object containing the audio to separate, shape (1, T)
63
+
64
+ Returns:
65
+ Separated audio as numpy array
66
+
67
+ Raises:
68
+ RuntimeError: If separator model is not available
69
+ """
70
+ if self.worker.separator_ort is None:
71
+ raise RuntimeError("Separator model not available. separator.onnx not found in model path.")
72
+ # Run separator model
73
+ separator_output = self.worker.separator_ort.run(
74
+ None,
75
+ {"audios": audio},
76
+ )
77
+ return separator_output[0]
78
+
56
79
  def alignment(
57
80
  self,
58
81
  audio: AudioData,
59
82
  supervisions: List[Supervision],
60
83
  split_sentence: Optional[bool] = False,
61
84
  return_details: Optional[bool] = False,
62
- emission: Optional[torch.Tensor] = None,
85
+ emission: Optional[np.ndarray] = None,
63
86
  offset: float = 0.0,
64
87
  verbose: bool = True,
65
88
  ) -> Tuple[List[Supervision], List[Supervision]]:
@@ -120,7 +143,12 @@ class Lattice1Aligner(object):
120
143
  safe_print(colorful.cyan("🎯 Step 4: Decoding lattice results to aligned segments"))
121
144
  try:
122
145
  alignments = self.tokenizer.detokenize(
123
- lattice_id, lattice_results, supervisions=supervisions, return_details=return_details
146
+ lattice_id,
147
+ lattice_results,
148
+ supervisions=supervisions,
149
+ return_details=return_details,
150
+ start_margin=self.config.start_margin,
151
+ end_margin=self.config.end_margin,
124
152
  )
125
153
  if verbose:
126
154
  safe_print(colorful.green(f" ✓ Successfully aligned {len(alignments)} segments"))
@@ -137,3 +165,7 @@ class Lattice1Aligner(object):
137
165
  raise
138
166
  except Exception as e:
139
167
  raise e
168
+
169
+ def profile(self) -> None:
170
+ """Print profiling statistics."""
171
+ self.worker.profile()
@@ -1,11 +1,12 @@
1
1
  import json
2
2
  import time
3
3
  from collections import defaultdict
4
+ from pathlib import Path
4
5
  from typing import Any, Dict, Optional, Tuple
5
6
 
7
+ import colorful
6
8
  import numpy as np
7
9
  import onnxruntime as ort
8
- import torch
9
10
  from lhotse import FbankConfig
10
11
  from lhotse.features.kaldi.layers import Wav2LogFilterBank
11
12
  from lhotse.utils import Pathlike
@@ -13,6 +14,7 @@ from tqdm import tqdm
13
14
 
14
15
  from lattifai.audio2 import AudioData
15
16
  from lattifai.errors import AlignmentError, DependencyError, ModelLoadError
17
+ from lattifai.utils import safe_print
16
18
 
17
19
 
18
20
  class Lattice1Worker:
@@ -60,104 +62,84 @@ class Lattice1Worker:
60
62
  except Exception as e:
61
63
  raise ModelLoadError(f"acoustic model from {model_path}", original_error=e)
62
64
 
65
+ # Get vocab_size from model output
66
+ self.vocab_size = self.acoustic_ort.get_outputs()[0].shape[-1]
67
+
63
68
  # get input_names
64
69
  input_names = [inp.name for inp in self.acoustic_ort.get_inputs()]
65
- if "audios" not in input_names:
70
+ assert "audios" in input_names, f"Input name audios not found in {input_names}"
71
+
72
+ # Initialize separator if available
73
+ separator_model_path = Path(model_path) / "separator.onnx"
74
+ if separator_model_path.exists():
66
75
  try:
67
- config = FbankConfig(num_mel_bins=80, device=device, snip_edges=False)
68
- config_dict = config.to_dict()
69
- config_dict.pop("device")
70
- self.extractor = Wav2LogFilterBank(**config_dict).to(device).eval()
76
+ self.separator_ort = ort.InferenceSession(
77
+ str(separator_model_path),
78
+ providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
79
+ )
71
80
  except Exception as e:
72
- raise ModelLoadError(f"feature extractor for device {device}", original_error=e)
81
+ raise ModelLoadError(f"separator model from {model_path}", original_error=e)
73
82
  else:
74
- self.extractor = None # ONNX model includes feature extractor
83
+ self.separator_ort = None
75
84
 
76
- self.device = torch.device(device)
77
85
  self.timings = defaultdict(lambda: 0.0)
78
86
 
79
87
  @property
80
88
  def frame_shift(self) -> float:
81
89
  return 0.02 # 20 ms
82
90
 
83
- @torch.inference_mode()
84
- def emission(self, ndarray: np.ndarray, acoustic_scale: float = 1.0, device: Optional[str] = None) -> torch.Tensor:
91
+ def emission(self, ndarray: np.ndarray, acoustic_scale: float = 1.0) -> np.ndarray:
85
92
  """Generate emission probabilities from audio ndarray.
86
93
 
87
94
  Args:
88
95
  ndarray: Audio data as numpy array of shape (1, T) or (C, T)
89
96
 
90
97
  Returns:
91
- Emission tensor of shape (1, T, vocab_size)
98
+ Emission numpy array of shape (1, T, vocab_size)
92
99
  """
93
100
  _start = time.time()
94
- if self.extractor is not None:
95
- # audio -> features -> emission
96
- audio = torch.from_numpy(ndarray).to(self.device)
97
- if audio.shape[1] < 160:
98
- audio = torch.nn.functional.pad(audio, (0, 320 - audio.shape[1]))
99
- features = self.extractor(audio) # (1, T, D)
100
- if features.shape[1] > 6000:
101
- emissions = []
102
- for start in range(0, features.size(1), 6000):
103
- _features = features[:, start : start + 6000, :]
104
- ort_inputs = {
105
- "features": _features.cpu().numpy(),
106
- "feature_lengths": np.array([_features.size(1)], dtype=np.int64),
107
- }
108
- emission = self.acoustic_ort.run(None, ort_inputs)[0] # (1, T, vocab_size) numpy
109
- emissions.append(emission)
110
- emission = torch.cat(
111
- [torch.from_numpy(emission).to(device or self.device) for emission in emissions], dim=1
112
- ) # (1, T, vocab_size)
113
- del emissions
114
- else:
115
- ort_inputs = {
116
- "features": features.cpu().numpy(),
117
- "feature_lengths": np.array([features.size(1)], dtype=np.int64),
118
- }
119
- emission = self.acoustic_ort.run(None, ort_inputs)[0] # (1, T, vocab_size) numpy
120
- emission = torch.from_numpy(emission).to(device or self.device)
101
+
102
+ if ndarray.shape[1] < 160:
103
+ ndarray = np.pad(ndarray, ((0, 0), (0, 320 - ndarray.shape[1])), mode="constant")
104
+
105
+ CHUNK_SIZE = 60 * 16000 # 60 seconds
106
+ total_samples = ndarray.shape[1]
107
+
108
+ if total_samples > CHUNK_SIZE:
109
+ frame_samples = int(16000 * self.frame_shift)
110
+ emissions = np.empty((1, total_samples // frame_samples + 1, self.vocab_size), dtype=np.float32)
111
+ for start in range(0, total_samples, CHUNK_SIZE):
112
+ chunk = ndarray[:, start : start + CHUNK_SIZE]
113
+ if chunk.shape[1] < 160:
114
+ chunk = np.pad(chunk, ((0, 0), (0, 320 - chunk.shape[1])), mode="constant")
115
+
116
+ emission_out = self.acoustic_ort.run(None, {"audios": chunk})[0]
117
+ if acoustic_scale != 1.0:
118
+ emission_out *= acoustic_scale
119
+ sf = start // frame_samples # start frame
120
+ lf = sf + emission_out.shape[1] # last frame
121
+ emissions[0, sf:lf, :] = emission_out
122
+ emissions[:, lf:, :] = 0.0
121
123
  else:
122
- if ndarray.shape[1] < 160:
123
- ndarray = np.pad(ndarray, ((0, 0), (0, 320 - ndarray.shape[1])), mode="constant")
124
-
125
- CHUNK_SIZE = 60 * 16000 # 60 seconds
126
- if ndarray.shape[1] > CHUNK_SIZE:
127
- emissions = []
128
- for start in range(0, ndarray.shape[1], CHUNK_SIZE):
129
- emission = self.acoustic_ort.run(
130
- None,
131
- {
132
- "audios": ndarray[:, start : start + CHUNK_SIZE],
133
- },
134
- ) # (1, T, vocab_size) numpy
135
- emissions.append(emission[0])
136
-
137
- emission = torch.cat(
138
- [torch.from_numpy(emission).to(device or self.device) for emission in emissions], dim=1
139
- ) # (1, T, vocab_size)
140
- del emissions
141
- else:
142
- emission = self.acoustic_ort.run(
143
- None,
144
- {
145
- "audios": ndarray,
146
- },
147
- ) # (1, T, vocab_size) numpy
148
- emission = torch.from_numpy(emission[0]).to(device or self.device)
124
+ emission_out = self.acoustic_ort.run(
125
+ None,
126
+ {
127
+ "audios": ndarray,
128
+ },
129
+ ) # (1, T, vocab_size) numpy
130
+ emissions = emission_out[0]
149
131
 
150
- if acoustic_scale != 1.0:
151
- emission = emission.mul_(acoustic_scale)
132
+ if acoustic_scale != 1.0:
133
+ emissions *= acoustic_scale
152
134
 
153
135
  self.timings["emission"] += time.time() - _start
154
- return emission # (1, T, vocab_size) torch
136
+ return emissions # (1, T, vocab_size) numpy
155
137
 
156
138
  def alignment(
157
139
  self,
158
140
  audio: AudioData,
159
141
  lattice_graph: Tuple[str, int, float],
160
- emission: Optional[torch.Tensor] = None,
142
+ emission: Optional[np.ndarray] = None,
161
143
  offset: float = 0.0,
162
144
  ) -> Dict[str, Any]:
163
145
  """Process audio with LatticeGraph.
@@ -165,7 +147,7 @@ class Lattice1Worker:
165
147
  Args:
166
148
  audio: AudioData object
167
149
  lattice_graph: LatticeGraph data
168
- emission: Pre-computed emission tensor (ignored if streaming=True)
150
+ emission: Pre-computed emission numpy array (ignored if streaming=True)
169
151
  offset: Time offset for the audio
170
152
  streaming: If True, use streaming mode for memory-efficient processing
171
153
 
@@ -178,25 +160,18 @@ class Lattice1Worker:
178
160
  AlignmentError: If alignment process fails
179
161
  """
180
162
  try:
181
- import k2
182
- except ImportError:
183
- raise DependencyError("k2", install_command="pip install install-k2 && python -m install_k2")
184
-
185
- try:
186
- from lattifai_core.lattice.decode import align_segments
163
+ import k2py as k2
187
164
  except ImportError:
188
- raise DependencyError("lattifai_core", install_command="Contact support for lattifai_core installation")
165
+ raise DependencyError("k2py", install_command="pip install k2py")
189
166
 
190
167
  lattice_graph_str, final_state, acoustic_scale = lattice_graph
191
168
 
192
169
  _start = time.time()
193
170
  try:
194
- # Create decoding graph
195
- decoding_graph = k2.Fsa.from_str(lattice_graph_str, acceptor=False)
196
- decoding_graph.requires_grad_(False)
197
- decoding_graph = k2.arc_sort(decoding_graph)
198
- decoding_graph.skip_id = int(final_state)
199
- decoding_graph.return_id = int(final_state + 1)
171
+ # Create decoding graph using k2py
172
+ graph_dict = k2.CreateFsaVecFromStr(lattice_graph_str, int(final_state), False)
173
+ decoding_graph = graph_dict["fsa"]
174
+ aux_labels = graph_dict["aux_labels"]
200
175
  except Exception as e:
201
176
  raise AlignmentError(
202
177
  "Failed to create decoding graph from lattice",
@@ -204,11 +179,6 @@ class Lattice1Worker:
204
179
  )
205
180
  self.timings["decoding_graph"] += time.time() - _start
206
181
 
207
- if self.device.type == "mps":
208
- device = "cpu" # k2 does not support mps yet
209
- else:
210
- device = self.device
211
-
212
182
  _start = time.time()
213
183
 
214
184
  # Get beam search parameters from config or use defaults
@@ -218,71 +188,54 @@ class Lattice1Worker:
218
188
  max_active_states = self.alignment_config.max_active_states or 10000
219
189
 
220
190
  if emission is None and audio.streaming_mode:
221
- # Streaming mode: pass emission iterator to align_segments
222
- # The align_segments function will automatically detect the iterator
223
- # and use k2.OnlineDenseIntersecter for memory-efficient processing
224
-
225
- def emission_iterator():
226
- """Generate emissions for each audio chunk with progress tracking."""
227
- total_duration = audio.duration
228
- processed_duration = 0.0
229
- total_minutes = int(total_duration / 60.0)
230
-
231
- with tqdm(
232
- total=total_minutes,
233
- desc=f"Processing audio ({total_minutes} min)",
234
- unit="min",
235
- unit_scale=False,
236
- unit_divisor=1,
237
- ) as pbar:
238
- for chunk in audio.iter_chunks():
239
- chunk_emission = self.emission(chunk.ndarray, acoustic_scale=acoustic_scale, device=device)
240
-
241
- # Update progress based on chunk duration in minutes
242
- chunk_duration = int(chunk.duration / 60.0)
243
- pbar.update(chunk_duration)
244
- processed_duration += chunk_duration
245
-
246
- yield chunk_emission
247
-
248
- # Calculate total frames for supervision_segments
249
- total_frames = int(audio.duration / self.frame_shift)
250
-
251
- results, labels = align_segments(
252
- emission_iterator(), # Pass iterator for streaming
253
- decoding_graph.to(device),
254
- torch.tensor([total_frames], dtype=torch.int32),
255
- search_beam=search_beam,
256
- output_beam=output_beam,
257
- min_active_states=min_active_states,
258
- max_active_states=max_active_states,
259
- subsampling_factor=1,
260
- reject_low_confidence=False,
191
+ # Initialize OnlineDenseIntersecter for streaming
192
+ intersecter = k2.OnlineDenseIntersecter(
193
+ decoding_graph,
194
+ aux_labels,
195
+ float(search_beam),
196
+ float(output_beam),
197
+ int(min_active_states),
198
+ int(max_active_states),
261
199
  )
262
200
 
263
- # For streaming, don't return emission tensor to save memory
201
+ # Streaming mode
202
+ total_duration = audio.duration
203
+ total_minutes = int(total_duration / 60.0)
204
+
205
+ with tqdm(
206
+ total=total_minutes,
207
+ desc=f"Processing audio ({total_minutes} min)",
208
+ unit="min",
209
+ unit_scale=False,
210
+ unit_divisor=1,
211
+ ) as pbar:
212
+ for chunk in audio.iter_chunks():
213
+ chunk_emission = self.emission(chunk.ndarray, acoustic_scale=acoustic_scale)
214
+ intersecter.decode(chunk_emission[0])
215
+
216
+ # Update progress
217
+ chunk_duration = int(chunk.duration / 60.0)
218
+ pbar.update(chunk_duration)
219
+
264
220
  emission_result = None
221
+ # Get results from intersecter
222
+ results, labels = intersecter.finish()
265
223
  else:
266
- # Batch mode: compute full emission tensor and pass to align_segments
224
+ # Batch mode
267
225
  if emission is None:
268
- emission = self.emission(
269
- audio.ndarray, acoustic_scale=acoustic_scale, device=device
270
- ) # (1, T, vocab_size)
226
+ emission = self.emission(audio.ndarray, acoustic_scale=acoustic_scale) # (1, T, vocab_size)
271
227
  else:
272
- emission = emission.to(device) * acoustic_scale
273
-
274
- results, labels = align_segments(
275
- emission,
276
- decoding_graph.to(device),
277
- torch.tensor([emission.shape[1]], dtype=torch.int32),
278
- search_beam=search_beam,
279
- output_beam=output_beam,
280
- min_active_states=min_active_states,
281
- max_active_states=max_active_states,
282
- subsampling_factor=1,
283
- reject_low_confidence=False,
228
+ if acoustic_scale != 1.0:
229
+ emission *= acoustic_scale
230
+ # Use AlignSegments directly
231
+ results, labels = k2.AlignSegments(
232
+ graph_dict,
233
+ emission[0], # Pass the prepared scores
234
+ float(search_beam),
235
+ float(output_beam),
236
+ int(min_active_states),
237
+ int(max_active_states),
284
238
  )
285
-
286
239
  emission_result = emission
287
240
 
288
241
  self.timings["align_segments"] += time.time() - _start
@@ -290,6 +243,41 @@ class Lattice1Worker:
290
243
  channel = 0
291
244
  return emission_result, results, labels, self.frame_shift, offset, channel # frame_shift=20ms
292
245
 
246
+ def profile(self) -> None:
247
+ """Print formatted profiling statistics."""
248
+ if not self.timings:
249
+ return
250
+
251
+ safe_print(colorful.bold(colorful.cyan("\n⏱️ Alignment Profiling")))
252
+ safe_print(colorful.gray("─" * 44))
253
+ safe_print(
254
+ f"{colorful.bold('Phase'.ljust(21))} "
255
+ f"{colorful.bold('Time'.ljust(12))} "
256
+ f"{colorful.bold('Percent'.rjust(8))}"
257
+ )
258
+ safe_print(colorful.gray("─" * 44))
259
+
260
+ total_time = sum(self.timings.values())
261
+
262
+ # Sort by duration descending
263
+ sorted_stats = sorted(self.timings.items(), key=lambda x: x[1], reverse=True)
264
+
265
+ for name, duration in sorted_stats:
266
+ percentage = (duration / total_time * 100) if total_time > 0 else 0.0
267
+ # Name: Cyan, Time: Yellow, Percent: Gray
268
+ safe_print(
269
+ f"{name:<20} "
270
+ f"{colorful.yellow(f'{duration:7.4f}s'.ljust(12))} "
271
+ f"{colorful.gray(f'{percentage:.2f}%'.rjust(8))}"
272
+ )
273
+
274
+ safe_print(colorful.gray("─" * 44))
275
+ # Pad "Total Time" before coloring to ensure correct alignment (ANSI codes don't count for width)
276
+ safe_print(
277
+ f"{colorful.bold('Total Time'.ljust(20))} "
278
+ f"{colorful.bold(colorful.yellow(f'{total_time:7.4f}s'.ljust(12)))}\n"
279
+ )
280
+
293
281
 
294
282
  def _load_worker(model_path: str, device: str, config: Optional[Any] = None) -> Lattice1Worker:
295
283
  """Instantiate lattice worker with consistent error handling."""
@@ -4,7 +4,7 @@ import re
4
4
  from collections import defaultdict
5
5
  from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar
6
6
 
7
- import torch
7
+ import numpy as np
8
8
 
9
9
  from lattifai.alignment.phonemizer import G2Phonemizer
10
10
  from lattifai.caption import Supervision
@@ -121,6 +121,7 @@ class LatticeTokenizer:
121
121
  def __init__(self, client_wrapper: Any):
122
122
  self.client_wrapper = client_wrapper
123
123
  self.model_name = ""
124
+ self.model_hub: Optional[str] = None
124
125
  self.words: List[str] = []
125
126
  self.g2p_model: Any = None # Placeholder for G2P model
126
127
  self.dictionaries = defaultdict(lambda: [])
@@ -142,10 +143,20 @@ class LatticeTokenizer:
142
143
  elif device.startswith("mps") and ort.get_all_providers().count("MPSExecutionProvider") > 0:
143
144
  providers.append("MPSExecutionProvider")
144
145
 
145
- sat = SaT(
146
- "sat-3l-sm",
147
- ort_providers=providers + ["CPUExecutionProvider"],
148
- )
146
+ if self.model_hub == "modelscope":
147
+ from modelscope.hub.snapshot_download import snapshot_download as ms_snapshot
148
+
149
+ downloaded_path = ms_snapshot("LattifAI/OmniTokenizer")
150
+ sat = SaT(
151
+ f"{downloaded_path}/sat-3l-sm",
152
+ tokenizer_name_or_path=f"{downloaded_path}/xlm-roberta-base",
153
+ ort_providers=providers + ["CPUExecutionProvider"],
154
+ )
155
+ else:
156
+ sat = SaT(
157
+ "sat-3l-sm",
158
+ ort_providers=providers + ["CPUExecutionProvider"],
159
+ )
149
160
  self.sentence_splitter = sat
150
161
 
151
162
  @staticmethod
@@ -200,6 +211,7 @@ class LatticeTokenizer:
200
211
  client_wrapper: Any,
201
212
  model_path: str,
202
213
  model_name: str,
214
+ model_hub: Optional[str] = None,
203
215
  device: str = "cpu",
204
216
  compressed: bool = True,
205
217
  ) -> TokenizerT:
@@ -214,7 +226,7 @@ class LatticeTokenizer:
214
226
  else:
215
227
  with open(words_model_path, "rb") as f:
216
228
  data = pickle.load(f)
217
- except pickle.UnpicklingError as e:
229
+ except Exception as e:
218
230
  del e
219
231
  import msgpack
220
232
 
@@ -227,6 +239,7 @@ class LatticeTokenizer:
227
239
 
228
240
  tokenizer = cls(client_wrapper=client_wrapper)
229
241
  tokenizer.model_name = model_name
242
+ tokenizer.model_hub = model_hub
230
243
  tokenizer.words = data["words"]
231
244
  tokenizer.dictionaries = defaultdict(list, data["dictionaries"])
232
245
  tokenizer.oov_word = data["oov_word"]
@@ -431,9 +444,11 @@ class LatticeTokenizer:
431
444
  def detokenize(
432
445
  self,
433
446
  lattice_id: str,
434
- lattice_results: Tuple[torch.Tensor, Any, Any, float, float],
447
+ lattice_results: Tuple[np.ndarray, Any, Any, float, float],
435
448
  supervisions: List[Supervision],
436
449
  return_details: bool = False,
450
+ start_margin: float = 0.08,
451
+ end_margin: float = 0.20,
437
452
  ) -> List[Supervision]:
438
453
  emission, results, labels, frame_shift, offset, channel = lattice_results # noqa: F841
439
454
  response = self.client_wrapper.post(
@@ -448,6 +463,8 @@ class LatticeTokenizer:
448
463
  "channel": channel,
449
464
  "return_details": False if return_details is None else return_details,
450
465
  "destroy_lattice": True,
466
+ "start_margin": start_margin,
467
+ "end_margin": end_margin,
451
468
  },
452
469
  )
453
470
  if response.status_code == 400:
@@ -477,7 +494,7 @@ class LatticeTokenizer:
477
494
 
478
495
  def _add_confidence_scores(
479
496
  supervisions: List[Supervision],
480
- emission: torch.Tensor,
497
+ emission: np.ndarray,
481
498
  labels: List[int],
482
499
  frame_shift: float,
483
500
  offset: float = 0.0,
@@ -495,17 +512,17 @@ def _add_confidence_scores(
495
512
  labels: Token labels corresponding to aligned tokens
496
513
  frame_shift: Frame shift in seconds for converting frames to time
497
514
  """
498
- tokens = torch.tensor(labels, dtype=torch.int64, device=emission.device)
515
+ tokens = np.array(labels, dtype=np.int64)
499
516
 
500
517
  for supervision in supervisions:
501
518
  start_frame = int((supervision.start - offset) / frame_shift)
502
519
  end_frame = int((supervision.end - offset) / frame_shift)
503
520
 
504
521
  # Compute segment-level confidence
505
- probabilities = emission[0, start_frame:end_frame].softmax(dim=-1)
522
+ probabilities = np.exp(emission[0, start_frame:end_frame])
506
523
  aligned = probabilities[range(0, end_frame - start_frame), tokens[start_frame:end_frame]]
507
- diffprobs = (probabilities.max(dim=-1).values - aligned).cpu()
508
- supervision.score = round(1.0 - diffprobs.mean().item(), ndigits=4)
524
+ diffprobs = np.max(probabilities, axis=-1) - aligned
525
+ supervision.score = round(1.0 - diffprobs.mean(), ndigits=4)
509
526
 
510
527
  # Compute word-level confidence if alignment exists
511
528
  if hasattr(supervision, "alignment") and supervision.alignment:
@@ -513,7 +530,7 @@ def _add_confidence_scores(
513
530
  for w, item in enumerate(words):
514
531
  start = int((item.start - offset) / frame_shift) - start_frame
515
532
  end = int((item.end - offset) / frame_shift) - start_frame
516
- words[w] = item._replace(score=round(1.0 - diffprobs[start:end].mean().item(), ndigits=4))
533
+ words[w] = item._replace(score=round(1.0 - diffprobs[start:end].mean(), ndigits=4))
517
534
 
518
535
 
519
536
  def _update_alignments_speaker(supervisions: List[Supervision], alignments: List[Supervision]) -> List[Supervision]:
@@ -535,15 +552,14 @@ def _load_tokenizer(
535
552
  model_name: str,
536
553
  device: str,
537
554
  *,
555
+ model_hub: Optional[str] = None,
538
556
  tokenizer_cls: Type[LatticeTokenizer] = LatticeTokenizer,
539
557
  ) -> LatticeTokenizer:
540
558
  """Instantiate tokenizer with consistent error handling."""
541
- try:
542
- return tokenizer_cls.from_pretrained(
543
- client_wrapper=client_wrapper,
544
- model_path=model_path,
545
- model_name=model_name,
546
- device=device,
547
- )
548
- except Exception as e:
549
- raise ModelLoadError(f"tokenizer from {model_path}", original_error=e)
559
+ return tokenizer_cls.from_pretrained(
560
+ client_wrapper=client_wrapper,
561
+ model_path=model_path,
562
+ model_name=model_name,
563
+ model_hub=model_hub,
564
+ device=device,
565
+ )