lattifai 0.4.6__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lattifai/__init__.py +42 -27
- lattifai/alignment/__init__.py +6 -0
- lattifai/alignment/lattice1_aligner.py +119 -0
- lattifai/{workers/lattice1_alpha.py → alignment/lattice1_worker.py} +33 -132
- lattifai/{tokenizer → alignment}/phonemizer.py +1 -1
- lattifai/alignment/segmenter.py +166 -0
- lattifai/{tokenizer → alignment}/tokenizer.py +186 -112
- lattifai/audio2.py +211 -0
- lattifai/caption/__init__.py +20 -0
- lattifai/caption/caption.py +1275 -0
- lattifai/{io → caption}/supervision.py +1 -0
- lattifai/{io → caption}/text_parser.py +53 -10
- lattifai/cli/__init__.py +17 -0
- lattifai/cli/alignment.py +153 -0
- lattifai/cli/caption.py +204 -0
- lattifai/cli/server.py +19 -0
- lattifai/cli/transcribe.py +197 -0
- lattifai/cli/youtube.py +128 -0
- lattifai/client.py +455 -246
- lattifai/config/__init__.py +20 -0
- lattifai/config/alignment.py +73 -0
- lattifai/config/caption.py +178 -0
- lattifai/config/client.py +46 -0
- lattifai/config/diarization.py +67 -0
- lattifai/config/media.py +335 -0
- lattifai/config/transcription.py +84 -0
- lattifai/diarization/__init__.py +5 -0
- lattifai/diarization/lattifai.py +89 -0
- lattifai/errors.py +41 -34
- lattifai/logging.py +116 -0
- lattifai/mixin.py +552 -0
- lattifai/server/app.py +420 -0
- lattifai/transcription/__init__.py +76 -0
- lattifai/transcription/base.py +108 -0
- lattifai/transcription/gemini.py +219 -0
- lattifai/transcription/lattifai.py +103 -0
- lattifai/types.py +30 -0
- lattifai/utils.py +3 -31
- lattifai/workflow/__init__.py +22 -0
- lattifai/workflow/agents.py +6 -0
- lattifai/{workflows → workflow}/file_manager.py +81 -57
- lattifai/workflow/youtube.py +564 -0
- lattifai-1.0.0.dist-info/METADATA +736 -0
- lattifai-1.0.0.dist-info/RECORD +52 -0
- {lattifai-0.4.6.dist-info → lattifai-1.0.0.dist-info}/WHEEL +1 -1
- lattifai-1.0.0.dist-info/entry_points.txt +13 -0
- lattifai/base_client.py +0 -126
- lattifai/bin/__init__.py +0 -3
- lattifai/bin/agent.py +0 -324
- lattifai/bin/align.py +0 -295
- lattifai/bin/cli_base.py +0 -25
- lattifai/bin/subtitle.py +0 -210
- lattifai/io/__init__.py +0 -43
- lattifai/io/reader.py +0 -86
- lattifai/io/utils.py +0 -15
- lattifai/io/writer.py +0 -102
- lattifai/tokenizer/__init__.py +0 -3
- lattifai/workers/__init__.py +0 -3
- lattifai/workflows/__init__.py +0 -34
- lattifai/workflows/agents.py +0 -12
- lattifai/workflows/gemini.py +0 -167
- lattifai/workflows/prompts/README.md +0 -22
- lattifai/workflows/prompts/gemini/README.md +0 -24
- lattifai/workflows/prompts/gemini/transcription_gem.txt +0 -81
- lattifai/workflows/youtube.py +0 -931
- lattifai-0.4.6.dist-info/METADATA +0 -806
- lattifai-0.4.6.dist-info/RECORD +0 -39
- lattifai-0.4.6.dist-info/entry_points.txt +0 -3
- /lattifai/{io → caption}/gemini_reader.py +0 -0
- /lattifai/{io → caption}/gemini_writer.py +0 -0
- /lattifai/{workflows → transcription}/prompts/__init__.py +0 -0
- /lattifai/{workflows → workflow}/base.py +0 -0
- {lattifai-0.4.6.dist-info → lattifai-1.0.0.dist-info}/licenses/LICENSE +0 -0
- {lattifai-0.4.6.dist-info → lattifai-1.0.0.dist-info}/top_level.txt +0 -0
lattifai/__init__.py
CHANGED
|
@@ -1,28 +1,40 @@
|
|
|
1
1
|
import sys
|
|
2
2
|
import warnings
|
|
3
|
+
from importlib.metadata import version
|
|
3
4
|
|
|
5
|
+
# Re-export I/O classes
|
|
6
|
+
from .caption import Caption
|
|
7
|
+
|
|
8
|
+
# Re-export client classes
|
|
9
|
+
from .client import LattifAI
|
|
10
|
+
|
|
11
|
+
# Re-export config classes
|
|
12
|
+
from .config import (
|
|
13
|
+
AUDIO_FORMATS,
|
|
14
|
+
MEDIA_FORMATS,
|
|
15
|
+
VIDEO_FORMATS,
|
|
16
|
+
AlignmentConfig,
|
|
17
|
+
CaptionConfig,
|
|
18
|
+
ClientConfig,
|
|
19
|
+
DiarizationConfig,
|
|
20
|
+
MediaConfig,
|
|
21
|
+
)
|
|
4
22
|
from .errors import (
|
|
5
23
|
AlignmentError,
|
|
6
24
|
APIError,
|
|
7
25
|
AudioFormatError,
|
|
8
26
|
AudioLoadError,
|
|
9
27
|
AudioProcessingError,
|
|
28
|
+
CaptionParseError,
|
|
29
|
+
CaptionProcessingError,
|
|
10
30
|
ConfigurationError,
|
|
11
31
|
DependencyError,
|
|
12
32
|
LatticeDecodingError,
|
|
13
33
|
LatticeEncodingError,
|
|
14
34
|
LattifAIError,
|
|
15
35
|
ModelLoadError,
|
|
16
|
-
SubtitleParseError,
|
|
17
|
-
SubtitleProcessingError,
|
|
18
36
|
)
|
|
19
|
-
from .
|
|
20
|
-
|
|
21
|
-
try:
|
|
22
|
-
from importlib.metadata import version
|
|
23
|
-
except ImportError:
|
|
24
|
-
# Python < 3.8
|
|
25
|
-
from importlib_metadata import version
|
|
37
|
+
from .logging import get_logger, set_log_level, setup_logger
|
|
26
38
|
|
|
27
39
|
try:
|
|
28
40
|
__version__ = version("lattifai")
|
|
@@ -54,28 +66,25 @@ def _check_and_install_k2():
|
|
|
54
66
|
_check_and_install_k2()
|
|
55
67
|
|
|
56
68
|
|
|
57
|
-
# Lazy import for LattifAI to avoid dependency issues during basic import
|
|
58
|
-
def __getattr__(name):
|
|
59
|
-
if name == "LattifAI":
|
|
60
|
-
from .client import LattifAI
|
|
61
|
-
|
|
62
|
-
return LattifAI
|
|
63
|
-
if name == "AsyncLattifAI":
|
|
64
|
-
from .client import AsyncLattifAI
|
|
65
|
-
|
|
66
|
-
return AsyncLattifAI
|
|
67
|
-
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
|
68
|
-
|
|
69
|
-
|
|
70
69
|
__all__ = [
|
|
71
|
-
|
|
72
|
-
"
|
|
70
|
+
# Client classes
|
|
71
|
+
"LattifAI",
|
|
72
|
+
# Config classes
|
|
73
|
+
"AlignmentConfig",
|
|
74
|
+
"ClientConfig",
|
|
75
|
+
"CaptionConfig",
|
|
76
|
+
"DiarizationConfig",
|
|
77
|
+
"MediaConfig",
|
|
78
|
+
"AUDIO_FORMATS",
|
|
79
|
+
"VIDEO_FORMATS",
|
|
80
|
+
"MEDIA_FORMATS",
|
|
81
|
+
# Error classes
|
|
73
82
|
"LattifAIError",
|
|
74
83
|
"AudioProcessingError",
|
|
75
84
|
"AudioLoadError",
|
|
76
85
|
"AudioFormatError",
|
|
77
|
-
"
|
|
78
|
-
"
|
|
86
|
+
"CaptionProcessingError",
|
|
87
|
+
"CaptionParseError",
|
|
79
88
|
"AlignmentError",
|
|
80
89
|
"LatticeEncodingError",
|
|
81
90
|
"LatticeDecodingError",
|
|
@@ -83,6 +92,12 @@ __all__ = [
|
|
|
83
92
|
"DependencyError",
|
|
84
93
|
"APIError",
|
|
85
94
|
"ConfigurationError",
|
|
86
|
-
|
|
95
|
+
# Logging
|
|
96
|
+
"setup_logger",
|
|
97
|
+
"get_logger",
|
|
98
|
+
"set_log_level",
|
|
99
|
+
# I/O
|
|
100
|
+
"Caption",
|
|
101
|
+
# Version
|
|
87
102
|
"__version__",
|
|
88
103
|
]
|
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
"""Lattice-1 Aligner implementation."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, List, Optional, Tuple
|
|
4
|
+
|
|
5
|
+
import colorful
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from lattifai.audio2 import AudioData
|
|
9
|
+
from lattifai.caption import Supervision
|
|
10
|
+
from lattifai.config import AlignmentConfig
|
|
11
|
+
from lattifai.errors import (
|
|
12
|
+
AlignmentError,
|
|
13
|
+
LatticeDecodingError,
|
|
14
|
+
LatticeEncodingError,
|
|
15
|
+
)
|
|
16
|
+
from lattifai.utils import _resolve_model_path
|
|
17
|
+
|
|
18
|
+
from .lattice1_worker import _load_worker
|
|
19
|
+
from .tokenizer import _load_tokenizer
|
|
20
|
+
|
|
21
|
+
ClientType = Any
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class Lattice1Aligner(object):
|
|
25
|
+
"""Synchronous LattifAI client with config-driven architecture."""
|
|
26
|
+
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
config: AlignmentConfig,
|
|
30
|
+
) -> None:
|
|
31
|
+
self.config = config
|
|
32
|
+
|
|
33
|
+
if config.client_wrapper is None:
|
|
34
|
+
raise ValueError("AlignmentConfig.client_wrapper is not set. It must be initialized by the client.")
|
|
35
|
+
|
|
36
|
+
client_wrapper = config.client_wrapper
|
|
37
|
+
model_path = _resolve_model_path(config.model_name)
|
|
38
|
+
|
|
39
|
+
self.tokenizer = _load_tokenizer(client_wrapper, model_path, config.model_name, config.device)
|
|
40
|
+
self.worker = _load_worker(model_path, config.device)
|
|
41
|
+
|
|
42
|
+
self.frame_shift = self.worker.frame_shift
|
|
43
|
+
|
|
44
|
+
def emission(self, audio: torch.Tensor) -> torch.Tensor:
|
|
45
|
+
return self.worker.emission(audio.to(self.worker.device))
|
|
46
|
+
|
|
47
|
+
def alignment(
|
|
48
|
+
self,
|
|
49
|
+
audio: AudioData,
|
|
50
|
+
supervisions: List[Supervision],
|
|
51
|
+
split_sentence: Optional[bool] = False,
|
|
52
|
+
return_details: Optional[bool] = False,
|
|
53
|
+
emission: Optional[torch.Tensor] = None,
|
|
54
|
+
offset: float = 0.0,
|
|
55
|
+
verbose: bool = True,
|
|
56
|
+
) -> Tuple[List[Supervision], List[Supervision]]:
|
|
57
|
+
"""
|
|
58
|
+
Perform alignment on audio and supervisions.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
audio: Audio file path
|
|
62
|
+
supervisions: List of supervision segments to align
|
|
63
|
+
split_sentence: Enable sentence re-splitting
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
Tuple of (supervisions, alignments)
|
|
67
|
+
|
|
68
|
+
Raises:
|
|
69
|
+
LatticeEncodingError: If lattice graph generation fails
|
|
70
|
+
AlignmentError: If audio alignment fails
|
|
71
|
+
LatticeDecodingError: If lattice decoding fails
|
|
72
|
+
"""
|
|
73
|
+
try:
|
|
74
|
+
if verbose:
|
|
75
|
+
print(colorful.cyan("🔗 Step 2: Creating lattice graph from segments"))
|
|
76
|
+
try:
|
|
77
|
+
supervisions, lattice_id, lattice_graph = self.tokenizer.tokenize(
|
|
78
|
+
supervisions, split_sentence=split_sentence
|
|
79
|
+
)
|
|
80
|
+
if verbose:
|
|
81
|
+
print(colorful.green(f" ✓ Generated lattice graph with ID: {lattice_id}"))
|
|
82
|
+
except Exception as e:
|
|
83
|
+
text_content = " ".join([sup.text for sup in supervisions]) if supervisions else ""
|
|
84
|
+
raise LatticeEncodingError(text_content, original_error=e)
|
|
85
|
+
|
|
86
|
+
if verbose:
|
|
87
|
+
print(colorful.cyan(f"🔍 Step 3: Searching lattice graph with media: {audio}"))
|
|
88
|
+
try:
|
|
89
|
+
lattice_results = self.worker.alignment(audio, lattice_graph, emission=emission, offset=offset)
|
|
90
|
+
if verbose:
|
|
91
|
+
print(colorful.green(" ✓ Lattice search completed"))
|
|
92
|
+
except Exception as e:
|
|
93
|
+
raise AlignmentError(
|
|
94
|
+
f"Audio alignment failed for {audio}",
|
|
95
|
+
media_path=str(audio),
|
|
96
|
+
context={"original_error": str(e)},
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
if verbose:
|
|
100
|
+
print(colorful.cyan("🎯 Step 4: Decoding lattice results to aligned segments"))
|
|
101
|
+
try:
|
|
102
|
+
alignments = self.tokenizer.detokenize(
|
|
103
|
+
lattice_id, lattice_results, supervisions=supervisions, return_details=return_details
|
|
104
|
+
)
|
|
105
|
+
if verbose:
|
|
106
|
+
print(colorful.green(f" ✓ Successfully aligned {len(alignments)} segments"))
|
|
107
|
+
except LatticeDecodingError as e:
|
|
108
|
+
print(colorful.red(" x Failed to decode lattice alignment results"))
|
|
109
|
+
raise e
|
|
110
|
+
except Exception as e:
|
|
111
|
+
print(colorful.red(" x Failed to decode lattice alignment results"))
|
|
112
|
+
raise LatticeDecodingError(lattice_id, original_error=e)
|
|
113
|
+
|
|
114
|
+
return (supervisions, alignments)
|
|
115
|
+
|
|
116
|
+
except (LatticeEncodingError, AlignmentError, LatticeDecodingError):
|
|
117
|
+
raise
|
|
118
|
+
except Exception as e:
|
|
119
|
+
raise e
|
|
@@ -1,73 +1,20 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import time
|
|
3
3
|
from collections import defaultdict
|
|
4
|
-
from typing import Any,
|
|
4
|
+
from typing import Any, Dict, Optional, Tuple
|
|
5
5
|
|
|
6
6
|
import numpy as np
|
|
7
7
|
import onnxruntime as ort
|
|
8
|
-
import soundfile as sf
|
|
9
8
|
import torch
|
|
10
9
|
from lhotse import FbankConfig
|
|
11
|
-
from lhotse.augmentation import get_or_create_resampler
|
|
12
10
|
from lhotse.features.kaldi.layers import Wav2LogFilterBank
|
|
13
11
|
from lhotse.utils import Pathlike
|
|
14
12
|
|
|
15
|
-
from lattifai.
|
|
13
|
+
from lattifai.audio2 import AudioData
|
|
14
|
+
from lattifai.errors import AlignmentError, DependencyError, ModelLoadError
|
|
16
15
|
|
|
17
|
-
ChannelSelectorType = Union[int, Iterable[int], str]
|
|
18
16
|
|
|
19
|
-
|
|
20
|
-
def resample_audio(
|
|
21
|
-
audio_sr: Tuple[torch.Tensor, int],
|
|
22
|
-
sampling_rate: int,
|
|
23
|
-
device: Optional[str],
|
|
24
|
-
channel_selector: Optional[ChannelSelectorType] = "average",
|
|
25
|
-
) -> torch.Tensor:
|
|
26
|
-
"""
|
|
27
|
-
return:
|
|
28
|
-
(1, T)
|
|
29
|
-
"""
|
|
30
|
-
audio, sr = audio_sr
|
|
31
|
-
|
|
32
|
-
if channel_selector is None:
|
|
33
|
-
# keep the original multi-channel signal
|
|
34
|
-
tensor = audio
|
|
35
|
-
elif isinstance(channel_selector, int):
|
|
36
|
-
assert audio.shape[0] >= channel_selector, f"Invalid channel: {channel_selector}"
|
|
37
|
-
tensor = audio[channel_selector : channel_selector + 1].clone()
|
|
38
|
-
del audio
|
|
39
|
-
elif isinstance(channel_selector, str):
|
|
40
|
-
assert channel_selector == "average"
|
|
41
|
-
tensor = torch.mean(audio.to(device), dim=0, keepdim=True)
|
|
42
|
-
del audio
|
|
43
|
-
else:
|
|
44
|
-
assert isinstance(channel_selector, Iterable)
|
|
45
|
-
num_channels = audio.shape[0]
|
|
46
|
-
print(f"Selecting channels {channel_selector} from the signal with {num_channels} channels.")
|
|
47
|
-
assert isinstance(channel_selector, Iterable)
|
|
48
|
-
if max(channel_selector) >= num_channels:
|
|
49
|
-
raise ValueError(
|
|
50
|
-
f"Cannot select channel subset {channel_selector} from a signal with {num_channels} channels."
|
|
51
|
-
)
|
|
52
|
-
tensor = audio[channel_selector]
|
|
53
|
-
|
|
54
|
-
tensor = tensor.to(device)
|
|
55
|
-
if sr != sampling_rate:
|
|
56
|
-
resampler = get_or_create_resampler(sr, sampling_rate).to(device=device)
|
|
57
|
-
length = tensor.size(-1)
|
|
58
|
-
chunk_size = sampling_rate * 3600
|
|
59
|
-
if length > chunk_size:
|
|
60
|
-
resampled_chunks = []
|
|
61
|
-
for i in range(0, length, chunk_size):
|
|
62
|
-
resampled_chunks.append(resampler(tensor[..., i : i + chunk_size]))
|
|
63
|
-
tensor = torch.cat(resampled_chunks, dim=-1)
|
|
64
|
-
else:
|
|
65
|
-
tensor = resampler(tensor)
|
|
66
|
-
|
|
67
|
-
return tensor
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
class Lattice1AlphaWorker:
|
|
17
|
+
class Lattice1Worker:
|
|
71
18
|
"""Worker for processing audio with LatticeGraph."""
|
|
72
19
|
|
|
73
20
|
def __init__(self, model_path: Pathlike, device: str = "cpu", num_threads: int = 8) -> None:
|
|
@@ -109,6 +56,10 @@ class Lattice1AlphaWorker:
|
|
|
109
56
|
self.device = torch.device(device)
|
|
110
57
|
self.timings = defaultdict(lambda: 0.0)
|
|
111
58
|
|
|
59
|
+
@property
|
|
60
|
+
def frame_shift(self) -> float:
|
|
61
|
+
return 0.02 # 20 ms
|
|
62
|
+
|
|
112
63
|
@torch.inference_mode()
|
|
113
64
|
def emission(self, audio: torch.Tensor) -> torch.Tensor:
|
|
114
65
|
_start = time.time()
|
|
@@ -138,68 +89,17 @@ class Lattice1AlphaWorker:
|
|
|
138
89
|
self.timings["emission"] += time.time() - _start
|
|
139
90
|
return emission # (1, T, vocab_size) torch
|
|
140
91
|
|
|
141
|
-
def load_audio(
|
|
142
|
-
self, audio: Union[Pathlike, BinaryIO], channel_selector: Optional[ChannelSelectorType] = "average"
|
|
143
|
-
) -> Tuple[torch.Tensor, int]:
|
|
144
|
-
# load audio
|
|
145
|
-
try:
|
|
146
|
-
waveform, sample_rate = sf.read(audio, always_2d=True, dtype="float32") # numpy array
|
|
147
|
-
waveform = waveform.T # (channels, samples)
|
|
148
|
-
except Exception as primary_error:
|
|
149
|
-
# Fallback to PyAV for formats not supported by soundfile
|
|
150
|
-
try:
|
|
151
|
-
import av
|
|
152
|
-
except ImportError:
|
|
153
|
-
raise DependencyError(
|
|
154
|
-
"av (PyAV)", install_command="pip install av", context={"primary_error": str(primary_error)}
|
|
155
|
-
)
|
|
156
|
-
|
|
157
|
-
try:
|
|
158
|
-
container = av.open(audio)
|
|
159
|
-
audio_stream = next((s for s in container.streams if s.type == "audio"), None)
|
|
160
|
-
|
|
161
|
-
if audio_stream is None:
|
|
162
|
-
raise AudioFormatError(str(audio), "No audio stream found in file")
|
|
163
|
-
|
|
164
|
-
# Resample to target sample rate during decoding
|
|
165
|
-
audio_stream.codec_context.format = av.AudioFormat("flt") # 32-bit float
|
|
166
|
-
|
|
167
|
-
frames = []
|
|
168
|
-
for frame in container.decode(audio_stream):
|
|
169
|
-
# Convert frame to numpy array
|
|
170
|
-
array = frame.to_ndarray()
|
|
171
|
-
# Ensure shape is (channels, samples)
|
|
172
|
-
if array.ndim == 1:
|
|
173
|
-
array = array.reshape(1, -1)
|
|
174
|
-
elif array.ndim == 2 and array.shape[0] > array.shape[1]:
|
|
175
|
-
array = array.T
|
|
176
|
-
frames.append(array)
|
|
177
|
-
|
|
178
|
-
container.close()
|
|
179
|
-
|
|
180
|
-
if not frames:
|
|
181
|
-
raise AudioFormatError(str(audio), "No audio data found in file")
|
|
182
|
-
|
|
183
|
-
# Concatenate all frames
|
|
184
|
-
waveform = np.concatenate(frames, axis=1)
|
|
185
|
-
sample_rate = audio_stream.codec_context.sample_rate
|
|
186
|
-
except Exception as e:
|
|
187
|
-
raise AudioLoadError(str(audio), original_error=e)
|
|
188
|
-
|
|
189
|
-
return resample_audio(
|
|
190
|
-
(torch.from_numpy(waveform), sample_rate),
|
|
191
|
-
self.config.get("sampling_rate", 16000),
|
|
192
|
-
device=self.device.type,
|
|
193
|
-
channel_selector=channel_selector,
|
|
194
|
-
)
|
|
195
|
-
|
|
196
92
|
def alignment(
|
|
197
|
-
self,
|
|
93
|
+
self,
|
|
94
|
+
audio: AudioData,
|
|
95
|
+
lattice_graph: Tuple[str, int, float],
|
|
96
|
+
emission: Optional[torch.Tensor] = None,
|
|
97
|
+
offset: float = 0.0,
|
|
198
98
|
) -> Dict[str, Any]:
|
|
199
99
|
"""Process audio with LatticeGraph.
|
|
200
100
|
|
|
201
101
|
Args:
|
|
202
|
-
audio:
|
|
102
|
+
audio: AudioData object
|
|
203
103
|
lattice_graph: LatticeGraph data
|
|
204
104
|
|
|
205
105
|
Returns:
|
|
@@ -210,22 +110,15 @@ class Lattice1AlphaWorker:
|
|
|
210
110
|
DependencyError: If required dependencies are missing
|
|
211
111
|
AlignmentError: If alignment process fails
|
|
212
112
|
"""
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
except Exception as e:
|
|
223
|
-
raise AlignmentError(
|
|
224
|
-
"Failed to compute acoustic features from audio",
|
|
225
|
-
audio_path=str(audio) if not isinstance(audio, torch.Tensor) else "tensor",
|
|
226
|
-
context={"original_error": str(e)},
|
|
227
|
-
)
|
|
228
|
-
self.timings["emission"] += time.time() - _start
|
|
113
|
+
if emission is None:
|
|
114
|
+
try:
|
|
115
|
+
emission = self.emission(audio.tensor.to(self.device)) # (1, T, vocab_size)
|
|
116
|
+
except Exception as e:
|
|
117
|
+
raise AlignmentError(
|
|
118
|
+
"Failed to compute acoustic features from audio",
|
|
119
|
+
media_path=str(audio) if not isinstance(audio, torch.Tensor) else "tensor",
|
|
120
|
+
context={"original_error": str(e)},
|
|
121
|
+
)
|
|
229
122
|
|
|
230
123
|
try:
|
|
231
124
|
import k2
|
|
@@ -275,10 +168,18 @@ class Lattice1AlphaWorker:
|
|
|
275
168
|
except Exception as e:
|
|
276
169
|
raise AlignmentError(
|
|
277
170
|
"Failed to perform forced alignment",
|
|
278
|
-
|
|
171
|
+
media_path=str(audio) if not isinstance(audio, torch.Tensor) else "tensor",
|
|
279
172
|
context={"original_error": str(e), "emission_shape": list(emission.shape), "device": str(device)},
|
|
280
173
|
)
|
|
281
174
|
self.timings["align_segments"] += time.time() - _start
|
|
282
175
|
|
|
283
176
|
channel = 0
|
|
284
|
-
return emission, results, labels,
|
|
177
|
+
return emission, results, labels, self.frame_shift, offset, channel # frame_shift=20ms
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def _load_worker(model_path: str, device: str) -> Lattice1Worker:
|
|
181
|
+
"""Instantiate lattice worker with consistent error handling."""
|
|
182
|
+
try:
|
|
183
|
+
return Lattice1Worker(model_path, device=device, num_threads=8)
|
|
184
|
+
except Exception as e:
|
|
185
|
+
raise ModelLoadError(f"worker from {model_path}", original_error=e)
|
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
"""Segmented alignment for long audio files."""
|
|
2
|
+
|
|
3
|
+
from typing import List, Optional, Tuple
|
|
4
|
+
|
|
5
|
+
import colorful
|
|
6
|
+
|
|
7
|
+
from lattifai.audio2 import AudioData
|
|
8
|
+
from lattifai.caption import Caption, Supervision
|
|
9
|
+
from lattifai.config import AlignmentConfig
|
|
10
|
+
|
|
11
|
+
from .tokenizer import END_PUNCTUATION
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class Segmenter:
|
|
15
|
+
"""
|
|
16
|
+
Handles segmented alignment for long audio/video files.
|
|
17
|
+
|
|
18
|
+
Instead of aligning the entire audio at once (which can be slow and memory-intensive
|
|
19
|
+
for long files), this class splits the alignment into manageable segments based on
|
|
20
|
+
caption boundaries, time intervals, or an adaptive strategy.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(self, config: AlignmentConfig):
|
|
24
|
+
"""
|
|
25
|
+
Initialize segmented aligner.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
config: Alignment configuration with segmentation parameters
|
|
29
|
+
"""
|
|
30
|
+
self.config = config
|
|
31
|
+
|
|
32
|
+
def __call__(
|
|
33
|
+
self,
|
|
34
|
+
caption: Caption,
|
|
35
|
+
max_duration: Optional[float] = None,
|
|
36
|
+
) -> List[Tuple[float, float, List[Supervision]]]:
|
|
37
|
+
"""
|
|
38
|
+
Create segments based on caption boundaries and gaps.
|
|
39
|
+
|
|
40
|
+
Splits when:
|
|
41
|
+
1. Gap between captions exceeds segment_max_gap
|
|
42
|
+
2. Duration approaches max_duration (adaptive mode only) and there's a reasonable break
|
|
43
|
+
3. Duration significantly exceeds max_duration (adaptive mode only)
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
caption: Caption object with supervisions
|
|
47
|
+
max_duration: Optional maximum segment duration (enables adaptive behavior)
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
List of (start_time, end_time, supervisions) tuples for each segment
|
|
51
|
+
"""
|
|
52
|
+
if not max_duration:
|
|
53
|
+
max_duration = self.config.segment_duration
|
|
54
|
+
|
|
55
|
+
if not caption.supervisions:
|
|
56
|
+
return []
|
|
57
|
+
|
|
58
|
+
supervisions = sorted(caption.supervisions, key=lambda s: s.start)
|
|
59
|
+
|
|
60
|
+
segments = []
|
|
61
|
+
current_segment_sups = []
|
|
62
|
+
|
|
63
|
+
def should_skipalign(sups):
|
|
64
|
+
return len(sups) == 1 and sups[0].text.strip().startswith("[") and sups[0].text.strip().endswith("]")
|
|
65
|
+
|
|
66
|
+
for i, sup in enumerate(supervisions):
|
|
67
|
+
if not current_segment_sups:
|
|
68
|
+
current_segment_sups.append(sup)
|
|
69
|
+
if should_skipalign(current_segment_sups):
|
|
70
|
+
# Single [APPLAUSE] caption, make its own segment
|
|
71
|
+
segments.append(
|
|
72
|
+
(current_segment_sups[0].start, current_segment_sups[-1].end, current_segment_sups, True)
|
|
73
|
+
)
|
|
74
|
+
current_segment_sups = []
|
|
75
|
+
continue
|
|
76
|
+
|
|
77
|
+
prev_sup = supervisions[i - 1]
|
|
78
|
+
|
|
79
|
+
gap = max(sup.start - prev_sup.end, 0.0)
|
|
80
|
+
# Always split on large gaps (natural breaks)
|
|
81
|
+
exclude_max_gap = False
|
|
82
|
+
if gap > self.config.segment_max_gap:
|
|
83
|
+
exclude_max_gap = True
|
|
84
|
+
|
|
85
|
+
endswith_punc = any(sup.text.endswith(punc) for punc in END_PUNCTUATION)
|
|
86
|
+
|
|
87
|
+
# Adaptive duration control
|
|
88
|
+
segment_duration = sup.end - current_segment_sups[0].start
|
|
89
|
+
|
|
90
|
+
# Split if approaching duration limit and there's a reasonable break
|
|
91
|
+
should_split = False
|
|
92
|
+
if segment_duration >= max_duration * 0.8 and gap >= 1.0:
|
|
93
|
+
should_split = True
|
|
94
|
+
|
|
95
|
+
# Force split if duration exceeded significantly
|
|
96
|
+
exclude_max_duration = False
|
|
97
|
+
if segment_duration >= max_duration * 1.2:
|
|
98
|
+
exclude_max_duration = True
|
|
99
|
+
|
|
100
|
+
# [APPLAUSE] [APPLAUSE] [MUSIC]
|
|
101
|
+
if sup.text.strip().startswith("[") and sup.text.strip().endswith("]"):
|
|
102
|
+
# Close current segment
|
|
103
|
+
if current_segment_sups:
|
|
104
|
+
segment_start = current_segment_sups[0].start
|
|
105
|
+
segment_end = current_segment_sups[-1].end + min(gap / 2.0, 2.0)
|
|
106
|
+
segments.append(
|
|
107
|
+
(segment_start, segment_end, current_segment_sups, should_skipalign(current_segment_sups))
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
# Add current supervision as its own segment
|
|
111
|
+
segments.append((sup.start, sup.end, [sup], True))
|
|
112
|
+
|
|
113
|
+
# Update reset for new segment
|
|
114
|
+
current_segment_sups = []
|
|
115
|
+
continue
|
|
116
|
+
|
|
117
|
+
if (should_split and endswith_punc) or exclude_max_gap or exclude_max_duration:
|
|
118
|
+
# Close current segment
|
|
119
|
+
if current_segment_sups:
|
|
120
|
+
segment_start = current_segment_sups[0].start
|
|
121
|
+
segment_end = current_segment_sups[-1].end + min(gap / 2.0, 2.0)
|
|
122
|
+
segments.append(
|
|
123
|
+
(segment_start, segment_end, current_segment_sups, should_skipalign(current_segment_sups))
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
# Start new segment
|
|
127
|
+
current_segment_sups = [sup]
|
|
128
|
+
else:
|
|
129
|
+
current_segment_sups.append(sup)
|
|
130
|
+
|
|
131
|
+
# Add final segment
|
|
132
|
+
if current_segment_sups:
|
|
133
|
+
segment_start = current_segment_sups[0].start
|
|
134
|
+
segment_end = current_segment_sups[-1].end + 2.0
|
|
135
|
+
segments.append((segment_start, segment_end, current_segment_sups, should_skipalign(current_segment_sups)))
|
|
136
|
+
|
|
137
|
+
return segments
|
|
138
|
+
|
|
139
|
+
def print_segment_info(
|
|
140
|
+
self,
|
|
141
|
+
segments: List[Tuple[float, float, List[Supervision]]],
|
|
142
|
+
verbose: bool = True,
|
|
143
|
+
) -> None:
|
|
144
|
+
"""
|
|
145
|
+
Print information about created segments.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
segments: List of segment tuples
|
|
149
|
+
verbose: Whether to print detailed info
|
|
150
|
+
"""
|
|
151
|
+
if not verbose:
|
|
152
|
+
return
|
|
153
|
+
|
|
154
|
+
total_sups = sum(len(sups) if isinstance(sups, list) else 1 for _, _, sups, _ in segments)
|
|
155
|
+
|
|
156
|
+
print(colorful.cyan(f"📊 Created {len(segments)} alignment segments:"))
|
|
157
|
+
for i, (start, end, sups, _) in enumerate(segments, 1):
|
|
158
|
+
duration = end - start
|
|
159
|
+
print(
|
|
160
|
+
colorful.white(
|
|
161
|
+
f" Segment {i:04d}: {start:8.2f}s - {end:8.2f}s "
|
|
162
|
+
f"(duration: {duration:8.2f}s, supervisions: {len(sups)if isinstance(sups, list) else 1:4d})"
|
|
163
|
+
)
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
print(colorful.green(f" Total: {total_sups} supervisions across {len(segments)} segments"))
|