lattifai 1.0.4__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lattifai/__init__.py +10 -0
- lattifai/alignment/lattice1_aligner.py +64 -15
- lattifai/alignment/lattice1_worker.py +135 -50
- lattifai/alignment/segmenter.py +3 -2
- lattifai/alignment/tokenizer.py +14 -13
- lattifai/audio2.py +269 -70
- lattifai/caption/caption.py +213 -19
- lattifai/cli/__init__.py +2 -0
- lattifai/cli/alignment.py +2 -1
- lattifai/cli/app_installer.py +35 -33
- lattifai/cli/caption.py +9 -19
- lattifai/cli/diarization.py +108 -0
- lattifai/cli/server.py +3 -1
- lattifai/cli/transcribe.py +55 -38
- lattifai/cli/youtube.py +1 -0
- lattifai/client.py +42 -121
- lattifai/config/alignment.py +37 -2
- lattifai/config/caption.py +1 -1
- lattifai/config/media.py +23 -3
- lattifai/config/transcription.py +4 -0
- lattifai/diarization/lattifai.py +18 -7
- lattifai/errors.py +7 -3
- lattifai/mixin.py +45 -16
- lattifai/server/app.py +2 -1
- lattifai/transcription/__init__.py +1 -1
- lattifai/transcription/base.py +21 -2
- lattifai/transcription/gemini.py +127 -1
- lattifai/transcription/lattifai.py +30 -2
- lattifai/utils.py +96 -28
- lattifai/workflow/file_manager.py +15 -13
- lattifai/workflow/youtube.py +16 -1
- {lattifai-1.0.4.dist-info → lattifai-1.1.0.dist-info}/METADATA +86 -22
- lattifai-1.1.0.dist-info/RECORD +57 -0
- {lattifai-1.0.4.dist-info → lattifai-1.1.0.dist-info}/entry_points.txt +2 -0
- {lattifai-1.0.4.dist-info → lattifai-1.1.0.dist-info}/licenses/LICENSE +1 -1
- lattifai-1.0.4.dist-info/RECORD +0 -56
- {lattifai-1.0.4.dist-info → lattifai-1.1.0.dist-info}/WHEEL +0 -0
- {lattifai-1.0.4.dist-info → lattifai-1.1.0.dist-info}/top_level.txt +0 -0
lattifai/__init__.py
CHANGED
|
@@ -1,7 +1,17 @@
|
|
|
1
|
+
import os
|
|
1
2
|
import sys
|
|
2
3
|
import warnings
|
|
3
4
|
from importlib.metadata import version
|
|
4
5
|
|
|
6
|
+
# Suppress SWIG deprecation warnings before any imports
|
|
7
|
+
warnings.filterwarnings("ignore", category=DeprecationWarning, message=".*SwigPy.*")
|
|
8
|
+
|
|
9
|
+
# Suppress PyTorch transformer nested tensor warning
|
|
10
|
+
warnings.filterwarnings("ignore", category=UserWarning, message=".*enable_nested_tensor.*")
|
|
11
|
+
|
|
12
|
+
# Disable tokenizers parallelism warning
|
|
13
|
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
14
|
+
|
|
5
15
|
# Re-export I/O classes
|
|
6
16
|
from .caption import Caption
|
|
7
17
|
|
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
from typing import Any, List, Optional, Tuple
|
|
4
4
|
|
|
5
5
|
import colorful
|
|
6
|
+
import numpy as np
|
|
6
7
|
import torch
|
|
7
8
|
|
|
8
9
|
from lattifai.audio2 import AudioData
|
|
@@ -13,7 +14,7 @@ from lattifai.errors import (
|
|
|
13
14
|
LatticeDecodingError,
|
|
14
15
|
LatticeEncodingError,
|
|
15
16
|
)
|
|
16
|
-
from lattifai.utils import _resolve_model_path
|
|
17
|
+
from lattifai.utils import _resolve_model_path, safe_print
|
|
17
18
|
|
|
18
19
|
from .lattice1_worker import _load_worker
|
|
19
20
|
from .tokenizer import _load_tokenizer
|
|
@@ -34,15 +35,47 @@ class Lattice1Aligner(object):
|
|
|
34
35
|
raise ValueError("AlignmentConfig.client_wrapper is not set. It must be initialized by the client.")
|
|
35
36
|
|
|
36
37
|
client_wrapper = config.client_wrapper
|
|
37
|
-
|
|
38
|
+
# Resolve model path using configured model hub
|
|
39
|
+
model_path = _resolve_model_path(config.model_name, getattr(config, "model_hub", "huggingface"))
|
|
38
40
|
|
|
39
41
|
self.tokenizer = _load_tokenizer(client_wrapper, model_path, config.model_name, config.device)
|
|
40
|
-
self.worker = _load_worker(model_path, config.device)
|
|
42
|
+
self.worker = _load_worker(model_path, config.device, config)
|
|
41
43
|
|
|
42
44
|
self.frame_shift = self.worker.frame_shift
|
|
43
45
|
|
|
44
|
-
def emission(self,
|
|
45
|
-
|
|
46
|
+
def emission(self, ndarray: np.ndarray) -> torch.Tensor:
|
|
47
|
+
"""Generate emission probabilities from audio ndarray.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
ndarray: Audio data as numpy array of shape (1, T) or (C, T)
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
Emission tensor of shape (1, T, vocab_size)
|
|
54
|
+
"""
|
|
55
|
+
return self.worker.emission(ndarray)
|
|
56
|
+
|
|
57
|
+
def separate(self, audio: np.ndarray) -> np.ndarray:
|
|
58
|
+
"""Separate audio using separator model.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
audio: np.ndarray object containing the audio to separate, shape (1, T)
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
Separated audio as numpy array
|
|
65
|
+
|
|
66
|
+
Raises:
|
|
67
|
+
RuntimeError: If separator model is not available
|
|
68
|
+
"""
|
|
69
|
+
if self.worker.separator_ort is None:
|
|
70
|
+
raise RuntimeError("Separator model not available. separator.onnx not found in model path.")
|
|
71
|
+
|
|
72
|
+
# Run separator model
|
|
73
|
+
separator_output = self.worker.separator_ort.run(
|
|
74
|
+
None,
|
|
75
|
+
{"audio": audio},
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
return separator_output[0]
|
|
46
79
|
|
|
47
80
|
def alignment(
|
|
48
81
|
self,
|
|
@@ -72,23 +105,34 @@ class Lattice1Aligner(object):
|
|
|
72
105
|
"""
|
|
73
106
|
try:
|
|
74
107
|
if verbose:
|
|
75
|
-
|
|
108
|
+
safe_print(colorful.cyan("🔗 Step 2: Creating lattice graph from segments"))
|
|
76
109
|
try:
|
|
77
110
|
supervisions, lattice_id, lattice_graph = self.tokenizer.tokenize(
|
|
78
111
|
supervisions, split_sentence=split_sentence
|
|
79
112
|
)
|
|
80
113
|
if verbose:
|
|
81
|
-
|
|
114
|
+
safe_print(colorful.green(f" ✓ Generated lattice graph with ID: {lattice_id}"))
|
|
82
115
|
except Exception as e:
|
|
83
116
|
text_content = " ".join([sup.text for sup in supervisions]) if supervisions else ""
|
|
84
117
|
raise LatticeEncodingError(text_content, original_error=e)
|
|
85
118
|
|
|
86
119
|
if verbose:
|
|
87
|
-
|
|
120
|
+
safe_print(colorful.cyan(f"🔍 Step 3: Searching lattice graph with media: {audio}"))
|
|
121
|
+
if audio.streaming_chunk_secs:
|
|
122
|
+
safe_print(
|
|
123
|
+
colorful.yellow(
|
|
124
|
+
f" ⚡Using streaming mode with {audio.streaming_chunk_secs}s (chunk duration)"
|
|
125
|
+
)
|
|
126
|
+
)
|
|
88
127
|
try:
|
|
89
|
-
lattice_results = self.worker.alignment(
|
|
128
|
+
lattice_results = self.worker.alignment(
|
|
129
|
+
audio,
|
|
130
|
+
lattice_graph,
|
|
131
|
+
emission=emission,
|
|
132
|
+
offset=offset,
|
|
133
|
+
)
|
|
90
134
|
if verbose:
|
|
91
|
-
|
|
135
|
+
safe_print(colorful.green(" ✓ Lattice search completed"))
|
|
92
136
|
except Exception as e:
|
|
93
137
|
raise AlignmentError(
|
|
94
138
|
f"Audio alignment failed for {audio}",
|
|
@@ -97,18 +141,23 @@ class Lattice1Aligner(object):
|
|
|
97
141
|
)
|
|
98
142
|
|
|
99
143
|
if verbose:
|
|
100
|
-
|
|
144
|
+
safe_print(colorful.cyan("🎯 Step 4: Decoding lattice results to aligned segments"))
|
|
101
145
|
try:
|
|
102
146
|
alignments = self.tokenizer.detokenize(
|
|
103
|
-
lattice_id,
|
|
147
|
+
lattice_id,
|
|
148
|
+
lattice_results,
|
|
149
|
+
supervisions=supervisions,
|
|
150
|
+
return_details=return_details,
|
|
151
|
+
start_margin=self.config.start_margin,
|
|
152
|
+
end_margin=self.config.end_margin,
|
|
104
153
|
)
|
|
105
154
|
if verbose:
|
|
106
|
-
|
|
155
|
+
safe_print(colorful.green(f" ✓ Successfully aligned {len(alignments)} segments"))
|
|
107
156
|
except LatticeDecodingError as e:
|
|
108
|
-
|
|
157
|
+
safe_print(colorful.red(" x Failed to decode lattice alignment results"))
|
|
109
158
|
raise e
|
|
110
159
|
except Exception as e:
|
|
111
|
-
|
|
160
|
+
safe_print(colorful.red(" x Failed to decode lattice alignment results"))
|
|
112
161
|
raise LatticeDecodingError(lattice_id, original_error=e)
|
|
113
162
|
|
|
114
163
|
return (supervisions, alignments)
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import time
|
|
3
3
|
from collections import defaultdict
|
|
4
|
+
from pathlib import Path
|
|
4
5
|
from typing import Any, Dict, Optional, Tuple
|
|
5
6
|
|
|
6
7
|
import numpy as np
|
|
@@ -9,6 +10,7 @@ import torch
|
|
|
9
10
|
from lhotse import FbankConfig
|
|
10
11
|
from lhotse.features.kaldi.layers import Wav2LogFilterBank
|
|
11
12
|
from lhotse.utils import Pathlike
|
|
13
|
+
from tqdm import tqdm
|
|
12
14
|
|
|
13
15
|
from lattifai.audio2 import AudioData
|
|
14
16
|
from lattifai.errors import AlignmentError, DependencyError, ModelLoadError
|
|
@@ -17,12 +19,17 @@ from lattifai.errors import AlignmentError, DependencyError, ModelLoadError
|
|
|
17
19
|
class Lattice1Worker:
|
|
18
20
|
"""Worker for processing audio with LatticeGraph."""
|
|
19
21
|
|
|
20
|
-
def __init__(
|
|
22
|
+
def __init__(
|
|
23
|
+
self, model_path: Pathlike, device: str = "cpu", num_threads: int = 8, config: Optional[Any] = None
|
|
24
|
+
) -> None:
|
|
21
25
|
try:
|
|
22
|
-
self.
|
|
26
|
+
self.model_config = json.load(open(f"{model_path}/config.json"))
|
|
23
27
|
except Exception as e:
|
|
24
28
|
raise ModelLoadError(f"config from {model_path}", original_error=e)
|
|
25
29
|
|
|
30
|
+
# Store alignment config with beam search parameters
|
|
31
|
+
self.alignment_config = config
|
|
32
|
+
|
|
26
33
|
# SessionOptions
|
|
27
34
|
sess_options = ort.SessionOptions()
|
|
28
35
|
# sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
@@ -67,6 +74,19 @@ class Lattice1Worker:
|
|
|
67
74
|
else:
|
|
68
75
|
self.extractor = None # ONNX model includes feature extractor
|
|
69
76
|
|
|
77
|
+
# Initialize separator if available
|
|
78
|
+
separator_model_path = Path(model_path) / "separator.onnx"
|
|
79
|
+
if separator_model_path.exists():
|
|
80
|
+
try:
|
|
81
|
+
self.separator_ort = ort.InferenceSession(
|
|
82
|
+
str(separator_model_path),
|
|
83
|
+
providers=providers + ["CPUExecutionProvider"],
|
|
84
|
+
)
|
|
85
|
+
except Exception as e:
|
|
86
|
+
raise ModelLoadError(f"separator model from {model_path}", original_error=e)
|
|
87
|
+
else:
|
|
88
|
+
self.separator_ort = None
|
|
89
|
+
|
|
70
90
|
self.device = torch.device(device)
|
|
71
91
|
self.timings = defaultdict(lambda: 0.0)
|
|
72
92
|
|
|
@@ -75,59 +95,74 @@ class Lattice1Worker:
|
|
|
75
95
|
return 0.02 # 20 ms
|
|
76
96
|
|
|
77
97
|
@torch.inference_mode()
|
|
78
|
-
def emission(self,
|
|
98
|
+
def emission(self, ndarray: np.ndarray, acoustic_scale: float = 1.0, device: Optional[str] = None) -> torch.Tensor:
|
|
99
|
+
"""Generate emission probabilities from audio ndarray.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
ndarray: Audio data as numpy array of shape (1, T) or (C, T)
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
Emission tensor of shape (1, T, vocab_size)
|
|
106
|
+
"""
|
|
79
107
|
_start = time.time()
|
|
80
108
|
if self.extractor is not None:
|
|
81
109
|
# audio -> features -> emission
|
|
110
|
+
audio = torch.from_numpy(ndarray).to(self.device)
|
|
111
|
+
if audio.shape[1] < 160:
|
|
112
|
+
audio = torch.nn.functional.pad(audio, (0, 320 - audio.shape[1]))
|
|
82
113
|
features = self.extractor(audio) # (1, T, D)
|
|
83
114
|
if features.shape[1] > 6000:
|
|
84
|
-
features_list = torch.split(features, 6000, dim=1)
|
|
85
115
|
emissions = []
|
|
86
|
-
for
|
|
116
|
+
for start in range(0, features.size(1), 6000):
|
|
117
|
+
_features = features[:, start : start + 6000, :]
|
|
87
118
|
ort_inputs = {
|
|
88
|
-
"features":
|
|
89
|
-
"feature_lengths": np.array([
|
|
119
|
+
"features": _features.cpu().numpy(),
|
|
120
|
+
"feature_lengths": np.array([_features.size(1)], dtype=np.int64),
|
|
90
121
|
}
|
|
91
122
|
emission = self.acoustic_ort.run(None, ort_inputs)[0] # (1, T, vocab_size) numpy
|
|
92
123
|
emissions.append(emission)
|
|
93
124
|
emission = torch.cat(
|
|
94
|
-
[torch.from_numpy(emission).to(self.device) for emission in emissions], dim=1
|
|
125
|
+
[torch.from_numpy(emission).to(device or self.device) for emission in emissions], dim=1
|
|
95
126
|
) # (1, T, vocab_size)
|
|
127
|
+
del emissions
|
|
96
128
|
else:
|
|
97
129
|
ort_inputs = {
|
|
98
130
|
"features": features.cpu().numpy(),
|
|
99
131
|
"feature_lengths": np.array([features.size(1)], dtype=np.int64),
|
|
100
132
|
}
|
|
101
133
|
emission = self.acoustic_ort.run(None, ort_inputs)[0] # (1, T, vocab_size) numpy
|
|
102
|
-
emission = torch.from_numpy(emission).to(self.device)
|
|
134
|
+
emission = torch.from_numpy(emission).to(device or self.device)
|
|
103
135
|
else:
|
|
136
|
+
if ndarray.shape[1] < 160:
|
|
137
|
+
ndarray = np.pad(ndarray, ((0, 0), (0, 320 - ndarray.shape[1])), mode="constant")
|
|
138
|
+
|
|
104
139
|
CHUNK_SIZE = 60 * 16000 # 60 seconds
|
|
105
|
-
if
|
|
106
|
-
audio_list = torch.split(audio, CHUNK_SIZE, dim=1)
|
|
140
|
+
if ndarray.shape[1] > CHUNK_SIZE:
|
|
107
141
|
emissions = []
|
|
108
|
-
for
|
|
142
|
+
for start in range(0, ndarray.shape[1], CHUNK_SIZE):
|
|
109
143
|
emission = self.acoustic_ort.run(
|
|
110
144
|
None,
|
|
111
145
|
{
|
|
112
|
-
"audios":
|
|
146
|
+
"audios": ndarray[:, start : start + CHUNK_SIZE],
|
|
113
147
|
},
|
|
114
|
-
)
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
emissions.append(emission)
|
|
148
|
+
) # (1, T, vocab_size) numpy
|
|
149
|
+
emissions.append(emission[0])
|
|
150
|
+
|
|
118
151
|
emission = torch.cat(
|
|
119
|
-
[torch.from_numpy(emission).to(self.device) for emission in emissions], dim=1
|
|
152
|
+
[torch.from_numpy(emission).to(device or self.device) for emission in emissions], dim=1
|
|
120
153
|
) # (1, T, vocab_size)
|
|
154
|
+
del emissions
|
|
121
155
|
else:
|
|
122
156
|
emission = self.acoustic_ort.run(
|
|
123
157
|
None,
|
|
124
158
|
{
|
|
125
|
-
"audios":
|
|
159
|
+
"audios": ndarray,
|
|
126
160
|
},
|
|
127
|
-
)
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
161
|
+
) # (1, T, vocab_size) numpy
|
|
162
|
+
emission = torch.from_numpy(emission[0]).to(device or self.device)
|
|
163
|
+
|
|
164
|
+
if acoustic_scale != 1.0:
|
|
165
|
+
emission = emission.mul_(acoustic_scale)
|
|
131
166
|
|
|
132
167
|
self.timings["emission"] += time.time() - _start
|
|
133
168
|
return emission # (1, T, vocab_size) torch
|
|
@@ -144,6 +179,9 @@ class Lattice1Worker:
|
|
|
144
179
|
Args:
|
|
145
180
|
audio: AudioData object
|
|
146
181
|
lattice_graph: LatticeGraph data
|
|
182
|
+
emission: Pre-computed emission tensor (ignored if streaming=True)
|
|
183
|
+
offset: Time offset for the audio
|
|
184
|
+
streaming: If True, use streaming mode for memory-efficient processing
|
|
147
185
|
|
|
148
186
|
Returns:
|
|
149
187
|
Processed LatticeGraph
|
|
@@ -153,16 +191,6 @@ class Lattice1Worker:
|
|
|
153
191
|
DependencyError: If required dependencies are missing
|
|
154
192
|
AlignmentError: If alignment process fails
|
|
155
193
|
"""
|
|
156
|
-
if emission is None:
|
|
157
|
-
try:
|
|
158
|
-
emission = self.emission(audio.tensor.to(self.device)) # (1, T, vocab_size)
|
|
159
|
-
except Exception as e:
|
|
160
|
-
raise AlignmentError(
|
|
161
|
-
"Failed to compute acoustic features from audio",
|
|
162
|
-
media_path=str(audio) if not isinstance(audio, torch.Tensor) else "tensor",
|
|
163
|
-
context={"original_error": str(e)},
|
|
164
|
-
)
|
|
165
|
-
|
|
166
194
|
try:
|
|
167
195
|
import k2
|
|
168
196
|
except ImportError:
|
|
@@ -177,7 +205,7 @@ class Lattice1Worker:
|
|
|
177
205
|
|
|
178
206
|
_start = time.time()
|
|
179
207
|
try:
|
|
180
|
-
# graph
|
|
208
|
+
# Create decoding graph
|
|
181
209
|
decoding_graph = k2.Fsa.from_str(lattice_graph_str, acceptor=False)
|
|
182
210
|
decoding_graph.requires_grad_(False)
|
|
183
211
|
decoding_graph = k2.arc_sort(decoding_graph)
|
|
@@ -190,39 +218,96 @@ class Lattice1Worker:
|
|
|
190
218
|
)
|
|
191
219
|
self.timings["decoding_graph"] += time.time() - _start
|
|
192
220
|
|
|
193
|
-
_start = time.time()
|
|
194
221
|
if self.device.type == "mps":
|
|
195
222
|
device = "cpu" # k2 does not support mps yet
|
|
196
223
|
else:
|
|
197
224
|
device = self.device
|
|
198
225
|
|
|
199
|
-
|
|
226
|
+
_start = time.time()
|
|
227
|
+
|
|
228
|
+
# Get beam search parameters from config or use defaults
|
|
229
|
+
search_beam = self.alignment_config.search_beam or 200
|
|
230
|
+
output_beam = self.alignment_config.output_beam or 80
|
|
231
|
+
min_active_states = self.alignment_config.min_active_states or 400
|
|
232
|
+
max_active_states = self.alignment_config.max_active_states or 10000
|
|
233
|
+
|
|
234
|
+
if emission is None and audio.streaming_mode:
|
|
235
|
+
# Streaming mode: pass emission iterator to align_segments
|
|
236
|
+
# The align_segments function will automatically detect the iterator
|
|
237
|
+
# and use k2.OnlineDenseIntersecter for memory-efficient processing
|
|
238
|
+
|
|
239
|
+
def emission_iterator():
|
|
240
|
+
"""Generate emissions for each audio chunk with progress tracking."""
|
|
241
|
+
total_duration = audio.duration
|
|
242
|
+
processed_duration = 0.0
|
|
243
|
+
total_minutes = int(total_duration / 60.0)
|
|
244
|
+
|
|
245
|
+
with tqdm(
|
|
246
|
+
total=total_minutes,
|
|
247
|
+
desc=f"Processing audio ({total_minutes} min)",
|
|
248
|
+
unit="min",
|
|
249
|
+
unit_scale=False,
|
|
250
|
+
unit_divisor=1,
|
|
251
|
+
) as pbar:
|
|
252
|
+
for chunk in audio.iter_chunks():
|
|
253
|
+
chunk_emission = self.emission(chunk.ndarray, acoustic_scale=acoustic_scale, device=device)
|
|
254
|
+
|
|
255
|
+
# Update progress based on chunk duration in minutes
|
|
256
|
+
chunk_duration = int(chunk.duration / 60.0)
|
|
257
|
+
pbar.update(chunk_duration)
|
|
258
|
+
processed_duration += chunk_duration
|
|
259
|
+
|
|
260
|
+
yield chunk_emission
|
|
261
|
+
|
|
262
|
+
# Calculate total frames for supervision_segments
|
|
263
|
+
total_frames = int(audio.duration / self.frame_shift)
|
|
264
|
+
|
|
200
265
|
results, labels = align_segments(
|
|
201
|
-
|
|
266
|
+
emission_iterator(), # Pass iterator for streaming
|
|
202
267
|
decoding_graph.to(device),
|
|
203
|
-
torch.tensor([
|
|
204
|
-
search_beam=
|
|
205
|
-
output_beam=
|
|
206
|
-
min_active_states=
|
|
207
|
-
max_active_states=
|
|
268
|
+
torch.tensor([total_frames], dtype=torch.int32),
|
|
269
|
+
search_beam=search_beam,
|
|
270
|
+
output_beam=output_beam,
|
|
271
|
+
min_active_states=min_active_states,
|
|
272
|
+
max_active_states=max_active_states,
|
|
208
273
|
subsampling_factor=1,
|
|
209
274
|
reject_low_confidence=False,
|
|
210
275
|
)
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
276
|
+
|
|
277
|
+
# For streaming, don't return emission tensor to save memory
|
|
278
|
+
emission_result = None
|
|
279
|
+
else:
|
|
280
|
+
# Batch mode: compute full emission tensor and pass to align_segments
|
|
281
|
+
if emission is None:
|
|
282
|
+
emission = self.emission(
|
|
283
|
+
audio.ndarray, acoustic_scale=acoustic_scale, device=device
|
|
284
|
+
) # (1, T, vocab_size)
|
|
285
|
+
else:
|
|
286
|
+
emission = emission.to(device) * acoustic_scale
|
|
287
|
+
|
|
288
|
+
results, labels = align_segments(
|
|
289
|
+
emission,
|
|
290
|
+
decoding_graph.to(device),
|
|
291
|
+
torch.tensor([emission.shape[1]], dtype=torch.int32),
|
|
292
|
+
search_beam=search_beam,
|
|
293
|
+
output_beam=output_beam,
|
|
294
|
+
min_active_states=min_active_states,
|
|
295
|
+
max_active_states=max_active_states,
|
|
296
|
+
subsampling_factor=1,
|
|
297
|
+
reject_low_confidence=False,
|
|
216
298
|
)
|
|
299
|
+
|
|
300
|
+
emission_result = emission
|
|
301
|
+
|
|
217
302
|
self.timings["align_segments"] += time.time() - _start
|
|
218
303
|
|
|
219
304
|
channel = 0
|
|
220
|
-
return
|
|
305
|
+
return emission_result, results, labels, self.frame_shift, offset, channel # frame_shift=20ms
|
|
221
306
|
|
|
222
307
|
|
|
223
|
-
def _load_worker(model_path: str, device: str) -> Lattice1Worker:
|
|
308
|
+
def _load_worker(model_path: str, device: str, config: Optional[Any] = None) -> Lattice1Worker:
|
|
224
309
|
"""Instantiate lattice worker with consistent error handling."""
|
|
225
310
|
try:
|
|
226
|
-
return Lattice1Worker(model_path, device=device, num_threads=8)
|
|
311
|
+
return Lattice1Worker(model_path, device=device, num_threads=8, config=config)
|
|
227
312
|
except Exception as e:
|
|
228
313
|
raise ModelLoadError(f"worker from {model_path}", original_error=e)
|
lattifai/alignment/segmenter.py
CHANGED
|
@@ -7,6 +7,7 @@ import colorful
|
|
|
7
7
|
from lattifai.audio2 import AudioData
|
|
8
8
|
from lattifai.caption import Caption, Supervision
|
|
9
9
|
from lattifai.config import AlignmentConfig
|
|
10
|
+
from lattifai.utils import safe_print
|
|
10
11
|
|
|
11
12
|
from .tokenizer import END_PUNCTUATION
|
|
12
13
|
|
|
@@ -153,7 +154,7 @@ class Segmenter:
|
|
|
153
154
|
|
|
154
155
|
total_sups = sum(len(sups) if isinstance(sups, list) else 1 for _, _, sups, _ in segments)
|
|
155
156
|
|
|
156
|
-
|
|
157
|
+
safe_print(colorful.cyan(f"📊 Created {len(segments)} alignment segments:"))
|
|
157
158
|
for i, (start, end, sups, _) in enumerate(segments, 1):
|
|
158
159
|
duration = end - start
|
|
159
160
|
print(
|
|
@@ -163,4 +164,4 @@ class Segmenter:
|
|
|
163
164
|
)
|
|
164
165
|
)
|
|
165
166
|
|
|
166
|
-
|
|
167
|
+
safe_print(colorful.green(f" Total: {total_sups} supervisions across {len(segments)} segments"))
|
lattifai/alignment/tokenizer.py
CHANGED
|
@@ -214,7 +214,7 @@ class LatticeTokenizer:
|
|
|
214
214
|
else:
|
|
215
215
|
with open(words_model_path, "rb") as f:
|
|
216
216
|
data = pickle.load(f)
|
|
217
|
-
except
|
|
217
|
+
except Exception as e:
|
|
218
218
|
del e
|
|
219
219
|
import msgpack
|
|
220
220
|
|
|
@@ -335,7 +335,7 @@ class LatticeTokenizer:
|
|
|
335
335
|
flush_segment(s, None)
|
|
336
336
|
|
|
337
337
|
assert len(speakers) == len(texts), f"len(speakers)={len(speakers)} != len(texts)={len(texts)}"
|
|
338
|
-
sentences = self.sentence_splitter.split(texts, threshold=0.15, strip_whitespace=strip_whitespace)
|
|
338
|
+
sentences = self.sentence_splitter.split(texts, threshold=0.15, strip_whitespace=strip_whitespace, batch_size=8)
|
|
339
339
|
|
|
340
340
|
supervisions, remainder = [], ""
|
|
341
341
|
for k, (_speaker, _sentences) in enumerate(zip(speakers, sentences)):
|
|
@@ -434,6 +434,8 @@ class LatticeTokenizer:
|
|
|
434
434
|
lattice_results: Tuple[torch.Tensor, Any, Any, float, float],
|
|
435
435
|
supervisions: List[Supervision],
|
|
436
436
|
return_details: bool = False,
|
|
437
|
+
start_margin: float = 0.08,
|
|
438
|
+
end_margin: float = 0.20,
|
|
437
439
|
) -> List[Supervision]:
|
|
438
440
|
emission, results, labels, frame_shift, offset, channel = lattice_results # noqa: F841
|
|
439
441
|
response = self.client_wrapper.post(
|
|
@@ -448,9 +450,11 @@ class LatticeTokenizer:
|
|
|
448
450
|
"channel": channel,
|
|
449
451
|
"return_details": False if return_details is None else return_details,
|
|
450
452
|
"destroy_lattice": True,
|
|
453
|
+
"start_margin": start_margin,
|
|
454
|
+
"end_margin": end_margin,
|
|
451
455
|
},
|
|
452
456
|
)
|
|
453
|
-
if response.status_code ==
|
|
457
|
+
if response.status_code == 400:
|
|
454
458
|
raise LatticeDecodingError(
|
|
455
459
|
lattice_id,
|
|
456
460
|
original_error=Exception(LATTICE_DECODING_FAILURE_HELP),
|
|
@@ -466,7 +470,7 @@ class LatticeTokenizer:
|
|
|
466
470
|
|
|
467
471
|
alignments = [Supervision.from_dict(s) for s in result["supervisions"]]
|
|
468
472
|
|
|
469
|
-
if return_details:
|
|
473
|
+
if emission is not None and return_details:
|
|
470
474
|
# Add emission confidence scores for segments and word-level alignments
|
|
471
475
|
_add_confidence_scores(alignments, emission, labels[0], frame_shift, offset)
|
|
472
476
|
|
|
@@ -538,12 +542,9 @@ def _load_tokenizer(
|
|
|
538
542
|
tokenizer_cls: Type[LatticeTokenizer] = LatticeTokenizer,
|
|
539
543
|
) -> LatticeTokenizer:
|
|
540
544
|
"""Instantiate tokenizer with consistent error handling."""
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
)
|
|
548
|
-
except Exception as e:
|
|
549
|
-
raise ModelLoadError(f"tokenizer from {model_path}", original_error=e)
|
|
545
|
+
return tokenizer_cls.from_pretrained(
|
|
546
|
+
client_wrapper=client_wrapper,
|
|
547
|
+
model_path=model_path,
|
|
548
|
+
model_name=model_name,
|
|
549
|
+
device=device,
|
|
550
|
+
)
|