lattifai 1.2.1__py3-none-any.whl → 1.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lattifai/alignment/__init__.py +10 -1
- lattifai/alignment/lattice1_aligner.py +66 -58
- lattifai/alignment/punctuation.py +38 -0
- lattifai/alignment/sentence_splitter.py +152 -21
- lattifai/alignment/text_align.py +440 -0
- lattifai/alignment/tokenizer.py +82 -40
- lattifai/caption/__init__.py +82 -6
- lattifai/caption/caption.py +335 -1141
- lattifai/caption/formats/__init__.py +199 -0
- lattifai/caption/formats/base.py +211 -0
- lattifai/caption/{gemini_reader.py → formats/gemini.py} +320 -60
- lattifai/caption/formats/json.py +194 -0
- lattifai/caption/formats/lrc.py +309 -0
- lattifai/caption/formats/nle/__init__.py +9 -0
- lattifai/caption/formats/nle/audition.py +561 -0
- lattifai/caption/formats/nle/avid.py +423 -0
- lattifai/caption/formats/nle/fcpxml.py +549 -0
- lattifai/caption/formats/nle/premiere.py +589 -0
- lattifai/caption/formats/pysubs2.py +642 -0
- lattifai/caption/formats/sbv.py +147 -0
- lattifai/caption/formats/tabular.py +338 -0
- lattifai/caption/formats/textgrid.py +193 -0
- lattifai/caption/formats/ttml.py +652 -0
- lattifai/caption/formats/vtt.py +469 -0
- lattifai/caption/parsers/__init__.py +9 -0
- lattifai/caption/{text_parser.py → parsers/text_parser.py} +4 -2
- lattifai/caption/standardize.py +636 -0
- lattifai/caption/utils.py +474 -0
- lattifai/cli/__init__.py +2 -1
- lattifai/cli/caption.py +108 -1
- lattifai/cli/transcribe.py +1 -1
- lattifai/cli/youtube.py +4 -1
- lattifai/client.py +33 -113
- lattifai/config/__init__.py +11 -1
- lattifai/config/alignment.py +7 -0
- lattifai/config/caption.py +267 -23
- lattifai/config/media.py +20 -0
- lattifai/diarization/__init__.py +41 -1
- lattifai/mixin.py +27 -15
- lattifai/transcription/base.py +6 -1
- lattifai/transcription/lattifai.py +19 -54
- lattifai/utils.py +7 -13
- lattifai/workflow/__init__.py +28 -4
- lattifai/workflow/file_manager.py +2 -5
- lattifai/youtube/__init__.py +43 -0
- lattifai/youtube/client.py +1170 -0
- lattifai/youtube/types.py +23 -0
- lattifai-1.2.2.dist-info/METADATA +615 -0
- lattifai-1.2.2.dist-info/RECORD +76 -0
- {lattifai-1.2.1.dist-info → lattifai-1.2.2.dist-info}/entry_points.txt +1 -2
- lattifai/caption/gemini_writer.py +0 -173
- lattifai/cli/app_installer.py +0 -142
- lattifai/cli/server.py +0 -44
- lattifai/server/app.py +0 -427
- lattifai/workflow/youtube.py +0 -577
- lattifai-1.2.1.dist-info/METADATA +0 -1134
- lattifai-1.2.1.dist-info/RECORD +0 -58
- {lattifai-1.2.1.dist-info → lattifai-1.2.2.dist-info}/WHEEL +0 -0
- {lattifai-1.2.1.dist-info → lattifai-1.2.2.dist-info}/licenses/LICENSE +0 -0
- {lattifai-1.2.1.dist-info → lattifai-1.2.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,440 @@
|
|
|
1
|
+
# Align caption.supervisions with transcription
|
|
2
|
+
import logging
|
|
3
|
+
import string # noqa: F401
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
from collections import defaultdict, namedtuple
|
|
6
|
+
from typing import Callable, Dict, List, Optional, Tuple, TypeVar
|
|
7
|
+
|
|
8
|
+
import regex
|
|
9
|
+
from error_align import error_align
|
|
10
|
+
from error_align.utils import DELIMITERS, NUMERIC_TOKEN, STANDARD_TOKEN, Alignment, OpType
|
|
11
|
+
|
|
12
|
+
from lattifai.caption import Caption, Supervision
|
|
13
|
+
from lattifai.utils import safe_print
|
|
14
|
+
|
|
15
|
+
from .punctuation import PUNCTUATION
|
|
16
|
+
|
|
17
|
+
Symbol = TypeVar("Symbol")
|
|
18
|
+
EPSILON = "`"
|
|
19
|
+
|
|
20
|
+
JOIN_TOKEN = "❄"
|
|
21
|
+
if JOIN_TOKEN not in DELIMITERS:
|
|
22
|
+
DELIMITERS.add(JOIN_TOKEN)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def custom_tokenizer(text: str) -> list:
|
|
26
|
+
"""Default tokenizer that splits text into words based on whitespace.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
text (str): The input text to tokenize.
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
list: A list of tokens (words).
|
|
33
|
+
|
|
34
|
+
"""
|
|
35
|
+
# Escape JOIN_TOKEN for use in regex pattern
|
|
36
|
+
escaped_join_token = regex.escape(JOIN_TOKEN)
|
|
37
|
+
return list(
|
|
38
|
+
regex.finditer(
|
|
39
|
+
rf"({NUMERIC_TOKEN})|({STANDARD_TOKEN}|{escaped_join_token})",
|
|
40
|
+
text,
|
|
41
|
+
regex.UNICODE | regex.VERBOSE,
|
|
42
|
+
)
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def equal_ratio(chunk: List[Alignment]):
|
|
47
|
+
return sum(a.op_type == OpType.MATCH for a in chunk) / max(len(chunk), 1)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def equal(chunk: List[Alignment]):
|
|
51
|
+
return all(a.op_type == OpType.MATCH for a in chunk)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def group_alignments(
|
|
55
|
+
supervisions: List[Supervision],
|
|
56
|
+
transcription: List[Supervision],
|
|
57
|
+
max_silence_gap: float = 10.0,
|
|
58
|
+
mini_num_supervisions: int = 1,
|
|
59
|
+
mini_num_transcription: int = 1,
|
|
60
|
+
equal_threshold: float = 0.5,
|
|
61
|
+
verbose: bool = False,
|
|
62
|
+
) -> List[Tuple[Tuple[int, int], Tuple[int, int], List[Alignment]]]:
|
|
63
|
+
# TABLE = str.maketrans(dict.fromkeys(string.punctuation))
|
|
64
|
+
# sup.text.lower().translate(TABLE)
|
|
65
|
+
ref = "".join(sup.text.lower() + JOIN_TOKEN for sup in supervisions)
|
|
66
|
+
hyp = "".join(sup.text.lower() + JOIN_TOKEN for sup in transcription)
|
|
67
|
+
alignments = error_align(ref, hyp, tokenizer=custom_tokenizer)
|
|
68
|
+
|
|
69
|
+
matches = []
|
|
70
|
+
# segment start index
|
|
71
|
+
ss_start, ss_idx = 0, 0
|
|
72
|
+
tr_start, tr_idx = 0, 0
|
|
73
|
+
|
|
74
|
+
idx_start = 0
|
|
75
|
+
for idx, ali in enumerate(alignments):
|
|
76
|
+
if ali.ref == JOIN_TOKEN:
|
|
77
|
+
ss_idx += 1
|
|
78
|
+
if ali.hyp == JOIN_TOKEN:
|
|
79
|
+
tr_idx += 1
|
|
80
|
+
|
|
81
|
+
if ali.ref == JOIN_TOKEN and ali.hyp == JOIN_TOKEN:
|
|
82
|
+
chunk = alignments[idx_start:idx]
|
|
83
|
+
|
|
84
|
+
split_at_silence = False
|
|
85
|
+
greater_two_silence_gap = False
|
|
86
|
+
if tr_idx > 0 and tr_idx < len(transcription):
|
|
87
|
+
gap = transcription[tr_idx].start - transcription[tr_idx - 1].end
|
|
88
|
+
if gap > max_silence_gap:
|
|
89
|
+
split_at_silence = True
|
|
90
|
+
if gap > 2 * max_silence_gap:
|
|
91
|
+
greater_two_silence_gap = True
|
|
92
|
+
|
|
93
|
+
if (
|
|
94
|
+
(equal_ratio(chunk[:10]) > equal_threshold or equal_ratio(chunk[-10:]) > equal_threshold)
|
|
95
|
+
and (
|
|
96
|
+
(ss_idx - ss_start >= mini_num_supervisions and tr_idx - tr_start >= mini_num_transcription)
|
|
97
|
+
or split_at_silence
|
|
98
|
+
)
|
|
99
|
+
) or greater_two_silence_gap:
|
|
100
|
+
matches.append(((ss_start, ss_idx), (tr_start, tr_idx), chunk))
|
|
101
|
+
|
|
102
|
+
if verbose:
|
|
103
|
+
sub_align = supervisions[ss_start:ss_idx]
|
|
104
|
+
asr_align = transcription[tr_start:tr_idx]
|
|
105
|
+
safe_print("========================================================================")
|
|
106
|
+
safe_print(f" Caption [{ss_start:>4d}, {ss_idx:>4d}): {[sup.text for sup in sub_align]}")
|
|
107
|
+
safe_print(f"Transcript [{tr_start:>4d}, {tr_idx:>4d}): {[sup.text for sup in asr_align]}")
|
|
108
|
+
safe_print("========================================================================\n\n")
|
|
109
|
+
|
|
110
|
+
ss_start = ss_idx
|
|
111
|
+
tr_start = tr_idx
|
|
112
|
+
idx_start = idx + 1
|
|
113
|
+
|
|
114
|
+
if ss_start == len(supervisions) and tr_start == len(transcription):
|
|
115
|
+
break
|
|
116
|
+
|
|
117
|
+
# remainder
|
|
118
|
+
if ss_idx == len(supervisions) or tr_idx == len(transcription):
|
|
119
|
+
chunk = alignments[idx_start:]
|
|
120
|
+
matches.append(((ss_start, len(supervisions)), (tr_start, len(transcription)), chunk))
|
|
121
|
+
break
|
|
122
|
+
|
|
123
|
+
return matches
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
class AlignQuality(namedtuple("AlignQuality", ["FW", "LW", "PREFIX", "SUFFIX", "WER"])):
|
|
127
|
+
def __repr__(self) -> str:
|
|
128
|
+
quality = f"WORD[{self.FW}][{self.LW}]_WER[{self.WER}]"
|
|
129
|
+
return quality
|
|
130
|
+
|
|
131
|
+
@property
|
|
132
|
+
def info(self) -> str:
|
|
133
|
+
info = f"WER {self.WER.WER:.4f} accuracy [{self.PREFIX:.2f}, {self.SUFFIX:.2f}] {self.WER}"
|
|
134
|
+
return info
|
|
135
|
+
|
|
136
|
+
@property
|
|
137
|
+
def first_word_equal(self) -> bool:
|
|
138
|
+
return self.FW == "FE"
|
|
139
|
+
|
|
140
|
+
@property
|
|
141
|
+
def last_word_equal(self) -> bool:
|
|
142
|
+
return self.LW == "LE"
|
|
143
|
+
|
|
144
|
+
@property
|
|
145
|
+
def wer(self) -> float:
|
|
146
|
+
return self.WER.WER
|
|
147
|
+
|
|
148
|
+
@property
|
|
149
|
+
def qwer(self):
|
|
150
|
+
wer = self.wer
|
|
151
|
+
# 考虑 ref_len
|
|
152
|
+
if wer == 0.0:
|
|
153
|
+
return "WZ" # zero
|
|
154
|
+
elif wer < 0.1:
|
|
155
|
+
return "WL" # low
|
|
156
|
+
elif wer < 0.32:
|
|
157
|
+
return "WM" # medium
|
|
158
|
+
else:
|
|
159
|
+
return "WH" # high
|
|
160
|
+
|
|
161
|
+
@property
|
|
162
|
+
def qprefix(self) -> str:
|
|
163
|
+
if self.PREFIX >= 0.7:
|
|
164
|
+
return "PH" # high
|
|
165
|
+
elif self.PREFIX >= 0.5:
|
|
166
|
+
return "PM" # medium
|
|
167
|
+
else:
|
|
168
|
+
return "PL" # low
|
|
169
|
+
|
|
170
|
+
@property
|
|
171
|
+
def qsuffix(self) -> str:
|
|
172
|
+
if self.SUFFIX > 0.7:
|
|
173
|
+
return "SH" # high
|
|
174
|
+
elif self.SUFFIX > 0.5:
|
|
175
|
+
return "SM" # medium
|
|
176
|
+
else:
|
|
177
|
+
return "SL" # low
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
class TimestampQuality(namedtuple("TimestampQuality", ["start", "end"])):
|
|
181
|
+
@property
|
|
182
|
+
def start_diff(self):
|
|
183
|
+
return abs(self.start[0] - self.start[1])
|
|
184
|
+
|
|
185
|
+
@property
|
|
186
|
+
def end_diff(self):
|
|
187
|
+
return abs(self.end[0] - self.end[1])
|
|
188
|
+
|
|
189
|
+
@property
|
|
190
|
+
def diff(self):
|
|
191
|
+
return max(self.start_diff, self.end_diff)
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
TextAlignResult = Tuple[Optional[List[Supervision]], Optional[List[Supervision]], AlignQuality, TimestampQuality, int]
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def align_supervisions_and_transcription(
|
|
198
|
+
caption: Caption,
|
|
199
|
+
max_duration: Optional[float] = None,
|
|
200
|
+
verbose: bool = False,
|
|
201
|
+
) -> List[TextAlignResult]:
|
|
202
|
+
"""Align caption.supervisions with caption.transcription.
|
|
203
|
+
|
|
204
|
+
Args:
|
|
205
|
+
caption: Caption object containing supervisions and transcription.
|
|
206
|
+
|
|
207
|
+
"""
|
|
208
|
+
groups = group_alignments(caption.supervisions, caption.transcription, verbose=False)
|
|
209
|
+
|
|
210
|
+
if max_duration is None:
|
|
211
|
+
max_duration = max(caption.transcription[-1].end, caption.supervisions[-1].end)
|
|
212
|
+
else:
|
|
213
|
+
max_duration = min(
|
|
214
|
+
max_duration,
|
|
215
|
+
max(caption.transcription[-1].end, caption.supervisions[-1].end) + 10.0,
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
def next_start(alignments: List[Supervision], idx: int) -> float:
|
|
219
|
+
if idx < len(alignments):
|
|
220
|
+
return alignments[idx].start
|
|
221
|
+
return min(alignments[-1].end + 2.0, max_duration)
|
|
222
|
+
|
|
223
|
+
wer_filter = WERFilter()
|
|
224
|
+
|
|
225
|
+
matches = []
|
|
226
|
+
for idx, ((sub_start, sub_end), (asr_start, asr_end), chunk) in enumerate(groups):
|
|
227
|
+
sub_align = caption.supervisions[sub_start:sub_end]
|
|
228
|
+
asr_align = caption.transcription[asr_start:asr_end]
|
|
229
|
+
|
|
230
|
+
if not sub_align or not asr_align:
|
|
231
|
+
if sub_align:
|
|
232
|
+
if matches:
|
|
233
|
+
_asr_start = matches[-1][-2].end[1]
|
|
234
|
+
else:
|
|
235
|
+
_asr_start = 0.0
|
|
236
|
+
startends = [
|
|
237
|
+
(sub_align[0].start, sub_align[-1].end),
|
|
238
|
+
(_asr_start, next_start(caption.transcription, asr_end)),
|
|
239
|
+
]
|
|
240
|
+
elif asr_align:
|
|
241
|
+
if matches:
|
|
242
|
+
_sub_start = matches[-1][-2].end[0]
|
|
243
|
+
else:
|
|
244
|
+
_sub_start = 0.0
|
|
245
|
+
startends = [
|
|
246
|
+
(_sub_start, next_start(caption.supervisions, sub_end)),
|
|
247
|
+
(asr_align[0].start, asr_align[-1].end),
|
|
248
|
+
]
|
|
249
|
+
else:
|
|
250
|
+
raise ValueError(
|
|
251
|
+
f"Never Here! subtitles[{len(caption.supervisions)}] {sub_start}-{sub_end} asrs[{len(caption.transcription)}] {asr_start}-{asr_end}"
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
if verbose:
|
|
255
|
+
safe_print("oooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo")
|
|
256
|
+
safe_print(
|
|
257
|
+
f" Caption idx=[{sub_start:>4d}, {sub_end:>4d}) timestamp=[{startends[0][0]:>8.2f}, {startends[0][1]:>8.2f}]: {[sup.text for sup in sub_align]}"
|
|
258
|
+
)
|
|
259
|
+
safe_print(
|
|
260
|
+
f"Transcript idx=[{asr_start:>4d}, {asr_end:>4d}) timestamp=[{startends[1][0]:>8.2f}, {startends[1][1]:>8.2f}]: {[sup.text for sup in asr_align]}"
|
|
261
|
+
)
|
|
262
|
+
safe_print("oooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo\n\n")
|
|
263
|
+
|
|
264
|
+
aligned, timestamp = quality(chunk, startends[0], startends[1], wer_fn=wer_filter)
|
|
265
|
+
matches.append([sub_align, asr_align, aligned, timestamp, chunk])
|
|
266
|
+
continue
|
|
267
|
+
else:
|
|
268
|
+
aligned, timestamp = quality(
|
|
269
|
+
chunk,
|
|
270
|
+
[sub_align[0].start, sub_align[-1].end],
|
|
271
|
+
[asr_align[0].start, asr_align[-1].end],
|
|
272
|
+
wer_fn=wer_filter,
|
|
273
|
+
)
|
|
274
|
+
matches.append([sub_align, asr_align, aligned, timestamp, chunk])
|
|
275
|
+
|
|
276
|
+
if verbose and aligned.wer > 0.0:
|
|
277
|
+
safe_print(
|
|
278
|
+
f"===================================WER={aligned.wer:>4.2f}====================================="
|
|
279
|
+
)
|
|
280
|
+
safe_print(
|
|
281
|
+
f" Caption idx=[{sub_start:>4d}, {sub_end:>4d}) timestamp=[{sub_align[0].start:>8.2f}, {sub_align[-1].end:>8.2f}]: {[sup.text for sup in sub_align]}"
|
|
282
|
+
)
|
|
283
|
+
safe_print(
|
|
284
|
+
f"Transcript idx=[{asr_start:>4d}, {asr_end:>4d}) timestamp=[{asr_align[0].start:>8.2f}, {asr_align[-1].end:>8.2f}]: {[sup.text for sup in asr_align]}"
|
|
285
|
+
)
|
|
286
|
+
safe_print("========================================================================\n\n")
|
|
287
|
+
|
|
288
|
+
return matches
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
class AlignFilter(ABC):
|
|
292
|
+
|
|
293
|
+
def __init__(
|
|
294
|
+
self, PUNCTUATION: str = PUNCTUATION, IGNORE: str = "", SPACE=" ", EPSILON=EPSILON, SEPARATOR=JOIN_TOKEN
|
|
295
|
+
):
|
|
296
|
+
super().__init__()
|
|
297
|
+
self._name = self.__class__.__name__
|
|
298
|
+
self.PUNCTUATION = PUNCTUATION
|
|
299
|
+
|
|
300
|
+
self.IGNORE = IGNORE
|
|
301
|
+
self.SPACE = SPACE
|
|
302
|
+
self.EPSILON = EPSILON
|
|
303
|
+
self.SEPARATOR = SEPARATOR
|
|
304
|
+
|
|
305
|
+
self.PUNCTUATION_SEPARATOR = PUNCTUATION + SEPARATOR
|
|
306
|
+
self.PUNCTUATION_SPACE = PUNCTUATION + SPACE
|
|
307
|
+
self.PUNCTUATION_SPACE_SEPARATOR = PUNCTUATION + SPACE + SEPARATOR
|
|
308
|
+
|
|
309
|
+
@abstractmethod
|
|
310
|
+
def __call__(self, chunk: List[Alignment]) -> str:
|
|
311
|
+
pass
|
|
312
|
+
|
|
313
|
+
@property
|
|
314
|
+
def name(self):
|
|
315
|
+
return self._name
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
class WERStats(namedtuple("AlignStats", ["WER", "ins_errs", "del_errs", "sub_errs", "ref_len"])):
|
|
319
|
+
def to_dict(self):
|
|
320
|
+
return {
|
|
321
|
+
"WER": self.WER,
|
|
322
|
+
"ins_errs": self.ins_errs,
|
|
323
|
+
"del_errs": self.del_errs,
|
|
324
|
+
"sub_errs": self.sub_errs,
|
|
325
|
+
"ref_len": self.ref_len,
|
|
326
|
+
}
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
# https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
|
|
330
|
+
def compute_align_stats(
|
|
331
|
+
ali: List[Tuple[str, str]],
|
|
332
|
+
ERR: str = "*",
|
|
333
|
+
IGNORE: str = "",
|
|
334
|
+
enable_log: bool = True,
|
|
335
|
+
) -> WERStats:
|
|
336
|
+
subs: Dict[Tuple[str, str], int] = defaultdict(int)
|
|
337
|
+
ins: Dict[str, int] = defaultdict(int)
|
|
338
|
+
dels: Dict[str, int] = defaultdict(int)
|
|
339
|
+
|
|
340
|
+
# `words` stores counts per word, as follows:
|
|
341
|
+
# corr, ref_sub, hyp_sub, ins, dels
|
|
342
|
+
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
|
|
343
|
+
num_corr = 0
|
|
344
|
+
|
|
345
|
+
ref_len = 0
|
|
346
|
+
skip = 0
|
|
347
|
+
for k, (ref_word, hyp_word) in enumerate(ali):
|
|
348
|
+
if skip > 0:
|
|
349
|
+
skip -= 1
|
|
350
|
+
continue
|
|
351
|
+
|
|
352
|
+
# compute_align_stats(ali, ERR=EPSILON, IGNORE=PUNCTUATION_SPACE_SEPARATOR)
|
|
353
|
+
if ali[k : k + 1] in [
|
|
354
|
+
[("is", "'s")], # what is -> what's
|
|
355
|
+
[("am", "'m")], # I am -> I'm
|
|
356
|
+
[("are", "'re")], # they are -> they're
|
|
357
|
+
[("would", "'d")], # they would -> they'd don't
|
|
358
|
+
[("had", "'d")], # I had -> I'd
|
|
359
|
+
# we will -> we'll
|
|
360
|
+
[("will", "'ll")],
|
|
361
|
+
# I have -> I've
|
|
362
|
+
[("have", "'ve")],
|
|
363
|
+
# ok -> okay
|
|
364
|
+
[("ok", "okay")],
|
|
365
|
+
# okay -> ok
|
|
366
|
+
[("okay", "ok")],
|
|
367
|
+
]:
|
|
368
|
+
skip = 1
|
|
369
|
+
ref_len += 1
|
|
370
|
+
continue
|
|
371
|
+
elif ali[k : k + 2] in [
|
|
372
|
+
# let us -> let's
|
|
373
|
+
[("let", "let"), ("us", "'s")],
|
|
374
|
+
# do not -> don't
|
|
375
|
+
[("do", "do"), ("not", "n't")],
|
|
376
|
+
]:
|
|
377
|
+
skip = 2
|
|
378
|
+
ref_len += 2
|
|
379
|
+
continue
|
|
380
|
+
elif (ref_word and ref_word in IGNORE) and (hyp_word and hyp_word in IGNORE):
|
|
381
|
+
continue
|
|
382
|
+
else:
|
|
383
|
+
ref_len += 1
|
|
384
|
+
|
|
385
|
+
if ref_word == ERR:
|
|
386
|
+
ins[hyp_word] += 1
|
|
387
|
+
words[hyp_word][3] += 1
|
|
388
|
+
elif hyp_word == ERR:
|
|
389
|
+
dels[ref_word] += 1
|
|
390
|
+
words[ref_word][4] += 1
|
|
391
|
+
elif hyp_word != ref_word:
|
|
392
|
+
subs[(ref_word, hyp_word)] += 1
|
|
393
|
+
words[ref_word][1] += 1
|
|
394
|
+
words[hyp_word][2] += 1
|
|
395
|
+
else:
|
|
396
|
+
words[ref_word][0] += 1
|
|
397
|
+
num_corr += 1
|
|
398
|
+
|
|
399
|
+
sub_errs = sum(subs.values())
|
|
400
|
+
ins_errs = sum(ins.values())
|
|
401
|
+
del_errs = sum(dels.values())
|
|
402
|
+
tot_errs = sub_errs + ins_errs + del_errs
|
|
403
|
+
|
|
404
|
+
stats = WERStats(
|
|
405
|
+
WER=round(tot_errs / max(ref_len, 1), ndigits=4),
|
|
406
|
+
ins_errs=ins_errs,
|
|
407
|
+
del_errs=del_errs,
|
|
408
|
+
sub_errs=sub_errs,
|
|
409
|
+
ref_len=ref_len,
|
|
410
|
+
)
|
|
411
|
+
|
|
412
|
+
if enable_log:
|
|
413
|
+
logging.info(
|
|
414
|
+
f"%WER {stats.WER:.4%} "
|
|
415
|
+
f"[{tot_errs} / {max(ref_len, 1)}, {ins_errs} ins, "
|
|
416
|
+
f"{del_errs} del, {sub_errs} sub ]"
|
|
417
|
+
)
|
|
418
|
+
|
|
419
|
+
return stats
|
|
420
|
+
|
|
421
|
+
|
|
422
|
+
class WERFilter(AlignFilter):
|
|
423
|
+
def __call__(self, chunk: List[Alignment]) -> WERStats:
|
|
424
|
+
ali = [(a.ref, a.hyp) for a in chunk]
|
|
425
|
+
stats = compute_align_stats(ali, ERR=JOIN_TOKEN, IGNORE=self.PUNCTUATION_SPACE_SEPARATOR, enable_log=False)
|
|
426
|
+
return stats
|
|
427
|
+
|
|
428
|
+
|
|
429
|
+
def quality(
|
|
430
|
+
chunk: List[Alignment], supervision: Tuple[float, float], transcript: Tuple[float, float], wer_fn: Callable
|
|
431
|
+
) -> Tuple[AlignQuality, TimestampQuality]:
|
|
432
|
+
_quality = AlignQuality(
|
|
433
|
+
FW="FE" if chunk and chunk[0].op_type == OpType.MATCH else "FN",
|
|
434
|
+
LW="LE" if chunk and chunk[-1].op_type == OpType.MATCH else "LN",
|
|
435
|
+
PREFIX=equal_ratio(chunk[:4]),
|
|
436
|
+
SUFFIX=equal_ratio(chunk[-4:]),
|
|
437
|
+
WER=wer_fn(chunk),
|
|
438
|
+
)
|
|
439
|
+
timestamp = TimestampQuality(start=(supervision[0], transcript[0]), end=(supervision[1], transcript[1]))
|
|
440
|
+
return _quality, timestamp
|
lattifai/alignment/tokenizer.py
CHANGED
|
@@ -2,11 +2,13 @@ import gzip
|
|
|
2
2
|
import pickle
|
|
3
3
|
import re
|
|
4
4
|
from collections import defaultdict
|
|
5
|
-
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar
|
|
5
|
+
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
|
6
6
|
|
|
7
7
|
import numpy as np
|
|
8
8
|
|
|
9
|
-
from lattifai.caption import Supervision
|
|
9
|
+
# from lattifai.caption import Supervision
|
|
10
|
+
from lhotse.supervision import SupervisionSegment as Supervision # NOTE: Transcriber SupervisionSegment
|
|
11
|
+
|
|
10
12
|
from lattifai.caption import normalize_text as normalize_html_text
|
|
11
13
|
from lattifai.errors import (
|
|
12
14
|
LATTICE_DECODING_FAILURE_HELP,
|
|
@@ -16,13 +18,9 @@ from lattifai.errors import (
|
|
|
16
18
|
)
|
|
17
19
|
|
|
18
20
|
from .phonemizer import G2Phonemizer
|
|
21
|
+
from .punctuation import PUNCTUATION, PUNCTUATION_SPACE
|
|
19
22
|
from .sentence_splitter import SentenceSplitter
|
|
20
|
-
|
|
21
|
-
PUNCTUATION = '!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~'
|
|
22
|
-
PUNCTUATION_SPACE = PUNCTUATION + " "
|
|
23
|
-
STAR_TOKEN = "※"
|
|
24
|
-
|
|
25
|
-
GROUPING_SEPARATOR = "✹"
|
|
23
|
+
from .text_align import TextAlignResult
|
|
26
24
|
|
|
27
25
|
MAXIMUM_WORD_LENGTH = 40
|
|
28
26
|
|
|
@@ -80,8 +78,11 @@ def tokenize_multilingual_text(text: str, keep_spaces: bool = True, attach_punct
|
|
|
80
78
|
['Kühlschrank']
|
|
81
79
|
>>> tokenize_multilingual_text("Hello, World!", attach_punctuation=True)
|
|
82
80
|
['Hello,', ' ', 'World!']
|
|
81
|
+
>>> tokenize_multilingual_text("[AED], World!", keep_spaces=False, attach_punctuation=True)
|
|
82
|
+
['[AED],', 'World!']
|
|
83
83
|
"""
|
|
84
84
|
# Regex pattern:
|
|
85
|
+
# - \[[A-Z_]+\] matches bracketed annotations like [APPLAUSE], [MUSIC], [SPEAKER_01]
|
|
85
86
|
# - [a-zA-Z0-9\u00C0-\u024F]+ matches Latin letters (including accented chars like ü, ö, ä, ß, é, etc.)
|
|
86
87
|
# - (?:'[a-zA-Z]{1,2})? optionally matches contractions like 's, 't, 'm, 'll, 're, 've
|
|
87
88
|
# - [\u4e00-\u9fff] matches CJK characters
|
|
@@ -90,7 +91,7 @@ def tokenize_multilingual_text(text: str, keep_spaces: bool = True, attach_punct
|
|
|
90
91
|
# - \u00C0-\u00FF: Latin-1 Supplement (À-ÿ)
|
|
91
92
|
# - \u0100-\u017F: Latin Extended-A
|
|
92
93
|
# - \u0180-\u024F: Latin Extended-B
|
|
93
|
-
pattern = re.compile(r"([a-zA-Z0-9\u00C0-\u024F]+(?:'[a-zA-Z]{1,2})?|[\u4e00-\u9fff]|.)")
|
|
94
|
+
pattern = re.compile(r"(\[[A-Z_]+\]|[a-zA-Z0-9\u00C0-\u024F]+(?:'[a-zA-Z]{1,2})?|[\u4e00-\u9fff]|.)")
|
|
94
95
|
|
|
95
96
|
# filter(None, ...) removes any empty strings from re.findall results
|
|
96
97
|
tokens = list(filter(None, pattern.findall(text)))
|
|
@@ -245,19 +246,37 @@ class LatticeTokenizer:
|
|
|
245
246
|
self.init_sentence_splitter()
|
|
246
247
|
return self.sentence_splitter.split_sentences(supervisions, strip_whitespace=strip_whitespace)
|
|
247
248
|
|
|
248
|
-
def tokenize(
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
"
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
249
|
+
def tokenize(
|
|
250
|
+
self, supervisions: Union[List[Supervision], TextAlignResult], split_sentence: bool = False, boost: float = 0.0
|
|
251
|
+
) -> Tuple[str, Dict[str, Any]]:
|
|
252
|
+
if isinstance(supervisions[0], Supervision):
|
|
253
|
+
if split_sentence:
|
|
254
|
+
supervisions = self.split_sentences(supervisions)
|
|
255
|
+
|
|
256
|
+
pronunciation_dictionaries = self.prenormalize([s.text for s in supervisions])
|
|
257
|
+
response = self.client_wrapper.post(
|
|
258
|
+
"tokenize",
|
|
259
|
+
json={
|
|
260
|
+
"model_name": self.model_name,
|
|
261
|
+
"supervisions": [s.to_dict() for s in supervisions],
|
|
262
|
+
"pronunciation_dictionaries": pronunciation_dictionaries,
|
|
263
|
+
},
|
|
264
|
+
)
|
|
265
|
+
else:
|
|
266
|
+
pronunciation_dictionaries = self.prenormalize([s.text for s in supervisions[0]])
|
|
267
|
+
pronunciation_dictionaries.update(self.prenormalize([s.text for s in supervisions[1]]))
|
|
268
|
+
|
|
269
|
+
response = self.client_wrapper.post(
|
|
270
|
+
"difftokenize",
|
|
271
|
+
json={
|
|
272
|
+
"model_name": self.model_name,
|
|
273
|
+
"supervisions": [s.to_dict() for s in supervisions[0]],
|
|
274
|
+
"transcription": [s.to_dict() for s in supervisions[1]],
|
|
275
|
+
"pronunciation_dictionaries": pronunciation_dictionaries,
|
|
276
|
+
"boost": boost,
|
|
277
|
+
},
|
|
278
|
+
)
|
|
279
|
+
|
|
261
280
|
if response.status_code == 402:
|
|
262
281
|
raise QuotaExceededError(response.json().get("detail", "Quota exceeded"))
|
|
263
282
|
if response.status_code != 200:
|
|
@@ -274,28 +293,47 @@ class LatticeTokenizer:
|
|
|
274
293
|
self,
|
|
275
294
|
lattice_id: str,
|
|
276
295
|
lattice_results: Tuple[np.ndarray, Any, Any, float, float],
|
|
277
|
-
supervisions: List[Supervision],
|
|
296
|
+
supervisions: Union[List[Supervision], TextAlignResult],
|
|
278
297
|
return_details: bool = False,
|
|
279
298
|
start_margin: float = 0.08,
|
|
280
299
|
end_margin: float = 0.20,
|
|
281
300
|
) -> List[Supervision]:
|
|
282
301
|
emission, results, labels, frame_shift, offset, channel = lattice_results # noqa: F841
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
302
|
+
if isinstance(supervisions[0], Supervision):
|
|
303
|
+
response = self.client_wrapper.post(
|
|
304
|
+
"detokenize",
|
|
305
|
+
json={
|
|
306
|
+
"model_name": self.model_name,
|
|
307
|
+
"lattice_id": lattice_id,
|
|
308
|
+
"frame_shift": frame_shift,
|
|
309
|
+
"results": [t.to_dict() for t in results[0]],
|
|
310
|
+
"labels": labels[0],
|
|
311
|
+
"offset": offset,
|
|
312
|
+
"channel": channel,
|
|
313
|
+
"return_details": False if return_details is None else return_details,
|
|
314
|
+
"destroy_lattice": True,
|
|
315
|
+
"start_margin": start_margin,
|
|
316
|
+
"end_margin": end_margin,
|
|
317
|
+
},
|
|
318
|
+
)
|
|
319
|
+
else:
|
|
320
|
+
response = self.client_wrapper.post(
|
|
321
|
+
"diffdetokenize",
|
|
322
|
+
json={
|
|
323
|
+
"model_name": self.model_name,
|
|
324
|
+
"lattice_id": lattice_id,
|
|
325
|
+
"frame_shift": frame_shift,
|
|
326
|
+
"results": [t.to_dict() for t in results[0]],
|
|
327
|
+
"labels": labels[0],
|
|
328
|
+
"offset": offset,
|
|
329
|
+
"channel": channel,
|
|
330
|
+
"return_details": False if return_details is None else return_details,
|
|
331
|
+
"destroy_lattice": True,
|
|
332
|
+
"start_margin": start_margin,
|
|
333
|
+
"end_margin": end_margin,
|
|
334
|
+
},
|
|
335
|
+
)
|
|
336
|
+
|
|
299
337
|
if response.status_code == 400:
|
|
300
338
|
raise LatticeDecodingError(
|
|
301
339
|
lattice_id,
|
|
@@ -316,7 +354,11 @@ class LatticeTokenizer:
|
|
|
316
354
|
# Add emission confidence scores for segments and word-level alignments
|
|
317
355
|
_add_confidence_scores(alignments, emission, labels[0], frame_shift, offset)
|
|
318
356
|
|
|
319
|
-
|
|
357
|
+
if isinstance(supervisions[0], Supervision):
|
|
358
|
+
alignments = _update_alignments_speaker(supervisions, alignments)
|
|
359
|
+
else:
|
|
360
|
+
# NOTE: Text Diff Alignment >> speaker has been handled in the backend service
|
|
361
|
+
pass
|
|
320
362
|
|
|
321
363
|
return alignments
|
|
322
364
|
|