lattifai 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
lattifai/__init__.py CHANGED
@@ -1,5 +1,4 @@
1
1
  import os
2
- import sys
3
2
  import warnings
4
3
  from importlib.metadata import version
5
4
 
@@ -52,30 +51,6 @@ except Exception:
52
51
  __version__ = "0.1.0" # fallback version
53
52
 
54
53
 
55
- # Check and auto-install k2 if not present
56
- def _check_and_install_k2():
57
- """Check if k2 is installed and attempt to install it if not."""
58
- try:
59
- import k2
60
- except ImportError:
61
- import subprocess
62
-
63
- print("k2 is not installed. Attempting to install k2...")
64
- try:
65
- subprocess.check_call([sys.executable, "-m", "pip", "install", "install-k2"])
66
- subprocess.check_call([sys.executable, "-m", "install_k2"])
67
- import k2 # Try importing again after installation
68
-
69
- print("k2 installed successfully.")
70
- except Exception as e:
71
- warnings.warn(f"Failed to install k2 automatically. Please install it manually. Error: {e}")
72
- return True
73
-
74
-
75
- # Auto-install k2 on first import
76
- _check_and_install_k2()
77
-
78
-
79
54
  __all__ = [
80
55
  # Client classes
81
56
  "LattifAI",
@@ -4,7 +4,6 @@ from typing import Any, List, Optional, Tuple
4
4
 
5
5
  import colorful
6
6
  import numpy as np
7
- import torch
8
7
 
9
8
  from lattifai.audio2 import AudioData
10
9
  from lattifai.caption import Supervision
@@ -38,19 +37,21 @@ class Lattice1Aligner(object):
38
37
  # Resolve model path using configured model hub
39
38
  model_path = _resolve_model_path(config.model_name, getattr(config, "model_hub", "huggingface"))
40
39
 
41
- self.tokenizer = _load_tokenizer(client_wrapper, model_path, config.model_name, config.device)
40
+ self.tokenizer = _load_tokenizer(
41
+ client_wrapper, model_path, config.model_name, config.device, model_hub=config.model_hub
42
+ )
42
43
  self.worker = _load_worker(model_path, config.device, config)
43
44
 
44
45
  self.frame_shift = self.worker.frame_shift
45
46
 
46
- def emission(self, ndarray: np.ndarray) -> torch.Tensor:
47
+ def emission(self, ndarray: np.ndarray) -> np.ndarray:
47
48
  """Generate emission probabilities from audio ndarray.
48
49
 
49
50
  Args:
50
51
  ndarray: Audio data as numpy array of shape (1, T) or (C, T)
51
52
 
52
53
  Returns:
53
- Emission tensor of shape (1, T, vocab_size)
54
+ Emission numpy array of shape (1, T, vocab_size)
54
55
  """
55
56
  return self.worker.emission(ndarray)
56
57
 
@@ -68,13 +69,11 @@ class Lattice1Aligner(object):
68
69
  """
69
70
  if self.worker.separator_ort is None:
70
71
  raise RuntimeError("Separator model not available. separator.onnx not found in model path.")
71
-
72
72
  # Run separator model
73
73
  separator_output = self.worker.separator_ort.run(
74
74
  None,
75
- {"audio": audio},
75
+ {"audios": audio},
76
76
  )
77
-
78
77
  return separator_output[0]
79
78
 
80
79
  def alignment(
@@ -83,7 +82,7 @@ class Lattice1Aligner(object):
83
82
  supervisions: List[Supervision],
84
83
  split_sentence: Optional[bool] = False,
85
84
  return_details: Optional[bool] = False,
86
- emission: Optional[torch.Tensor] = None,
85
+ emission: Optional[np.ndarray] = None,
87
86
  offset: float = 0.0,
88
87
  verbose: bool = True,
89
88
  ) -> Tuple[List[Supervision], List[Supervision]]:
@@ -118,7 +117,7 @@ class Lattice1Aligner(object):
118
117
 
119
118
  if verbose:
120
119
  safe_print(colorful.cyan(f"🔍 Step 3: Searching lattice graph with media: {audio}"))
121
- if audio.streaming_chunk_secs:
120
+ if audio.streaming_mode:
122
121
  safe_print(
123
122
  colorful.yellow(
124
123
  f" ⚡Using streaming mode with {audio.streaming_chunk_secs}s (chunk duration)"
@@ -166,3 +165,7 @@ class Lattice1Aligner(object):
166
165
  raise
167
166
  except Exception as e:
168
167
  raise e
168
+
169
+ def profile(self) -> None:
170
+ """Print profiling statistics."""
171
+ self.worker.profile()
@@ -4,16 +4,15 @@ from collections import defaultdict
4
4
  from pathlib import Path
5
5
  from typing import Any, Dict, Optional, Tuple
6
6
 
7
+ import colorful
7
8
  import numpy as np
8
9
  import onnxruntime as ort
9
- import torch
10
- from lhotse import FbankConfig
11
- from lhotse.features.kaldi.layers import Wav2LogFilterBank
12
10
  from lhotse.utils import Pathlike
13
11
  from tqdm import tqdm
14
12
 
15
13
  from lattifai.audio2 import AudioData
16
14
  from lattifai.errors import AlignmentError, DependencyError, ModelLoadError
15
+ from lattifai.utils import safe_print
17
16
 
18
17
 
19
18
  class Lattice1Worker:
@@ -61,18 +60,12 @@ class Lattice1Worker:
61
60
  except Exception as e:
62
61
  raise ModelLoadError(f"acoustic model from {model_path}", original_error=e)
63
62
 
63
+ # Get vocab_size from model output
64
+ self.vocab_size = self.acoustic_ort.get_outputs()[0].shape[-1]
65
+
64
66
  # get input_names
65
67
  input_names = [inp.name for inp in self.acoustic_ort.get_inputs()]
66
- if "audios" not in input_names:
67
- try:
68
- config = FbankConfig(num_mel_bins=80, device=device, snip_edges=False)
69
- config_dict = config.to_dict()
70
- config_dict.pop("device")
71
- self.extractor = Wav2LogFilterBank(**config_dict).to(device).eval()
72
- except Exception as e:
73
- raise ModelLoadError(f"feature extractor for device {device}", original_error=e)
74
- else:
75
- self.extractor = None # ONNX model includes feature extractor
68
+ assert "audios" in input_names, f"Input name audios not found in {input_names}"
76
69
 
77
70
  # Initialize separator if available
78
71
  separator_model_path = Path(model_path) / "separator.onnx"
@@ -80,98 +73,71 @@ class Lattice1Worker:
80
73
  try:
81
74
  self.separator_ort = ort.InferenceSession(
82
75
  str(separator_model_path),
83
- providers=providers + ["CPUExecutionProvider"],
76
+ providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
84
77
  )
85
78
  except Exception as e:
86
79
  raise ModelLoadError(f"separator model from {model_path}", original_error=e)
87
80
  else:
88
81
  self.separator_ort = None
89
82
 
90
- self.device = torch.device(device)
91
83
  self.timings = defaultdict(lambda: 0.0)
92
84
 
93
85
  @property
94
86
  def frame_shift(self) -> float:
95
87
  return 0.02 # 20 ms
96
88
 
97
- @torch.inference_mode()
98
- def emission(self, ndarray: np.ndarray, acoustic_scale: float = 1.0, device: Optional[str] = None) -> torch.Tensor:
89
+ def emission(self, ndarray: np.ndarray, acoustic_scale: float = 1.0) -> np.ndarray:
99
90
  """Generate emission probabilities from audio ndarray.
100
91
 
101
92
  Args:
102
93
  ndarray: Audio data as numpy array of shape (1, T) or (C, T)
103
94
 
104
95
  Returns:
105
- Emission tensor of shape (1, T, vocab_size)
96
+ Emission numpy array of shape (1, T, vocab_size)
106
97
  """
107
98
  _start = time.time()
108
- if self.extractor is not None:
109
- # audio -> features -> emission
110
- audio = torch.from_numpy(ndarray).to(self.device)
111
- if audio.shape[1] < 160:
112
- audio = torch.nn.functional.pad(audio, (0, 320 - audio.shape[1]))
113
- features = self.extractor(audio) # (1, T, D)
114
- if features.shape[1] > 6000:
115
- emissions = []
116
- for start in range(0, features.size(1), 6000):
117
- _features = features[:, start : start + 6000, :]
118
- ort_inputs = {
119
- "features": _features.cpu().numpy(),
120
- "feature_lengths": np.array([_features.size(1)], dtype=np.int64),
121
- }
122
- emission = self.acoustic_ort.run(None, ort_inputs)[0] # (1, T, vocab_size) numpy
123
- emissions.append(emission)
124
- emission = torch.cat(
125
- [torch.from_numpy(emission).to(device or self.device) for emission in emissions], dim=1
126
- ) # (1, T, vocab_size)
127
- del emissions
128
- else:
129
- ort_inputs = {
130
- "features": features.cpu().numpy(),
131
- "feature_lengths": np.array([features.size(1)], dtype=np.int64),
132
- }
133
- emission = self.acoustic_ort.run(None, ort_inputs)[0] # (1, T, vocab_size) numpy
134
- emission = torch.from_numpy(emission).to(device or self.device)
99
+
100
+ if ndarray.shape[1] < 160:
101
+ ndarray = np.pad(ndarray, ((0, 0), (0, 320 - ndarray.shape[1])), mode="constant")
102
+
103
+ CHUNK_SIZE = 60 * 16000 # 60 seconds
104
+ total_samples = ndarray.shape[1]
105
+
106
+ if total_samples > CHUNK_SIZE:
107
+ frame_samples = int(16000 * self.frame_shift)
108
+ emissions = np.empty((1, total_samples // frame_samples + 1, self.vocab_size), dtype=np.float32)
109
+ for start in range(0, total_samples, CHUNK_SIZE):
110
+ chunk = ndarray[:, start : start + CHUNK_SIZE]
111
+ if chunk.shape[1] < 160:
112
+ chunk = np.pad(chunk, ((0, 0), (0, 320 - chunk.shape[1])), mode="constant")
113
+
114
+ emission_out = self.acoustic_ort.run(None, {"audios": chunk})[0]
115
+ if acoustic_scale != 1.0:
116
+ emission_out *= acoustic_scale
117
+ sf = start // frame_samples # start frame
118
+ lf = sf + emission_out.shape[1] # last frame
119
+ emissions[0, sf:lf, :] = emission_out
120
+ emissions[:, lf:, :] = 0.0
135
121
  else:
136
- if ndarray.shape[1] < 160:
137
- ndarray = np.pad(ndarray, ((0, 0), (0, 320 - ndarray.shape[1])), mode="constant")
138
-
139
- CHUNK_SIZE = 60 * 16000 # 60 seconds
140
- if ndarray.shape[1] > CHUNK_SIZE:
141
- emissions = []
142
- for start in range(0, ndarray.shape[1], CHUNK_SIZE):
143
- emission = self.acoustic_ort.run(
144
- None,
145
- {
146
- "audios": ndarray[:, start : start + CHUNK_SIZE],
147
- },
148
- ) # (1, T, vocab_size) numpy
149
- emissions.append(emission[0])
150
-
151
- emission = torch.cat(
152
- [torch.from_numpy(emission).to(device or self.device) for emission in emissions], dim=1
153
- ) # (1, T, vocab_size)
154
- del emissions
155
- else:
156
- emission = self.acoustic_ort.run(
157
- None,
158
- {
159
- "audios": ndarray,
160
- },
161
- ) # (1, T, vocab_size) numpy
162
- emission = torch.from_numpy(emission[0]).to(device or self.device)
122
+ emission_out = self.acoustic_ort.run(
123
+ None,
124
+ {
125
+ "audios": ndarray,
126
+ },
127
+ ) # (1, T, vocab_size) numpy
128
+ emissions = emission_out[0]
163
129
 
164
- if acoustic_scale != 1.0:
165
- emission = emission.mul_(acoustic_scale)
130
+ if acoustic_scale != 1.0:
131
+ emissions *= acoustic_scale
166
132
 
167
133
  self.timings["emission"] += time.time() - _start
168
- return emission # (1, T, vocab_size) torch
134
+ return emissions # (1, T, vocab_size) numpy
169
135
 
170
136
  def alignment(
171
137
  self,
172
138
  audio: AudioData,
173
139
  lattice_graph: Tuple[str, int, float],
174
- emission: Optional[torch.Tensor] = None,
140
+ emission: Optional[np.ndarray] = None,
175
141
  offset: float = 0.0,
176
142
  ) -> Dict[str, Any]:
177
143
  """Process audio with LatticeGraph.
@@ -179,7 +145,7 @@ class Lattice1Worker:
179
145
  Args:
180
146
  audio: AudioData object
181
147
  lattice_graph: LatticeGraph data
182
- emission: Pre-computed emission tensor (ignored if streaming=True)
148
+ emission: Pre-computed emission numpy array (ignored if streaming=True)
183
149
  offset: Time offset for the audio
184
150
  streaming: If True, use streaming mode for memory-efficient processing
185
151
 
@@ -191,26 +157,16 @@ class Lattice1Worker:
191
157
  DependencyError: If required dependencies are missing
192
158
  AlignmentError: If alignment process fails
193
159
  """
194
- try:
195
- import k2
196
- except ImportError:
197
- raise DependencyError("k2", install_command="pip install install-k2 && python -m install_k2")
198
-
199
- try:
200
- from lattifai_core.lattice.decode import align_segments
201
- except ImportError:
202
- raise DependencyError("lattifai_core", install_command="Contact support for lattifai_core installation")
160
+ import k2py as k2
203
161
 
204
162
  lattice_graph_str, final_state, acoustic_scale = lattice_graph
205
163
 
206
164
  _start = time.time()
207
165
  try:
208
- # Create decoding graph
209
- decoding_graph = k2.Fsa.from_str(lattice_graph_str, acceptor=False)
210
- decoding_graph.requires_grad_(False)
211
- decoding_graph = k2.arc_sort(decoding_graph)
212
- decoding_graph.skip_id = int(final_state)
213
- decoding_graph.return_id = int(final_state + 1)
166
+ # Create decoding graph using k2py
167
+ graph_dict = k2.CreateFsaVecFromStr(lattice_graph_str, int(final_state), False)
168
+ decoding_graph = graph_dict["fsa"]
169
+ aux_labels = graph_dict["aux_labels"]
214
170
  except Exception as e:
215
171
  raise AlignmentError(
216
172
  "Failed to create decoding graph from lattice",
@@ -218,11 +174,6 @@ class Lattice1Worker:
218
174
  )
219
175
  self.timings["decoding_graph"] += time.time() - _start
220
176
 
221
- if self.device.type == "mps":
222
- device = "cpu" # k2 does not support mps yet
223
- else:
224
- device = self.device
225
-
226
177
  _start = time.time()
227
178
 
228
179
  # Get beam search parameters from config or use defaults
@@ -232,71 +183,54 @@ class Lattice1Worker:
232
183
  max_active_states = self.alignment_config.max_active_states or 10000
233
184
 
234
185
  if emission is None and audio.streaming_mode:
235
- # Streaming mode: pass emission iterator to align_segments
236
- # The align_segments function will automatically detect the iterator
237
- # and use k2.OnlineDenseIntersecter for memory-efficient processing
238
-
239
- def emission_iterator():
240
- """Generate emissions for each audio chunk with progress tracking."""
241
- total_duration = audio.duration
242
- processed_duration = 0.0
243
- total_minutes = int(total_duration / 60.0)
244
-
245
- with tqdm(
246
- total=total_minutes,
247
- desc=f"Processing audio ({total_minutes} min)",
248
- unit="min",
249
- unit_scale=False,
250
- unit_divisor=1,
251
- ) as pbar:
252
- for chunk in audio.iter_chunks():
253
- chunk_emission = self.emission(chunk.ndarray, acoustic_scale=acoustic_scale, device=device)
254
-
255
- # Update progress based on chunk duration in minutes
256
- chunk_duration = int(chunk.duration / 60.0)
257
- pbar.update(chunk_duration)
258
- processed_duration += chunk_duration
259
-
260
- yield chunk_emission
261
-
262
- # Calculate total frames for supervision_segments
263
- total_frames = int(audio.duration / self.frame_shift)
264
-
265
- results, labels = align_segments(
266
- emission_iterator(), # Pass iterator for streaming
267
- decoding_graph.to(device),
268
- torch.tensor([total_frames], dtype=torch.int32),
269
- search_beam=search_beam,
270
- output_beam=output_beam,
271
- min_active_states=min_active_states,
272
- max_active_states=max_active_states,
273
- subsampling_factor=1,
274
- reject_low_confidence=False,
186
+ # Initialize OnlineDenseIntersecter for streaming
187
+ intersecter = k2.OnlineDenseIntersecter(
188
+ decoding_graph,
189
+ aux_labels,
190
+ float(search_beam),
191
+ float(output_beam),
192
+ int(min_active_states),
193
+ int(max_active_states),
275
194
  )
276
195
 
277
- # For streaming, don't return emission tensor to save memory
196
+ # Streaming mode
197
+ total_duration = audio.duration
198
+ total_minutes = int(total_duration / 60.0)
199
+
200
+ with tqdm(
201
+ total=total_minutes,
202
+ desc=f"Processing audio ({total_minutes} min)",
203
+ unit="min",
204
+ unit_scale=False,
205
+ unit_divisor=1,
206
+ ) as pbar:
207
+ for chunk in audio.iter_chunks():
208
+ chunk_emission = self.emission(chunk.ndarray, acoustic_scale=acoustic_scale)
209
+ intersecter.decode(chunk_emission[0])
210
+
211
+ # Update progress
212
+ chunk_duration = int(chunk.duration / 60.0)
213
+ pbar.update(chunk_duration)
214
+
278
215
  emission_result = None
216
+ # Get results from intersecter
217
+ results, labels = intersecter.finish()
279
218
  else:
280
- # Batch mode: compute full emission tensor and pass to align_segments
219
+ # Batch mode
281
220
  if emission is None:
282
- emission = self.emission(
283
- audio.ndarray, acoustic_scale=acoustic_scale, device=device
284
- ) # (1, T, vocab_size)
221
+ emission = self.emission(audio.ndarray, acoustic_scale=acoustic_scale) # (1, T, vocab_size)
285
222
  else:
286
- emission = emission.to(device) * acoustic_scale
287
-
288
- results, labels = align_segments(
289
- emission,
290
- decoding_graph.to(device),
291
- torch.tensor([emission.shape[1]], dtype=torch.int32),
292
- search_beam=search_beam,
293
- output_beam=output_beam,
294
- min_active_states=min_active_states,
295
- max_active_states=max_active_states,
296
- subsampling_factor=1,
297
- reject_low_confidence=False,
223
+ if acoustic_scale != 1.0:
224
+ emission *= acoustic_scale
225
+ # Use AlignSegments directly
226
+ results, labels = k2.AlignSegments(
227
+ graph_dict,
228
+ emission[0], # Pass the prepared scores
229
+ float(search_beam),
230
+ float(output_beam),
231
+ int(min_active_states),
232
+ int(max_active_states),
298
233
  )
299
-
300
234
  emission_result = emission
301
235
 
302
236
  self.timings["align_segments"] += time.time() - _start
@@ -304,6 +238,41 @@ class Lattice1Worker:
304
238
  channel = 0
305
239
  return emission_result, results, labels, self.frame_shift, offset, channel # frame_shift=20ms
306
240
 
241
+ def profile(self) -> None:
242
+ """Print formatted profiling statistics."""
243
+ if not self.timings:
244
+ return
245
+
246
+ safe_print(colorful.bold(colorful.cyan("\n⏱️ Alignment Profiling")))
247
+ safe_print(colorful.gray("─" * 44))
248
+ safe_print(
249
+ f"{colorful.bold('Phase'.ljust(21))} "
250
+ f"{colorful.bold('Time'.ljust(12))} "
251
+ f"{colorful.bold('Percent'.rjust(8))}"
252
+ )
253
+ safe_print(colorful.gray("─" * 44))
254
+
255
+ total_time = sum(self.timings.values())
256
+
257
+ # Sort by duration descending
258
+ sorted_stats = sorted(self.timings.items(), key=lambda x: x[1], reverse=True)
259
+
260
+ for name, duration in sorted_stats:
261
+ percentage = (duration / total_time * 100) if total_time > 0 else 0.0
262
+ # Name: Cyan, Time: Yellow, Percent: Gray
263
+ safe_print(
264
+ f"{name:<20} "
265
+ f"{colorful.yellow(f'{duration:7.4f}s'.ljust(12))} "
266
+ f"{colorful.gray(f'{percentage:.2f}%'.rjust(8))}"
267
+ )
268
+
269
+ safe_print(colorful.gray("─" * 44))
270
+ # Pad "Total Time" before coloring to ensure correct alignment (ANSI codes don't count for width)
271
+ safe_print(
272
+ f"{colorful.bold('Total Time'.ljust(20))} "
273
+ f"{colorful.bold(colorful.yellow(f'{total_time:7.4f}s'.ljust(12)))}\n"
274
+ )
275
+
307
276
 
308
277
  def _load_worker(model_path: str, device: str, config: Optional[Any] = None) -> Lattice1Worker:
309
278
  """Instantiate lattice worker with consistent error handling."""
@@ -9,7 +9,7 @@ from lattifai.caption import Caption, Supervision
9
9
  from lattifai.config import AlignmentConfig
10
10
  from lattifai.utils import safe_print
11
11
 
12
- from .tokenizer import END_PUNCTUATION
12
+ from .sentence_splitter import END_PUNCTUATION
13
13
 
14
14
 
15
15
  class Segmenter: