lattifai 0.4.5__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lattifai/__init__.py +61 -47
- lattifai/alignment/__init__.py +6 -0
- lattifai/alignment/lattice1_aligner.py +119 -0
- lattifai/alignment/lattice1_worker.py +185 -0
- lattifai/{tokenizer → alignment}/phonemizer.py +4 -4
- lattifai/alignment/segmenter.py +166 -0
- lattifai/{tokenizer → alignment}/tokenizer.py +244 -169
- lattifai/audio2.py +211 -0
- lattifai/caption/__init__.py +20 -0
- lattifai/caption/caption.py +1275 -0
- lattifai/{io → caption}/gemini_reader.py +30 -30
- lattifai/{io → caption}/gemini_writer.py +17 -17
- lattifai/{io → caption}/supervision.py +4 -3
- lattifai/caption/text_parser.py +145 -0
- lattifai/cli/__init__.py +17 -0
- lattifai/cli/alignment.py +153 -0
- lattifai/cli/caption.py +204 -0
- lattifai/cli/server.py +19 -0
- lattifai/cli/transcribe.py +197 -0
- lattifai/cli/youtube.py +128 -0
- lattifai/client.py +460 -251
- lattifai/config/__init__.py +20 -0
- lattifai/config/alignment.py +73 -0
- lattifai/config/caption.py +178 -0
- lattifai/config/client.py +46 -0
- lattifai/config/diarization.py +67 -0
- lattifai/config/media.py +335 -0
- lattifai/config/transcription.py +84 -0
- lattifai/diarization/__init__.py +5 -0
- lattifai/diarization/lattifai.py +89 -0
- lattifai/errors.py +98 -91
- lattifai/logging.py +116 -0
- lattifai/mixin.py +552 -0
- lattifai/server/app.py +420 -0
- lattifai/transcription/__init__.py +76 -0
- lattifai/transcription/base.py +108 -0
- lattifai/transcription/gemini.py +219 -0
- lattifai/transcription/lattifai.py +103 -0
- lattifai/{workflows → transcription}/prompts/__init__.py +4 -4
- lattifai/types.py +30 -0
- lattifai/utils.py +16 -44
- lattifai/workflow/__init__.py +22 -0
- lattifai/workflow/agents.py +6 -0
- lattifai/{workflows → workflow}/base.py +22 -22
- lattifai/{workflows → workflow}/file_manager.py +239 -215
- lattifai/workflow/youtube.py +564 -0
- lattifai-1.0.0.dist-info/METADATA +736 -0
- lattifai-1.0.0.dist-info/RECORD +52 -0
- {lattifai-0.4.5.dist-info → lattifai-1.0.0.dist-info}/WHEEL +1 -1
- lattifai-1.0.0.dist-info/entry_points.txt +13 -0
- {lattifai-0.4.5.dist-info → lattifai-1.0.0.dist-info}/licenses/LICENSE +1 -1
- lattifai/base_client.py +0 -126
- lattifai/bin/__init__.py +0 -3
- lattifai/bin/agent.py +0 -325
- lattifai/bin/align.py +0 -296
- lattifai/bin/cli_base.py +0 -25
- lattifai/bin/subtitle.py +0 -210
- lattifai/io/__init__.py +0 -42
- lattifai/io/reader.py +0 -85
- lattifai/io/text_parser.py +0 -75
- lattifai/io/utils.py +0 -15
- lattifai/io/writer.py +0 -90
- lattifai/tokenizer/__init__.py +0 -3
- lattifai/workers/__init__.py +0 -3
- lattifai/workers/lattice1_alpha.py +0 -284
- lattifai/workflows/__init__.py +0 -34
- lattifai/workflows/agents.py +0 -10
- lattifai/workflows/gemini.py +0 -167
- lattifai/workflows/prompts/README.md +0 -22
- lattifai/workflows/prompts/gemini/README.md +0 -24
- lattifai/workflows/prompts/gemini/transcription_gem.txt +0 -81
- lattifai/workflows/youtube.py +0 -931
- lattifai-0.4.5.dist-info/METADATA +0 -808
- lattifai-0.4.5.dist-info/RECORD +0 -39
- lattifai-0.4.5.dist-info/entry_points.txt +0 -3
- {lattifai-0.4.5.dist-info → lattifai-1.0.0.dist-info}/top_level.txt +0 -0
lattifai/__init__.py
CHANGED
|
@@ -1,34 +1,45 @@
|
|
|
1
|
-
import os
|
|
2
1
|
import sys
|
|
3
2
|
import warnings
|
|
3
|
+
from importlib.metadata import version
|
|
4
4
|
|
|
5
|
+
# Re-export I/O classes
|
|
6
|
+
from .caption import Caption
|
|
7
|
+
|
|
8
|
+
# Re-export client classes
|
|
9
|
+
from .client import LattifAI
|
|
10
|
+
|
|
11
|
+
# Re-export config classes
|
|
12
|
+
from .config import (
|
|
13
|
+
AUDIO_FORMATS,
|
|
14
|
+
MEDIA_FORMATS,
|
|
15
|
+
VIDEO_FORMATS,
|
|
16
|
+
AlignmentConfig,
|
|
17
|
+
CaptionConfig,
|
|
18
|
+
ClientConfig,
|
|
19
|
+
DiarizationConfig,
|
|
20
|
+
MediaConfig,
|
|
21
|
+
)
|
|
5
22
|
from .errors import (
|
|
6
23
|
AlignmentError,
|
|
7
24
|
APIError,
|
|
8
25
|
AudioFormatError,
|
|
9
26
|
AudioLoadError,
|
|
10
27
|
AudioProcessingError,
|
|
28
|
+
CaptionParseError,
|
|
29
|
+
CaptionProcessingError,
|
|
11
30
|
ConfigurationError,
|
|
12
31
|
DependencyError,
|
|
13
32
|
LatticeDecodingError,
|
|
14
33
|
LatticeEncodingError,
|
|
15
34
|
LattifAIError,
|
|
16
35
|
ModelLoadError,
|
|
17
|
-
SubtitleParseError,
|
|
18
|
-
SubtitleProcessingError,
|
|
19
36
|
)
|
|
20
|
-
from .
|
|
21
|
-
|
|
22
|
-
try:
|
|
23
|
-
from importlib.metadata import version
|
|
24
|
-
except ImportError:
|
|
25
|
-
# Python < 3.8
|
|
26
|
-
from importlib_metadata import version
|
|
37
|
+
from .logging import get_logger, set_log_level, setup_logger
|
|
27
38
|
|
|
28
39
|
try:
|
|
29
|
-
__version__ = version(
|
|
40
|
+
__version__ = version("lattifai")
|
|
30
41
|
except Exception:
|
|
31
|
-
__version__ =
|
|
42
|
+
__version__ = "0.1.0" # fallback version
|
|
32
43
|
|
|
33
44
|
|
|
34
45
|
# Check and auto-install k2 if not present
|
|
@@ -39,15 +50,15 @@ def _check_and_install_k2():
|
|
|
39
50
|
except ImportError:
|
|
40
51
|
import subprocess
|
|
41
52
|
|
|
42
|
-
print(
|
|
53
|
+
print("k2 is not installed. Attempting to install k2...")
|
|
43
54
|
try:
|
|
44
|
-
subprocess.check_call([sys.executable,
|
|
45
|
-
subprocess.check_call([sys.executable,
|
|
55
|
+
subprocess.check_call([sys.executable, "-m", "pip", "install", "install-k2"])
|
|
56
|
+
subprocess.check_call([sys.executable, "-m", "install_k2"])
|
|
46
57
|
import k2 # Try importing again after installation
|
|
47
58
|
|
|
48
|
-
print(
|
|
59
|
+
print("k2 installed successfully.")
|
|
49
60
|
except Exception as e:
|
|
50
|
-
warnings.warn(f
|
|
61
|
+
warnings.warn(f"Failed to install k2 automatically. Please install it manually. Error: {e}")
|
|
51
62
|
return True
|
|
52
63
|
|
|
53
64
|
|
|
@@ -55,35 +66,38 @@ def _check_and_install_k2():
|
|
|
55
66
|
_check_and_install_k2()
|
|
56
67
|
|
|
57
68
|
|
|
58
|
-
# Lazy import for LattifAI to avoid dependency issues during basic import
|
|
59
|
-
def __getattr__(name):
|
|
60
|
-
if name == 'LattifAI':
|
|
61
|
-
from .client import LattifAI
|
|
62
|
-
|
|
63
|
-
return LattifAI
|
|
64
|
-
if name == 'AsyncLattifAI':
|
|
65
|
-
from .client import AsyncLattifAI
|
|
66
|
-
|
|
67
|
-
return AsyncLattifAI
|
|
68
|
-
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
|
69
|
-
|
|
70
|
-
|
|
71
69
|
__all__ = [
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
70
|
+
# Client classes
|
|
71
|
+
"LattifAI",
|
|
72
|
+
# Config classes
|
|
73
|
+
"AlignmentConfig",
|
|
74
|
+
"ClientConfig",
|
|
75
|
+
"CaptionConfig",
|
|
76
|
+
"DiarizationConfig",
|
|
77
|
+
"MediaConfig",
|
|
78
|
+
"AUDIO_FORMATS",
|
|
79
|
+
"VIDEO_FORMATS",
|
|
80
|
+
"MEDIA_FORMATS",
|
|
81
|
+
# Error classes
|
|
82
|
+
"LattifAIError",
|
|
83
|
+
"AudioProcessingError",
|
|
84
|
+
"AudioLoadError",
|
|
85
|
+
"AudioFormatError",
|
|
86
|
+
"CaptionProcessingError",
|
|
87
|
+
"CaptionParseError",
|
|
88
|
+
"AlignmentError",
|
|
89
|
+
"LatticeEncodingError",
|
|
90
|
+
"LatticeDecodingError",
|
|
91
|
+
"ModelLoadError",
|
|
92
|
+
"DependencyError",
|
|
93
|
+
"APIError",
|
|
94
|
+
"ConfigurationError",
|
|
95
|
+
# Logging
|
|
96
|
+
"setup_logger",
|
|
97
|
+
"get_logger",
|
|
98
|
+
"set_log_level",
|
|
99
|
+
# I/O
|
|
100
|
+
"Caption",
|
|
101
|
+
# Version
|
|
102
|
+
"__version__",
|
|
89
103
|
]
|
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
"""Lattice-1 Aligner implementation."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, List, Optional, Tuple
|
|
4
|
+
|
|
5
|
+
import colorful
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from lattifai.audio2 import AudioData
|
|
9
|
+
from lattifai.caption import Supervision
|
|
10
|
+
from lattifai.config import AlignmentConfig
|
|
11
|
+
from lattifai.errors import (
|
|
12
|
+
AlignmentError,
|
|
13
|
+
LatticeDecodingError,
|
|
14
|
+
LatticeEncodingError,
|
|
15
|
+
)
|
|
16
|
+
from lattifai.utils import _resolve_model_path
|
|
17
|
+
|
|
18
|
+
from .lattice1_worker import _load_worker
|
|
19
|
+
from .tokenizer import _load_tokenizer
|
|
20
|
+
|
|
21
|
+
ClientType = Any
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class Lattice1Aligner(object):
|
|
25
|
+
"""Synchronous LattifAI client with config-driven architecture."""
|
|
26
|
+
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
config: AlignmentConfig,
|
|
30
|
+
) -> None:
|
|
31
|
+
self.config = config
|
|
32
|
+
|
|
33
|
+
if config.client_wrapper is None:
|
|
34
|
+
raise ValueError("AlignmentConfig.client_wrapper is not set. It must be initialized by the client.")
|
|
35
|
+
|
|
36
|
+
client_wrapper = config.client_wrapper
|
|
37
|
+
model_path = _resolve_model_path(config.model_name)
|
|
38
|
+
|
|
39
|
+
self.tokenizer = _load_tokenizer(client_wrapper, model_path, config.model_name, config.device)
|
|
40
|
+
self.worker = _load_worker(model_path, config.device)
|
|
41
|
+
|
|
42
|
+
self.frame_shift = self.worker.frame_shift
|
|
43
|
+
|
|
44
|
+
def emission(self, audio: torch.Tensor) -> torch.Tensor:
|
|
45
|
+
return self.worker.emission(audio.to(self.worker.device))
|
|
46
|
+
|
|
47
|
+
def alignment(
|
|
48
|
+
self,
|
|
49
|
+
audio: AudioData,
|
|
50
|
+
supervisions: List[Supervision],
|
|
51
|
+
split_sentence: Optional[bool] = False,
|
|
52
|
+
return_details: Optional[bool] = False,
|
|
53
|
+
emission: Optional[torch.Tensor] = None,
|
|
54
|
+
offset: float = 0.0,
|
|
55
|
+
verbose: bool = True,
|
|
56
|
+
) -> Tuple[List[Supervision], List[Supervision]]:
|
|
57
|
+
"""
|
|
58
|
+
Perform alignment on audio and supervisions.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
audio: Audio file path
|
|
62
|
+
supervisions: List of supervision segments to align
|
|
63
|
+
split_sentence: Enable sentence re-splitting
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
Tuple of (supervisions, alignments)
|
|
67
|
+
|
|
68
|
+
Raises:
|
|
69
|
+
LatticeEncodingError: If lattice graph generation fails
|
|
70
|
+
AlignmentError: If audio alignment fails
|
|
71
|
+
LatticeDecodingError: If lattice decoding fails
|
|
72
|
+
"""
|
|
73
|
+
try:
|
|
74
|
+
if verbose:
|
|
75
|
+
print(colorful.cyan("🔗 Step 2: Creating lattice graph from segments"))
|
|
76
|
+
try:
|
|
77
|
+
supervisions, lattice_id, lattice_graph = self.tokenizer.tokenize(
|
|
78
|
+
supervisions, split_sentence=split_sentence
|
|
79
|
+
)
|
|
80
|
+
if verbose:
|
|
81
|
+
print(colorful.green(f" ✓ Generated lattice graph with ID: {lattice_id}"))
|
|
82
|
+
except Exception as e:
|
|
83
|
+
text_content = " ".join([sup.text for sup in supervisions]) if supervisions else ""
|
|
84
|
+
raise LatticeEncodingError(text_content, original_error=e)
|
|
85
|
+
|
|
86
|
+
if verbose:
|
|
87
|
+
print(colorful.cyan(f"🔍 Step 3: Searching lattice graph with media: {audio}"))
|
|
88
|
+
try:
|
|
89
|
+
lattice_results = self.worker.alignment(audio, lattice_graph, emission=emission, offset=offset)
|
|
90
|
+
if verbose:
|
|
91
|
+
print(colorful.green(" ✓ Lattice search completed"))
|
|
92
|
+
except Exception as e:
|
|
93
|
+
raise AlignmentError(
|
|
94
|
+
f"Audio alignment failed for {audio}",
|
|
95
|
+
media_path=str(audio),
|
|
96
|
+
context={"original_error": str(e)},
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
if verbose:
|
|
100
|
+
print(colorful.cyan("🎯 Step 4: Decoding lattice results to aligned segments"))
|
|
101
|
+
try:
|
|
102
|
+
alignments = self.tokenizer.detokenize(
|
|
103
|
+
lattice_id, lattice_results, supervisions=supervisions, return_details=return_details
|
|
104
|
+
)
|
|
105
|
+
if verbose:
|
|
106
|
+
print(colorful.green(f" ✓ Successfully aligned {len(alignments)} segments"))
|
|
107
|
+
except LatticeDecodingError as e:
|
|
108
|
+
print(colorful.red(" x Failed to decode lattice alignment results"))
|
|
109
|
+
raise e
|
|
110
|
+
except Exception as e:
|
|
111
|
+
print(colorful.red(" x Failed to decode lattice alignment results"))
|
|
112
|
+
raise LatticeDecodingError(lattice_id, original_error=e)
|
|
113
|
+
|
|
114
|
+
return (supervisions, alignments)
|
|
115
|
+
|
|
116
|
+
except (LatticeEncodingError, AlignmentError, LatticeDecodingError):
|
|
117
|
+
raise
|
|
118
|
+
except Exception as e:
|
|
119
|
+
raise e
|
|
@@ -0,0 +1,185 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import time
|
|
3
|
+
from collections import defaultdict
|
|
4
|
+
from typing import Any, Dict, Optional, Tuple
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import onnxruntime as ort
|
|
8
|
+
import torch
|
|
9
|
+
from lhotse import FbankConfig
|
|
10
|
+
from lhotse.features.kaldi.layers import Wav2LogFilterBank
|
|
11
|
+
from lhotse.utils import Pathlike
|
|
12
|
+
|
|
13
|
+
from lattifai.audio2 import AudioData
|
|
14
|
+
from lattifai.errors import AlignmentError, DependencyError, ModelLoadError
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Lattice1Worker:
|
|
18
|
+
"""Worker for processing audio with LatticeGraph."""
|
|
19
|
+
|
|
20
|
+
def __init__(self, model_path: Pathlike, device: str = "cpu", num_threads: int = 8) -> None:
|
|
21
|
+
try:
|
|
22
|
+
self.config = json.load(open(f"{model_path}/config.json"))
|
|
23
|
+
except Exception as e:
|
|
24
|
+
raise ModelLoadError(f"config from {model_path}", original_error=e)
|
|
25
|
+
|
|
26
|
+
# SessionOptions
|
|
27
|
+
sess_options = ort.SessionOptions()
|
|
28
|
+
# sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
29
|
+
sess_options.intra_op_num_threads = num_threads # CPU cores
|
|
30
|
+
sess_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL
|
|
31
|
+
sess_options.add_session_config_entry("session.intra_op.allow_spinning", "0")
|
|
32
|
+
|
|
33
|
+
providers = []
|
|
34
|
+
if device.startswith("cuda") and ort.get_all_providers().count("CUDAExecutionProvider") > 0:
|
|
35
|
+
providers.append("CUDAExecutionProvider")
|
|
36
|
+
elif device.startswith("mps") and ort.get_all_providers().count("MPSExecutionProvider") > 0:
|
|
37
|
+
providers.append("MPSExecutionProvider")
|
|
38
|
+
|
|
39
|
+
try:
|
|
40
|
+
self.acoustic_ort = ort.InferenceSession(
|
|
41
|
+
f"{model_path}/acoustic_opt.onnx",
|
|
42
|
+
sess_options,
|
|
43
|
+
providers=providers + ["CPUExecutionProvider", "CoreMLExecutionProvider"],
|
|
44
|
+
)
|
|
45
|
+
except Exception as e:
|
|
46
|
+
raise ModelLoadError(f"acoustic model from {model_path}", original_error=e)
|
|
47
|
+
|
|
48
|
+
try:
|
|
49
|
+
config = FbankConfig(num_mel_bins=80, device=device, snip_edges=False)
|
|
50
|
+
config_dict = config.to_dict()
|
|
51
|
+
config_dict.pop("device")
|
|
52
|
+
self.extractor = Wav2LogFilterBank(**config_dict).to(device).eval()
|
|
53
|
+
except Exception as e:
|
|
54
|
+
raise ModelLoadError(f"feature extractor for device {device}", original_error=e)
|
|
55
|
+
|
|
56
|
+
self.device = torch.device(device)
|
|
57
|
+
self.timings = defaultdict(lambda: 0.0)
|
|
58
|
+
|
|
59
|
+
@property
|
|
60
|
+
def frame_shift(self) -> float:
|
|
61
|
+
return 0.02 # 20 ms
|
|
62
|
+
|
|
63
|
+
@torch.inference_mode()
|
|
64
|
+
def emission(self, audio: torch.Tensor) -> torch.Tensor:
|
|
65
|
+
_start = time.time()
|
|
66
|
+
# audio -> features -> emission
|
|
67
|
+
features = self.extractor(audio) # (1, T, D)
|
|
68
|
+
if features.shape[1] > 6000:
|
|
69
|
+
features_list = torch.split(features, 6000, dim=1)
|
|
70
|
+
emissions = []
|
|
71
|
+
for features in features_list:
|
|
72
|
+
ort_inputs = {
|
|
73
|
+
"features": features.cpu().numpy(),
|
|
74
|
+
"feature_lengths": np.array([features.size(1)], dtype=np.int64),
|
|
75
|
+
}
|
|
76
|
+
emission = self.acoustic_ort.run(None, ort_inputs)[0] # (1, T, vocab_size) numpy
|
|
77
|
+
emissions.append(emission)
|
|
78
|
+
emission = torch.cat(
|
|
79
|
+
[torch.from_numpy(emission).to(self.device) for emission in emissions], dim=1
|
|
80
|
+
) # (1, T, vocab_size)
|
|
81
|
+
else:
|
|
82
|
+
ort_inputs = {
|
|
83
|
+
"features": features.cpu().numpy(),
|
|
84
|
+
"feature_lengths": np.array([features.size(1)], dtype=np.int64),
|
|
85
|
+
}
|
|
86
|
+
emission = self.acoustic_ort.run(None, ort_inputs)[0] # (1, T, vocab_size) numpy
|
|
87
|
+
emission = torch.from_numpy(emission).to(self.device)
|
|
88
|
+
|
|
89
|
+
self.timings["emission"] += time.time() - _start
|
|
90
|
+
return emission # (1, T, vocab_size) torch
|
|
91
|
+
|
|
92
|
+
def alignment(
|
|
93
|
+
self,
|
|
94
|
+
audio: AudioData,
|
|
95
|
+
lattice_graph: Tuple[str, int, float],
|
|
96
|
+
emission: Optional[torch.Tensor] = None,
|
|
97
|
+
offset: float = 0.0,
|
|
98
|
+
) -> Dict[str, Any]:
|
|
99
|
+
"""Process audio with LatticeGraph.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
audio: AudioData object
|
|
103
|
+
lattice_graph: LatticeGraph data
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
Processed LatticeGraph
|
|
107
|
+
|
|
108
|
+
Raises:
|
|
109
|
+
AudioLoadError: If audio cannot be loaded
|
|
110
|
+
DependencyError: If required dependencies are missing
|
|
111
|
+
AlignmentError: If alignment process fails
|
|
112
|
+
"""
|
|
113
|
+
if emission is None:
|
|
114
|
+
try:
|
|
115
|
+
emission = self.emission(audio.tensor.to(self.device)) # (1, T, vocab_size)
|
|
116
|
+
except Exception as e:
|
|
117
|
+
raise AlignmentError(
|
|
118
|
+
"Failed to compute acoustic features from audio",
|
|
119
|
+
media_path=str(audio) if not isinstance(audio, torch.Tensor) else "tensor",
|
|
120
|
+
context={"original_error": str(e)},
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
try:
|
|
124
|
+
import k2
|
|
125
|
+
except ImportError:
|
|
126
|
+
raise DependencyError("k2", install_command="pip install install-k2 && python -m install_k2")
|
|
127
|
+
|
|
128
|
+
try:
|
|
129
|
+
from lattifai_core.lattice.decode import align_segments
|
|
130
|
+
except ImportError:
|
|
131
|
+
raise DependencyError("lattifai_core", install_command="Contact support for lattifai_core installation")
|
|
132
|
+
|
|
133
|
+
lattice_graph_str, final_state, acoustic_scale = lattice_graph
|
|
134
|
+
|
|
135
|
+
_start = time.time()
|
|
136
|
+
try:
|
|
137
|
+
# graph
|
|
138
|
+
decoding_graph = k2.Fsa.from_str(lattice_graph_str, acceptor=False)
|
|
139
|
+
decoding_graph.requires_grad_(False)
|
|
140
|
+
decoding_graph = k2.arc_sort(decoding_graph)
|
|
141
|
+
decoding_graph.skip_id = int(final_state)
|
|
142
|
+
decoding_graph.return_id = int(final_state + 1)
|
|
143
|
+
except Exception as e:
|
|
144
|
+
raise AlignmentError(
|
|
145
|
+
"Failed to create decoding graph from lattice",
|
|
146
|
+
context={"original_error": str(e), "lattice_graph_length": len(lattice_graph_str)},
|
|
147
|
+
)
|
|
148
|
+
self.timings["decoding_graph"] += time.time() - _start
|
|
149
|
+
|
|
150
|
+
_start = time.time()
|
|
151
|
+
if self.device.type == "mps":
|
|
152
|
+
device = "cpu" # k2 does not support mps yet
|
|
153
|
+
else:
|
|
154
|
+
device = self.device
|
|
155
|
+
|
|
156
|
+
try:
|
|
157
|
+
results, labels = align_segments(
|
|
158
|
+
emission.to(device) * acoustic_scale,
|
|
159
|
+
decoding_graph.to(device),
|
|
160
|
+
torch.tensor([emission.shape[1]], dtype=torch.int32),
|
|
161
|
+
search_beam=200,
|
|
162
|
+
output_beam=80,
|
|
163
|
+
min_active_states=400,
|
|
164
|
+
max_active_states=10000,
|
|
165
|
+
subsampling_factor=1,
|
|
166
|
+
reject_low_confidence=False,
|
|
167
|
+
)
|
|
168
|
+
except Exception as e:
|
|
169
|
+
raise AlignmentError(
|
|
170
|
+
"Failed to perform forced alignment",
|
|
171
|
+
media_path=str(audio) if not isinstance(audio, torch.Tensor) else "tensor",
|
|
172
|
+
context={"original_error": str(e), "emission_shape": list(emission.shape), "device": str(device)},
|
|
173
|
+
)
|
|
174
|
+
self.timings["align_segments"] += time.time() - _start
|
|
175
|
+
|
|
176
|
+
channel = 0
|
|
177
|
+
return emission, results, labels, self.frame_shift, offset, channel # frame_shift=20ms
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def _load_worker(model_path: str, device: str) -> Lattice1Worker:
|
|
181
|
+
"""Instantiate lattice worker with consistent error handling."""
|
|
182
|
+
try:
|
|
183
|
+
return Lattice1Worker(model_path, device=device, num_threads=8)
|
|
184
|
+
except Exception as e:
|
|
185
|
+
raise ModelLoadError(f"worker from {model_path}", original_error=e)
|
|
@@ -1,16 +1,16 @@
|
|
|
1
1
|
import re
|
|
2
2
|
from typing import List, Optional, Union
|
|
3
3
|
|
|
4
|
-
from
|
|
4
|
+
from g2pp.phonemizer import Phonemizer # g2p-phonemizer
|
|
5
5
|
from num2words import num2words
|
|
6
6
|
|
|
7
|
-
LANGUAGE =
|
|
7
|
+
LANGUAGE = "omni"
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
class G2Phonemizer:
|
|
11
11
|
def __init__(self, model_checkpoint, device):
|
|
12
12
|
self.phonemizer = Phonemizer.from_checkpoint(model_checkpoint, device=device).predictor
|
|
13
|
-
self.pattern = re.compile(r
|
|
13
|
+
self.pattern = re.compile(r"\d+")
|
|
14
14
|
|
|
15
15
|
def num2words(self, word, lang: str):
|
|
16
16
|
matches = self.pattern.findall(word)
|
|
@@ -31,7 +31,7 @@ class G2Phonemizer:
|
|
|
31
31
|
is_list = False
|
|
32
32
|
|
|
33
33
|
predictions = self.phonemizer(
|
|
34
|
-
[self.num2words(word.replace(
|
|
34
|
+
[self.num2words(word.replace(" .", ".").replace(".", " ."), lang=lang or "en") for word in words],
|
|
35
35
|
lang=LANGUAGE,
|
|
36
36
|
batch_size=min(batch_size or len(words), 128),
|
|
37
37
|
num_prons=num_prons,
|
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
"""Segmented alignment for long audio files."""
|
|
2
|
+
|
|
3
|
+
from typing import List, Optional, Tuple
|
|
4
|
+
|
|
5
|
+
import colorful
|
|
6
|
+
|
|
7
|
+
from lattifai.audio2 import AudioData
|
|
8
|
+
from lattifai.caption import Caption, Supervision
|
|
9
|
+
from lattifai.config import AlignmentConfig
|
|
10
|
+
|
|
11
|
+
from .tokenizer import END_PUNCTUATION
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class Segmenter:
|
|
15
|
+
"""
|
|
16
|
+
Handles segmented alignment for long audio/video files.
|
|
17
|
+
|
|
18
|
+
Instead of aligning the entire audio at once (which can be slow and memory-intensive
|
|
19
|
+
for long files), this class splits the alignment into manageable segments based on
|
|
20
|
+
caption boundaries, time intervals, or an adaptive strategy.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(self, config: AlignmentConfig):
|
|
24
|
+
"""
|
|
25
|
+
Initialize segmented aligner.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
config: Alignment configuration with segmentation parameters
|
|
29
|
+
"""
|
|
30
|
+
self.config = config
|
|
31
|
+
|
|
32
|
+
def __call__(
|
|
33
|
+
self,
|
|
34
|
+
caption: Caption,
|
|
35
|
+
max_duration: Optional[float] = None,
|
|
36
|
+
) -> List[Tuple[float, float, List[Supervision]]]:
|
|
37
|
+
"""
|
|
38
|
+
Create segments based on caption boundaries and gaps.
|
|
39
|
+
|
|
40
|
+
Splits when:
|
|
41
|
+
1. Gap between captions exceeds segment_max_gap
|
|
42
|
+
2. Duration approaches max_duration (adaptive mode only) and there's a reasonable break
|
|
43
|
+
3. Duration significantly exceeds max_duration (adaptive mode only)
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
caption: Caption object with supervisions
|
|
47
|
+
max_duration: Optional maximum segment duration (enables adaptive behavior)
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
List of (start_time, end_time, supervisions) tuples for each segment
|
|
51
|
+
"""
|
|
52
|
+
if not max_duration:
|
|
53
|
+
max_duration = self.config.segment_duration
|
|
54
|
+
|
|
55
|
+
if not caption.supervisions:
|
|
56
|
+
return []
|
|
57
|
+
|
|
58
|
+
supervisions = sorted(caption.supervisions, key=lambda s: s.start)
|
|
59
|
+
|
|
60
|
+
segments = []
|
|
61
|
+
current_segment_sups = []
|
|
62
|
+
|
|
63
|
+
def should_skipalign(sups):
|
|
64
|
+
return len(sups) == 1 and sups[0].text.strip().startswith("[") and sups[0].text.strip().endswith("]")
|
|
65
|
+
|
|
66
|
+
for i, sup in enumerate(supervisions):
|
|
67
|
+
if not current_segment_sups:
|
|
68
|
+
current_segment_sups.append(sup)
|
|
69
|
+
if should_skipalign(current_segment_sups):
|
|
70
|
+
# Single [APPLAUSE] caption, make its own segment
|
|
71
|
+
segments.append(
|
|
72
|
+
(current_segment_sups[0].start, current_segment_sups[-1].end, current_segment_sups, True)
|
|
73
|
+
)
|
|
74
|
+
current_segment_sups = []
|
|
75
|
+
continue
|
|
76
|
+
|
|
77
|
+
prev_sup = supervisions[i - 1]
|
|
78
|
+
|
|
79
|
+
gap = max(sup.start - prev_sup.end, 0.0)
|
|
80
|
+
# Always split on large gaps (natural breaks)
|
|
81
|
+
exclude_max_gap = False
|
|
82
|
+
if gap > self.config.segment_max_gap:
|
|
83
|
+
exclude_max_gap = True
|
|
84
|
+
|
|
85
|
+
endswith_punc = any(sup.text.endswith(punc) for punc in END_PUNCTUATION)
|
|
86
|
+
|
|
87
|
+
# Adaptive duration control
|
|
88
|
+
segment_duration = sup.end - current_segment_sups[0].start
|
|
89
|
+
|
|
90
|
+
# Split if approaching duration limit and there's a reasonable break
|
|
91
|
+
should_split = False
|
|
92
|
+
if segment_duration >= max_duration * 0.8 and gap >= 1.0:
|
|
93
|
+
should_split = True
|
|
94
|
+
|
|
95
|
+
# Force split if duration exceeded significantly
|
|
96
|
+
exclude_max_duration = False
|
|
97
|
+
if segment_duration >= max_duration * 1.2:
|
|
98
|
+
exclude_max_duration = True
|
|
99
|
+
|
|
100
|
+
# [APPLAUSE] [APPLAUSE] [MUSIC]
|
|
101
|
+
if sup.text.strip().startswith("[") and sup.text.strip().endswith("]"):
|
|
102
|
+
# Close current segment
|
|
103
|
+
if current_segment_sups:
|
|
104
|
+
segment_start = current_segment_sups[0].start
|
|
105
|
+
segment_end = current_segment_sups[-1].end + min(gap / 2.0, 2.0)
|
|
106
|
+
segments.append(
|
|
107
|
+
(segment_start, segment_end, current_segment_sups, should_skipalign(current_segment_sups))
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
# Add current supervision as its own segment
|
|
111
|
+
segments.append((sup.start, sup.end, [sup], True))
|
|
112
|
+
|
|
113
|
+
# Update reset for new segment
|
|
114
|
+
current_segment_sups = []
|
|
115
|
+
continue
|
|
116
|
+
|
|
117
|
+
if (should_split and endswith_punc) or exclude_max_gap or exclude_max_duration:
|
|
118
|
+
# Close current segment
|
|
119
|
+
if current_segment_sups:
|
|
120
|
+
segment_start = current_segment_sups[0].start
|
|
121
|
+
segment_end = current_segment_sups[-1].end + min(gap / 2.0, 2.0)
|
|
122
|
+
segments.append(
|
|
123
|
+
(segment_start, segment_end, current_segment_sups, should_skipalign(current_segment_sups))
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
# Start new segment
|
|
127
|
+
current_segment_sups = [sup]
|
|
128
|
+
else:
|
|
129
|
+
current_segment_sups.append(sup)
|
|
130
|
+
|
|
131
|
+
# Add final segment
|
|
132
|
+
if current_segment_sups:
|
|
133
|
+
segment_start = current_segment_sups[0].start
|
|
134
|
+
segment_end = current_segment_sups[-1].end + 2.0
|
|
135
|
+
segments.append((segment_start, segment_end, current_segment_sups, should_skipalign(current_segment_sups)))
|
|
136
|
+
|
|
137
|
+
return segments
|
|
138
|
+
|
|
139
|
+
def print_segment_info(
|
|
140
|
+
self,
|
|
141
|
+
segments: List[Tuple[float, float, List[Supervision]]],
|
|
142
|
+
verbose: bool = True,
|
|
143
|
+
) -> None:
|
|
144
|
+
"""
|
|
145
|
+
Print information about created segments.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
segments: List of segment tuples
|
|
149
|
+
verbose: Whether to print detailed info
|
|
150
|
+
"""
|
|
151
|
+
if not verbose:
|
|
152
|
+
return
|
|
153
|
+
|
|
154
|
+
total_sups = sum(len(sups) if isinstance(sups, list) else 1 for _, _, sups, _ in segments)
|
|
155
|
+
|
|
156
|
+
print(colorful.cyan(f"📊 Created {len(segments)} alignment segments:"))
|
|
157
|
+
for i, (start, end, sups, _) in enumerate(segments, 1):
|
|
158
|
+
duration = end - start
|
|
159
|
+
print(
|
|
160
|
+
colorful.white(
|
|
161
|
+
f" Segment {i:04d}: {start:8.2f}s - {end:8.2f}s "
|
|
162
|
+
f"(duration: {duration:8.2f}s, supervisions: {len(sups)if isinstance(sups, list) else 1:4d})"
|
|
163
|
+
)
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
print(colorful.green(f" Total: {total_sups} supervisions across {len(segments)} segments"))
|