lattifai 1.0.5__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -35,7 +35,8 @@ class Lattice1Aligner(object):
35
35
  raise ValueError("AlignmentConfig.client_wrapper is not set. It must be initialized by the client.")
36
36
 
37
37
  client_wrapper = config.client_wrapper
38
- model_path = _resolve_model_path(config.model_name)
38
+ # Resolve model path using configured model hub
39
+ model_path = _resolve_model_path(config.model_name, getattr(config, "model_hub", "huggingface"))
39
40
 
40
41
  self.tokenizer = _load_tokenizer(client_wrapper, model_path, config.model_name, config.device)
41
42
  self.worker = _load_worker(model_path, config.device, config)
@@ -53,6 +54,29 @@ class Lattice1Aligner(object):
53
54
  """
54
55
  return self.worker.emission(ndarray)
55
56
 
57
+ def separate(self, audio: np.ndarray) -> np.ndarray:
58
+ """Separate audio using separator model.
59
+
60
+ Args:
61
+ audio: np.ndarray object containing the audio to separate, shape (1, T)
62
+
63
+ Returns:
64
+ Separated audio as numpy array
65
+
66
+ Raises:
67
+ RuntimeError: If separator model is not available
68
+ """
69
+ if self.worker.separator_ort is None:
70
+ raise RuntimeError("Separator model not available. separator.onnx not found in model path.")
71
+
72
+ # Run separator model
73
+ separator_output = self.worker.separator_ort.run(
74
+ None,
75
+ {"audio": audio},
76
+ )
77
+
78
+ return separator_output[0]
79
+
56
80
  def alignment(
57
81
  self,
58
82
  audio: AudioData,
@@ -120,7 +144,12 @@ class Lattice1Aligner(object):
120
144
  safe_print(colorful.cyan("🎯 Step 4: Decoding lattice results to aligned segments"))
121
145
  try:
122
146
  alignments = self.tokenizer.detokenize(
123
- lattice_id, lattice_results, supervisions=supervisions, return_details=return_details
147
+ lattice_id,
148
+ lattice_results,
149
+ supervisions=supervisions,
150
+ return_details=return_details,
151
+ start_margin=self.config.start_margin,
152
+ end_margin=self.config.end_margin,
124
153
  )
125
154
  if verbose:
126
155
  safe_print(colorful.green(f" ✓ Successfully aligned {len(alignments)} segments"))
@@ -1,6 +1,7 @@
1
1
  import json
2
2
  import time
3
3
  from collections import defaultdict
4
+ from pathlib import Path
4
5
  from typing import Any, Dict, Optional, Tuple
5
6
 
6
7
  import numpy as np
@@ -73,6 +74,19 @@ class Lattice1Worker:
73
74
  else:
74
75
  self.extractor = None # ONNX model includes feature extractor
75
76
 
77
+ # Initialize separator if available
78
+ separator_model_path = Path(model_path) / "separator.onnx"
79
+ if separator_model_path.exists():
80
+ try:
81
+ self.separator_ort = ort.InferenceSession(
82
+ str(separator_model_path),
83
+ providers=providers + ["CPUExecutionProvider"],
84
+ )
85
+ except Exception as e:
86
+ raise ModelLoadError(f"separator model from {model_path}", original_error=e)
87
+ else:
88
+ self.separator_ort = None
89
+
76
90
  self.device = torch.device(device)
77
91
  self.timings = defaultdict(lambda: 0.0)
78
92
 
@@ -214,7 +214,7 @@ class LatticeTokenizer:
214
214
  else:
215
215
  with open(words_model_path, "rb") as f:
216
216
  data = pickle.load(f)
217
- except pickle.UnpicklingError as e:
217
+ except Exception as e:
218
218
  del e
219
219
  import msgpack
220
220
 
@@ -434,6 +434,8 @@ class LatticeTokenizer:
434
434
  lattice_results: Tuple[torch.Tensor, Any, Any, float, float],
435
435
  supervisions: List[Supervision],
436
436
  return_details: bool = False,
437
+ start_margin: float = 0.08,
438
+ end_margin: float = 0.20,
437
439
  ) -> List[Supervision]:
438
440
  emission, results, labels, frame_shift, offset, channel = lattice_results # noqa: F841
439
441
  response = self.client_wrapper.post(
@@ -448,6 +450,8 @@ class LatticeTokenizer:
448
450
  "channel": channel,
449
451
  "return_details": False if return_details is None else return_details,
450
452
  "destroy_lattice": True,
453
+ "start_margin": start_margin,
454
+ "end_margin": end_margin,
451
455
  },
452
456
  )
453
457
  if response.status_code == 400:
@@ -538,12 +542,9 @@ def _load_tokenizer(
538
542
  tokenizer_cls: Type[LatticeTokenizer] = LatticeTokenizer,
539
543
  ) -> LatticeTokenizer:
540
544
  """Instantiate tokenizer with consistent error handling."""
541
- try:
542
- return tokenizer_cls.from_pretrained(
543
- client_wrapper=client_wrapper,
544
- model_path=model_path,
545
- model_name=model_name,
546
- device=device,
547
- )
548
- except Exception as e:
549
- raise ModelLoadError(f"tokenizer from {model_path}", original_error=e)
545
+ return tokenizer_cls.from_pretrained(
546
+ client_wrapper=client_wrapper,
547
+ model_path=model_path,
548
+ model_name=model_name,
549
+ device=device,
550
+ )
@@ -4,17 +4,19 @@ import json
4
4
  import re
5
5
  from dataclasses import dataclass, field
6
6
  from pathlib import Path
7
- from typing import Any, Dict, List, Optional, Union
7
+ from typing import Any, Dict, List, Optional, TypeVar
8
8
 
9
9
  from lhotse.supervision import AlignmentItem
10
10
  from lhotse.utils import Pathlike
11
11
  from tgt import TextGrid
12
12
 
13
- from ..config.caption import InputCaptionFormat, OutputCaptionFormat
13
+ from ..config.caption import InputCaptionFormat, OutputCaptionFormat # noqa: F401
14
14
  from .supervision import Supervision
15
15
  from .text_parser import normalize_text as normalize_text_fn
16
16
  from .text_parser import parse_speaker_text, parse_timestamp_text
17
17
 
18
+ DiarizationOutput = TypeVar("DiarizationOutput")
19
+
18
20
 
19
21
  @dataclass
20
22
  class Caption:
@@ -40,7 +42,7 @@ class Caption:
40
42
  # Audio Event Detection results
41
43
  audio_events: Optional[TextGrid] = None
42
44
  # Speaker Diarization results
43
- speaker_diarization: Optional[TextGrid] = None
45
+ speaker_diarization: Optional[DiarizationOutput] = None
44
46
  # Alignment results
45
47
  alignments: List[Supervision] = field(default_factory=list)
46
48
 
@@ -272,7 +274,7 @@ class Caption:
272
274
  cls,
273
275
  transcription: List[Supervision],
274
276
  audio_events: Optional[TextGrid] = None,
275
- speaker_diarization: Optional[TextGrid] = None,
277
+ speaker_diarization: Optional[DiarizationOutput] = None,
276
278
  language: Optional[str] = None,
277
279
  source_path: Optional[Pathlike] = None,
278
280
  metadata: Optional[Dict[str, str]] = None,
@@ -283,7 +285,7 @@ class Caption:
283
285
  Args:
284
286
  transcription: List of transcription supervision segments
285
287
  audio_events: Optional TextGrid with audio event detection results
286
- speaker_diarization: Optional TextGrid with speaker diarization results
288
+ speaker_diarization: Optional DiarizationOutput with speaker diarization results
287
289
  language: Language code
288
290
  source_path: Source file path
289
291
  metadata: Additional metadata
@@ -384,9 +386,9 @@ class Caption:
384
386
  """
385
387
  Read speaker diarization TextGrid from file.
386
388
  """
387
- from tgt import read_textgrid
389
+ from lattifai_core.diarization import DiarizationOutput
388
390
 
389
- self.speaker_diarization = read_textgrid(path)
391
+ self.speaker_diarization = DiarizationOutput.read(path)
390
392
  return self.speaker_diarization
391
393
 
392
394
  def write_speaker_diarization(
@@ -399,9 +401,7 @@ class Caption:
399
401
  if not self.speaker_diarization:
400
402
  raise ValueError("No speaker diarization data to write.")
401
403
 
402
- from tgt import write_to_file
403
-
404
- write_to_file(self.speaker_diarization, path, format="long")
404
+ self.speaker_diarization.write(path)
405
405
  return path
406
406
 
407
407
  @staticmethod
@@ -451,7 +451,10 @@ class Caption:
451
451
  else:
452
452
  if include_speaker_in_text and sup.speaker is not None:
453
453
  # Use [SPEAKER]: format for consistency with parsing
454
- text = f"[{sup.speaker}]: {sup.text}"
454
+ if not sup.has_custom("original_speaker") or sup.custom["original_speaker"]:
455
+ text = f"[{sup.speaker}]: {sup.text}"
456
+ else:
457
+ text = f"{sup.text}"
455
458
  else:
456
459
  text = sup.text
457
460
  f.write(f"[{sup.start:.2f}-{sup.end:.2f}] {text}\n")
@@ -471,7 +474,12 @@ class Caption:
471
474
  tg = TextGrid()
472
475
  supervisions, words, scores = [], [], {"utterances": [], "words": []}
473
476
  for supervision in sorted(alignments, key=lambda x: x.start):
474
- if include_speaker_in_text and supervision.speaker is not None:
477
+ # Respect `original_speaker` custom flag: default to include speaker when missing
478
+ if (
479
+ include_speaker_in_text
480
+ and supervision.speaker is not None
481
+ and (not supervision.has_custom("original_speaker") or supervision.custom["original_speaker"])
482
+ ):
475
483
  text = f"{supervision.speaker} {supervision.text}"
476
484
  else:
477
485
  text = supervision.text
@@ -526,7 +534,10 @@ class Caption:
526
534
  )
527
535
  else:
528
536
  if include_speaker_in_text and sup.speaker is not None:
529
- text = f"{sup.speaker} {sup.text}"
537
+ if not sup.has_custom("original_speaker") or sup.custom["original_speaker"]:
538
+ text = f"{sup.speaker} {sup.text}"
539
+ else:
540
+ text = f"{sup.text}"
530
541
  else:
531
542
  text = sup.text
532
543
  subs.append(
@@ -830,7 +841,16 @@ class Caption:
830
841
  if cls._is_youtube_vtt_with_word_timestamps(content):
831
842
  return cls._parse_youtube_vtt_with_word_timestamps(content, normalize_text)
832
843
 
833
- if format == "gemini" or str(caption).endswith("Gemini.md") or str(caption).endswith("Gemini3.md"):
844
+ # Match Gemini format: explicit format, or files ending with Gemini.md/Gemini3.md,
845
+ # or files containing "gemini" in the name with .md extension
846
+ caption_str = str(caption).lower()
847
+ is_gemini_format = (
848
+ format == "gemini"
849
+ or str(caption).endswith("Gemini.md")
850
+ or str(caption).endswith("Gemini3.md")
851
+ or ("gemini" in caption_str and caption_str.endswith(".md"))
852
+ )
853
+ if is_gemini_format:
834
854
  from .gemini_reader import GeminiReader
835
855
 
836
856
  supervisions = GeminiReader.extract_for_alignment(caption)
@@ -1242,7 +1262,11 @@ class Caption:
1242
1262
  if include_speaker_in_text:
1243
1263
  file.write("speaker\tstart\tend\ttext\n")
1244
1264
  for supervision in alignments:
1245
- speaker = supervision.speaker or ""
1265
+ # Respect `original_speaker` custom flag: default to True when missing
1266
+ include_speaker = supervision.speaker and (
1267
+ not supervision.has_custom("original_speaker") or supervision.custom["original_speaker"]
1268
+ )
1269
+ speaker = supervision.speaker if include_speaker else ""
1246
1270
  start_ms = round(1000 * supervision.start)
1247
1271
  end_ms = round(1000 * supervision.end)
1248
1272
  text = supervision.text.strip().replace("\t", " ")
@@ -1280,7 +1304,10 @@ class Caption:
1280
1304
  writer = csv.writer(file)
1281
1305
  writer.writerow(["speaker", "start", "end", "text"])
1282
1306
  for supervision in alignments:
1283
- speaker = supervision.speaker or ""
1307
+ include_speaker = supervision.speaker and (
1308
+ not supervision.has_custom("original_speaker") or supervision.custom["original_speaker"]
1309
+ )
1310
+ speaker = supervision.speaker if include_speaker else ""
1284
1311
  start_ms = round(1000 * supervision.start)
1285
1312
  end_ms = round(1000 * supervision.end)
1286
1313
  text = supervision.text.strip()
@@ -1318,7 +1345,12 @@ class Caption:
1318
1345
  end = supervision.end
1319
1346
  text = supervision.text.strip().replace("\t", " ")
1320
1347
 
1321
- if include_speaker_in_text and supervision.speaker:
1348
+ # Respect `original_speaker` custom flag when adding speaker prefix
1349
+ if (
1350
+ include_speaker_in_text
1351
+ and supervision.speaker
1352
+ and (not supervision.has_custom("original_speaker") or supervision.custom["original_speaker"])
1353
+ ):
1322
1354
  text = f"[[{supervision.speaker}]]{text}"
1323
1355
 
1324
1356
  file.write(f"{start}\t{end}\t{text}\n")
@@ -1364,9 +1396,13 @@ class Caption:
1364
1396
  # Write timestamp line
1365
1397
  file.write(f"{start_time},{end_time}\n")
1366
1398
 
1367
- # Write text (with optional speaker)
1399
+ # Write text (with optional speaker). Respect `original_speaker` custom flag.
1368
1400
  text = supervision.text.strip()
1369
- if include_speaker_in_text and supervision.speaker:
1401
+ if (
1402
+ include_speaker_in_text
1403
+ and supervision.speaker
1404
+ and (not supervision.has_custom("original_speaker") or supervision.custom["original_speaker"])
1405
+ ):
1370
1406
  text = f"{supervision.speaker}: {text}"
1371
1407
 
1372
1408
  file.write(f"{text}\n")
lattifai/cli/__init__.py CHANGED
@@ -5,12 +5,14 @@ import nemo_run as run # noqa: F401
5
5
  # Import and re-export entrypoints at package level so NeMo Run can find them
6
6
  from lattifai.cli.alignment import align
7
7
  from lattifai.cli.caption import convert
8
+ from lattifai.cli.diarization import diarize
8
9
  from lattifai.cli.transcribe import transcribe, transcribe_align
9
10
  from lattifai.cli.youtube import youtube
10
11
 
11
12
  __all__ = [
12
13
  "align",
13
14
  "convert",
15
+ "diarize",
14
16
  "transcribe",
15
17
  "transcribe_align",
16
18
  "youtube",
lattifai/cli/caption.py CHANGED
@@ -14,7 +14,7 @@ from lattifai.utils import safe_print
14
14
  def convert(
15
15
  input_path: Pathlike,
16
16
  output_path: Pathlike,
17
- include_speaker_in_text: bool = True,
17
+ include_speaker_in_text: bool = False,
18
18
  normalize_text: bool = False,
19
19
  ):
20
20
  """
@@ -0,0 +1,108 @@
1
+ """Speaker diarization CLI entry point with nemo_run."""
2
+
3
+ from pathlib import Path
4
+ from typing import Optional
5
+
6
+ import colorful
7
+ import nemo_run as run
8
+ from typing_extensions import Annotated
9
+
10
+ from lattifai.client import LattifAI
11
+ from lattifai.config import CaptionConfig, ClientConfig, DiarizationConfig, MediaConfig
12
+ from lattifai.utils import safe_print
13
+
14
+ __all__ = ["diarize"]
15
+
16
+
17
+ @run.cli.entrypoint(name="run", namespace="diarization")
18
+ def diarize(
19
+ input_media: Optional[str] = None,
20
+ input_caption: Optional[str] = None,
21
+ output_caption: Optional[str] = None,
22
+ media: Annotated[Optional[MediaConfig], run.Config[MediaConfig]] = None,
23
+ caption: Annotated[Optional[CaptionConfig], run.Config[CaptionConfig]] = None,
24
+ client: Annotated[Optional[ClientConfig], run.Config[ClientConfig]] = None,
25
+ diarization: Annotated[Optional[DiarizationConfig], run.Config[DiarizationConfig]] = None,
26
+ ):
27
+ """Run speaker diarization on aligned captions and audio."""
28
+
29
+ media_config = media or MediaConfig()
30
+ caption_config = caption or CaptionConfig()
31
+ diarization_config = diarization or DiarizationConfig()
32
+
33
+ if input_media and media_config.input_path:
34
+ raise ValueError("Cannot specify both positional input_media and media.input_path.")
35
+ if input_media:
36
+ media_config.set_input_path(input_media)
37
+ if not media_config.input_path:
38
+ raise ValueError("Input media path must be provided via positional input_media or media.input_path.")
39
+
40
+ if input_caption and caption_config.input_path:
41
+ raise ValueError("Cannot specify both positional input_caption and caption.input_path.")
42
+ if input_caption:
43
+ caption_config.set_input_path(input_caption)
44
+ if not caption_config.input_path:
45
+ raise ValueError("Input caption path must be provided via positional input_caption or caption.input_path.")
46
+
47
+ if output_caption and caption_config.output_path:
48
+ raise ValueError("Cannot specify both positional output_caption and caption.output_path.")
49
+ if output_caption:
50
+ caption_config.set_output_path(output_caption)
51
+
52
+ diarization_config.enabled = True
53
+
54
+ client_instance = LattifAI(
55
+ client_config=client,
56
+ caption_config=caption_config,
57
+ diarization_config=diarization_config,
58
+ )
59
+
60
+ safe_print(colorful.cyan("🎧 Loading media for diarization..."))
61
+ media_audio = client_instance.audio_loader(
62
+ media_config.input_path,
63
+ channel_selector=media_config.channel_selector,
64
+ streaming_chunk_secs=media_config.streaming_chunk_secs,
65
+ )
66
+
67
+ safe_print(colorful.cyan("📖 Loading caption segments..."))
68
+ caption_obj = client_instance._read_caption(
69
+ caption_config.input_path,
70
+ input_caption_format=None if caption_config.input_format == "auto" else caption_config.input_format,
71
+ verbose=False,
72
+ )
73
+
74
+ if not caption_obj.alignments:
75
+ caption_obj.alignments = caption_obj.supervisions
76
+
77
+ if not caption_obj.alignments:
78
+ raise ValueError("Caption does not contain segments for diarization.")
79
+
80
+ if caption_config.output_path:
81
+ output_path = caption_config.output_path
82
+ else:
83
+ from datetime import datetime
84
+
85
+ input_caption_path = Path(caption_config.input_path)
86
+ timestamp = datetime.now().strftime("%Y%m%d_%H")
87
+ default_output = (
88
+ input_caption_path.parent / f"{input_caption_path.stem}.diarized.{timestamp}.{caption_config.output_format}"
89
+ )
90
+ caption_config.set_output_path(default_output)
91
+ output_path = caption_config.output_path
92
+
93
+ safe_print(colorful.cyan("🗣️ Performing speaker diarization..."))
94
+ diarized_caption = client_instance.speaker_diarization(
95
+ input_media=media_audio,
96
+ caption=caption_obj,
97
+ output_caption_path=output_path,
98
+ )
99
+
100
+ return diarized_caption
101
+
102
+
103
+ def main():
104
+ run.cli.main(diarize)
105
+
106
+
107
+ if __name__ == "__main__":
108
+ main()
@@ -117,7 +117,9 @@ def transcribe(
117
117
 
118
118
  # Create transcriber
119
119
  if not transcription_config.lattice_model_path:
120
- transcription_config.lattice_model_path = _resolve_model_path("LattifAI/Lattice-1")
120
+ transcription_config.lattice_model_path = _resolve_model_path(
121
+ "LattifAI/Lattice-1", getattr(transcription_config, "model_hub", "huggingface")
122
+ )
121
123
  transcriber = create_transcriber(transcription_config=transcription_config)
122
124
 
123
125
  safe_print(colorful.cyan(f"🎤 Starting transcription with {transcriber.name}..."))
lattifai/client.py CHANGED
@@ -106,7 +106,13 @@ class LattifAI(LattifAIClientMixin, SyncAPIClient):
106
106
  )
107
107
 
108
108
  if not input_caption:
109
- caption = self._transcribe(media_audio, source_lang=self.caption_config.source_lang, is_async=False)
109
+ output_dir = None
110
+ if output_caption_path:
111
+ output_dir = Path(str(output_caption_path)).parent
112
+ output_dir.mkdir(parents=True, exist_ok=True)
113
+ caption = self._transcribe(
114
+ media_audio, source_lang=self.caption_config.source_lang, is_async=False, output_dir=output_dir
115
+ )
110
116
  else:
111
117
  caption = self._read_caption(input_caption, input_caption_format)
112
118
 
@@ -260,18 +266,9 @@ class LattifAI(LattifAIClientMixin, SyncAPIClient):
260
266
  caption.supervisions = supervisions
261
267
  caption.alignments = alignments
262
268
 
263
- # Step 5: Speaker diarization
264
- if self.diarization_config.enabled and self.diarizer:
265
- safe_print(colorful.cyan("🗣️ Performing speaker diarization..."))
266
- caption = self.speaker_diarization(
267
- input_media=media_audio,
268
- caption=caption,
269
- output_caption_path=output_caption_path,
270
- )
271
- elif output_caption_path:
269
+ if output_caption_path:
272
270
  self._write_caption(caption, output_caption_path)
273
271
 
274
- return caption
275
272
  except (CaptionProcessingError, LatticeEncodingError, AlignmentError, LatticeDecodingError):
276
273
  # Re-raise our specific errors as-is
277
274
  raise
@@ -284,6 +281,17 @@ class LattifAI(LattifAIClientMixin, SyncAPIClient):
284
281
  context={"original_error": str(e), "error_type": e.__class__.__name__},
285
282
  )
286
283
 
284
+ # Step 5: Speaker diarization
285
+ if self.diarization_config.enabled and self.diarizer:
286
+ safe_print(colorful.cyan("🗣️ Performing speaker diarization..."))
287
+ caption = self.speaker_diarization(
288
+ input_media=media_audio,
289
+ caption=caption,
290
+ output_caption_path=output_caption_path,
291
+ )
292
+
293
+ return caption
294
+
287
295
  def speaker_diarization(
288
296
  self,
289
297
  input_media: AudioData,
@@ -315,7 +323,14 @@ class LattifAI(LattifAIClientMixin, SyncAPIClient):
315
323
  caption.read_speaker_diarization(diarization_file)
316
324
 
317
325
  diarization, alignments = self.diarizer.diarize_with_alignments(
318
- input_media, caption.alignments, diarization=caption.speaker_diarization
326
+ input_media,
327
+ caption.alignments,
328
+ diarization=caption.speaker_diarization,
329
+ alignment_fn=self.aligner.alignment,
330
+ transcribe_fn=self.transcriber.transcribe_numpy if self.transcriber else None,
331
+ separate_fn=self.aligner.separate if self.aligner.worker.separator_ort else None,
332
+ debug=self.diarizer.config.debug,
333
+ output_path=output_caption_path,
319
334
  )
320
335
  caption.alignments = alignments
321
336
  caption.speaker_diarization = diarization
@@ -324,105 +339,6 @@ class LattifAI(LattifAIClientMixin, SyncAPIClient):
324
339
  if output_caption_path:
325
340
  self._write_caption(caption, output_caption_path)
326
341
 
327
- if self.diarizer.config.debug:
328
- # debug
329
- from tgt import Interval, IntervalTier, TextGrid, write_to_file
330
-
331
- debug_tg = TextGrid()
332
- transcript_tier = IntervalTier(
333
- start_time=0,
334
- end_time=input_media.duration,
335
- name="transcript",
336
- objects=[Interval(sup.start, sup.end, sup.text) for sup in caption.alignments],
337
- )
338
- debug_tg.add_tier(transcript_tier)
339
-
340
- speaker_tier = IntervalTier(
341
- start_time=0,
342
- end_time=input_media.duration,
343
- name="speaker",
344
- objects=[Interval(sup.start, sup.end, sup.speaker) for sup in caption.alignments],
345
- )
346
- debug_tg.add_tier(speaker_tier)
347
-
348
- from collections import defaultdict
349
-
350
- spk2intervals = defaultdict(lambda: [])
351
- num_multispk = 0
352
-
353
- segments, skipks = [], []
354
- for k, supervision in enumerate(caption.alignments): # TODO: alignments 本身存在 overlap, eg: [event]
355
- # supervision = caption.alignments[k]
356
- if supervision.custom.get("speaker", []):
357
- num_multispk += 1
358
- else:
359
- continue
360
-
361
- if k in skipks:
362
- continue
363
-
364
- for speaker in supervision.custom.get("speaker", []):
365
- for name, start_time, end_time in speaker:
366
- spk2intervals[name].append(Interval(start_time, end_time, name))
367
-
368
- _segments = []
369
- if k > 0:
370
- _segments.append(caption.alignments[k - 1])
371
- _segments.append(supervision)
372
- while k + 1 < len(caption.alignments):
373
- skipks.append(k + 1)
374
- next_sup = caption.alignments[k + 1]
375
- if not next_sup.custom.get("speaker", []):
376
- k += 1
377
- break
378
- _segments.append(next_sup)
379
- k += 1
380
-
381
- if segments:
382
- if _segments[0].start >= segments[-1][-1].end:
383
- segments.append(_segments)
384
- else:
385
- if _segments[1:]:
386
- segments.append(_segments[1:])
387
- else:
388
- pass
389
- else:
390
- segments.append(_segments)
391
-
392
- print(
393
- f"Number of multi-speaker segments: {num_multispk}/{len(caption.alignments)} segments: {len(segments)}"
394
- )
395
-
396
- for speaker, intervals in sorted(spk2intervals.items(), key=lambda x: x[0]):
397
- speaker_tier = IntervalTier(
398
- start_time=0, end_time=input_media.duration, name=speaker, objects=intervals
399
- )
400
- debug_tg.add_tier(speaker_tier)
401
-
402
- for tier in caption.speaker_diarization.tiers:
403
- tier.name = f"Diarization-{tier.name}"
404
- debug_tg.add_tier(tier)
405
-
406
- tier = IntervalTier(
407
- start_time=0,
408
- end_time=input_media.duration,
409
- name="resegment",
410
- objects=[
411
- Interval(round(sup.start, 2), round(sup.end, 2), sup.text)
412
- for _segments in segments
413
- for sup in _segments
414
- ],
415
- )
416
- debug_tg.add_tier(tier)
417
-
418
- # if caption.audio_events:
419
- # for tier in caption.audio_events.tiers:
420
- # # tier.name = f"{tier.name}"
421
- # debug_tg.add_tier(tier)
422
-
423
- debug_tgt_file = Path(str(output_caption_path)).with_suffix(".DiarizationDebug.TextGrid")
424
- write_to_file(debug_tg, debug_tgt_file, format="long")
425
-
426
342
  return caption
427
343
 
428
344
  def youtube(
@@ -21,6 +21,9 @@ class AlignmentConfig:
21
21
  model_name: str = "LattifAI/Lattice-1"
22
22
  """Model identifier or path to local model directory (e.g., 'LattifAI/Lattice-1')."""
23
23
 
24
+ model_hub: Literal["huggingface", "modelscope"] = "huggingface"
25
+ """Which model hub to use when resolving remote model names: 'huggingface' or 'modelscope'."""
26
+
24
27
  device: Literal["cpu", "cuda", "mps", "auto"] = "auto"
25
28
  """Computation device: 'cpu' for CPU, 'cuda' for NVIDIA GPU, 'mps' for Apple Silicon."""
26
29
 
@@ -79,6 +82,17 @@ class AlignmentConfig:
79
82
  Default: 10000. Typical range: 1000-20000.
80
83
  """
81
84
 
85
+ # Alignment timing configuration
86
+ start_margin: float = 0.08
87
+ """Maximum start time margin (in seconds) to extend segment boundaries at the beginning.
88
+ Default: 0.08. Typical range: 0.0-0.5.
89
+ """
90
+
91
+ end_margin: float = 0.20
92
+ """Maximum end time margin (in seconds) to extend segment boundaries at the end.
93
+ Default: 0.20. Typical range: 0.0-0.5.
94
+ """
95
+
82
96
  client_wrapper: Optional["SyncAPIClient"] = field(default=None, repr=False)
83
97
  """Reference to the SyncAPIClient instance. Auto-set during client initialization."""
84
98
 
@@ -12,6 +12,7 @@ if TYPE_CHECKING:
12
12
  SUPPORTED_TRANSCRIPTION_MODELS = Literal[
13
13
  "gemini-2.5-pro",
14
14
  "gemini-3-pro-preview",
15
+ "gemini-3-flash-preview",
15
16
  "nvidia/parakeet-tdt-0.6b-v3",
16
17
  "nvidia/canary-1b-v2",
17
18
  "iic/SenseVoiceSmall",
@@ -50,6 +51,9 @@ class TranscriptionConfig:
50
51
  lattice_model_path: Optional[str] = None
51
52
  """Path to local LattifAI model. Will be auto-set in LattifAI client."""
52
53
 
54
+ model_hub: Literal["huggingface", "modelscope"] = "huggingface"
55
+ """Which model hub to use when resolving lattice models for transcription."""
56
+
53
57
  client_wrapper: Optional["SyncAPIClient"] = field(default=None, repr=False)
54
58
  """Reference to the SyncAPIClient instance. Auto-set during client initialization."""
55
59