lattifai 1.2.0__py3-none-any.whl → 1.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. lattifai/__init__.py +0 -24
  2. lattifai/alignment/__init__.py +10 -1
  3. lattifai/alignment/lattice1_aligner.py +66 -58
  4. lattifai/alignment/lattice1_worker.py +1 -6
  5. lattifai/alignment/punctuation.py +38 -0
  6. lattifai/alignment/segmenter.py +1 -1
  7. lattifai/alignment/sentence_splitter.py +350 -0
  8. lattifai/alignment/text_align.py +440 -0
  9. lattifai/alignment/tokenizer.py +91 -220
  10. lattifai/caption/__init__.py +82 -6
  11. lattifai/caption/caption.py +335 -1143
  12. lattifai/caption/formats/__init__.py +199 -0
  13. lattifai/caption/formats/base.py +211 -0
  14. lattifai/caption/formats/gemini.py +722 -0
  15. lattifai/caption/formats/json.py +194 -0
  16. lattifai/caption/formats/lrc.py +309 -0
  17. lattifai/caption/formats/nle/__init__.py +9 -0
  18. lattifai/caption/formats/nle/audition.py +561 -0
  19. lattifai/caption/formats/nle/avid.py +423 -0
  20. lattifai/caption/formats/nle/fcpxml.py +549 -0
  21. lattifai/caption/formats/nle/premiere.py +589 -0
  22. lattifai/caption/formats/pysubs2.py +642 -0
  23. lattifai/caption/formats/sbv.py +147 -0
  24. lattifai/caption/formats/tabular.py +338 -0
  25. lattifai/caption/formats/textgrid.py +193 -0
  26. lattifai/caption/formats/ttml.py +652 -0
  27. lattifai/caption/formats/vtt.py +469 -0
  28. lattifai/caption/parsers/__init__.py +9 -0
  29. lattifai/caption/{text_parser.py → parsers/text_parser.py} +4 -2
  30. lattifai/caption/standardize.py +636 -0
  31. lattifai/caption/utils.py +474 -0
  32. lattifai/cli/__init__.py +2 -1
  33. lattifai/cli/caption.py +108 -1
  34. lattifai/cli/transcribe.py +4 -9
  35. lattifai/cli/youtube.py +4 -1
  36. lattifai/client.py +48 -84
  37. lattifai/config/__init__.py +11 -1
  38. lattifai/config/alignment.py +9 -2
  39. lattifai/config/caption.py +267 -23
  40. lattifai/config/media.py +20 -0
  41. lattifai/diarization/__init__.py +41 -1
  42. lattifai/mixin.py +36 -18
  43. lattifai/transcription/base.py +6 -1
  44. lattifai/transcription/lattifai.py +19 -54
  45. lattifai/utils.py +81 -13
  46. lattifai/workflow/__init__.py +28 -4
  47. lattifai/workflow/file_manager.py +2 -5
  48. lattifai/youtube/__init__.py +43 -0
  49. lattifai/youtube/client.py +1170 -0
  50. lattifai/youtube/types.py +23 -0
  51. lattifai-1.2.2.dist-info/METADATA +615 -0
  52. lattifai-1.2.2.dist-info/RECORD +76 -0
  53. {lattifai-1.2.0.dist-info → lattifai-1.2.2.dist-info}/entry_points.txt +1 -2
  54. lattifai/caption/gemini_reader.py +0 -371
  55. lattifai/caption/gemini_writer.py +0 -173
  56. lattifai/cli/app_installer.py +0 -142
  57. lattifai/cli/server.py +0 -44
  58. lattifai/server/app.py +0 -427
  59. lattifai/workflow/youtube.py +0 -577
  60. lattifai-1.2.0.dist-info/METADATA +0 -1133
  61. lattifai-1.2.0.dist-info/RECORD +0 -57
  62. {lattifai-1.2.0.dist-info → lattifai-1.2.2.dist-info}/WHEEL +0 -0
  63. {lattifai-1.2.0.dist-info → lattifai-1.2.2.dist-info}/licenses/LICENSE +0 -0
  64. {lattifai-1.2.0.dist-info → lattifai-1.2.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,440 @@
1
+ # Align caption.supervisions with transcription
2
+ import logging
3
+ import string # noqa: F401
4
+ from abc import ABC, abstractmethod
5
+ from collections import defaultdict, namedtuple
6
+ from typing import Callable, Dict, List, Optional, Tuple, TypeVar
7
+
8
+ import regex
9
+ from error_align import error_align
10
+ from error_align.utils import DELIMITERS, NUMERIC_TOKEN, STANDARD_TOKEN, Alignment, OpType
11
+
12
+ from lattifai.caption import Caption, Supervision
13
+ from lattifai.utils import safe_print
14
+
15
+ from .punctuation import PUNCTUATION
16
+
17
+ Symbol = TypeVar("Symbol")
18
+ EPSILON = "`"
19
+
20
+ JOIN_TOKEN = "❄"
21
+ if JOIN_TOKEN not in DELIMITERS:
22
+ DELIMITERS.add(JOIN_TOKEN)
23
+
24
+
25
+ def custom_tokenizer(text: str) -> list:
26
+ """Default tokenizer that splits text into words based on whitespace.
27
+
28
+ Args:
29
+ text (str): The input text to tokenize.
30
+
31
+ Returns:
32
+ list: A list of tokens (words).
33
+
34
+ """
35
+ # Escape JOIN_TOKEN for use in regex pattern
36
+ escaped_join_token = regex.escape(JOIN_TOKEN)
37
+ return list(
38
+ regex.finditer(
39
+ rf"({NUMERIC_TOKEN})|({STANDARD_TOKEN}|{escaped_join_token})",
40
+ text,
41
+ regex.UNICODE | regex.VERBOSE,
42
+ )
43
+ )
44
+
45
+
46
+ def equal_ratio(chunk: List[Alignment]):
47
+ return sum(a.op_type == OpType.MATCH for a in chunk) / max(len(chunk), 1)
48
+
49
+
50
+ def equal(chunk: List[Alignment]):
51
+ return all(a.op_type == OpType.MATCH for a in chunk)
52
+
53
+
54
+ def group_alignments(
55
+ supervisions: List[Supervision],
56
+ transcription: List[Supervision],
57
+ max_silence_gap: float = 10.0,
58
+ mini_num_supervisions: int = 1,
59
+ mini_num_transcription: int = 1,
60
+ equal_threshold: float = 0.5,
61
+ verbose: bool = False,
62
+ ) -> List[Tuple[Tuple[int, int], Tuple[int, int], List[Alignment]]]:
63
+ # TABLE = str.maketrans(dict.fromkeys(string.punctuation))
64
+ # sup.text.lower().translate(TABLE)
65
+ ref = "".join(sup.text.lower() + JOIN_TOKEN for sup in supervisions)
66
+ hyp = "".join(sup.text.lower() + JOIN_TOKEN for sup in transcription)
67
+ alignments = error_align(ref, hyp, tokenizer=custom_tokenizer)
68
+
69
+ matches = []
70
+ # segment start index
71
+ ss_start, ss_idx = 0, 0
72
+ tr_start, tr_idx = 0, 0
73
+
74
+ idx_start = 0
75
+ for idx, ali in enumerate(alignments):
76
+ if ali.ref == JOIN_TOKEN:
77
+ ss_idx += 1
78
+ if ali.hyp == JOIN_TOKEN:
79
+ tr_idx += 1
80
+
81
+ if ali.ref == JOIN_TOKEN and ali.hyp == JOIN_TOKEN:
82
+ chunk = alignments[idx_start:idx]
83
+
84
+ split_at_silence = False
85
+ greater_two_silence_gap = False
86
+ if tr_idx > 0 and tr_idx < len(transcription):
87
+ gap = transcription[tr_idx].start - transcription[tr_idx - 1].end
88
+ if gap > max_silence_gap:
89
+ split_at_silence = True
90
+ if gap > 2 * max_silence_gap:
91
+ greater_two_silence_gap = True
92
+
93
+ if (
94
+ (equal_ratio(chunk[:10]) > equal_threshold or equal_ratio(chunk[-10:]) > equal_threshold)
95
+ and (
96
+ (ss_idx - ss_start >= mini_num_supervisions and tr_idx - tr_start >= mini_num_transcription)
97
+ or split_at_silence
98
+ )
99
+ ) or greater_two_silence_gap:
100
+ matches.append(((ss_start, ss_idx), (tr_start, tr_idx), chunk))
101
+
102
+ if verbose:
103
+ sub_align = supervisions[ss_start:ss_idx]
104
+ asr_align = transcription[tr_start:tr_idx]
105
+ safe_print("========================================================================")
106
+ safe_print(f" Caption [{ss_start:>4d}, {ss_idx:>4d}): {[sup.text for sup in sub_align]}")
107
+ safe_print(f"Transcript [{tr_start:>4d}, {tr_idx:>4d}): {[sup.text for sup in asr_align]}")
108
+ safe_print("========================================================================\n\n")
109
+
110
+ ss_start = ss_idx
111
+ tr_start = tr_idx
112
+ idx_start = idx + 1
113
+
114
+ if ss_start == len(supervisions) and tr_start == len(transcription):
115
+ break
116
+
117
+ # remainder
118
+ if ss_idx == len(supervisions) or tr_idx == len(transcription):
119
+ chunk = alignments[idx_start:]
120
+ matches.append(((ss_start, len(supervisions)), (tr_start, len(transcription)), chunk))
121
+ break
122
+
123
+ return matches
124
+
125
+
126
+ class AlignQuality(namedtuple("AlignQuality", ["FW", "LW", "PREFIX", "SUFFIX", "WER"])):
127
+ def __repr__(self) -> str:
128
+ quality = f"WORD[{self.FW}][{self.LW}]_WER[{self.WER}]"
129
+ return quality
130
+
131
+ @property
132
+ def info(self) -> str:
133
+ info = f"WER {self.WER.WER:.4f} accuracy [{self.PREFIX:.2f}, {self.SUFFIX:.2f}] {self.WER}"
134
+ return info
135
+
136
+ @property
137
+ def first_word_equal(self) -> bool:
138
+ return self.FW == "FE"
139
+
140
+ @property
141
+ def last_word_equal(self) -> bool:
142
+ return self.LW == "LE"
143
+
144
+ @property
145
+ def wer(self) -> float:
146
+ return self.WER.WER
147
+
148
+ @property
149
+ def qwer(self):
150
+ wer = self.wer
151
+ # 考虑 ref_len
152
+ if wer == 0.0:
153
+ return "WZ" # zero
154
+ elif wer < 0.1:
155
+ return "WL" # low
156
+ elif wer < 0.32:
157
+ return "WM" # medium
158
+ else:
159
+ return "WH" # high
160
+
161
+ @property
162
+ def qprefix(self) -> str:
163
+ if self.PREFIX >= 0.7:
164
+ return "PH" # high
165
+ elif self.PREFIX >= 0.5:
166
+ return "PM" # medium
167
+ else:
168
+ return "PL" # low
169
+
170
+ @property
171
+ def qsuffix(self) -> str:
172
+ if self.SUFFIX > 0.7:
173
+ return "SH" # high
174
+ elif self.SUFFIX > 0.5:
175
+ return "SM" # medium
176
+ else:
177
+ return "SL" # low
178
+
179
+
180
+ class TimestampQuality(namedtuple("TimestampQuality", ["start", "end"])):
181
+ @property
182
+ def start_diff(self):
183
+ return abs(self.start[0] - self.start[1])
184
+
185
+ @property
186
+ def end_diff(self):
187
+ return abs(self.end[0] - self.end[1])
188
+
189
+ @property
190
+ def diff(self):
191
+ return max(self.start_diff, self.end_diff)
192
+
193
+
194
+ TextAlignResult = Tuple[Optional[List[Supervision]], Optional[List[Supervision]], AlignQuality, TimestampQuality, int]
195
+
196
+
197
+ def align_supervisions_and_transcription(
198
+ caption: Caption,
199
+ max_duration: Optional[float] = None,
200
+ verbose: bool = False,
201
+ ) -> List[TextAlignResult]:
202
+ """Align caption.supervisions with caption.transcription.
203
+
204
+ Args:
205
+ caption: Caption object containing supervisions and transcription.
206
+
207
+ """
208
+ groups = group_alignments(caption.supervisions, caption.transcription, verbose=False)
209
+
210
+ if max_duration is None:
211
+ max_duration = max(caption.transcription[-1].end, caption.supervisions[-1].end)
212
+ else:
213
+ max_duration = min(
214
+ max_duration,
215
+ max(caption.transcription[-1].end, caption.supervisions[-1].end) + 10.0,
216
+ )
217
+
218
+ def next_start(alignments: List[Supervision], idx: int) -> float:
219
+ if idx < len(alignments):
220
+ return alignments[idx].start
221
+ return min(alignments[-1].end + 2.0, max_duration)
222
+
223
+ wer_filter = WERFilter()
224
+
225
+ matches = []
226
+ for idx, ((sub_start, sub_end), (asr_start, asr_end), chunk) in enumerate(groups):
227
+ sub_align = caption.supervisions[sub_start:sub_end]
228
+ asr_align = caption.transcription[asr_start:asr_end]
229
+
230
+ if not sub_align or not asr_align:
231
+ if sub_align:
232
+ if matches:
233
+ _asr_start = matches[-1][-2].end[1]
234
+ else:
235
+ _asr_start = 0.0
236
+ startends = [
237
+ (sub_align[0].start, sub_align[-1].end),
238
+ (_asr_start, next_start(caption.transcription, asr_end)),
239
+ ]
240
+ elif asr_align:
241
+ if matches:
242
+ _sub_start = matches[-1][-2].end[0]
243
+ else:
244
+ _sub_start = 0.0
245
+ startends = [
246
+ (_sub_start, next_start(caption.supervisions, sub_end)),
247
+ (asr_align[0].start, asr_align[-1].end),
248
+ ]
249
+ else:
250
+ raise ValueError(
251
+ f"Never Here! subtitles[{len(caption.supervisions)}] {sub_start}-{sub_end} asrs[{len(caption.transcription)}] {asr_start}-{asr_end}"
252
+ )
253
+
254
+ if verbose:
255
+ safe_print("oooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo")
256
+ safe_print(
257
+ f" Caption idx=[{sub_start:>4d}, {sub_end:>4d}) timestamp=[{startends[0][0]:>8.2f}, {startends[0][1]:>8.2f}]: {[sup.text for sup in sub_align]}"
258
+ )
259
+ safe_print(
260
+ f"Transcript idx=[{asr_start:>4d}, {asr_end:>4d}) timestamp=[{startends[1][0]:>8.2f}, {startends[1][1]:>8.2f}]: {[sup.text for sup in asr_align]}"
261
+ )
262
+ safe_print("oooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo\n\n")
263
+
264
+ aligned, timestamp = quality(chunk, startends[0], startends[1], wer_fn=wer_filter)
265
+ matches.append([sub_align, asr_align, aligned, timestamp, chunk])
266
+ continue
267
+ else:
268
+ aligned, timestamp = quality(
269
+ chunk,
270
+ [sub_align[0].start, sub_align[-1].end],
271
+ [asr_align[0].start, asr_align[-1].end],
272
+ wer_fn=wer_filter,
273
+ )
274
+ matches.append([sub_align, asr_align, aligned, timestamp, chunk])
275
+
276
+ if verbose and aligned.wer > 0.0:
277
+ safe_print(
278
+ f"===================================WER={aligned.wer:>4.2f}====================================="
279
+ )
280
+ safe_print(
281
+ f" Caption idx=[{sub_start:>4d}, {sub_end:>4d}) timestamp=[{sub_align[0].start:>8.2f}, {sub_align[-1].end:>8.2f}]: {[sup.text for sup in sub_align]}"
282
+ )
283
+ safe_print(
284
+ f"Transcript idx=[{asr_start:>4d}, {asr_end:>4d}) timestamp=[{asr_align[0].start:>8.2f}, {asr_align[-1].end:>8.2f}]: {[sup.text for sup in asr_align]}"
285
+ )
286
+ safe_print("========================================================================\n\n")
287
+
288
+ return matches
289
+
290
+
291
+ class AlignFilter(ABC):
292
+
293
+ def __init__(
294
+ self, PUNCTUATION: str = PUNCTUATION, IGNORE: str = "", SPACE=" ", EPSILON=EPSILON, SEPARATOR=JOIN_TOKEN
295
+ ):
296
+ super().__init__()
297
+ self._name = self.__class__.__name__
298
+ self.PUNCTUATION = PUNCTUATION
299
+
300
+ self.IGNORE = IGNORE
301
+ self.SPACE = SPACE
302
+ self.EPSILON = EPSILON
303
+ self.SEPARATOR = SEPARATOR
304
+
305
+ self.PUNCTUATION_SEPARATOR = PUNCTUATION + SEPARATOR
306
+ self.PUNCTUATION_SPACE = PUNCTUATION + SPACE
307
+ self.PUNCTUATION_SPACE_SEPARATOR = PUNCTUATION + SPACE + SEPARATOR
308
+
309
+ @abstractmethod
310
+ def __call__(self, chunk: List[Alignment]) -> str:
311
+ pass
312
+
313
+ @property
314
+ def name(self):
315
+ return self._name
316
+
317
+
318
+ class WERStats(namedtuple("AlignStats", ["WER", "ins_errs", "del_errs", "sub_errs", "ref_len"])):
319
+ def to_dict(self):
320
+ return {
321
+ "WER": self.WER,
322
+ "ins_errs": self.ins_errs,
323
+ "del_errs": self.del_errs,
324
+ "sub_errs": self.sub_errs,
325
+ "ref_len": self.ref_len,
326
+ }
327
+
328
+
329
+ # https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
330
+ def compute_align_stats(
331
+ ali: List[Tuple[str, str]],
332
+ ERR: str = "*",
333
+ IGNORE: str = "",
334
+ enable_log: bool = True,
335
+ ) -> WERStats:
336
+ subs: Dict[Tuple[str, str], int] = defaultdict(int)
337
+ ins: Dict[str, int] = defaultdict(int)
338
+ dels: Dict[str, int] = defaultdict(int)
339
+
340
+ # `words` stores counts per word, as follows:
341
+ # corr, ref_sub, hyp_sub, ins, dels
342
+ words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
343
+ num_corr = 0
344
+
345
+ ref_len = 0
346
+ skip = 0
347
+ for k, (ref_word, hyp_word) in enumerate(ali):
348
+ if skip > 0:
349
+ skip -= 1
350
+ continue
351
+
352
+ # compute_align_stats(ali, ERR=EPSILON, IGNORE=PUNCTUATION_SPACE_SEPARATOR)
353
+ if ali[k : k + 1] in [
354
+ [("is", "'s")], # what is -> what's
355
+ [("am", "'m")], # I am -> I'm
356
+ [("are", "'re")], # they are -> they're
357
+ [("would", "'d")], # they would -> they'd don't
358
+ [("had", "'d")], # I had -> I'd
359
+ # we will -> we'll
360
+ [("will", "'ll")],
361
+ # I have -> I've
362
+ [("have", "'ve")],
363
+ # ok -> okay
364
+ [("ok", "okay")],
365
+ # okay -> ok
366
+ [("okay", "ok")],
367
+ ]:
368
+ skip = 1
369
+ ref_len += 1
370
+ continue
371
+ elif ali[k : k + 2] in [
372
+ # let us -> let's
373
+ [("let", "let"), ("us", "'s")],
374
+ # do not -> don't
375
+ [("do", "do"), ("not", "n't")],
376
+ ]:
377
+ skip = 2
378
+ ref_len += 2
379
+ continue
380
+ elif (ref_word and ref_word in IGNORE) and (hyp_word and hyp_word in IGNORE):
381
+ continue
382
+ else:
383
+ ref_len += 1
384
+
385
+ if ref_word == ERR:
386
+ ins[hyp_word] += 1
387
+ words[hyp_word][3] += 1
388
+ elif hyp_word == ERR:
389
+ dels[ref_word] += 1
390
+ words[ref_word][4] += 1
391
+ elif hyp_word != ref_word:
392
+ subs[(ref_word, hyp_word)] += 1
393
+ words[ref_word][1] += 1
394
+ words[hyp_word][2] += 1
395
+ else:
396
+ words[ref_word][0] += 1
397
+ num_corr += 1
398
+
399
+ sub_errs = sum(subs.values())
400
+ ins_errs = sum(ins.values())
401
+ del_errs = sum(dels.values())
402
+ tot_errs = sub_errs + ins_errs + del_errs
403
+
404
+ stats = WERStats(
405
+ WER=round(tot_errs / max(ref_len, 1), ndigits=4),
406
+ ins_errs=ins_errs,
407
+ del_errs=del_errs,
408
+ sub_errs=sub_errs,
409
+ ref_len=ref_len,
410
+ )
411
+
412
+ if enable_log:
413
+ logging.info(
414
+ f"%WER {stats.WER:.4%} "
415
+ f"[{tot_errs} / {max(ref_len, 1)}, {ins_errs} ins, "
416
+ f"{del_errs} del, {sub_errs} sub ]"
417
+ )
418
+
419
+ return stats
420
+
421
+
422
+ class WERFilter(AlignFilter):
423
+ def __call__(self, chunk: List[Alignment]) -> WERStats:
424
+ ali = [(a.ref, a.hyp) for a in chunk]
425
+ stats = compute_align_stats(ali, ERR=JOIN_TOKEN, IGNORE=self.PUNCTUATION_SPACE_SEPARATOR, enable_log=False)
426
+ return stats
427
+
428
+
429
+ def quality(
430
+ chunk: List[Alignment], supervision: Tuple[float, float], transcript: Tuple[float, float], wer_fn: Callable
431
+ ) -> Tuple[AlignQuality, TimestampQuality]:
432
+ _quality = AlignQuality(
433
+ FW="FE" if chunk and chunk[0].op_type == OpType.MATCH else "FN",
434
+ LW="LE" if chunk and chunk[-1].op_type == OpType.MATCH else "LN",
435
+ PREFIX=equal_ratio(chunk[:4]),
436
+ SUFFIX=equal_ratio(chunk[-4:]),
437
+ WER=wer_fn(chunk),
438
+ )
439
+ timestamp = TimestampQuality(start=(supervision[0], transcript[0]), end=(supervision[1], transcript[1]))
440
+ return _quality, timestamp