lattifai 0.4.5__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. lattifai/__init__.py +61 -47
  2. lattifai/alignment/__init__.py +6 -0
  3. lattifai/alignment/lattice1_aligner.py +119 -0
  4. lattifai/alignment/lattice1_worker.py +185 -0
  5. lattifai/{tokenizer → alignment}/phonemizer.py +4 -4
  6. lattifai/alignment/segmenter.py +166 -0
  7. lattifai/{tokenizer → alignment}/tokenizer.py +244 -169
  8. lattifai/audio2.py +211 -0
  9. lattifai/caption/__init__.py +20 -0
  10. lattifai/caption/caption.py +1275 -0
  11. lattifai/{io → caption}/gemini_reader.py +30 -30
  12. lattifai/{io → caption}/gemini_writer.py +17 -17
  13. lattifai/{io → caption}/supervision.py +4 -3
  14. lattifai/caption/text_parser.py +145 -0
  15. lattifai/cli/__init__.py +17 -0
  16. lattifai/cli/alignment.py +153 -0
  17. lattifai/cli/caption.py +204 -0
  18. lattifai/cli/server.py +19 -0
  19. lattifai/cli/transcribe.py +197 -0
  20. lattifai/cli/youtube.py +128 -0
  21. lattifai/client.py +460 -251
  22. lattifai/config/__init__.py +20 -0
  23. lattifai/config/alignment.py +73 -0
  24. lattifai/config/caption.py +178 -0
  25. lattifai/config/client.py +46 -0
  26. lattifai/config/diarization.py +67 -0
  27. lattifai/config/media.py +335 -0
  28. lattifai/config/transcription.py +84 -0
  29. lattifai/diarization/__init__.py +5 -0
  30. lattifai/diarization/lattifai.py +89 -0
  31. lattifai/errors.py +98 -91
  32. lattifai/logging.py +116 -0
  33. lattifai/mixin.py +552 -0
  34. lattifai/server/app.py +420 -0
  35. lattifai/transcription/__init__.py +76 -0
  36. lattifai/transcription/base.py +108 -0
  37. lattifai/transcription/gemini.py +219 -0
  38. lattifai/transcription/lattifai.py +103 -0
  39. lattifai/{workflows → transcription}/prompts/__init__.py +4 -4
  40. lattifai/types.py +30 -0
  41. lattifai/utils.py +16 -44
  42. lattifai/workflow/__init__.py +22 -0
  43. lattifai/workflow/agents.py +6 -0
  44. lattifai/{workflows → workflow}/base.py +22 -22
  45. lattifai/{workflows → workflow}/file_manager.py +239 -215
  46. lattifai/workflow/youtube.py +564 -0
  47. lattifai-1.0.0.dist-info/METADATA +736 -0
  48. lattifai-1.0.0.dist-info/RECORD +52 -0
  49. {lattifai-0.4.5.dist-info → lattifai-1.0.0.dist-info}/WHEEL +1 -1
  50. lattifai-1.0.0.dist-info/entry_points.txt +13 -0
  51. {lattifai-0.4.5.dist-info → lattifai-1.0.0.dist-info}/licenses/LICENSE +1 -1
  52. lattifai/base_client.py +0 -126
  53. lattifai/bin/__init__.py +0 -3
  54. lattifai/bin/agent.py +0 -325
  55. lattifai/bin/align.py +0 -296
  56. lattifai/bin/cli_base.py +0 -25
  57. lattifai/bin/subtitle.py +0 -210
  58. lattifai/io/__init__.py +0 -42
  59. lattifai/io/reader.py +0 -85
  60. lattifai/io/text_parser.py +0 -75
  61. lattifai/io/utils.py +0 -15
  62. lattifai/io/writer.py +0 -90
  63. lattifai/tokenizer/__init__.py +0 -3
  64. lattifai/workers/__init__.py +0 -3
  65. lattifai/workers/lattice1_alpha.py +0 -284
  66. lattifai/workflows/__init__.py +0 -34
  67. lattifai/workflows/agents.py +0 -10
  68. lattifai/workflows/gemini.py +0 -167
  69. lattifai/workflows/prompts/README.md +0 -22
  70. lattifai/workflows/prompts/gemini/README.md +0 -24
  71. lattifai/workflows/prompts/gemini/transcription_gem.txt +0 -81
  72. lattifai/workflows/youtube.py +0 -931
  73. lattifai-0.4.5.dist-info/METADATA +0 -808
  74. lattifai-0.4.5.dist-info/RECORD +0 -39
  75. lattifai-0.4.5.dist-info/entry_points.txt +0 -3
  76. {lattifai-0.4.5.dist-info → lattifai-1.0.0.dist-info}/top_level.txt +0 -0
lattifai/__init__.py CHANGED
@@ -1,34 +1,45 @@
1
- import os
2
1
  import sys
3
2
  import warnings
3
+ from importlib.metadata import version
4
4
 
5
+ # Re-export I/O classes
6
+ from .caption import Caption
7
+
8
+ # Re-export client classes
9
+ from .client import LattifAI
10
+
11
+ # Re-export config classes
12
+ from .config import (
13
+ AUDIO_FORMATS,
14
+ MEDIA_FORMATS,
15
+ VIDEO_FORMATS,
16
+ AlignmentConfig,
17
+ CaptionConfig,
18
+ ClientConfig,
19
+ DiarizationConfig,
20
+ MediaConfig,
21
+ )
5
22
  from .errors import (
6
23
  AlignmentError,
7
24
  APIError,
8
25
  AudioFormatError,
9
26
  AudioLoadError,
10
27
  AudioProcessingError,
28
+ CaptionParseError,
29
+ CaptionProcessingError,
11
30
  ConfigurationError,
12
31
  DependencyError,
13
32
  LatticeDecodingError,
14
33
  LatticeEncodingError,
15
34
  LattifAIError,
16
35
  ModelLoadError,
17
- SubtitleParseError,
18
- SubtitleProcessingError,
19
36
  )
20
- from .io import SubtitleIO
21
-
22
- try:
23
- from importlib.metadata import version
24
- except ImportError:
25
- # Python < 3.8
26
- from importlib_metadata import version
37
+ from .logging import get_logger, set_log_level, setup_logger
27
38
 
28
39
  try:
29
- __version__ = version('lattifai')
40
+ __version__ = version("lattifai")
30
41
  except Exception:
31
- __version__ = '0.1.0' # fallback version
42
+ __version__ = "0.1.0" # fallback version
32
43
 
33
44
 
34
45
  # Check and auto-install k2 if not present
@@ -39,15 +50,15 @@ def _check_and_install_k2():
39
50
  except ImportError:
40
51
  import subprocess
41
52
 
42
- print('k2 is not installed. Attempting to install k2...')
53
+ print("k2 is not installed. Attempting to install k2...")
43
54
  try:
44
- subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'install-k2'])
45
- subprocess.check_call([sys.executable, '-m', 'install_k2'])
55
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "install-k2"])
56
+ subprocess.check_call([sys.executable, "-m", "install_k2"])
46
57
  import k2 # Try importing again after installation
47
58
 
48
- print('k2 installed successfully.')
59
+ print("k2 installed successfully.")
49
60
  except Exception as e:
50
- warnings.warn(f'Failed to install k2 automatically. Please install it manually. Error: {e}')
61
+ warnings.warn(f"Failed to install k2 automatically. Please install it manually. Error: {e}")
51
62
  return True
52
63
 
53
64
 
@@ -55,35 +66,38 @@ def _check_and_install_k2():
55
66
  _check_and_install_k2()
56
67
 
57
68
 
58
- # Lazy import for LattifAI to avoid dependency issues during basic import
59
- def __getattr__(name):
60
- if name == 'LattifAI':
61
- from .client import LattifAI
62
-
63
- return LattifAI
64
- if name == 'AsyncLattifAI':
65
- from .client import AsyncLattifAI
66
-
67
- return AsyncLattifAI
68
- raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
69
-
70
-
71
69
  __all__ = [
72
- 'LattifAI', # noqa: F822
73
- 'AsyncLattifAI', # noqa: F822
74
- 'LattifAIError',
75
- 'AudioProcessingError',
76
- 'AudioLoadError',
77
- 'AudioFormatError',
78
- 'SubtitleProcessingError',
79
- 'SubtitleParseError',
80
- 'AlignmentError',
81
- 'LatticeEncodingError',
82
- 'LatticeDecodingError',
83
- 'ModelLoadError',
84
- 'DependencyError',
85
- 'APIError',
86
- 'ConfigurationError',
87
- 'SubtitleIO',
88
- '__version__',
70
+ # Client classes
71
+ "LattifAI",
72
+ # Config classes
73
+ "AlignmentConfig",
74
+ "ClientConfig",
75
+ "CaptionConfig",
76
+ "DiarizationConfig",
77
+ "MediaConfig",
78
+ "AUDIO_FORMATS",
79
+ "VIDEO_FORMATS",
80
+ "MEDIA_FORMATS",
81
+ # Error classes
82
+ "LattifAIError",
83
+ "AudioProcessingError",
84
+ "AudioLoadError",
85
+ "AudioFormatError",
86
+ "CaptionProcessingError",
87
+ "CaptionParseError",
88
+ "AlignmentError",
89
+ "LatticeEncodingError",
90
+ "LatticeDecodingError",
91
+ "ModelLoadError",
92
+ "DependencyError",
93
+ "APIError",
94
+ "ConfigurationError",
95
+ # Logging
96
+ "setup_logger",
97
+ "get_logger",
98
+ "set_log_level",
99
+ # I/O
100
+ "Caption",
101
+ # Version
102
+ "__version__",
89
103
  ]
@@ -0,0 +1,6 @@
1
+ """Alignment module for LattifAI forced alignment."""
2
+
3
+ from .lattice1_aligner import Lattice1Aligner
4
+ from .segmenter import Segmenter
5
+
6
+ __all__ = ["Lattice1Aligner", "Segmenter"]
@@ -0,0 +1,119 @@
1
+ """Lattice-1 Aligner implementation."""
2
+
3
+ from typing import Any, List, Optional, Tuple
4
+
5
+ import colorful
6
+ import torch
7
+
8
+ from lattifai.audio2 import AudioData
9
+ from lattifai.caption import Supervision
10
+ from lattifai.config import AlignmentConfig
11
+ from lattifai.errors import (
12
+ AlignmentError,
13
+ LatticeDecodingError,
14
+ LatticeEncodingError,
15
+ )
16
+ from lattifai.utils import _resolve_model_path
17
+
18
+ from .lattice1_worker import _load_worker
19
+ from .tokenizer import _load_tokenizer
20
+
21
+ ClientType = Any
22
+
23
+
24
+ class Lattice1Aligner(object):
25
+ """Synchronous LattifAI client with config-driven architecture."""
26
+
27
+ def __init__(
28
+ self,
29
+ config: AlignmentConfig,
30
+ ) -> None:
31
+ self.config = config
32
+
33
+ if config.client_wrapper is None:
34
+ raise ValueError("AlignmentConfig.client_wrapper is not set. It must be initialized by the client.")
35
+
36
+ client_wrapper = config.client_wrapper
37
+ model_path = _resolve_model_path(config.model_name)
38
+
39
+ self.tokenizer = _load_tokenizer(client_wrapper, model_path, config.model_name, config.device)
40
+ self.worker = _load_worker(model_path, config.device)
41
+
42
+ self.frame_shift = self.worker.frame_shift
43
+
44
+ def emission(self, audio: torch.Tensor) -> torch.Tensor:
45
+ return self.worker.emission(audio.to(self.worker.device))
46
+
47
+ def alignment(
48
+ self,
49
+ audio: AudioData,
50
+ supervisions: List[Supervision],
51
+ split_sentence: Optional[bool] = False,
52
+ return_details: Optional[bool] = False,
53
+ emission: Optional[torch.Tensor] = None,
54
+ offset: float = 0.0,
55
+ verbose: bool = True,
56
+ ) -> Tuple[List[Supervision], List[Supervision]]:
57
+ """
58
+ Perform alignment on audio and supervisions.
59
+
60
+ Args:
61
+ audio: Audio file path
62
+ supervisions: List of supervision segments to align
63
+ split_sentence: Enable sentence re-splitting
64
+
65
+ Returns:
66
+ Tuple of (supervisions, alignments)
67
+
68
+ Raises:
69
+ LatticeEncodingError: If lattice graph generation fails
70
+ AlignmentError: If audio alignment fails
71
+ LatticeDecodingError: If lattice decoding fails
72
+ """
73
+ try:
74
+ if verbose:
75
+ print(colorful.cyan("🔗 Step 2: Creating lattice graph from segments"))
76
+ try:
77
+ supervisions, lattice_id, lattice_graph = self.tokenizer.tokenize(
78
+ supervisions, split_sentence=split_sentence
79
+ )
80
+ if verbose:
81
+ print(colorful.green(f" ✓ Generated lattice graph with ID: {lattice_id}"))
82
+ except Exception as e:
83
+ text_content = " ".join([sup.text for sup in supervisions]) if supervisions else ""
84
+ raise LatticeEncodingError(text_content, original_error=e)
85
+
86
+ if verbose:
87
+ print(colorful.cyan(f"🔍 Step 3: Searching lattice graph with media: {audio}"))
88
+ try:
89
+ lattice_results = self.worker.alignment(audio, lattice_graph, emission=emission, offset=offset)
90
+ if verbose:
91
+ print(colorful.green(" ✓ Lattice search completed"))
92
+ except Exception as e:
93
+ raise AlignmentError(
94
+ f"Audio alignment failed for {audio}",
95
+ media_path=str(audio),
96
+ context={"original_error": str(e)},
97
+ )
98
+
99
+ if verbose:
100
+ print(colorful.cyan("🎯 Step 4: Decoding lattice results to aligned segments"))
101
+ try:
102
+ alignments = self.tokenizer.detokenize(
103
+ lattice_id, lattice_results, supervisions=supervisions, return_details=return_details
104
+ )
105
+ if verbose:
106
+ print(colorful.green(f" ✓ Successfully aligned {len(alignments)} segments"))
107
+ except LatticeDecodingError as e:
108
+ print(colorful.red(" x Failed to decode lattice alignment results"))
109
+ raise e
110
+ except Exception as e:
111
+ print(colorful.red(" x Failed to decode lattice alignment results"))
112
+ raise LatticeDecodingError(lattice_id, original_error=e)
113
+
114
+ return (supervisions, alignments)
115
+
116
+ except (LatticeEncodingError, AlignmentError, LatticeDecodingError):
117
+ raise
118
+ except Exception as e:
119
+ raise e
@@ -0,0 +1,185 @@
1
+ import json
2
+ import time
3
+ from collections import defaultdict
4
+ from typing import Any, Dict, Optional, Tuple
5
+
6
+ import numpy as np
7
+ import onnxruntime as ort
8
+ import torch
9
+ from lhotse import FbankConfig
10
+ from lhotse.features.kaldi.layers import Wav2LogFilterBank
11
+ from lhotse.utils import Pathlike
12
+
13
+ from lattifai.audio2 import AudioData
14
+ from lattifai.errors import AlignmentError, DependencyError, ModelLoadError
15
+
16
+
17
+ class Lattice1Worker:
18
+ """Worker for processing audio with LatticeGraph."""
19
+
20
+ def __init__(self, model_path: Pathlike, device: str = "cpu", num_threads: int = 8) -> None:
21
+ try:
22
+ self.config = json.load(open(f"{model_path}/config.json"))
23
+ except Exception as e:
24
+ raise ModelLoadError(f"config from {model_path}", original_error=e)
25
+
26
+ # SessionOptions
27
+ sess_options = ort.SessionOptions()
28
+ # sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
29
+ sess_options.intra_op_num_threads = num_threads # CPU cores
30
+ sess_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL
31
+ sess_options.add_session_config_entry("session.intra_op.allow_spinning", "0")
32
+
33
+ providers = []
34
+ if device.startswith("cuda") and ort.get_all_providers().count("CUDAExecutionProvider") > 0:
35
+ providers.append("CUDAExecutionProvider")
36
+ elif device.startswith("mps") and ort.get_all_providers().count("MPSExecutionProvider") > 0:
37
+ providers.append("MPSExecutionProvider")
38
+
39
+ try:
40
+ self.acoustic_ort = ort.InferenceSession(
41
+ f"{model_path}/acoustic_opt.onnx",
42
+ sess_options,
43
+ providers=providers + ["CPUExecutionProvider", "CoreMLExecutionProvider"],
44
+ )
45
+ except Exception as e:
46
+ raise ModelLoadError(f"acoustic model from {model_path}", original_error=e)
47
+
48
+ try:
49
+ config = FbankConfig(num_mel_bins=80, device=device, snip_edges=False)
50
+ config_dict = config.to_dict()
51
+ config_dict.pop("device")
52
+ self.extractor = Wav2LogFilterBank(**config_dict).to(device).eval()
53
+ except Exception as e:
54
+ raise ModelLoadError(f"feature extractor for device {device}", original_error=e)
55
+
56
+ self.device = torch.device(device)
57
+ self.timings = defaultdict(lambda: 0.0)
58
+
59
+ @property
60
+ def frame_shift(self) -> float:
61
+ return 0.02 # 20 ms
62
+
63
+ @torch.inference_mode()
64
+ def emission(self, audio: torch.Tensor) -> torch.Tensor:
65
+ _start = time.time()
66
+ # audio -> features -> emission
67
+ features = self.extractor(audio) # (1, T, D)
68
+ if features.shape[1] > 6000:
69
+ features_list = torch.split(features, 6000, dim=1)
70
+ emissions = []
71
+ for features in features_list:
72
+ ort_inputs = {
73
+ "features": features.cpu().numpy(),
74
+ "feature_lengths": np.array([features.size(1)], dtype=np.int64),
75
+ }
76
+ emission = self.acoustic_ort.run(None, ort_inputs)[0] # (1, T, vocab_size) numpy
77
+ emissions.append(emission)
78
+ emission = torch.cat(
79
+ [torch.from_numpy(emission).to(self.device) for emission in emissions], dim=1
80
+ ) # (1, T, vocab_size)
81
+ else:
82
+ ort_inputs = {
83
+ "features": features.cpu().numpy(),
84
+ "feature_lengths": np.array([features.size(1)], dtype=np.int64),
85
+ }
86
+ emission = self.acoustic_ort.run(None, ort_inputs)[0] # (1, T, vocab_size) numpy
87
+ emission = torch.from_numpy(emission).to(self.device)
88
+
89
+ self.timings["emission"] += time.time() - _start
90
+ return emission # (1, T, vocab_size) torch
91
+
92
+ def alignment(
93
+ self,
94
+ audio: AudioData,
95
+ lattice_graph: Tuple[str, int, float],
96
+ emission: Optional[torch.Tensor] = None,
97
+ offset: float = 0.0,
98
+ ) -> Dict[str, Any]:
99
+ """Process audio with LatticeGraph.
100
+
101
+ Args:
102
+ audio: AudioData object
103
+ lattice_graph: LatticeGraph data
104
+
105
+ Returns:
106
+ Processed LatticeGraph
107
+
108
+ Raises:
109
+ AudioLoadError: If audio cannot be loaded
110
+ DependencyError: If required dependencies are missing
111
+ AlignmentError: If alignment process fails
112
+ """
113
+ if emission is None:
114
+ try:
115
+ emission = self.emission(audio.tensor.to(self.device)) # (1, T, vocab_size)
116
+ except Exception as e:
117
+ raise AlignmentError(
118
+ "Failed to compute acoustic features from audio",
119
+ media_path=str(audio) if not isinstance(audio, torch.Tensor) else "tensor",
120
+ context={"original_error": str(e)},
121
+ )
122
+
123
+ try:
124
+ import k2
125
+ except ImportError:
126
+ raise DependencyError("k2", install_command="pip install install-k2 && python -m install_k2")
127
+
128
+ try:
129
+ from lattifai_core.lattice.decode import align_segments
130
+ except ImportError:
131
+ raise DependencyError("lattifai_core", install_command="Contact support for lattifai_core installation")
132
+
133
+ lattice_graph_str, final_state, acoustic_scale = lattice_graph
134
+
135
+ _start = time.time()
136
+ try:
137
+ # graph
138
+ decoding_graph = k2.Fsa.from_str(lattice_graph_str, acceptor=False)
139
+ decoding_graph.requires_grad_(False)
140
+ decoding_graph = k2.arc_sort(decoding_graph)
141
+ decoding_graph.skip_id = int(final_state)
142
+ decoding_graph.return_id = int(final_state + 1)
143
+ except Exception as e:
144
+ raise AlignmentError(
145
+ "Failed to create decoding graph from lattice",
146
+ context={"original_error": str(e), "lattice_graph_length": len(lattice_graph_str)},
147
+ )
148
+ self.timings["decoding_graph"] += time.time() - _start
149
+
150
+ _start = time.time()
151
+ if self.device.type == "mps":
152
+ device = "cpu" # k2 does not support mps yet
153
+ else:
154
+ device = self.device
155
+
156
+ try:
157
+ results, labels = align_segments(
158
+ emission.to(device) * acoustic_scale,
159
+ decoding_graph.to(device),
160
+ torch.tensor([emission.shape[1]], dtype=torch.int32),
161
+ search_beam=200,
162
+ output_beam=80,
163
+ min_active_states=400,
164
+ max_active_states=10000,
165
+ subsampling_factor=1,
166
+ reject_low_confidence=False,
167
+ )
168
+ except Exception as e:
169
+ raise AlignmentError(
170
+ "Failed to perform forced alignment",
171
+ media_path=str(audio) if not isinstance(audio, torch.Tensor) else "tensor",
172
+ context={"original_error": str(e), "emission_shape": list(emission.shape), "device": str(device)},
173
+ )
174
+ self.timings["align_segments"] += time.time() - _start
175
+
176
+ channel = 0
177
+ return emission, results, labels, self.frame_shift, offset, channel # frame_shift=20ms
178
+
179
+
180
+ def _load_worker(model_path: str, device: str) -> Lattice1Worker:
181
+ """Instantiate lattice worker with consistent error handling."""
182
+ try:
183
+ return Lattice1Worker(model_path, device=device, num_threads=8)
184
+ except Exception as e:
185
+ raise ModelLoadError(f"worker from {model_path}", original_error=e)
@@ -1,16 +1,16 @@
1
1
  import re
2
2
  from typing import List, Optional, Union
3
3
 
4
- from dp.phonemizer import Phonemizer # g2p-phonemizer
4
+ from g2pp.phonemizer import Phonemizer # g2p-phonemizer
5
5
  from num2words import num2words
6
6
 
7
- LANGUAGE = 'omni'
7
+ LANGUAGE = "omni"
8
8
 
9
9
 
10
10
  class G2Phonemizer:
11
11
  def __init__(self, model_checkpoint, device):
12
12
  self.phonemizer = Phonemizer.from_checkpoint(model_checkpoint, device=device).predictor
13
- self.pattern = re.compile(r'\d+')
13
+ self.pattern = re.compile(r"\d+")
14
14
 
15
15
  def num2words(self, word, lang: str):
16
16
  matches = self.pattern.findall(word)
@@ -31,7 +31,7 @@ class G2Phonemizer:
31
31
  is_list = False
32
32
 
33
33
  predictions = self.phonemizer(
34
- [self.num2words(word.replace(' .', '.').replace('.', ' .'), lang=lang or 'en') for word in words],
34
+ [self.num2words(word.replace(" .", ".").replace(".", " ."), lang=lang or "en") for word in words],
35
35
  lang=LANGUAGE,
36
36
  batch_size=min(batch_size or len(words), 128),
37
37
  num_prons=num_prons,
@@ -0,0 +1,166 @@
1
+ """Segmented alignment for long audio files."""
2
+
3
+ from typing import List, Optional, Tuple
4
+
5
+ import colorful
6
+
7
+ from lattifai.audio2 import AudioData
8
+ from lattifai.caption import Caption, Supervision
9
+ from lattifai.config import AlignmentConfig
10
+
11
+ from .tokenizer import END_PUNCTUATION
12
+
13
+
14
+ class Segmenter:
15
+ """
16
+ Handles segmented alignment for long audio/video files.
17
+
18
+ Instead of aligning the entire audio at once (which can be slow and memory-intensive
19
+ for long files), this class splits the alignment into manageable segments based on
20
+ caption boundaries, time intervals, or an adaptive strategy.
21
+ """
22
+
23
+ def __init__(self, config: AlignmentConfig):
24
+ """
25
+ Initialize segmented aligner.
26
+
27
+ Args:
28
+ config: Alignment configuration with segmentation parameters
29
+ """
30
+ self.config = config
31
+
32
+ def __call__(
33
+ self,
34
+ caption: Caption,
35
+ max_duration: Optional[float] = None,
36
+ ) -> List[Tuple[float, float, List[Supervision]]]:
37
+ """
38
+ Create segments based on caption boundaries and gaps.
39
+
40
+ Splits when:
41
+ 1. Gap between captions exceeds segment_max_gap
42
+ 2. Duration approaches max_duration (adaptive mode only) and there's a reasonable break
43
+ 3. Duration significantly exceeds max_duration (adaptive mode only)
44
+
45
+ Args:
46
+ caption: Caption object with supervisions
47
+ max_duration: Optional maximum segment duration (enables adaptive behavior)
48
+
49
+ Returns:
50
+ List of (start_time, end_time, supervisions) tuples for each segment
51
+ """
52
+ if not max_duration:
53
+ max_duration = self.config.segment_duration
54
+
55
+ if not caption.supervisions:
56
+ return []
57
+
58
+ supervisions = sorted(caption.supervisions, key=lambda s: s.start)
59
+
60
+ segments = []
61
+ current_segment_sups = []
62
+
63
+ def should_skipalign(sups):
64
+ return len(sups) == 1 and sups[0].text.strip().startswith("[") and sups[0].text.strip().endswith("]")
65
+
66
+ for i, sup in enumerate(supervisions):
67
+ if not current_segment_sups:
68
+ current_segment_sups.append(sup)
69
+ if should_skipalign(current_segment_sups):
70
+ # Single [APPLAUSE] caption, make its own segment
71
+ segments.append(
72
+ (current_segment_sups[0].start, current_segment_sups[-1].end, current_segment_sups, True)
73
+ )
74
+ current_segment_sups = []
75
+ continue
76
+
77
+ prev_sup = supervisions[i - 1]
78
+
79
+ gap = max(sup.start - prev_sup.end, 0.0)
80
+ # Always split on large gaps (natural breaks)
81
+ exclude_max_gap = False
82
+ if gap > self.config.segment_max_gap:
83
+ exclude_max_gap = True
84
+
85
+ endswith_punc = any(sup.text.endswith(punc) for punc in END_PUNCTUATION)
86
+
87
+ # Adaptive duration control
88
+ segment_duration = sup.end - current_segment_sups[0].start
89
+
90
+ # Split if approaching duration limit and there's a reasonable break
91
+ should_split = False
92
+ if segment_duration >= max_duration * 0.8 and gap >= 1.0:
93
+ should_split = True
94
+
95
+ # Force split if duration exceeded significantly
96
+ exclude_max_duration = False
97
+ if segment_duration >= max_duration * 1.2:
98
+ exclude_max_duration = True
99
+
100
+ # [APPLAUSE] [APPLAUSE] [MUSIC]
101
+ if sup.text.strip().startswith("[") and sup.text.strip().endswith("]"):
102
+ # Close current segment
103
+ if current_segment_sups:
104
+ segment_start = current_segment_sups[0].start
105
+ segment_end = current_segment_sups[-1].end + min(gap / 2.0, 2.0)
106
+ segments.append(
107
+ (segment_start, segment_end, current_segment_sups, should_skipalign(current_segment_sups))
108
+ )
109
+
110
+ # Add current supervision as its own segment
111
+ segments.append((sup.start, sup.end, [sup], True))
112
+
113
+ # Update reset for new segment
114
+ current_segment_sups = []
115
+ continue
116
+
117
+ if (should_split and endswith_punc) or exclude_max_gap or exclude_max_duration:
118
+ # Close current segment
119
+ if current_segment_sups:
120
+ segment_start = current_segment_sups[0].start
121
+ segment_end = current_segment_sups[-1].end + min(gap / 2.0, 2.0)
122
+ segments.append(
123
+ (segment_start, segment_end, current_segment_sups, should_skipalign(current_segment_sups))
124
+ )
125
+
126
+ # Start new segment
127
+ current_segment_sups = [sup]
128
+ else:
129
+ current_segment_sups.append(sup)
130
+
131
+ # Add final segment
132
+ if current_segment_sups:
133
+ segment_start = current_segment_sups[0].start
134
+ segment_end = current_segment_sups[-1].end + 2.0
135
+ segments.append((segment_start, segment_end, current_segment_sups, should_skipalign(current_segment_sups)))
136
+
137
+ return segments
138
+
139
+ def print_segment_info(
140
+ self,
141
+ segments: List[Tuple[float, float, List[Supervision]]],
142
+ verbose: bool = True,
143
+ ) -> None:
144
+ """
145
+ Print information about created segments.
146
+
147
+ Args:
148
+ segments: List of segment tuples
149
+ verbose: Whether to print detailed info
150
+ """
151
+ if not verbose:
152
+ return
153
+
154
+ total_sups = sum(len(sups) if isinstance(sups, list) else 1 for _, _, sups, _ in segments)
155
+
156
+ print(colorful.cyan(f"📊 Created {len(segments)} alignment segments:"))
157
+ for i, (start, end, sups, _) in enumerate(segments, 1):
158
+ duration = end - start
159
+ print(
160
+ colorful.white(
161
+ f" Segment {i:04d}: {start:8.2f}s - {end:8.2f}s "
162
+ f"(duration: {duration:8.2f}s, supervisions: {len(sups)if isinstance(sups, list) else 1:4d})"
163
+ )
164
+ )
165
+
166
+ print(colorful.green(f" Total: {total_sups} supervisions across {len(segments)} segments"))