lattifai 1.1.0__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lattifai/__init__.py +11 -12
- lattifai/alignment/lattice1_aligner.py +11 -8
- lattifai/alignment/lattice1_worker.py +125 -151
- lattifai/alignment/tokenizer.py +27 -12
- lattifai/audio2.py +1 -1
- lattifai/cli/diarization.py +3 -1
- lattifai/cli/youtube.py +11 -0
- lattifai/client.py +5 -0
- lattifai/config/client.py +5 -0
- lattifai/mixin.py +7 -4
- lattifai/utils.py +21 -59
- lattifai/workflow/youtube.py +55 -57
- {lattifai-1.1.0.dist-info → lattifai-1.2.0.dist-info}/METADATA +330 -48
- {lattifai-1.1.0.dist-info → lattifai-1.2.0.dist-info}/RECORD +18 -18
- {lattifai-1.1.0.dist-info → lattifai-1.2.0.dist-info}/WHEEL +0 -0
- {lattifai-1.1.0.dist-info → lattifai-1.2.0.dist-info}/entry_points.txt +0 -0
- {lattifai-1.1.0.dist-info → lattifai-1.2.0.dist-info}/licenses/LICENSE +0 -0
- {lattifai-1.1.0.dist-info → lattifai-1.2.0.dist-info}/top_level.txt +0 -0
lattifai/__init__.py
CHANGED
|
@@ -52,28 +52,27 @@ except Exception:
|
|
|
52
52
|
__version__ = "0.1.0" # fallback version
|
|
53
53
|
|
|
54
54
|
|
|
55
|
-
# Check and auto-install
|
|
56
|
-
def
|
|
57
|
-
"""Check if
|
|
55
|
+
# Check and auto-install k2py if not present
|
|
56
|
+
def _check_and_install_k2py():
|
|
57
|
+
"""Check if k2py is installed and attempt to install it if not."""
|
|
58
58
|
try:
|
|
59
|
-
import
|
|
59
|
+
import k2py
|
|
60
60
|
except ImportError:
|
|
61
61
|
import subprocess
|
|
62
62
|
|
|
63
|
-
print("
|
|
63
|
+
print("k2py is not installed. Attempting to install k2py...")
|
|
64
64
|
try:
|
|
65
|
-
subprocess.check_call([sys.executable, "-m", "pip", "install", "
|
|
66
|
-
|
|
67
|
-
import k2 # Try importing again after installation
|
|
65
|
+
subprocess.check_call([sys.executable, "-m", "pip", "install", "k2py"])
|
|
66
|
+
import k2py # Try importing again after installation
|
|
68
67
|
|
|
69
|
-
print("
|
|
68
|
+
print("k2py installed successfully.")
|
|
70
69
|
except Exception as e:
|
|
71
|
-
warnings.warn(f"Failed to install
|
|
70
|
+
warnings.warn(f"Failed to install k2py automatically. Please install it manually. Error: {e}")
|
|
72
71
|
return True
|
|
73
72
|
|
|
74
73
|
|
|
75
|
-
# Auto-install
|
|
76
|
-
|
|
74
|
+
# Auto-install k2py on first import
|
|
75
|
+
_check_and_install_k2py()
|
|
77
76
|
|
|
78
77
|
|
|
79
78
|
__all__ = [
|
|
@@ -4,7 +4,6 @@ from typing import Any, List, Optional, Tuple
|
|
|
4
4
|
|
|
5
5
|
import colorful
|
|
6
6
|
import numpy as np
|
|
7
|
-
import torch
|
|
8
7
|
|
|
9
8
|
from lattifai.audio2 import AudioData
|
|
10
9
|
from lattifai.caption import Supervision
|
|
@@ -38,19 +37,21 @@ class Lattice1Aligner(object):
|
|
|
38
37
|
# Resolve model path using configured model hub
|
|
39
38
|
model_path = _resolve_model_path(config.model_name, getattr(config, "model_hub", "huggingface"))
|
|
40
39
|
|
|
41
|
-
self.tokenizer = _load_tokenizer(
|
|
40
|
+
self.tokenizer = _load_tokenizer(
|
|
41
|
+
client_wrapper, model_path, config.model_name, config.device, model_hub=config.model_hub
|
|
42
|
+
)
|
|
42
43
|
self.worker = _load_worker(model_path, config.device, config)
|
|
43
44
|
|
|
44
45
|
self.frame_shift = self.worker.frame_shift
|
|
45
46
|
|
|
46
|
-
def emission(self, ndarray: np.ndarray) ->
|
|
47
|
+
def emission(self, ndarray: np.ndarray) -> np.ndarray:
|
|
47
48
|
"""Generate emission probabilities from audio ndarray.
|
|
48
49
|
|
|
49
50
|
Args:
|
|
50
51
|
ndarray: Audio data as numpy array of shape (1, T) or (C, T)
|
|
51
52
|
|
|
52
53
|
Returns:
|
|
53
|
-
Emission
|
|
54
|
+
Emission numpy array of shape (1, T, vocab_size)
|
|
54
55
|
"""
|
|
55
56
|
return self.worker.emission(ndarray)
|
|
56
57
|
|
|
@@ -68,13 +69,11 @@ class Lattice1Aligner(object):
|
|
|
68
69
|
"""
|
|
69
70
|
if self.worker.separator_ort is None:
|
|
70
71
|
raise RuntimeError("Separator model not available. separator.onnx not found in model path.")
|
|
71
|
-
|
|
72
72
|
# Run separator model
|
|
73
73
|
separator_output = self.worker.separator_ort.run(
|
|
74
74
|
None,
|
|
75
|
-
{"
|
|
75
|
+
{"audios": audio},
|
|
76
76
|
)
|
|
77
|
-
|
|
78
77
|
return separator_output[0]
|
|
79
78
|
|
|
80
79
|
def alignment(
|
|
@@ -83,7 +82,7 @@ class Lattice1Aligner(object):
|
|
|
83
82
|
supervisions: List[Supervision],
|
|
84
83
|
split_sentence: Optional[bool] = False,
|
|
85
84
|
return_details: Optional[bool] = False,
|
|
86
|
-
emission: Optional[
|
|
85
|
+
emission: Optional[np.ndarray] = None,
|
|
87
86
|
offset: float = 0.0,
|
|
88
87
|
verbose: bool = True,
|
|
89
88
|
) -> Tuple[List[Supervision], List[Supervision]]:
|
|
@@ -166,3 +165,7 @@ class Lattice1Aligner(object):
|
|
|
166
165
|
raise
|
|
167
166
|
except Exception as e:
|
|
168
167
|
raise e
|
|
168
|
+
|
|
169
|
+
def profile(self) -> None:
|
|
170
|
+
"""Print profiling statistics."""
|
|
171
|
+
self.worker.profile()
|
|
@@ -4,9 +4,9 @@ from collections import defaultdict
|
|
|
4
4
|
from pathlib import Path
|
|
5
5
|
from typing import Any, Dict, Optional, Tuple
|
|
6
6
|
|
|
7
|
+
import colorful
|
|
7
8
|
import numpy as np
|
|
8
9
|
import onnxruntime as ort
|
|
9
|
-
import torch
|
|
10
10
|
from lhotse import FbankConfig
|
|
11
11
|
from lhotse.features.kaldi.layers import Wav2LogFilterBank
|
|
12
12
|
from lhotse.utils import Pathlike
|
|
@@ -14,6 +14,7 @@ from tqdm import tqdm
|
|
|
14
14
|
|
|
15
15
|
from lattifai.audio2 import AudioData
|
|
16
16
|
from lattifai.errors import AlignmentError, DependencyError, ModelLoadError
|
|
17
|
+
from lattifai.utils import safe_print
|
|
17
18
|
|
|
18
19
|
|
|
19
20
|
class Lattice1Worker:
|
|
@@ -61,18 +62,12 @@ class Lattice1Worker:
|
|
|
61
62
|
except Exception as e:
|
|
62
63
|
raise ModelLoadError(f"acoustic model from {model_path}", original_error=e)
|
|
63
64
|
|
|
65
|
+
# Get vocab_size from model output
|
|
66
|
+
self.vocab_size = self.acoustic_ort.get_outputs()[0].shape[-1]
|
|
67
|
+
|
|
64
68
|
# get input_names
|
|
65
69
|
input_names = [inp.name for inp in self.acoustic_ort.get_inputs()]
|
|
66
|
-
|
|
67
|
-
try:
|
|
68
|
-
config = FbankConfig(num_mel_bins=80, device=device, snip_edges=False)
|
|
69
|
-
config_dict = config.to_dict()
|
|
70
|
-
config_dict.pop("device")
|
|
71
|
-
self.extractor = Wav2LogFilterBank(**config_dict).to(device).eval()
|
|
72
|
-
except Exception as e:
|
|
73
|
-
raise ModelLoadError(f"feature extractor for device {device}", original_error=e)
|
|
74
|
-
else:
|
|
75
|
-
self.extractor = None # ONNX model includes feature extractor
|
|
70
|
+
assert "audios" in input_names, f"Input name audios not found in {input_names}"
|
|
76
71
|
|
|
77
72
|
# Initialize separator if available
|
|
78
73
|
separator_model_path = Path(model_path) / "separator.onnx"
|
|
@@ -80,98 +75,71 @@ class Lattice1Worker:
|
|
|
80
75
|
try:
|
|
81
76
|
self.separator_ort = ort.InferenceSession(
|
|
82
77
|
str(separator_model_path),
|
|
83
|
-
providers=
|
|
78
|
+
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
|
|
84
79
|
)
|
|
85
80
|
except Exception as e:
|
|
86
81
|
raise ModelLoadError(f"separator model from {model_path}", original_error=e)
|
|
87
82
|
else:
|
|
88
83
|
self.separator_ort = None
|
|
89
84
|
|
|
90
|
-
self.device = torch.device(device)
|
|
91
85
|
self.timings = defaultdict(lambda: 0.0)
|
|
92
86
|
|
|
93
87
|
@property
|
|
94
88
|
def frame_shift(self) -> float:
|
|
95
89
|
return 0.02 # 20 ms
|
|
96
90
|
|
|
97
|
-
|
|
98
|
-
def emission(self, ndarray: np.ndarray, acoustic_scale: float = 1.0, device: Optional[str] = None) -> torch.Tensor:
|
|
91
|
+
def emission(self, ndarray: np.ndarray, acoustic_scale: float = 1.0) -> np.ndarray:
|
|
99
92
|
"""Generate emission probabilities from audio ndarray.
|
|
100
93
|
|
|
101
94
|
Args:
|
|
102
95
|
ndarray: Audio data as numpy array of shape (1, T) or (C, T)
|
|
103
96
|
|
|
104
97
|
Returns:
|
|
105
|
-
Emission
|
|
98
|
+
Emission numpy array of shape (1, T, vocab_size)
|
|
106
99
|
"""
|
|
107
100
|
_start = time.time()
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
"features": features.cpu().numpy(),
|
|
131
|
-
"feature_lengths": np.array([features.size(1)], dtype=np.int64),
|
|
132
|
-
}
|
|
133
|
-
emission = self.acoustic_ort.run(None, ort_inputs)[0] # (1, T, vocab_size) numpy
|
|
134
|
-
emission = torch.from_numpy(emission).to(device or self.device)
|
|
101
|
+
|
|
102
|
+
if ndarray.shape[1] < 160:
|
|
103
|
+
ndarray = np.pad(ndarray, ((0, 0), (0, 320 - ndarray.shape[1])), mode="constant")
|
|
104
|
+
|
|
105
|
+
CHUNK_SIZE = 60 * 16000 # 60 seconds
|
|
106
|
+
total_samples = ndarray.shape[1]
|
|
107
|
+
|
|
108
|
+
if total_samples > CHUNK_SIZE:
|
|
109
|
+
frame_samples = int(16000 * self.frame_shift)
|
|
110
|
+
emissions = np.empty((1, total_samples // frame_samples + 1, self.vocab_size), dtype=np.float32)
|
|
111
|
+
for start in range(0, total_samples, CHUNK_SIZE):
|
|
112
|
+
chunk = ndarray[:, start : start + CHUNK_SIZE]
|
|
113
|
+
if chunk.shape[1] < 160:
|
|
114
|
+
chunk = np.pad(chunk, ((0, 0), (0, 320 - chunk.shape[1])), mode="constant")
|
|
115
|
+
|
|
116
|
+
emission_out = self.acoustic_ort.run(None, {"audios": chunk})[0]
|
|
117
|
+
if acoustic_scale != 1.0:
|
|
118
|
+
emission_out *= acoustic_scale
|
|
119
|
+
sf = start // frame_samples # start frame
|
|
120
|
+
lf = sf + emission_out.shape[1] # last frame
|
|
121
|
+
emissions[0, sf:lf, :] = emission_out
|
|
122
|
+
emissions[:, lf:, :] = 0.0
|
|
135
123
|
else:
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
emission = self.acoustic_ort.run(
|
|
144
|
-
None,
|
|
145
|
-
{
|
|
146
|
-
"audios": ndarray[:, start : start + CHUNK_SIZE],
|
|
147
|
-
},
|
|
148
|
-
) # (1, T, vocab_size) numpy
|
|
149
|
-
emissions.append(emission[0])
|
|
150
|
-
|
|
151
|
-
emission = torch.cat(
|
|
152
|
-
[torch.from_numpy(emission).to(device or self.device) for emission in emissions], dim=1
|
|
153
|
-
) # (1, T, vocab_size)
|
|
154
|
-
del emissions
|
|
155
|
-
else:
|
|
156
|
-
emission = self.acoustic_ort.run(
|
|
157
|
-
None,
|
|
158
|
-
{
|
|
159
|
-
"audios": ndarray,
|
|
160
|
-
},
|
|
161
|
-
) # (1, T, vocab_size) numpy
|
|
162
|
-
emission = torch.from_numpy(emission[0]).to(device or self.device)
|
|
124
|
+
emission_out = self.acoustic_ort.run(
|
|
125
|
+
None,
|
|
126
|
+
{
|
|
127
|
+
"audios": ndarray,
|
|
128
|
+
},
|
|
129
|
+
) # (1, T, vocab_size) numpy
|
|
130
|
+
emissions = emission_out[0]
|
|
163
131
|
|
|
164
|
-
|
|
165
|
-
|
|
132
|
+
if acoustic_scale != 1.0:
|
|
133
|
+
emissions *= acoustic_scale
|
|
166
134
|
|
|
167
135
|
self.timings["emission"] += time.time() - _start
|
|
168
|
-
return
|
|
136
|
+
return emissions # (1, T, vocab_size) numpy
|
|
169
137
|
|
|
170
138
|
def alignment(
|
|
171
139
|
self,
|
|
172
140
|
audio: AudioData,
|
|
173
141
|
lattice_graph: Tuple[str, int, float],
|
|
174
|
-
emission: Optional[
|
|
142
|
+
emission: Optional[np.ndarray] = None,
|
|
175
143
|
offset: float = 0.0,
|
|
176
144
|
) -> Dict[str, Any]:
|
|
177
145
|
"""Process audio with LatticeGraph.
|
|
@@ -179,7 +147,7 @@ class Lattice1Worker:
|
|
|
179
147
|
Args:
|
|
180
148
|
audio: AudioData object
|
|
181
149
|
lattice_graph: LatticeGraph data
|
|
182
|
-
emission: Pre-computed emission
|
|
150
|
+
emission: Pre-computed emission numpy array (ignored if streaming=True)
|
|
183
151
|
offset: Time offset for the audio
|
|
184
152
|
streaming: If True, use streaming mode for memory-efficient processing
|
|
185
153
|
|
|
@@ -192,25 +160,18 @@ class Lattice1Worker:
|
|
|
192
160
|
AlignmentError: If alignment process fails
|
|
193
161
|
"""
|
|
194
162
|
try:
|
|
195
|
-
import k2
|
|
163
|
+
import k2py as k2
|
|
196
164
|
except ImportError:
|
|
197
|
-
raise DependencyError("
|
|
198
|
-
|
|
199
|
-
try:
|
|
200
|
-
from lattifai_core.lattice.decode import align_segments
|
|
201
|
-
except ImportError:
|
|
202
|
-
raise DependencyError("lattifai_core", install_command="Contact support for lattifai_core installation")
|
|
165
|
+
raise DependencyError("k2py", install_command="pip install k2py")
|
|
203
166
|
|
|
204
167
|
lattice_graph_str, final_state, acoustic_scale = lattice_graph
|
|
205
168
|
|
|
206
169
|
_start = time.time()
|
|
207
170
|
try:
|
|
208
|
-
# Create decoding graph
|
|
209
|
-
|
|
210
|
-
decoding_graph
|
|
211
|
-
|
|
212
|
-
decoding_graph.skip_id = int(final_state)
|
|
213
|
-
decoding_graph.return_id = int(final_state + 1)
|
|
171
|
+
# Create decoding graph using k2py
|
|
172
|
+
graph_dict = k2.CreateFsaVecFromStr(lattice_graph_str, int(final_state), False)
|
|
173
|
+
decoding_graph = graph_dict["fsa"]
|
|
174
|
+
aux_labels = graph_dict["aux_labels"]
|
|
214
175
|
except Exception as e:
|
|
215
176
|
raise AlignmentError(
|
|
216
177
|
"Failed to create decoding graph from lattice",
|
|
@@ -218,11 +179,6 @@ class Lattice1Worker:
|
|
|
218
179
|
)
|
|
219
180
|
self.timings["decoding_graph"] += time.time() - _start
|
|
220
181
|
|
|
221
|
-
if self.device.type == "mps":
|
|
222
|
-
device = "cpu" # k2 does not support mps yet
|
|
223
|
-
else:
|
|
224
|
-
device = self.device
|
|
225
|
-
|
|
226
182
|
_start = time.time()
|
|
227
183
|
|
|
228
184
|
# Get beam search parameters from config or use defaults
|
|
@@ -232,71 +188,54 @@ class Lattice1Worker:
|
|
|
232
188
|
max_active_states = self.alignment_config.max_active_states or 10000
|
|
233
189
|
|
|
234
190
|
if emission is None and audio.streaming_mode:
|
|
235
|
-
#
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
total_minutes = int(total_duration / 60.0)
|
|
244
|
-
|
|
245
|
-
with tqdm(
|
|
246
|
-
total=total_minutes,
|
|
247
|
-
desc=f"Processing audio ({total_minutes} min)",
|
|
248
|
-
unit="min",
|
|
249
|
-
unit_scale=False,
|
|
250
|
-
unit_divisor=1,
|
|
251
|
-
) as pbar:
|
|
252
|
-
for chunk in audio.iter_chunks():
|
|
253
|
-
chunk_emission = self.emission(chunk.ndarray, acoustic_scale=acoustic_scale, device=device)
|
|
254
|
-
|
|
255
|
-
# Update progress based on chunk duration in minutes
|
|
256
|
-
chunk_duration = int(chunk.duration / 60.0)
|
|
257
|
-
pbar.update(chunk_duration)
|
|
258
|
-
processed_duration += chunk_duration
|
|
259
|
-
|
|
260
|
-
yield chunk_emission
|
|
261
|
-
|
|
262
|
-
# Calculate total frames for supervision_segments
|
|
263
|
-
total_frames = int(audio.duration / self.frame_shift)
|
|
264
|
-
|
|
265
|
-
results, labels = align_segments(
|
|
266
|
-
emission_iterator(), # Pass iterator for streaming
|
|
267
|
-
decoding_graph.to(device),
|
|
268
|
-
torch.tensor([total_frames], dtype=torch.int32),
|
|
269
|
-
search_beam=search_beam,
|
|
270
|
-
output_beam=output_beam,
|
|
271
|
-
min_active_states=min_active_states,
|
|
272
|
-
max_active_states=max_active_states,
|
|
273
|
-
subsampling_factor=1,
|
|
274
|
-
reject_low_confidence=False,
|
|
191
|
+
# Initialize OnlineDenseIntersecter for streaming
|
|
192
|
+
intersecter = k2.OnlineDenseIntersecter(
|
|
193
|
+
decoding_graph,
|
|
194
|
+
aux_labels,
|
|
195
|
+
float(search_beam),
|
|
196
|
+
float(output_beam),
|
|
197
|
+
int(min_active_states),
|
|
198
|
+
int(max_active_states),
|
|
275
199
|
)
|
|
276
200
|
|
|
277
|
-
#
|
|
201
|
+
# Streaming mode
|
|
202
|
+
total_duration = audio.duration
|
|
203
|
+
total_minutes = int(total_duration / 60.0)
|
|
204
|
+
|
|
205
|
+
with tqdm(
|
|
206
|
+
total=total_minutes,
|
|
207
|
+
desc=f"Processing audio ({total_minutes} min)",
|
|
208
|
+
unit="min",
|
|
209
|
+
unit_scale=False,
|
|
210
|
+
unit_divisor=1,
|
|
211
|
+
) as pbar:
|
|
212
|
+
for chunk in audio.iter_chunks():
|
|
213
|
+
chunk_emission = self.emission(chunk.ndarray, acoustic_scale=acoustic_scale)
|
|
214
|
+
intersecter.decode(chunk_emission[0])
|
|
215
|
+
|
|
216
|
+
# Update progress
|
|
217
|
+
chunk_duration = int(chunk.duration / 60.0)
|
|
218
|
+
pbar.update(chunk_duration)
|
|
219
|
+
|
|
278
220
|
emission_result = None
|
|
221
|
+
# Get results from intersecter
|
|
222
|
+
results, labels = intersecter.finish()
|
|
279
223
|
else:
|
|
280
|
-
# Batch mode
|
|
224
|
+
# Batch mode
|
|
281
225
|
if emission is None:
|
|
282
|
-
emission = self.emission(
|
|
283
|
-
audio.ndarray, acoustic_scale=acoustic_scale, device=device
|
|
284
|
-
) # (1, T, vocab_size)
|
|
226
|
+
emission = self.emission(audio.ndarray, acoustic_scale=acoustic_scale) # (1, T, vocab_size)
|
|
285
227
|
else:
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
search_beam
|
|
293
|
-
output_beam
|
|
294
|
-
min_active_states
|
|
295
|
-
max_active_states
|
|
296
|
-
subsampling_factor=1,
|
|
297
|
-
reject_low_confidence=False,
|
|
228
|
+
if acoustic_scale != 1.0:
|
|
229
|
+
emission *= acoustic_scale
|
|
230
|
+
# Use AlignSegments directly
|
|
231
|
+
results, labels = k2.AlignSegments(
|
|
232
|
+
graph_dict,
|
|
233
|
+
emission[0], # Pass the prepared scores
|
|
234
|
+
float(search_beam),
|
|
235
|
+
float(output_beam),
|
|
236
|
+
int(min_active_states),
|
|
237
|
+
int(max_active_states),
|
|
298
238
|
)
|
|
299
|
-
|
|
300
239
|
emission_result = emission
|
|
301
240
|
|
|
302
241
|
self.timings["align_segments"] += time.time() - _start
|
|
@@ -304,6 +243,41 @@ class Lattice1Worker:
|
|
|
304
243
|
channel = 0
|
|
305
244
|
return emission_result, results, labels, self.frame_shift, offset, channel # frame_shift=20ms
|
|
306
245
|
|
|
246
|
+
def profile(self) -> None:
|
|
247
|
+
"""Print formatted profiling statistics."""
|
|
248
|
+
if not self.timings:
|
|
249
|
+
return
|
|
250
|
+
|
|
251
|
+
safe_print(colorful.bold(colorful.cyan("\n⏱️ Alignment Profiling")))
|
|
252
|
+
safe_print(colorful.gray("─" * 44))
|
|
253
|
+
safe_print(
|
|
254
|
+
f"{colorful.bold('Phase'.ljust(21))} "
|
|
255
|
+
f"{colorful.bold('Time'.ljust(12))} "
|
|
256
|
+
f"{colorful.bold('Percent'.rjust(8))}"
|
|
257
|
+
)
|
|
258
|
+
safe_print(colorful.gray("─" * 44))
|
|
259
|
+
|
|
260
|
+
total_time = sum(self.timings.values())
|
|
261
|
+
|
|
262
|
+
# Sort by duration descending
|
|
263
|
+
sorted_stats = sorted(self.timings.items(), key=lambda x: x[1], reverse=True)
|
|
264
|
+
|
|
265
|
+
for name, duration in sorted_stats:
|
|
266
|
+
percentage = (duration / total_time * 100) if total_time > 0 else 0.0
|
|
267
|
+
# Name: Cyan, Time: Yellow, Percent: Gray
|
|
268
|
+
safe_print(
|
|
269
|
+
f"{name:<20} "
|
|
270
|
+
f"{colorful.yellow(f'{duration:7.4f}s'.ljust(12))} "
|
|
271
|
+
f"{colorful.gray(f'{percentage:.2f}%'.rjust(8))}"
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
safe_print(colorful.gray("─" * 44))
|
|
275
|
+
# Pad "Total Time" before coloring to ensure correct alignment (ANSI codes don't count for width)
|
|
276
|
+
safe_print(
|
|
277
|
+
f"{colorful.bold('Total Time'.ljust(20))} "
|
|
278
|
+
f"{colorful.bold(colorful.yellow(f'{total_time:7.4f}s'.ljust(12)))}\n"
|
|
279
|
+
)
|
|
280
|
+
|
|
307
281
|
|
|
308
282
|
def _load_worker(model_path: str, device: str, config: Optional[Any] = None) -> Lattice1Worker:
|
|
309
283
|
"""Instantiate lattice worker with consistent error handling."""
|
lattifai/alignment/tokenizer.py
CHANGED
|
@@ -4,7 +4,7 @@ import re
|
|
|
4
4
|
from collections import defaultdict
|
|
5
5
|
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar
|
|
6
6
|
|
|
7
|
-
import
|
|
7
|
+
import numpy as np
|
|
8
8
|
|
|
9
9
|
from lattifai.alignment.phonemizer import G2Phonemizer
|
|
10
10
|
from lattifai.caption import Supervision
|
|
@@ -121,6 +121,7 @@ class LatticeTokenizer:
|
|
|
121
121
|
def __init__(self, client_wrapper: Any):
|
|
122
122
|
self.client_wrapper = client_wrapper
|
|
123
123
|
self.model_name = ""
|
|
124
|
+
self.model_hub: Optional[str] = None
|
|
124
125
|
self.words: List[str] = []
|
|
125
126
|
self.g2p_model: Any = None # Placeholder for G2P model
|
|
126
127
|
self.dictionaries = defaultdict(lambda: [])
|
|
@@ -142,10 +143,20 @@ class LatticeTokenizer:
|
|
|
142
143
|
elif device.startswith("mps") and ort.get_all_providers().count("MPSExecutionProvider") > 0:
|
|
143
144
|
providers.append("MPSExecutionProvider")
|
|
144
145
|
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
146
|
+
if self.model_hub == "modelscope":
|
|
147
|
+
from modelscope.hub.snapshot_download import snapshot_download as ms_snapshot
|
|
148
|
+
|
|
149
|
+
downloaded_path = ms_snapshot("LattifAI/OmniTokenizer")
|
|
150
|
+
sat = SaT(
|
|
151
|
+
f"{downloaded_path}/sat-3l-sm",
|
|
152
|
+
tokenizer_name_or_path=f"{downloaded_path}/xlm-roberta-base",
|
|
153
|
+
ort_providers=providers + ["CPUExecutionProvider"],
|
|
154
|
+
)
|
|
155
|
+
else:
|
|
156
|
+
sat = SaT(
|
|
157
|
+
"sat-3l-sm",
|
|
158
|
+
ort_providers=providers + ["CPUExecutionProvider"],
|
|
159
|
+
)
|
|
149
160
|
self.sentence_splitter = sat
|
|
150
161
|
|
|
151
162
|
@staticmethod
|
|
@@ -200,6 +211,7 @@ class LatticeTokenizer:
|
|
|
200
211
|
client_wrapper: Any,
|
|
201
212
|
model_path: str,
|
|
202
213
|
model_name: str,
|
|
214
|
+
model_hub: Optional[str] = None,
|
|
203
215
|
device: str = "cpu",
|
|
204
216
|
compressed: bool = True,
|
|
205
217
|
) -> TokenizerT:
|
|
@@ -227,6 +239,7 @@ class LatticeTokenizer:
|
|
|
227
239
|
|
|
228
240
|
tokenizer = cls(client_wrapper=client_wrapper)
|
|
229
241
|
tokenizer.model_name = model_name
|
|
242
|
+
tokenizer.model_hub = model_hub
|
|
230
243
|
tokenizer.words = data["words"]
|
|
231
244
|
tokenizer.dictionaries = defaultdict(list, data["dictionaries"])
|
|
232
245
|
tokenizer.oov_word = data["oov_word"]
|
|
@@ -431,7 +444,7 @@ class LatticeTokenizer:
|
|
|
431
444
|
def detokenize(
|
|
432
445
|
self,
|
|
433
446
|
lattice_id: str,
|
|
434
|
-
lattice_results: Tuple[
|
|
447
|
+
lattice_results: Tuple[np.ndarray, Any, Any, float, float],
|
|
435
448
|
supervisions: List[Supervision],
|
|
436
449
|
return_details: bool = False,
|
|
437
450
|
start_margin: float = 0.08,
|
|
@@ -481,7 +494,7 @@ class LatticeTokenizer:
|
|
|
481
494
|
|
|
482
495
|
def _add_confidence_scores(
|
|
483
496
|
supervisions: List[Supervision],
|
|
484
|
-
emission:
|
|
497
|
+
emission: np.ndarray,
|
|
485
498
|
labels: List[int],
|
|
486
499
|
frame_shift: float,
|
|
487
500
|
offset: float = 0.0,
|
|
@@ -499,17 +512,17 @@ def _add_confidence_scores(
|
|
|
499
512
|
labels: Token labels corresponding to aligned tokens
|
|
500
513
|
frame_shift: Frame shift in seconds for converting frames to time
|
|
501
514
|
"""
|
|
502
|
-
tokens =
|
|
515
|
+
tokens = np.array(labels, dtype=np.int64)
|
|
503
516
|
|
|
504
517
|
for supervision in supervisions:
|
|
505
518
|
start_frame = int((supervision.start - offset) / frame_shift)
|
|
506
519
|
end_frame = int((supervision.end - offset) / frame_shift)
|
|
507
520
|
|
|
508
521
|
# Compute segment-level confidence
|
|
509
|
-
probabilities = emission[0, start_frame:end_frame]
|
|
522
|
+
probabilities = np.exp(emission[0, start_frame:end_frame])
|
|
510
523
|
aligned = probabilities[range(0, end_frame - start_frame), tokens[start_frame:end_frame]]
|
|
511
|
-
diffprobs =
|
|
512
|
-
supervision.score = round(1.0 - diffprobs.mean()
|
|
524
|
+
diffprobs = np.max(probabilities, axis=-1) - aligned
|
|
525
|
+
supervision.score = round(1.0 - diffprobs.mean(), ndigits=4)
|
|
513
526
|
|
|
514
527
|
# Compute word-level confidence if alignment exists
|
|
515
528
|
if hasattr(supervision, "alignment") and supervision.alignment:
|
|
@@ -517,7 +530,7 @@ def _add_confidence_scores(
|
|
|
517
530
|
for w, item in enumerate(words):
|
|
518
531
|
start = int((item.start - offset) / frame_shift) - start_frame
|
|
519
532
|
end = int((item.end - offset) / frame_shift) - start_frame
|
|
520
|
-
words[w] = item._replace(score=round(1.0 - diffprobs[start:end].mean()
|
|
533
|
+
words[w] = item._replace(score=round(1.0 - diffprobs[start:end].mean(), ndigits=4))
|
|
521
534
|
|
|
522
535
|
|
|
523
536
|
def _update_alignments_speaker(supervisions: List[Supervision], alignments: List[Supervision]) -> List[Supervision]:
|
|
@@ -539,6 +552,7 @@ def _load_tokenizer(
|
|
|
539
552
|
model_name: str,
|
|
540
553
|
device: str,
|
|
541
554
|
*,
|
|
555
|
+
model_hub: Optional[str] = None,
|
|
542
556
|
tokenizer_cls: Type[LatticeTokenizer] = LatticeTokenizer,
|
|
543
557
|
) -> LatticeTokenizer:
|
|
544
558
|
"""Instantiate tokenizer with consistent error handling."""
|
|
@@ -546,5 +560,6 @@ def _load_tokenizer(
|
|
|
546
560
|
client_wrapper=client_wrapper,
|
|
547
561
|
model_path=model_path,
|
|
548
562
|
model_name=model_name,
|
|
563
|
+
model_hub=model_hub,
|
|
549
564
|
device=device,
|
|
550
565
|
)
|
lattifai/audio2.py
CHANGED
|
@@ -36,7 +36,7 @@ class AudioData(namedtuple("AudioData", ["sampling_rate", "ndarray", "path", "st
|
|
|
36
36
|
@property
|
|
37
37
|
def streaming_mode(self) -> bool:
|
|
38
38
|
"""Indicates whether streaming mode is enabled based on streaming_chunk_secs."""
|
|
39
|
-
if self.streaming_chunk_secs
|
|
39
|
+
if self.streaming_chunk_secs:
|
|
40
40
|
return self.duration > self.streaming_chunk_secs * 1.1
|
|
41
41
|
return False
|
|
42
42
|
|
lattifai/cli/diarization.py
CHANGED
|
@@ -8,7 +8,7 @@ import nemo_run as run
|
|
|
8
8
|
from typing_extensions import Annotated
|
|
9
9
|
|
|
10
10
|
from lattifai.client import LattifAI
|
|
11
|
-
from lattifai.config import CaptionConfig, ClientConfig, DiarizationConfig, MediaConfig
|
|
11
|
+
from lattifai.config import AlignmentConfig, CaptionConfig, ClientConfig, DiarizationConfig, MediaConfig
|
|
12
12
|
from lattifai.utils import safe_print
|
|
13
13
|
|
|
14
14
|
__all__ = ["diarize"]
|
|
@@ -22,6 +22,7 @@ def diarize(
|
|
|
22
22
|
media: Annotated[Optional[MediaConfig], run.Config[MediaConfig]] = None,
|
|
23
23
|
caption: Annotated[Optional[CaptionConfig], run.Config[CaptionConfig]] = None,
|
|
24
24
|
client: Annotated[Optional[ClientConfig], run.Config[ClientConfig]] = None,
|
|
25
|
+
alignment: Annotated[Optional[AlignmentConfig], run.Config[AlignmentConfig]] = None,
|
|
25
26
|
diarization: Annotated[Optional[DiarizationConfig], run.Config[DiarizationConfig]] = None,
|
|
26
27
|
):
|
|
27
28
|
"""Run speaker diarization on aligned captions and audio."""
|
|
@@ -53,6 +54,7 @@ def diarize(
|
|
|
53
54
|
|
|
54
55
|
client_instance = LattifAI(
|
|
55
56
|
client_config=client,
|
|
57
|
+
alignment_config=alignment,
|
|
56
58
|
caption_config=caption_config,
|
|
57
59
|
diarization_config=diarization_config,
|
|
58
60
|
)
|