lattifai 1.0.4__py3-none-any.whl → 1.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lattifai/__init__.py +10 -0
- lattifai/alignment/lattice1_aligner.py +33 -13
- lattifai/alignment/lattice1_worker.py +121 -50
- lattifai/alignment/segmenter.py +3 -2
- lattifai/alignment/tokenizer.py +3 -3
- lattifai/audio2.py +269 -70
- lattifai/caption/caption.py +161 -3
- lattifai/cli/alignment.py +2 -1
- lattifai/cli/app_installer.py +35 -33
- lattifai/cli/caption.py +8 -18
- lattifai/cli/server.py +3 -1
- lattifai/cli/transcribe.py +53 -38
- lattifai/cli/youtube.py +1 -0
- lattifai/client.py +16 -11
- lattifai/config/alignment.py +23 -2
- lattifai/config/caption.py +1 -1
- lattifai/config/media.py +23 -3
- lattifai/errors.py +7 -3
- lattifai/mixin.py +26 -15
- lattifai/server/app.py +2 -1
- lattifai/utils.py +37 -0
- lattifai/workflow/file_manager.py +15 -13
- lattifai/workflow/youtube.py +16 -1
- {lattifai-1.0.4.dist-info → lattifai-1.0.5.dist-info}/METADATA +65 -15
- {lattifai-1.0.4.dist-info → lattifai-1.0.5.dist-info}/RECORD +29 -29
- {lattifai-1.0.4.dist-info → lattifai-1.0.5.dist-info}/licenses/LICENSE +1 -1
- {lattifai-1.0.4.dist-info → lattifai-1.0.5.dist-info}/WHEEL +0 -0
- {lattifai-1.0.4.dist-info → lattifai-1.0.5.dist-info}/entry_points.txt +0 -0
- {lattifai-1.0.4.dist-info → lattifai-1.0.5.dist-info}/top_level.txt +0 -0
lattifai/__init__.py
CHANGED
|
@@ -1,7 +1,17 @@
|
|
|
1
|
+
import os
|
|
1
2
|
import sys
|
|
2
3
|
import warnings
|
|
3
4
|
from importlib.metadata import version
|
|
4
5
|
|
|
6
|
+
# Suppress SWIG deprecation warnings before any imports
|
|
7
|
+
warnings.filterwarnings("ignore", category=DeprecationWarning, message=".*SwigPy.*")
|
|
8
|
+
|
|
9
|
+
# Suppress PyTorch transformer nested tensor warning
|
|
10
|
+
warnings.filterwarnings("ignore", category=UserWarning, message=".*enable_nested_tensor.*")
|
|
11
|
+
|
|
12
|
+
# Disable tokenizers parallelism warning
|
|
13
|
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
14
|
+
|
|
5
15
|
# Re-export I/O classes
|
|
6
16
|
from .caption import Caption
|
|
7
17
|
|
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
from typing import Any, List, Optional, Tuple
|
|
4
4
|
|
|
5
5
|
import colorful
|
|
6
|
+
import numpy as np
|
|
6
7
|
import torch
|
|
7
8
|
|
|
8
9
|
from lattifai.audio2 import AudioData
|
|
@@ -13,7 +14,7 @@ from lattifai.errors import (
|
|
|
13
14
|
LatticeDecodingError,
|
|
14
15
|
LatticeEncodingError,
|
|
15
16
|
)
|
|
16
|
-
from lattifai.utils import _resolve_model_path
|
|
17
|
+
from lattifai.utils import _resolve_model_path, safe_print
|
|
17
18
|
|
|
18
19
|
from .lattice1_worker import _load_worker
|
|
19
20
|
from .tokenizer import _load_tokenizer
|
|
@@ -37,12 +38,20 @@ class Lattice1Aligner(object):
|
|
|
37
38
|
model_path = _resolve_model_path(config.model_name)
|
|
38
39
|
|
|
39
40
|
self.tokenizer = _load_tokenizer(client_wrapper, model_path, config.model_name, config.device)
|
|
40
|
-
self.worker = _load_worker(model_path, config.device)
|
|
41
|
+
self.worker = _load_worker(model_path, config.device, config)
|
|
41
42
|
|
|
42
43
|
self.frame_shift = self.worker.frame_shift
|
|
43
44
|
|
|
44
|
-
def emission(self,
|
|
45
|
-
|
|
45
|
+
def emission(self, ndarray: np.ndarray) -> torch.Tensor:
|
|
46
|
+
"""Generate emission probabilities from audio ndarray.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
ndarray: Audio data as numpy array of shape (1, T) or (C, T)
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
Emission tensor of shape (1, T, vocab_size)
|
|
53
|
+
"""
|
|
54
|
+
return self.worker.emission(ndarray)
|
|
46
55
|
|
|
47
56
|
def alignment(
|
|
48
57
|
self,
|
|
@@ -72,23 +81,34 @@ class Lattice1Aligner(object):
|
|
|
72
81
|
"""
|
|
73
82
|
try:
|
|
74
83
|
if verbose:
|
|
75
|
-
|
|
84
|
+
safe_print(colorful.cyan("🔗 Step 2: Creating lattice graph from segments"))
|
|
76
85
|
try:
|
|
77
86
|
supervisions, lattice_id, lattice_graph = self.tokenizer.tokenize(
|
|
78
87
|
supervisions, split_sentence=split_sentence
|
|
79
88
|
)
|
|
80
89
|
if verbose:
|
|
81
|
-
|
|
90
|
+
safe_print(colorful.green(f" ✓ Generated lattice graph with ID: {lattice_id}"))
|
|
82
91
|
except Exception as e:
|
|
83
92
|
text_content = " ".join([sup.text for sup in supervisions]) if supervisions else ""
|
|
84
93
|
raise LatticeEncodingError(text_content, original_error=e)
|
|
85
94
|
|
|
86
95
|
if verbose:
|
|
87
|
-
|
|
96
|
+
safe_print(colorful.cyan(f"🔍 Step 3: Searching lattice graph with media: {audio}"))
|
|
97
|
+
if audio.streaming_chunk_secs:
|
|
98
|
+
safe_print(
|
|
99
|
+
colorful.yellow(
|
|
100
|
+
f" ⚡Using streaming mode with {audio.streaming_chunk_secs}s (chunk duration)"
|
|
101
|
+
)
|
|
102
|
+
)
|
|
88
103
|
try:
|
|
89
|
-
lattice_results = self.worker.alignment(
|
|
104
|
+
lattice_results = self.worker.alignment(
|
|
105
|
+
audio,
|
|
106
|
+
lattice_graph,
|
|
107
|
+
emission=emission,
|
|
108
|
+
offset=offset,
|
|
109
|
+
)
|
|
90
110
|
if verbose:
|
|
91
|
-
|
|
111
|
+
safe_print(colorful.green(" ✓ Lattice search completed"))
|
|
92
112
|
except Exception as e:
|
|
93
113
|
raise AlignmentError(
|
|
94
114
|
f"Audio alignment failed for {audio}",
|
|
@@ -97,18 +117,18 @@ class Lattice1Aligner(object):
|
|
|
97
117
|
)
|
|
98
118
|
|
|
99
119
|
if verbose:
|
|
100
|
-
|
|
120
|
+
safe_print(colorful.cyan("🎯 Step 4: Decoding lattice results to aligned segments"))
|
|
101
121
|
try:
|
|
102
122
|
alignments = self.tokenizer.detokenize(
|
|
103
123
|
lattice_id, lattice_results, supervisions=supervisions, return_details=return_details
|
|
104
124
|
)
|
|
105
125
|
if verbose:
|
|
106
|
-
|
|
126
|
+
safe_print(colorful.green(f" ✓ Successfully aligned {len(alignments)} segments"))
|
|
107
127
|
except LatticeDecodingError as e:
|
|
108
|
-
|
|
128
|
+
safe_print(colorful.red(" x Failed to decode lattice alignment results"))
|
|
109
129
|
raise e
|
|
110
130
|
except Exception as e:
|
|
111
|
-
|
|
131
|
+
safe_print(colorful.red(" x Failed to decode lattice alignment results"))
|
|
112
132
|
raise LatticeDecodingError(lattice_id, original_error=e)
|
|
113
133
|
|
|
114
134
|
return (supervisions, alignments)
|
|
@@ -9,6 +9,7 @@ import torch
|
|
|
9
9
|
from lhotse import FbankConfig
|
|
10
10
|
from lhotse.features.kaldi.layers import Wav2LogFilterBank
|
|
11
11
|
from lhotse.utils import Pathlike
|
|
12
|
+
from tqdm import tqdm
|
|
12
13
|
|
|
13
14
|
from lattifai.audio2 import AudioData
|
|
14
15
|
from lattifai.errors import AlignmentError, DependencyError, ModelLoadError
|
|
@@ -17,12 +18,17 @@ from lattifai.errors import AlignmentError, DependencyError, ModelLoadError
|
|
|
17
18
|
class Lattice1Worker:
|
|
18
19
|
"""Worker for processing audio with LatticeGraph."""
|
|
19
20
|
|
|
20
|
-
def __init__(
|
|
21
|
+
def __init__(
|
|
22
|
+
self, model_path: Pathlike, device: str = "cpu", num_threads: int = 8, config: Optional[Any] = None
|
|
23
|
+
) -> None:
|
|
21
24
|
try:
|
|
22
|
-
self.
|
|
25
|
+
self.model_config = json.load(open(f"{model_path}/config.json"))
|
|
23
26
|
except Exception as e:
|
|
24
27
|
raise ModelLoadError(f"config from {model_path}", original_error=e)
|
|
25
28
|
|
|
29
|
+
# Store alignment config with beam search parameters
|
|
30
|
+
self.alignment_config = config
|
|
31
|
+
|
|
26
32
|
# SessionOptions
|
|
27
33
|
sess_options = ort.SessionOptions()
|
|
28
34
|
# sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
@@ -75,59 +81,74 @@ class Lattice1Worker:
|
|
|
75
81
|
return 0.02 # 20 ms
|
|
76
82
|
|
|
77
83
|
@torch.inference_mode()
|
|
78
|
-
def emission(self,
|
|
84
|
+
def emission(self, ndarray: np.ndarray, acoustic_scale: float = 1.0, device: Optional[str] = None) -> torch.Tensor:
|
|
85
|
+
"""Generate emission probabilities from audio ndarray.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
ndarray: Audio data as numpy array of shape (1, T) or (C, T)
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
Emission tensor of shape (1, T, vocab_size)
|
|
92
|
+
"""
|
|
79
93
|
_start = time.time()
|
|
80
94
|
if self.extractor is not None:
|
|
81
95
|
# audio -> features -> emission
|
|
96
|
+
audio = torch.from_numpy(ndarray).to(self.device)
|
|
97
|
+
if audio.shape[1] < 160:
|
|
98
|
+
audio = torch.nn.functional.pad(audio, (0, 320 - audio.shape[1]))
|
|
82
99
|
features = self.extractor(audio) # (1, T, D)
|
|
83
100
|
if features.shape[1] > 6000:
|
|
84
|
-
features_list = torch.split(features, 6000, dim=1)
|
|
85
101
|
emissions = []
|
|
86
|
-
for
|
|
102
|
+
for start in range(0, features.size(1), 6000):
|
|
103
|
+
_features = features[:, start : start + 6000, :]
|
|
87
104
|
ort_inputs = {
|
|
88
|
-
"features":
|
|
89
|
-
"feature_lengths": np.array([
|
|
105
|
+
"features": _features.cpu().numpy(),
|
|
106
|
+
"feature_lengths": np.array([_features.size(1)], dtype=np.int64),
|
|
90
107
|
}
|
|
91
108
|
emission = self.acoustic_ort.run(None, ort_inputs)[0] # (1, T, vocab_size) numpy
|
|
92
109
|
emissions.append(emission)
|
|
93
110
|
emission = torch.cat(
|
|
94
|
-
[torch.from_numpy(emission).to(self.device) for emission in emissions], dim=1
|
|
111
|
+
[torch.from_numpy(emission).to(device or self.device) for emission in emissions], dim=1
|
|
95
112
|
) # (1, T, vocab_size)
|
|
113
|
+
del emissions
|
|
96
114
|
else:
|
|
97
115
|
ort_inputs = {
|
|
98
116
|
"features": features.cpu().numpy(),
|
|
99
117
|
"feature_lengths": np.array([features.size(1)], dtype=np.int64),
|
|
100
118
|
}
|
|
101
119
|
emission = self.acoustic_ort.run(None, ort_inputs)[0] # (1, T, vocab_size) numpy
|
|
102
|
-
emission = torch.from_numpy(emission).to(self.device)
|
|
120
|
+
emission = torch.from_numpy(emission).to(device or self.device)
|
|
103
121
|
else:
|
|
122
|
+
if ndarray.shape[1] < 160:
|
|
123
|
+
ndarray = np.pad(ndarray, ((0, 0), (0, 320 - ndarray.shape[1])), mode="constant")
|
|
124
|
+
|
|
104
125
|
CHUNK_SIZE = 60 * 16000 # 60 seconds
|
|
105
|
-
if
|
|
106
|
-
audio_list = torch.split(audio, CHUNK_SIZE, dim=1)
|
|
126
|
+
if ndarray.shape[1] > CHUNK_SIZE:
|
|
107
127
|
emissions = []
|
|
108
|
-
for
|
|
128
|
+
for start in range(0, ndarray.shape[1], CHUNK_SIZE):
|
|
109
129
|
emission = self.acoustic_ort.run(
|
|
110
130
|
None,
|
|
111
131
|
{
|
|
112
|
-
"audios":
|
|
132
|
+
"audios": ndarray[:, start : start + CHUNK_SIZE],
|
|
113
133
|
},
|
|
114
|
-
)
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
emissions.append(emission)
|
|
134
|
+
) # (1, T, vocab_size) numpy
|
|
135
|
+
emissions.append(emission[0])
|
|
136
|
+
|
|
118
137
|
emission = torch.cat(
|
|
119
|
-
[torch.from_numpy(emission).to(self.device) for emission in emissions], dim=1
|
|
138
|
+
[torch.from_numpy(emission).to(device or self.device) for emission in emissions], dim=1
|
|
120
139
|
) # (1, T, vocab_size)
|
|
140
|
+
del emissions
|
|
121
141
|
else:
|
|
122
142
|
emission = self.acoustic_ort.run(
|
|
123
143
|
None,
|
|
124
144
|
{
|
|
125
|
-
"audios":
|
|
145
|
+
"audios": ndarray,
|
|
126
146
|
},
|
|
127
|
-
)
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
147
|
+
) # (1, T, vocab_size) numpy
|
|
148
|
+
emission = torch.from_numpy(emission[0]).to(device or self.device)
|
|
149
|
+
|
|
150
|
+
if acoustic_scale != 1.0:
|
|
151
|
+
emission = emission.mul_(acoustic_scale)
|
|
131
152
|
|
|
132
153
|
self.timings["emission"] += time.time() - _start
|
|
133
154
|
return emission # (1, T, vocab_size) torch
|
|
@@ -144,6 +165,9 @@ class Lattice1Worker:
|
|
|
144
165
|
Args:
|
|
145
166
|
audio: AudioData object
|
|
146
167
|
lattice_graph: LatticeGraph data
|
|
168
|
+
emission: Pre-computed emission tensor (ignored if streaming=True)
|
|
169
|
+
offset: Time offset for the audio
|
|
170
|
+
streaming: If True, use streaming mode for memory-efficient processing
|
|
147
171
|
|
|
148
172
|
Returns:
|
|
149
173
|
Processed LatticeGraph
|
|
@@ -153,16 +177,6 @@ class Lattice1Worker:
|
|
|
153
177
|
DependencyError: If required dependencies are missing
|
|
154
178
|
AlignmentError: If alignment process fails
|
|
155
179
|
"""
|
|
156
|
-
if emission is None:
|
|
157
|
-
try:
|
|
158
|
-
emission = self.emission(audio.tensor.to(self.device)) # (1, T, vocab_size)
|
|
159
|
-
except Exception as e:
|
|
160
|
-
raise AlignmentError(
|
|
161
|
-
"Failed to compute acoustic features from audio",
|
|
162
|
-
media_path=str(audio) if not isinstance(audio, torch.Tensor) else "tensor",
|
|
163
|
-
context={"original_error": str(e)},
|
|
164
|
-
)
|
|
165
|
-
|
|
166
180
|
try:
|
|
167
181
|
import k2
|
|
168
182
|
except ImportError:
|
|
@@ -177,7 +191,7 @@ class Lattice1Worker:
|
|
|
177
191
|
|
|
178
192
|
_start = time.time()
|
|
179
193
|
try:
|
|
180
|
-
# graph
|
|
194
|
+
# Create decoding graph
|
|
181
195
|
decoding_graph = k2.Fsa.from_str(lattice_graph_str, acceptor=False)
|
|
182
196
|
decoding_graph.requires_grad_(False)
|
|
183
197
|
decoding_graph = k2.arc_sort(decoding_graph)
|
|
@@ -190,39 +204,96 @@ class Lattice1Worker:
|
|
|
190
204
|
)
|
|
191
205
|
self.timings["decoding_graph"] += time.time() - _start
|
|
192
206
|
|
|
193
|
-
_start = time.time()
|
|
194
207
|
if self.device.type == "mps":
|
|
195
208
|
device = "cpu" # k2 does not support mps yet
|
|
196
209
|
else:
|
|
197
210
|
device = self.device
|
|
198
211
|
|
|
199
|
-
|
|
212
|
+
_start = time.time()
|
|
213
|
+
|
|
214
|
+
# Get beam search parameters from config or use defaults
|
|
215
|
+
search_beam = self.alignment_config.search_beam or 200
|
|
216
|
+
output_beam = self.alignment_config.output_beam or 80
|
|
217
|
+
min_active_states = self.alignment_config.min_active_states or 400
|
|
218
|
+
max_active_states = self.alignment_config.max_active_states or 10000
|
|
219
|
+
|
|
220
|
+
if emission is None and audio.streaming_mode:
|
|
221
|
+
# Streaming mode: pass emission iterator to align_segments
|
|
222
|
+
# The align_segments function will automatically detect the iterator
|
|
223
|
+
# and use k2.OnlineDenseIntersecter for memory-efficient processing
|
|
224
|
+
|
|
225
|
+
def emission_iterator():
|
|
226
|
+
"""Generate emissions for each audio chunk with progress tracking."""
|
|
227
|
+
total_duration = audio.duration
|
|
228
|
+
processed_duration = 0.0
|
|
229
|
+
total_minutes = int(total_duration / 60.0)
|
|
230
|
+
|
|
231
|
+
with tqdm(
|
|
232
|
+
total=total_minutes,
|
|
233
|
+
desc=f"Processing audio ({total_minutes} min)",
|
|
234
|
+
unit="min",
|
|
235
|
+
unit_scale=False,
|
|
236
|
+
unit_divisor=1,
|
|
237
|
+
) as pbar:
|
|
238
|
+
for chunk in audio.iter_chunks():
|
|
239
|
+
chunk_emission = self.emission(chunk.ndarray, acoustic_scale=acoustic_scale, device=device)
|
|
240
|
+
|
|
241
|
+
# Update progress based on chunk duration in minutes
|
|
242
|
+
chunk_duration = int(chunk.duration / 60.0)
|
|
243
|
+
pbar.update(chunk_duration)
|
|
244
|
+
processed_duration += chunk_duration
|
|
245
|
+
|
|
246
|
+
yield chunk_emission
|
|
247
|
+
|
|
248
|
+
# Calculate total frames for supervision_segments
|
|
249
|
+
total_frames = int(audio.duration / self.frame_shift)
|
|
250
|
+
|
|
200
251
|
results, labels = align_segments(
|
|
201
|
-
|
|
252
|
+
emission_iterator(), # Pass iterator for streaming
|
|
202
253
|
decoding_graph.to(device),
|
|
203
|
-
torch.tensor([
|
|
204
|
-
search_beam=
|
|
205
|
-
output_beam=
|
|
206
|
-
min_active_states=
|
|
207
|
-
max_active_states=
|
|
254
|
+
torch.tensor([total_frames], dtype=torch.int32),
|
|
255
|
+
search_beam=search_beam,
|
|
256
|
+
output_beam=output_beam,
|
|
257
|
+
min_active_states=min_active_states,
|
|
258
|
+
max_active_states=max_active_states,
|
|
208
259
|
subsampling_factor=1,
|
|
209
260
|
reject_low_confidence=False,
|
|
210
261
|
)
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
262
|
+
|
|
263
|
+
# For streaming, don't return emission tensor to save memory
|
|
264
|
+
emission_result = None
|
|
265
|
+
else:
|
|
266
|
+
# Batch mode: compute full emission tensor and pass to align_segments
|
|
267
|
+
if emission is None:
|
|
268
|
+
emission = self.emission(
|
|
269
|
+
audio.ndarray, acoustic_scale=acoustic_scale, device=device
|
|
270
|
+
) # (1, T, vocab_size)
|
|
271
|
+
else:
|
|
272
|
+
emission = emission.to(device) * acoustic_scale
|
|
273
|
+
|
|
274
|
+
results, labels = align_segments(
|
|
275
|
+
emission,
|
|
276
|
+
decoding_graph.to(device),
|
|
277
|
+
torch.tensor([emission.shape[1]], dtype=torch.int32),
|
|
278
|
+
search_beam=search_beam,
|
|
279
|
+
output_beam=output_beam,
|
|
280
|
+
min_active_states=min_active_states,
|
|
281
|
+
max_active_states=max_active_states,
|
|
282
|
+
subsampling_factor=1,
|
|
283
|
+
reject_low_confidence=False,
|
|
216
284
|
)
|
|
285
|
+
|
|
286
|
+
emission_result = emission
|
|
287
|
+
|
|
217
288
|
self.timings["align_segments"] += time.time() - _start
|
|
218
289
|
|
|
219
290
|
channel = 0
|
|
220
|
-
return
|
|
291
|
+
return emission_result, results, labels, self.frame_shift, offset, channel # frame_shift=20ms
|
|
221
292
|
|
|
222
293
|
|
|
223
|
-
def _load_worker(model_path: str, device: str) -> Lattice1Worker:
|
|
294
|
+
def _load_worker(model_path: str, device: str, config: Optional[Any] = None) -> Lattice1Worker:
|
|
224
295
|
"""Instantiate lattice worker with consistent error handling."""
|
|
225
296
|
try:
|
|
226
|
-
return Lattice1Worker(model_path, device=device, num_threads=8)
|
|
297
|
+
return Lattice1Worker(model_path, device=device, num_threads=8, config=config)
|
|
227
298
|
except Exception as e:
|
|
228
299
|
raise ModelLoadError(f"worker from {model_path}", original_error=e)
|
lattifai/alignment/segmenter.py
CHANGED
|
@@ -7,6 +7,7 @@ import colorful
|
|
|
7
7
|
from lattifai.audio2 import AudioData
|
|
8
8
|
from lattifai.caption import Caption, Supervision
|
|
9
9
|
from lattifai.config import AlignmentConfig
|
|
10
|
+
from lattifai.utils import safe_print
|
|
10
11
|
|
|
11
12
|
from .tokenizer import END_PUNCTUATION
|
|
12
13
|
|
|
@@ -153,7 +154,7 @@ class Segmenter:
|
|
|
153
154
|
|
|
154
155
|
total_sups = sum(len(sups) if isinstance(sups, list) else 1 for _, _, sups, _ in segments)
|
|
155
156
|
|
|
156
|
-
|
|
157
|
+
safe_print(colorful.cyan(f"📊 Created {len(segments)} alignment segments:"))
|
|
157
158
|
for i, (start, end, sups, _) in enumerate(segments, 1):
|
|
158
159
|
duration = end - start
|
|
159
160
|
print(
|
|
@@ -163,4 +164,4 @@ class Segmenter:
|
|
|
163
164
|
)
|
|
164
165
|
)
|
|
165
166
|
|
|
166
|
-
|
|
167
|
+
safe_print(colorful.green(f" Total: {total_sups} supervisions across {len(segments)} segments"))
|
lattifai/alignment/tokenizer.py
CHANGED
|
@@ -335,7 +335,7 @@ class LatticeTokenizer:
|
|
|
335
335
|
flush_segment(s, None)
|
|
336
336
|
|
|
337
337
|
assert len(speakers) == len(texts), f"len(speakers)={len(speakers)} != len(texts)={len(texts)}"
|
|
338
|
-
sentences = self.sentence_splitter.split(texts, threshold=0.15, strip_whitespace=strip_whitespace)
|
|
338
|
+
sentences = self.sentence_splitter.split(texts, threshold=0.15, strip_whitespace=strip_whitespace, batch_size=8)
|
|
339
339
|
|
|
340
340
|
supervisions, remainder = [], ""
|
|
341
341
|
for k, (_speaker, _sentences) in enumerate(zip(speakers, sentences)):
|
|
@@ -450,7 +450,7 @@ class LatticeTokenizer:
|
|
|
450
450
|
"destroy_lattice": True,
|
|
451
451
|
},
|
|
452
452
|
)
|
|
453
|
-
if response.status_code ==
|
|
453
|
+
if response.status_code == 400:
|
|
454
454
|
raise LatticeDecodingError(
|
|
455
455
|
lattice_id,
|
|
456
456
|
original_error=Exception(LATTICE_DECODING_FAILURE_HELP),
|
|
@@ -466,7 +466,7 @@ class LatticeTokenizer:
|
|
|
466
466
|
|
|
467
467
|
alignments = [Supervision.from_dict(s) for s in result["supervisions"]]
|
|
468
468
|
|
|
469
|
-
if return_details:
|
|
469
|
+
if emission is not None and return_details:
|
|
470
470
|
# Add emission confidence scores for segments and word-level alignments
|
|
471
471
|
_add_confidence_scores(alignments, emission, labels[0], frame_shift, offset)
|
|
472
472
|
|