lattifai 1.2.0__py3-none-any.whl → 1.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lattifai/__init__.py +0 -24
- lattifai/alignment/__init__.py +10 -1
- lattifai/alignment/lattice1_aligner.py +66 -58
- lattifai/alignment/lattice1_worker.py +1 -6
- lattifai/alignment/punctuation.py +38 -0
- lattifai/alignment/segmenter.py +1 -1
- lattifai/alignment/sentence_splitter.py +350 -0
- lattifai/alignment/text_align.py +440 -0
- lattifai/alignment/tokenizer.py +91 -220
- lattifai/caption/__init__.py +82 -6
- lattifai/caption/caption.py +335 -1143
- lattifai/caption/formats/__init__.py +199 -0
- lattifai/caption/formats/base.py +211 -0
- lattifai/caption/formats/gemini.py +722 -0
- lattifai/caption/formats/json.py +194 -0
- lattifai/caption/formats/lrc.py +309 -0
- lattifai/caption/formats/nle/__init__.py +9 -0
- lattifai/caption/formats/nle/audition.py +561 -0
- lattifai/caption/formats/nle/avid.py +423 -0
- lattifai/caption/formats/nle/fcpxml.py +549 -0
- lattifai/caption/formats/nle/premiere.py +589 -0
- lattifai/caption/formats/pysubs2.py +642 -0
- lattifai/caption/formats/sbv.py +147 -0
- lattifai/caption/formats/tabular.py +338 -0
- lattifai/caption/formats/textgrid.py +193 -0
- lattifai/caption/formats/ttml.py +652 -0
- lattifai/caption/formats/vtt.py +469 -0
- lattifai/caption/parsers/__init__.py +9 -0
- lattifai/caption/{text_parser.py → parsers/text_parser.py} +4 -2
- lattifai/caption/standardize.py +636 -0
- lattifai/caption/utils.py +474 -0
- lattifai/cli/__init__.py +2 -1
- lattifai/cli/caption.py +108 -1
- lattifai/cli/transcribe.py +4 -9
- lattifai/cli/youtube.py +4 -1
- lattifai/client.py +48 -84
- lattifai/config/__init__.py +11 -1
- lattifai/config/alignment.py +9 -2
- lattifai/config/caption.py +267 -23
- lattifai/config/media.py +20 -0
- lattifai/diarization/__init__.py +41 -1
- lattifai/mixin.py +36 -18
- lattifai/transcription/base.py +6 -1
- lattifai/transcription/lattifai.py +19 -54
- lattifai/utils.py +81 -13
- lattifai/workflow/__init__.py +28 -4
- lattifai/workflow/file_manager.py +2 -5
- lattifai/youtube/__init__.py +43 -0
- lattifai/youtube/client.py +1170 -0
- lattifai/youtube/types.py +23 -0
- lattifai-1.2.2.dist-info/METADATA +615 -0
- lattifai-1.2.2.dist-info/RECORD +76 -0
- {lattifai-1.2.0.dist-info → lattifai-1.2.2.dist-info}/entry_points.txt +1 -2
- lattifai/caption/gemini_reader.py +0 -371
- lattifai/caption/gemini_writer.py +0 -173
- lattifai/cli/app_installer.py +0 -142
- lattifai/cli/server.py +0 -44
- lattifai/server/app.py +0 -427
- lattifai/workflow/youtube.py +0 -577
- lattifai-1.2.0.dist-info/METADATA +0 -1133
- lattifai-1.2.0.dist-info/RECORD +0 -57
- {lattifai-1.2.0.dist-info → lattifai-1.2.2.dist-info}/WHEEL +0 -0
- {lattifai-1.2.0.dist-info → lattifai-1.2.2.dist-info}/licenses/LICENSE +0 -0
- {lattifai-1.2.0.dist-info → lattifai-1.2.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,440 @@
|
|
|
1
|
+
# Align caption.supervisions with transcription
|
|
2
|
+
import logging
|
|
3
|
+
import string # noqa: F401
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
from collections import defaultdict, namedtuple
|
|
6
|
+
from typing import Callable, Dict, List, Optional, Tuple, TypeVar
|
|
7
|
+
|
|
8
|
+
import regex
|
|
9
|
+
from error_align import error_align
|
|
10
|
+
from error_align.utils import DELIMITERS, NUMERIC_TOKEN, STANDARD_TOKEN, Alignment, OpType
|
|
11
|
+
|
|
12
|
+
from lattifai.caption import Caption, Supervision
|
|
13
|
+
from lattifai.utils import safe_print
|
|
14
|
+
|
|
15
|
+
from .punctuation import PUNCTUATION
|
|
16
|
+
|
|
17
|
+
Symbol = TypeVar("Symbol")
|
|
18
|
+
EPSILON = "`"
|
|
19
|
+
|
|
20
|
+
JOIN_TOKEN = "❄"
|
|
21
|
+
if JOIN_TOKEN not in DELIMITERS:
|
|
22
|
+
DELIMITERS.add(JOIN_TOKEN)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def custom_tokenizer(text: str) -> list:
|
|
26
|
+
"""Default tokenizer that splits text into words based on whitespace.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
text (str): The input text to tokenize.
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
list: A list of tokens (words).
|
|
33
|
+
|
|
34
|
+
"""
|
|
35
|
+
# Escape JOIN_TOKEN for use in regex pattern
|
|
36
|
+
escaped_join_token = regex.escape(JOIN_TOKEN)
|
|
37
|
+
return list(
|
|
38
|
+
regex.finditer(
|
|
39
|
+
rf"({NUMERIC_TOKEN})|({STANDARD_TOKEN}|{escaped_join_token})",
|
|
40
|
+
text,
|
|
41
|
+
regex.UNICODE | regex.VERBOSE,
|
|
42
|
+
)
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def equal_ratio(chunk: List[Alignment]):
|
|
47
|
+
return sum(a.op_type == OpType.MATCH for a in chunk) / max(len(chunk), 1)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def equal(chunk: List[Alignment]):
|
|
51
|
+
return all(a.op_type == OpType.MATCH for a in chunk)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def group_alignments(
|
|
55
|
+
supervisions: List[Supervision],
|
|
56
|
+
transcription: List[Supervision],
|
|
57
|
+
max_silence_gap: float = 10.0,
|
|
58
|
+
mini_num_supervisions: int = 1,
|
|
59
|
+
mini_num_transcription: int = 1,
|
|
60
|
+
equal_threshold: float = 0.5,
|
|
61
|
+
verbose: bool = False,
|
|
62
|
+
) -> List[Tuple[Tuple[int, int], Tuple[int, int], List[Alignment]]]:
|
|
63
|
+
# TABLE = str.maketrans(dict.fromkeys(string.punctuation))
|
|
64
|
+
# sup.text.lower().translate(TABLE)
|
|
65
|
+
ref = "".join(sup.text.lower() + JOIN_TOKEN for sup in supervisions)
|
|
66
|
+
hyp = "".join(sup.text.lower() + JOIN_TOKEN for sup in transcription)
|
|
67
|
+
alignments = error_align(ref, hyp, tokenizer=custom_tokenizer)
|
|
68
|
+
|
|
69
|
+
matches = []
|
|
70
|
+
# segment start index
|
|
71
|
+
ss_start, ss_idx = 0, 0
|
|
72
|
+
tr_start, tr_idx = 0, 0
|
|
73
|
+
|
|
74
|
+
idx_start = 0
|
|
75
|
+
for idx, ali in enumerate(alignments):
|
|
76
|
+
if ali.ref == JOIN_TOKEN:
|
|
77
|
+
ss_idx += 1
|
|
78
|
+
if ali.hyp == JOIN_TOKEN:
|
|
79
|
+
tr_idx += 1
|
|
80
|
+
|
|
81
|
+
if ali.ref == JOIN_TOKEN and ali.hyp == JOIN_TOKEN:
|
|
82
|
+
chunk = alignments[idx_start:idx]
|
|
83
|
+
|
|
84
|
+
split_at_silence = False
|
|
85
|
+
greater_two_silence_gap = False
|
|
86
|
+
if tr_idx > 0 and tr_idx < len(transcription):
|
|
87
|
+
gap = transcription[tr_idx].start - transcription[tr_idx - 1].end
|
|
88
|
+
if gap > max_silence_gap:
|
|
89
|
+
split_at_silence = True
|
|
90
|
+
if gap > 2 * max_silence_gap:
|
|
91
|
+
greater_two_silence_gap = True
|
|
92
|
+
|
|
93
|
+
if (
|
|
94
|
+
(equal_ratio(chunk[:10]) > equal_threshold or equal_ratio(chunk[-10:]) > equal_threshold)
|
|
95
|
+
and (
|
|
96
|
+
(ss_idx - ss_start >= mini_num_supervisions and tr_idx - tr_start >= mini_num_transcription)
|
|
97
|
+
or split_at_silence
|
|
98
|
+
)
|
|
99
|
+
) or greater_two_silence_gap:
|
|
100
|
+
matches.append(((ss_start, ss_idx), (tr_start, tr_idx), chunk))
|
|
101
|
+
|
|
102
|
+
if verbose:
|
|
103
|
+
sub_align = supervisions[ss_start:ss_idx]
|
|
104
|
+
asr_align = transcription[tr_start:tr_idx]
|
|
105
|
+
safe_print("========================================================================")
|
|
106
|
+
safe_print(f" Caption [{ss_start:>4d}, {ss_idx:>4d}): {[sup.text for sup in sub_align]}")
|
|
107
|
+
safe_print(f"Transcript [{tr_start:>4d}, {tr_idx:>4d}): {[sup.text for sup in asr_align]}")
|
|
108
|
+
safe_print("========================================================================\n\n")
|
|
109
|
+
|
|
110
|
+
ss_start = ss_idx
|
|
111
|
+
tr_start = tr_idx
|
|
112
|
+
idx_start = idx + 1
|
|
113
|
+
|
|
114
|
+
if ss_start == len(supervisions) and tr_start == len(transcription):
|
|
115
|
+
break
|
|
116
|
+
|
|
117
|
+
# remainder
|
|
118
|
+
if ss_idx == len(supervisions) or tr_idx == len(transcription):
|
|
119
|
+
chunk = alignments[idx_start:]
|
|
120
|
+
matches.append(((ss_start, len(supervisions)), (tr_start, len(transcription)), chunk))
|
|
121
|
+
break
|
|
122
|
+
|
|
123
|
+
return matches
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
class AlignQuality(namedtuple("AlignQuality", ["FW", "LW", "PREFIX", "SUFFIX", "WER"])):
|
|
127
|
+
def __repr__(self) -> str:
|
|
128
|
+
quality = f"WORD[{self.FW}][{self.LW}]_WER[{self.WER}]"
|
|
129
|
+
return quality
|
|
130
|
+
|
|
131
|
+
@property
|
|
132
|
+
def info(self) -> str:
|
|
133
|
+
info = f"WER {self.WER.WER:.4f} accuracy [{self.PREFIX:.2f}, {self.SUFFIX:.2f}] {self.WER}"
|
|
134
|
+
return info
|
|
135
|
+
|
|
136
|
+
@property
|
|
137
|
+
def first_word_equal(self) -> bool:
|
|
138
|
+
return self.FW == "FE"
|
|
139
|
+
|
|
140
|
+
@property
|
|
141
|
+
def last_word_equal(self) -> bool:
|
|
142
|
+
return self.LW == "LE"
|
|
143
|
+
|
|
144
|
+
@property
|
|
145
|
+
def wer(self) -> float:
|
|
146
|
+
return self.WER.WER
|
|
147
|
+
|
|
148
|
+
@property
|
|
149
|
+
def qwer(self):
|
|
150
|
+
wer = self.wer
|
|
151
|
+
# 考虑 ref_len
|
|
152
|
+
if wer == 0.0:
|
|
153
|
+
return "WZ" # zero
|
|
154
|
+
elif wer < 0.1:
|
|
155
|
+
return "WL" # low
|
|
156
|
+
elif wer < 0.32:
|
|
157
|
+
return "WM" # medium
|
|
158
|
+
else:
|
|
159
|
+
return "WH" # high
|
|
160
|
+
|
|
161
|
+
@property
|
|
162
|
+
def qprefix(self) -> str:
|
|
163
|
+
if self.PREFIX >= 0.7:
|
|
164
|
+
return "PH" # high
|
|
165
|
+
elif self.PREFIX >= 0.5:
|
|
166
|
+
return "PM" # medium
|
|
167
|
+
else:
|
|
168
|
+
return "PL" # low
|
|
169
|
+
|
|
170
|
+
@property
|
|
171
|
+
def qsuffix(self) -> str:
|
|
172
|
+
if self.SUFFIX > 0.7:
|
|
173
|
+
return "SH" # high
|
|
174
|
+
elif self.SUFFIX > 0.5:
|
|
175
|
+
return "SM" # medium
|
|
176
|
+
else:
|
|
177
|
+
return "SL" # low
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
class TimestampQuality(namedtuple("TimestampQuality", ["start", "end"])):
|
|
181
|
+
@property
|
|
182
|
+
def start_diff(self):
|
|
183
|
+
return abs(self.start[0] - self.start[1])
|
|
184
|
+
|
|
185
|
+
@property
|
|
186
|
+
def end_diff(self):
|
|
187
|
+
return abs(self.end[0] - self.end[1])
|
|
188
|
+
|
|
189
|
+
@property
|
|
190
|
+
def diff(self):
|
|
191
|
+
return max(self.start_diff, self.end_diff)
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
TextAlignResult = Tuple[Optional[List[Supervision]], Optional[List[Supervision]], AlignQuality, TimestampQuality, int]
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def align_supervisions_and_transcription(
|
|
198
|
+
caption: Caption,
|
|
199
|
+
max_duration: Optional[float] = None,
|
|
200
|
+
verbose: bool = False,
|
|
201
|
+
) -> List[TextAlignResult]:
|
|
202
|
+
"""Align caption.supervisions with caption.transcription.
|
|
203
|
+
|
|
204
|
+
Args:
|
|
205
|
+
caption: Caption object containing supervisions and transcription.
|
|
206
|
+
|
|
207
|
+
"""
|
|
208
|
+
groups = group_alignments(caption.supervisions, caption.transcription, verbose=False)
|
|
209
|
+
|
|
210
|
+
if max_duration is None:
|
|
211
|
+
max_duration = max(caption.transcription[-1].end, caption.supervisions[-1].end)
|
|
212
|
+
else:
|
|
213
|
+
max_duration = min(
|
|
214
|
+
max_duration,
|
|
215
|
+
max(caption.transcription[-1].end, caption.supervisions[-1].end) + 10.0,
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
def next_start(alignments: List[Supervision], idx: int) -> float:
|
|
219
|
+
if idx < len(alignments):
|
|
220
|
+
return alignments[idx].start
|
|
221
|
+
return min(alignments[-1].end + 2.0, max_duration)
|
|
222
|
+
|
|
223
|
+
wer_filter = WERFilter()
|
|
224
|
+
|
|
225
|
+
matches = []
|
|
226
|
+
for idx, ((sub_start, sub_end), (asr_start, asr_end), chunk) in enumerate(groups):
|
|
227
|
+
sub_align = caption.supervisions[sub_start:sub_end]
|
|
228
|
+
asr_align = caption.transcription[asr_start:asr_end]
|
|
229
|
+
|
|
230
|
+
if not sub_align or not asr_align:
|
|
231
|
+
if sub_align:
|
|
232
|
+
if matches:
|
|
233
|
+
_asr_start = matches[-1][-2].end[1]
|
|
234
|
+
else:
|
|
235
|
+
_asr_start = 0.0
|
|
236
|
+
startends = [
|
|
237
|
+
(sub_align[0].start, sub_align[-1].end),
|
|
238
|
+
(_asr_start, next_start(caption.transcription, asr_end)),
|
|
239
|
+
]
|
|
240
|
+
elif asr_align:
|
|
241
|
+
if matches:
|
|
242
|
+
_sub_start = matches[-1][-2].end[0]
|
|
243
|
+
else:
|
|
244
|
+
_sub_start = 0.0
|
|
245
|
+
startends = [
|
|
246
|
+
(_sub_start, next_start(caption.supervisions, sub_end)),
|
|
247
|
+
(asr_align[0].start, asr_align[-1].end),
|
|
248
|
+
]
|
|
249
|
+
else:
|
|
250
|
+
raise ValueError(
|
|
251
|
+
f"Never Here! subtitles[{len(caption.supervisions)}] {sub_start}-{sub_end} asrs[{len(caption.transcription)}] {asr_start}-{asr_end}"
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
if verbose:
|
|
255
|
+
safe_print("oooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo")
|
|
256
|
+
safe_print(
|
|
257
|
+
f" Caption idx=[{sub_start:>4d}, {sub_end:>4d}) timestamp=[{startends[0][0]:>8.2f}, {startends[0][1]:>8.2f}]: {[sup.text for sup in sub_align]}"
|
|
258
|
+
)
|
|
259
|
+
safe_print(
|
|
260
|
+
f"Transcript idx=[{asr_start:>4d}, {asr_end:>4d}) timestamp=[{startends[1][0]:>8.2f}, {startends[1][1]:>8.2f}]: {[sup.text for sup in asr_align]}"
|
|
261
|
+
)
|
|
262
|
+
safe_print("oooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo\n\n")
|
|
263
|
+
|
|
264
|
+
aligned, timestamp = quality(chunk, startends[0], startends[1], wer_fn=wer_filter)
|
|
265
|
+
matches.append([sub_align, asr_align, aligned, timestamp, chunk])
|
|
266
|
+
continue
|
|
267
|
+
else:
|
|
268
|
+
aligned, timestamp = quality(
|
|
269
|
+
chunk,
|
|
270
|
+
[sub_align[0].start, sub_align[-1].end],
|
|
271
|
+
[asr_align[0].start, asr_align[-1].end],
|
|
272
|
+
wer_fn=wer_filter,
|
|
273
|
+
)
|
|
274
|
+
matches.append([sub_align, asr_align, aligned, timestamp, chunk])
|
|
275
|
+
|
|
276
|
+
if verbose and aligned.wer > 0.0:
|
|
277
|
+
safe_print(
|
|
278
|
+
f"===================================WER={aligned.wer:>4.2f}====================================="
|
|
279
|
+
)
|
|
280
|
+
safe_print(
|
|
281
|
+
f" Caption idx=[{sub_start:>4d}, {sub_end:>4d}) timestamp=[{sub_align[0].start:>8.2f}, {sub_align[-1].end:>8.2f}]: {[sup.text for sup in sub_align]}"
|
|
282
|
+
)
|
|
283
|
+
safe_print(
|
|
284
|
+
f"Transcript idx=[{asr_start:>4d}, {asr_end:>4d}) timestamp=[{asr_align[0].start:>8.2f}, {asr_align[-1].end:>8.2f}]: {[sup.text for sup in asr_align]}"
|
|
285
|
+
)
|
|
286
|
+
safe_print("========================================================================\n\n")
|
|
287
|
+
|
|
288
|
+
return matches
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
class AlignFilter(ABC):
|
|
292
|
+
|
|
293
|
+
def __init__(
|
|
294
|
+
self, PUNCTUATION: str = PUNCTUATION, IGNORE: str = "", SPACE=" ", EPSILON=EPSILON, SEPARATOR=JOIN_TOKEN
|
|
295
|
+
):
|
|
296
|
+
super().__init__()
|
|
297
|
+
self._name = self.__class__.__name__
|
|
298
|
+
self.PUNCTUATION = PUNCTUATION
|
|
299
|
+
|
|
300
|
+
self.IGNORE = IGNORE
|
|
301
|
+
self.SPACE = SPACE
|
|
302
|
+
self.EPSILON = EPSILON
|
|
303
|
+
self.SEPARATOR = SEPARATOR
|
|
304
|
+
|
|
305
|
+
self.PUNCTUATION_SEPARATOR = PUNCTUATION + SEPARATOR
|
|
306
|
+
self.PUNCTUATION_SPACE = PUNCTUATION + SPACE
|
|
307
|
+
self.PUNCTUATION_SPACE_SEPARATOR = PUNCTUATION + SPACE + SEPARATOR
|
|
308
|
+
|
|
309
|
+
@abstractmethod
|
|
310
|
+
def __call__(self, chunk: List[Alignment]) -> str:
|
|
311
|
+
pass
|
|
312
|
+
|
|
313
|
+
@property
|
|
314
|
+
def name(self):
|
|
315
|
+
return self._name
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
class WERStats(namedtuple("AlignStats", ["WER", "ins_errs", "del_errs", "sub_errs", "ref_len"])):
|
|
319
|
+
def to_dict(self):
|
|
320
|
+
return {
|
|
321
|
+
"WER": self.WER,
|
|
322
|
+
"ins_errs": self.ins_errs,
|
|
323
|
+
"del_errs": self.del_errs,
|
|
324
|
+
"sub_errs": self.sub_errs,
|
|
325
|
+
"ref_len": self.ref_len,
|
|
326
|
+
}
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
# https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
|
|
330
|
+
def compute_align_stats(
|
|
331
|
+
ali: List[Tuple[str, str]],
|
|
332
|
+
ERR: str = "*",
|
|
333
|
+
IGNORE: str = "",
|
|
334
|
+
enable_log: bool = True,
|
|
335
|
+
) -> WERStats:
|
|
336
|
+
subs: Dict[Tuple[str, str], int] = defaultdict(int)
|
|
337
|
+
ins: Dict[str, int] = defaultdict(int)
|
|
338
|
+
dels: Dict[str, int] = defaultdict(int)
|
|
339
|
+
|
|
340
|
+
# `words` stores counts per word, as follows:
|
|
341
|
+
# corr, ref_sub, hyp_sub, ins, dels
|
|
342
|
+
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
|
|
343
|
+
num_corr = 0
|
|
344
|
+
|
|
345
|
+
ref_len = 0
|
|
346
|
+
skip = 0
|
|
347
|
+
for k, (ref_word, hyp_word) in enumerate(ali):
|
|
348
|
+
if skip > 0:
|
|
349
|
+
skip -= 1
|
|
350
|
+
continue
|
|
351
|
+
|
|
352
|
+
# compute_align_stats(ali, ERR=EPSILON, IGNORE=PUNCTUATION_SPACE_SEPARATOR)
|
|
353
|
+
if ali[k : k + 1] in [
|
|
354
|
+
[("is", "'s")], # what is -> what's
|
|
355
|
+
[("am", "'m")], # I am -> I'm
|
|
356
|
+
[("are", "'re")], # they are -> they're
|
|
357
|
+
[("would", "'d")], # they would -> they'd don't
|
|
358
|
+
[("had", "'d")], # I had -> I'd
|
|
359
|
+
# we will -> we'll
|
|
360
|
+
[("will", "'ll")],
|
|
361
|
+
# I have -> I've
|
|
362
|
+
[("have", "'ve")],
|
|
363
|
+
# ok -> okay
|
|
364
|
+
[("ok", "okay")],
|
|
365
|
+
# okay -> ok
|
|
366
|
+
[("okay", "ok")],
|
|
367
|
+
]:
|
|
368
|
+
skip = 1
|
|
369
|
+
ref_len += 1
|
|
370
|
+
continue
|
|
371
|
+
elif ali[k : k + 2] in [
|
|
372
|
+
# let us -> let's
|
|
373
|
+
[("let", "let"), ("us", "'s")],
|
|
374
|
+
# do not -> don't
|
|
375
|
+
[("do", "do"), ("not", "n't")],
|
|
376
|
+
]:
|
|
377
|
+
skip = 2
|
|
378
|
+
ref_len += 2
|
|
379
|
+
continue
|
|
380
|
+
elif (ref_word and ref_word in IGNORE) and (hyp_word and hyp_word in IGNORE):
|
|
381
|
+
continue
|
|
382
|
+
else:
|
|
383
|
+
ref_len += 1
|
|
384
|
+
|
|
385
|
+
if ref_word == ERR:
|
|
386
|
+
ins[hyp_word] += 1
|
|
387
|
+
words[hyp_word][3] += 1
|
|
388
|
+
elif hyp_word == ERR:
|
|
389
|
+
dels[ref_word] += 1
|
|
390
|
+
words[ref_word][4] += 1
|
|
391
|
+
elif hyp_word != ref_word:
|
|
392
|
+
subs[(ref_word, hyp_word)] += 1
|
|
393
|
+
words[ref_word][1] += 1
|
|
394
|
+
words[hyp_word][2] += 1
|
|
395
|
+
else:
|
|
396
|
+
words[ref_word][0] += 1
|
|
397
|
+
num_corr += 1
|
|
398
|
+
|
|
399
|
+
sub_errs = sum(subs.values())
|
|
400
|
+
ins_errs = sum(ins.values())
|
|
401
|
+
del_errs = sum(dels.values())
|
|
402
|
+
tot_errs = sub_errs + ins_errs + del_errs
|
|
403
|
+
|
|
404
|
+
stats = WERStats(
|
|
405
|
+
WER=round(tot_errs / max(ref_len, 1), ndigits=4),
|
|
406
|
+
ins_errs=ins_errs,
|
|
407
|
+
del_errs=del_errs,
|
|
408
|
+
sub_errs=sub_errs,
|
|
409
|
+
ref_len=ref_len,
|
|
410
|
+
)
|
|
411
|
+
|
|
412
|
+
if enable_log:
|
|
413
|
+
logging.info(
|
|
414
|
+
f"%WER {stats.WER:.4%} "
|
|
415
|
+
f"[{tot_errs} / {max(ref_len, 1)}, {ins_errs} ins, "
|
|
416
|
+
f"{del_errs} del, {sub_errs} sub ]"
|
|
417
|
+
)
|
|
418
|
+
|
|
419
|
+
return stats
|
|
420
|
+
|
|
421
|
+
|
|
422
|
+
class WERFilter(AlignFilter):
|
|
423
|
+
def __call__(self, chunk: List[Alignment]) -> WERStats:
|
|
424
|
+
ali = [(a.ref, a.hyp) for a in chunk]
|
|
425
|
+
stats = compute_align_stats(ali, ERR=JOIN_TOKEN, IGNORE=self.PUNCTUATION_SPACE_SEPARATOR, enable_log=False)
|
|
426
|
+
return stats
|
|
427
|
+
|
|
428
|
+
|
|
429
|
+
def quality(
|
|
430
|
+
chunk: List[Alignment], supervision: Tuple[float, float], transcript: Tuple[float, float], wer_fn: Callable
|
|
431
|
+
) -> Tuple[AlignQuality, TimestampQuality]:
|
|
432
|
+
_quality = AlignQuality(
|
|
433
|
+
FW="FE" if chunk and chunk[0].op_type == OpType.MATCH else "FN",
|
|
434
|
+
LW="LE" if chunk and chunk[-1].op_type == OpType.MATCH else "LN",
|
|
435
|
+
PREFIX=equal_ratio(chunk[:4]),
|
|
436
|
+
SUFFIX=equal_ratio(chunk[-4:]),
|
|
437
|
+
WER=wer_fn(chunk),
|
|
438
|
+
)
|
|
439
|
+
timestamp = TimestampQuality(start=(supervision[0], transcript[0]), end=(supervision[1], transcript[1]))
|
|
440
|
+
return _quality, timestamp
|