lattifai 0.4.6__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. lattifai/__init__.py +42 -27
  2. lattifai/alignment/__init__.py +6 -0
  3. lattifai/alignment/lattice1_aligner.py +119 -0
  4. lattifai/{workers/lattice1_alpha.py → alignment/lattice1_worker.py} +33 -132
  5. lattifai/{tokenizer → alignment}/phonemizer.py +1 -1
  6. lattifai/alignment/segmenter.py +166 -0
  7. lattifai/{tokenizer → alignment}/tokenizer.py +186 -112
  8. lattifai/audio2.py +211 -0
  9. lattifai/caption/__init__.py +20 -0
  10. lattifai/caption/caption.py +1275 -0
  11. lattifai/{io → caption}/supervision.py +1 -0
  12. lattifai/{io → caption}/text_parser.py +53 -10
  13. lattifai/cli/__init__.py +17 -0
  14. lattifai/cli/alignment.py +153 -0
  15. lattifai/cli/caption.py +204 -0
  16. lattifai/cli/server.py +19 -0
  17. lattifai/cli/transcribe.py +197 -0
  18. lattifai/cli/youtube.py +128 -0
  19. lattifai/client.py +455 -246
  20. lattifai/config/__init__.py +20 -0
  21. lattifai/config/alignment.py +73 -0
  22. lattifai/config/caption.py +178 -0
  23. lattifai/config/client.py +46 -0
  24. lattifai/config/diarization.py +67 -0
  25. lattifai/config/media.py +335 -0
  26. lattifai/config/transcription.py +84 -0
  27. lattifai/diarization/__init__.py +5 -0
  28. lattifai/diarization/lattifai.py +89 -0
  29. lattifai/errors.py +41 -34
  30. lattifai/logging.py +116 -0
  31. lattifai/mixin.py +552 -0
  32. lattifai/server/app.py +420 -0
  33. lattifai/transcription/__init__.py +76 -0
  34. lattifai/transcription/base.py +108 -0
  35. lattifai/transcription/gemini.py +219 -0
  36. lattifai/transcription/lattifai.py +103 -0
  37. lattifai/types.py +30 -0
  38. lattifai/utils.py +3 -31
  39. lattifai/workflow/__init__.py +22 -0
  40. lattifai/workflow/agents.py +6 -0
  41. lattifai/{workflows → workflow}/file_manager.py +81 -57
  42. lattifai/workflow/youtube.py +564 -0
  43. lattifai-1.0.0.dist-info/METADATA +736 -0
  44. lattifai-1.0.0.dist-info/RECORD +52 -0
  45. {lattifai-0.4.6.dist-info → lattifai-1.0.0.dist-info}/WHEEL +1 -1
  46. lattifai-1.0.0.dist-info/entry_points.txt +13 -0
  47. lattifai/base_client.py +0 -126
  48. lattifai/bin/__init__.py +0 -3
  49. lattifai/bin/agent.py +0 -324
  50. lattifai/bin/align.py +0 -295
  51. lattifai/bin/cli_base.py +0 -25
  52. lattifai/bin/subtitle.py +0 -210
  53. lattifai/io/__init__.py +0 -43
  54. lattifai/io/reader.py +0 -86
  55. lattifai/io/utils.py +0 -15
  56. lattifai/io/writer.py +0 -102
  57. lattifai/tokenizer/__init__.py +0 -3
  58. lattifai/workers/__init__.py +0 -3
  59. lattifai/workflows/__init__.py +0 -34
  60. lattifai/workflows/agents.py +0 -12
  61. lattifai/workflows/gemini.py +0 -167
  62. lattifai/workflows/prompts/README.md +0 -22
  63. lattifai/workflows/prompts/gemini/README.md +0 -24
  64. lattifai/workflows/prompts/gemini/transcription_gem.txt +0 -81
  65. lattifai/workflows/youtube.py +0 -931
  66. lattifai-0.4.6.dist-info/METADATA +0 -806
  67. lattifai-0.4.6.dist-info/RECORD +0 -39
  68. lattifai-0.4.6.dist-info/entry_points.txt +0 -3
  69. /lattifai/{io → caption}/gemini_reader.py +0 -0
  70. /lattifai/{io → caption}/gemini_writer.py +0 -0
  71. /lattifai/{workflows → transcription}/prompts/__init__.py +0 -0
  72. /lattifai/{workflows → workflow}/base.py +0 -0
  73. {lattifai-0.4.6.dist-info → lattifai-1.0.0.dist-info}/licenses/LICENSE +0 -0
  74. {lattifai-0.4.6.dist-info → lattifai-1.0.0.dist-info}/top_level.txt +0 -0
lattifai/client.py CHANGED
@@ -1,297 +1,503 @@
1
- """LattifAI client implementation."""
1
+ """LattifAI client implementation with config-driven architecture."""
2
2
 
3
- import asyncio
4
- import os
5
- from typing import Dict, List, Optional, Tuple, Union
3
+ from pathlib import Path
4
+ from typing import TYPE_CHECKING, Optional, Union
6
5
 
7
6
  import colorful
7
+ from lattifai_core.client import SyncAPIClient
8
8
  from lhotse.utils import Pathlike
9
9
 
10
- from lattifai.base_client import AsyncAPIClient, SyncAPIClient
10
+ from lattifai.alignment import Lattice1Aligner, Segmenter
11
+ from lattifai.audio2 import AudioData, AudioLoader
12
+ from lattifai.caption import Caption, InputCaptionFormat
13
+ from lattifai.config import AlignmentConfig, CaptionConfig, ClientConfig, DiarizationConfig, TranscriptionConfig
11
14
  from lattifai.errors import (
12
15
  AlignmentError,
13
- ConfigurationError,
16
+ CaptionProcessingError,
14
17
  LatticeDecodingError,
15
18
  LatticeEncodingError,
16
- LattifAIError,
17
- SubtitleProcessingError,
18
- handle_exception,
19
19
  )
20
- from lattifai.io import SubtitleFormat, SubtitleIO, Supervision
21
- from lattifai.tokenizer import AsyncLatticeTokenizer
22
- from lattifai.utils import _load_tokenizer, _load_worker, _resolve_model_path, _select_device
20
+ from lattifai.mixin import LattifAIClientMixin
23
21
 
22
+ if TYPE_CHECKING:
23
+ from lattifai.diarization import LattifAIDiarizer # noqa: F401
24
24
 
25
- class LattifAI(SyncAPIClient):
26
- """Synchronous LattifAI client."""
25
+
26
+ class LattifAI(LattifAIClientMixin, SyncAPIClient):
27
+ __doc__ = LattifAIClientMixin._CLASS_DOC.format(
28
+ sync_or_async="Synchronous",
29
+ sync_or_async_lower="synchronous",
30
+ client_class="LattifAI",
31
+ await_keyword="",
32
+ async_note="",
33
+ transcriber_note=" (initialized if TranscriptionConfig provided)",
34
+ )
27
35
 
28
36
  def __init__(
29
37
  self,
30
- *,
31
- api_key: Optional[str] = None,
32
- model_name_or_path: str = "Lattifai/Lattice-1-Alpha",
33
- device: Optional[str] = None,
34
- base_url: Optional[str] = None,
35
- timeout: Union[float, int] = 120.0,
36
- max_retries: int = 2,
37
- default_headers: Optional[Dict[str, str]] = None,
38
+ client_config: Optional[ClientConfig] = None,
39
+ alignment_config: Optional[AlignmentConfig] = None,
40
+ caption_config: Optional[CaptionConfig] = None,
41
+ transcription_config: Optional[TranscriptionConfig] = None,
42
+ diarization_config: Optional[DiarizationConfig] = None,
38
43
  ) -> None:
39
- if api_key is None:
40
- api_key = os.environ.get("LATTIFAI_API_KEY")
41
- if api_key is None:
42
- raise ConfigurationError(
43
- "The api_key client option must be set either by passing api_key to the client "
44
- "or by setting the LATTIFAI_API_KEY environment variable"
45
- )
44
+ __doc__ = LattifAIClientMixin._INIT_DOC.format(
45
+ client_class="LattifAI",
46
+ sync_or_async_lower="synchronous",
47
+ config_desc="model and behavior configuration",
48
+ default_desc="default settings (Lattice-1 model, auto device selection)",
49
+ caption_note=" (auto-detect format)",
50
+ transcription_note=". If provided with valid API key, enables transcription capabilities (e.g., Gemini for YouTube videos)",
51
+ api_key_source="and LATTIFAI_API_KEY env var is not set",
52
+ )
53
+ if client_config is None:
54
+ client_config = ClientConfig()
46
55
 
47
- if base_url is None:
48
- base_url = os.environ.get("LATTIFAI_BASE_URL")
49
- if not base_url:
50
- base_url = "https://api.lattifai.com/v1"
51
-
52
- super().__init__(
53
- api_key=api_key,
54
- base_url=base_url,
55
- timeout=timeout,
56
- max_retries=max_retries,
57
- default_headers=default_headers,
56
+ # Initialize base API client
57
+ super().__init__(config=client_config)
58
+
59
+ # Initialize all configs with defaults
60
+ alignment_config, transcription_config, diarization_config = self._init_configs(
61
+ alignment_config, transcription_config, diarization_config
58
62
  )
59
63
 
60
- model_path = _resolve_model_path(model_name_or_path)
61
- device = _select_device(device)
64
+ # Store configs
65
+ if caption_config is None:
66
+ caption_config = CaptionConfig()
67
+ self.caption_config = caption_config
62
68
 
63
- self.tokenizer = _load_tokenizer(self, model_path, device)
64
- self.worker = _load_worker(model_path, device)
65
- self.device = device
69
+ # audio loader
70
+ self.audio_loader = AudioLoader(device=alignment_config.device)
66
71
 
67
- def alignment(
68
- self,
69
- audio: Pathlike,
70
- subtitle: Pathlike,
71
- format: Optional[SubtitleFormat] = None,
72
- split_sentence: bool = False,
73
- return_details: bool = False,
74
- output_subtitle_path: Optional[Pathlike] = None,
75
- ) -> Tuple[List[Supervision], Optional[Pathlike]]:
76
- """Perform alignment on audio and subtitle/text.
72
+ # aligner
73
+ self.aligner = Lattice1Aligner(config=alignment_config)
77
74
 
78
- Args:
79
- audio: Audio file path
80
- subtitle: Subtitle/Text to align with audio
81
- format: Input subtitle format (srt, vtt, ass, txt). Auto-detected if None
82
- split_sentence: Enable intelligent sentence re-splitting based on punctuation semantics
83
- return_details: Return word-level alignment details in Supervision.alignment field
84
- output_subtitle_path: Output path for aligned subtitle (optional)
75
+ # Initialize diarizer if enabled
76
+ self.diarization_config = diarization_config
77
+ self.diarizer: Optional["LattifAIDiarizer"] = None
78
+ if self.diarization_config.enabled:
79
+ from lattifai.diarization import LattifAIDiarizer # noqa: F811
85
80
 
86
- Returns:
87
- Tuple containing:
88
- - List of aligned Supervision objects with timing information
89
- - Output subtitle path (if output_subtitle_path was provided)
81
+ self.diarizer = LattifAIDiarizer(config=self.diarization_config)
90
82
 
91
- Raises:
92
- SubtitleProcessingError: If subtitle file cannot be parsed
93
- LatticeEncodingError: If lattice graph generation fails
94
- AlignmentError: If audio alignment fails
95
- LatticeDecodingError: If lattice decoding fails
96
- """
83
+ # Initialize shared components (transcriber, downloader)
84
+ self._init_shared_components(transcription_config)
85
+
86
+ def alignment(
87
+ self,
88
+ input_media: Union[Pathlike, AudioData],
89
+ input_caption: Optional[Union[Pathlike, Caption]] = None,
90
+ output_caption_path: Optional[Pathlike] = None,
91
+ input_caption_format: Optional[InputCaptionFormat] = None,
92
+ split_sentence: Optional[bool] = None,
93
+ channel_selector: Optional[str | int] = "average",
94
+ ) -> Caption:
97
95
  try:
98
- # step1: parse text or subtitles
99
- print(colorful.cyan(f"📖 Step 1: Reading subtitle file from {subtitle}"))
100
- try:
101
- supervisions = SubtitleIO.read(subtitle, format=format)
102
- print(colorful.green(f" ✓ Parsed {len(supervisions)} subtitle segments"))
103
- except Exception as e:
104
- raise SubtitleProcessingError(
105
- f"Failed to parse subtitle file: {subtitle}",
106
- subtitle_path=str(subtitle),
107
- context={"original_error": str(e)},
96
+ # Step 1: Get caption
97
+ if isinstance(input_media, AudioData):
98
+ media_audio = input_media
99
+ else:
100
+ media_audio = self.audio_loader(
101
+ input_media,
102
+ channel_selector=channel_selector,
108
103
  )
109
104
 
110
- # step2: make lattice by call Lattifai API
111
- print(colorful.cyan("🔗 Step 2: Creating lattice graph from segments"))
112
- try:
113
- supervisions, lattice_id, lattice_graph = self.tokenizer.tokenize(
114
- supervisions, split_sentence=split_sentence
115
- )
116
- print(colorful.green(f" ✓ Generated lattice graph with ID: {lattice_id}"))
117
- except Exception as e:
118
- text_content = " ".join([sup.text for sup in supervisions]) if supervisions else ""
119
- raise LatticeEncodingError(text_content, original_error=e)
120
-
121
- # step3: search lattice graph with audio
122
- print(colorful.cyan(f"🔍 Step 3: Searching lattice graph with audio: {audio}"))
123
- try:
124
- lattice_results = self.worker.alignment(audio, lattice_graph)
125
- print(colorful.green(" ✓ Lattice search completed"))
126
- except Exception as e:
127
- raise AlignmentError(
128
- f"Audio alignment failed for {audio}",
129
- audio_path=str(audio),
130
- subtitle_path=str(subtitle),
131
- context={"original_error": str(e)},
105
+ if not input_caption:
106
+ caption = self._transcribe(media_audio, source_lang=self.caption_config.source_lang, is_async=False)
107
+ else:
108
+ caption = self._read_caption(input_caption, input_caption_format)
109
+
110
+ output_caption_path = output_caption_path or self.caption_config.output_path
111
+
112
+ # Step 2: Check if segmented alignment is needed
113
+ alignment_strategy = self.aligner.config.strategy
114
+
115
+ if alignment_strategy != "entire" or caption.transcription:
116
+ print(colorful.cyan(f"🔄 Using segmented alignment strategy: {alignment_strategy}"))
117
+
118
+ if caption.supervisions and alignment_strategy == "transcription":
119
+ # raise NotImplementedError("Transcription-based alignment is not yet implemented.")
120
+ assert (
121
+ "gemini" not in self.transcriber.name.lower()
122
+ ), "Transcription-based alignment is not supported with Gemini transcriber."
123
+ assert (
124
+ caption.supervisions
125
+ ), "Input caption should contain supervisions when using transcription-based alignment."
126
+ if not caption.transcription:
127
+ import asyncio
128
+
129
+ print(colorful.cyan("📝 Transcribing media for alignment..."))
130
+ if output_caption_path:
131
+ transcript_file = (
132
+ Path(str(output_caption_path)).parent
133
+ / f"{Path(str(media_audio)).stem}_{self.transcriber.file_name}"
134
+ )
135
+ if transcript_file.exists():
136
+ # print(colorful.cyan(f"Reading existing transcription from {transcript_file}"))
137
+ transcript = self._read_caption(transcript_file, verbose=False)
138
+ caption.transcription = transcript.supervisions
139
+ caption.audio_events = transcript.audio_events
140
+
141
+ if not caption.transcription:
142
+ transcript = asyncio.run(
143
+ self.transcriber.transcribe(media_audio, language=self.caption_config.source_lang)
144
+ )
145
+ caption.transcription = transcript.transcription
146
+ caption.audio_events = transcript.audio_events
147
+
148
+ # Align caption.supervisions with transcription to get segments
149
+ import regex
150
+ from error_align import ErrorAlign, error_align # noqa: F401
151
+ from error_align.utils import DELIMITERS, NUMERIC_TOKEN, STANDARD_TOKEN, OpType
152
+
153
+ JOIN_TOKEN = "❄"
154
+ if JOIN_TOKEN not in DELIMITERS:
155
+ DELIMITERS.add(JOIN_TOKEN)
156
+
157
+ def custom_tokenizer(text: str) -> list:
158
+ """Default tokenizer that splits text into words based on whitespace.
159
+
160
+ Args:
161
+ text (str): The input text to tokenize.
162
+
163
+ Returns:
164
+ list: A list of tokens (words).
165
+
166
+ """
167
+ # Escape JOIN_TOKEN for use in regex pattern
168
+ escaped_join_token = regex.escape(JOIN_TOKEN)
169
+ return list(
170
+ regex.finditer(
171
+ rf"({NUMERIC_TOKEN})|({STANDARD_TOKEN}|{escaped_join_token})",
172
+ text,
173
+ regex.UNICODE | regex.VERBOSE,
174
+ )
175
+ )
176
+
177
+ alignments = error_align(
178
+ f"{JOIN_TOKEN}".join(sup.text for sup in caption.supervisions),
179
+ f"{JOIN_TOKEN}".join(sup.text for sup in caption.transcription),
180
+ tokenizer=custom_tokenizer,
181
+ )
182
+
183
+ for align in alignments:
184
+ if align.hyp == JOIN_TOKEN and align.op_type == OpType.MATCH:
185
+ pass
186
+
187
+ # if align.op_type == OpType.MATCH:
188
+ # continue
189
+ # elif align.op_type in (OpType.INSERT, OpType.DELETE, OpType.SUBSTITUTE):
190
+ # # print(colorful.yellow(f"⚠️ Alignment warning: {op}"))
191
+ # pass
192
+
193
+ raise NotImplementedError("Transcription-based segmentation is not yet implemented.")
194
+ else:
195
+ if caption.transcription:
196
+ if not caption.supervisions: # youtube + transcription case
197
+ segments = [(sup.start, sup.end, [sup], not sup.text) for sup in caption.transcription]
198
+ else:
199
+ raise NotImplementedError(
200
+ f"Input caption with both supervisions and transcription(strategy={alignment_strategy}) is not supported."
201
+ )
202
+ elif self.aligner.config.trust_caption_timestamps:
203
+ # Create segmenter
204
+ segmenter = Segmenter(self.aligner.config)
205
+ # Create segments from caption
206
+ segments = segmenter(caption)
207
+ else:
208
+ raise NotImplementedError(
209
+ "Segmented alignment without trusting input timestamps is not yet implemented."
210
+ )
211
+
212
+ # align each segment
213
+ supervisions, alignments = [], []
214
+ for i, (start, end, _supervisions, skipalign) in enumerate(segments, 1):
215
+ print(
216
+ colorful.green(
217
+ f" ⏩ aligning segment {i:04d}/{len(segments):04d}: {start:8.2f}s - {end:8.2f}s"
218
+ )
219
+ )
220
+ if skipalign:
221
+ supervisions.extend(_supervisions)
222
+ alignments.extend(_supervisions) # may overlap with supervisions, but harmless
223
+ continue
224
+
225
+ offset = round(start, 4)
226
+ emission = self.aligner.emission(
227
+ media_audio.tensor[
228
+ :, int(start * media_audio.sampling_rate) : int(end * media_audio.sampling_rate)
229
+ ]
230
+ )
231
+
232
+ # Align segment
233
+ _supervisions, _alignments = self.aligner.alignment(
234
+ media_audio,
235
+ _supervisions,
236
+ split_sentence=split_sentence or self.caption_config.split_sentence,
237
+ return_details=self.caption_config.word_level
238
+ or (output_caption_path and str(output_caption_path).endswith(".TextGrid")),
239
+ emission=emission,
240
+ offset=offset,
241
+ verbose=False,
242
+ )
243
+
244
+ supervisions.extend(_supervisions)
245
+ alignments.extend(_alignments)
246
+ else:
247
+ # Step 2-4: Standard single-pass alignment
248
+ supervisions, alignments = self.aligner.alignment(
249
+ media_audio,
250
+ caption.supervisions,
251
+ split_sentence=split_sentence or self.caption_config.split_sentence,
252
+ return_details=self.caption_config.word_level
253
+ or (output_caption_path and str(output_caption_path).endswith(".TextGrid")),
132
254
  )
133
255
 
134
- # step4: decode lattice results to aligned segments
135
- print(colorful.cyan("🎯 Step 4: Decoding lattice results to aligned segments"))
136
- try:
137
- alignments = self.tokenizer.detokenize(
138
- lattice_id, lattice_results, supervisions=supervisions, return_details=return_details
256
+ # Update caption with aligned results
257
+ caption.supervisions = supervisions
258
+ caption.alignments = alignments
259
+
260
+ # Step 5: Speaker diarization
261
+ if self.diarization_config.enabled and self.diarizer:
262
+ print(colorful.cyan("🗣️ Performing speaker diarization..."))
263
+ caption = self.speaker_diarization(
264
+ input_media=media_audio,
265
+ caption=caption,
266
+ output_caption_path=output_caption_path,
139
267
  )
140
- print(colorful.green(f" ✓ Successfully aligned {len(alignments)} segments"))
141
- except LatticeDecodingError as e:
142
- print(colorful.red(" x Failed to decode lattice alignment results"))
143
- raise e
144
- except Exception as e:
145
- print(colorful.red(" x Failed to decode lattice alignment results"))
146
- raise LatticeDecodingError(lattice_id, original_error=e)
147
-
148
- # step5: export alignments to target format
149
- if output_subtitle_path:
150
- try:
151
- SubtitleIO.write(alignments, output_path=output_subtitle_path)
152
- print(colorful.green(f"🎉🎉🎉🎉🎉 Subtitle file written to: {output_subtitle_path}"))
153
- except Exception as e:
154
- raise SubtitleProcessingError(
155
- f"Failed to write output file: {output_subtitle_path}",
156
- subtitle_path=str(output_subtitle_path),
157
- context={"original_error": str(e)},
158
- )
159
- return (alignments, output_subtitle_path)
268
+ elif output_caption_path:
269
+ self._write_caption(caption, output_caption_path)
160
270
 
161
- except (SubtitleProcessingError, LatticeEncodingError, AlignmentError, LatticeDecodingError):
271
+ return caption
272
+ except (CaptionProcessingError, LatticeEncodingError, AlignmentError, LatticeDecodingError):
162
273
  # Re-raise our specific errors as-is
163
274
  raise
164
275
  except Exception as e:
165
276
  # Catch any unexpected errors and wrap them
166
277
  raise AlignmentError(
167
278
  "Unexpected error during alignment process",
168
- audio_path=str(audio),
169
- subtitle_path=str(subtitle),
279
+ media_path=str(input_media),
280
+ caption_path=str(input_caption),
170
281
  context={"original_error": str(e), "error_type": e.__class__.__name__},
171
282
  )
172
283
 
173
-
174
- class AsyncLattifAI(AsyncAPIClient):
175
- """Asynchronous LattifAI client."""
176
-
177
- def __init__(
284
+ def speaker_diarization(
178
285
  self,
179
- *,
180
- api_key: Optional[str] = None,
181
- model_name_or_path: str = "Lattifai/Lattice-1-Alpha",
182
- device: Optional[str] = None,
183
- base_url: Optional[str] = None,
184
- timeout: Union[float, int] = 120.0,
185
- max_retries: int = 2,
186
- default_headers: Optional[Dict[str, str]] = None,
187
- ) -> None:
188
- if api_key is None:
189
- api_key = os.environ.get("LATTIFAI_API_KEY")
190
- if api_key is None:
191
- raise ConfigurationError(
192
- "The api_key client option must be set either by passing api_key to the client "
193
- "or by setting the LATTIFAI_API_KEY environment variable"
194
- )
195
-
196
- if base_url is None:
197
- base_url = os.environ.get("LATTIFAI_BASE_URL")
198
- if not base_url:
199
- base_url = "https://api.lattifai.com/v1"
200
-
201
- super().__init__(
202
- api_key=api_key,
203
- base_url=base_url,
204
- timeout=timeout,
205
- max_retries=max_retries,
206
- default_headers=default_headers,
207
- )
286
+ input_media: AudioData,
287
+ caption: Caption,
288
+ output_caption_path: Optional[Pathlike] = None,
289
+ ) -> Caption:
290
+ """
291
+ Perform speaker diarization on aligned caption.
208
292
 
209
- model_path = _resolve_model_path(model_name_or_path)
210
- device = _select_device(device)
293
+ Args:
294
+ input_media: AudioData object
295
+ caption: Caption object with aligned segments
296
+ output_caption_path: Optional path to write diarized caption
211
297
 
212
- self.tokenizer = _load_tokenizer(self, model_path, device, tokenizer_cls=AsyncLatticeTokenizer)
213
- self.worker = _load_worker(model_path, device)
214
- self.device = device
298
+ Returns:
299
+ Caption object with speaker labels assigned
215
300
 
216
- async def alignment(
217
- self,
218
- audio: Pathlike,
219
- subtitle: Pathlike,
220
- format: Optional[SubtitleFormat] = None,
221
- split_sentence: bool = False,
222
- return_details: bool = False,
223
- output_subtitle_path: Optional[Pathlike] = None,
224
- ) -> Tuple[List[Supervision], Optional[Pathlike]]:
225
- try:
226
- print(colorful.cyan(f"📖 Step 1: Reading subtitle file from {subtitle}"))
227
- try:
228
- supervisions = await asyncio.to_thread(SubtitleIO.read, subtitle, format=format)
229
- print(colorful.green(f" ✓ Parsed {len(supervisions)} subtitle segments"))
230
- except Exception as e:
231
- raise SubtitleProcessingError(
232
- f"Failed to parse subtitle file: {subtitle}",
233
- subtitle_path=str(subtitle),
234
- context={"original_error": str(e)},
301
+ Raises:
302
+ RuntimeError: If diarizer is not initialized or diarization fails
303
+ """
304
+ if not self.diarizer:
305
+ raise RuntimeError("Diarizer not initialized. Set diarization_config.enabled=True")
306
+
307
+ # Perform diarization and assign speaker labels to caption alignments
308
+ if output_caption_path:
309
+ diarization_file = Path(str(output_caption_path)).with_suffix(".SpkDiar")
310
+ if diarization_file.exists():
311
+ print(colorful.cyan(f"Reading existing speaker diarization from {diarization_file}"))
312
+ caption.read_speaker_diarization(diarization_file)
313
+
314
+ diarization, alignments = self.diarizer.diarize_with_alignments(
315
+ input_media, caption.alignments, diarization=caption.speaker_diarization
316
+ )
317
+ caption.alignments = alignments
318
+ caption.speaker_diarization = diarization
319
+
320
+ # Write output if requested
321
+ if output_caption_path:
322
+ self._write_caption(caption, output_caption_path)
323
+
324
+ if self.diarizer.config.debug:
325
+ # debug
326
+ from tgt import Interval, IntervalTier, TextGrid, write_to_file
327
+
328
+ debug_tg = TextGrid()
329
+ transcript_tier = IntervalTier(
330
+ start_time=0,
331
+ end_time=input_media.duration,
332
+ name="transcript",
333
+ objects=[Interval(sup.start, sup.end, sup.text) for sup in caption.alignments],
235
334
  )
335
+ debug_tg.add_tier(transcript_tier)
236
336
 
237
- print(colorful.cyan("🔗 Step 2: Creating lattice graph from segments"))
238
- try:
239
- supervisions, lattice_id, lattice_graph = await self.tokenizer.tokenize(
240
- supervisions,
241
- split_sentence=split_sentence,
337
+ speaker_tier = IntervalTier(
338
+ start_time=0,
339
+ end_time=input_media.duration,
340
+ name="speaker",
341
+ objects=[Interval(sup.start, sup.end, sup.speaker) for sup in caption.alignments],
242
342
  )
243
- print(colorful.green(f" ✓ Generated lattice graph with ID: {lattice_id}"))
244
- except Exception as e:
245
- text_content = " ".join([sup.text for sup in supervisions]) if supervisions else ""
246
- raise LatticeEncodingError(text_content, original_error=e)
247
-
248
- print(colorful.cyan(f"🔍 Step 3: Searching lattice graph with audio: {audio}"))
249
- try:
250
- lattice_results = await asyncio.to_thread(self.worker.alignment, audio, lattice_graph)
251
- print(colorful.green(" ✓ Lattice search completed"))
252
- except Exception as e:
253
- raise AlignmentError(
254
- f"Audio alignment failed for {audio}",
255
- audio_path=str(audio),
256
- subtitle_path=str(subtitle),
257
- context={"original_error": str(e)},
343
+ debug_tg.add_tier(speaker_tier)
344
+
345
+ from collections import defaultdict
346
+
347
+ spk2intervals = defaultdict(lambda: [])
348
+ num_multispk = 0
349
+
350
+ segments, skipks = [], []
351
+ for k, supervision in enumerate(caption.alignments): # TODO: alignments 本身存在 overlap, eg: [event]
352
+ # supervision = caption.alignments[k]
353
+ if supervision.custom.get("speaker", []):
354
+ num_multispk += 1
355
+ else:
356
+ continue
357
+
358
+ if k in skipks:
359
+ continue
360
+
361
+ for speaker in supervision.custom.get("speaker", []):
362
+ for name, start_time, end_time in speaker:
363
+ spk2intervals[name].append(Interval(start_time, end_time, name))
364
+
365
+ _segments = []
366
+ if k > 0:
367
+ _segments.append(caption.alignments[k - 1])
368
+ _segments.append(supervision)
369
+ while k + 1 < len(caption.alignments):
370
+ skipks.append(k + 1)
371
+ next_sup = caption.alignments[k + 1]
372
+ if not next_sup.custom.get("speaker", []):
373
+ k += 1
374
+ break
375
+ _segments.append(next_sup)
376
+ k += 1
377
+
378
+ if segments:
379
+ if _segments[0].start >= segments[-1][-1].end:
380
+ segments.append(_segments)
381
+ else:
382
+ if _segments[1:]:
383
+ segments.append(_segments[1:])
384
+ else:
385
+ pass
386
+ else:
387
+ segments.append(_segments)
388
+
389
+ print(
390
+ f"Number of multi-speaker segments: {num_multispk}/{len(caption.alignments)} segments: {len(segments)}"
258
391
  )
259
392
 
260
- print(colorful.cyan("🎯 Step 4: Decoding lattice results to aligned segments"))
261
- try:
262
- alignments = await self.tokenizer.detokenize(
263
- lattice_id, lattice_results, supervisions=supervisions, return_details=return_details
264
- )
265
- print(colorful.green(f" ✓ Successfully aligned {len(alignments)} segments"))
266
- except LatticeDecodingError as e:
267
- print(colorful.red(" x Failed to decode lattice alignment results"))
268
- raise e
269
- except Exception as e:
270
- print(colorful.red(" x Failed to decode lattice alignment results"))
271
- raise LatticeDecodingError(lattice_id, original_error=e)
272
-
273
- if output_subtitle_path:
274
- try:
275
- await asyncio.to_thread(SubtitleIO.write, alignments, output_subtitle_path)
276
- print(colorful.green(f"🎉🎉🎉🎉🎉 Subtitle file written to: {output_subtitle_path}"))
277
- except Exception as e:
278
- raise SubtitleProcessingError(
279
- f"Failed to write output file: {output_subtitle_path}",
280
- subtitle_path=str(output_subtitle_path),
281
- context={"original_error": str(e)},
393
+ for speaker, intervals in sorted(spk2intervals.items(), key=lambda x: x[0]):
394
+ speaker_tier = IntervalTier(
395
+ start_time=0, end_time=input_media.duration, name=speaker, objects=intervals
282
396
  )
397
+ debug_tg.add_tier(speaker_tier)
398
+
399
+ for tier in caption.speaker_diarization.tiers:
400
+ tier.name = f"Diarization-{tier.name}"
401
+ debug_tg.add_tier(tier)
402
+
403
+ tier = IntervalTier(
404
+ start_time=0,
405
+ end_time=input_media.duration,
406
+ name="resegment",
407
+ objects=[
408
+ Interval(round(sup.start, 2), round(sup.end, 2), sup.text)
409
+ for _segments in segments
410
+ for sup in _segments
411
+ ],
412
+ )
413
+ debug_tg.add_tier(tier)
283
414
 
284
- return (alignments, output_subtitle_path)
415
+ # if caption.audio_events:
416
+ # for tier in caption.audio_events.tiers:
417
+ # # tier.name = f"{tier.name}"
418
+ # debug_tg.add_tier(tier)
285
419
 
286
- except (SubtitleProcessingError, LatticeEncodingError, AlignmentError, LatticeDecodingError):
287
- raise
288
- except Exception as e:
289
- raise AlignmentError(
290
- "Unexpected error during alignment process",
291
- audio_path=str(audio),
292
- subtitle_path=str(subtitle),
293
- context={"original_error": str(e), "error_type": e.__class__.__name__},
294
- )
420
+ debug_tgt_file = Path(str(output_caption_path)).with_suffix(".DiarizationDebug.TextGrid")
421
+ write_to_file(debug_tg, debug_tgt_file, format="long")
422
+
423
+ return caption
424
+
425
+ def youtube(
426
+ self,
427
+ url: str,
428
+ output_dir: Optional[Pathlike] = None,
429
+ media_format: Optional[str] = None,
430
+ source_lang: Optional[str] = None,
431
+ force_overwrite: bool = False,
432
+ output_caption_path: Optional[Pathlike] = None,
433
+ split_sentence: Optional[bool] = None,
434
+ use_transcription: bool = False,
435
+ channel_selector: Optional[str | int] = "average",
436
+ ) -> Caption:
437
+ # Prepare output directory and media format
438
+ output_dir = self._prepare_youtube_output_dir(output_dir)
439
+ media_format = self._determine_media_format(media_format)
440
+
441
+ print(colorful.cyan(f"🎬 Starting YouTube workflow for: {url}"))
442
+
443
+ # Step 1: Download media
444
+ media_file = self._download_media_sync(url, output_dir, media_format, force_overwrite)
445
+
446
+ media_audio = self.audio_loader(media_file, channel_selector=channel_selector)
447
+
448
+ # Step 2: Get or create captions (download or transcribe)
449
+ caption = self._download_or_transcribe_caption(
450
+ url,
451
+ output_dir,
452
+ media_audio,
453
+ force_overwrite,
454
+ source_lang or self.caption_config.source_lang,
455
+ is_async=False,
456
+ use_transcription=use_transcription,
457
+ )
458
+
459
+ # Step 3: Generate output path if not provided
460
+ output_caption_path = self._generate_output_caption_path(output_caption_path, media_file, output_dir)
461
+
462
+ # Step 4: Perform alignment
463
+ print(colorful.cyan("🔗 Performing forced alignment..."))
464
+
465
+ caption: Caption = self.alignment(
466
+ input_media=media_audio,
467
+ input_caption=caption,
468
+ output_caption_path=output_caption_path,
469
+ split_sentence=split_sentence,
470
+ channel_selector=channel_selector,
471
+ )
472
+
473
+ return caption
474
+
475
+
476
+ # Set docstrings for LattifAI methods
477
+ LattifAI.alignment.__doc__ = LattifAIClientMixin._ALIGNMENT_DOC.format(
478
+ async_prefix="",
479
+ async_word="",
480
+ timing_desc="each word",
481
+ concurrency_note="",
482
+ async_suffix1="",
483
+ async_suffix2="",
484
+ async_suffix3="",
485
+ async_suffix4="",
486
+ async_suffix5="",
487
+ format_default="auto-detects",
488
+ export_note=" in the same format as input (or config default)",
489
+ timing_note=" (start, duration, text)",
490
+ example_imports="client = LattifAI()",
491
+ example_code="""alignments, output_path = client.alignment(
492
+ ... input_media="speech.wav",
493
+ ... input_caption="transcript.srt",
494
+ ... output_caption_path="aligned.srt"
495
+ ... )
496
+ >>> for seg in alignments:
497
+ ... print(f"{seg.start:.2f}s - {seg.end:.2f}s: {seg.text}")""",
498
+ )
499
+
500
+ LattifAI.youtube.__doc__ = LattifAIClientMixin._YOUTUBE_METHOD_DOC.format(client_class="LattifAI", await_keyword="")
295
501
 
296
502
 
297
503
  if __name__ == "__main__":
@@ -299,14 +505,17 @@ if __name__ == "__main__":
299
505
  import sys
300
506
 
301
507
  if len(sys.argv) == 5:
302
- audio, subtitle, output, split_sentence = sys.argv[1:]
508
+ audio, caption, output, split_sentence = sys.argv[1:]
303
509
  split_sentence = split_sentence.lower() in ("true", "1", "yes")
304
510
  else:
305
511
  audio = "tests/data/SA1.wav"
306
- subtitle = "tests/data/SA1.TXT"
512
+ caption = "tests/data/SA1.TXT"
307
513
  output = None
308
514
  split_sentence = False
309
515
 
310
- (alignments, output_subtitle_path) = client.alignment(
311
- audio, subtitle, output_subtitle_path=output, split_sentence=split_sentence, return_details=True
516
+ (alignments, output_caption_path) = client.alignment(
517
+ input_media=audio,
518
+ input_caption=caption,
519
+ output_caption_path=output,
520
+ split_sentence=split_sentence,
312
521
  )