lattifai 1.2.2__py3-none-any.whl → 1.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. lattifai/_init.py +20 -0
  2. lattifai/alignment/__init__.py +2 -3
  3. lattifai/alignment/lattice1_aligner.py +117 -4
  4. lattifai/alignment/lattice1_worker.py +47 -4
  5. lattifai/alignment/segmenter.py +3 -2
  6. lattifai/alignment/text_align.py +2 -1
  7. lattifai/alignment/tokenizer.py +56 -29
  8. lattifai/audio2.py +162 -183
  9. lattifai/cli/alignment.py +5 -0
  10. lattifai/cli/caption.py +6 -6
  11. lattifai/cli/transcribe.py +1 -5
  12. lattifai/cli/youtube.py +3 -0
  13. lattifai/client.py +41 -12
  14. lattifai/config/__init__.py +21 -3
  15. lattifai/config/alignment.py +7 -0
  16. lattifai/config/caption.py +13 -243
  17. lattifai/config/client.py +16 -0
  18. lattifai/config/event.py +102 -0
  19. lattifai/config/transcription.py +25 -1
  20. lattifai/data/__init__.py +8 -0
  21. lattifai/data/caption.py +228 -0
  22. lattifai/errors.py +78 -53
  23. lattifai/event/__init__.py +65 -0
  24. lattifai/event/lattifai.py +166 -0
  25. lattifai/mixin.py +22 -17
  26. lattifai/transcription/base.py +2 -1
  27. lattifai/transcription/gemini.py +147 -16
  28. lattifai/transcription/lattifai.py +8 -11
  29. lattifai/types.py +1 -1
  30. lattifai/youtube/client.py +143 -48
  31. {lattifai-1.2.2.dist-info → lattifai-1.3.1.dist-info}/METADATA +129 -58
  32. lattifai-1.3.1.dist-info/RECORD +57 -0
  33. lattifai/__init__.py +0 -88
  34. lattifai/alignment/sentence_splitter.py +0 -350
  35. lattifai/caption/__init__.py +0 -96
  36. lattifai/caption/caption.py +0 -661
  37. lattifai/caption/formats/__init__.py +0 -199
  38. lattifai/caption/formats/base.py +0 -211
  39. lattifai/caption/formats/gemini.py +0 -722
  40. lattifai/caption/formats/json.py +0 -194
  41. lattifai/caption/formats/lrc.py +0 -309
  42. lattifai/caption/formats/nle/__init__.py +0 -9
  43. lattifai/caption/formats/nle/audition.py +0 -561
  44. lattifai/caption/formats/nle/avid.py +0 -423
  45. lattifai/caption/formats/nle/fcpxml.py +0 -549
  46. lattifai/caption/formats/nle/premiere.py +0 -589
  47. lattifai/caption/formats/pysubs2.py +0 -642
  48. lattifai/caption/formats/sbv.py +0 -147
  49. lattifai/caption/formats/tabular.py +0 -338
  50. lattifai/caption/formats/textgrid.py +0 -193
  51. lattifai/caption/formats/ttml.py +0 -652
  52. lattifai/caption/formats/vtt.py +0 -469
  53. lattifai/caption/parsers/__init__.py +0 -9
  54. lattifai/caption/parsers/text_parser.py +0 -147
  55. lattifai/caption/standardize.py +0 -636
  56. lattifai/caption/supervision.py +0 -34
  57. lattifai/caption/utils.py +0 -474
  58. lattifai-1.2.2.dist-info/RECORD +0 -76
  59. {lattifai-1.2.2.dist-info → lattifai-1.3.1.dist-info}/WHEEL +0 -0
  60. {lattifai-1.2.2.dist-info → lattifai-1.3.1.dist-info}/entry_points.txt +0 -0
  61. {lattifai-1.2.2.dist-info → lattifai-1.3.1.dist-info}/licenses/LICENSE +0 -0
  62. {lattifai-1.2.2.dist-info → lattifai-1.3.1.dist-info}/top_level.txt +0 -0
lattifai/_init.py ADDED
@@ -0,0 +1,20 @@
1
+ """Environment configuration for LattifAI.
2
+
3
+ Import this module early to suppress warnings before other imports.
4
+
5
+ Usage:
6
+ import lattifai._init # noqa: F401
7
+ from lattifai.client import LattifAI
8
+ """
9
+
10
+ import os
11
+ import warnings
12
+
13
+ # Suppress SWIG deprecation warnings before any imports
14
+ warnings.filterwarnings("ignore", category=DeprecationWarning, message=".*SwigPy.*")
15
+
16
+ # Suppress PyTorch transformer nested tensor warning
17
+ warnings.filterwarnings("ignore", category=UserWarning, message=".*enable_nested_tensor.*")
18
+
19
+ # Disable tokenizers parallelism warning
20
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -1,15 +1,14 @@
1
1
  """Alignment module for LattifAI forced alignment."""
2
2
 
3
+ from lattifai.caption import SentenceSplitter
4
+
3
5
  from .lattice1_aligner import Lattice1Aligner
4
6
  from .segmenter import Segmenter
5
- from .sentence_splitter import SentenceSplitter
6
- from .text_align import align_supervisions_and_transcription
7
7
  from .tokenizer import tokenize_multilingual_text
8
8
 
9
9
  __all__ = [
10
10
  "Lattice1Aligner",
11
11
  "Segmenter",
12
12
  "SentenceSplitter",
13
- "align_supervisions_and_transcription",
14
13
  "tokenize_multilingual_text",
15
14
  ]
@@ -1,6 +1,6 @@
1
1
  """Lattice-1 Aligner implementation."""
2
2
 
3
- from typing import Any, List, Optional, Tuple, Union
3
+ from typing import Any, Dict, List, Optional, Tuple, Union
4
4
 
5
5
  import colorful
6
6
  import numpy as np
@@ -118,7 +118,10 @@ class Lattice1Aligner(object):
118
118
  safe_print(colorful.cyan("🔗 Step 2: Creating lattice graph from segments"))
119
119
  try:
120
120
  supervisions, lattice_id, lattice_graph = self.tokenizer.tokenize(
121
- supervisions, split_sentence=split_sentence, boost=self.config.boost
121
+ supervisions,
122
+ split_sentence=split_sentence,
123
+ boost=self.config.boost,
124
+ transition_penalty=self.config.transition_penalty,
122
125
  )
123
126
  if verbose:
124
127
  safe_print(colorful.green(f" ✓ Generated lattice graph with ID: {lattice_id}"))
@@ -162,12 +165,34 @@ class Lattice1Aligner(object):
162
165
  return_details=return_details,
163
166
  start_margin=self.config.start_margin,
164
167
  end_margin=self.config.end_margin,
168
+ check_sanity=True,
165
169
  )
166
170
  if verbose:
167
171
  safe_print(colorful.green(f" ✓ Successfully aligned {len(alignments)} segments"))
168
- except LatticeDecodingError:
172
+ except LatticeDecodingError as e:
169
173
  safe_print(colorful.red(" x Failed to decode lattice alignment results"))
170
- raise
174
+ _alignments = self.tokenizer.detokenize(
175
+ lattice_id,
176
+ lattice_results,
177
+ supervisions=supervisions,
178
+ return_details=return_details,
179
+ start_margin=self.config.start_margin,
180
+ end_margin=self.config.end_margin,
181
+ check_sanity=False,
182
+ )
183
+ # Check for score anomalies (media-text mismatch)
184
+ anomaly = _detect_score_anomalies(_alignments)
185
+ if anomaly:
186
+ anomaly_str = _format_anomaly_warning(anomaly)
187
+ del _alignments
188
+ raise LatticeDecodingError(
189
+ lattice_id,
190
+ message=colorful.yellow("Score anomaly detected - media and text mismatch:\n" + anomaly_str),
191
+ skip_help=True, # anomaly info is more specific than default help
192
+ )
193
+ else:
194
+ del _alignments
195
+ raise e
171
196
  except Exception as e:
172
197
  safe_print(colorful.red(" x Failed to decode lattice alignment results"))
173
198
  raise LatticeDecodingError(lattice_id, original_error=e)
@@ -177,3 +202,91 @@ class Lattice1Aligner(object):
177
202
  def profile(self) -> None:
178
203
  """Print profiling statistics."""
179
204
  self.worker.profile()
205
+
206
+
207
+ def _detect_score_anomalies(
208
+ alignments: List[Supervision],
209
+ drop_threshold: float = 0.08,
210
+ window_size: int = 5,
211
+ ) -> Optional[Dict[str, Any]]:
212
+ """Detect score anomalies indicating alignment mismatch.
213
+
214
+ Compares average of window_size segments before vs after each position.
215
+ When the drop is significant, it indicates the audio doesn't match
216
+ the text starting at that position.
217
+
218
+ Args:
219
+ alignments: List of aligned supervisions with scores
220
+ drop_threshold: Minimum drop between before/after averages to trigger
221
+ window_size: Number of segments to average on each side
222
+
223
+ Returns:
224
+ Dict with anomaly info if found, None otherwise
225
+ """
226
+ scores = [s.score for s in alignments if s.score is not None]
227
+ if len(scores) < window_size * 2:
228
+ return None
229
+
230
+ for i in range(window_size, len(scores) - window_size):
231
+ before_avg = np.mean(scores[i - window_size : i])
232
+ after_avg = np.mean(scores[i : i + window_size])
233
+ drop = before_avg - after_avg
234
+
235
+ # Trigger: significant drop between before and after windows
236
+ if drop > drop_threshold:
237
+ # Find the exact mutation point (largest single-step drop)
238
+ max_drop = 0
239
+ mutation_idx = i
240
+ for j in range(i - 1, min(i + window_size, len(scores) - 1)):
241
+ single_drop = scores[j] - scores[j + 1]
242
+ if single_drop > max_drop:
243
+ max_drop = single_drop
244
+ mutation_idx = j + 1
245
+
246
+ # Segments: last normal + anomaly segments
247
+ last_normal = alignments[mutation_idx - 1] if mutation_idx > 0 else None
248
+ anomaly_segments = [
249
+ alignments[j] for j in range(mutation_idx, min(mutation_idx + window_size, len(alignments)))
250
+ ]
251
+
252
+ return {
253
+ "mutation_index": mutation_idx,
254
+ "before_avg": round(before_avg, 4),
255
+ "after_avg": round(after_avg, 4),
256
+ "window_drop": round(drop, 4),
257
+ "mutation_drop": round(max_drop, 4),
258
+ "last_normal": last_normal,
259
+ "segments": anomaly_segments,
260
+ }
261
+
262
+ return None
263
+
264
+
265
+ def _format_anomaly_warning(anomaly: Dict[str, Any]) -> str:
266
+ """Format anomaly detection result as warning message."""
267
+ lines = [
268
+ f"⚠️ Score anomaly detected at segment #{anomaly['mutation_index']}",
269
+ f" Window avg: {anomaly['before_avg']:.4f} → {anomaly['after_avg']:.4f} (drop: {anomaly['window_drop']:.4f})", # noqa: E501
270
+ f" Mutation drop: {anomaly['mutation_drop']:.4f}",
271
+ "",
272
+ ]
273
+
274
+ # Show last normal segment
275
+ if anomaly.get("last_normal"):
276
+ seg = anomaly["last_normal"]
277
+ text_preview = seg.text[:50] + "..." if len(seg.text) > 50 else seg.text
278
+ lines.append(f' [{seg.start:.2f}s-{seg.end:.2f}s] score={seg.score:.4f} "{text_preview}"')
279
+
280
+ # Separator - mutation point
281
+ lines.append(" " + "─" * 60)
282
+ lines.append(f" ⬇️ MUTATION: The following {len(anomaly['segments'])}+ segments don't match audio")
283
+ lines.append(" " + "─" * 60)
284
+
285
+ # Show anomaly segments
286
+ for seg in anomaly["segments"]:
287
+ text_preview = seg.text[:50] + "..." if len(seg.text) > 50 else seg.text
288
+ lines.append(f' [{seg.start:.2f}s-{seg.end:.2f}s] score={seg.score:.4f} "{text_preview}"')
289
+
290
+ lines.append("")
291
+ lines.append(" Possible causes: Transcription error, missing content, or wrong audio region")
292
+ return "\n".join(lines)
@@ -35,6 +35,8 @@ class Lattice1Worker:
35
35
  sess_options.intra_op_num_threads = num_threads # CPU cores
36
36
  sess_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL
37
37
  sess_options.add_session_config_entry("session.intra_op.allow_spinning", "0")
38
+ # Suppress CoreMLExecutionProvider warnings about partial graph support
39
+ sess_options.log_severity_level = 3 # ERROR level only
38
40
 
39
41
  acoustic_model_path = f"{model_path}/acoustic_opt.onnx"
40
42
 
@@ -191,12 +193,17 @@ class Lattice1Worker:
191
193
  float(output_beam),
192
194
  int(min_active_states),
193
195
  int(max_active_states),
196
+ allow_partial=True,
194
197
  )
195
198
 
196
- # Streaming mode
199
+ # Streaming mode with confidence score accumulation
197
200
  total_duration = audio.duration
198
201
  total_minutes = int(total_duration / 60.0)
199
202
 
203
+ max_probs = []
204
+ aligned_probs = []
205
+ prev_labels_len = 0
206
+
200
207
  with tqdm(
201
208
  total=total_minutes,
202
209
  desc=f"Processing audio ({total_minutes} min)",
@@ -208,13 +215,43 @@ class Lattice1Worker:
208
215
  chunk_emission = self.emission(chunk.ndarray, acoustic_scale=acoustic_scale)
209
216
  intersecter.decode(chunk_emission[0])
210
217
 
218
+ __start = time.time()
219
+ # Get partial labels and compute confidence stats for this chunk
220
+ partial_labels = intersecter.get_partial_labels()
221
+ chunk_len = chunk_emission.shape[1]
222
+
223
+ # Get labels for current chunk (new labels since last chunk)
224
+ chunk_labels = partial_labels[prev_labels_len : prev_labels_len + chunk_len]
225
+ prev_labels_len = len(partial_labels)
226
+
227
+ # Compute emission-based confidence stats
228
+ probs = np.exp(chunk_emission[0]) # [T, V]
229
+ max_probs.append(np.max(probs, axis=-1)) # [T]
230
+
231
+ # Handle case where chunk_labels length might differ from chunk_len
232
+ if len(chunk_labels) == chunk_len:
233
+ aligned_probs.append(probs[np.arange(chunk_len), chunk_labels])
234
+ else:
235
+ # Fallback: use max probs as aligned probs (approximate)
236
+ aligned_probs.append(np.max(probs, axis=-1))
237
+
238
+ del chunk_emission, probs # Free memory
239
+ self.timings["align_>labels"] += time.time() - __start
240
+
211
241
  # Update progress
212
242
  chunk_duration = int(chunk.duration / 60.0)
213
243
  pbar.update(chunk_duration)
214
244
 
215
- emission_result = None
245
+ # Build emission_stats for confidence calculation
246
+ emission_stats = {
247
+ "max_probs": np.concatenate(max_probs),
248
+ "aligned_probs": np.concatenate(aligned_probs),
249
+ }
250
+
216
251
  # Get results from intersecter
252
+ __start = time.time()
217
253
  results, labels = intersecter.finish()
254
+ self.timings["align_>finish"] += time.time() - __start
218
255
  else:
219
256
  # Batch mode
220
257
  if emission is None:
@@ -230,13 +267,19 @@ class Lattice1Worker:
230
267
  float(output_beam),
231
268
  int(min_active_states),
232
269
  int(max_active_states),
270
+ allow_partial=True,
233
271
  )
234
- emission_result = emission
272
+ # Compute emission_stats from full emission (same format as streaming)
273
+ probs = np.exp(emission[0]) # [T, V]
274
+ emission_stats = {
275
+ "max_probs": np.max(probs, axis=-1), # [T]
276
+ "aligned_probs": probs[np.arange(probs.shape[0]), labels[0]], # [T]
277
+ }
235
278
 
236
279
  self.timings["align_segments"] += time.time() - _start
237
280
 
238
281
  channel = 0
239
- return emission_result, results, labels, self.frame_shift, offset, channel # frame_shift=20ms
282
+ return emission_stats, results, labels, self.frame_shift, offset, channel # frame_shift=20ms
240
283
 
241
284
  def profile(self) -> None:
242
285
  """Print formatted profiling statistics."""
@@ -5,11 +5,12 @@ from typing import List, Optional, Tuple
5
5
  import colorful
6
6
 
7
7
  from lattifai.audio2 import AudioData
8
- from lattifai.caption import Caption, Supervision
8
+ from lattifai.caption import Supervision
9
9
  from lattifai.config import AlignmentConfig
10
+ from lattifai.data import Caption
10
11
  from lattifai.utils import safe_print
11
12
 
12
- from .sentence_splitter import END_PUNCTUATION
13
+ from .punctuation import END_PUNCTUATION
13
14
 
14
15
 
15
16
  class Segmenter:
@@ -9,7 +9,8 @@ import regex
9
9
  from error_align import error_align
10
10
  from error_align.utils import DELIMITERS, NUMERIC_TOKEN, STANDARD_TOKEN, Alignment, OpType
11
11
 
12
- from lattifai.caption import Caption, Supervision
12
+ from lattifai.caption import Supervision
13
+ from lattifai.data import Caption
13
14
  from lattifai.utils import safe_print
14
15
 
15
16
  from .punctuation import PUNCTUATION
@@ -6,9 +6,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
6
6
 
7
7
  import numpy as np
8
8
 
9
- # from lattifai.caption import Supervision
10
- from lhotse.supervision import SupervisionSegment as Supervision # NOTE: Transcriber SupervisionSegment
11
-
9
+ from lattifai.caption import SentenceSplitter, Supervision
12
10
  from lattifai.caption import normalize_text as normalize_html_text
13
11
  from lattifai.errors import (
14
12
  LATTICE_DECODING_FAILURE_HELP,
@@ -17,9 +15,7 @@ from lattifai.errors import (
17
15
  QuotaExceededError,
18
16
  )
19
17
 
20
- from .phonemizer import G2Phonemizer
21
18
  from .punctuation import PUNCTUATION, PUNCTUATION_SPACE
22
- from .sentence_splitter import SentenceSplitter
23
19
  from .text_align import TextAlignResult
24
20
 
25
21
  MAXIMUM_WORD_LENGTH = 40
@@ -174,13 +170,16 @@ class LatticeTokenizer:
174
170
  tokenizer.dictionaries = defaultdict(list, data["dictionaries"])
175
171
  tokenizer.oov_word = data["oov_word"]
176
172
 
173
+ # Lazy load G2P model only if it exists (avoids PyTorch dependency)
177
174
  g2pp_model_path = f"{model_path}/g2pp.bin" if Path(f"{model_path}/g2pp.bin").exists() else None
178
- if g2pp_model_path:
179
- tokenizer.g2p_model = G2Phonemizer(g2pp_model_path, device=device)
175
+ g2p_model_path = f"{model_path}/g2p.bin" if Path(f"{model_path}/g2p.bin").exists() else None
176
+ g2p_path = g2pp_model_path or g2p_model_path
177
+ if g2p_path:
178
+ from .phonemizer import G2Phonemizer
179
+
180
+ tokenizer.g2p_model = G2Phonemizer(g2p_path, device=device)
180
181
  else:
181
- g2p_model_path = f"{model_path}/g2p.bin" if Path(f"{model_path}/g2p.bin").exists() else None
182
- if g2p_model_path:
183
- tokenizer.g2p_model = G2Phonemizer(g2p_model_path, device=device)
182
+ tokenizer.g2p_model = None
184
183
 
185
184
  tokenizer.device = device
186
185
  tokenizer.add_special_tokens()
@@ -246,9 +245,24 @@ class LatticeTokenizer:
246
245
  self.init_sentence_splitter()
247
246
  return self.sentence_splitter.split_sentences(supervisions, strip_whitespace=strip_whitespace)
248
247
 
248
+ def _get_client_info(self) -> Dict[str, Optional[str]]:
249
+ """Get client identification info for usage tracking."""
250
+ try:
251
+ from importlib.metadata import version
252
+
253
+ return {"client_name": "python-sdk", "client_version": version("lattifai")}
254
+ except Exception:
255
+ return {"client_name": "python-sdk", "client_version": "unknown"}
256
+
249
257
  def tokenize(
250
- self, supervisions: Union[List[Supervision], TextAlignResult], split_sentence: bool = False, boost: float = 0.0
258
+ self,
259
+ supervisions: Union[List[Supervision], TextAlignResult],
260
+ split_sentence: bool = False,
261
+ boost: float = 0.0,
262
+ transition_penalty: Optional[float] = 0.0,
251
263
  ) -> Tuple[str, Dict[str, Any]]:
264
+ client_info = self._get_client_info()
265
+
252
266
  if isinstance(supervisions[0], Supervision):
253
267
  if split_sentence:
254
268
  supervisions = self.split_sentences(supervisions)
@@ -260,6 +274,8 @@ class LatticeTokenizer:
260
274
  "model_name": self.model_name,
261
275
  "supervisions": [s.to_dict() for s in supervisions],
262
276
  "pronunciation_dictionaries": pronunciation_dictionaries,
277
+ **client_info,
278
+ "transition_penalty": transition_penalty,
263
279
  },
264
280
  )
265
281
  else:
@@ -274,6 +290,7 @@ class LatticeTokenizer:
274
290
  "transcription": [s.to_dict() for s in supervisions[1]],
275
291
  "pronunciation_dictionaries": pronunciation_dictionaries,
276
292
  "boost": boost,
293
+ **client_info,
277
294
  },
278
295
  )
279
296
 
@@ -297,8 +314,10 @@ class LatticeTokenizer:
297
314
  return_details: bool = False,
298
315
  start_margin: float = 0.08,
299
316
  end_margin: float = 0.20,
317
+ check_sanity: bool = True,
300
318
  ) -> List[Supervision]:
301
- emission, results, labels, frame_shift, offset, channel = lattice_results # noqa: F841
319
+ emission_stats, results, labels, frame_shift, offset, channel = lattice_results # noqa: F841
320
+ # emission_stats is a dict with 'max_probs' and 'aligned_probs' (unified for batch and streaming)
302
321
  if isinstance(supervisions[0], Supervision):
303
322
  response = self.client_wrapper.post(
304
323
  "detokenize",
@@ -314,6 +333,7 @@ class LatticeTokenizer:
314
333
  "destroy_lattice": True,
315
334
  "start_margin": start_margin,
316
335
  "end_margin": end_margin,
336
+ "check_sanity": check_sanity,
317
337
  },
318
338
  )
319
339
  else:
@@ -331,6 +351,7 @@ class LatticeTokenizer:
331
351
  "destroy_lattice": True,
332
352
  "start_margin": start_margin,
333
353
  "end_margin": end_margin,
354
+ "check_sanity": check_sanity,
334
355
  },
335
356
  )
336
357
 
@@ -350,9 +371,8 @@ class LatticeTokenizer:
350
371
 
351
372
  alignments = [Supervision.from_dict(s) for s in result["supervisions"]]
352
373
 
353
- if emission is not None and return_details:
354
- # Add emission confidence scores for segments and word-level alignments
355
- _add_confidence_scores(alignments, emission, labels[0], frame_shift, offset)
374
+ # Add emission confidence scores for segments and word-level alignments
375
+ _add_confidence_scores(alignments, emission_stats, frame_shift, offset)
356
376
 
357
377
  if isinstance(supervisions[0], Supervision):
358
378
  alignments = _update_alignments_speaker(supervisions, alignments)
@@ -365,8 +385,7 @@ class LatticeTokenizer:
365
385
 
366
386
  def _add_confidence_scores(
367
387
  supervisions: List[Supervision],
368
- emission: np.ndarray,
369
- labels: List[int],
388
+ emission_stats: Dict[str, np.ndarray],
370
389
  frame_shift: float,
371
390
  offset: float = 0.0,
372
391
  ) -> None:
@@ -379,29 +398,37 @@ def _add_confidence_scores(
379
398
 
380
399
  Args:
381
400
  supervisions: List of Supervision objects to add scores to (modified in-place)
382
- emission: Emission tensor with shape [batch, time, vocab_size]
383
- labels: Token labels corresponding to aligned tokens
401
+ emission_stats: Dict with 'max_probs' and 'aligned_probs' arrays
384
402
  frame_shift: Frame shift in seconds for converting frames to time
403
+ offset: Time offset in seconds
385
404
  """
386
- tokens = np.array(labels, dtype=np.int64)
405
+ max_probs = emission_stats["max_probs"]
406
+ aligned_probs = emission_stats["aligned_probs"]
407
+ diffprobs_full = max_probs - aligned_probs
387
408
 
388
409
  for supervision in supervisions:
389
410
  start_frame = int((supervision.start - offset) / frame_shift)
390
411
  end_frame = int((supervision.end - offset) / frame_shift)
391
412
 
392
- # Compute segment-level confidence
393
- probabilities = np.exp(emission[0, start_frame:end_frame])
394
- aligned = probabilities[range(0, end_frame - start_frame), tokens[start_frame:end_frame]]
395
- diffprobs = np.max(probabilities, axis=-1) - aligned
396
- supervision.score = round(1.0 - diffprobs.mean(), ndigits=4)
413
+ # Clamp to valid range
414
+ start_frame = max(0, min(start_frame, len(diffprobs_full) - 1))
415
+ end_frame = max(start_frame + 1, min(end_frame, len(diffprobs_full)))
416
+
417
+ diffprobs = diffprobs_full[start_frame:end_frame]
418
+ if len(diffprobs) > 0:
419
+ supervision.score = round(1.0 - diffprobs.mean().item(), ndigits=4)
397
420
 
398
- # Compute word-level confidence if alignment exists
421
+ # Word-level confidence
399
422
  if hasattr(supervision, "alignment") and supervision.alignment:
400
423
  words = supervision.alignment.get("word", [])
401
424
  for w, item in enumerate(words):
402
- start = int((item.start - offset) / frame_shift) - start_frame
403
- end = int((item.end - offset) / frame_shift) - start_frame
404
- words[w] = item._replace(score=round(1.0 - diffprobs[start:end].mean(), ndigits=4))
425
+ start = int((item.start - offset) / frame_shift)
426
+ end = int((item.end - offset) / frame_shift)
427
+ start = max(0, min(start, len(diffprobs_full) - 1))
428
+ end = max(start + 1, min(end, len(diffprobs_full)))
429
+ word_diffprobs = diffprobs_full[start:end]
430
+ if len(word_diffprobs) > 0:
431
+ words[w] = item._replace(score=round(1.0 - word_diffprobs.mean().item(), ndigits=4))
405
432
 
406
433
 
407
434
  def _update_alignments_speaker(supervisions: List[Supervision], alignments: List[Supervision]) -> List[Supervision]: