lattifai 1.2.1__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lattifai/_init.py +20 -0
- lattifai/alignment/__init__.py +9 -1
- lattifai/alignment/lattice1_aligner.py +175 -54
- lattifai/alignment/lattice1_worker.py +47 -4
- lattifai/alignment/punctuation.py +38 -0
- lattifai/alignment/segmenter.py +3 -2
- lattifai/alignment/text_align.py +441 -0
- lattifai/alignment/tokenizer.py +134 -65
- lattifai/audio2.py +162 -183
- lattifai/cli/__init__.py +2 -1
- lattifai/cli/alignment.py +5 -0
- lattifai/cli/caption.py +111 -4
- lattifai/cli/transcribe.py +2 -6
- lattifai/cli/youtube.py +7 -1
- lattifai/client.py +72 -123
- lattifai/config/__init__.py +28 -0
- lattifai/config/alignment.py +14 -0
- lattifai/config/caption.py +45 -31
- lattifai/config/client.py +16 -0
- lattifai/config/event.py +102 -0
- lattifai/config/media.py +20 -0
- lattifai/config/transcription.py +25 -1
- lattifai/data/__init__.py +8 -0
- lattifai/data/caption.py +228 -0
- lattifai/diarization/__init__.py +41 -1
- lattifai/errors.py +78 -53
- lattifai/event/__init__.py +65 -0
- lattifai/event/lattifai.py +166 -0
- lattifai/mixin.py +49 -32
- lattifai/transcription/base.py +8 -2
- lattifai/transcription/gemini.py +147 -16
- lattifai/transcription/lattifai.py +25 -63
- lattifai/types.py +1 -1
- lattifai/utils.py +7 -13
- lattifai/workflow/__init__.py +28 -4
- lattifai/workflow/file_manager.py +2 -5
- lattifai/youtube/__init__.py +43 -0
- lattifai/youtube/client.py +1265 -0
- lattifai/youtube/types.py +23 -0
- lattifai-1.3.0.dist-info/METADATA +678 -0
- lattifai-1.3.0.dist-info/RECORD +57 -0
- {lattifai-1.2.1.dist-info → lattifai-1.3.0.dist-info}/entry_points.txt +1 -2
- lattifai/__init__.py +0 -88
- lattifai/alignment/sentence_splitter.py +0 -219
- lattifai/caption/__init__.py +0 -20
- lattifai/caption/caption.py +0 -1467
- lattifai/caption/gemini_reader.py +0 -462
- lattifai/caption/gemini_writer.py +0 -173
- lattifai/caption/supervision.py +0 -34
- lattifai/caption/text_parser.py +0 -145
- lattifai/cli/app_installer.py +0 -142
- lattifai/cli/server.py +0 -44
- lattifai/server/app.py +0 -427
- lattifai/workflow/youtube.py +0 -577
- lattifai-1.2.1.dist-info/METADATA +0 -1134
- lattifai-1.2.1.dist-info/RECORD +0 -58
- {lattifai-1.2.1.dist-info → lattifai-1.3.0.dist-info}/WHEEL +0 -0
- {lattifai-1.2.1.dist-info → lattifai-1.3.0.dist-info}/licenses/LICENSE +0 -0
- {lattifai-1.2.1.dist-info → lattifai-1.3.0.dist-info}/top_level.txt +0 -0
lattifai/_init.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
"""Environment configuration for LattifAI.
|
|
2
|
+
|
|
3
|
+
Import this module early to suppress warnings before other imports.
|
|
4
|
+
|
|
5
|
+
Usage:
|
|
6
|
+
import lattifai._init # noqa: F401
|
|
7
|
+
from lattifai.client import LattifAI
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import os
|
|
11
|
+
import warnings
|
|
12
|
+
|
|
13
|
+
# Suppress SWIG deprecation warnings before any imports
|
|
14
|
+
warnings.filterwarnings("ignore", category=DeprecationWarning, message=".*SwigPy.*")
|
|
15
|
+
|
|
16
|
+
# Suppress PyTorch transformer nested tensor warning
|
|
17
|
+
warnings.filterwarnings("ignore", category=UserWarning, message=".*enable_nested_tensor.*")
|
|
18
|
+
|
|
19
|
+
# Disable tokenizers parallelism warning
|
|
20
|
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
lattifai/alignment/__init__.py
CHANGED
|
@@ -1,6 +1,14 @@
|
|
|
1
1
|
"""Alignment module for LattifAI forced alignment."""
|
|
2
2
|
|
|
3
|
+
from lattifai.caption import SentenceSplitter
|
|
4
|
+
|
|
3
5
|
from .lattice1_aligner import Lattice1Aligner
|
|
4
6
|
from .segmenter import Segmenter
|
|
7
|
+
from .tokenizer import tokenize_multilingual_text
|
|
5
8
|
|
|
6
|
-
__all__ = [
|
|
9
|
+
__all__ = [
|
|
10
|
+
"Lattice1Aligner",
|
|
11
|
+
"Segmenter",
|
|
12
|
+
"SentenceSplitter",
|
|
13
|
+
"tokenize_multilingual_text",
|
|
14
|
+
]
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""Lattice-1 Aligner implementation."""
|
|
2
2
|
|
|
3
|
-
from typing import Any, List, Optional, Tuple
|
|
3
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
4
4
|
|
|
5
5
|
import colorful
|
|
6
6
|
import numpy as np
|
|
@@ -16,11 +16,22 @@ from lattifai.errors import (
|
|
|
16
16
|
from lattifai.utils import _resolve_model_path, safe_print
|
|
17
17
|
|
|
18
18
|
from .lattice1_worker import _load_worker
|
|
19
|
+
from .text_align import TextAlignResult
|
|
19
20
|
from .tokenizer import _load_tokenizer
|
|
20
21
|
|
|
21
22
|
ClientType = Any
|
|
22
23
|
|
|
23
24
|
|
|
25
|
+
def _extract_text_for_error(supervisions: Union[list, tuple]) -> str:
|
|
26
|
+
"""Extract text from supervisions for error messages."""
|
|
27
|
+
if not supervisions:
|
|
28
|
+
return ""
|
|
29
|
+
# TextAlignResult is a tuple: (caption_sups, transcript_sups, ...)
|
|
30
|
+
if isinstance(supervisions, tuple):
|
|
31
|
+
supervisions = supervisions[0] or supervisions[1] or []
|
|
32
|
+
return " ".join(s.text for s in supervisions if s and s.text)
|
|
33
|
+
|
|
34
|
+
|
|
24
35
|
class Lattice1Aligner(object):
|
|
25
36
|
"""Synchronous LattifAI client with config-driven architecture."""
|
|
26
37
|
|
|
@@ -79,7 +90,7 @@ class Lattice1Aligner(object):
|
|
|
79
90
|
def alignment(
|
|
80
91
|
self,
|
|
81
92
|
audio: AudioData,
|
|
82
|
-
supervisions: List[Supervision],
|
|
93
|
+
supervisions: Union[List[Supervision], TextAlignResult],
|
|
83
94
|
split_sentence: Optional[bool] = False,
|
|
84
95
|
return_details: Optional[bool] = False,
|
|
85
96
|
emission: Optional[np.ndarray] = None,
|
|
@@ -102,70 +113,180 @@ class Lattice1Aligner(object):
|
|
|
102
113
|
AlignmentError: If audio alignment fails
|
|
103
114
|
LatticeDecodingError: If lattice decoding fails
|
|
104
115
|
"""
|
|
116
|
+
# Step 2: Create lattice graph
|
|
117
|
+
if verbose:
|
|
118
|
+
safe_print(colorful.cyan("🔗 Step 2: Creating lattice graph from segments"))
|
|
105
119
|
try:
|
|
120
|
+
supervisions, lattice_id, lattice_graph = self.tokenizer.tokenize(
|
|
121
|
+
supervisions,
|
|
122
|
+
split_sentence=split_sentence,
|
|
123
|
+
boost=self.config.boost,
|
|
124
|
+
transition_penalty=self.config.transition_penalty,
|
|
125
|
+
)
|
|
106
126
|
if verbose:
|
|
107
|
-
safe_print(colorful.
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
)
|
|
112
|
-
if verbose:
|
|
113
|
-
safe_print(colorful.green(f" ✓ Generated lattice graph with ID: {lattice_id}"))
|
|
114
|
-
except Exception as e:
|
|
115
|
-
text_content = " ".join([sup.text for sup in supervisions]) if supervisions else ""
|
|
116
|
-
raise LatticeEncodingError(text_content, original_error=e)
|
|
127
|
+
safe_print(colorful.green(f" ✓ Generated lattice graph with ID: {lattice_id}"))
|
|
128
|
+
except Exception as e:
|
|
129
|
+
text_content = _extract_text_for_error(supervisions)
|
|
130
|
+
raise LatticeEncodingError(text_content, original_error=e)
|
|
117
131
|
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
)
|
|
132
|
+
# Step 3: Search lattice graph
|
|
133
|
+
if verbose:
|
|
134
|
+
safe_print(colorful.cyan(f"🔍 Step 3: Searching lattice graph with media: {audio}"))
|
|
135
|
+
if audio.streaming_mode:
|
|
136
|
+
safe_print(
|
|
137
|
+
colorful.yellow(
|
|
138
|
+
f" ⚡Using streaming mode with {audio.streaming_chunk_secs}s (chunk duration)"
|
|
125
139
|
)
|
|
126
|
-
try:
|
|
127
|
-
lattice_results = self.worker.alignment(
|
|
128
|
-
audio,
|
|
129
|
-
lattice_graph,
|
|
130
|
-
emission=emission,
|
|
131
|
-
offset=offset,
|
|
132
|
-
)
|
|
133
|
-
if verbose:
|
|
134
|
-
safe_print(colorful.green(" ✓ Lattice search completed"))
|
|
135
|
-
except Exception as e:
|
|
136
|
-
raise AlignmentError(
|
|
137
|
-
f"Audio alignment failed for {audio}",
|
|
138
|
-
media_path=str(audio),
|
|
139
|
-
context={"original_error": str(e)},
|
|
140
140
|
)
|
|
141
|
+
try:
|
|
142
|
+
lattice_results = self.worker.alignment(
|
|
143
|
+
audio,
|
|
144
|
+
lattice_graph,
|
|
145
|
+
emission=emission,
|
|
146
|
+
offset=offset,
|
|
147
|
+
)
|
|
148
|
+
if verbose:
|
|
149
|
+
safe_print(colorful.green(" ✓ Lattice search completed"))
|
|
150
|
+
except Exception as e:
|
|
151
|
+
raise AlignmentError(
|
|
152
|
+
f"Audio alignment failed for {audio}",
|
|
153
|
+
media_path=str(audio),
|
|
154
|
+
context={"original_error": str(e)},
|
|
155
|
+
)
|
|
141
156
|
|
|
157
|
+
# Step 4: Decode lattice results
|
|
158
|
+
if verbose:
|
|
159
|
+
safe_print(colorful.cyan("🎯 Step 4: Decoding lattice results to aligned segments"))
|
|
160
|
+
try:
|
|
161
|
+
alignments = self.tokenizer.detokenize(
|
|
162
|
+
lattice_id,
|
|
163
|
+
lattice_results,
|
|
164
|
+
supervisions=supervisions,
|
|
165
|
+
return_details=return_details,
|
|
166
|
+
start_margin=self.config.start_margin,
|
|
167
|
+
end_margin=self.config.end_margin,
|
|
168
|
+
check_sanity=True,
|
|
169
|
+
)
|
|
142
170
|
if verbose:
|
|
143
|
-
safe_print(colorful.
|
|
144
|
-
|
|
145
|
-
|
|
171
|
+
safe_print(colorful.green(f" ✓ Successfully aligned {len(alignments)} segments"))
|
|
172
|
+
except LatticeDecodingError as e:
|
|
173
|
+
safe_print(colorful.red(" x Failed to decode lattice alignment results"))
|
|
174
|
+
_alignments = self.tokenizer.detokenize(
|
|
175
|
+
lattice_id,
|
|
176
|
+
lattice_results,
|
|
177
|
+
supervisions=supervisions,
|
|
178
|
+
return_details=return_details,
|
|
179
|
+
start_margin=self.config.start_margin,
|
|
180
|
+
end_margin=self.config.end_margin,
|
|
181
|
+
check_sanity=False,
|
|
182
|
+
)
|
|
183
|
+
# Check for score anomalies (media-text mismatch)
|
|
184
|
+
anomaly = _detect_score_anomalies(_alignments)
|
|
185
|
+
if anomaly:
|
|
186
|
+
anomaly_str = _format_anomaly_warning(anomaly)
|
|
187
|
+
del _alignments
|
|
188
|
+
raise LatticeDecodingError(
|
|
146
189
|
lattice_id,
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
return_details=return_details,
|
|
150
|
-
start_margin=self.config.start_margin,
|
|
151
|
-
end_margin=self.config.end_margin,
|
|
190
|
+
message=colorful.yellow("Score anomaly detected - media and text mismatch:\n" + anomaly_str),
|
|
191
|
+
skip_help=True, # anomaly info is more specific than default help
|
|
152
192
|
)
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
except LatticeDecodingError as e:
|
|
156
|
-
safe_print(colorful.red(" x Failed to decode lattice alignment results"))
|
|
193
|
+
else:
|
|
194
|
+
del _alignments
|
|
157
195
|
raise e
|
|
158
|
-
except Exception as e:
|
|
159
|
-
safe_print(colorful.red(" x Failed to decode lattice alignment results"))
|
|
160
|
-
raise LatticeDecodingError(lattice_id, original_error=e)
|
|
161
|
-
|
|
162
|
-
return (supervisions, alignments)
|
|
163
|
-
|
|
164
|
-
except (LatticeEncodingError, AlignmentError, LatticeDecodingError):
|
|
165
|
-
raise
|
|
166
196
|
except Exception as e:
|
|
167
|
-
|
|
197
|
+
safe_print(colorful.red(" x Failed to decode lattice alignment results"))
|
|
198
|
+
raise LatticeDecodingError(lattice_id, original_error=e)
|
|
199
|
+
|
|
200
|
+
return (supervisions, alignments)
|
|
168
201
|
|
|
169
202
|
def profile(self) -> None:
|
|
170
203
|
"""Print profiling statistics."""
|
|
171
204
|
self.worker.profile()
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def _detect_score_anomalies(
|
|
208
|
+
alignments: List[Supervision],
|
|
209
|
+
drop_threshold: float = 0.08,
|
|
210
|
+
window_size: int = 5,
|
|
211
|
+
) -> Optional[Dict[str, Any]]:
|
|
212
|
+
"""Detect score anomalies indicating alignment mismatch.
|
|
213
|
+
|
|
214
|
+
Compares average of window_size segments before vs after each position.
|
|
215
|
+
When the drop is significant, it indicates the audio doesn't match
|
|
216
|
+
the text starting at that position.
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
alignments: List of aligned supervisions with scores
|
|
220
|
+
drop_threshold: Minimum drop between before/after averages to trigger
|
|
221
|
+
window_size: Number of segments to average on each side
|
|
222
|
+
|
|
223
|
+
Returns:
|
|
224
|
+
Dict with anomaly info if found, None otherwise
|
|
225
|
+
"""
|
|
226
|
+
scores = [s.score for s in alignments if s.score is not None]
|
|
227
|
+
if len(scores) < window_size * 2:
|
|
228
|
+
return None
|
|
229
|
+
|
|
230
|
+
for i in range(window_size, len(scores) - window_size):
|
|
231
|
+
before_avg = np.mean(scores[i - window_size : i])
|
|
232
|
+
after_avg = np.mean(scores[i : i + window_size])
|
|
233
|
+
drop = before_avg - after_avg
|
|
234
|
+
|
|
235
|
+
# Trigger: significant drop between before and after windows
|
|
236
|
+
if drop > drop_threshold:
|
|
237
|
+
# Find the exact mutation point (largest single-step drop)
|
|
238
|
+
max_drop = 0
|
|
239
|
+
mutation_idx = i
|
|
240
|
+
for j in range(i - 1, min(i + window_size, len(scores) - 1)):
|
|
241
|
+
single_drop = scores[j] - scores[j + 1]
|
|
242
|
+
if single_drop > max_drop:
|
|
243
|
+
max_drop = single_drop
|
|
244
|
+
mutation_idx = j + 1
|
|
245
|
+
|
|
246
|
+
# Segments: last normal + anomaly segments
|
|
247
|
+
last_normal = alignments[mutation_idx - 1] if mutation_idx > 0 else None
|
|
248
|
+
anomaly_segments = [
|
|
249
|
+
alignments[j] for j in range(mutation_idx, min(mutation_idx + window_size, len(alignments)))
|
|
250
|
+
]
|
|
251
|
+
|
|
252
|
+
return {
|
|
253
|
+
"mutation_index": mutation_idx,
|
|
254
|
+
"before_avg": round(before_avg, 4),
|
|
255
|
+
"after_avg": round(after_avg, 4),
|
|
256
|
+
"window_drop": round(drop, 4),
|
|
257
|
+
"mutation_drop": round(max_drop, 4),
|
|
258
|
+
"last_normal": last_normal,
|
|
259
|
+
"segments": anomaly_segments,
|
|
260
|
+
}
|
|
261
|
+
|
|
262
|
+
return None
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
def _format_anomaly_warning(anomaly: Dict[str, Any]) -> str:
|
|
266
|
+
"""Format anomaly detection result as warning message."""
|
|
267
|
+
lines = [
|
|
268
|
+
f"⚠️ Score anomaly detected at segment #{anomaly['mutation_index']}",
|
|
269
|
+
f" Window avg: {anomaly['before_avg']:.4f} → {anomaly['after_avg']:.4f} (drop: {anomaly['window_drop']:.4f})", # noqa: E501
|
|
270
|
+
f" Mutation drop: {anomaly['mutation_drop']:.4f}",
|
|
271
|
+
"",
|
|
272
|
+
]
|
|
273
|
+
|
|
274
|
+
# Show last normal segment
|
|
275
|
+
if anomaly.get("last_normal"):
|
|
276
|
+
seg = anomaly["last_normal"]
|
|
277
|
+
text_preview = seg.text[:50] + "..." if len(seg.text) > 50 else seg.text
|
|
278
|
+
lines.append(f' [{seg.start:.2f}s-{seg.end:.2f}s] score={seg.score:.4f} "{text_preview}"')
|
|
279
|
+
|
|
280
|
+
# Separator - mutation point
|
|
281
|
+
lines.append(" " + "─" * 60)
|
|
282
|
+
lines.append(f" ⬇️ MUTATION: The following {len(anomaly['segments'])}+ segments don't match audio")
|
|
283
|
+
lines.append(" " + "─" * 60)
|
|
284
|
+
|
|
285
|
+
# Show anomaly segments
|
|
286
|
+
for seg in anomaly["segments"]:
|
|
287
|
+
text_preview = seg.text[:50] + "..." if len(seg.text) > 50 else seg.text
|
|
288
|
+
lines.append(f' [{seg.start:.2f}s-{seg.end:.2f}s] score={seg.score:.4f} "{text_preview}"')
|
|
289
|
+
|
|
290
|
+
lines.append("")
|
|
291
|
+
lines.append(" Possible causes: Transcription error, missing content, or wrong audio region")
|
|
292
|
+
return "\n".join(lines)
|
|
@@ -35,6 +35,8 @@ class Lattice1Worker:
|
|
|
35
35
|
sess_options.intra_op_num_threads = num_threads # CPU cores
|
|
36
36
|
sess_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL
|
|
37
37
|
sess_options.add_session_config_entry("session.intra_op.allow_spinning", "0")
|
|
38
|
+
# Suppress CoreMLExecutionProvider warnings about partial graph support
|
|
39
|
+
sess_options.log_severity_level = 3 # ERROR level only
|
|
38
40
|
|
|
39
41
|
acoustic_model_path = f"{model_path}/acoustic_opt.onnx"
|
|
40
42
|
|
|
@@ -191,12 +193,17 @@ class Lattice1Worker:
|
|
|
191
193
|
float(output_beam),
|
|
192
194
|
int(min_active_states),
|
|
193
195
|
int(max_active_states),
|
|
196
|
+
allow_partial=True,
|
|
194
197
|
)
|
|
195
198
|
|
|
196
|
-
# Streaming mode
|
|
199
|
+
# Streaming mode with confidence score accumulation
|
|
197
200
|
total_duration = audio.duration
|
|
198
201
|
total_minutes = int(total_duration / 60.0)
|
|
199
202
|
|
|
203
|
+
max_probs = []
|
|
204
|
+
aligned_probs = []
|
|
205
|
+
prev_labels_len = 0
|
|
206
|
+
|
|
200
207
|
with tqdm(
|
|
201
208
|
total=total_minutes,
|
|
202
209
|
desc=f"Processing audio ({total_minutes} min)",
|
|
@@ -208,13 +215,43 @@ class Lattice1Worker:
|
|
|
208
215
|
chunk_emission = self.emission(chunk.ndarray, acoustic_scale=acoustic_scale)
|
|
209
216
|
intersecter.decode(chunk_emission[0])
|
|
210
217
|
|
|
218
|
+
__start = time.time()
|
|
219
|
+
# Get partial labels and compute confidence stats for this chunk
|
|
220
|
+
partial_labels = intersecter.get_partial_labels()
|
|
221
|
+
chunk_len = chunk_emission.shape[1]
|
|
222
|
+
|
|
223
|
+
# Get labels for current chunk (new labels since last chunk)
|
|
224
|
+
chunk_labels = partial_labels[prev_labels_len : prev_labels_len + chunk_len]
|
|
225
|
+
prev_labels_len = len(partial_labels)
|
|
226
|
+
|
|
227
|
+
# Compute emission-based confidence stats
|
|
228
|
+
probs = np.exp(chunk_emission[0]) # [T, V]
|
|
229
|
+
max_probs.append(np.max(probs, axis=-1)) # [T]
|
|
230
|
+
|
|
231
|
+
# Handle case where chunk_labels length might differ from chunk_len
|
|
232
|
+
if len(chunk_labels) == chunk_len:
|
|
233
|
+
aligned_probs.append(probs[np.arange(chunk_len), chunk_labels])
|
|
234
|
+
else:
|
|
235
|
+
# Fallback: use max probs as aligned probs (approximate)
|
|
236
|
+
aligned_probs.append(np.max(probs, axis=-1))
|
|
237
|
+
|
|
238
|
+
del chunk_emission, probs # Free memory
|
|
239
|
+
self.timings["align_>labels"] += time.time() - __start
|
|
240
|
+
|
|
211
241
|
# Update progress
|
|
212
242
|
chunk_duration = int(chunk.duration / 60.0)
|
|
213
243
|
pbar.update(chunk_duration)
|
|
214
244
|
|
|
215
|
-
|
|
245
|
+
# Build emission_stats for confidence calculation
|
|
246
|
+
emission_stats = {
|
|
247
|
+
"max_probs": np.concatenate(max_probs),
|
|
248
|
+
"aligned_probs": np.concatenate(aligned_probs),
|
|
249
|
+
}
|
|
250
|
+
|
|
216
251
|
# Get results from intersecter
|
|
252
|
+
__start = time.time()
|
|
217
253
|
results, labels = intersecter.finish()
|
|
254
|
+
self.timings["align_>finish"] += time.time() - __start
|
|
218
255
|
else:
|
|
219
256
|
# Batch mode
|
|
220
257
|
if emission is None:
|
|
@@ -230,13 +267,19 @@ class Lattice1Worker:
|
|
|
230
267
|
float(output_beam),
|
|
231
268
|
int(min_active_states),
|
|
232
269
|
int(max_active_states),
|
|
270
|
+
allow_partial=True,
|
|
233
271
|
)
|
|
234
|
-
|
|
272
|
+
# Compute emission_stats from full emission (same format as streaming)
|
|
273
|
+
probs = np.exp(emission[0]) # [T, V]
|
|
274
|
+
emission_stats = {
|
|
275
|
+
"max_probs": np.max(probs, axis=-1), # [T]
|
|
276
|
+
"aligned_probs": probs[np.arange(probs.shape[0]), labels[0]], # [T]
|
|
277
|
+
}
|
|
235
278
|
|
|
236
279
|
self.timings["align_segments"] += time.time() - _start
|
|
237
280
|
|
|
238
281
|
channel = 0
|
|
239
|
-
return
|
|
282
|
+
return emission_stats, results, labels, self.frame_shift, offset, channel # frame_shift=20ms
|
|
240
283
|
|
|
241
284
|
def profile(self) -> None:
|
|
242
285
|
"""Print formatted profiling statistics."""
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
# Multilingual punctuation characters (no duplicates)
|
|
2
|
+
PUNCTUATION = (
|
|
3
|
+
# ASCII punctuation
|
|
4
|
+
"!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"
|
|
5
|
+
# Chinese/CJK punctuation
|
|
6
|
+
"。,、?!:;·¥…—~“”‘’"
|
|
7
|
+
"《》〈〉【】〔〕〖〗()" # Brackets
|
|
8
|
+
# Japanese punctuation (unique only)
|
|
9
|
+
"「」『』・"
|
|
10
|
+
# CJK punctuation (unique only)
|
|
11
|
+
"〃〆〇〒〓〘〙〚〛〜〝〞〟"
|
|
12
|
+
# Arabic punctuation
|
|
13
|
+
"،؛؟"
|
|
14
|
+
# Thai punctuation
|
|
15
|
+
"๏๚๛"
|
|
16
|
+
# Hebrew punctuation
|
|
17
|
+
"־׀׃׆"
|
|
18
|
+
# Other common punctuation
|
|
19
|
+
"¡¿" # Spanish inverted marks
|
|
20
|
+
"«»‹›" # Guillemets
|
|
21
|
+
"‐‑‒–―" # Dashes (excluding — already above)
|
|
22
|
+
"‚„" # Low quotation marks
|
|
23
|
+
"†‡•‣" # Daggers and bullets
|
|
24
|
+
"′″‴" # Prime marks
|
|
25
|
+
"‰‱" # Per mille
|
|
26
|
+
)
|
|
27
|
+
PUNCTUATION_SPACE = PUNCTUATION + " "
|
|
28
|
+
STAR_TOKEN = "※"
|
|
29
|
+
|
|
30
|
+
# End of sentence punctuation marks (multilingual)
|
|
31
|
+
# - ASCII: .!?"'])
|
|
32
|
+
# - Chinese/CJK: 。!?"】」』〗〙〛 (including right double quote U+201D)
|
|
33
|
+
# - Japanese: 。 (halfwidth period)
|
|
34
|
+
# - Arabic: ؟
|
|
35
|
+
# - Ellipsis: …
|
|
36
|
+
END_PUNCTUATION = ".!?\"'])。!?\u201d】」』〗〙〛。؟…"
|
|
37
|
+
|
|
38
|
+
GROUPING_SEPARATOR = "✹"
|
lattifai/alignment/segmenter.py
CHANGED
|
@@ -5,11 +5,12 @@ from typing import List, Optional, Tuple
|
|
|
5
5
|
import colorful
|
|
6
6
|
|
|
7
7
|
from lattifai.audio2 import AudioData
|
|
8
|
-
from lattifai.caption import
|
|
8
|
+
from lattifai.caption import Supervision
|
|
9
9
|
from lattifai.config import AlignmentConfig
|
|
10
|
+
from lattifai.data import Caption
|
|
10
11
|
from lattifai.utils import safe_print
|
|
11
12
|
|
|
12
|
-
from .
|
|
13
|
+
from .punctuation import END_PUNCTUATION
|
|
13
14
|
|
|
14
15
|
|
|
15
16
|
class Segmenter:
|