lattifai 1.0.5__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lattifai/__init__.py +11 -12
- lattifai/alignment/lattice1_aligner.py +39 -7
- lattifai/alignment/lattice1_worker.py +135 -147
- lattifai/alignment/tokenizer.py +38 -22
- lattifai/audio2.py +1 -1
- lattifai/caption/caption.py +55 -19
- lattifai/cli/__init__.py +2 -0
- lattifai/cli/caption.py +1 -1
- lattifai/cli/diarization.py +110 -0
- lattifai/cli/transcribe.py +3 -1
- lattifai/cli/youtube.py +11 -0
- lattifai/client.py +32 -111
- lattifai/config/alignment.py +14 -0
- lattifai/config/client.py +5 -0
- lattifai/config/transcription.py +4 -0
- lattifai/diarization/lattifai.py +18 -7
- lattifai/mixin.py +26 -5
- lattifai/transcription/__init__.py +1 -1
- lattifai/transcription/base.py +21 -2
- lattifai/transcription/gemini.py +127 -1
- lattifai/transcription/lattifai.py +30 -2
- lattifai/utils.py +62 -69
- lattifai/workflow/youtube.py +55 -57
- {lattifai-1.0.5.dist-info → lattifai-1.2.0.dist-info}/METADATA +352 -56
- {lattifai-1.0.5.dist-info → lattifai-1.2.0.dist-info}/RECORD +29 -28
- {lattifai-1.0.5.dist-info → lattifai-1.2.0.dist-info}/entry_points.txt +2 -0
- {lattifai-1.0.5.dist-info → lattifai-1.2.0.dist-info}/WHEEL +0 -0
- {lattifai-1.0.5.dist-info → lattifai-1.2.0.dist-info}/licenses/LICENSE +0 -0
- {lattifai-1.0.5.dist-info → lattifai-1.2.0.dist-info}/top_level.txt +0 -0
lattifai/audio2.py
CHANGED
|
@@ -36,7 +36,7 @@ class AudioData(namedtuple("AudioData", ["sampling_rate", "ndarray", "path", "st
|
|
|
36
36
|
@property
|
|
37
37
|
def streaming_mode(self) -> bool:
|
|
38
38
|
"""Indicates whether streaming mode is enabled based on streaming_chunk_secs."""
|
|
39
|
-
if self.streaming_chunk_secs
|
|
39
|
+
if self.streaming_chunk_secs:
|
|
40
40
|
return self.duration > self.streaming_chunk_secs * 1.1
|
|
41
41
|
return False
|
|
42
42
|
|
lattifai/caption/caption.py
CHANGED
|
@@ -4,17 +4,19 @@ import json
|
|
|
4
4
|
import re
|
|
5
5
|
from dataclasses import dataclass, field
|
|
6
6
|
from pathlib import Path
|
|
7
|
-
from typing import Any, Dict, List, Optional,
|
|
7
|
+
from typing import Any, Dict, List, Optional, TypeVar
|
|
8
8
|
|
|
9
9
|
from lhotse.supervision import AlignmentItem
|
|
10
10
|
from lhotse.utils import Pathlike
|
|
11
11
|
from tgt import TextGrid
|
|
12
12
|
|
|
13
|
-
from ..config.caption import InputCaptionFormat, OutputCaptionFormat
|
|
13
|
+
from ..config.caption import InputCaptionFormat, OutputCaptionFormat # noqa: F401
|
|
14
14
|
from .supervision import Supervision
|
|
15
15
|
from .text_parser import normalize_text as normalize_text_fn
|
|
16
16
|
from .text_parser import parse_speaker_text, parse_timestamp_text
|
|
17
17
|
|
|
18
|
+
DiarizationOutput = TypeVar("DiarizationOutput")
|
|
19
|
+
|
|
18
20
|
|
|
19
21
|
@dataclass
|
|
20
22
|
class Caption:
|
|
@@ -40,7 +42,7 @@ class Caption:
|
|
|
40
42
|
# Audio Event Detection results
|
|
41
43
|
audio_events: Optional[TextGrid] = None
|
|
42
44
|
# Speaker Diarization results
|
|
43
|
-
speaker_diarization: Optional[
|
|
45
|
+
speaker_diarization: Optional[DiarizationOutput] = None
|
|
44
46
|
# Alignment results
|
|
45
47
|
alignments: List[Supervision] = field(default_factory=list)
|
|
46
48
|
|
|
@@ -272,7 +274,7 @@ class Caption:
|
|
|
272
274
|
cls,
|
|
273
275
|
transcription: List[Supervision],
|
|
274
276
|
audio_events: Optional[TextGrid] = None,
|
|
275
|
-
speaker_diarization: Optional[
|
|
277
|
+
speaker_diarization: Optional[DiarizationOutput] = None,
|
|
276
278
|
language: Optional[str] = None,
|
|
277
279
|
source_path: Optional[Pathlike] = None,
|
|
278
280
|
metadata: Optional[Dict[str, str]] = None,
|
|
@@ -283,7 +285,7 @@ class Caption:
|
|
|
283
285
|
Args:
|
|
284
286
|
transcription: List of transcription supervision segments
|
|
285
287
|
audio_events: Optional TextGrid with audio event detection results
|
|
286
|
-
speaker_diarization: Optional
|
|
288
|
+
speaker_diarization: Optional DiarizationOutput with speaker diarization results
|
|
287
289
|
language: Language code
|
|
288
290
|
source_path: Source file path
|
|
289
291
|
metadata: Additional metadata
|
|
@@ -384,9 +386,9 @@ class Caption:
|
|
|
384
386
|
"""
|
|
385
387
|
Read speaker diarization TextGrid from file.
|
|
386
388
|
"""
|
|
387
|
-
from
|
|
389
|
+
from lattifai_core.diarization import DiarizationOutput
|
|
388
390
|
|
|
389
|
-
self.speaker_diarization =
|
|
391
|
+
self.speaker_diarization = DiarizationOutput.read(path)
|
|
390
392
|
return self.speaker_diarization
|
|
391
393
|
|
|
392
394
|
def write_speaker_diarization(
|
|
@@ -399,9 +401,7 @@ class Caption:
|
|
|
399
401
|
if not self.speaker_diarization:
|
|
400
402
|
raise ValueError("No speaker diarization data to write.")
|
|
401
403
|
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
write_to_file(self.speaker_diarization, path, format="long")
|
|
404
|
+
self.speaker_diarization.write(path)
|
|
405
405
|
return path
|
|
406
406
|
|
|
407
407
|
@staticmethod
|
|
@@ -451,7 +451,10 @@ class Caption:
|
|
|
451
451
|
else:
|
|
452
452
|
if include_speaker_in_text and sup.speaker is not None:
|
|
453
453
|
# Use [SPEAKER]: format for consistency with parsing
|
|
454
|
-
|
|
454
|
+
if not sup.has_custom("original_speaker") or sup.custom["original_speaker"]:
|
|
455
|
+
text = f"[{sup.speaker}]: {sup.text}"
|
|
456
|
+
else:
|
|
457
|
+
text = f"{sup.text}"
|
|
455
458
|
else:
|
|
456
459
|
text = sup.text
|
|
457
460
|
f.write(f"[{sup.start:.2f}-{sup.end:.2f}] {text}\n")
|
|
@@ -471,7 +474,12 @@ class Caption:
|
|
|
471
474
|
tg = TextGrid()
|
|
472
475
|
supervisions, words, scores = [], [], {"utterances": [], "words": []}
|
|
473
476
|
for supervision in sorted(alignments, key=lambda x: x.start):
|
|
474
|
-
|
|
477
|
+
# Respect `original_speaker` custom flag: default to include speaker when missing
|
|
478
|
+
if (
|
|
479
|
+
include_speaker_in_text
|
|
480
|
+
and supervision.speaker is not None
|
|
481
|
+
and (not supervision.has_custom("original_speaker") or supervision.custom["original_speaker"])
|
|
482
|
+
):
|
|
475
483
|
text = f"{supervision.speaker} {supervision.text}"
|
|
476
484
|
else:
|
|
477
485
|
text = supervision.text
|
|
@@ -526,7 +534,10 @@ class Caption:
|
|
|
526
534
|
)
|
|
527
535
|
else:
|
|
528
536
|
if include_speaker_in_text and sup.speaker is not None:
|
|
529
|
-
|
|
537
|
+
if not sup.has_custom("original_speaker") or sup.custom["original_speaker"]:
|
|
538
|
+
text = f"{sup.speaker} {sup.text}"
|
|
539
|
+
else:
|
|
540
|
+
text = f"{sup.text}"
|
|
530
541
|
else:
|
|
531
542
|
text = sup.text
|
|
532
543
|
subs.append(
|
|
@@ -830,7 +841,16 @@ class Caption:
|
|
|
830
841
|
if cls._is_youtube_vtt_with_word_timestamps(content):
|
|
831
842
|
return cls._parse_youtube_vtt_with_word_timestamps(content, normalize_text)
|
|
832
843
|
|
|
833
|
-
|
|
844
|
+
# Match Gemini format: explicit format, or files ending with Gemini.md/Gemini3.md,
|
|
845
|
+
# or files containing "gemini" in the name with .md extension
|
|
846
|
+
caption_str = str(caption).lower()
|
|
847
|
+
is_gemini_format = (
|
|
848
|
+
format == "gemini"
|
|
849
|
+
or str(caption).endswith("Gemini.md")
|
|
850
|
+
or str(caption).endswith("Gemini3.md")
|
|
851
|
+
or ("gemini" in caption_str and caption_str.endswith(".md"))
|
|
852
|
+
)
|
|
853
|
+
if is_gemini_format:
|
|
834
854
|
from .gemini_reader import GeminiReader
|
|
835
855
|
|
|
836
856
|
supervisions = GeminiReader.extract_for_alignment(caption)
|
|
@@ -1242,7 +1262,11 @@ class Caption:
|
|
|
1242
1262
|
if include_speaker_in_text:
|
|
1243
1263
|
file.write("speaker\tstart\tend\ttext\n")
|
|
1244
1264
|
for supervision in alignments:
|
|
1245
|
-
|
|
1265
|
+
# Respect `original_speaker` custom flag: default to True when missing
|
|
1266
|
+
include_speaker = supervision.speaker and (
|
|
1267
|
+
not supervision.has_custom("original_speaker") or supervision.custom["original_speaker"]
|
|
1268
|
+
)
|
|
1269
|
+
speaker = supervision.speaker if include_speaker else ""
|
|
1246
1270
|
start_ms = round(1000 * supervision.start)
|
|
1247
1271
|
end_ms = round(1000 * supervision.end)
|
|
1248
1272
|
text = supervision.text.strip().replace("\t", " ")
|
|
@@ -1280,7 +1304,10 @@ class Caption:
|
|
|
1280
1304
|
writer = csv.writer(file)
|
|
1281
1305
|
writer.writerow(["speaker", "start", "end", "text"])
|
|
1282
1306
|
for supervision in alignments:
|
|
1283
|
-
|
|
1307
|
+
include_speaker = supervision.speaker and (
|
|
1308
|
+
not supervision.has_custom("original_speaker") or supervision.custom["original_speaker"]
|
|
1309
|
+
)
|
|
1310
|
+
speaker = supervision.speaker if include_speaker else ""
|
|
1284
1311
|
start_ms = round(1000 * supervision.start)
|
|
1285
1312
|
end_ms = round(1000 * supervision.end)
|
|
1286
1313
|
text = supervision.text.strip()
|
|
@@ -1318,7 +1345,12 @@ class Caption:
|
|
|
1318
1345
|
end = supervision.end
|
|
1319
1346
|
text = supervision.text.strip().replace("\t", " ")
|
|
1320
1347
|
|
|
1321
|
-
|
|
1348
|
+
# Respect `original_speaker` custom flag when adding speaker prefix
|
|
1349
|
+
if (
|
|
1350
|
+
include_speaker_in_text
|
|
1351
|
+
and supervision.speaker
|
|
1352
|
+
and (not supervision.has_custom("original_speaker") or supervision.custom["original_speaker"])
|
|
1353
|
+
):
|
|
1322
1354
|
text = f"[[{supervision.speaker}]]{text}"
|
|
1323
1355
|
|
|
1324
1356
|
file.write(f"{start}\t{end}\t{text}\n")
|
|
@@ -1364,9 +1396,13 @@ class Caption:
|
|
|
1364
1396
|
# Write timestamp line
|
|
1365
1397
|
file.write(f"{start_time},{end_time}\n")
|
|
1366
1398
|
|
|
1367
|
-
# Write text (with optional speaker)
|
|
1399
|
+
# Write text (with optional speaker). Respect `original_speaker` custom flag.
|
|
1368
1400
|
text = supervision.text.strip()
|
|
1369
|
-
if
|
|
1401
|
+
if (
|
|
1402
|
+
include_speaker_in_text
|
|
1403
|
+
and supervision.speaker
|
|
1404
|
+
and (not supervision.has_custom("original_speaker") or supervision.custom["original_speaker"])
|
|
1405
|
+
):
|
|
1370
1406
|
text = f"{supervision.speaker}: {text}"
|
|
1371
1407
|
|
|
1372
1408
|
file.write(f"{text}\n")
|
lattifai/cli/__init__.py
CHANGED
|
@@ -5,12 +5,14 @@ import nemo_run as run # noqa: F401
|
|
|
5
5
|
# Import and re-export entrypoints at package level so NeMo Run can find them
|
|
6
6
|
from lattifai.cli.alignment import align
|
|
7
7
|
from lattifai.cli.caption import convert
|
|
8
|
+
from lattifai.cli.diarization import diarize
|
|
8
9
|
from lattifai.cli.transcribe import transcribe, transcribe_align
|
|
9
10
|
from lattifai.cli.youtube import youtube
|
|
10
11
|
|
|
11
12
|
__all__ = [
|
|
12
13
|
"align",
|
|
13
14
|
"convert",
|
|
15
|
+
"diarize",
|
|
14
16
|
"transcribe",
|
|
15
17
|
"transcribe_align",
|
|
16
18
|
"youtube",
|
lattifai/cli/caption.py
CHANGED
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
"""Speaker diarization CLI entry point with nemo_run."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
import colorful
|
|
7
|
+
import nemo_run as run
|
|
8
|
+
from typing_extensions import Annotated
|
|
9
|
+
|
|
10
|
+
from lattifai.client import LattifAI
|
|
11
|
+
from lattifai.config import AlignmentConfig, CaptionConfig, ClientConfig, DiarizationConfig, MediaConfig
|
|
12
|
+
from lattifai.utils import safe_print
|
|
13
|
+
|
|
14
|
+
__all__ = ["diarize"]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@run.cli.entrypoint(name="run", namespace="diarization")
|
|
18
|
+
def diarize(
|
|
19
|
+
input_media: Optional[str] = None,
|
|
20
|
+
input_caption: Optional[str] = None,
|
|
21
|
+
output_caption: Optional[str] = None,
|
|
22
|
+
media: Annotated[Optional[MediaConfig], run.Config[MediaConfig]] = None,
|
|
23
|
+
caption: Annotated[Optional[CaptionConfig], run.Config[CaptionConfig]] = None,
|
|
24
|
+
client: Annotated[Optional[ClientConfig], run.Config[ClientConfig]] = None,
|
|
25
|
+
alignment: Annotated[Optional[AlignmentConfig], run.Config[AlignmentConfig]] = None,
|
|
26
|
+
diarization: Annotated[Optional[DiarizationConfig], run.Config[DiarizationConfig]] = None,
|
|
27
|
+
):
|
|
28
|
+
"""Run speaker diarization on aligned captions and audio."""
|
|
29
|
+
|
|
30
|
+
media_config = media or MediaConfig()
|
|
31
|
+
caption_config = caption or CaptionConfig()
|
|
32
|
+
diarization_config = diarization or DiarizationConfig()
|
|
33
|
+
|
|
34
|
+
if input_media and media_config.input_path:
|
|
35
|
+
raise ValueError("Cannot specify both positional input_media and media.input_path.")
|
|
36
|
+
if input_media:
|
|
37
|
+
media_config.set_input_path(input_media)
|
|
38
|
+
if not media_config.input_path:
|
|
39
|
+
raise ValueError("Input media path must be provided via positional input_media or media.input_path.")
|
|
40
|
+
|
|
41
|
+
if input_caption and caption_config.input_path:
|
|
42
|
+
raise ValueError("Cannot specify both positional input_caption and caption.input_path.")
|
|
43
|
+
if input_caption:
|
|
44
|
+
caption_config.set_input_path(input_caption)
|
|
45
|
+
if not caption_config.input_path:
|
|
46
|
+
raise ValueError("Input caption path must be provided via positional input_caption or caption.input_path.")
|
|
47
|
+
|
|
48
|
+
if output_caption and caption_config.output_path:
|
|
49
|
+
raise ValueError("Cannot specify both positional output_caption and caption.output_path.")
|
|
50
|
+
if output_caption:
|
|
51
|
+
caption_config.set_output_path(output_caption)
|
|
52
|
+
|
|
53
|
+
diarization_config.enabled = True
|
|
54
|
+
|
|
55
|
+
client_instance = LattifAI(
|
|
56
|
+
client_config=client,
|
|
57
|
+
alignment_config=alignment,
|
|
58
|
+
caption_config=caption_config,
|
|
59
|
+
diarization_config=diarization_config,
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
safe_print(colorful.cyan("🎧 Loading media for diarization..."))
|
|
63
|
+
media_audio = client_instance.audio_loader(
|
|
64
|
+
media_config.input_path,
|
|
65
|
+
channel_selector=media_config.channel_selector,
|
|
66
|
+
streaming_chunk_secs=media_config.streaming_chunk_secs,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
safe_print(colorful.cyan("📖 Loading caption segments..."))
|
|
70
|
+
caption_obj = client_instance._read_caption(
|
|
71
|
+
caption_config.input_path,
|
|
72
|
+
input_caption_format=None if caption_config.input_format == "auto" else caption_config.input_format,
|
|
73
|
+
verbose=False,
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
if not caption_obj.alignments:
|
|
77
|
+
caption_obj.alignments = caption_obj.supervisions
|
|
78
|
+
|
|
79
|
+
if not caption_obj.alignments:
|
|
80
|
+
raise ValueError("Caption does not contain segments for diarization.")
|
|
81
|
+
|
|
82
|
+
if caption_config.output_path:
|
|
83
|
+
output_path = caption_config.output_path
|
|
84
|
+
else:
|
|
85
|
+
from datetime import datetime
|
|
86
|
+
|
|
87
|
+
input_caption_path = Path(caption_config.input_path)
|
|
88
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H")
|
|
89
|
+
default_output = (
|
|
90
|
+
input_caption_path.parent / f"{input_caption_path.stem}.diarized.{timestamp}.{caption_config.output_format}"
|
|
91
|
+
)
|
|
92
|
+
caption_config.set_output_path(default_output)
|
|
93
|
+
output_path = caption_config.output_path
|
|
94
|
+
|
|
95
|
+
safe_print(colorful.cyan("🗣️ Performing speaker diarization..."))
|
|
96
|
+
diarized_caption = client_instance.speaker_diarization(
|
|
97
|
+
input_media=media_audio,
|
|
98
|
+
caption=caption_obj,
|
|
99
|
+
output_caption_path=output_path,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
return diarized_caption
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def main():
|
|
106
|
+
run.cli.main(diarize)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
if __name__ == "__main__":
|
|
110
|
+
main()
|
lattifai/cli/transcribe.py
CHANGED
|
@@ -117,7 +117,9 @@ def transcribe(
|
|
|
117
117
|
|
|
118
118
|
# Create transcriber
|
|
119
119
|
if not transcription_config.lattice_model_path:
|
|
120
|
-
transcription_config.lattice_model_path = _resolve_model_path(
|
|
120
|
+
transcription_config.lattice_model_path = _resolve_model_path(
|
|
121
|
+
"LattifAI/Lattice-1", getattr(transcription_config, "model_hub", "huggingface")
|
|
122
|
+
)
|
|
121
123
|
transcriber = create_transcriber(transcription_config=transcription_config)
|
|
122
124
|
|
|
123
125
|
safe_print(colorful.cyan(f"🎤 Starting transcription with {transcriber.name}..."))
|
lattifai/cli/youtube.py
CHANGED
|
@@ -25,6 +25,7 @@ def youtube(
|
|
|
25
25
|
caption: Annotated[Optional[CaptionConfig], run.Config[CaptionConfig]] = None,
|
|
26
26
|
transcription: Annotated[Optional[TranscriptionConfig], run.Config[TranscriptionConfig]] = None,
|
|
27
27
|
diarization: Annotated[Optional[DiarizationConfig], run.Config[DiarizationConfig]] = None,
|
|
28
|
+
use_transcription: bool = False,
|
|
28
29
|
):
|
|
29
30
|
"""
|
|
30
31
|
Download media from YouTube (when needed) and align captions.
|
|
@@ -55,6 +56,11 @@ def youtube(
|
|
|
55
56
|
Fields: gemini_api_key, model_name, language, device
|
|
56
57
|
diarization: Speaker diarization configuration.
|
|
57
58
|
Fields: enabled, num_speakers, min_speakers, max_speakers, device
|
|
59
|
+
use_transcription: If True, skip YouTube caption download and directly use
|
|
60
|
+
transcription.model_name to transcribe. If False (default), first try to
|
|
61
|
+
download YouTube captions; if download fails (no captions available or
|
|
62
|
+
errors like HTTP 429), automatically fallback to transcription if
|
|
63
|
+
transcription.model_name is configured.
|
|
58
64
|
|
|
59
65
|
Examples:
|
|
60
66
|
# Download from YouTube and align (positional argument)
|
|
@@ -108,7 +114,11 @@ def youtube(
|
|
|
108
114
|
transcription_config=transcription,
|
|
109
115
|
diarization_config=diarization,
|
|
110
116
|
)
|
|
117
|
+
|
|
111
118
|
# Call the client's youtube method
|
|
119
|
+
# If use_transcription=True, skip YouTube caption download and use transcription directly.
|
|
120
|
+
# If use_transcription=False (default), try YouTube captions first; on failure,
|
|
121
|
+
# automatically fallback to transcription if transcription.model_name is configured.
|
|
112
122
|
return lattifai_client.youtube(
|
|
113
123
|
url=media_config.input_path,
|
|
114
124
|
output_dir=media_config.output_dir,
|
|
@@ -118,6 +128,7 @@ def youtube(
|
|
|
118
128
|
split_sentence=caption_config.split_sentence,
|
|
119
129
|
channel_selector=media_config.channel_selector,
|
|
120
130
|
streaming_chunk_secs=media_config.streaming_chunk_secs,
|
|
131
|
+
use_transcription=use_transcription,
|
|
121
132
|
)
|
|
122
133
|
|
|
123
134
|
|
lattifai/client.py
CHANGED
|
@@ -56,6 +56,7 @@ class LattifAI(LattifAIClientMixin, SyncAPIClient):
|
|
|
56
56
|
|
|
57
57
|
# Initialize base API client
|
|
58
58
|
super().__init__(config=client_config)
|
|
59
|
+
self.config = client_config
|
|
59
60
|
|
|
60
61
|
# Initialize all configs with defaults
|
|
61
62
|
alignment_config, transcription_config, diarization_config = self._init_configs(
|
|
@@ -106,7 +107,13 @@ class LattifAI(LattifAIClientMixin, SyncAPIClient):
|
|
|
106
107
|
)
|
|
107
108
|
|
|
108
109
|
if not input_caption:
|
|
109
|
-
|
|
110
|
+
output_dir = None
|
|
111
|
+
if output_caption_path:
|
|
112
|
+
output_dir = Path(str(output_caption_path)).parent
|
|
113
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
114
|
+
caption = self._transcribe(
|
|
115
|
+
media_audio, source_lang=self.caption_config.source_lang, is_async=False, output_dir=output_dir
|
|
116
|
+
)
|
|
110
117
|
else:
|
|
111
118
|
caption = self._read_caption(input_caption, input_caption_format)
|
|
112
119
|
|
|
@@ -260,18 +267,13 @@ class LattifAI(LattifAIClientMixin, SyncAPIClient):
|
|
|
260
267
|
caption.supervisions = supervisions
|
|
261
268
|
caption.alignments = alignments
|
|
262
269
|
|
|
263
|
-
|
|
264
|
-
if self.diarization_config.enabled and self.diarizer:
|
|
265
|
-
safe_print(colorful.cyan("🗣️ Performing speaker diarization..."))
|
|
266
|
-
caption = self.speaker_diarization(
|
|
267
|
-
input_media=media_audio,
|
|
268
|
-
caption=caption,
|
|
269
|
-
output_caption_path=output_caption_path,
|
|
270
|
-
)
|
|
271
|
-
elif output_caption_path:
|
|
270
|
+
if output_caption_path:
|
|
272
271
|
self._write_caption(caption, output_caption_path)
|
|
273
272
|
|
|
274
|
-
|
|
273
|
+
# Profile if enabled
|
|
274
|
+
if self.config.profile:
|
|
275
|
+
self.aligner.profile()
|
|
276
|
+
|
|
275
277
|
except (CaptionProcessingError, LatticeEncodingError, AlignmentError, LatticeDecodingError):
|
|
276
278
|
# Re-raise our specific errors as-is
|
|
277
279
|
raise
|
|
@@ -284,6 +286,17 @@ class LattifAI(LattifAIClientMixin, SyncAPIClient):
|
|
|
284
286
|
context={"original_error": str(e), "error_type": e.__class__.__name__},
|
|
285
287
|
)
|
|
286
288
|
|
|
289
|
+
# Step 5: Speaker diarization
|
|
290
|
+
if self.diarization_config.enabled and self.diarizer:
|
|
291
|
+
safe_print(colorful.cyan("🗣️ Performing speaker diarization..."))
|
|
292
|
+
caption = self.speaker_diarization(
|
|
293
|
+
input_media=media_audio,
|
|
294
|
+
caption=caption,
|
|
295
|
+
output_caption_path=output_caption_path,
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
return caption
|
|
299
|
+
|
|
287
300
|
def speaker_diarization(
|
|
288
301
|
self,
|
|
289
302
|
input_media: AudioData,
|
|
@@ -315,7 +328,14 @@ class LattifAI(LattifAIClientMixin, SyncAPIClient):
|
|
|
315
328
|
caption.read_speaker_diarization(diarization_file)
|
|
316
329
|
|
|
317
330
|
diarization, alignments = self.diarizer.diarize_with_alignments(
|
|
318
|
-
input_media,
|
|
331
|
+
input_media,
|
|
332
|
+
caption.alignments,
|
|
333
|
+
diarization=caption.speaker_diarization,
|
|
334
|
+
alignment_fn=self.aligner.alignment,
|
|
335
|
+
transcribe_fn=self.transcriber.transcribe_numpy if self.transcriber else None,
|
|
336
|
+
separate_fn=self.aligner.separate if self.aligner.worker.separator_ort else None,
|
|
337
|
+
debug=self.diarizer.config.debug,
|
|
338
|
+
output_path=output_caption_path,
|
|
319
339
|
)
|
|
320
340
|
caption.alignments = alignments
|
|
321
341
|
caption.speaker_diarization = diarization
|
|
@@ -324,105 +344,6 @@ class LattifAI(LattifAIClientMixin, SyncAPIClient):
|
|
|
324
344
|
if output_caption_path:
|
|
325
345
|
self._write_caption(caption, output_caption_path)
|
|
326
346
|
|
|
327
|
-
if self.diarizer.config.debug:
|
|
328
|
-
# debug
|
|
329
|
-
from tgt import Interval, IntervalTier, TextGrid, write_to_file
|
|
330
|
-
|
|
331
|
-
debug_tg = TextGrid()
|
|
332
|
-
transcript_tier = IntervalTier(
|
|
333
|
-
start_time=0,
|
|
334
|
-
end_time=input_media.duration,
|
|
335
|
-
name="transcript",
|
|
336
|
-
objects=[Interval(sup.start, sup.end, sup.text) for sup in caption.alignments],
|
|
337
|
-
)
|
|
338
|
-
debug_tg.add_tier(transcript_tier)
|
|
339
|
-
|
|
340
|
-
speaker_tier = IntervalTier(
|
|
341
|
-
start_time=0,
|
|
342
|
-
end_time=input_media.duration,
|
|
343
|
-
name="speaker",
|
|
344
|
-
objects=[Interval(sup.start, sup.end, sup.speaker) for sup in caption.alignments],
|
|
345
|
-
)
|
|
346
|
-
debug_tg.add_tier(speaker_tier)
|
|
347
|
-
|
|
348
|
-
from collections import defaultdict
|
|
349
|
-
|
|
350
|
-
spk2intervals = defaultdict(lambda: [])
|
|
351
|
-
num_multispk = 0
|
|
352
|
-
|
|
353
|
-
segments, skipks = [], []
|
|
354
|
-
for k, supervision in enumerate(caption.alignments): # TODO: alignments 本身存在 overlap, eg: [event]
|
|
355
|
-
# supervision = caption.alignments[k]
|
|
356
|
-
if supervision.custom.get("speaker", []):
|
|
357
|
-
num_multispk += 1
|
|
358
|
-
else:
|
|
359
|
-
continue
|
|
360
|
-
|
|
361
|
-
if k in skipks:
|
|
362
|
-
continue
|
|
363
|
-
|
|
364
|
-
for speaker in supervision.custom.get("speaker", []):
|
|
365
|
-
for name, start_time, end_time in speaker:
|
|
366
|
-
spk2intervals[name].append(Interval(start_time, end_time, name))
|
|
367
|
-
|
|
368
|
-
_segments = []
|
|
369
|
-
if k > 0:
|
|
370
|
-
_segments.append(caption.alignments[k - 1])
|
|
371
|
-
_segments.append(supervision)
|
|
372
|
-
while k + 1 < len(caption.alignments):
|
|
373
|
-
skipks.append(k + 1)
|
|
374
|
-
next_sup = caption.alignments[k + 1]
|
|
375
|
-
if not next_sup.custom.get("speaker", []):
|
|
376
|
-
k += 1
|
|
377
|
-
break
|
|
378
|
-
_segments.append(next_sup)
|
|
379
|
-
k += 1
|
|
380
|
-
|
|
381
|
-
if segments:
|
|
382
|
-
if _segments[0].start >= segments[-1][-1].end:
|
|
383
|
-
segments.append(_segments)
|
|
384
|
-
else:
|
|
385
|
-
if _segments[1:]:
|
|
386
|
-
segments.append(_segments[1:])
|
|
387
|
-
else:
|
|
388
|
-
pass
|
|
389
|
-
else:
|
|
390
|
-
segments.append(_segments)
|
|
391
|
-
|
|
392
|
-
print(
|
|
393
|
-
f"Number of multi-speaker segments: {num_multispk}/{len(caption.alignments)} segments: {len(segments)}"
|
|
394
|
-
)
|
|
395
|
-
|
|
396
|
-
for speaker, intervals in sorted(spk2intervals.items(), key=lambda x: x[0]):
|
|
397
|
-
speaker_tier = IntervalTier(
|
|
398
|
-
start_time=0, end_time=input_media.duration, name=speaker, objects=intervals
|
|
399
|
-
)
|
|
400
|
-
debug_tg.add_tier(speaker_tier)
|
|
401
|
-
|
|
402
|
-
for tier in caption.speaker_diarization.tiers:
|
|
403
|
-
tier.name = f"Diarization-{tier.name}"
|
|
404
|
-
debug_tg.add_tier(tier)
|
|
405
|
-
|
|
406
|
-
tier = IntervalTier(
|
|
407
|
-
start_time=0,
|
|
408
|
-
end_time=input_media.duration,
|
|
409
|
-
name="resegment",
|
|
410
|
-
objects=[
|
|
411
|
-
Interval(round(sup.start, 2), round(sup.end, 2), sup.text)
|
|
412
|
-
for _segments in segments
|
|
413
|
-
for sup in _segments
|
|
414
|
-
],
|
|
415
|
-
)
|
|
416
|
-
debug_tg.add_tier(tier)
|
|
417
|
-
|
|
418
|
-
# if caption.audio_events:
|
|
419
|
-
# for tier in caption.audio_events.tiers:
|
|
420
|
-
# # tier.name = f"{tier.name}"
|
|
421
|
-
# debug_tg.add_tier(tier)
|
|
422
|
-
|
|
423
|
-
debug_tgt_file = Path(str(output_caption_path)).with_suffix(".DiarizationDebug.TextGrid")
|
|
424
|
-
write_to_file(debug_tg, debug_tgt_file, format="long")
|
|
425
|
-
|
|
426
347
|
return caption
|
|
427
348
|
|
|
428
349
|
def youtube(
|
lattifai/config/alignment.py
CHANGED
|
@@ -21,6 +21,9 @@ class AlignmentConfig:
|
|
|
21
21
|
model_name: str = "LattifAI/Lattice-1"
|
|
22
22
|
"""Model identifier or path to local model directory (e.g., 'LattifAI/Lattice-1')."""
|
|
23
23
|
|
|
24
|
+
model_hub: Literal["huggingface", "modelscope"] = "huggingface"
|
|
25
|
+
"""Which model hub to use when resolving remote model names: 'huggingface' or 'modelscope'."""
|
|
26
|
+
|
|
24
27
|
device: Literal["cpu", "cuda", "mps", "auto"] = "auto"
|
|
25
28
|
"""Computation device: 'cpu' for CPU, 'cuda' for NVIDIA GPU, 'mps' for Apple Silicon."""
|
|
26
29
|
|
|
@@ -79,6 +82,17 @@ class AlignmentConfig:
|
|
|
79
82
|
Default: 10000. Typical range: 1000-20000.
|
|
80
83
|
"""
|
|
81
84
|
|
|
85
|
+
# Alignment timing configuration
|
|
86
|
+
start_margin: float = 0.08
|
|
87
|
+
"""Maximum start time margin (in seconds) to extend segment boundaries at the beginning.
|
|
88
|
+
Default: 0.08. Typical range: 0.0-0.5.
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
end_margin: float = 0.20
|
|
92
|
+
"""Maximum end time margin (in seconds) to extend segment boundaries at the end.
|
|
93
|
+
Default: 0.20. Typical range: 0.0-0.5.
|
|
94
|
+
"""
|
|
95
|
+
|
|
82
96
|
client_wrapper: Optional["SyncAPIClient"] = field(default=None, repr=False)
|
|
83
97
|
"""Reference to the SyncAPIClient instance. Auto-set during client initialization."""
|
|
84
98
|
|
lattifai/config/client.py
CHANGED
|
@@ -26,6 +26,11 @@ class ClientConfig:
|
|
|
26
26
|
default_headers: Optional[Dict[str, str]] = field(default=None)
|
|
27
27
|
"""Optional static headers to include in all requests."""
|
|
28
28
|
|
|
29
|
+
profile: bool = False
|
|
30
|
+
"""Enable profiling of client operations tasks.
|
|
31
|
+
When True, prints detailed timing information for various stages of the process.
|
|
32
|
+
"""
|
|
33
|
+
|
|
29
34
|
def __post_init__(self):
|
|
30
35
|
"""Validate and auto-populate configuration after initialization."""
|
|
31
36
|
|
lattifai/config/transcription.py
CHANGED
|
@@ -12,6 +12,7 @@ if TYPE_CHECKING:
|
|
|
12
12
|
SUPPORTED_TRANSCRIPTION_MODELS = Literal[
|
|
13
13
|
"gemini-2.5-pro",
|
|
14
14
|
"gemini-3-pro-preview",
|
|
15
|
+
"gemini-3-flash-preview",
|
|
15
16
|
"nvidia/parakeet-tdt-0.6b-v3",
|
|
16
17
|
"nvidia/canary-1b-v2",
|
|
17
18
|
"iic/SenseVoiceSmall",
|
|
@@ -50,6 +51,9 @@ class TranscriptionConfig:
|
|
|
50
51
|
lattice_model_path: Optional[str] = None
|
|
51
52
|
"""Path to local LattifAI model. Will be auto-set in LattifAI client."""
|
|
52
53
|
|
|
54
|
+
model_hub: Literal["huggingface", "modelscope"] = "huggingface"
|
|
55
|
+
"""Which model hub to use when resolving lattice models for transcription."""
|
|
56
|
+
|
|
53
57
|
client_wrapper: Optional["SyncAPIClient"] = field(default=None, repr=False)
|
|
54
58
|
"""Reference to the SyncAPIClient instance. Auto-set during client initialization."""
|
|
55
59
|
|