lattifai 1.0.4__py3-none-any.whl → 1.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
lattifai/__init__.py CHANGED
@@ -1,7 +1,17 @@
1
+ import os
1
2
  import sys
2
3
  import warnings
3
4
  from importlib.metadata import version
4
5
 
6
+ # Suppress SWIG deprecation warnings before any imports
7
+ warnings.filterwarnings("ignore", category=DeprecationWarning, message=".*SwigPy.*")
8
+
9
+ # Suppress PyTorch transformer nested tensor warning
10
+ warnings.filterwarnings("ignore", category=UserWarning, message=".*enable_nested_tensor.*")
11
+
12
+ # Disable tokenizers parallelism warning
13
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
14
+
5
15
  # Re-export I/O classes
6
16
  from .caption import Caption
7
17
 
@@ -3,6 +3,7 @@
3
3
  from typing import Any, List, Optional, Tuple
4
4
 
5
5
  import colorful
6
+ import numpy as np
6
7
  import torch
7
8
 
8
9
  from lattifai.audio2 import AudioData
@@ -13,7 +14,7 @@ from lattifai.errors import (
13
14
  LatticeDecodingError,
14
15
  LatticeEncodingError,
15
16
  )
16
- from lattifai.utils import _resolve_model_path
17
+ from lattifai.utils import _resolve_model_path, safe_print
17
18
 
18
19
  from .lattice1_worker import _load_worker
19
20
  from .tokenizer import _load_tokenizer
@@ -37,12 +38,20 @@ class Lattice1Aligner(object):
37
38
  model_path = _resolve_model_path(config.model_name)
38
39
 
39
40
  self.tokenizer = _load_tokenizer(client_wrapper, model_path, config.model_name, config.device)
40
- self.worker = _load_worker(model_path, config.device)
41
+ self.worker = _load_worker(model_path, config.device, config)
41
42
 
42
43
  self.frame_shift = self.worker.frame_shift
43
44
 
44
- def emission(self, audio: torch.Tensor) -> torch.Tensor:
45
- return self.worker.emission(audio.to(self.worker.device))
45
+ def emission(self, ndarray: np.ndarray) -> torch.Tensor:
46
+ """Generate emission probabilities from audio ndarray.
47
+
48
+ Args:
49
+ ndarray: Audio data as numpy array of shape (1, T) or (C, T)
50
+
51
+ Returns:
52
+ Emission tensor of shape (1, T, vocab_size)
53
+ """
54
+ return self.worker.emission(ndarray)
46
55
 
47
56
  def alignment(
48
57
  self,
@@ -72,23 +81,34 @@ class Lattice1Aligner(object):
72
81
  """
73
82
  try:
74
83
  if verbose:
75
- print(colorful.cyan("🔗 Step 2: Creating lattice graph from segments"))
84
+ safe_print(colorful.cyan("🔗 Step 2: Creating lattice graph from segments"))
76
85
  try:
77
86
  supervisions, lattice_id, lattice_graph = self.tokenizer.tokenize(
78
87
  supervisions, split_sentence=split_sentence
79
88
  )
80
89
  if verbose:
81
- print(colorful.green(f" ✓ Generated lattice graph with ID: {lattice_id}"))
90
+ safe_print(colorful.green(f" ✓ Generated lattice graph with ID: {lattice_id}"))
82
91
  except Exception as e:
83
92
  text_content = " ".join([sup.text for sup in supervisions]) if supervisions else ""
84
93
  raise LatticeEncodingError(text_content, original_error=e)
85
94
 
86
95
  if verbose:
87
- print(colorful.cyan(f"🔍 Step 3: Searching lattice graph with media: {audio}"))
96
+ safe_print(colorful.cyan(f"🔍 Step 3: Searching lattice graph with media: {audio}"))
97
+ if audio.streaming_chunk_secs:
98
+ safe_print(
99
+ colorful.yellow(
100
+ f" ⚡Using streaming mode with {audio.streaming_chunk_secs}s (chunk duration)"
101
+ )
102
+ )
88
103
  try:
89
- lattice_results = self.worker.alignment(audio, lattice_graph, emission=emission, offset=offset)
104
+ lattice_results = self.worker.alignment(
105
+ audio,
106
+ lattice_graph,
107
+ emission=emission,
108
+ offset=offset,
109
+ )
90
110
  if verbose:
91
- print(colorful.green(" ✓ Lattice search completed"))
111
+ safe_print(colorful.green(" ✓ Lattice search completed"))
92
112
  except Exception as e:
93
113
  raise AlignmentError(
94
114
  f"Audio alignment failed for {audio}",
@@ -97,18 +117,18 @@ class Lattice1Aligner(object):
97
117
  )
98
118
 
99
119
  if verbose:
100
- print(colorful.cyan("🎯 Step 4: Decoding lattice results to aligned segments"))
120
+ safe_print(colorful.cyan("🎯 Step 4: Decoding lattice results to aligned segments"))
101
121
  try:
102
122
  alignments = self.tokenizer.detokenize(
103
123
  lattice_id, lattice_results, supervisions=supervisions, return_details=return_details
104
124
  )
105
125
  if verbose:
106
- print(colorful.green(f" ✓ Successfully aligned {len(alignments)} segments"))
126
+ safe_print(colorful.green(f" ✓ Successfully aligned {len(alignments)} segments"))
107
127
  except LatticeDecodingError as e:
108
- print(colorful.red(" x Failed to decode lattice alignment results"))
128
+ safe_print(colorful.red(" x Failed to decode lattice alignment results"))
109
129
  raise e
110
130
  except Exception as e:
111
- print(colorful.red(" x Failed to decode lattice alignment results"))
131
+ safe_print(colorful.red(" x Failed to decode lattice alignment results"))
112
132
  raise LatticeDecodingError(lattice_id, original_error=e)
113
133
 
114
134
  return (supervisions, alignments)
@@ -9,6 +9,7 @@ import torch
9
9
  from lhotse import FbankConfig
10
10
  from lhotse.features.kaldi.layers import Wav2LogFilterBank
11
11
  from lhotse.utils import Pathlike
12
+ from tqdm import tqdm
12
13
 
13
14
  from lattifai.audio2 import AudioData
14
15
  from lattifai.errors import AlignmentError, DependencyError, ModelLoadError
@@ -17,12 +18,17 @@ from lattifai.errors import AlignmentError, DependencyError, ModelLoadError
17
18
  class Lattice1Worker:
18
19
  """Worker for processing audio with LatticeGraph."""
19
20
 
20
- def __init__(self, model_path: Pathlike, device: str = "cpu", num_threads: int = 8) -> None:
21
+ def __init__(
22
+ self, model_path: Pathlike, device: str = "cpu", num_threads: int = 8, config: Optional[Any] = None
23
+ ) -> None:
21
24
  try:
22
- self.config = json.load(open(f"{model_path}/config.json"))
25
+ self.model_config = json.load(open(f"{model_path}/config.json"))
23
26
  except Exception as e:
24
27
  raise ModelLoadError(f"config from {model_path}", original_error=e)
25
28
 
29
+ # Store alignment config with beam search parameters
30
+ self.alignment_config = config
31
+
26
32
  # SessionOptions
27
33
  sess_options = ort.SessionOptions()
28
34
  # sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
@@ -75,59 +81,74 @@ class Lattice1Worker:
75
81
  return 0.02 # 20 ms
76
82
 
77
83
  @torch.inference_mode()
78
- def emission(self, audio: torch.Tensor) -> torch.Tensor:
84
+ def emission(self, ndarray: np.ndarray, acoustic_scale: float = 1.0, device: Optional[str] = None) -> torch.Tensor:
85
+ """Generate emission probabilities from audio ndarray.
86
+
87
+ Args:
88
+ ndarray: Audio data as numpy array of shape (1, T) or (C, T)
89
+
90
+ Returns:
91
+ Emission tensor of shape (1, T, vocab_size)
92
+ """
79
93
  _start = time.time()
80
94
  if self.extractor is not None:
81
95
  # audio -> features -> emission
96
+ audio = torch.from_numpy(ndarray).to(self.device)
97
+ if audio.shape[1] < 160:
98
+ audio = torch.nn.functional.pad(audio, (0, 320 - audio.shape[1]))
82
99
  features = self.extractor(audio) # (1, T, D)
83
100
  if features.shape[1] > 6000:
84
- features_list = torch.split(features, 6000, dim=1)
85
101
  emissions = []
86
- for features in features_list:
102
+ for start in range(0, features.size(1), 6000):
103
+ _features = features[:, start : start + 6000, :]
87
104
  ort_inputs = {
88
- "features": features.cpu().numpy(),
89
- "feature_lengths": np.array([features.size(1)], dtype=np.int64),
105
+ "features": _features.cpu().numpy(),
106
+ "feature_lengths": np.array([_features.size(1)], dtype=np.int64),
90
107
  }
91
108
  emission = self.acoustic_ort.run(None, ort_inputs)[0] # (1, T, vocab_size) numpy
92
109
  emissions.append(emission)
93
110
  emission = torch.cat(
94
- [torch.from_numpy(emission).to(self.device) for emission in emissions], dim=1
111
+ [torch.from_numpy(emission).to(device or self.device) for emission in emissions], dim=1
95
112
  ) # (1, T, vocab_size)
113
+ del emissions
96
114
  else:
97
115
  ort_inputs = {
98
116
  "features": features.cpu().numpy(),
99
117
  "feature_lengths": np.array([features.size(1)], dtype=np.int64),
100
118
  }
101
119
  emission = self.acoustic_ort.run(None, ort_inputs)[0] # (1, T, vocab_size) numpy
102
- emission = torch.from_numpy(emission).to(self.device)
120
+ emission = torch.from_numpy(emission).to(device or self.device)
103
121
  else:
122
+ if ndarray.shape[1] < 160:
123
+ ndarray = np.pad(ndarray, ((0, 0), (0, 320 - ndarray.shape[1])), mode="constant")
124
+
104
125
  CHUNK_SIZE = 60 * 16000 # 60 seconds
105
- if audio.shape[1] > CHUNK_SIZE:
106
- audio_list = torch.split(audio, CHUNK_SIZE, dim=1)
126
+ if ndarray.shape[1] > CHUNK_SIZE:
107
127
  emissions = []
108
- for audios in audio_list:
128
+ for start in range(0, ndarray.shape[1], CHUNK_SIZE):
109
129
  emission = self.acoustic_ort.run(
110
130
  None,
111
131
  {
112
- "audios": audios.cpu().numpy(),
132
+ "audios": ndarray[:, start : start + CHUNK_SIZE],
113
133
  },
114
- )[
115
- 0
116
- ] # (1, T, vocab_size) numpy
117
- emissions.append(emission)
134
+ ) # (1, T, vocab_size) numpy
135
+ emissions.append(emission[0])
136
+
118
137
  emission = torch.cat(
119
- [torch.from_numpy(emission).to(self.device) for emission in emissions], dim=1
138
+ [torch.from_numpy(emission).to(device or self.device) for emission in emissions], dim=1
120
139
  ) # (1, T, vocab_size)
140
+ del emissions
121
141
  else:
122
142
  emission = self.acoustic_ort.run(
123
143
  None,
124
144
  {
125
- "audios": audio.cpu().numpy(),
145
+ "audios": ndarray,
126
146
  },
127
- )[
128
- 0
129
- ] # (1, T, vocab_size) numpy
130
- emission = torch.from_numpy(emission).to(self.device)
147
+ ) # (1, T, vocab_size) numpy
148
+ emission = torch.from_numpy(emission[0]).to(device or self.device)
149
+
150
+ if acoustic_scale != 1.0:
151
+ emission = emission.mul_(acoustic_scale)
131
152
 
132
153
  self.timings["emission"] += time.time() - _start
133
154
  return emission # (1, T, vocab_size) torch
@@ -144,6 +165,9 @@ class Lattice1Worker:
144
165
  Args:
145
166
  audio: AudioData object
146
167
  lattice_graph: LatticeGraph data
168
+ emission: Pre-computed emission tensor (ignored if streaming=True)
169
+ offset: Time offset for the audio
170
+ streaming: If True, use streaming mode for memory-efficient processing
147
171
 
148
172
  Returns:
149
173
  Processed LatticeGraph
@@ -153,16 +177,6 @@ class Lattice1Worker:
153
177
  DependencyError: If required dependencies are missing
154
178
  AlignmentError: If alignment process fails
155
179
  """
156
- if emission is None:
157
- try:
158
- emission = self.emission(audio.tensor.to(self.device)) # (1, T, vocab_size)
159
- except Exception as e:
160
- raise AlignmentError(
161
- "Failed to compute acoustic features from audio",
162
- media_path=str(audio) if not isinstance(audio, torch.Tensor) else "tensor",
163
- context={"original_error": str(e)},
164
- )
165
-
166
180
  try:
167
181
  import k2
168
182
  except ImportError:
@@ -177,7 +191,7 @@ class Lattice1Worker:
177
191
 
178
192
  _start = time.time()
179
193
  try:
180
- # graph
194
+ # Create decoding graph
181
195
  decoding_graph = k2.Fsa.from_str(lattice_graph_str, acceptor=False)
182
196
  decoding_graph.requires_grad_(False)
183
197
  decoding_graph = k2.arc_sort(decoding_graph)
@@ -190,39 +204,96 @@ class Lattice1Worker:
190
204
  )
191
205
  self.timings["decoding_graph"] += time.time() - _start
192
206
 
193
- _start = time.time()
194
207
  if self.device.type == "mps":
195
208
  device = "cpu" # k2 does not support mps yet
196
209
  else:
197
210
  device = self.device
198
211
 
199
- try:
212
+ _start = time.time()
213
+
214
+ # Get beam search parameters from config or use defaults
215
+ search_beam = self.alignment_config.search_beam or 200
216
+ output_beam = self.alignment_config.output_beam or 80
217
+ min_active_states = self.alignment_config.min_active_states or 400
218
+ max_active_states = self.alignment_config.max_active_states or 10000
219
+
220
+ if emission is None and audio.streaming_mode:
221
+ # Streaming mode: pass emission iterator to align_segments
222
+ # The align_segments function will automatically detect the iterator
223
+ # and use k2.OnlineDenseIntersecter for memory-efficient processing
224
+
225
+ def emission_iterator():
226
+ """Generate emissions for each audio chunk with progress tracking."""
227
+ total_duration = audio.duration
228
+ processed_duration = 0.0
229
+ total_minutes = int(total_duration / 60.0)
230
+
231
+ with tqdm(
232
+ total=total_minutes,
233
+ desc=f"Processing audio ({total_minutes} min)",
234
+ unit="min",
235
+ unit_scale=False,
236
+ unit_divisor=1,
237
+ ) as pbar:
238
+ for chunk in audio.iter_chunks():
239
+ chunk_emission = self.emission(chunk.ndarray, acoustic_scale=acoustic_scale, device=device)
240
+
241
+ # Update progress based on chunk duration in minutes
242
+ chunk_duration = int(chunk.duration / 60.0)
243
+ pbar.update(chunk_duration)
244
+ processed_duration += chunk_duration
245
+
246
+ yield chunk_emission
247
+
248
+ # Calculate total frames for supervision_segments
249
+ total_frames = int(audio.duration / self.frame_shift)
250
+
200
251
  results, labels = align_segments(
201
- emission.to(device) * acoustic_scale,
252
+ emission_iterator(), # Pass iterator for streaming
202
253
  decoding_graph.to(device),
203
- torch.tensor([emission.shape[1]], dtype=torch.int32),
204
- search_beam=200,
205
- output_beam=80,
206
- min_active_states=400,
207
- max_active_states=10000,
254
+ torch.tensor([total_frames], dtype=torch.int32),
255
+ search_beam=search_beam,
256
+ output_beam=output_beam,
257
+ min_active_states=min_active_states,
258
+ max_active_states=max_active_states,
208
259
  subsampling_factor=1,
209
260
  reject_low_confidence=False,
210
261
  )
211
- except Exception as e:
212
- raise AlignmentError(
213
- "Failed to perform forced alignment",
214
- media_path=str(audio) if not isinstance(audio, torch.Tensor) else "tensor",
215
- context={"original_error": str(e), "emission_shape": list(emission.shape), "device": str(device)},
262
+
263
+ # For streaming, don't return emission tensor to save memory
264
+ emission_result = None
265
+ else:
266
+ # Batch mode: compute full emission tensor and pass to align_segments
267
+ if emission is None:
268
+ emission = self.emission(
269
+ audio.ndarray, acoustic_scale=acoustic_scale, device=device
270
+ ) # (1, T, vocab_size)
271
+ else:
272
+ emission = emission.to(device) * acoustic_scale
273
+
274
+ results, labels = align_segments(
275
+ emission,
276
+ decoding_graph.to(device),
277
+ torch.tensor([emission.shape[1]], dtype=torch.int32),
278
+ search_beam=search_beam,
279
+ output_beam=output_beam,
280
+ min_active_states=min_active_states,
281
+ max_active_states=max_active_states,
282
+ subsampling_factor=1,
283
+ reject_low_confidence=False,
216
284
  )
285
+
286
+ emission_result = emission
287
+
217
288
  self.timings["align_segments"] += time.time() - _start
218
289
 
219
290
  channel = 0
220
- return emission, results, labels, self.frame_shift, offset, channel # frame_shift=20ms
291
+ return emission_result, results, labels, self.frame_shift, offset, channel # frame_shift=20ms
221
292
 
222
293
 
223
- def _load_worker(model_path: str, device: str) -> Lattice1Worker:
294
+ def _load_worker(model_path: str, device: str, config: Optional[Any] = None) -> Lattice1Worker:
224
295
  """Instantiate lattice worker with consistent error handling."""
225
296
  try:
226
- return Lattice1Worker(model_path, device=device, num_threads=8)
297
+ return Lattice1Worker(model_path, device=device, num_threads=8, config=config)
227
298
  except Exception as e:
228
299
  raise ModelLoadError(f"worker from {model_path}", original_error=e)
@@ -7,6 +7,7 @@ import colorful
7
7
  from lattifai.audio2 import AudioData
8
8
  from lattifai.caption import Caption, Supervision
9
9
  from lattifai.config import AlignmentConfig
10
+ from lattifai.utils import safe_print
10
11
 
11
12
  from .tokenizer import END_PUNCTUATION
12
13
 
@@ -153,7 +154,7 @@ class Segmenter:
153
154
 
154
155
  total_sups = sum(len(sups) if isinstance(sups, list) else 1 for _, _, sups, _ in segments)
155
156
 
156
- print(colorful.cyan(f"📊 Created {len(segments)} alignment segments:"))
157
+ safe_print(colorful.cyan(f"📊 Created {len(segments)} alignment segments:"))
157
158
  for i, (start, end, sups, _) in enumerate(segments, 1):
158
159
  duration = end - start
159
160
  print(
@@ -163,4 +164,4 @@ class Segmenter:
163
164
  )
164
165
  )
165
166
 
166
- print(colorful.green(f" Total: {total_sups} supervisions across {len(segments)} segments"))
167
+ safe_print(colorful.green(f" Total: {total_sups} supervisions across {len(segments)} segments"))
@@ -335,7 +335,7 @@ class LatticeTokenizer:
335
335
  flush_segment(s, None)
336
336
 
337
337
  assert len(speakers) == len(texts), f"len(speakers)={len(speakers)} != len(texts)={len(texts)}"
338
- sentences = self.sentence_splitter.split(texts, threshold=0.15, strip_whitespace=strip_whitespace)
338
+ sentences = self.sentence_splitter.split(texts, threshold=0.15, strip_whitespace=strip_whitespace, batch_size=8)
339
339
 
340
340
  supervisions, remainder = [], ""
341
341
  for k, (_speaker, _sentences) in enumerate(zip(speakers, sentences)):
@@ -450,7 +450,7 @@ class LatticeTokenizer:
450
450
  "destroy_lattice": True,
451
451
  },
452
452
  )
453
- if response.status_code == 422:
453
+ if response.status_code == 400:
454
454
  raise LatticeDecodingError(
455
455
  lattice_id,
456
456
  original_error=Exception(LATTICE_DECODING_FAILURE_HELP),
@@ -466,7 +466,7 @@ class LatticeTokenizer:
466
466
 
467
467
  alignments = [Supervision.from_dict(s) for s in result["supervisions"]]
468
468
 
469
- if return_details:
469
+ if emission is not None and return_details:
470
470
  # Add emission confidence scores for segments and word-level alignments
471
471
  _add_confidence_scores(alignments, emission, labels[0], frame_shift, offset)
472
472