lattifai 1.0.4__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. lattifai/__init__.py +10 -0
  2. lattifai/alignment/lattice1_aligner.py +64 -15
  3. lattifai/alignment/lattice1_worker.py +135 -50
  4. lattifai/alignment/segmenter.py +3 -2
  5. lattifai/alignment/tokenizer.py +14 -13
  6. lattifai/audio2.py +269 -70
  7. lattifai/caption/caption.py +213 -19
  8. lattifai/cli/__init__.py +2 -0
  9. lattifai/cli/alignment.py +2 -1
  10. lattifai/cli/app_installer.py +35 -33
  11. lattifai/cli/caption.py +9 -19
  12. lattifai/cli/diarization.py +108 -0
  13. lattifai/cli/server.py +3 -1
  14. lattifai/cli/transcribe.py +55 -38
  15. lattifai/cli/youtube.py +1 -0
  16. lattifai/client.py +42 -121
  17. lattifai/config/alignment.py +37 -2
  18. lattifai/config/caption.py +1 -1
  19. lattifai/config/media.py +23 -3
  20. lattifai/config/transcription.py +4 -0
  21. lattifai/diarization/lattifai.py +18 -7
  22. lattifai/errors.py +7 -3
  23. lattifai/mixin.py +45 -16
  24. lattifai/server/app.py +2 -1
  25. lattifai/transcription/__init__.py +1 -1
  26. lattifai/transcription/base.py +21 -2
  27. lattifai/transcription/gemini.py +127 -1
  28. lattifai/transcription/lattifai.py +30 -2
  29. lattifai/utils.py +96 -28
  30. lattifai/workflow/file_manager.py +15 -13
  31. lattifai/workflow/youtube.py +16 -1
  32. {lattifai-1.0.4.dist-info → lattifai-1.1.0.dist-info}/METADATA +86 -22
  33. lattifai-1.1.0.dist-info/RECORD +57 -0
  34. {lattifai-1.0.4.dist-info → lattifai-1.1.0.dist-info}/entry_points.txt +2 -0
  35. {lattifai-1.0.4.dist-info → lattifai-1.1.0.dist-info}/licenses/LICENSE +1 -1
  36. lattifai-1.0.4.dist-info/RECORD +0 -56
  37. {lattifai-1.0.4.dist-info → lattifai-1.1.0.dist-info}/WHEEL +0 -0
  38. {lattifai-1.0.4.dist-info → lattifai-1.1.0.dist-info}/top_level.txt +0 -0
lattifai/__init__.py CHANGED
@@ -1,7 +1,17 @@
1
+ import os
1
2
  import sys
2
3
  import warnings
3
4
  from importlib.metadata import version
4
5
 
6
+ # Suppress SWIG deprecation warnings before any imports
7
+ warnings.filterwarnings("ignore", category=DeprecationWarning, message=".*SwigPy.*")
8
+
9
+ # Suppress PyTorch transformer nested tensor warning
10
+ warnings.filterwarnings("ignore", category=UserWarning, message=".*enable_nested_tensor.*")
11
+
12
+ # Disable tokenizers parallelism warning
13
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
14
+
5
15
  # Re-export I/O classes
6
16
  from .caption import Caption
7
17
 
@@ -3,6 +3,7 @@
3
3
  from typing import Any, List, Optional, Tuple
4
4
 
5
5
  import colorful
6
+ import numpy as np
6
7
  import torch
7
8
 
8
9
  from lattifai.audio2 import AudioData
@@ -13,7 +14,7 @@ from lattifai.errors import (
13
14
  LatticeDecodingError,
14
15
  LatticeEncodingError,
15
16
  )
16
- from lattifai.utils import _resolve_model_path
17
+ from lattifai.utils import _resolve_model_path, safe_print
17
18
 
18
19
  from .lattice1_worker import _load_worker
19
20
  from .tokenizer import _load_tokenizer
@@ -34,15 +35,47 @@ class Lattice1Aligner(object):
34
35
  raise ValueError("AlignmentConfig.client_wrapper is not set. It must be initialized by the client.")
35
36
 
36
37
  client_wrapper = config.client_wrapper
37
- model_path = _resolve_model_path(config.model_name)
38
+ # Resolve model path using configured model hub
39
+ model_path = _resolve_model_path(config.model_name, getattr(config, "model_hub", "huggingface"))
38
40
 
39
41
  self.tokenizer = _load_tokenizer(client_wrapper, model_path, config.model_name, config.device)
40
- self.worker = _load_worker(model_path, config.device)
42
+ self.worker = _load_worker(model_path, config.device, config)
41
43
 
42
44
  self.frame_shift = self.worker.frame_shift
43
45
 
44
- def emission(self, audio: torch.Tensor) -> torch.Tensor:
45
- return self.worker.emission(audio.to(self.worker.device))
46
+ def emission(self, ndarray: np.ndarray) -> torch.Tensor:
47
+ """Generate emission probabilities from audio ndarray.
48
+
49
+ Args:
50
+ ndarray: Audio data as numpy array of shape (1, T) or (C, T)
51
+
52
+ Returns:
53
+ Emission tensor of shape (1, T, vocab_size)
54
+ """
55
+ return self.worker.emission(ndarray)
56
+
57
+ def separate(self, audio: np.ndarray) -> np.ndarray:
58
+ """Separate audio using separator model.
59
+
60
+ Args:
61
+ audio: np.ndarray object containing the audio to separate, shape (1, T)
62
+
63
+ Returns:
64
+ Separated audio as numpy array
65
+
66
+ Raises:
67
+ RuntimeError: If separator model is not available
68
+ """
69
+ if self.worker.separator_ort is None:
70
+ raise RuntimeError("Separator model not available. separator.onnx not found in model path.")
71
+
72
+ # Run separator model
73
+ separator_output = self.worker.separator_ort.run(
74
+ None,
75
+ {"audio": audio},
76
+ )
77
+
78
+ return separator_output[0]
46
79
 
47
80
  def alignment(
48
81
  self,
@@ -72,23 +105,34 @@ class Lattice1Aligner(object):
72
105
  """
73
106
  try:
74
107
  if verbose:
75
- print(colorful.cyan("🔗 Step 2: Creating lattice graph from segments"))
108
+ safe_print(colorful.cyan("🔗 Step 2: Creating lattice graph from segments"))
76
109
  try:
77
110
  supervisions, lattice_id, lattice_graph = self.tokenizer.tokenize(
78
111
  supervisions, split_sentence=split_sentence
79
112
  )
80
113
  if verbose:
81
- print(colorful.green(f" ✓ Generated lattice graph with ID: {lattice_id}"))
114
+ safe_print(colorful.green(f" ✓ Generated lattice graph with ID: {lattice_id}"))
82
115
  except Exception as e:
83
116
  text_content = " ".join([sup.text for sup in supervisions]) if supervisions else ""
84
117
  raise LatticeEncodingError(text_content, original_error=e)
85
118
 
86
119
  if verbose:
87
- print(colorful.cyan(f"🔍 Step 3: Searching lattice graph with media: {audio}"))
120
+ safe_print(colorful.cyan(f"🔍 Step 3: Searching lattice graph with media: {audio}"))
121
+ if audio.streaming_chunk_secs:
122
+ safe_print(
123
+ colorful.yellow(
124
+ f" ⚡Using streaming mode with {audio.streaming_chunk_secs}s (chunk duration)"
125
+ )
126
+ )
88
127
  try:
89
- lattice_results = self.worker.alignment(audio, lattice_graph, emission=emission, offset=offset)
128
+ lattice_results = self.worker.alignment(
129
+ audio,
130
+ lattice_graph,
131
+ emission=emission,
132
+ offset=offset,
133
+ )
90
134
  if verbose:
91
- print(colorful.green(" ✓ Lattice search completed"))
135
+ safe_print(colorful.green(" ✓ Lattice search completed"))
92
136
  except Exception as e:
93
137
  raise AlignmentError(
94
138
  f"Audio alignment failed for {audio}",
@@ -97,18 +141,23 @@ class Lattice1Aligner(object):
97
141
  )
98
142
 
99
143
  if verbose:
100
- print(colorful.cyan("🎯 Step 4: Decoding lattice results to aligned segments"))
144
+ safe_print(colorful.cyan("🎯 Step 4: Decoding lattice results to aligned segments"))
101
145
  try:
102
146
  alignments = self.tokenizer.detokenize(
103
- lattice_id, lattice_results, supervisions=supervisions, return_details=return_details
147
+ lattice_id,
148
+ lattice_results,
149
+ supervisions=supervisions,
150
+ return_details=return_details,
151
+ start_margin=self.config.start_margin,
152
+ end_margin=self.config.end_margin,
104
153
  )
105
154
  if verbose:
106
- print(colorful.green(f" ✓ Successfully aligned {len(alignments)} segments"))
155
+ safe_print(colorful.green(f" ✓ Successfully aligned {len(alignments)} segments"))
107
156
  except LatticeDecodingError as e:
108
- print(colorful.red(" x Failed to decode lattice alignment results"))
157
+ safe_print(colorful.red(" x Failed to decode lattice alignment results"))
109
158
  raise e
110
159
  except Exception as e:
111
- print(colorful.red(" x Failed to decode lattice alignment results"))
160
+ safe_print(colorful.red(" x Failed to decode lattice alignment results"))
112
161
  raise LatticeDecodingError(lattice_id, original_error=e)
113
162
 
114
163
  return (supervisions, alignments)
@@ -1,6 +1,7 @@
1
1
  import json
2
2
  import time
3
3
  from collections import defaultdict
4
+ from pathlib import Path
4
5
  from typing import Any, Dict, Optional, Tuple
5
6
 
6
7
  import numpy as np
@@ -9,6 +10,7 @@ import torch
9
10
  from lhotse import FbankConfig
10
11
  from lhotse.features.kaldi.layers import Wav2LogFilterBank
11
12
  from lhotse.utils import Pathlike
13
+ from tqdm import tqdm
12
14
 
13
15
  from lattifai.audio2 import AudioData
14
16
  from lattifai.errors import AlignmentError, DependencyError, ModelLoadError
@@ -17,12 +19,17 @@ from lattifai.errors import AlignmentError, DependencyError, ModelLoadError
17
19
  class Lattice1Worker:
18
20
  """Worker for processing audio with LatticeGraph."""
19
21
 
20
- def __init__(self, model_path: Pathlike, device: str = "cpu", num_threads: int = 8) -> None:
22
+ def __init__(
23
+ self, model_path: Pathlike, device: str = "cpu", num_threads: int = 8, config: Optional[Any] = None
24
+ ) -> None:
21
25
  try:
22
- self.config = json.load(open(f"{model_path}/config.json"))
26
+ self.model_config = json.load(open(f"{model_path}/config.json"))
23
27
  except Exception as e:
24
28
  raise ModelLoadError(f"config from {model_path}", original_error=e)
25
29
 
30
+ # Store alignment config with beam search parameters
31
+ self.alignment_config = config
32
+
26
33
  # SessionOptions
27
34
  sess_options = ort.SessionOptions()
28
35
  # sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
@@ -67,6 +74,19 @@ class Lattice1Worker:
67
74
  else:
68
75
  self.extractor = None # ONNX model includes feature extractor
69
76
 
77
+ # Initialize separator if available
78
+ separator_model_path = Path(model_path) / "separator.onnx"
79
+ if separator_model_path.exists():
80
+ try:
81
+ self.separator_ort = ort.InferenceSession(
82
+ str(separator_model_path),
83
+ providers=providers + ["CPUExecutionProvider"],
84
+ )
85
+ except Exception as e:
86
+ raise ModelLoadError(f"separator model from {model_path}", original_error=e)
87
+ else:
88
+ self.separator_ort = None
89
+
70
90
  self.device = torch.device(device)
71
91
  self.timings = defaultdict(lambda: 0.0)
72
92
 
@@ -75,59 +95,74 @@ class Lattice1Worker:
75
95
  return 0.02 # 20 ms
76
96
 
77
97
  @torch.inference_mode()
78
- def emission(self, audio: torch.Tensor) -> torch.Tensor:
98
+ def emission(self, ndarray: np.ndarray, acoustic_scale: float = 1.0, device: Optional[str] = None) -> torch.Tensor:
99
+ """Generate emission probabilities from audio ndarray.
100
+
101
+ Args:
102
+ ndarray: Audio data as numpy array of shape (1, T) or (C, T)
103
+
104
+ Returns:
105
+ Emission tensor of shape (1, T, vocab_size)
106
+ """
79
107
  _start = time.time()
80
108
  if self.extractor is not None:
81
109
  # audio -> features -> emission
110
+ audio = torch.from_numpy(ndarray).to(self.device)
111
+ if audio.shape[1] < 160:
112
+ audio = torch.nn.functional.pad(audio, (0, 320 - audio.shape[1]))
82
113
  features = self.extractor(audio) # (1, T, D)
83
114
  if features.shape[1] > 6000:
84
- features_list = torch.split(features, 6000, dim=1)
85
115
  emissions = []
86
- for features in features_list:
116
+ for start in range(0, features.size(1), 6000):
117
+ _features = features[:, start : start + 6000, :]
87
118
  ort_inputs = {
88
- "features": features.cpu().numpy(),
89
- "feature_lengths": np.array([features.size(1)], dtype=np.int64),
119
+ "features": _features.cpu().numpy(),
120
+ "feature_lengths": np.array([_features.size(1)], dtype=np.int64),
90
121
  }
91
122
  emission = self.acoustic_ort.run(None, ort_inputs)[0] # (1, T, vocab_size) numpy
92
123
  emissions.append(emission)
93
124
  emission = torch.cat(
94
- [torch.from_numpy(emission).to(self.device) for emission in emissions], dim=1
125
+ [torch.from_numpy(emission).to(device or self.device) for emission in emissions], dim=1
95
126
  ) # (1, T, vocab_size)
127
+ del emissions
96
128
  else:
97
129
  ort_inputs = {
98
130
  "features": features.cpu().numpy(),
99
131
  "feature_lengths": np.array([features.size(1)], dtype=np.int64),
100
132
  }
101
133
  emission = self.acoustic_ort.run(None, ort_inputs)[0] # (1, T, vocab_size) numpy
102
- emission = torch.from_numpy(emission).to(self.device)
134
+ emission = torch.from_numpy(emission).to(device or self.device)
103
135
  else:
136
+ if ndarray.shape[1] < 160:
137
+ ndarray = np.pad(ndarray, ((0, 0), (0, 320 - ndarray.shape[1])), mode="constant")
138
+
104
139
  CHUNK_SIZE = 60 * 16000 # 60 seconds
105
- if audio.shape[1] > CHUNK_SIZE:
106
- audio_list = torch.split(audio, CHUNK_SIZE, dim=1)
140
+ if ndarray.shape[1] > CHUNK_SIZE:
107
141
  emissions = []
108
- for audios in audio_list:
142
+ for start in range(0, ndarray.shape[1], CHUNK_SIZE):
109
143
  emission = self.acoustic_ort.run(
110
144
  None,
111
145
  {
112
- "audios": audios.cpu().numpy(),
146
+ "audios": ndarray[:, start : start + CHUNK_SIZE],
113
147
  },
114
- )[
115
- 0
116
- ] # (1, T, vocab_size) numpy
117
- emissions.append(emission)
148
+ ) # (1, T, vocab_size) numpy
149
+ emissions.append(emission[0])
150
+
118
151
  emission = torch.cat(
119
- [torch.from_numpy(emission).to(self.device) for emission in emissions], dim=1
152
+ [torch.from_numpy(emission).to(device or self.device) for emission in emissions], dim=1
120
153
  ) # (1, T, vocab_size)
154
+ del emissions
121
155
  else:
122
156
  emission = self.acoustic_ort.run(
123
157
  None,
124
158
  {
125
- "audios": audio.cpu().numpy(),
159
+ "audios": ndarray,
126
160
  },
127
- )[
128
- 0
129
- ] # (1, T, vocab_size) numpy
130
- emission = torch.from_numpy(emission).to(self.device)
161
+ ) # (1, T, vocab_size) numpy
162
+ emission = torch.from_numpy(emission[0]).to(device or self.device)
163
+
164
+ if acoustic_scale != 1.0:
165
+ emission = emission.mul_(acoustic_scale)
131
166
 
132
167
  self.timings["emission"] += time.time() - _start
133
168
  return emission # (1, T, vocab_size) torch
@@ -144,6 +179,9 @@ class Lattice1Worker:
144
179
  Args:
145
180
  audio: AudioData object
146
181
  lattice_graph: LatticeGraph data
182
+ emission: Pre-computed emission tensor (ignored if streaming=True)
183
+ offset: Time offset for the audio
184
+ streaming: If True, use streaming mode for memory-efficient processing
147
185
 
148
186
  Returns:
149
187
  Processed LatticeGraph
@@ -153,16 +191,6 @@ class Lattice1Worker:
153
191
  DependencyError: If required dependencies are missing
154
192
  AlignmentError: If alignment process fails
155
193
  """
156
- if emission is None:
157
- try:
158
- emission = self.emission(audio.tensor.to(self.device)) # (1, T, vocab_size)
159
- except Exception as e:
160
- raise AlignmentError(
161
- "Failed to compute acoustic features from audio",
162
- media_path=str(audio) if not isinstance(audio, torch.Tensor) else "tensor",
163
- context={"original_error": str(e)},
164
- )
165
-
166
194
  try:
167
195
  import k2
168
196
  except ImportError:
@@ -177,7 +205,7 @@ class Lattice1Worker:
177
205
 
178
206
  _start = time.time()
179
207
  try:
180
- # graph
208
+ # Create decoding graph
181
209
  decoding_graph = k2.Fsa.from_str(lattice_graph_str, acceptor=False)
182
210
  decoding_graph.requires_grad_(False)
183
211
  decoding_graph = k2.arc_sort(decoding_graph)
@@ -190,39 +218,96 @@ class Lattice1Worker:
190
218
  )
191
219
  self.timings["decoding_graph"] += time.time() - _start
192
220
 
193
- _start = time.time()
194
221
  if self.device.type == "mps":
195
222
  device = "cpu" # k2 does not support mps yet
196
223
  else:
197
224
  device = self.device
198
225
 
199
- try:
226
+ _start = time.time()
227
+
228
+ # Get beam search parameters from config or use defaults
229
+ search_beam = self.alignment_config.search_beam or 200
230
+ output_beam = self.alignment_config.output_beam or 80
231
+ min_active_states = self.alignment_config.min_active_states or 400
232
+ max_active_states = self.alignment_config.max_active_states or 10000
233
+
234
+ if emission is None and audio.streaming_mode:
235
+ # Streaming mode: pass emission iterator to align_segments
236
+ # The align_segments function will automatically detect the iterator
237
+ # and use k2.OnlineDenseIntersecter for memory-efficient processing
238
+
239
+ def emission_iterator():
240
+ """Generate emissions for each audio chunk with progress tracking."""
241
+ total_duration = audio.duration
242
+ processed_duration = 0.0
243
+ total_minutes = int(total_duration / 60.0)
244
+
245
+ with tqdm(
246
+ total=total_minutes,
247
+ desc=f"Processing audio ({total_minutes} min)",
248
+ unit="min",
249
+ unit_scale=False,
250
+ unit_divisor=1,
251
+ ) as pbar:
252
+ for chunk in audio.iter_chunks():
253
+ chunk_emission = self.emission(chunk.ndarray, acoustic_scale=acoustic_scale, device=device)
254
+
255
+ # Update progress based on chunk duration in minutes
256
+ chunk_duration = int(chunk.duration / 60.0)
257
+ pbar.update(chunk_duration)
258
+ processed_duration += chunk_duration
259
+
260
+ yield chunk_emission
261
+
262
+ # Calculate total frames for supervision_segments
263
+ total_frames = int(audio.duration / self.frame_shift)
264
+
200
265
  results, labels = align_segments(
201
- emission.to(device) * acoustic_scale,
266
+ emission_iterator(), # Pass iterator for streaming
202
267
  decoding_graph.to(device),
203
- torch.tensor([emission.shape[1]], dtype=torch.int32),
204
- search_beam=200,
205
- output_beam=80,
206
- min_active_states=400,
207
- max_active_states=10000,
268
+ torch.tensor([total_frames], dtype=torch.int32),
269
+ search_beam=search_beam,
270
+ output_beam=output_beam,
271
+ min_active_states=min_active_states,
272
+ max_active_states=max_active_states,
208
273
  subsampling_factor=1,
209
274
  reject_low_confidence=False,
210
275
  )
211
- except Exception as e:
212
- raise AlignmentError(
213
- "Failed to perform forced alignment",
214
- media_path=str(audio) if not isinstance(audio, torch.Tensor) else "tensor",
215
- context={"original_error": str(e), "emission_shape": list(emission.shape), "device": str(device)},
276
+
277
+ # For streaming, don't return emission tensor to save memory
278
+ emission_result = None
279
+ else:
280
+ # Batch mode: compute full emission tensor and pass to align_segments
281
+ if emission is None:
282
+ emission = self.emission(
283
+ audio.ndarray, acoustic_scale=acoustic_scale, device=device
284
+ ) # (1, T, vocab_size)
285
+ else:
286
+ emission = emission.to(device) * acoustic_scale
287
+
288
+ results, labels = align_segments(
289
+ emission,
290
+ decoding_graph.to(device),
291
+ torch.tensor([emission.shape[1]], dtype=torch.int32),
292
+ search_beam=search_beam,
293
+ output_beam=output_beam,
294
+ min_active_states=min_active_states,
295
+ max_active_states=max_active_states,
296
+ subsampling_factor=1,
297
+ reject_low_confidence=False,
216
298
  )
299
+
300
+ emission_result = emission
301
+
217
302
  self.timings["align_segments"] += time.time() - _start
218
303
 
219
304
  channel = 0
220
- return emission, results, labels, self.frame_shift, offset, channel # frame_shift=20ms
305
+ return emission_result, results, labels, self.frame_shift, offset, channel # frame_shift=20ms
221
306
 
222
307
 
223
- def _load_worker(model_path: str, device: str) -> Lattice1Worker:
308
+ def _load_worker(model_path: str, device: str, config: Optional[Any] = None) -> Lattice1Worker:
224
309
  """Instantiate lattice worker with consistent error handling."""
225
310
  try:
226
- return Lattice1Worker(model_path, device=device, num_threads=8)
311
+ return Lattice1Worker(model_path, device=device, num_threads=8, config=config)
227
312
  except Exception as e:
228
313
  raise ModelLoadError(f"worker from {model_path}", original_error=e)
@@ -7,6 +7,7 @@ import colorful
7
7
  from lattifai.audio2 import AudioData
8
8
  from lattifai.caption import Caption, Supervision
9
9
  from lattifai.config import AlignmentConfig
10
+ from lattifai.utils import safe_print
10
11
 
11
12
  from .tokenizer import END_PUNCTUATION
12
13
 
@@ -153,7 +154,7 @@ class Segmenter:
153
154
 
154
155
  total_sups = sum(len(sups) if isinstance(sups, list) else 1 for _, _, sups, _ in segments)
155
156
 
156
- print(colorful.cyan(f"📊 Created {len(segments)} alignment segments:"))
157
+ safe_print(colorful.cyan(f"📊 Created {len(segments)} alignment segments:"))
157
158
  for i, (start, end, sups, _) in enumerate(segments, 1):
158
159
  duration = end - start
159
160
  print(
@@ -163,4 +164,4 @@ class Segmenter:
163
164
  )
164
165
  )
165
166
 
166
- print(colorful.green(f" Total: {total_sups} supervisions across {len(segments)} segments"))
167
+ safe_print(colorful.green(f" Total: {total_sups} supervisions across {len(segments)} segments"))
@@ -214,7 +214,7 @@ class LatticeTokenizer:
214
214
  else:
215
215
  with open(words_model_path, "rb") as f:
216
216
  data = pickle.load(f)
217
- except pickle.UnpicklingError as e:
217
+ except Exception as e:
218
218
  del e
219
219
  import msgpack
220
220
 
@@ -335,7 +335,7 @@ class LatticeTokenizer:
335
335
  flush_segment(s, None)
336
336
 
337
337
  assert len(speakers) == len(texts), f"len(speakers)={len(speakers)} != len(texts)={len(texts)}"
338
- sentences = self.sentence_splitter.split(texts, threshold=0.15, strip_whitespace=strip_whitespace)
338
+ sentences = self.sentence_splitter.split(texts, threshold=0.15, strip_whitespace=strip_whitespace, batch_size=8)
339
339
 
340
340
  supervisions, remainder = [], ""
341
341
  for k, (_speaker, _sentences) in enumerate(zip(speakers, sentences)):
@@ -434,6 +434,8 @@ class LatticeTokenizer:
434
434
  lattice_results: Tuple[torch.Tensor, Any, Any, float, float],
435
435
  supervisions: List[Supervision],
436
436
  return_details: bool = False,
437
+ start_margin: float = 0.08,
438
+ end_margin: float = 0.20,
437
439
  ) -> List[Supervision]:
438
440
  emission, results, labels, frame_shift, offset, channel = lattice_results # noqa: F841
439
441
  response = self.client_wrapper.post(
@@ -448,9 +450,11 @@ class LatticeTokenizer:
448
450
  "channel": channel,
449
451
  "return_details": False if return_details is None else return_details,
450
452
  "destroy_lattice": True,
453
+ "start_margin": start_margin,
454
+ "end_margin": end_margin,
451
455
  },
452
456
  )
453
- if response.status_code == 422:
457
+ if response.status_code == 400:
454
458
  raise LatticeDecodingError(
455
459
  lattice_id,
456
460
  original_error=Exception(LATTICE_DECODING_FAILURE_HELP),
@@ -466,7 +470,7 @@ class LatticeTokenizer:
466
470
 
467
471
  alignments = [Supervision.from_dict(s) for s in result["supervisions"]]
468
472
 
469
- if return_details:
473
+ if emission is not None and return_details:
470
474
  # Add emission confidence scores for segments and word-level alignments
471
475
  _add_confidence_scores(alignments, emission, labels[0], frame_shift, offset)
472
476
 
@@ -538,12 +542,9 @@ def _load_tokenizer(
538
542
  tokenizer_cls: Type[LatticeTokenizer] = LatticeTokenizer,
539
543
  ) -> LatticeTokenizer:
540
544
  """Instantiate tokenizer with consistent error handling."""
541
- try:
542
- return tokenizer_cls.from_pretrained(
543
- client_wrapper=client_wrapper,
544
- model_path=model_path,
545
- model_name=model_name,
546
- device=device,
547
- )
548
- except Exception as e:
549
- raise ModelLoadError(f"tokenizer from {model_path}", original_error=e)
545
+ return tokenizer_cls.from_pretrained(
546
+ client_wrapper=client_wrapper,
547
+ model_path=model_path,
548
+ model_name=model_name,
549
+ device=device,
550
+ )