lattifai 1.0.5__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lattifai/__init__.py +11 -12
- lattifai/alignment/lattice1_aligner.py +39 -7
- lattifai/alignment/lattice1_worker.py +135 -147
- lattifai/alignment/tokenizer.py +38 -22
- lattifai/audio2.py +1 -1
- lattifai/caption/caption.py +55 -19
- lattifai/cli/__init__.py +2 -0
- lattifai/cli/caption.py +1 -1
- lattifai/cli/diarization.py +110 -0
- lattifai/cli/transcribe.py +3 -1
- lattifai/cli/youtube.py +11 -0
- lattifai/client.py +32 -111
- lattifai/config/alignment.py +14 -0
- lattifai/config/client.py +5 -0
- lattifai/config/transcription.py +4 -0
- lattifai/diarization/lattifai.py +18 -7
- lattifai/mixin.py +26 -5
- lattifai/transcription/__init__.py +1 -1
- lattifai/transcription/base.py +21 -2
- lattifai/transcription/gemini.py +127 -1
- lattifai/transcription/lattifai.py +30 -2
- lattifai/utils.py +62 -69
- lattifai/workflow/youtube.py +55 -57
- {lattifai-1.0.5.dist-info → lattifai-1.2.0.dist-info}/METADATA +352 -56
- {lattifai-1.0.5.dist-info → lattifai-1.2.0.dist-info}/RECORD +29 -28
- {lattifai-1.0.5.dist-info → lattifai-1.2.0.dist-info}/entry_points.txt +2 -0
- {lattifai-1.0.5.dist-info → lattifai-1.2.0.dist-info}/WHEEL +0 -0
- {lattifai-1.0.5.dist-info → lattifai-1.2.0.dist-info}/licenses/LICENSE +0 -0
- {lattifai-1.0.5.dist-info → lattifai-1.2.0.dist-info}/top_level.txt +0 -0
lattifai/__init__.py
CHANGED
|
@@ -52,28 +52,27 @@ except Exception:
|
|
|
52
52
|
__version__ = "0.1.0" # fallback version
|
|
53
53
|
|
|
54
54
|
|
|
55
|
-
# Check and auto-install
|
|
56
|
-
def
|
|
57
|
-
"""Check if
|
|
55
|
+
# Check and auto-install k2py if not present
|
|
56
|
+
def _check_and_install_k2py():
|
|
57
|
+
"""Check if k2py is installed and attempt to install it if not."""
|
|
58
58
|
try:
|
|
59
|
-
import
|
|
59
|
+
import k2py
|
|
60
60
|
except ImportError:
|
|
61
61
|
import subprocess
|
|
62
62
|
|
|
63
|
-
print("
|
|
63
|
+
print("k2py is not installed. Attempting to install k2py...")
|
|
64
64
|
try:
|
|
65
|
-
subprocess.check_call([sys.executable, "-m", "pip", "install", "
|
|
66
|
-
|
|
67
|
-
import k2 # Try importing again after installation
|
|
65
|
+
subprocess.check_call([sys.executable, "-m", "pip", "install", "k2py"])
|
|
66
|
+
import k2py # Try importing again after installation
|
|
68
67
|
|
|
69
|
-
print("
|
|
68
|
+
print("k2py installed successfully.")
|
|
70
69
|
except Exception as e:
|
|
71
|
-
warnings.warn(f"Failed to install
|
|
70
|
+
warnings.warn(f"Failed to install k2py automatically. Please install it manually. Error: {e}")
|
|
72
71
|
return True
|
|
73
72
|
|
|
74
73
|
|
|
75
|
-
# Auto-install
|
|
76
|
-
|
|
74
|
+
# Auto-install k2py on first import
|
|
75
|
+
_check_and_install_k2py()
|
|
77
76
|
|
|
78
77
|
|
|
79
78
|
__all__ = [
|
|
@@ -4,7 +4,6 @@ from typing import Any, List, Optional, Tuple
|
|
|
4
4
|
|
|
5
5
|
import colorful
|
|
6
6
|
import numpy as np
|
|
7
|
-
import torch
|
|
8
7
|
|
|
9
8
|
from lattifai.audio2 import AudioData
|
|
10
9
|
from lattifai.caption import Supervision
|
|
@@ -35,31 +34,55 @@ class Lattice1Aligner(object):
|
|
|
35
34
|
raise ValueError("AlignmentConfig.client_wrapper is not set. It must be initialized by the client.")
|
|
36
35
|
|
|
37
36
|
client_wrapper = config.client_wrapper
|
|
38
|
-
|
|
37
|
+
# Resolve model path using configured model hub
|
|
38
|
+
model_path = _resolve_model_path(config.model_name, getattr(config, "model_hub", "huggingface"))
|
|
39
39
|
|
|
40
|
-
self.tokenizer = _load_tokenizer(
|
|
40
|
+
self.tokenizer = _load_tokenizer(
|
|
41
|
+
client_wrapper, model_path, config.model_name, config.device, model_hub=config.model_hub
|
|
42
|
+
)
|
|
41
43
|
self.worker = _load_worker(model_path, config.device, config)
|
|
42
44
|
|
|
43
45
|
self.frame_shift = self.worker.frame_shift
|
|
44
46
|
|
|
45
|
-
def emission(self, ndarray: np.ndarray) ->
|
|
47
|
+
def emission(self, ndarray: np.ndarray) -> np.ndarray:
|
|
46
48
|
"""Generate emission probabilities from audio ndarray.
|
|
47
49
|
|
|
48
50
|
Args:
|
|
49
51
|
ndarray: Audio data as numpy array of shape (1, T) or (C, T)
|
|
50
52
|
|
|
51
53
|
Returns:
|
|
52
|
-
Emission
|
|
54
|
+
Emission numpy array of shape (1, T, vocab_size)
|
|
53
55
|
"""
|
|
54
56
|
return self.worker.emission(ndarray)
|
|
55
57
|
|
|
58
|
+
def separate(self, audio: np.ndarray) -> np.ndarray:
|
|
59
|
+
"""Separate audio using separator model.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
audio: np.ndarray object containing the audio to separate, shape (1, T)
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
Separated audio as numpy array
|
|
66
|
+
|
|
67
|
+
Raises:
|
|
68
|
+
RuntimeError: If separator model is not available
|
|
69
|
+
"""
|
|
70
|
+
if self.worker.separator_ort is None:
|
|
71
|
+
raise RuntimeError("Separator model not available. separator.onnx not found in model path.")
|
|
72
|
+
# Run separator model
|
|
73
|
+
separator_output = self.worker.separator_ort.run(
|
|
74
|
+
None,
|
|
75
|
+
{"audios": audio},
|
|
76
|
+
)
|
|
77
|
+
return separator_output[0]
|
|
78
|
+
|
|
56
79
|
def alignment(
|
|
57
80
|
self,
|
|
58
81
|
audio: AudioData,
|
|
59
82
|
supervisions: List[Supervision],
|
|
60
83
|
split_sentence: Optional[bool] = False,
|
|
61
84
|
return_details: Optional[bool] = False,
|
|
62
|
-
emission: Optional[
|
|
85
|
+
emission: Optional[np.ndarray] = None,
|
|
63
86
|
offset: float = 0.0,
|
|
64
87
|
verbose: bool = True,
|
|
65
88
|
) -> Tuple[List[Supervision], List[Supervision]]:
|
|
@@ -120,7 +143,12 @@ class Lattice1Aligner(object):
|
|
|
120
143
|
safe_print(colorful.cyan("🎯 Step 4: Decoding lattice results to aligned segments"))
|
|
121
144
|
try:
|
|
122
145
|
alignments = self.tokenizer.detokenize(
|
|
123
|
-
lattice_id,
|
|
146
|
+
lattice_id,
|
|
147
|
+
lattice_results,
|
|
148
|
+
supervisions=supervisions,
|
|
149
|
+
return_details=return_details,
|
|
150
|
+
start_margin=self.config.start_margin,
|
|
151
|
+
end_margin=self.config.end_margin,
|
|
124
152
|
)
|
|
125
153
|
if verbose:
|
|
126
154
|
safe_print(colorful.green(f" ✓ Successfully aligned {len(alignments)} segments"))
|
|
@@ -137,3 +165,7 @@ class Lattice1Aligner(object):
|
|
|
137
165
|
raise
|
|
138
166
|
except Exception as e:
|
|
139
167
|
raise e
|
|
168
|
+
|
|
169
|
+
def profile(self) -> None:
|
|
170
|
+
"""Print profiling statistics."""
|
|
171
|
+
self.worker.profile()
|
|
@@ -1,11 +1,12 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import time
|
|
3
3
|
from collections import defaultdict
|
|
4
|
+
from pathlib import Path
|
|
4
5
|
from typing import Any, Dict, Optional, Tuple
|
|
5
6
|
|
|
7
|
+
import colorful
|
|
6
8
|
import numpy as np
|
|
7
9
|
import onnxruntime as ort
|
|
8
|
-
import torch
|
|
9
10
|
from lhotse import FbankConfig
|
|
10
11
|
from lhotse.features.kaldi.layers import Wav2LogFilterBank
|
|
11
12
|
from lhotse.utils import Pathlike
|
|
@@ -13,6 +14,7 @@ from tqdm import tqdm
|
|
|
13
14
|
|
|
14
15
|
from lattifai.audio2 import AudioData
|
|
15
16
|
from lattifai.errors import AlignmentError, DependencyError, ModelLoadError
|
|
17
|
+
from lattifai.utils import safe_print
|
|
16
18
|
|
|
17
19
|
|
|
18
20
|
class Lattice1Worker:
|
|
@@ -60,104 +62,84 @@ class Lattice1Worker:
|
|
|
60
62
|
except Exception as e:
|
|
61
63
|
raise ModelLoadError(f"acoustic model from {model_path}", original_error=e)
|
|
62
64
|
|
|
65
|
+
# Get vocab_size from model output
|
|
66
|
+
self.vocab_size = self.acoustic_ort.get_outputs()[0].shape[-1]
|
|
67
|
+
|
|
63
68
|
# get input_names
|
|
64
69
|
input_names = [inp.name for inp in self.acoustic_ort.get_inputs()]
|
|
65
|
-
|
|
70
|
+
assert "audios" in input_names, f"Input name audios not found in {input_names}"
|
|
71
|
+
|
|
72
|
+
# Initialize separator if available
|
|
73
|
+
separator_model_path = Path(model_path) / "separator.onnx"
|
|
74
|
+
if separator_model_path.exists():
|
|
66
75
|
try:
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
76
|
+
self.separator_ort = ort.InferenceSession(
|
|
77
|
+
str(separator_model_path),
|
|
78
|
+
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
|
|
79
|
+
)
|
|
71
80
|
except Exception as e:
|
|
72
|
-
raise ModelLoadError(f"
|
|
81
|
+
raise ModelLoadError(f"separator model from {model_path}", original_error=e)
|
|
73
82
|
else:
|
|
74
|
-
self.
|
|
83
|
+
self.separator_ort = None
|
|
75
84
|
|
|
76
|
-
self.device = torch.device(device)
|
|
77
85
|
self.timings = defaultdict(lambda: 0.0)
|
|
78
86
|
|
|
79
87
|
@property
|
|
80
88
|
def frame_shift(self) -> float:
|
|
81
89
|
return 0.02 # 20 ms
|
|
82
90
|
|
|
83
|
-
|
|
84
|
-
def emission(self, ndarray: np.ndarray, acoustic_scale: float = 1.0, device: Optional[str] = None) -> torch.Tensor:
|
|
91
|
+
def emission(self, ndarray: np.ndarray, acoustic_scale: float = 1.0) -> np.ndarray:
|
|
85
92
|
"""Generate emission probabilities from audio ndarray.
|
|
86
93
|
|
|
87
94
|
Args:
|
|
88
95
|
ndarray: Audio data as numpy array of shape (1, T) or (C, T)
|
|
89
96
|
|
|
90
97
|
Returns:
|
|
91
|
-
Emission
|
|
98
|
+
Emission numpy array of shape (1, T, vocab_size)
|
|
92
99
|
"""
|
|
93
100
|
_start = time.time()
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
"features": features.cpu().numpy(),
|
|
117
|
-
"feature_lengths": np.array([features.size(1)], dtype=np.int64),
|
|
118
|
-
}
|
|
119
|
-
emission = self.acoustic_ort.run(None, ort_inputs)[0] # (1, T, vocab_size) numpy
|
|
120
|
-
emission = torch.from_numpy(emission).to(device or self.device)
|
|
101
|
+
|
|
102
|
+
if ndarray.shape[1] < 160:
|
|
103
|
+
ndarray = np.pad(ndarray, ((0, 0), (0, 320 - ndarray.shape[1])), mode="constant")
|
|
104
|
+
|
|
105
|
+
CHUNK_SIZE = 60 * 16000 # 60 seconds
|
|
106
|
+
total_samples = ndarray.shape[1]
|
|
107
|
+
|
|
108
|
+
if total_samples > CHUNK_SIZE:
|
|
109
|
+
frame_samples = int(16000 * self.frame_shift)
|
|
110
|
+
emissions = np.empty((1, total_samples // frame_samples + 1, self.vocab_size), dtype=np.float32)
|
|
111
|
+
for start in range(0, total_samples, CHUNK_SIZE):
|
|
112
|
+
chunk = ndarray[:, start : start + CHUNK_SIZE]
|
|
113
|
+
if chunk.shape[1] < 160:
|
|
114
|
+
chunk = np.pad(chunk, ((0, 0), (0, 320 - chunk.shape[1])), mode="constant")
|
|
115
|
+
|
|
116
|
+
emission_out = self.acoustic_ort.run(None, {"audios": chunk})[0]
|
|
117
|
+
if acoustic_scale != 1.0:
|
|
118
|
+
emission_out *= acoustic_scale
|
|
119
|
+
sf = start // frame_samples # start frame
|
|
120
|
+
lf = sf + emission_out.shape[1] # last frame
|
|
121
|
+
emissions[0, sf:lf, :] = emission_out
|
|
122
|
+
emissions[:, lf:, :] = 0.0
|
|
121
123
|
else:
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
emission = self.acoustic_ort.run(
|
|
130
|
-
None,
|
|
131
|
-
{
|
|
132
|
-
"audios": ndarray[:, start : start + CHUNK_SIZE],
|
|
133
|
-
},
|
|
134
|
-
) # (1, T, vocab_size) numpy
|
|
135
|
-
emissions.append(emission[0])
|
|
136
|
-
|
|
137
|
-
emission = torch.cat(
|
|
138
|
-
[torch.from_numpy(emission).to(device or self.device) for emission in emissions], dim=1
|
|
139
|
-
) # (1, T, vocab_size)
|
|
140
|
-
del emissions
|
|
141
|
-
else:
|
|
142
|
-
emission = self.acoustic_ort.run(
|
|
143
|
-
None,
|
|
144
|
-
{
|
|
145
|
-
"audios": ndarray,
|
|
146
|
-
},
|
|
147
|
-
) # (1, T, vocab_size) numpy
|
|
148
|
-
emission = torch.from_numpy(emission[0]).to(device or self.device)
|
|
124
|
+
emission_out = self.acoustic_ort.run(
|
|
125
|
+
None,
|
|
126
|
+
{
|
|
127
|
+
"audios": ndarray,
|
|
128
|
+
},
|
|
129
|
+
) # (1, T, vocab_size) numpy
|
|
130
|
+
emissions = emission_out[0]
|
|
149
131
|
|
|
150
|
-
|
|
151
|
-
|
|
132
|
+
if acoustic_scale != 1.0:
|
|
133
|
+
emissions *= acoustic_scale
|
|
152
134
|
|
|
153
135
|
self.timings["emission"] += time.time() - _start
|
|
154
|
-
return
|
|
136
|
+
return emissions # (1, T, vocab_size) numpy
|
|
155
137
|
|
|
156
138
|
def alignment(
|
|
157
139
|
self,
|
|
158
140
|
audio: AudioData,
|
|
159
141
|
lattice_graph: Tuple[str, int, float],
|
|
160
|
-
emission: Optional[
|
|
142
|
+
emission: Optional[np.ndarray] = None,
|
|
161
143
|
offset: float = 0.0,
|
|
162
144
|
) -> Dict[str, Any]:
|
|
163
145
|
"""Process audio with LatticeGraph.
|
|
@@ -165,7 +147,7 @@ class Lattice1Worker:
|
|
|
165
147
|
Args:
|
|
166
148
|
audio: AudioData object
|
|
167
149
|
lattice_graph: LatticeGraph data
|
|
168
|
-
emission: Pre-computed emission
|
|
150
|
+
emission: Pre-computed emission numpy array (ignored if streaming=True)
|
|
169
151
|
offset: Time offset for the audio
|
|
170
152
|
streaming: If True, use streaming mode for memory-efficient processing
|
|
171
153
|
|
|
@@ -178,25 +160,18 @@ class Lattice1Worker:
|
|
|
178
160
|
AlignmentError: If alignment process fails
|
|
179
161
|
"""
|
|
180
162
|
try:
|
|
181
|
-
import k2
|
|
182
|
-
except ImportError:
|
|
183
|
-
raise DependencyError("k2", install_command="pip install install-k2 && python -m install_k2")
|
|
184
|
-
|
|
185
|
-
try:
|
|
186
|
-
from lattifai_core.lattice.decode import align_segments
|
|
163
|
+
import k2py as k2
|
|
187
164
|
except ImportError:
|
|
188
|
-
raise DependencyError("
|
|
165
|
+
raise DependencyError("k2py", install_command="pip install k2py")
|
|
189
166
|
|
|
190
167
|
lattice_graph_str, final_state, acoustic_scale = lattice_graph
|
|
191
168
|
|
|
192
169
|
_start = time.time()
|
|
193
170
|
try:
|
|
194
|
-
# Create decoding graph
|
|
195
|
-
|
|
196
|
-
decoding_graph
|
|
197
|
-
|
|
198
|
-
decoding_graph.skip_id = int(final_state)
|
|
199
|
-
decoding_graph.return_id = int(final_state + 1)
|
|
171
|
+
# Create decoding graph using k2py
|
|
172
|
+
graph_dict = k2.CreateFsaVecFromStr(lattice_graph_str, int(final_state), False)
|
|
173
|
+
decoding_graph = graph_dict["fsa"]
|
|
174
|
+
aux_labels = graph_dict["aux_labels"]
|
|
200
175
|
except Exception as e:
|
|
201
176
|
raise AlignmentError(
|
|
202
177
|
"Failed to create decoding graph from lattice",
|
|
@@ -204,11 +179,6 @@ class Lattice1Worker:
|
|
|
204
179
|
)
|
|
205
180
|
self.timings["decoding_graph"] += time.time() - _start
|
|
206
181
|
|
|
207
|
-
if self.device.type == "mps":
|
|
208
|
-
device = "cpu" # k2 does not support mps yet
|
|
209
|
-
else:
|
|
210
|
-
device = self.device
|
|
211
|
-
|
|
212
182
|
_start = time.time()
|
|
213
183
|
|
|
214
184
|
# Get beam search parameters from config or use defaults
|
|
@@ -218,71 +188,54 @@ class Lattice1Worker:
|
|
|
218
188
|
max_active_states = self.alignment_config.max_active_states or 10000
|
|
219
189
|
|
|
220
190
|
if emission is None and audio.streaming_mode:
|
|
221
|
-
#
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
total_minutes = int(total_duration / 60.0)
|
|
230
|
-
|
|
231
|
-
with tqdm(
|
|
232
|
-
total=total_minutes,
|
|
233
|
-
desc=f"Processing audio ({total_minutes} min)",
|
|
234
|
-
unit="min",
|
|
235
|
-
unit_scale=False,
|
|
236
|
-
unit_divisor=1,
|
|
237
|
-
) as pbar:
|
|
238
|
-
for chunk in audio.iter_chunks():
|
|
239
|
-
chunk_emission = self.emission(chunk.ndarray, acoustic_scale=acoustic_scale, device=device)
|
|
240
|
-
|
|
241
|
-
# Update progress based on chunk duration in minutes
|
|
242
|
-
chunk_duration = int(chunk.duration / 60.0)
|
|
243
|
-
pbar.update(chunk_duration)
|
|
244
|
-
processed_duration += chunk_duration
|
|
245
|
-
|
|
246
|
-
yield chunk_emission
|
|
247
|
-
|
|
248
|
-
# Calculate total frames for supervision_segments
|
|
249
|
-
total_frames = int(audio.duration / self.frame_shift)
|
|
250
|
-
|
|
251
|
-
results, labels = align_segments(
|
|
252
|
-
emission_iterator(), # Pass iterator for streaming
|
|
253
|
-
decoding_graph.to(device),
|
|
254
|
-
torch.tensor([total_frames], dtype=torch.int32),
|
|
255
|
-
search_beam=search_beam,
|
|
256
|
-
output_beam=output_beam,
|
|
257
|
-
min_active_states=min_active_states,
|
|
258
|
-
max_active_states=max_active_states,
|
|
259
|
-
subsampling_factor=1,
|
|
260
|
-
reject_low_confidence=False,
|
|
191
|
+
# Initialize OnlineDenseIntersecter for streaming
|
|
192
|
+
intersecter = k2.OnlineDenseIntersecter(
|
|
193
|
+
decoding_graph,
|
|
194
|
+
aux_labels,
|
|
195
|
+
float(search_beam),
|
|
196
|
+
float(output_beam),
|
|
197
|
+
int(min_active_states),
|
|
198
|
+
int(max_active_states),
|
|
261
199
|
)
|
|
262
200
|
|
|
263
|
-
#
|
|
201
|
+
# Streaming mode
|
|
202
|
+
total_duration = audio.duration
|
|
203
|
+
total_minutes = int(total_duration / 60.0)
|
|
204
|
+
|
|
205
|
+
with tqdm(
|
|
206
|
+
total=total_minutes,
|
|
207
|
+
desc=f"Processing audio ({total_minutes} min)",
|
|
208
|
+
unit="min",
|
|
209
|
+
unit_scale=False,
|
|
210
|
+
unit_divisor=1,
|
|
211
|
+
) as pbar:
|
|
212
|
+
for chunk in audio.iter_chunks():
|
|
213
|
+
chunk_emission = self.emission(chunk.ndarray, acoustic_scale=acoustic_scale)
|
|
214
|
+
intersecter.decode(chunk_emission[0])
|
|
215
|
+
|
|
216
|
+
# Update progress
|
|
217
|
+
chunk_duration = int(chunk.duration / 60.0)
|
|
218
|
+
pbar.update(chunk_duration)
|
|
219
|
+
|
|
264
220
|
emission_result = None
|
|
221
|
+
# Get results from intersecter
|
|
222
|
+
results, labels = intersecter.finish()
|
|
265
223
|
else:
|
|
266
|
-
# Batch mode
|
|
224
|
+
# Batch mode
|
|
267
225
|
if emission is None:
|
|
268
|
-
emission = self.emission(
|
|
269
|
-
audio.ndarray, acoustic_scale=acoustic_scale, device=device
|
|
270
|
-
) # (1, T, vocab_size)
|
|
226
|
+
emission = self.emission(audio.ndarray, acoustic_scale=acoustic_scale) # (1, T, vocab_size)
|
|
271
227
|
else:
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
search_beam
|
|
279
|
-
output_beam
|
|
280
|
-
min_active_states
|
|
281
|
-
max_active_states
|
|
282
|
-
subsampling_factor=1,
|
|
283
|
-
reject_low_confidence=False,
|
|
228
|
+
if acoustic_scale != 1.0:
|
|
229
|
+
emission *= acoustic_scale
|
|
230
|
+
# Use AlignSegments directly
|
|
231
|
+
results, labels = k2.AlignSegments(
|
|
232
|
+
graph_dict,
|
|
233
|
+
emission[0], # Pass the prepared scores
|
|
234
|
+
float(search_beam),
|
|
235
|
+
float(output_beam),
|
|
236
|
+
int(min_active_states),
|
|
237
|
+
int(max_active_states),
|
|
284
238
|
)
|
|
285
|
-
|
|
286
239
|
emission_result = emission
|
|
287
240
|
|
|
288
241
|
self.timings["align_segments"] += time.time() - _start
|
|
@@ -290,6 +243,41 @@ class Lattice1Worker:
|
|
|
290
243
|
channel = 0
|
|
291
244
|
return emission_result, results, labels, self.frame_shift, offset, channel # frame_shift=20ms
|
|
292
245
|
|
|
246
|
+
def profile(self) -> None:
|
|
247
|
+
"""Print formatted profiling statistics."""
|
|
248
|
+
if not self.timings:
|
|
249
|
+
return
|
|
250
|
+
|
|
251
|
+
safe_print(colorful.bold(colorful.cyan("\n⏱️ Alignment Profiling")))
|
|
252
|
+
safe_print(colorful.gray("─" * 44))
|
|
253
|
+
safe_print(
|
|
254
|
+
f"{colorful.bold('Phase'.ljust(21))} "
|
|
255
|
+
f"{colorful.bold('Time'.ljust(12))} "
|
|
256
|
+
f"{colorful.bold('Percent'.rjust(8))}"
|
|
257
|
+
)
|
|
258
|
+
safe_print(colorful.gray("─" * 44))
|
|
259
|
+
|
|
260
|
+
total_time = sum(self.timings.values())
|
|
261
|
+
|
|
262
|
+
# Sort by duration descending
|
|
263
|
+
sorted_stats = sorted(self.timings.items(), key=lambda x: x[1], reverse=True)
|
|
264
|
+
|
|
265
|
+
for name, duration in sorted_stats:
|
|
266
|
+
percentage = (duration / total_time * 100) if total_time > 0 else 0.0
|
|
267
|
+
# Name: Cyan, Time: Yellow, Percent: Gray
|
|
268
|
+
safe_print(
|
|
269
|
+
f"{name:<20} "
|
|
270
|
+
f"{colorful.yellow(f'{duration:7.4f}s'.ljust(12))} "
|
|
271
|
+
f"{colorful.gray(f'{percentage:.2f}%'.rjust(8))}"
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
safe_print(colorful.gray("─" * 44))
|
|
275
|
+
# Pad "Total Time" before coloring to ensure correct alignment (ANSI codes don't count for width)
|
|
276
|
+
safe_print(
|
|
277
|
+
f"{colorful.bold('Total Time'.ljust(20))} "
|
|
278
|
+
f"{colorful.bold(colorful.yellow(f'{total_time:7.4f}s'.ljust(12)))}\n"
|
|
279
|
+
)
|
|
280
|
+
|
|
293
281
|
|
|
294
282
|
def _load_worker(model_path: str, device: str, config: Optional[Any] = None) -> Lattice1Worker:
|
|
295
283
|
"""Instantiate lattice worker with consistent error handling."""
|
lattifai/alignment/tokenizer.py
CHANGED
|
@@ -4,7 +4,7 @@ import re
|
|
|
4
4
|
from collections import defaultdict
|
|
5
5
|
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar
|
|
6
6
|
|
|
7
|
-
import
|
|
7
|
+
import numpy as np
|
|
8
8
|
|
|
9
9
|
from lattifai.alignment.phonemizer import G2Phonemizer
|
|
10
10
|
from lattifai.caption import Supervision
|
|
@@ -121,6 +121,7 @@ class LatticeTokenizer:
|
|
|
121
121
|
def __init__(self, client_wrapper: Any):
|
|
122
122
|
self.client_wrapper = client_wrapper
|
|
123
123
|
self.model_name = ""
|
|
124
|
+
self.model_hub: Optional[str] = None
|
|
124
125
|
self.words: List[str] = []
|
|
125
126
|
self.g2p_model: Any = None # Placeholder for G2P model
|
|
126
127
|
self.dictionaries = defaultdict(lambda: [])
|
|
@@ -142,10 +143,20 @@ class LatticeTokenizer:
|
|
|
142
143
|
elif device.startswith("mps") and ort.get_all_providers().count("MPSExecutionProvider") > 0:
|
|
143
144
|
providers.append("MPSExecutionProvider")
|
|
144
145
|
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
146
|
+
if self.model_hub == "modelscope":
|
|
147
|
+
from modelscope.hub.snapshot_download import snapshot_download as ms_snapshot
|
|
148
|
+
|
|
149
|
+
downloaded_path = ms_snapshot("LattifAI/OmniTokenizer")
|
|
150
|
+
sat = SaT(
|
|
151
|
+
f"{downloaded_path}/sat-3l-sm",
|
|
152
|
+
tokenizer_name_or_path=f"{downloaded_path}/xlm-roberta-base",
|
|
153
|
+
ort_providers=providers + ["CPUExecutionProvider"],
|
|
154
|
+
)
|
|
155
|
+
else:
|
|
156
|
+
sat = SaT(
|
|
157
|
+
"sat-3l-sm",
|
|
158
|
+
ort_providers=providers + ["CPUExecutionProvider"],
|
|
159
|
+
)
|
|
149
160
|
self.sentence_splitter = sat
|
|
150
161
|
|
|
151
162
|
@staticmethod
|
|
@@ -200,6 +211,7 @@ class LatticeTokenizer:
|
|
|
200
211
|
client_wrapper: Any,
|
|
201
212
|
model_path: str,
|
|
202
213
|
model_name: str,
|
|
214
|
+
model_hub: Optional[str] = None,
|
|
203
215
|
device: str = "cpu",
|
|
204
216
|
compressed: bool = True,
|
|
205
217
|
) -> TokenizerT:
|
|
@@ -214,7 +226,7 @@ class LatticeTokenizer:
|
|
|
214
226
|
else:
|
|
215
227
|
with open(words_model_path, "rb") as f:
|
|
216
228
|
data = pickle.load(f)
|
|
217
|
-
except
|
|
229
|
+
except Exception as e:
|
|
218
230
|
del e
|
|
219
231
|
import msgpack
|
|
220
232
|
|
|
@@ -227,6 +239,7 @@ class LatticeTokenizer:
|
|
|
227
239
|
|
|
228
240
|
tokenizer = cls(client_wrapper=client_wrapper)
|
|
229
241
|
tokenizer.model_name = model_name
|
|
242
|
+
tokenizer.model_hub = model_hub
|
|
230
243
|
tokenizer.words = data["words"]
|
|
231
244
|
tokenizer.dictionaries = defaultdict(list, data["dictionaries"])
|
|
232
245
|
tokenizer.oov_word = data["oov_word"]
|
|
@@ -431,9 +444,11 @@ class LatticeTokenizer:
|
|
|
431
444
|
def detokenize(
|
|
432
445
|
self,
|
|
433
446
|
lattice_id: str,
|
|
434
|
-
lattice_results: Tuple[
|
|
447
|
+
lattice_results: Tuple[np.ndarray, Any, Any, float, float],
|
|
435
448
|
supervisions: List[Supervision],
|
|
436
449
|
return_details: bool = False,
|
|
450
|
+
start_margin: float = 0.08,
|
|
451
|
+
end_margin: float = 0.20,
|
|
437
452
|
) -> List[Supervision]:
|
|
438
453
|
emission, results, labels, frame_shift, offset, channel = lattice_results # noqa: F841
|
|
439
454
|
response = self.client_wrapper.post(
|
|
@@ -448,6 +463,8 @@ class LatticeTokenizer:
|
|
|
448
463
|
"channel": channel,
|
|
449
464
|
"return_details": False if return_details is None else return_details,
|
|
450
465
|
"destroy_lattice": True,
|
|
466
|
+
"start_margin": start_margin,
|
|
467
|
+
"end_margin": end_margin,
|
|
451
468
|
},
|
|
452
469
|
)
|
|
453
470
|
if response.status_code == 400:
|
|
@@ -477,7 +494,7 @@ class LatticeTokenizer:
|
|
|
477
494
|
|
|
478
495
|
def _add_confidence_scores(
|
|
479
496
|
supervisions: List[Supervision],
|
|
480
|
-
emission:
|
|
497
|
+
emission: np.ndarray,
|
|
481
498
|
labels: List[int],
|
|
482
499
|
frame_shift: float,
|
|
483
500
|
offset: float = 0.0,
|
|
@@ -495,17 +512,17 @@ def _add_confidence_scores(
|
|
|
495
512
|
labels: Token labels corresponding to aligned tokens
|
|
496
513
|
frame_shift: Frame shift in seconds for converting frames to time
|
|
497
514
|
"""
|
|
498
|
-
tokens =
|
|
515
|
+
tokens = np.array(labels, dtype=np.int64)
|
|
499
516
|
|
|
500
517
|
for supervision in supervisions:
|
|
501
518
|
start_frame = int((supervision.start - offset) / frame_shift)
|
|
502
519
|
end_frame = int((supervision.end - offset) / frame_shift)
|
|
503
520
|
|
|
504
521
|
# Compute segment-level confidence
|
|
505
|
-
probabilities = emission[0, start_frame:end_frame]
|
|
522
|
+
probabilities = np.exp(emission[0, start_frame:end_frame])
|
|
506
523
|
aligned = probabilities[range(0, end_frame - start_frame), tokens[start_frame:end_frame]]
|
|
507
|
-
diffprobs =
|
|
508
|
-
supervision.score = round(1.0 - diffprobs.mean()
|
|
524
|
+
diffprobs = np.max(probabilities, axis=-1) - aligned
|
|
525
|
+
supervision.score = round(1.0 - diffprobs.mean(), ndigits=4)
|
|
509
526
|
|
|
510
527
|
# Compute word-level confidence if alignment exists
|
|
511
528
|
if hasattr(supervision, "alignment") and supervision.alignment:
|
|
@@ -513,7 +530,7 @@ def _add_confidence_scores(
|
|
|
513
530
|
for w, item in enumerate(words):
|
|
514
531
|
start = int((item.start - offset) / frame_shift) - start_frame
|
|
515
532
|
end = int((item.end - offset) / frame_shift) - start_frame
|
|
516
|
-
words[w] = item._replace(score=round(1.0 - diffprobs[start:end].mean()
|
|
533
|
+
words[w] = item._replace(score=round(1.0 - diffprobs[start:end].mean(), ndigits=4))
|
|
517
534
|
|
|
518
535
|
|
|
519
536
|
def _update_alignments_speaker(supervisions: List[Supervision], alignments: List[Supervision]) -> List[Supervision]:
|
|
@@ -535,15 +552,14 @@ def _load_tokenizer(
|
|
|
535
552
|
model_name: str,
|
|
536
553
|
device: str,
|
|
537
554
|
*,
|
|
555
|
+
model_hub: Optional[str] = None,
|
|
538
556
|
tokenizer_cls: Type[LatticeTokenizer] = LatticeTokenizer,
|
|
539
557
|
) -> LatticeTokenizer:
|
|
540
558
|
"""Instantiate tokenizer with consistent error handling."""
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
except Exception as e:
|
|
549
|
-
raise ModelLoadError(f"tokenizer from {model_path}", original_error=e)
|
|
559
|
+
return tokenizer_cls.from_pretrained(
|
|
560
|
+
client_wrapper=client_wrapper,
|
|
561
|
+
model_path=model_path,
|
|
562
|
+
model_name=model_name,
|
|
563
|
+
model_hub=model_hub,
|
|
564
|
+
device=device,
|
|
565
|
+
)
|