lattifai 1.1.0__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
lattifai/__init__.py CHANGED
@@ -52,28 +52,27 @@ except Exception:
52
52
  __version__ = "0.1.0" # fallback version
53
53
 
54
54
 
55
- # Check and auto-install k2 if not present
56
- def _check_and_install_k2():
57
- """Check if k2 is installed and attempt to install it if not."""
55
+ # Check and auto-install k2py if not present
56
+ def _check_and_install_k2py():
57
+ """Check if k2py is installed and attempt to install it if not."""
58
58
  try:
59
- import k2
59
+ import k2py
60
60
  except ImportError:
61
61
  import subprocess
62
62
 
63
- print("k2 is not installed. Attempting to install k2...")
63
+ print("k2py is not installed. Attempting to install k2py...")
64
64
  try:
65
- subprocess.check_call([sys.executable, "-m", "pip", "install", "install-k2"])
66
- subprocess.check_call([sys.executable, "-m", "install_k2"])
67
- import k2 # Try importing again after installation
65
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "k2py"])
66
+ import k2py # Try importing again after installation
68
67
 
69
- print("k2 installed successfully.")
68
+ print("k2py installed successfully.")
70
69
  except Exception as e:
71
- warnings.warn(f"Failed to install k2 automatically. Please install it manually. Error: {e}")
70
+ warnings.warn(f"Failed to install k2py automatically. Please install it manually. Error: {e}")
72
71
  return True
73
72
 
74
73
 
75
- # Auto-install k2 on first import
76
- _check_and_install_k2()
74
+ # Auto-install k2py on first import
75
+ _check_and_install_k2py()
77
76
 
78
77
 
79
78
  __all__ = [
@@ -4,7 +4,6 @@ from typing import Any, List, Optional, Tuple
4
4
 
5
5
  import colorful
6
6
  import numpy as np
7
- import torch
8
7
 
9
8
  from lattifai.audio2 import AudioData
10
9
  from lattifai.caption import Supervision
@@ -38,19 +37,21 @@ class Lattice1Aligner(object):
38
37
  # Resolve model path using configured model hub
39
38
  model_path = _resolve_model_path(config.model_name, getattr(config, "model_hub", "huggingface"))
40
39
 
41
- self.tokenizer = _load_tokenizer(client_wrapper, model_path, config.model_name, config.device)
40
+ self.tokenizer = _load_tokenizer(
41
+ client_wrapper, model_path, config.model_name, config.device, model_hub=config.model_hub
42
+ )
42
43
  self.worker = _load_worker(model_path, config.device, config)
43
44
 
44
45
  self.frame_shift = self.worker.frame_shift
45
46
 
46
- def emission(self, ndarray: np.ndarray) -> torch.Tensor:
47
+ def emission(self, ndarray: np.ndarray) -> np.ndarray:
47
48
  """Generate emission probabilities from audio ndarray.
48
49
 
49
50
  Args:
50
51
  ndarray: Audio data as numpy array of shape (1, T) or (C, T)
51
52
 
52
53
  Returns:
53
- Emission tensor of shape (1, T, vocab_size)
54
+ Emission numpy array of shape (1, T, vocab_size)
54
55
  """
55
56
  return self.worker.emission(ndarray)
56
57
 
@@ -68,13 +69,11 @@ class Lattice1Aligner(object):
68
69
  """
69
70
  if self.worker.separator_ort is None:
70
71
  raise RuntimeError("Separator model not available. separator.onnx not found in model path.")
71
-
72
72
  # Run separator model
73
73
  separator_output = self.worker.separator_ort.run(
74
74
  None,
75
- {"audio": audio},
75
+ {"audios": audio},
76
76
  )
77
-
78
77
  return separator_output[0]
79
78
 
80
79
  def alignment(
@@ -83,7 +82,7 @@ class Lattice1Aligner(object):
83
82
  supervisions: List[Supervision],
84
83
  split_sentence: Optional[bool] = False,
85
84
  return_details: Optional[bool] = False,
86
- emission: Optional[torch.Tensor] = None,
85
+ emission: Optional[np.ndarray] = None,
87
86
  offset: float = 0.0,
88
87
  verbose: bool = True,
89
88
  ) -> Tuple[List[Supervision], List[Supervision]]:
@@ -166,3 +165,7 @@ class Lattice1Aligner(object):
166
165
  raise
167
166
  except Exception as e:
168
167
  raise e
168
+
169
+ def profile(self) -> None:
170
+ """Print profiling statistics."""
171
+ self.worker.profile()
@@ -4,9 +4,9 @@ from collections import defaultdict
4
4
  from pathlib import Path
5
5
  from typing import Any, Dict, Optional, Tuple
6
6
 
7
+ import colorful
7
8
  import numpy as np
8
9
  import onnxruntime as ort
9
- import torch
10
10
  from lhotse import FbankConfig
11
11
  from lhotse.features.kaldi.layers import Wav2LogFilterBank
12
12
  from lhotse.utils import Pathlike
@@ -14,6 +14,7 @@ from tqdm import tqdm
14
14
 
15
15
  from lattifai.audio2 import AudioData
16
16
  from lattifai.errors import AlignmentError, DependencyError, ModelLoadError
17
+ from lattifai.utils import safe_print
17
18
 
18
19
 
19
20
  class Lattice1Worker:
@@ -61,18 +62,12 @@ class Lattice1Worker:
61
62
  except Exception as e:
62
63
  raise ModelLoadError(f"acoustic model from {model_path}", original_error=e)
63
64
 
65
+ # Get vocab_size from model output
66
+ self.vocab_size = self.acoustic_ort.get_outputs()[0].shape[-1]
67
+
64
68
  # get input_names
65
69
  input_names = [inp.name for inp in self.acoustic_ort.get_inputs()]
66
- if "audios" not in input_names:
67
- try:
68
- config = FbankConfig(num_mel_bins=80, device=device, snip_edges=False)
69
- config_dict = config.to_dict()
70
- config_dict.pop("device")
71
- self.extractor = Wav2LogFilterBank(**config_dict).to(device).eval()
72
- except Exception as e:
73
- raise ModelLoadError(f"feature extractor for device {device}", original_error=e)
74
- else:
75
- self.extractor = None # ONNX model includes feature extractor
70
+ assert "audios" in input_names, f"Input name audios not found in {input_names}"
76
71
 
77
72
  # Initialize separator if available
78
73
  separator_model_path = Path(model_path) / "separator.onnx"
@@ -80,98 +75,71 @@ class Lattice1Worker:
80
75
  try:
81
76
  self.separator_ort = ort.InferenceSession(
82
77
  str(separator_model_path),
83
- providers=providers + ["CPUExecutionProvider"],
78
+ providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
84
79
  )
85
80
  except Exception as e:
86
81
  raise ModelLoadError(f"separator model from {model_path}", original_error=e)
87
82
  else:
88
83
  self.separator_ort = None
89
84
 
90
- self.device = torch.device(device)
91
85
  self.timings = defaultdict(lambda: 0.0)
92
86
 
93
87
  @property
94
88
  def frame_shift(self) -> float:
95
89
  return 0.02 # 20 ms
96
90
 
97
- @torch.inference_mode()
98
- def emission(self, ndarray: np.ndarray, acoustic_scale: float = 1.0, device: Optional[str] = None) -> torch.Tensor:
91
+ def emission(self, ndarray: np.ndarray, acoustic_scale: float = 1.0) -> np.ndarray:
99
92
  """Generate emission probabilities from audio ndarray.
100
93
 
101
94
  Args:
102
95
  ndarray: Audio data as numpy array of shape (1, T) or (C, T)
103
96
 
104
97
  Returns:
105
- Emission tensor of shape (1, T, vocab_size)
98
+ Emission numpy array of shape (1, T, vocab_size)
106
99
  """
107
100
  _start = time.time()
108
- if self.extractor is not None:
109
- # audio -> features -> emission
110
- audio = torch.from_numpy(ndarray).to(self.device)
111
- if audio.shape[1] < 160:
112
- audio = torch.nn.functional.pad(audio, (0, 320 - audio.shape[1]))
113
- features = self.extractor(audio) # (1, T, D)
114
- if features.shape[1] > 6000:
115
- emissions = []
116
- for start in range(0, features.size(1), 6000):
117
- _features = features[:, start : start + 6000, :]
118
- ort_inputs = {
119
- "features": _features.cpu().numpy(),
120
- "feature_lengths": np.array([_features.size(1)], dtype=np.int64),
121
- }
122
- emission = self.acoustic_ort.run(None, ort_inputs)[0] # (1, T, vocab_size) numpy
123
- emissions.append(emission)
124
- emission = torch.cat(
125
- [torch.from_numpy(emission).to(device or self.device) for emission in emissions], dim=1
126
- ) # (1, T, vocab_size)
127
- del emissions
128
- else:
129
- ort_inputs = {
130
- "features": features.cpu().numpy(),
131
- "feature_lengths": np.array([features.size(1)], dtype=np.int64),
132
- }
133
- emission = self.acoustic_ort.run(None, ort_inputs)[0] # (1, T, vocab_size) numpy
134
- emission = torch.from_numpy(emission).to(device or self.device)
101
+
102
+ if ndarray.shape[1] < 160:
103
+ ndarray = np.pad(ndarray, ((0, 0), (0, 320 - ndarray.shape[1])), mode="constant")
104
+
105
+ CHUNK_SIZE = 60 * 16000 # 60 seconds
106
+ total_samples = ndarray.shape[1]
107
+
108
+ if total_samples > CHUNK_SIZE:
109
+ frame_samples = int(16000 * self.frame_shift)
110
+ emissions = np.empty((1, total_samples // frame_samples + 1, self.vocab_size), dtype=np.float32)
111
+ for start in range(0, total_samples, CHUNK_SIZE):
112
+ chunk = ndarray[:, start : start + CHUNK_SIZE]
113
+ if chunk.shape[1] < 160:
114
+ chunk = np.pad(chunk, ((0, 0), (0, 320 - chunk.shape[1])), mode="constant")
115
+
116
+ emission_out = self.acoustic_ort.run(None, {"audios": chunk})[0]
117
+ if acoustic_scale != 1.0:
118
+ emission_out *= acoustic_scale
119
+ sf = start // frame_samples # start frame
120
+ lf = sf + emission_out.shape[1] # last frame
121
+ emissions[0, sf:lf, :] = emission_out
122
+ emissions[:, lf:, :] = 0.0
135
123
  else:
136
- if ndarray.shape[1] < 160:
137
- ndarray = np.pad(ndarray, ((0, 0), (0, 320 - ndarray.shape[1])), mode="constant")
138
-
139
- CHUNK_SIZE = 60 * 16000 # 60 seconds
140
- if ndarray.shape[1] > CHUNK_SIZE:
141
- emissions = []
142
- for start in range(0, ndarray.shape[1], CHUNK_SIZE):
143
- emission = self.acoustic_ort.run(
144
- None,
145
- {
146
- "audios": ndarray[:, start : start + CHUNK_SIZE],
147
- },
148
- ) # (1, T, vocab_size) numpy
149
- emissions.append(emission[0])
150
-
151
- emission = torch.cat(
152
- [torch.from_numpy(emission).to(device or self.device) for emission in emissions], dim=1
153
- ) # (1, T, vocab_size)
154
- del emissions
155
- else:
156
- emission = self.acoustic_ort.run(
157
- None,
158
- {
159
- "audios": ndarray,
160
- },
161
- ) # (1, T, vocab_size) numpy
162
- emission = torch.from_numpy(emission[0]).to(device or self.device)
124
+ emission_out = self.acoustic_ort.run(
125
+ None,
126
+ {
127
+ "audios": ndarray,
128
+ },
129
+ ) # (1, T, vocab_size) numpy
130
+ emissions = emission_out[0]
163
131
 
164
- if acoustic_scale != 1.0:
165
- emission = emission.mul_(acoustic_scale)
132
+ if acoustic_scale != 1.0:
133
+ emissions *= acoustic_scale
166
134
 
167
135
  self.timings["emission"] += time.time() - _start
168
- return emission # (1, T, vocab_size) torch
136
+ return emissions # (1, T, vocab_size) numpy
169
137
 
170
138
  def alignment(
171
139
  self,
172
140
  audio: AudioData,
173
141
  lattice_graph: Tuple[str, int, float],
174
- emission: Optional[torch.Tensor] = None,
142
+ emission: Optional[np.ndarray] = None,
175
143
  offset: float = 0.0,
176
144
  ) -> Dict[str, Any]:
177
145
  """Process audio with LatticeGraph.
@@ -179,7 +147,7 @@ class Lattice1Worker:
179
147
  Args:
180
148
  audio: AudioData object
181
149
  lattice_graph: LatticeGraph data
182
- emission: Pre-computed emission tensor (ignored if streaming=True)
150
+ emission: Pre-computed emission numpy array (ignored if streaming=True)
183
151
  offset: Time offset for the audio
184
152
  streaming: If True, use streaming mode for memory-efficient processing
185
153
 
@@ -192,25 +160,18 @@ class Lattice1Worker:
192
160
  AlignmentError: If alignment process fails
193
161
  """
194
162
  try:
195
- import k2
163
+ import k2py as k2
196
164
  except ImportError:
197
- raise DependencyError("k2", install_command="pip install install-k2 && python -m install_k2")
198
-
199
- try:
200
- from lattifai_core.lattice.decode import align_segments
201
- except ImportError:
202
- raise DependencyError("lattifai_core", install_command="Contact support for lattifai_core installation")
165
+ raise DependencyError("k2py", install_command="pip install k2py")
203
166
 
204
167
  lattice_graph_str, final_state, acoustic_scale = lattice_graph
205
168
 
206
169
  _start = time.time()
207
170
  try:
208
- # Create decoding graph
209
- decoding_graph = k2.Fsa.from_str(lattice_graph_str, acceptor=False)
210
- decoding_graph.requires_grad_(False)
211
- decoding_graph = k2.arc_sort(decoding_graph)
212
- decoding_graph.skip_id = int(final_state)
213
- decoding_graph.return_id = int(final_state + 1)
171
+ # Create decoding graph using k2py
172
+ graph_dict = k2.CreateFsaVecFromStr(lattice_graph_str, int(final_state), False)
173
+ decoding_graph = graph_dict["fsa"]
174
+ aux_labels = graph_dict["aux_labels"]
214
175
  except Exception as e:
215
176
  raise AlignmentError(
216
177
  "Failed to create decoding graph from lattice",
@@ -218,11 +179,6 @@ class Lattice1Worker:
218
179
  )
219
180
  self.timings["decoding_graph"] += time.time() - _start
220
181
 
221
- if self.device.type == "mps":
222
- device = "cpu" # k2 does not support mps yet
223
- else:
224
- device = self.device
225
-
226
182
  _start = time.time()
227
183
 
228
184
  # Get beam search parameters from config or use defaults
@@ -232,71 +188,54 @@ class Lattice1Worker:
232
188
  max_active_states = self.alignment_config.max_active_states or 10000
233
189
 
234
190
  if emission is None and audio.streaming_mode:
235
- # Streaming mode: pass emission iterator to align_segments
236
- # The align_segments function will automatically detect the iterator
237
- # and use k2.OnlineDenseIntersecter for memory-efficient processing
238
-
239
- def emission_iterator():
240
- """Generate emissions for each audio chunk with progress tracking."""
241
- total_duration = audio.duration
242
- processed_duration = 0.0
243
- total_minutes = int(total_duration / 60.0)
244
-
245
- with tqdm(
246
- total=total_minutes,
247
- desc=f"Processing audio ({total_minutes} min)",
248
- unit="min",
249
- unit_scale=False,
250
- unit_divisor=1,
251
- ) as pbar:
252
- for chunk in audio.iter_chunks():
253
- chunk_emission = self.emission(chunk.ndarray, acoustic_scale=acoustic_scale, device=device)
254
-
255
- # Update progress based on chunk duration in minutes
256
- chunk_duration = int(chunk.duration / 60.0)
257
- pbar.update(chunk_duration)
258
- processed_duration += chunk_duration
259
-
260
- yield chunk_emission
261
-
262
- # Calculate total frames for supervision_segments
263
- total_frames = int(audio.duration / self.frame_shift)
264
-
265
- results, labels = align_segments(
266
- emission_iterator(), # Pass iterator for streaming
267
- decoding_graph.to(device),
268
- torch.tensor([total_frames], dtype=torch.int32),
269
- search_beam=search_beam,
270
- output_beam=output_beam,
271
- min_active_states=min_active_states,
272
- max_active_states=max_active_states,
273
- subsampling_factor=1,
274
- reject_low_confidence=False,
191
+ # Initialize OnlineDenseIntersecter for streaming
192
+ intersecter = k2.OnlineDenseIntersecter(
193
+ decoding_graph,
194
+ aux_labels,
195
+ float(search_beam),
196
+ float(output_beam),
197
+ int(min_active_states),
198
+ int(max_active_states),
275
199
  )
276
200
 
277
- # For streaming, don't return emission tensor to save memory
201
+ # Streaming mode
202
+ total_duration = audio.duration
203
+ total_minutes = int(total_duration / 60.0)
204
+
205
+ with tqdm(
206
+ total=total_minutes,
207
+ desc=f"Processing audio ({total_minutes} min)",
208
+ unit="min",
209
+ unit_scale=False,
210
+ unit_divisor=1,
211
+ ) as pbar:
212
+ for chunk in audio.iter_chunks():
213
+ chunk_emission = self.emission(chunk.ndarray, acoustic_scale=acoustic_scale)
214
+ intersecter.decode(chunk_emission[0])
215
+
216
+ # Update progress
217
+ chunk_duration = int(chunk.duration / 60.0)
218
+ pbar.update(chunk_duration)
219
+
278
220
  emission_result = None
221
+ # Get results from intersecter
222
+ results, labels = intersecter.finish()
279
223
  else:
280
- # Batch mode: compute full emission tensor and pass to align_segments
224
+ # Batch mode
281
225
  if emission is None:
282
- emission = self.emission(
283
- audio.ndarray, acoustic_scale=acoustic_scale, device=device
284
- ) # (1, T, vocab_size)
226
+ emission = self.emission(audio.ndarray, acoustic_scale=acoustic_scale) # (1, T, vocab_size)
285
227
  else:
286
- emission = emission.to(device) * acoustic_scale
287
-
288
- results, labels = align_segments(
289
- emission,
290
- decoding_graph.to(device),
291
- torch.tensor([emission.shape[1]], dtype=torch.int32),
292
- search_beam=search_beam,
293
- output_beam=output_beam,
294
- min_active_states=min_active_states,
295
- max_active_states=max_active_states,
296
- subsampling_factor=1,
297
- reject_low_confidence=False,
228
+ if acoustic_scale != 1.0:
229
+ emission *= acoustic_scale
230
+ # Use AlignSegments directly
231
+ results, labels = k2.AlignSegments(
232
+ graph_dict,
233
+ emission[0], # Pass the prepared scores
234
+ float(search_beam),
235
+ float(output_beam),
236
+ int(min_active_states),
237
+ int(max_active_states),
298
238
  )
299
-
300
239
  emission_result = emission
301
240
 
302
241
  self.timings["align_segments"] += time.time() - _start
@@ -304,6 +243,41 @@ class Lattice1Worker:
304
243
  channel = 0
305
244
  return emission_result, results, labels, self.frame_shift, offset, channel # frame_shift=20ms
306
245
 
246
+ def profile(self) -> None:
247
+ """Print formatted profiling statistics."""
248
+ if not self.timings:
249
+ return
250
+
251
+ safe_print(colorful.bold(colorful.cyan("\n⏱️ Alignment Profiling")))
252
+ safe_print(colorful.gray("─" * 44))
253
+ safe_print(
254
+ f"{colorful.bold('Phase'.ljust(21))} "
255
+ f"{colorful.bold('Time'.ljust(12))} "
256
+ f"{colorful.bold('Percent'.rjust(8))}"
257
+ )
258
+ safe_print(colorful.gray("─" * 44))
259
+
260
+ total_time = sum(self.timings.values())
261
+
262
+ # Sort by duration descending
263
+ sorted_stats = sorted(self.timings.items(), key=lambda x: x[1], reverse=True)
264
+
265
+ for name, duration in sorted_stats:
266
+ percentage = (duration / total_time * 100) if total_time > 0 else 0.0
267
+ # Name: Cyan, Time: Yellow, Percent: Gray
268
+ safe_print(
269
+ f"{name:<20} "
270
+ f"{colorful.yellow(f'{duration:7.4f}s'.ljust(12))} "
271
+ f"{colorful.gray(f'{percentage:.2f}%'.rjust(8))}"
272
+ )
273
+
274
+ safe_print(colorful.gray("─" * 44))
275
+ # Pad "Total Time" before coloring to ensure correct alignment (ANSI codes don't count for width)
276
+ safe_print(
277
+ f"{colorful.bold('Total Time'.ljust(20))} "
278
+ f"{colorful.bold(colorful.yellow(f'{total_time:7.4f}s'.ljust(12)))}\n"
279
+ )
280
+
307
281
 
308
282
  def _load_worker(model_path: str, device: str, config: Optional[Any] = None) -> Lattice1Worker:
309
283
  """Instantiate lattice worker with consistent error handling."""
@@ -4,7 +4,7 @@ import re
4
4
  from collections import defaultdict
5
5
  from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar
6
6
 
7
- import torch
7
+ import numpy as np
8
8
 
9
9
  from lattifai.alignment.phonemizer import G2Phonemizer
10
10
  from lattifai.caption import Supervision
@@ -121,6 +121,7 @@ class LatticeTokenizer:
121
121
  def __init__(self, client_wrapper: Any):
122
122
  self.client_wrapper = client_wrapper
123
123
  self.model_name = ""
124
+ self.model_hub: Optional[str] = None
124
125
  self.words: List[str] = []
125
126
  self.g2p_model: Any = None # Placeholder for G2P model
126
127
  self.dictionaries = defaultdict(lambda: [])
@@ -142,10 +143,20 @@ class LatticeTokenizer:
142
143
  elif device.startswith("mps") and ort.get_all_providers().count("MPSExecutionProvider") > 0:
143
144
  providers.append("MPSExecutionProvider")
144
145
 
145
- sat = SaT(
146
- "sat-3l-sm",
147
- ort_providers=providers + ["CPUExecutionProvider"],
148
- )
146
+ if self.model_hub == "modelscope":
147
+ from modelscope.hub.snapshot_download import snapshot_download as ms_snapshot
148
+
149
+ downloaded_path = ms_snapshot("LattifAI/OmniTokenizer")
150
+ sat = SaT(
151
+ f"{downloaded_path}/sat-3l-sm",
152
+ tokenizer_name_or_path=f"{downloaded_path}/xlm-roberta-base",
153
+ ort_providers=providers + ["CPUExecutionProvider"],
154
+ )
155
+ else:
156
+ sat = SaT(
157
+ "sat-3l-sm",
158
+ ort_providers=providers + ["CPUExecutionProvider"],
159
+ )
149
160
  self.sentence_splitter = sat
150
161
 
151
162
  @staticmethod
@@ -200,6 +211,7 @@ class LatticeTokenizer:
200
211
  client_wrapper: Any,
201
212
  model_path: str,
202
213
  model_name: str,
214
+ model_hub: Optional[str] = None,
203
215
  device: str = "cpu",
204
216
  compressed: bool = True,
205
217
  ) -> TokenizerT:
@@ -227,6 +239,7 @@ class LatticeTokenizer:
227
239
 
228
240
  tokenizer = cls(client_wrapper=client_wrapper)
229
241
  tokenizer.model_name = model_name
242
+ tokenizer.model_hub = model_hub
230
243
  tokenizer.words = data["words"]
231
244
  tokenizer.dictionaries = defaultdict(list, data["dictionaries"])
232
245
  tokenizer.oov_word = data["oov_word"]
@@ -431,7 +444,7 @@ class LatticeTokenizer:
431
444
  def detokenize(
432
445
  self,
433
446
  lattice_id: str,
434
- lattice_results: Tuple[torch.Tensor, Any, Any, float, float],
447
+ lattice_results: Tuple[np.ndarray, Any, Any, float, float],
435
448
  supervisions: List[Supervision],
436
449
  return_details: bool = False,
437
450
  start_margin: float = 0.08,
@@ -481,7 +494,7 @@ class LatticeTokenizer:
481
494
 
482
495
  def _add_confidence_scores(
483
496
  supervisions: List[Supervision],
484
- emission: torch.Tensor,
497
+ emission: np.ndarray,
485
498
  labels: List[int],
486
499
  frame_shift: float,
487
500
  offset: float = 0.0,
@@ -499,17 +512,17 @@ def _add_confidence_scores(
499
512
  labels: Token labels corresponding to aligned tokens
500
513
  frame_shift: Frame shift in seconds for converting frames to time
501
514
  """
502
- tokens = torch.tensor(labels, dtype=torch.int64, device=emission.device)
515
+ tokens = np.array(labels, dtype=np.int64)
503
516
 
504
517
  for supervision in supervisions:
505
518
  start_frame = int((supervision.start - offset) / frame_shift)
506
519
  end_frame = int((supervision.end - offset) / frame_shift)
507
520
 
508
521
  # Compute segment-level confidence
509
- probabilities = emission[0, start_frame:end_frame].softmax(dim=-1)
522
+ probabilities = np.exp(emission[0, start_frame:end_frame])
510
523
  aligned = probabilities[range(0, end_frame - start_frame), tokens[start_frame:end_frame]]
511
- diffprobs = (probabilities.max(dim=-1).values - aligned).cpu()
512
- supervision.score = round(1.0 - diffprobs.mean().item(), ndigits=4)
524
+ diffprobs = np.max(probabilities, axis=-1) - aligned
525
+ supervision.score = round(1.0 - diffprobs.mean(), ndigits=4)
513
526
 
514
527
  # Compute word-level confidence if alignment exists
515
528
  if hasattr(supervision, "alignment") and supervision.alignment:
@@ -517,7 +530,7 @@ def _add_confidence_scores(
517
530
  for w, item in enumerate(words):
518
531
  start = int((item.start - offset) / frame_shift) - start_frame
519
532
  end = int((item.end - offset) / frame_shift) - start_frame
520
- words[w] = item._replace(score=round(1.0 - diffprobs[start:end].mean().item(), ndigits=4))
533
+ words[w] = item._replace(score=round(1.0 - diffprobs[start:end].mean(), ndigits=4))
521
534
 
522
535
 
523
536
  def _update_alignments_speaker(supervisions: List[Supervision], alignments: List[Supervision]) -> List[Supervision]:
@@ -539,6 +552,7 @@ def _load_tokenizer(
539
552
  model_name: str,
540
553
  device: str,
541
554
  *,
555
+ model_hub: Optional[str] = None,
542
556
  tokenizer_cls: Type[LatticeTokenizer] = LatticeTokenizer,
543
557
  ) -> LatticeTokenizer:
544
558
  """Instantiate tokenizer with consistent error handling."""
@@ -546,5 +560,6 @@ def _load_tokenizer(
546
560
  client_wrapper=client_wrapper,
547
561
  model_path=model_path,
548
562
  model_name=model_name,
563
+ model_hub=model_hub,
549
564
  device=device,
550
565
  )
lattifai/audio2.py CHANGED
@@ -36,7 +36,7 @@ class AudioData(namedtuple("AudioData", ["sampling_rate", "ndarray", "path", "st
36
36
  @property
37
37
  def streaming_mode(self) -> bool:
38
38
  """Indicates whether streaming mode is enabled based on streaming_chunk_secs."""
39
- if self.streaming_chunk_secs is not None:
39
+ if self.streaming_chunk_secs:
40
40
  return self.duration > self.streaming_chunk_secs * 1.1
41
41
  return False
42
42
 
@@ -8,7 +8,7 @@ import nemo_run as run
8
8
  from typing_extensions import Annotated
9
9
 
10
10
  from lattifai.client import LattifAI
11
- from lattifai.config import CaptionConfig, ClientConfig, DiarizationConfig, MediaConfig
11
+ from lattifai.config import AlignmentConfig, CaptionConfig, ClientConfig, DiarizationConfig, MediaConfig
12
12
  from lattifai.utils import safe_print
13
13
 
14
14
  __all__ = ["diarize"]
@@ -22,6 +22,7 @@ def diarize(
22
22
  media: Annotated[Optional[MediaConfig], run.Config[MediaConfig]] = None,
23
23
  caption: Annotated[Optional[CaptionConfig], run.Config[CaptionConfig]] = None,
24
24
  client: Annotated[Optional[ClientConfig], run.Config[ClientConfig]] = None,
25
+ alignment: Annotated[Optional[AlignmentConfig], run.Config[AlignmentConfig]] = None,
25
26
  diarization: Annotated[Optional[DiarizationConfig], run.Config[DiarizationConfig]] = None,
26
27
  ):
27
28
  """Run speaker diarization on aligned captions and audio."""
@@ -53,6 +54,7 @@ def diarize(
53
54
 
54
55
  client_instance = LattifAI(
55
56
  client_config=client,
57
+ alignment_config=alignment,
56
58
  caption_config=caption_config,
57
59
  diarization_config=diarization_config,
58
60
  )