lattifai 0.4.6__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lattifai/__init__.py +42 -27
- lattifai/alignment/__init__.py +6 -0
- lattifai/alignment/lattice1_aligner.py +119 -0
- lattifai/{workers/lattice1_alpha.py → alignment/lattice1_worker.py} +33 -132
- lattifai/{tokenizer → alignment}/phonemizer.py +1 -1
- lattifai/alignment/segmenter.py +166 -0
- lattifai/{tokenizer → alignment}/tokenizer.py +186 -112
- lattifai/audio2.py +211 -0
- lattifai/caption/__init__.py +20 -0
- lattifai/caption/caption.py +1275 -0
- lattifai/{io → caption}/supervision.py +1 -0
- lattifai/{io → caption}/text_parser.py +53 -10
- lattifai/cli/__init__.py +17 -0
- lattifai/cli/alignment.py +153 -0
- lattifai/cli/caption.py +204 -0
- lattifai/cli/server.py +19 -0
- lattifai/cli/transcribe.py +197 -0
- lattifai/cli/youtube.py +128 -0
- lattifai/client.py +455 -246
- lattifai/config/__init__.py +20 -0
- lattifai/config/alignment.py +73 -0
- lattifai/config/caption.py +178 -0
- lattifai/config/client.py +46 -0
- lattifai/config/diarization.py +67 -0
- lattifai/config/media.py +335 -0
- lattifai/config/transcription.py +84 -0
- lattifai/diarization/__init__.py +5 -0
- lattifai/diarization/lattifai.py +89 -0
- lattifai/errors.py +41 -34
- lattifai/logging.py +116 -0
- lattifai/mixin.py +552 -0
- lattifai/server/app.py +420 -0
- lattifai/transcription/__init__.py +76 -0
- lattifai/transcription/base.py +108 -0
- lattifai/transcription/gemini.py +219 -0
- lattifai/transcription/lattifai.py +103 -0
- lattifai/types.py +30 -0
- lattifai/utils.py +3 -31
- lattifai/workflow/__init__.py +22 -0
- lattifai/workflow/agents.py +6 -0
- lattifai/{workflows → workflow}/file_manager.py +81 -57
- lattifai/workflow/youtube.py +564 -0
- lattifai-1.0.0.dist-info/METADATA +736 -0
- lattifai-1.0.0.dist-info/RECORD +52 -0
- {lattifai-0.4.6.dist-info → lattifai-1.0.0.dist-info}/WHEEL +1 -1
- lattifai-1.0.0.dist-info/entry_points.txt +13 -0
- lattifai/base_client.py +0 -126
- lattifai/bin/__init__.py +0 -3
- lattifai/bin/agent.py +0 -324
- lattifai/bin/align.py +0 -295
- lattifai/bin/cli_base.py +0 -25
- lattifai/bin/subtitle.py +0 -210
- lattifai/io/__init__.py +0 -43
- lattifai/io/reader.py +0 -86
- lattifai/io/utils.py +0 -15
- lattifai/io/writer.py +0 -102
- lattifai/tokenizer/__init__.py +0 -3
- lattifai/workers/__init__.py +0 -3
- lattifai/workflows/__init__.py +0 -34
- lattifai/workflows/agents.py +0 -12
- lattifai/workflows/gemini.py +0 -167
- lattifai/workflows/prompts/README.md +0 -22
- lattifai/workflows/prompts/gemini/README.md +0 -24
- lattifai/workflows/prompts/gemini/transcription_gem.txt +0 -81
- lattifai/workflows/youtube.py +0 -931
- lattifai-0.4.6.dist-info/METADATA +0 -806
- lattifai-0.4.6.dist-info/RECORD +0 -39
- lattifai-0.4.6.dist-info/entry_points.txt +0 -3
- /lattifai/{io → caption}/gemini_reader.py +0 -0
- /lattifai/{io → caption}/gemini_writer.py +0 -0
- /lattifai/{workflows → transcription}/prompts/__init__.py +0 -0
- /lattifai/{workflows → workflow}/base.py +0 -0
- {lattifai-0.4.6.dist-info → lattifai-1.0.0.dist-info}/licenses/LICENSE +0 -0
- {lattifai-0.4.6.dist-info → lattifai-1.0.0.dist-info}/top_level.txt +0 -0
lattifai/client.py
CHANGED
|
@@ -1,297 +1,503 @@
|
|
|
1
|
-
"""LattifAI client implementation."""
|
|
1
|
+
"""LattifAI client implementation with config-driven architecture."""
|
|
2
2
|
|
|
3
|
-
import
|
|
4
|
-
import
|
|
5
|
-
from typing import Dict, List, Optional, Tuple, Union
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import TYPE_CHECKING, Optional, Union
|
|
6
5
|
|
|
7
6
|
import colorful
|
|
7
|
+
from lattifai_core.client import SyncAPIClient
|
|
8
8
|
from lhotse.utils import Pathlike
|
|
9
9
|
|
|
10
|
-
from lattifai.
|
|
10
|
+
from lattifai.alignment import Lattice1Aligner, Segmenter
|
|
11
|
+
from lattifai.audio2 import AudioData, AudioLoader
|
|
12
|
+
from lattifai.caption import Caption, InputCaptionFormat
|
|
13
|
+
from lattifai.config import AlignmentConfig, CaptionConfig, ClientConfig, DiarizationConfig, TranscriptionConfig
|
|
11
14
|
from lattifai.errors import (
|
|
12
15
|
AlignmentError,
|
|
13
|
-
|
|
16
|
+
CaptionProcessingError,
|
|
14
17
|
LatticeDecodingError,
|
|
15
18
|
LatticeEncodingError,
|
|
16
|
-
LattifAIError,
|
|
17
|
-
SubtitleProcessingError,
|
|
18
|
-
handle_exception,
|
|
19
19
|
)
|
|
20
|
-
from lattifai.
|
|
21
|
-
from lattifai.tokenizer import AsyncLatticeTokenizer
|
|
22
|
-
from lattifai.utils import _load_tokenizer, _load_worker, _resolve_model_path, _select_device
|
|
20
|
+
from lattifai.mixin import LattifAIClientMixin
|
|
23
21
|
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from lattifai.diarization import LattifAIDiarizer # noqa: F401
|
|
24
24
|
|
|
25
|
-
|
|
26
|
-
|
|
25
|
+
|
|
26
|
+
class LattifAI(LattifAIClientMixin, SyncAPIClient):
|
|
27
|
+
__doc__ = LattifAIClientMixin._CLASS_DOC.format(
|
|
28
|
+
sync_or_async="Synchronous",
|
|
29
|
+
sync_or_async_lower="synchronous",
|
|
30
|
+
client_class="LattifAI",
|
|
31
|
+
await_keyword="",
|
|
32
|
+
async_note="",
|
|
33
|
+
transcriber_note=" (initialized if TranscriptionConfig provided)",
|
|
34
|
+
)
|
|
27
35
|
|
|
28
36
|
def __init__(
|
|
29
37
|
self,
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
timeout: Union[float, int] = 120.0,
|
|
36
|
-
max_retries: int = 2,
|
|
37
|
-
default_headers: Optional[Dict[str, str]] = None,
|
|
38
|
+
client_config: Optional[ClientConfig] = None,
|
|
39
|
+
alignment_config: Optional[AlignmentConfig] = None,
|
|
40
|
+
caption_config: Optional[CaptionConfig] = None,
|
|
41
|
+
transcription_config: Optional[TranscriptionConfig] = None,
|
|
42
|
+
diarization_config: Optional[DiarizationConfig] = None,
|
|
38
43
|
) -> None:
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
)
|
|
44
|
+
__doc__ = LattifAIClientMixin._INIT_DOC.format(
|
|
45
|
+
client_class="LattifAI",
|
|
46
|
+
sync_or_async_lower="synchronous",
|
|
47
|
+
config_desc="model and behavior configuration",
|
|
48
|
+
default_desc="default settings (Lattice-1 model, auto device selection)",
|
|
49
|
+
caption_note=" (auto-detect format)",
|
|
50
|
+
transcription_note=". If provided with valid API key, enables transcription capabilities (e.g., Gemini for YouTube videos)",
|
|
51
|
+
api_key_source="and LATTIFAI_API_KEY env var is not set",
|
|
52
|
+
)
|
|
53
|
+
if client_config is None:
|
|
54
|
+
client_config = ClientConfig()
|
|
46
55
|
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
api_key=api_key,
|
|
54
|
-
base_url=base_url,
|
|
55
|
-
timeout=timeout,
|
|
56
|
-
max_retries=max_retries,
|
|
57
|
-
default_headers=default_headers,
|
|
56
|
+
# Initialize base API client
|
|
57
|
+
super().__init__(config=client_config)
|
|
58
|
+
|
|
59
|
+
# Initialize all configs with defaults
|
|
60
|
+
alignment_config, transcription_config, diarization_config = self._init_configs(
|
|
61
|
+
alignment_config, transcription_config, diarization_config
|
|
58
62
|
)
|
|
59
63
|
|
|
60
|
-
|
|
61
|
-
|
|
64
|
+
# Store configs
|
|
65
|
+
if caption_config is None:
|
|
66
|
+
caption_config = CaptionConfig()
|
|
67
|
+
self.caption_config = caption_config
|
|
62
68
|
|
|
63
|
-
|
|
64
|
-
self.
|
|
65
|
-
self.device = device
|
|
69
|
+
# audio loader
|
|
70
|
+
self.audio_loader = AudioLoader(device=alignment_config.device)
|
|
66
71
|
|
|
67
|
-
|
|
68
|
-
self
|
|
69
|
-
audio: Pathlike,
|
|
70
|
-
subtitle: Pathlike,
|
|
71
|
-
format: Optional[SubtitleFormat] = None,
|
|
72
|
-
split_sentence: bool = False,
|
|
73
|
-
return_details: bool = False,
|
|
74
|
-
output_subtitle_path: Optional[Pathlike] = None,
|
|
75
|
-
) -> Tuple[List[Supervision], Optional[Pathlike]]:
|
|
76
|
-
"""Perform alignment on audio and subtitle/text.
|
|
72
|
+
# aligner
|
|
73
|
+
self.aligner = Lattice1Aligner(config=alignment_config)
|
|
77
74
|
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
return_details: Return word-level alignment details in Supervision.alignment field
|
|
84
|
-
output_subtitle_path: Output path for aligned subtitle (optional)
|
|
75
|
+
# Initialize diarizer if enabled
|
|
76
|
+
self.diarization_config = diarization_config
|
|
77
|
+
self.diarizer: Optional["LattifAIDiarizer"] = None
|
|
78
|
+
if self.diarization_config.enabled:
|
|
79
|
+
from lattifai.diarization import LattifAIDiarizer # noqa: F811
|
|
85
80
|
|
|
86
|
-
|
|
87
|
-
Tuple containing:
|
|
88
|
-
- List of aligned Supervision objects with timing information
|
|
89
|
-
- Output subtitle path (if output_subtitle_path was provided)
|
|
81
|
+
self.diarizer = LattifAIDiarizer(config=self.diarization_config)
|
|
90
82
|
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
83
|
+
# Initialize shared components (transcriber, downloader)
|
|
84
|
+
self._init_shared_components(transcription_config)
|
|
85
|
+
|
|
86
|
+
def alignment(
|
|
87
|
+
self,
|
|
88
|
+
input_media: Union[Pathlike, AudioData],
|
|
89
|
+
input_caption: Optional[Union[Pathlike, Caption]] = None,
|
|
90
|
+
output_caption_path: Optional[Pathlike] = None,
|
|
91
|
+
input_caption_format: Optional[InputCaptionFormat] = None,
|
|
92
|
+
split_sentence: Optional[bool] = None,
|
|
93
|
+
channel_selector: Optional[str | int] = "average",
|
|
94
|
+
) -> Caption:
|
|
97
95
|
try:
|
|
98
|
-
#
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
f"Failed to parse subtitle file: {subtitle}",
|
|
106
|
-
subtitle_path=str(subtitle),
|
|
107
|
-
context={"original_error": str(e)},
|
|
96
|
+
# Step 1: Get caption
|
|
97
|
+
if isinstance(input_media, AudioData):
|
|
98
|
+
media_audio = input_media
|
|
99
|
+
else:
|
|
100
|
+
media_audio = self.audio_loader(
|
|
101
|
+
input_media,
|
|
102
|
+
channel_selector=channel_selector,
|
|
108
103
|
)
|
|
109
104
|
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
105
|
+
if not input_caption:
|
|
106
|
+
caption = self._transcribe(media_audio, source_lang=self.caption_config.source_lang, is_async=False)
|
|
107
|
+
else:
|
|
108
|
+
caption = self._read_caption(input_caption, input_caption_format)
|
|
109
|
+
|
|
110
|
+
output_caption_path = output_caption_path or self.caption_config.output_path
|
|
111
|
+
|
|
112
|
+
# Step 2: Check if segmented alignment is needed
|
|
113
|
+
alignment_strategy = self.aligner.config.strategy
|
|
114
|
+
|
|
115
|
+
if alignment_strategy != "entire" or caption.transcription:
|
|
116
|
+
print(colorful.cyan(f"🔄 Using segmented alignment strategy: {alignment_strategy}"))
|
|
117
|
+
|
|
118
|
+
if caption.supervisions and alignment_strategy == "transcription":
|
|
119
|
+
# raise NotImplementedError("Transcription-based alignment is not yet implemented.")
|
|
120
|
+
assert (
|
|
121
|
+
"gemini" not in self.transcriber.name.lower()
|
|
122
|
+
), "Transcription-based alignment is not supported with Gemini transcriber."
|
|
123
|
+
assert (
|
|
124
|
+
caption.supervisions
|
|
125
|
+
), "Input caption should contain supervisions when using transcription-based alignment."
|
|
126
|
+
if not caption.transcription:
|
|
127
|
+
import asyncio
|
|
128
|
+
|
|
129
|
+
print(colorful.cyan("📝 Transcribing media for alignment..."))
|
|
130
|
+
if output_caption_path:
|
|
131
|
+
transcript_file = (
|
|
132
|
+
Path(str(output_caption_path)).parent
|
|
133
|
+
/ f"{Path(str(media_audio)).stem}_{self.transcriber.file_name}"
|
|
134
|
+
)
|
|
135
|
+
if transcript_file.exists():
|
|
136
|
+
# print(colorful.cyan(f"Reading existing transcription from {transcript_file}"))
|
|
137
|
+
transcript = self._read_caption(transcript_file, verbose=False)
|
|
138
|
+
caption.transcription = transcript.supervisions
|
|
139
|
+
caption.audio_events = transcript.audio_events
|
|
140
|
+
|
|
141
|
+
if not caption.transcription:
|
|
142
|
+
transcript = asyncio.run(
|
|
143
|
+
self.transcriber.transcribe(media_audio, language=self.caption_config.source_lang)
|
|
144
|
+
)
|
|
145
|
+
caption.transcription = transcript.transcription
|
|
146
|
+
caption.audio_events = transcript.audio_events
|
|
147
|
+
|
|
148
|
+
# Align caption.supervisions with transcription to get segments
|
|
149
|
+
import regex
|
|
150
|
+
from error_align import ErrorAlign, error_align # noqa: F401
|
|
151
|
+
from error_align.utils import DELIMITERS, NUMERIC_TOKEN, STANDARD_TOKEN, OpType
|
|
152
|
+
|
|
153
|
+
JOIN_TOKEN = "❄"
|
|
154
|
+
if JOIN_TOKEN not in DELIMITERS:
|
|
155
|
+
DELIMITERS.add(JOIN_TOKEN)
|
|
156
|
+
|
|
157
|
+
def custom_tokenizer(text: str) -> list:
|
|
158
|
+
"""Default tokenizer that splits text into words based on whitespace.
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
text (str): The input text to tokenize.
|
|
162
|
+
|
|
163
|
+
Returns:
|
|
164
|
+
list: A list of tokens (words).
|
|
165
|
+
|
|
166
|
+
"""
|
|
167
|
+
# Escape JOIN_TOKEN for use in regex pattern
|
|
168
|
+
escaped_join_token = regex.escape(JOIN_TOKEN)
|
|
169
|
+
return list(
|
|
170
|
+
regex.finditer(
|
|
171
|
+
rf"({NUMERIC_TOKEN})|({STANDARD_TOKEN}|{escaped_join_token})",
|
|
172
|
+
text,
|
|
173
|
+
regex.UNICODE | regex.VERBOSE,
|
|
174
|
+
)
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
alignments = error_align(
|
|
178
|
+
f"{JOIN_TOKEN}".join(sup.text for sup in caption.supervisions),
|
|
179
|
+
f"{JOIN_TOKEN}".join(sup.text for sup in caption.transcription),
|
|
180
|
+
tokenizer=custom_tokenizer,
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
for align in alignments:
|
|
184
|
+
if align.hyp == JOIN_TOKEN and align.op_type == OpType.MATCH:
|
|
185
|
+
pass
|
|
186
|
+
|
|
187
|
+
# if align.op_type == OpType.MATCH:
|
|
188
|
+
# continue
|
|
189
|
+
# elif align.op_type in (OpType.INSERT, OpType.DELETE, OpType.SUBSTITUTE):
|
|
190
|
+
# # print(colorful.yellow(f"⚠️ Alignment warning: {op}"))
|
|
191
|
+
# pass
|
|
192
|
+
|
|
193
|
+
raise NotImplementedError("Transcription-based segmentation is not yet implemented.")
|
|
194
|
+
else:
|
|
195
|
+
if caption.transcription:
|
|
196
|
+
if not caption.supervisions: # youtube + transcription case
|
|
197
|
+
segments = [(sup.start, sup.end, [sup], not sup.text) for sup in caption.transcription]
|
|
198
|
+
else:
|
|
199
|
+
raise NotImplementedError(
|
|
200
|
+
f"Input caption with both supervisions and transcription(strategy={alignment_strategy}) is not supported."
|
|
201
|
+
)
|
|
202
|
+
elif self.aligner.config.trust_caption_timestamps:
|
|
203
|
+
# Create segmenter
|
|
204
|
+
segmenter = Segmenter(self.aligner.config)
|
|
205
|
+
# Create segments from caption
|
|
206
|
+
segments = segmenter(caption)
|
|
207
|
+
else:
|
|
208
|
+
raise NotImplementedError(
|
|
209
|
+
"Segmented alignment without trusting input timestamps is not yet implemented."
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
# align each segment
|
|
213
|
+
supervisions, alignments = [], []
|
|
214
|
+
for i, (start, end, _supervisions, skipalign) in enumerate(segments, 1):
|
|
215
|
+
print(
|
|
216
|
+
colorful.green(
|
|
217
|
+
f" ⏩ aligning segment {i:04d}/{len(segments):04d}: {start:8.2f}s - {end:8.2f}s"
|
|
218
|
+
)
|
|
219
|
+
)
|
|
220
|
+
if skipalign:
|
|
221
|
+
supervisions.extend(_supervisions)
|
|
222
|
+
alignments.extend(_supervisions) # may overlap with supervisions, but harmless
|
|
223
|
+
continue
|
|
224
|
+
|
|
225
|
+
offset = round(start, 4)
|
|
226
|
+
emission = self.aligner.emission(
|
|
227
|
+
media_audio.tensor[
|
|
228
|
+
:, int(start * media_audio.sampling_rate) : int(end * media_audio.sampling_rate)
|
|
229
|
+
]
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
# Align segment
|
|
233
|
+
_supervisions, _alignments = self.aligner.alignment(
|
|
234
|
+
media_audio,
|
|
235
|
+
_supervisions,
|
|
236
|
+
split_sentence=split_sentence or self.caption_config.split_sentence,
|
|
237
|
+
return_details=self.caption_config.word_level
|
|
238
|
+
or (output_caption_path and str(output_caption_path).endswith(".TextGrid")),
|
|
239
|
+
emission=emission,
|
|
240
|
+
offset=offset,
|
|
241
|
+
verbose=False,
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
supervisions.extend(_supervisions)
|
|
245
|
+
alignments.extend(_alignments)
|
|
246
|
+
else:
|
|
247
|
+
# Step 2-4: Standard single-pass alignment
|
|
248
|
+
supervisions, alignments = self.aligner.alignment(
|
|
249
|
+
media_audio,
|
|
250
|
+
caption.supervisions,
|
|
251
|
+
split_sentence=split_sentence or self.caption_config.split_sentence,
|
|
252
|
+
return_details=self.caption_config.word_level
|
|
253
|
+
or (output_caption_path and str(output_caption_path).endswith(".TextGrid")),
|
|
132
254
|
)
|
|
133
255
|
|
|
134
|
-
#
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
256
|
+
# Update caption with aligned results
|
|
257
|
+
caption.supervisions = supervisions
|
|
258
|
+
caption.alignments = alignments
|
|
259
|
+
|
|
260
|
+
# Step 5: Speaker diarization
|
|
261
|
+
if self.diarization_config.enabled and self.diarizer:
|
|
262
|
+
print(colorful.cyan("🗣️ Performing speaker diarization..."))
|
|
263
|
+
caption = self.speaker_diarization(
|
|
264
|
+
input_media=media_audio,
|
|
265
|
+
caption=caption,
|
|
266
|
+
output_caption_path=output_caption_path,
|
|
139
267
|
)
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
print(colorful.red(" x Failed to decode lattice alignment results"))
|
|
143
|
-
raise e
|
|
144
|
-
except Exception as e:
|
|
145
|
-
print(colorful.red(" x Failed to decode lattice alignment results"))
|
|
146
|
-
raise LatticeDecodingError(lattice_id, original_error=e)
|
|
147
|
-
|
|
148
|
-
# step5: export alignments to target format
|
|
149
|
-
if output_subtitle_path:
|
|
150
|
-
try:
|
|
151
|
-
SubtitleIO.write(alignments, output_path=output_subtitle_path)
|
|
152
|
-
print(colorful.green(f"🎉🎉🎉🎉🎉 Subtitle file written to: {output_subtitle_path}"))
|
|
153
|
-
except Exception as e:
|
|
154
|
-
raise SubtitleProcessingError(
|
|
155
|
-
f"Failed to write output file: {output_subtitle_path}",
|
|
156
|
-
subtitle_path=str(output_subtitle_path),
|
|
157
|
-
context={"original_error": str(e)},
|
|
158
|
-
)
|
|
159
|
-
return (alignments, output_subtitle_path)
|
|
268
|
+
elif output_caption_path:
|
|
269
|
+
self._write_caption(caption, output_caption_path)
|
|
160
270
|
|
|
161
|
-
|
|
271
|
+
return caption
|
|
272
|
+
except (CaptionProcessingError, LatticeEncodingError, AlignmentError, LatticeDecodingError):
|
|
162
273
|
# Re-raise our specific errors as-is
|
|
163
274
|
raise
|
|
164
275
|
except Exception as e:
|
|
165
276
|
# Catch any unexpected errors and wrap them
|
|
166
277
|
raise AlignmentError(
|
|
167
278
|
"Unexpected error during alignment process",
|
|
168
|
-
|
|
169
|
-
|
|
279
|
+
media_path=str(input_media),
|
|
280
|
+
caption_path=str(input_caption),
|
|
170
281
|
context={"original_error": str(e), "error_type": e.__class__.__name__},
|
|
171
282
|
)
|
|
172
283
|
|
|
173
|
-
|
|
174
|
-
class AsyncLattifAI(AsyncAPIClient):
|
|
175
|
-
"""Asynchronous LattifAI client."""
|
|
176
|
-
|
|
177
|
-
def __init__(
|
|
284
|
+
def speaker_diarization(
|
|
178
285
|
self,
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
max_retries: int = 2,
|
|
186
|
-
default_headers: Optional[Dict[str, str]] = None,
|
|
187
|
-
) -> None:
|
|
188
|
-
if api_key is None:
|
|
189
|
-
api_key = os.environ.get("LATTIFAI_API_KEY")
|
|
190
|
-
if api_key is None:
|
|
191
|
-
raise ConfigurationError(
|
|
192
|
-
"The api_key client option must be set either by passing api_key to the client "
|
|
193
|
-
"or by setting the LATTIFAI_API_KEY environment variable"
|
|
194
|
-
)
|
|
195
|
-
|
|
196
|
-
if base_url is None:
|
|
197
|
-
base_url = os.environ.get("LATTIFAI_BASE_URL")
|
|
198
|
-
if not base_url:
|
|
199
|
-
base_url = "https://api.lattifai.com/v1"
|
|
200
|
-
|
|
201
|
-
super().__init__(
|
|
202
|
-
api_key=api_key,
|
|
203
|
-
base_url=base_url,
|
|
204
|
-
timeout=timeout,
|
|
205
|
-
max_retries=max_retries,
|
|
206
|
-
default_headers=default_headers,
|
|
207
|
-
)
|
|
286
|
+
input_media: AudioData,
|
|
287
|
+
caption: Caption,
|
|
288
|
+
output_caption_path: Optional[Pathlike] = None,
|
|
289
|
+
) -> Caption:
|
|
290
|
+
"""
|
|
291
|
+
Perform speaker diarization on aligned caption.
|
|
208
292
|
|
|
209
|
-
|
|
210
|
-
|
|
293
|
+
Args:
|
|
294
|
+
input_media: AudioData object
|
|
295
|
+
caption: Caption object with aligned segments
|
|
296
|
+
output_caption_path: Optional path to write diarized caption
|
|
211
297
|
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
self.device = device
|
|
298
|
+
Returns:
|
|
299
|
+
Caption object with speaker labels assigned
|
|
215
300
|
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
301
|
+
Raises:
|
|
302
|
+
RuntimeError: If diarizer is not initialized or diarization fails
|
|
303
|
+
"""
|
|
304
|
+
if not self.diarizer:
|
|
305
|
+
raise RuntimeError("Diarizer not initialized. Set diarization_config.enabled=True")
|
|
306
|
+
|
|
307
|
+
# Perform diarization and assign speaker labels to caption alignments
|
|
308
|
+
if output_caption_path:
|
|
309
|
+
diarization_file = Path(str(output_caption_path)).with_suffix(".SpkDiar")
|
|
310
|
+
if diarization_file.exists():
|
|
311
|
+
print(colorful.cyan(f"Reading existing speaker diarization from {diarization_file}"))
|
|
312
|
+
caption.read_speaker_diarization(diarization_file)
|
|
313
|
+
|
|
314
|
+
diarization, alignments = self.diarizer.diarize_with_alignments(
|
|
315
|
+
input_media, caption.alignments, diarization=caption.speaker_diarization
|
|
316
|
+
)
|
|
317
|
+
caption.alignments = alignments
|
|
318
|
+
caption.speaker_diarization = diarization
|
|
319
|
+
|
|
320
|
+
# Write output if requested
|
|
321
|
+
if output_caption_path:
|
|
322
|
+
self._write_caption(caption, output_caption_path)
|
|
323
|
+
|
|
324
|
+
if self.diarizer.config.debug:
|
|
325
|
+
# debug
|
|
326
|
+
from tgt import Interval, IntervalTier, TextGrid, write_to_file
|
|
327
|
+
|
|
328
|
+
debug_tg = TextGrid()
|
|
329
|
+
transcript_tier = IntervalTier(
|
|
330
|
+
start_time=0,
|
|
331
|
+
end_time=input_media.duration,
|
|
332
|
+
name="transcript",
|
|
333
|
+
objects=[Interval(sup.start, sup.end, sup.text) for sup in caption.alignments],
|
|
235
334
|
)
|
|
335
|
+
debug_tg.add_tier(transcript_tier)
|
|
236
336
|
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
337
|
+
speaker_tier = IntervalTier(
|
|
338
|
+
start_time=0,
|
|
339
|
+
end_time=input_media.duration,
|
|
340
|
+
name="speaker",
|
|
341
|
+
objects=[Interval(sup.start, sup.end, sup.speaker) for sup in caption.alignments],
|
|
242
342
|
)
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
343
|
+
debug_tg.add_tier(speaker_tier)
|
|
344
|
+
|
|
345
|
+
from collections import defaultdict
|
|
346
|
+
|
|
347
|
+
spk2intervals = defaultdict(lambda: [])
|
|
348
|
+
num_multispk = 0
|
|
349
|
+
|
|
350
|
+
segments, skipks = [], []
|
|
351
|
+
for k, supervision in enumerate(caption.alignments): # TODO: alignments 本身存在 overlap, eg: [event]
|
|
352
|
+
# supervision = caption.alignments[k]
|
|
353
|
+
if supervision.custom.get("speaker", []):
|
|
354
|
+
num_multispk += 1
|
|
355
|
+
else:
|
|
356
|
+
continue
|
|
357
|
+
|
|
358
|
+
if k in skipks:
|
|
359
|
+
continue
|
|
360
|
+
|
|
361
|
+
for speaker in supervision.custom.get("speaker", []):
|
|
362
|
+
for name, start_time, end_time in speaker:
|
|
363
|
+
spk2intervals[name].append(Interval(start_time, end_time, name))
|
|
364
|
+
|
|
365
|
+
_segments = []
|
|
366
|
+
if k > 0:
|
|
367
|
+
_segments.append(caption.alignments[k - 1])
|
|
368
|
+
_segments.append(supervision)
|
|
369
|
+
while k + 1 < len(caption.alignments):
|
|
370
|
+
skipks.append(k + 1)
|
|
371
|
+
next_sup = caption.alignments[k + 1]
|
|
372
|
+
if not next_sup.custom.get("speaker", []):
|
|
373
|
+
k += 1
|
|
374
|
+
break
|
|
375
|
+
_segments.append(next_sup)
|
|
376
|
+
k += 1
|
|
377
|
+
|
|
378
|
+
if segments:
|
|
379
|
+
if _segments[0].start >= segments[-1][-1].end:
|
|
380
|
+
segments.append(_segments)
|
|
381
|
+
else:
|
|
382
|
+
if _segments[1:]:
|
|
383
|
+
segments.append(_segments[1:])
|
|
384
|
+
else:
|
|
385
|
+
pass
|
|
386
|
+
else:
|
|
387
|
+
segments.append(_segments)
|
|
388
|
+
|
|
389
|
+
print(
|
|
390
|
+
f"Number of multi-speaker segments: {num_multispk}/{len(caption.alignments)} segments: {len(segments)}"
|
|
258
391
|
)
|
|
259
392
|
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
lattice_id, lattice_results, supervisions=supervisions, return_details=return_details
|
|
264
|
-
)
|
|
265
|
-
print(colorful.green(f" ✓ Successfully aligned {len(alignments)} segments"))
|
|
266
|
-
except LatticeDecodingError as e:
|
|
267
|
-
print(colorful.red(" x Failed to decode lattice alignment results"))
|
|
268
|
-
raise e
|
|
269
|
-
except Exception as e:
|
|
270
|
-
print(colorful.red(" x Failed to decode lattice alignment results"))
|
|
271
|
-
raise LatticeDecodingError(lattice_id, original_error=e)
|
|
272
|
-
|
|
273
|
-
if output_subtitle_path:
|
|
274
|
-
try:
|
|
275
|
-
await asyncio.to_thread(SubtitleIO.write, alignments, output_subtitle_path)
|
|
276
|
-
print(colorful.green(f"🎉🎉🎉🎉🎉 Subtitle file written to: {output_subtitle_path}"))
|
|
277
|
-
except Exception as e:
|
|
278
|
-
raise SubtitleProcessingError(
|
|
279
|
-
f"Failed to write output file: {output_subtitle_path}",
|
|
280
|
-
subtitle_path=str(output_subtitle_path),
|
|
281
|
-
context={"original_error": str(e)},
|
|
393
|
+
for speaker, intervals in sorted(spk2intervals.items(), key=lambda x: x[0]):
|
|
394
|
+
speaker_tier = IntervalTier(
|
|
395
|
+
start_time=0, end_time=input_media.duration, name=speaker, objects=intervals
|
|
282
396
|
)
|
|
397
|
+
debug_tg.add_tier(speaker_tier)
|
|
398
|
+
|
|
399
|
+
for tier in caption.speaker_diarization.tiers:
|
|
400
|
+
tier.name = f"Diarization-{tier.name}"
|
|
401
|
+
debug_tg.add_tier(tier)
|
|
402
|
+
|
|
403
|
+
tier = IntervalTier(
|
|
404
|
+
start_time=0,
|
|
405
|
+
end_time=input_media.duration,
|
|
406
|
+
name="resegment",
|
|
407
|
+
objects=[
|
|
408
|
+
Interval(round(sup.start, 2), round(sup.end, 2), sup.text)
|
|
409
|
+
for _segments in segments
|
|
410
|
+
for sup in _segments
|
|
411
|
+
],
|
|
412
|
+
)
|
|
413
|
+
debug_tg.add_tier(tier)
|
|
283
414
|
|
|
284
|
-
|
|
415
|
+
# if caption.audio_events:
|
|
416
|
+
# for tier in caption.audio_events.tiers:
|
|
417
|
+
# # tier.name = f"{tier.name}"
|
|
418
|
+
# debug_tg.add_tier(tier)
|
|
285
419
|
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
420
|
+
debug_tgt_file = Path(str(output_caption_path)).with_suffix(".DiarizationDebug.TextGrid")
|
|
421
|
+
write_to_file(debug_tg, debug_tgt_file, format="long")
|
|
422
|
+
|
|
423
|
+
return caption
|
|
424
|
+
|
|
425
|
+
def youtube(
|
|
426
|
+
self,
|
|
427
|
+
url: str,
|
|
428
|
+
output_dir: Optional[Pathlike] = None,
|
|
429
|
+
media_format: Optional[str] = None,
|
|
430
|
+
source_lang: Optional[str] = None,
|
|
431
|
+
force_overwrite: bool = False,
|
|
432
|
+
output_caption_path: Optional[Pathlike] = None,
|
|
433
|
+
split_sentence: Optional[bool] = None,
|
|
434
|
+
use_transcription: bool = False,
|
|
435
|
+
channel_selector: Optional[str | int] = "average",
|
|
436
|
+
) -> Caption:
|
|
437
|
+
# Prepare output directory and media format
|
|
438
|
+
output_dir = self._prepare_youtube_output_dir(output_dir)
|
|
439
|
+
media_format = self._determine_media_format(media_format)
|
|
440
|
+
|
|
441
|
+
print(colorful.cyan(f"🎬 Starting YouTube workflow for: {url}"))
|
|
442
|
+
|
|
443
|
+
# Step 1: Download media
|
|
444
|
+
media_file = self._download_media_sync(url, output_dir, media_format, force_overwrite)
|
|
445
|
+
|
|
446
|
+
media_audio = self.audio_loader(media_file, channel_selector=channel_selector)
|
|
447
|
+
|
|
448
|
+
# Step 2: Get or create captions (download or transcribe)
|
|
449
|
+
caption = self._download_or_transcribe_caption(
|
|
450
|
+
url,
|
|
451
|
+
output_dir,
|
|
452
|
+
media_audio,
|
|
453
|
+
force_overwrite,
|
|
454
|
+
source_lang or self.caption_config.source_lang,
|
|
455
|
+
is_async=False,
|
|
456
|
+
use_transcription=use_transcription,
|
|
457
|
+
)
|
|
458
|
+
|
|
459
|
+
# Step 3: Generate output path if not provided
|
|
460
|
+
output_caption_path = self._generate_output_caption_path(output_caption_path, media_file, output_dir)
|
|
461
|
+
|
|
462
|
+
# Step 4: Perform alignment
|
|
463
|
+
print(colorful.cyan("🔗 Performing forced alignment..."))
|
|
464
|
+
|
|
465
|
+
caption: Caption = self.alignment(
|
|
466
|
+
input_media=media_audio,
|
|
467
|
+
input_caption=caption,
|
|
468
|
+
output_caption_path=output_caption_path,
|
|
469
|
+
split_sentence=split_sentence,
|
|
470
|
+
channel_selector=channel_selector,
|
|
471
|
+
)
|
|
472
|
+
|
|
473
|
+
return caption
|
|
474
|
+
|
|
475
|
+
|
|
476
|
+
# Set docstrings for LattifAI methods
|
|
477
|
+
LattifAI.alignment.__doc__ = LattifAIClientMixin._ALIGNMENT_DOC.format(
|
|
478
|
+
async_prefix="",
|
|
479
|
+
async_word="",
|
|
480
|
+
timing_desc="each word",
|
|
481
|
+
concurrency_note="",
|
|
482
|
+
async_suffix1="",
|
|
483
|
+
async_suffix2="",
|
|
484
|
+
async_suffix3="",
|
|
485
|
+
async_suffix4="",
|
|
486
|
+
async_suffix5="",
|
|
487
|
+
format_default="auto-detects",
|
|
488
|
+
export_note=" in the same format as input (or config default)",
|
|
489
|
+
timing_note=" (start, duration, text)",
|
|
490
|
+
example_imports="client = LattifAI()",
|
|
491
|
+
example_code="""alignments, output_path = client.alignment(
|
|
492
|
+
... input_media="speech.wav",
|
|
493
|
+
... input_caption="transcript.srt",
|
|
494
|
+
... output_caption_path="aligned.srt"
|
|
495
|
+
... )
|
|
496
|
+
>>> for seg in alignments:
|
|
497
|
+
... print(f"{seg.start:.2f}s - {seg.end:.2f}s: {seg.text}")""",
|
|
498
|
+
)
|
|
499
|
+
|
|
500
|
+
LattifAI.youtube.__doc__ = LattifAIClientMixin._YOUTUBE_METHOD_DOC.format(client_class="LattifAI", await_keyword="")
|
|
295
501
|
|
|
296
502
|
|
|
297
503
|
if __name__ == "__main__":
|
|
@@ -299,14 +505,17 @@ if __name__ == "__main__":
|
|
|
299
505
|
import sys
|
|
300
506
|
|
|
301
507
|
if len(sys.argv) == 5:
|
|
302
|
-
audio,
|
|
508
|
+
audio, caption, output, split_sentence = sys.argv[1:]
|
|
303
509
|
split_sentence = split_sentence.lower() in ("true", "1", "yes")
|
|
304
510
|
else:
|
|
305
511
|
audio = "tests/data/SA1.wav"
|
|
306
|
-
|
|
512
|
+
caption = "tests/data/SA1.TXT"
|
|
307
513
|
output = None
|
|
308
514
|
split_sentence = False
|
|
309
515
|
|
|
310
|
-
(alignments,
|
|
311
|
-
audio,
|
|
516
|
+
(alignments, output_caption_path) = client.alignment(
|
|
517
|
+
input_media=audio,
|
|
518
|
+
input_caption=caption,
|
|
519
|
+
output_caption_path=output,
|
|
520
|
+
split_sentence=split_sentence,
|
|
312
521
|
)
|