lattifai 0.4.6__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. lattifai/__init__.py +42 -27
  2. lattifai/alignment/__init__.py +6 -0
  3. lattifai/alignment/lattice1_aligner.py +119 -0
  4. lattifai/{workers/lattice1_alpha.py → alignment/lattice1_worker.py} +33 -132
  5. lattifai/{tokenizer → alignment}/phonemizer.py +1 -1
  6. lattifai/alignment/segmenter.py +166 -0
  7. lattifai/{tokenizer → alignment}/tokenizer.py +186 -112
  8. lattifai/audio2.py +211 -0
  9. lattifai/caption/__init__.py +20 -0
  10. lattifai/caption/caption.py +1275 -0
  11. lattifai/{io → caption}/supervision.py +1 -0
  12. lattifai/{io → caption}/text_parser.py +53 -10
  13. lattifai/cli/__init__.py +17 -0
  14. lattifai/cli/alignment.py +153 -0
  15. lattifai/cli/caption.py +204 -0
  16. lattifai/cli/server.py +19 -0
  17. lattifai/cli/transcribe.py +197 -0
  18. lattifai/cli/youtube.py +128 -0
  19. lattifai/client.py +455 -246
  20. lattifai/config/__init__.py +20 -0
  21. lattifai/config/alignment.py +73 -0
  22. lattifai/config/caption.py +178 -0
  23. lattifai/config/client.py +46 -0
  24. lattifai/config/diarization.py +67 -0
  25. lattifai/config/media.py +335 -0
  26. lattifai/config/transcription.py +84 -0
  27. lattifai/diarization/__init__.py +5 -0
  28. lattifai/diarization/lattifai.py +89 -0
  29. lattifai/errors.py +41 -34
  30. lattifai/logging.py +116 -0
  31. lattifai/mixin.py +552 -0
  32. lattifai/server/app.py +420 -0
  33. lattifai/transcription/__init__.py +76 -0
  34. lattifai/transcription/base.py +108 -0
  35. lattifai/transcription/gemini.py +219 -0
  36. lattifai/transcription/lattifai.py +103 -0
  37. lattifai/types.py +30 -0
  38. lattifai/utils.py +3 -31
  39. lattifai/workflow/__init__.py +22 -0
  40. lattifai/workflow/agents.py +6 -0
  41. lattifai/{workflows → workflow}/file_manager.py +81 -57
  42. lattifai/workflow/youtube.py +564 -0
  43. lattifai-1.0.0.dist-info/METADATA +736 -0
  44. lattifai-1.0.0.dist-info/RECORD +52 -0
  45. {lattifai-0.4.6.dist-info → lattifai-1.0.0.dist-info}/WHEEL +1 -1
  46. lattifai-1.0.0.dist-info/entry_points.txt +13 -0
  47. lattifai/base_client.py +0 -126
  48. lattifai/bin/__init__.py +0 -3
  49. lattifai/bin/agent.py +0 -324
  50. lattifai/bin/align.py +0 -295
  51. lattifai/bin/cli_base.py +0 -25
  52. lattifai/bin/subtitle.py +0 -210
  53. lattifai/io/__init__.py +0 -43
  54. lattifai/io/reader.py +0 -86
  55. lattifai/io/utils.py +0 -15
  56. lattifai/io/writer.py +0 -102
  57. lattifai/tokenizer/__init__.py +0 -3
  58. lattifai/workers/__init__.py +0 -3
  59. lattifai/workflows/__init__.py +0 -34
  60. lattifai/workflows/agents.py +0 -12
  61. lattifai/workflows/gemini.py +0 -167
  62. lattifai/workflows/prompts/README.md +0 -22
  63. lattifai/workflows/prompts/gemini/README.md +0 -24
  64. lattifai/workflows/prompts/gemini/transcription_gem.txt +0 -81
  65. lattifai/workflows/youtube.py +0 -931
  66. lattifai-0.4.6.dist-info/METADATA +0 -806
  67. lattifai-0.4.6.dist-info/RECORD +0 -39
  68. lattifai-0.4.6.dist-info/entry_points.txt +0 -3
  69. /lattifai/{io → caption}/gemini_reader.py +0 -0
  70. /lattifai/{io → caption}/gemini_writer.py +0 -0
  71. /lattifai/{workflows → transcription}/prompts/__init__.py +0 -0
  72. /lattifai/{workflows → workflow}/base.py +0 -0
  73. {lattifai-0.4.6.dist-info → lattifai-1.0.0.dist-info}/licenses/LICENSE +0 -0
  74. {lattifai-0.4.6.dist-info → lattifai-1.0.0.dist-info}/top_level.txt +0 -0
lattifai/__init__.py CHANGED
@@ -1,28 +1,40 @@
1
1
  import sys
2
2
  import warnings
3
+ from importlib.metadata import version
3
4
 
5
+ # Re-export I/O classes
6
+ from .caption import Caption
7
+
8
+ # Re-export client classes
9
+ from .client import LattifAI
10
+
11
+ # Re-export config classes
12
+ from .config import (
13
+ AUDIO_FORMATS,
14
+ MEDIA_FORMATS,
15
+ VIDEO_FORMATS,
16
+ AlignmentConfig,
17
+ CaptionConfig,
18
+ ClientConfig,
19
+ DiarizationConfig,
20
+ MediaConfig,
21
+ )
4
22
  from .errors import (
5
23
  AlignmentError,
6
24
  APIError,
7
25
  AudioFormatError,
8
26
  AudioLoadError,
9
27
  AudioProcessingError,
28
+ CaptionParseError,
29
+ CaptionProcessingError,
10
30
  ConfigurationError,
11
31
  DependencyError,
12
32
  LatticeDecodingError,
13
33
  LatticeEncodingError,
14
34
  LattifAIError,
15
35
  ModelLoadError,
16
- SubtitleParseError,
17
- SubtitleProcessingError,
18
36
  )
19
- from .io import SubtitleIO
20
-
21
- try:
22
- from importlib.metadata import version
23
- except ImportError:
24
- # Python < 3.8
25
- from importlib_metadata import version
37
+ from .logging import get_logger, set_log_level, setup_logger
26
38
 
27
39
  try:
28
40
  __version__ = version("lattifai")
@@ -54,28 +66,25 @@ def _check_and_install_k2():
54
66
  _check_and_install_k2()
55
67
 
56
68
 
57
- # Lazy import for LattifAI to avoid dependency issues during basic import
58
- def __getattr__(name):
59
- if name == "LattifAI":
60
- from .client import LattifAI
61
-
62
- return LattifAI
63
- if name == "AsyncLattifAI":
64
- from .client import AsyncLattifAI
65
-
66
- return AsyncLattifAI
67
- raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
68
-
69
-
70
69
  __all__ = [
71
- "LattifAI", # noqa: F822
72
- "AsyncLattifAI", # noqa: F822
70
+ # Client classes
71
+ "LattifAI",
72
+ # Config classes
73
+ "AlignmentConfig",
74
+ "ClientConfig",
75
+ "CaptionConfig",
76
+ "DiarizationConfig",
77
+ "MediaConfig",
78
+ "AUDIO_FORMATS",
79
+ "VIDEO_FORMATS",
80
+ "MEDIA_FORMATS",
81
+ # Error classes
73
82
  "LattifAIError",
74
83
  "AudioProcessingError",
75
84
  "AudioLoadError",
76
85
  "AudioFormatError",
77
- "SubtitleProcessingError",
78
- "SubtitleParseError",
86
+ "CaptionProcessingError",
87
+ "CaptionParseError",
79
88
  "AlignmentError",
80
89
  "LatticeEncodingError",
81
90
  "LatticeDecodingError",
@@ -83,6 +92,12 @@ __all__ = [
83
92
  "DependencyError",
84
93
  "APIError",
85
94
  "ConfigurationError",
86
- "SubtitleIO",
95
+ # Logging
96
+ "setup_logger",
97
+ "get_logger",
98
+ "set_log_level",
99
+ # I/O
100
+ "Caption",
101
+ # Version
87
102
  "__version__",
88
103
  ]
@@ -0,0 +1,6 @@
1
+ """Alignment module for LattifAI forced alignment."""
2
+
3
+ from .lattice1_aligner import Lattice1Aligner
4
+ from .segmenter import Segmenter
5
+
6
+ __all__ = ["Lattice1Aligner", "Segmenter"]
@@ -0,0 +1,119 @@
1
+ """Lattice-1 Aligner implementation."""
2
+
3
+ from typing import Any, List, Optional, Tuple
4
+
5
+ import colorful
6
+ import torch
7
+
8
+ from lattifai.audio2 import AudioData
9
+ from lattifai.caption import Supervision
10
+ from lattifai.config import AlignmentConfig
11
+ from lattifai.errors import (
12
+ AlignmentError,
13
+ LatticeDecodingError,
14
+ LatticeEncodingError,
15
+ )
16
+ from lattifai.utils import _resolve_model_path
17
+
18
+ from .lattice1_worker import _load_worker
19
+ from .tokenizer import _load_tokenizer
20
+
21
+ ClientType = Any
22
+
23
+
24
+ class Lattice1Aligner(object):
25
+ """Synchronous LattifAI client with config-driven architecture."""
26
+
27
+ def __init__(
28
+ self,
29
+ config: AlignmentConfig,
30
+ ) -> None:
31
+ self.config = config
32
+
33
+ if config.client_wrapper is None:
34
+ raise ValueError("AlignmentConfig.client_wrapper is not set. It must be initialized by the client.")
35
+
36
+ client_wrapper = config.client_wrapper
37
+ model_path = _resolve_model_path(config.model_name)
38
+
39
+ self.tokenizer = _load_tokenizer(client_wrapper, model_path, config.model_name, config.device)
40
+ self.worker = _load_worker(model_path, config.device)
41
+
42
+ self.frame_shift = self.worker.frame_shift
43
+
44
+ def emission(self, audio: torch.Tensor) -> torch.Tensor:
45
+ return self.worker.emission(audio.to(self.worker.device))
46
+
47
+ def alignment(
48
+ self,
49
+ audio: AudioData,
50
+ supervisions: List[Supervision],
51
+ split_sentence: Optional[bool] = False,
52
+ return_details: Optional[bool] = False,
53
+ emission: Optional[torch.Tensor] = None,
54
+ offset: float = 0.0,
55
+ verbose: bool = True,
56
+ ) -> Tuple[List[Supervision], List[Supervision]]:
57
+ """
58
+ Perform alignment on audio and supervisions.
59
+
60
+ Args:
61
+ audio: Audio file path
62
+ supervisions: List of supervision segments to align
63
+ split_sentence: Enable sentence re-splitting
64
+
65
+ Returns:
66
+ Tuple of (supervisions, alignments)
67
+
68
+ Raises:
69
+ LatticeEncodingError: If lattice graph generation fails
70
+ AlignmentError: If audio alignment fails
71
+ LatticeDecodingError: If lattice decoding fails
72
+ """
73
+ try:
74
+ if verbose:
75
+ print(colorful.cyan("🔗 Step 2: Creating lattice graph from segments"))
76
+ try:
77
+ supervisions, lattice_id, lattice_graph = self.tokenizer.tokenize(
78
+ supervisions, split_sentence=split_sentence
79
+ )
80
+ if verbose:
81
+ print(colorful.green(f" ✓ Generated lattice graph with ID: {lattice_id}"))
82
+ except Exception as e:
83
+ text_content = " ".join([sup.text for sup in supervisions]) if supervisions else ""
84
+ raise LatticeEncodingError(text_content, original_error=e)
85
+
86
+ if verbose:
87
+ print(colorful.cyan(f"🔍 Step 3: Searching lattice graph with media: {audio}"))
88
+ try:
89
+ lattice_results = self.worker.alignment(audio, lattice_graph, emission=emission, offset=offset)
90
+ if verbose:
91
+ print(colorful.green(" ✓ Lattice search completed"))
92
+ except Exception as e:
93
+ raise AlignmentError(
94
+ f"Audio alignment failed for {audio}",
95
+ media_path=str(audio),
96
+ context={"original_error": str(e)},
97
+ )
98
+
99
+ if verbose:
100
+ print(colorful.cyan("🎯 Step 4: Decoding lattice results to aligned segments"))
101
+ try:
102
+ alignments = self.tokenizer.detokenize(
103
+ lattice_id, lattice_results, supervisions=supervisions, return_details=return_details
104
+ )
105
+ if verbose:
106
+ print(colorful.green(f" ✓ Successfully aligned {len(alignments)} segments"))
107
+ except LatticeDecodingError as e:
108
+ print(colorful.red(" x Failed to decode lattice alignment results"))
109
+ raise e
110
+ except Exception as e:
111
+ print(colorful.red(" x Failed to decode lattice alignment results"))
112
+ raise LatticeDecodingError(lattice_id, original_error=e)
113
+
114
+ return (supervisions, alignments)
115
+
116
+ except (LatticeEncodingError, AlignmentError, LatticeDecodingError):
117
+ raise
118
+ except Exception as e:
119
+ raise e
@@ -1,73 +1,20 @@
1
1
  import json
2
2
  import time
3
3
  from collections import defaultdict
4
- from typing import Any, BinaryIO, Dict, Iterable, Optional, Tuple, Union
4
+ from typing import Any, Dict, Optional, Tuple
5
5
 
6
6
  import numpy as np
7
7
  import onnxruntime as ort
8
- import soundfile as sf
9
8
  import torch
10
9
  from lhotse import FbankConfig
11
- from lhotse.augmentation import get_or_create_resampler
12
10
  from lhotse.features.kaldi.layers import Wav2LogFilterBank
13
11
  from lhotse.utils import Pathlike
14
12
 
15
- from lattifai.errors import AlignmentError, AudioFormatError, AudioLoadError, DependencyError, ModelLoadError
13
+ from lattifai.audio2 import AudioData
14
+ from lattifai.errors import AlignmentError, DependencyError, ModelLoadError
16
15
 
17
- ChannelSelectorType = Union[int, Iterable[int], str]
18
16
 
19
-
20
- def resample_audio(
21
- audio_sr: Tuple[torch.Tensor, int],
22
- sampling_rate: int,
23
- device: Optional[str],
24
- channel_selector: Optional[ChannelSelectorType] = "average",
25
- ) -> torch.Tensor:
26
- """
27
- return:
28
- (1, T)
29
- """
30
- audio, sr = audio_sr
31
-
32
- if channel_selector is None:
33
- # keep the original multi-channel signal
34
- tensor = audio
35
- elif isinstance(channel_selector, int):
36
- assert audio.shape[0] >= channel_selector, f"Invalid channel: {channel_selector}"
37
- tensor = audio[channel_selector : channel_selector + 1].clone()
38
- del audio
39
- elif isinstance(channel_selector, str):
40
- assert channel_selector == "average"
41
- tensor = torch.mean(audio.to(device), dim=0, keepdim=True)
42
- del audio
43
- else:
44
- assert isinstance(channel_selector, Iterable)
45
- num_channels = audio.shape[0]
46
- print(f"Selecting channels {channel_selector} from the signal with {num_channels} channels.")
47
- assert isinstance(channel_selector, Iterable)
48
- if max(channel_selector) >= num_channels:
49
- raise ValueError(
50
- f"Cannot select channel subset {channel_selector} from a signal with {num_channels} channels."
51
- )
52
- tensor = audio[channel_selector]
53
-
54
- tensor = tensor.to(device)
55
- if sr != sampling_rate:
56
- resampler = get_or_create_resampler(sr, sampling_rate).to(device=device)
57
- length = tensor.size(-1)
58
- chunk_size = sampling_rate * 3600
59
- if length > chunk_size:
60
- resampled_chunks = []
61
- for i in range(0, length, chunk_size):
62
- resampled_chunks.append(resampler(tensor[..., i : i + chunk_size]))
63
- tensor = torch.cat(resampled_chunks, dim=-1)
64
- else:
65
- tensor = resampler(tensor)
66
-
67
- return tensor
68
-
69
-
70
- class Lattice1AlphaWorker:
17
+ class Lattice1Worker:
71
18
  """Worker for processing audio with LatticeGraph."""
72
19
 
73
20
  def __init__(self, model_path: Pathlike, device: str = "cpu", num_threads: int = 8) -> None:
@@ -109,6 +56,10 @@ class Lattice1AlphaWorker:
109
56
  self.device = torch.device(device)
110
57
  self.timings = defaultdict(lambda: 0.0)
111
58
 
59
+ @property
60
+ def frame_shift(self) -> float:
61
+ return 0.02 # 20 ms
62
+
112
63
  @torch.inference_mode()
113
64
  def emission(self, audio: torch.Tensor) -> torch.Tensor:
114
65
  _start = time.time()
@@ -138,68 +89,17 @@ class Lattice1AlphaWorker:
138
89
  self.timings["emission"] += time.time() - _start
139
90
  return emission # (1, T, vocab_size) torch
140
91
 
141
- def load_audio(
142
- self, audio: Union[Pathlike, BinaryIO], channel_selector: Optional[ChannelSelectorType] = "average"
143
- ) -> Tuple[torch.Tensor, int]:
144
- # load audio
145
- try:
146
- waveform, sample_rate = sf.read(audio, always_2d=True, dtype="float32") # numpy array
147
- waveform = waveform.T # (channels, samples)
148
- except Exception as primary_error:
149
- # Fallback to PyAV for formats not supported by soundfile
150
- try:
151
- import av
152
- except ImportError:
153
- raise DependencyError(
154
- "av (PyAV)", install_command="pip install av", context={"primary_error": str(primary_error)}
155
- )
156
-
157
- try:
158
- container = av.open(audio)
159
- audio_stream = next((s for s in container.streams if s.type == "audio"), None)
160
-
161
- if audio_stream is None:
162
- raise AudioFormatError(str(audio), "No audio stream found in file")
163
-
164
- # Resample to target sample rate during decoding
165
- audio_stream.codec_context.format = av.AudioFormat("flt") # 32-bit float
166
-
167
- frames = []
168
- for frame in container.decode(audio_stream):
169
- # Convert frame to numpy array
170
- array = frame.to_ndarray()
171
- # Ensure shape is (channels, samples)
172
- if array.ndim == 1:
173
- array = array.reshape(1, -1)
174
- elif array.ndim == 2 and array.shape[0] > array.shape[1]:
175
- array = array.T
176
- frames.append(array)
177
-
178
- container.close()
179
-
180
- if not frames:
181
- raise AudioFormatError(str(audio), "No audio data found in file")
182
-
183
- # Concatenate all frames
184
- waveform = np.concatenate(frames, axis=1)
185
- sample_rate = audio_stream.codec_context.sample_rate
186
- except Exception as e:
187
- raise AudioLoadError(str(audio), original_error=e)
188
-
189
- return resample_audio(
190
- (torch.from_numpy(waveform), sample_rate),
191
- self.config.get("sampling_rate", 16000),
192
- device=self.device.type,
193
- channel_selector=channel_selector,
194
- )
195
-
196
92
  def alignment(
197
- self, audio: Union[Union[Pathlike, BinaryIO], torch.tensor], lattice_graph: Tuple[str, int, float]
93
+ self,
94
+ audio: AudioData,
95
+ lattice_graph: Tuple[str, int, float],
96
+ emission: Optional[torch.Tensor] = None,
97
+ offset: float = 0.0,
198
98
  ) -> Dict[str, Any]:
199
99
  """Process audio with LatticeGraph.
200
100
 
201
101
  Args:
202
- audio: Audio file path or binary data
102
+ audio: AudioData object
203
103
  lattice_graph: LatticeGraph data
204
104
 
205
105
  Returns:
@@ -210,22 +110,15 @@ class Lattice1AlphaWorker:
210
110
  DependencyError: If required dependencies are missing
211
111
  AlignmentError: If alignment process fails
212
112
  """
213
- # load audio
214
- if isinstance(audio, torch.Tensor):
215
- waveform = audio
216
- else:
217
- waveform = self.load_audio(audio) # (1, L)
218
-
219
- _start = time.time()
220
- try:
221
- emission = self.emission(waveform.to(self.device)) # (1, T, vocab_size)
222
- except Exception as e:
223
- raise AlignmentError(
224
- "Failed to compute acoustic features from audio",
225
- audio_path=str(audio) if not isinstance(audio, torch.Tensor) else "tensor",
226
- context={"original_error": str(e)},
227
- )
228
- self.timings["emission"] += time.time() - _start
113
+ if emission is None:
114
+ try:
115
+ emission = self.emission(audio.tensor.to(self.device)) # (1, T, vocab_size)
116
+ except Exception as e:
117
+ raise AlignmentError(
118
+ "Failed to compute acoustic features from audio",
119
+ media_path=str(audio) if not isinstance(audio, torch.Tensor) else "tensor",
120
+ context={"original_error": str(e)},
121
+ )
229
122
 
230
123
  try:
231
124
  import k2
@@ -275,10 +168,18 @@ class Lattice1AlphaWorker:
275
168
  except Exception as e:
276
169
  raise AlignmentError(
277
170
  "Failed to perform forced alignment",
278
- audio_path=str(audio) if not isinstance(audio, torch.Tensor) else "tensor",
171
+ media_path=str(audio) if not isinstance(audio, torch.Tensor) else "tensor",
279
172
  context={"original_error": str(e), "emission_shape": list(emission.shape), "device": str(device)},
280
173
  )
281
174
  self.timings["align_segments"] += time.time() - _start
282
175
 
283
176
  channel = 0
284
- return emission, results, labels, 0.02, 0.0, channel # frame_shift=20ms, offset=0.0s
177
+ return emission, results, labels, self.frame_shift, offset, channel # frame_shift=20ms
178
+
179
+
180
+ def _load_worker(model_path: str, device: str) -> Lattice1Worker:
181
+ """Instantiate lattice worker with consistent error handling."""
182
+ try:
183
+ return Lattice1Worker(model_path, device=device, num_threads=8)
184
+ except Exception as e:
185
+ raise ModelLoadError(f"worker from {model_path}", original_error=e)
@@ -1,7 +1,7 @@
1
1
  import re
2
2
  from typing import List, Optional, Union
3
3
 
4
- from dp.phonemizer import Phonemizer # g2p-phonemizer
4
+ from g2pp.phonemizer import Phonemizer # g2p-phonemizer
5
5
  from num2words import num2words
6
6
 
7
7
  LANGUAGE = "omni"
@@ -0,0 +1,166 @@
1
+ """Segmented alignment for long audio files."""
2
+
3
+ from typing import List, Optional, Tuple
4
+
5
+ import colorful
6
+
7
+ from lattifai.audio2 import AudioData
8
+ from lattifai.caption import Caption, Supervision
9
+ from lattifai.config import AlignmentConfig
10
+
11
+ from .tokenizer import END_PUNCTUATION
12
+
13
+
14
+ class Segmenter:
15
+ """
16
+ Handles segmented alignment for long audio/video files.
17
+
18
+ Instead of aligning the entire audio at once (which can be slow and memory-intensive
19
+ for long files), this class splits the alignment into manageable segments based on
20
+ caption boundaries, time intervals, or an adaptive strategy.
21
+ """
22
+
23
+ def __init__(self, config: AlignmentConfig):
24
+ """
25
+ Initialize segmented aligner.
26
+
27
+ Args:
28
+ config: Alignment configuration with segmentation parameters
29
+ """
30
+ self.config = config
31
+
32
+ def __call__(
33
+ self,
34
+ caption: Caption,
35
+ max_duration: Optional[float] = None,
36
+ ) -> List[Tuple[float, float, List[Supervision]]]:
37
+ """
38
+ Create segments based on caption boundaries and gaps.
39
+
40
+ Splits when:
41
+ 1. Gap between captions exceeds segment_max_gap
42
+ 2. Duration approaches max_duration (adaptive mode only) and there's a reasonable break
43
+ 3. Duration significantly exceeds max_duration (adaptive mode only)
44
+
45
+ Args:
46
+ caption: Caption object with supervisions
47
+ max_duration: Optional maximum segment duration (enables adaptive behavior)
48
+
49
+ Returns:
50
+ List of (start_time, end_time, supervisions) tuples for each segment
51
+ """
52
+ if not max_duration:
53
+ max_duration = self.config.segment_duration
54
+
55
+ if not caption.supervisions:
56
+ return []
57
+
58
+ supervisions = sorted(caption.supervisions, key=lambda s: s.start)
59
+
60
+ segments = []
61
+ current_segment_sups = []
62
+
63
+ def should_skipalign(sups):
64
+ return len(sups) == 1 and sups[0].text.strip().startswith("[") and sups[0].text.strip().endswith("]")
65
+
66
+ for i, sup in enumerate(supervisions):
67
+ if not current_segment_sups:
68
+ current_segment_sups.append(sup)
69
+ if should_skipalign(current_segment_sups):
70
+ # Single [APPLAUSE] caption, make its own segment
71
+ segments.append(
72
+ (current_segment_sups[0].start, current_segment_sups[-1].end, current_segment_sups, True)
73
+ )
74
+ current_segment_sups = []
75
+ continue
76
+
77
+ prev_sup = supervisions[i - 1]
78
+
79
+ gap = max(sup.start - prev_sup.end, 0.0)
80
+ # Always split on large gaps (natural breaks)
81
+ exclude_max_gap = False
82
+ if gap > self.config.segment_max_gap:
83
+ exclude_max_gap = True
84
+
85
+ endswith_punc = any(sup.text.endswith(punc) for punc in END_PUNCTUATION)
86
+
87
+ # Adaptive duration control
88
+ segment_duration = sup.end - current_segment_sups[0].start
89
+
90
+ # Split if approaching duration limit and there's a reasonable break
91
+ should_split = False
92
+ if segment_duration >= max_duration * 0.8 and gap >= 1.0:
93
+ should_split = True
94
+
95
+ # Force split if duration exceeded significantly
96
+ exclude_max_duration = False
97
+ if segment_duration >= max_duration * 1.2:
98
+ exclude_max_duration = True
99
+
100
+ # [APPLAUSE] [APPLAUSE] [MUSIC]
101
+ if sup.text.strip().startswith("[") and sup.text.strip().endswith("]"):
102
+ # Close current segment
103
+ if current_segment_sups:
104
+ segment_start = current_segment_sups[0].start
105
+ segment_end = current_segment_sups[-1].end + min(gap / 2.0, 2.0)
106
+ segments.append(
107
+ (segment_start, segment_end, current_segment_sups, should_skipalign(current_segment_sups))
108
+ )
109
+
110
+ # Add current supervision as its own segment
111
+ segments.append((sup.start, sup.end, [sup], True))
112
+
113
+ # Update reset for new segment
114
+ current_segment_sups = []
115
+ continue
116
+
117
+ if (should_split and endswith_punc) or exclude_max_gap or exclude_max_duration:
118
+ # Close current segment
119
+ if current_segment_sups:
120
+ segment_start = current_segment_sups[0].start
121
+ segment_end = current_segment_sups[-1].end + min(gap / 2.0, 2.0)
122
+ segments.append(
123
+ (segment_start, segment_end, current_segment_sups, should_skipalign(current_segment_sups))
124
+ )
125
+
126
+ # Start new segment
127
+ current_segment_sups = [sup]
128
+ else:
129
+ current_segment_sups.append(sup)
130
+
131
+ # Add final segment
132
+ if current_segment_sups:
133
+ segment_start = current_segment_sups[0].start
134
+ segment_end = current_segment_sups[-1].end + 2.0
135
+ segments.append((segment_start, segment_end, current_segment_sups, should_skipalign(current_segment_sups)))
136
+
137
+ return segments
138
+
139
+ def print_segment_info(
140
+ self,
141
+ segments: List[Tuple[float, float, List[Supervision]]],
142
+ verbose: bool = True,
143
+ ) -> None:
144
+ """
145
+ Print information about created segments.
146
+
147
+ Args:
148
+ segments: List of segment tuples
149
+ verbose: Whether to print detailed info
150
+ """
151
+ if not verbose:
152
+ return
153
+
154
+ total_sups = sum(len(sups) if isinstance(sups, list) else 1 for _, _, sups, _ in segments)
155
+
156
+ print(colorful.cyan(f"📊 Created {len(segments)} alignment segments:"))
157
+ for i, (start, end, sups, _) in enumerate(segments, 1):
158
+ duration = end - start
159
+ print(
160
+ colorful.white(
161
+ f" Segment {i:04d}: {start:8.2f}s - {end:8.2f}s "
162
+ f"(duration: {duration:8.2f}s, supervisions: {len(sups)if isinstance(sups, list) else 1:4d})"
163
+ )
164
+ )
165
+
166
+ print(colorful.green(f" Total: {total_sups} supervisions across {len(segments)} segments"))