lattifai 1.2.2__py3-none-any.whl → 1.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lattifai/_init.py +20 -0
- lattifai/alignment/__init__.py +2 -3
- lattifai/alignment/lattice1_aligner.py +117 -4
- lattifai/alignment/lattice1_worker.py +47 -4
- lattifai/alignment/segmenter.py +3 -2
- lattifai/alignment/text_align.py +2 -1
- lattifai/alignment/tokenizer.py +56 -29
- lattifai/audio2.py +162 -183
- lattifai/cli/alignment.py +5 -0
- lattifai/cli/caption.py +6 -6
- lattifai/cli/transcribe.py +1 -5
- lattifai/cli/youtube.py +3 -0
- lattifai/client.py +41 -12
- lattifai/config/__init__.py +21 -3
- lattifai/config/alignment.py +7 -0
- lattifai/config/caption.py +13 -243
- lattifai/config/client.py +16 -0
- lattifai/config/event.py +102 -0
- lattifai/config/transcription.py +25 -1
- lattifai/data/__init__.py +8 -0
- lattifai/data/caption.py +228 -0
- lattifai/errors.py +78 -53
- lattifai/event/__init__.py +65 -0
- lattifai/event/lattifai.py +166 -0
- lattifai/mixin.py +22 -17
- lattifai/transcription/base.py +2 -1
- lattifai/transcription/gemini.py +147 -16
- lattifai/transcription/lattifai.py +8 -11
- lattifai/types.py +1 -1
- lattifai/youtube/client.py +143 -48
- {lattifai-1.2.2.dist-info → lattifai-1.3.1.dist-info}/METADATA +129 -58
- lattifai-1.3.1.dist-info/RECORD +57 -0
- lattifai/__init__.py +0 -88
- lattifai/alignment/sentence_splitter.py +0 -350
- lattifai/caption/__init__.py +0 -96
- lattifai/caption/caption.py +0 -661
- lattifai/caption/formats/__init__.py +0 -199
- lattifai/caption/formats/base.py +0 -211
- lattifai/caption/formats/gemini.py +0 -722
- lattifai/caption/formats/json.py +0 -194
- lattifai/caption/formats/lrc.py +0 -309
- lattifai/caption/formats/nle/__init__.py +0 -9
- lattifai/caption/formats/nle/audition.py +0 -561
- lattifai/caption/formats/nle/avid.py +0 -423
- lattifai/caption/formats/nle/fcpxml.py +0 -549
- lattifai/caption/formats/nle/premiere.py +0 -589
- lattifai/caption/formats/pysubs2.py +0 -642
- lattifai/caption/formats/sbv.py +0 -147
- lattifai/caption/formats/tabular.py +0 -338
- lattifai/caption/formats/textgrid.py +0 -193
- lattifai/caption/formats/ttml.py +0 -652
- lattifai/caption/formats/vtt.py +0 -469
- lattifai/caption/parsers/__init__.py +0 -9
- lattifai/caption/parsers/text_parser.py +0 -147
- lattifai/caption/standardize.py +0 -636
- lattifai/caption/supervision.py +0 -34
- lattifai/caption/utils.py +0 -474
- lattifai-1.2.2.dist-info/RECORD +0 -76
- {lattifai-1.2.2.dist-info → lattifai-1.3.1.dist-info}/WHEEL +0 -0
- {lattifai-1.2.2.dist-info → lattifai-1.3.1.dist-info}/entry_points.txt +0 -0
- {lattifai-1.2.2.dist-info → lattifai-1.3.1.dist-info}/licenses/LICENSE +0 -0
- {lattifai-1.2.2.dist-info → lattifai-1.3.1.dist-info}/top_level.txt +0 -0
lattifai/_init.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
"""Environment configuration for LattifAI.
|
|
2
|
+
|
|
3
|
+
Import this module early to suppress warnings before other imports.
|
|
4
|
+
|
|
5
|
+
Usage:
|
|
6
|
+
import lattifai._init # noqa: F401
|
|
7
|
+
from lattifai.client import LattifAI
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import os
|
|
11
|
+
import warnings
|
|
12
|
+
|
|
13
|
+
# Suppress SWIG deprecation warnings before any imports
|
|
14
|
+
warnings.filterwarnings("ignore", category=DeprecationWarning, message=".*SwigPy.*")
|
|
15
|
+
|
|
16
|
+
# Suppress PyTorch transformer nested tensor warning
|
|
17
|
+
warnings.filterwarnings("ignore", category=UserWarning, message=".*enable_nested_tensor.*")
|
|
18
|
+
|
|
19
|
+
# Disable tokenizers parallelism warning
|
|
20
|
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
lattifai/alignment/__init__.py
CHANGED
|
@@ -1,15 +1,14 @@
|
|
|
1
1
|
"""Alignment module for LattifAI forced alignment."""
|
|
2
2
|
|
|
3
|
+
from lattifai.caption import SentenceSplitter
|
|
4
|
+
|
|
3
5
|
from .lattice1_aligner import Lattice1Aligner
|
|
4
6
|
from .segmenter import Segmenter
|
|
5
|
-
from .sentence_splitter import SentenceSplitter
|
|
6
|
-
from .text_align import align_supervisions_and_transcription
|
|
7
7
|
from .tokenizer import tokenize_multilingual_text
|
|
8
8
|
|
|
9
9
|
__all__ = [
|
|
10
10
|
"Lattice1Aligner",
|
|
11
11
|
"Segmenter",
|
|
12
12
|
"SentenceSplitter",
|
|
13
|
-
"align_supervisions_and_transcription",
|
|
14
13
|
"tokenize_multilingual_text",
|
|
15
14
|
]
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""Lattice-1 Aligner implementation."""
|
|
2
2
|
|
|
3
|
-
from typing import Any, List, Optional, Tuple, Union
|
|
3
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
4
4
|
|
|
5
5
|
import colorful
|
|
6
6
|
import numpy as np
|
|
@@ -118,7 +118,10 @@ class Lattice1Aligner(object):
|
|
|
118
118
|
safe_print(colorful.cyan("🔗 Step 2: Creating lattice graph from segments"))
|
|
119
119
|
try:
|
|
120
120
|
supervisions, lattice_id, lattice_graph = self.tokenizer.tokenize(
|
|
121
|
-
supervisions,
|
|
121
|
+
supervisions,
|
|
122
|
+
split_sentence=split_sentence,
|
|
123
|
+
boost=self.config.boost,
|
|
124
|
+
transition_penalty=self.config.transition_penalty,
|
|
122
125
|
)
|
|
123
126
|
if verbose:
|
|
124
127
|
safe_print(colorful.green(f" ✓ Generated lattice graph with ID: {lattice_id}"))
|
|
@@ -162,12 +165,34 @@ class Lattice1Aligner(object):
|
|
|
162
165
|
return_details=return_details,
|
|
163
166
|
start_margin=self.config.start_margin,
|
|
164
167
|
end_margin=self.config.end_margin,
|
|
168
|
+
check_sanity=True,
|
|
165
169
|
)
|
|
166
170
|
if verbose:
|
|
167
171
|
safe_print(colorful.green(f" ✓ Successfully aligned {len(alignments)} segments"))
|
|
168
|
-
except LatticeDecodingError:
|
|
172
|
+
except LatticeDecodingError as e:
|
|
169
173
|
safe_print(colorful.red(" x Failed to decode lattice alignment results"))
|
|
170
|
-
|
|
174
|
+
_alignments = self.tokenizer.detokenize(
|
|
175
|
+
lattice_id,
|
|
176
|
+
lattice_results,
|
|
177
|
+
supervisions=supervisions,
|
|
178
|
+
return_details=return_details,
|
|
179
|
+
start_margin=self.config.start_margin,
|
|
180
|
+
end_margin=self.config.end_margin,
|
|
181
|
+
check_sanity=False,
|
|
182
|
+
)
|
|
183
|
+
# Check for score anomalies (media-text mismatch)
|
|
184
|
+
anomaly = _detect_score_anomalies(_alignments)
|
|
185
|
+
if anomaly:
|
|
186
|
+
anomaly_str = _format_anomaly_warning(anomaly)
|
|
187
|
+
del _alignments
|
|
188
|
+
raise LatticeDecodingError(
|
|
189
|
+
lattice_id,
|
|
190
|
+
message=colorful.yellow("Score anomaly detected - media and text mismatch:\n" + anomaly_str),
|
|
191
|
+
skip_help=True, # anomaly info is more specific than default help
|
|
192
|
+
)
|
|
193
|
+
else:
|
|
194
|
+
del _alignments
|
|
195
|
+
raise e
|
|
171
196
|
except Exception as e:
|
|
172
197
|
safe_print(colorful.red(" x Failed to decode lattice alignment results"))
|
|
173
198
|
raise LatticeDecodingError(lattice_id, original_error=e)
|
|
@@ -177,3 +202,91 @@ class Lattice1Aligner(object):
|
|
|
177
202
|
def profile(self) -> None:
|
|
178
203
|
"""Print profiling statistics."""
|
|
179
204
|
self.worker.profile()
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def _detect_score_anomalies(
|
|
208
|
+
alignments: List[Supervision],
|
|
209
|
+
drop_threshold: float = 0.08,
|
|
210
|
+
window_size: int = 5,
|
|
211
|
+
) -> Optional[Dict[str, Any]]:
|
|
212
|
+
"""Detect score anomalies indicating alignment mismatch.
|
|
213
|
+
|
|
214
|
+
Compares average of window_size segments before vs after each position.
|
|
215
|
+
When the drop is significant, it indicates the audio doesn't match
|
|
216
|
+
the text starting at that position.
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
alignments: List of aligned supervisions with scores
|
|
220
|
+
drop_threshold: Minimum drop between before/after averages to trigger
|
|
221
|
+
window_size: Number of segments to average on each side
|
|
222
|
+
|
|
223
|
+
Returns:
|
|
224
|
+
Dict with anomaly info if found, None otherwise
|
|
225
|
+
"""
|
|
226
|
+
scores = [s.score for s in alignments if s.score is not None]
|
|
227
|
+
if len(scores) < window_size * 2:
|
|
228
|
+
return None
|
|
229
|
+
|
|
230
|
+
for i in range(window_size, len(scores) - window_size):
|
|
231
|
+
before_avg = np.mean(scores[i - window_size : i])
|
|
232
|
+
after_avg = np.mean(scores[i : i + window_size])
|
|
233
|
+
drop = before_avg - after_avg
|
|
234
|
+
|
|
235
|
+
# Trigger: significant drop between before and after windows
|
|
236
|
+
if drop > drop_threshold:
|
|
237
|
+
# Find the exact mutation point (largest single-step drop)
|
|
238
|
+
max_drop = 0
|
|
239
|
+
mutation_idx = i
|
|
240
|
+
for j in range(i - 1, min(i + window_size, len(scores) - 1)):
|
|
241
|
+
single_drop = scores[j] - scores[j + 1]
|
|
242
|
+
if single_drop > max_drop:
|
|
243
|
+
max_drop = single_drop
|
|
244
|
+
mutation_idx = j + 1
|
|
245
|
+
|
|
246
|
+
# Segments: last normal + anomaly segments
|
|
247
|
+
last_normal = alignments[mutation_idx - 1] if mutation_idx > 0 else None
|
|
248
|
+
anomaly_segments = [
|
|
249
|
+
alignments[j] for j in range(mutation_idx, min(mutation_idx + window_size, len(alignments)))
|
|
250
|
+
]
|
|
251
|
+
|
|
252
|
+
return {
|
|
253
|
+
"mutation_index": mutation_idx,
|
|
254
|
+
"before_avg": round(before_avg, 4),
|
|
255
|
+
"after_avg": round(after_avg, 4),
|
|
256
|
+
"window_drop": round(drop, 4),
|
|
257
|
+
"mutation_drop": round(max_drop, 4),
|
|
258
|
+
"last_normal": last_normal,
|
|
259
|
+
"segments": anomaly_segments,
|
|
260
|
+
}
|
|
261
|
+
|
|
262
|
+
return None
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
def _format_anomaly_warning(anomaly: Dict[str, Any]) -> str:
|
|
266
|
+
"""Format anomaly detection result as warning message."""
|
|
267
|
+
lines = [
|
|
268
|
+
f"⚠️ Score anomaly detected at segment #{anomaly['mutation_index']}",
|
|
269
|
+
f" Window avg: {anomaly['before_avg']:.4f} → {anomaly['after_avg']:.4f} (drop: {anomaly['window_drop']:.4f})", # noqa: E501
|
|
270
|
+
f" Mutation drop: {anomaly['mutation_drop']:.4f}",
|
|
271
|
+
"",
|
|
272
|
+
]
|
|
273
|
+
|
|
274
|
+
# Show last normal segment
|
|
275
|
+
if anomaly.get("last_normal"):
|
|
276
|
+
seg = anomaly["last_normal"]
|
|
277
|
+
text_preview = seg.text[:50] + "..." if len(seg.text) > 50 else seg.text
|
|
278
|
+
lines.append(f' [{seg.start:.2f}s-{seg.end:.2f}s] score={seg.score:.4f} "{text_preview}"')
|
|
279
|
+
|
|
280
|
+
# Separator - mutation point
|
|
281
|
+
lines.append(" " + "─" * 60)
|
|
282
|
+
lines.append(f" ⬇️ MUTATION: The following {len(anomaly['segments'])}+ segments don't match audio")
|
|
283
|
+
lines.append(" " + "─" * 60)
|
|
284
|
+
|
|
285
|
+
# Show anomaly segments
|
|
286
|
+
for seg in anomaly["segments"]:
|
|
287
|
+
text_preview = seg.text[:50] + "..." if len(seg.text) > 50 else seg.text
|
|
288
|
+
lines.append(f' [{seg.start:.2f}s-{seg.end:.2f}s] score={seg.score:.4f} "{text_preview}"')
|
|
289
|
+
|
|
290
|
+
lines.append("")
|
|
291
|
+
lines.append(" Possible causes: Transcription error, missing content, or wrong audio region")
|
|
292
|
+
return "\n".join(lines)
|
|
@@ -35,6 +35,8 @@ class Lattice1Worker:
|
|
|
35
35
|
sess_options.intra_op_num_threads = num_threads # CPU cores
|
|
36
36
|
sess_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL
|
|
37
37
|
sess_options.add_session_config_entry("session.intra_op.allow_spinning", "0")
|
|
38
|
+
# Suppress CoreMLExecutionProvider warnings about partial graph support
|
|
39
|
+
sess_options.log_severity_level = 3 # ERROR level only
|
|
38
40
|
|
|
39
41
|
acoustic_model_path = f"{model_path}/acoustic_opt.onnx"
|
|
40
42
|
|
|
@@ -191,12 +193,17 @@ class Lattice1Worker:
|
|
|
191
193
|
float(output_beam),
|
|
192
194
|
int(min_active_states),
|
|
193
195
|
int(max_active_states),
|
|
196
|
+
allow_partial=True,
|
|
194
197
|
)
|
|
195
198
|
|
|
196
|
-
# Streaming mode
|
|
199
|
+
# Streaming mode with confidence score accumulation
|
|
197
200
|
total_duration = audio.duration
|
|
198
201
|
total_minutes = int(total_duration / 60.0)
|
|
199
202
|
|
|
203
|
+
max_probs = []
|
|
204
|
+
aligned_probs = []
|
|
205
|
+
prev_labels_len = 0
|
|
206
|
+
|
|
200
207
|
with tqdm(
|
|
201
208
|
total=total_minutes,
|
|
202
209
|
desc=f"Processing audio ({total_minutes} min)",
|
|
@@ -208,13 +215,43 @@ class Lattice1Worker:
|
|
|
208
215
|
chunk_emission = self.emission(chunk.ndarray, acoustic_scale=acoustic_scale)
|
|
209
216
|
intersecter.decode(chunk_emission[0])
|
|
210
217
|
|
|
218
|
+
__start = time.time()
|
|
219
|
+
# Get partial labels and compute confidence stats for this chunk
|
|
220
|
+
partial_labels = intersecter.get_partial_labels()
|
|
221
|
+
chunk_len = chunk_emission.shape[1]
|
|
222
|
+
|
|
223
|
+
# Get labels for current chunk (new labels since last chunk)
|
|
224
|
+
chunk_labels = partial_labels[prev_labels_len : prev_labels_len + chunk_len]
|
|
225
|
+
prev_labels_len = len(partial_labels)
|
|
226
|
+
|
|
227
|
+
# Compute emission-based confidence stats
|
|
228
|
+
probs = np.exp(chunk_emission[0]) # [T, V]
|
|
229
|
+
max_probs.append(np.max(probs, axis=-1)) # [T]
|
|
230
|
+
|
|
231
|
+
# Handle case where chunk_labels length might differ from chunk_len
|
|
232
|
+
if len(chunk_labels) == chunk_len:
|
|
233
|
+
aligned_probs.append(probs[np.arange(chunk_len), chunk_labels])
|
|
234
|
+
else:
|
|
235
|
+
# Fallback: use max probs as aligned probs (approximate)
|
|
236
|
+
aligned_probs.append(np.max(probs, axis=-1))
|
|
237
|
+
|
|
238
|
+
del chunk_emission, probs # Free memory
|
|
239
|
+
self.timings["align_>labels"] += time.time() - __start
|
|
240
|
+
|
|
211
241
|
# Update progress
|
|
212
242
|
chunk_duration = int(chunk.duration / 60.0)
|
|
213
243
|
pbar.update(chunk_duration)
|
|
214
244
|
|
|
215
|
-
|
|
245
|
+
# Build emission_stats for confidence calculation
|
|
246
|
+
emission_stats = {
|
|
247
|
+
"max_probs": np.concatenate(max_probs),
|
|
248
|
+
"aligned_probs": np.concatenate(aligned_probs),
|
|
249
|
+
}
|
|
250
|
+
|
|
216
251
|
# Get results from intersecter
|
|
252
|
+
__start = time.time()
|
|
217
253
|
results, labels = intersecter.finish()
|
|
254
|
+
self.timings["align_>finish"] += time.time() - __start
|
|
218
255
|
else:
|
|
219
256
|
# Batch mode
|
|
220
257
|
if emission is None:
|
|
@@ -230,13 +267,19 @@ class Lattice1Worker:
|
|
|
230
267
|
float(output_beam),
|
|
231
268
|
int(min_active_states),
|
|
232
269
|
int(max_active_states),
|
|
270
|
+
allow_partial=True,
|
|
233
271
|
)
|
|
234
|
-
|
|
272
|
+
# Compute emission_stats from full emission (same format as streaming)
|
|
273
|
+
probs = np.exp(emission[0]) # [T, V]
|
|
274
|
+
emission_stats = {
|
|
275
|
+
"max_probs": np.max(probs, axis=-1), # [T]
|
|
276
|
+
"aligned_probs": probs[np.arange(probs.shape[0]), labels[0]], # [T]
|
|
277
|
+
}
|
|
235
278
|
|
|
236
279
|
self.timings["align_segments"] += time.time() - _start
|
|
237
280
|
|
|
238
281
|
channel = 0
|
|
239
|
-
return
|
|
282
|
+
return emission_stats, results, labels, self.frame_shift, offset, channel # frame_shift=20ms
|
|
240
283
|
|
|
241
284
|
def profile(self) -> None:
|
|
242
285
|
"""Print formatted profiling statistics."""
|
lattifai/alignment/segmenter.py
CHANGED
|
@@ -5,11 +5,12 @@ from typing import List, Optional, Tuple
|
|
|
5
5
|
import colorful
|
|
6
6
|
|
|
7
7
|
from lattifai.audio2 import AudioData
|
|
8
|
-
from lattifai.caption import
|
|
8
|
+
from lattifai.caption import Supervision
|
|
9
9
|
from lattifai.config import AlignmentConfig
|
|
10
|
+
from lattifai.data import Caption
|
|
10
11
|
from lattifai.utils import safe_print
|
|
11
12
|
|
|
12
|
-
from .
|
|
13
|
+
from .punctuation import END_PUNCTUATION
|
|
13
14
|
|
|
14
15
|
|
|
15
16
|
class Segmenter:
|
lattifai/alignment/text_align.py
CHANGED
|
@@ -9,7 +9,8 @@ import regex
|
|
|
9
9
|
from error_align import error_align
|
|
10
10
|
from error_align.utils import DELIMITERS, NUMERIC_TOKEN, STANDARD_TOKEN, Alignment, OpType
|
|
11
11
|
|
|
12
|
-
from lattifai.caption import
|
|
12
|
+
from lattifai.caption import Supervision
|
|
13
|
+
from lattifai.data import Caption
|
|
13
14
|
from lattifai.utils import safe_print
|
|
14
15
|
|
|
15
16
|
from .punctuation import PUNCTUATION
|
lattifai/alignment/tokenizer.py
CHANGED
|
@@ -6,9 +6,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
|
|
6
6
|
|
|
7
7
|
import numpy as np
|
|
8
8
|
|
|
9
|
-
|
|
10
|
-
from lhotse.supervision import SupervisionSegment as Supervision # NOTE: Transcriber SupervisionSegment
|
|
11
|
-
|
|
9
|
+
from lattifai.caption import SentenceSplitter, Supervision
|
|
12
10
|
from lattifai.caption import normalize_text as normalize_html_text
|
|
13
11
|
from lattifai.errors import (
|
|
14
12
|
LATTICE_DECODING_FAILURE_HELP,
|
|
@@ -17,9 +15,7 @@ from lattifai.errors import (
|
|
|
17
15
|
QuotaExceededError,
|
|
18
16
|
)
|
|
19
17
|
|
|
20
|
-
from .phonemizer import G2Phonemizer
|
|
21
18
|
from .punctuation import PUNCTUATION, PUNCTUATION_SPACE
|
|
22
|
-
from .sentence_splitter import SentenceSplitter
|
|
23
19
|
from .text_align import TextAlignResult
|
|
24
20
|
|
|
25
21
|
MAXIMUM_WORD_LENGTH = 40
|
|
@@ -174,13 +170,16 @@ class LatticeTokenizer:
|
|
|
174
170
|
tokenizer.dictionaries = defaultdict(list, data["dictionaries"])
|
|
175
171
|
tokenizer.oov_word = data["oov_word"]
|
|
176
172
|
|
|
173
|
+
# Lazy load G2P model only if it exists (avoids PyTorch dependency)
|
|
177
174
|
g2pp_model_path = f"{model_path}/g2pp.bin" if Path(f"{model_path}/g2pp.bin").exists() else None
|
|
178
|
-
if
|
|
179
|
-
|
|
175
|
+
g2p_model_path = f"{model_path}/g2p.bin" if Path(f"{model_path}/g2p.bin").exists() else None
|
|
176
|
+
g2p_path = g2pp_model_path or g2p_model_path
|
|
177
|
+
if g2p_path:
|
|
178
|
+
from .phonemizer import G2Phonemizer
|
|
179
|
+
|
|
180
|
+
tokenizer.g2p_model = G2Phonemizer(g2p_path, device=device)
|
|
180
181
|
else:
|
|
181
|
-
|
|
182
|
-
if g2p_model_path:
|
|
183
|
-
tokenizer.g2p_model = G2Phonemizer(g2p_model_path, device=device)
|
|
182
|
+
tokenizer.g2p_model = None
|
|
184
183
|
|
|
185
184
|
tokenizer.device = device
|
|
186
185
|
tokenizer.add_special_tokens()
|
|
@@ -246,9 +245,24 @@ class LatticeTokenizer:
|
|
|
246
245
|
self.init_sentence_splitter()
|
|
247
246
|
return self.sentence_splitter.split_sentences(supervisions, strip_whitespace=strip_whitespace)
|
|
248
247
|
|
|
248
|
+
def _get_client_info(self) -> Dict[str, Optional[str]]:
|
|
249
|
+
"""Get client identification info for usage tracking."""
|
|
250
|
+
try:
|
|
251
|
+
from importlib.metadata import version
|
|
252
|
+
|
|
253
|
+
return {"client_name": "python-sdk", "client_version": version("lattifai")}
|
|
254
|
+
except Exception:
|
|
255
|
+
return {"client_name": "python-sdk", "client_version": "unknown"}
|
|
256
|
+
|
|
249
257
|
def tokenize(
|
|
250
|
-
self,
|
|
258
|
+
self,
|
|
259
|
+
supervisions: Union[List[Supervision], TextAlignResult],
|
|
260
|
+
split_sentence: bool = False,
|
|
261
|
+
boost: float = 0.0,
|
|
262
|
+
transition_penalty: Optional[float] = 0.0,
|
|
251
263
|
) -> Tuple[str, Dict[str, Any]]:
|
|
264
|
+
client_info = self._get_client_info()
|
|
265
|
+
|
|
252
266
|
if isinstance(supervisions[0], Supervision):
|
|
253
267
|
if split_sentence:
|
|
254
268
|
supervisions = self.split_sentences(supervisions)
|
|
@@ -260,6 +274,8 @@ class LatticeTokenizer:
|
|
|
260
274
|
"model_name": self.model_name,
|
|
261
275
|
"supervisions": [s.to_dict() for s in supervisions],
|
|
262
276
|
"pronunciation_dictionaries": pronunciation_dictionaries,
|
|
277
|
+
**client_info,
|
|
278
|
+
"transition_penalty": transition_penalty,
|
|
263
279
|
},
|
|
264
280
|
)
|
|
265
281
|
else:
|
|
@@ -274,6 +290,7 @@ class LatticeTokenizer:
|
|
|
274
290
|
"transcription": [s.to_dict() for s in supervisions[1]],
|
|
275
291
|
"pronunciation_dictionaries": pronunciation_dictionaries,
|
|
276
292
|
"boost": boost,
|
|
293
|
+
**client_info,
|
|
277
294
|
},
|
|
278
295
|
)
|
|
279
296
|
|
|
@@ -297,8 +314,10 @@ class LatticeTokenizer:
|
|
|
297
314
|
return_details: bool = False,
|
|
298
315
|
start_margin: float = 0.08,
|
|
299
316
|
end_margin: float = 0.20,
|
|
317
|
+
check_sanity: bool = True,
|
|
300
318
|
) -> List[Supervision]:
|
|
301
|
-
|
|
319
|
+
emission_stats, results, labels, frame_shift, offset, channel = lattice_results # noqa: F841
|
|
320
|
+
# emission_stats is a dict with 'max_probs' and 'aligned_probs' (unified for batch and streaming)
|
|
302
321
|
if isinstance(supervisions[0], Supervision):
|
|
303
322
|
response = self.client_wrapper.post(
|
|
304
323
|
"detokenize",
|
|
@@ -314,6 +333,7 @@ class LatticeTokenizer:
|
|
|
314
333
|
"destroy_lattice": True,
|
|
315
334
|
"start_margin": start_margin,
|
|
316
335
|
"end_margin": end_margin,
|
|
336
|
+
"check_sanity": check_sanity,
|
|
317
337
|
},
|
|
318
338
|
)
|
|
319
339
|
else:
|
|
@@ -331,6 +351,7 @@ class LatticeTokenizer:
|
|
|
331
351
|
"destroy_lattice": True,
|
|
332
352
|
"start_margin": start_margin,
|
|
333
353
|
"end_margin": end_margin,
|
|
354
|
+
"check_sanity": check_sanity,
|
|
334
355
|
},
|
|
335
356
|
)
|
|
336
357
|
|
|
@@ -350,9 +371,8 @@ class LatticeTokenizer:
|
|
|
350
371
|
|
|
351
372
|
alignments = [Supervision.from_dict(s) for s in result["supervisions"]]
|
|
352
373
|
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
_add_confidence_scores(alignments, emission, labels[0], frame_shift, offset)
|
|
374
|
+
# Add emission confidence scores for segments and word-level alignments
|
|
375
|
+
_add_confidence_scores(alignments, emission_stats, frame_shift, offset)
|
|
356
376
|
|
|
357
377
|
if isinstance(supervisions[0], Supervision):
|
|
358
378
|
alignments = _update_alignments_speaker(supervisions, alignments)
|
|
@@ -365,8 +385,7 @@ class LatticeTokenizer:
|
|
|
365
385
|
|
|
366
386
|
def _add_confidence_scores(
|
|
367
387
|
supervisions: List[Supervision],
|
|
368
|
-
|
|
369
|
-
labels: List[int],
|
|
388
|
+
emission_stats: Dict[str, np.ndarray],
|
|
370
389
|
frame_shift: float,
|
|
371
390
|
offset: float = 0.0,
|
|
372
391
|
) -> None:
|
|
@@ -379,29 +398,37 @@ def _add_confidence_scores(
|
|
|
379
398
|
|
|
380
399
|
Args:
|
|
381
400
|
supervisions: List of Supervision objects to add scores to (modified in-place)
|
|
382
|
-
|
|
383
|
-
labels: Token labels corresponding to aligned tokens
|
|
401
|
+
emission_stats: Dict with 'max_probs' and 'aligned_probs' arrays
|
|
384
402
|
frame_shift: Frame shift in seconds for converting frames to time
|
|
403
|
+
offset: Time offset in seconds
|
|
385
404
|
"""
|
|
386
|
-
|
|
405
|
+
max_probs = emission_stats["max_probs"]
|
|
406
|
+
aligned_probs = emission_stats["aligned_probs"]
|
|
407
|
+
diffprobs_full = max_probs - aligned_probs
|
|
387
408
|
|
|
388
409
|
for supervision in supervisions:
|
|
389
410
|
start_frame = int((supervision.start - offset) / frame_shift)
|
|
390
411
|
end_frame = int((supervision.end - offset) / frame_shift)
|
|
391
412
|
|
|
392
|
-
#
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
413
|
+
# Clamp to valid range
|
|
414
|
+
start_frame = max(0, min(start_frame, len(diffprobs_full) - 1))
|
|
415
|
+
end_frame = max(start_frame + 1, min(end_frame, len(diffprobs_full)))
|
|
416
|
+
|
|
417
|
+
diffprobs = diffprobs_full[start_frame:end_frame]
|
|
418
|
+
if len(diffprobs) > 0:
|
|
419
|
+
supervision.score = round(1.0 - diffprobs.mean().item(), ndigits=4)
|
|
397
420
|
|
|
398
|
-
#
|
|
421
|
+
# Word-level confidence
|
|
399
422
|
if hasattr(supervision, "alignment") and supervision.alignment:
|
|
400
423
|
words = supervision.alignment.get("word", [])
|
|
401
424
|
for w, item in enumerate(words):
|
|
402
|
-
start = int((item.start - offset) / frame_shift)
|
|
403
|
-
end = int((item.end - offset) / frame_shift)
|
|
404
|
-
|
|
425
|
+
start = int((item.start - offset) / frame_shift)
|
|
426
|
+
end = int((item.end - offset) / frame_shift)
|
|
427
|
+
start = max(0, min(start, len(diffprobs_full) - 1))
|
|
428
|
+
end = max(start + 1, min(end, len(diffprobs_full)))
|
|
429
|
+
word_diffprobs = diffprobs_full[start:end]
|
|
430
|
+
if len(word_diffprobs) > 0:
|
|
431
|
+
words[w] = item._replace(score=round(1.0 - word_diffprobs.mean().item(), ndigits=4))
|
|
405
432
|
|
|
406
433
|
|
|
407
434
|
def _update_alignments_speaker(supervisions: List[Supervision], alignments: List[Supervision]) -> List[Supervision]:
|