lattifai 1.2.1__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. lattifai/_init.py +20 -0
  2. lattifai/alignment/__init__.py +9 -1
  3. lattifai/alignment/lattice1_aligner.py +175 -54
  4. lattifai/alignment/lattice1_worker.py +47 -4
  5. lattifai/alignment/punctuation.py +38 -0
  6. lattifai/alignment/segmenter.py +3 -2
  7. lattifai/alignment/text_align.py +441 -0
  8. lattifai/alignment/tokenizer.py +134 -65
  9. lattifai/audio2.py +162 -183
  10. lattifai/cli/__init__.py +2 -1
  11. lattifai/cli/alignment.py +5 -0
  12. lattifai/cli/caption.py +111 -4
  13. lattifai/cli/transcribe.py +2 -6
  14. lattifai/cli/youtube.py +7 -1
  15. lattifai/client.py +72 -123
  16. lattifai/config/__init__.py +28 -0
  17. lattifai/config/alignment.py +14 -0
  18. lattifai/config/caption.py +45 -31
  19. lattifai/config/client.py +16 -0
  20. lattifai/config/event.py +102 -0
  21. lattifai/config/media.py +20 -0
  22. lattifai/config/transcription.py +25 -1
  23. lattifai/data/__init__.py +8 -0
  24. lattifai/data/caption.py +228 -0
  25. lattifai/diarization/__init__.py +41 -1
  26. lattifai/errors.py +78 -53
  27. lattifai/event/__init__.py +65 -0
  28. lattifai/event/lattifai.py +166 -0
  29. lattifai/mixin.py +49 -32
  30. lattifai/transcription/base.py +8 -2
  31. lattifai/transcription/gemini.py +147 -16
  32. lattifai/transcription/lattifai.py +25 -63
  33. lattifai/types.py +1 -1
  34. lattifai/utils.py +7 -13
  35. lattifai/workflow/__init__.py +28 -4
  36. lattifai/workflow/file_manager.py +2 -5
  37. lattifai/youtube/__init__.py +43 -0
  38. lattifai/youtube/client.py +1265 -0
  39. lattifai/youtube/types.py +23 -0
  40. lattifai-1.3.0.dist-info/METADATA +678 -0
  41. lattifai-1.3.0.dist-info/RECORD +57 -0
  42. {lattifai-1.2.1.dist-info → lattifai-1.3.0.dist-info}/entry_points.txt +1 -2
  43. lattifai/__init__.py +0 -88
  44. lattifai/alignment/sentence_splitter.py +0 -219
  45. lattifai/caption/__init__.py +0 -20
  46. lattifai/caption/caption.py +0 -1467
  47. lattifai/caption/gemini_reader.py +0 -462
  48. lattifai/caption/gemini_writer.py +0 -173
  49. lattifai/caption/supervision.py +0 -34
  50. lattifai/caption/text_parser.py +0 -145
  51. lattifai/cli/app_installer.py +0 -142
  52. lattifai/cli/server.py +0 -44
  53. lattifai/server/app.py +0 -427
  54. lattifai/workflow/youtube.py +0 -577
  55. lattifai-1.2.1.dist-info/METADATA +0 -1134
  56. lattifai-1.2.1.dist-info/RECORD +0 -58
  57. {lattifai-1.2.1.dist-info → lattifai-1.3.0.dist-info}/WHEEL +0 -0
  58. {lattifai-1.2.1.dist-info → lattifai-1.3.0.dist-info}/licenses/LICENSE +0 -0
  59. {lattifai-1.2.1.dist-info → lattifai-1.3.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,441 @@
1
+ # Align caption.supervisions with transcription
2
+ import logging
3
+ import string # noqa: F401
4
+ from abc import ABC, abstractmethod
5
+ from collections import defaultdict, namedtuple
6
+ from typing import Callable, Dict, List, Optional, Tuple, TypeVar
7
+
8
+ import regex
9
+ from error_align import error_align
10
+ from error_align.utils import DELIMITERS, NUMERIC_TOKEN, STANDARD_TOKEN, Alignment, OpType
11
+
12
+ from lattifai.caption import Supervision
13
+ from lattifai.data import Caption
14
+ from lattifai.utils import safe_print
15
+
16
+ from .punctuation import PUNCTUATION
17
+
18
+ Symbol = TypeVar("Symbol")
19
+ EPSILON = "`"
20
+
21
+ JOIN_TOKEN = "❄"
22
+ if JOIN_TOKEN not in DELIMITERS:
23
+ DELIMITERS.add(JOIN_TOKEN)
24
+
25
+
26
+ def custom_tokenizer(text: str) -> list:
27
+ """Default tokenizer that splits text into words based on whitespace.
28
+
29
+ Args:
30
+ text (str): The input text to tokenize.
31
+
32
+ Returns:
33
+ list: A list of tokens (words).
34
+
35
+ """
36
+ # Escape JOIN_TOKEN for use in regex pattern
37
+ escaped_join_token = regex.escape(JOIN_TOKEN)
38
+ return list(
39
+ regex.finditer(
40
+ rf"({NUMERIC_TOKEN})|({STANDARD_TOKEN}|{escaped_join_token})",
41
+ text,
42
+ regex.UNICODE | regex.VERBOSE,
43
+ )
44
+ )
45
+
46
+
47
+ def equal_ratio(chunk: List[Alignment]):
48
+ return sum(a.op_type == OpType.MATCH for a in chunk) / max(len(chunk), 1)
49
+
50
+
51
+ def equal(chunk: List[Alignment]):
52
+ return all(a.op_type == OpType.MATCH for a in chunk)
53
+
54
+
55
+ def group_alignments(
56
+ supervisions: List[Supervision],
57
+ transcription: List[Supervision],
58
+ max_silence_gap: float = 10.0,
59
+ mini_num_supervisions: int = 1,
60
+ mini_num_transcription: int = 1,
61
+ equal_threshold: float = 0.5,
62
+ verbose: bool = False,
63
+ ) -> List[Tuple[Tuple[int, int], Tuple[int, int], List[Alignment]]]:
64
+ # TABLE = str.maketrans(dict.fromkeys(string.punctuation))
65
+ # sup.text.lower().translate(TABLE)
66
+ ref = "".join(sup.text.lower() + JOIN_TOKEN for sup in supervisions)
67
+ hyp = "".join(sup.text.lower() + JOIN_TOKEN for sup in transcription)
68
+ alignments = error_align(ref, hyp, tokenizer=custom_tokenizer)
69
+
70
+ matches = []
71
+ # segment start index
72
+ ss_start, ss_idx = 0, 0
73
+ tr_start, tr_idx = 0, 0
74
+
75
+ idx_start = 0
76
+ for idx, ali in enumerate(alignments):
77
+ if ali.ref == JOIN_TOKEN:
78
+ ss_idx += 1
79
+ if ali.hyp == JOIN_TOKEN:
80
+ tr_idx += 1
81
+
82
+ if ali.ref == JOIN_TOKEN and ali.hyp == JOIN_TOKEN:
83
+ chunk = alignments[idx_start:idx]
84
+
85
+ split_at_silence = False
86
+ greater_two_silence_gap = False
87
+ if tr_idx > 0 and tr_idx < len(transcription):
88
+ gap = transcription[tr_idx].start - transcription[tr_idx - 1].end
89
+ if gap > max_silence_gap:
90
+ split_at_silence = True
91
+ if gap > 2 * max_silence_gap:
92
+ greater_two_silence_gap = True
93
+
94
+ if (
95
+ (equal_ratio(chunk[:10]) > equal_threshold or equal_ratio(chunk[-10:]) > equal_threshold)
96
+ and (
97
+ (ss_idx - ss_start >= mini_num_supervisions and tr_idx - tr_start >= mini_num_transcription)
98
+ or split_at_silence
99
+ )
100
+ ) or greater_two_silence_gap:
101
+ matches.append(((ss_start, ss_idx), (tr_start, tr_idx), chunk))
102
+
103
+ if verbose:
104
+ sub_align = supervisions[ss_start:ss_idx]
105
+ asr_align = transcription[tr_start:tr_idx]
106
+ safe_print("========================================================================")
107
+ safe_print(f" Caption [{ss_start:>4d}, {ss_idx:>4d}): {[sup.text for sup in sub_align]}")
108
+ safe_print(f"Transcript [{tr_start:>4d}, {tr_idx:>4d}): {[sup.text for sup in asr_align]}")
109
+ safe_print("========================================================================\n\n")
110
+
111
+ ss_start = ss_idx
112
+ tr_start = tr_idx
113
+ idx_start = idx + 1
114
+
115
+ if ss_start == len(supervisions) and tr_start == len(transcription):
116
+ break
117
+
118
+ # remainder
119
+ if ss_idx == len(supervisions) or tr_idx == len(transcription):
120
+ chunk = alignments[idx_start:]
121
+ matches.append(((ss_start, len(supervisions)), (tr_start, len(transcription)), chunk))
122
+ break
123
+
124
+ return matches
125
+
126
+
127
+ class AlignQuality(namedtuple("AlignQuality", ["FW", "LW", "PREFIX", "SUFFIX", "WER"])):
128
+ def __repr__(self) -> str:
129
+ quality = f"WORD[{self.FW}][{self.LW}]_WER[{self.WER}]"
130
+ return quality
131
+
132
+ @property
133
+ def info(self) -> str:
134
+ info = f"WER {self.WER.WER:.4f} accuracy [{self.PREFIX:.2f}, {self.SUFFIX:.2f}] {self.WER}"
135
+ return info
136
+
137
+ @property
138
+ def first_word_equal(self) -> bool:
139
+ return self.FW == "FE"
140
+
141
+ @property
142
+ def last_word_equal(self) -> bool:
143
+ return self.LW == "LE"
144
+
145
+ @property
146
+ def wer(self) -> float:
147
+ return self.WER.WER
148
+
149
+ @property
150
+ def qwer(self):
151
+ wer = self.wer
152
+ # 考虑 ref_len
153
+ if wer == 0.0:
154
+ return "WZ" # zero
155
+ elif wer < 0.1:
156
+ return "WL" # low
157
+ elif wer < 0.32:
158
+ return "WM" # medium
159
+ else:
160
+ return "WH" # high
161
+
162
+ @property
163
+ def qprefix(self) -> str:
164
+ if self.PREFIX >= 0.7:
165
+ return "PH" # high
166
+ elif self.PREFIX >= 0.5:
167
+ return "PM" # medium
168
+ else:
169
+ return "PL" # low
170
+
171
+ @property
172
+ def qsuffix(self) -> str:
173
+ if self.SUFFIX > 0.7:
174
+ return "SH" # high
175
+ elif self.SUFFIX > 0.5:
176
+ return "SM" # medium
177
+ else:
178
+ return "SL" # low
179
+
180
+
181
+ class TimestampQuality(namedtuple("TimestampQuality", ["start", "end"])):
182
+ @property
183
+ def start_diff(self):
184
+ return abs(self.start[0] - self.start[1])
185
+
186
+ @property
187
+ def end_diff(self):
188
+ return abs(self.end[0] - self.end[1])
189
+
190
+ @property
191
+ def diff(self):
192
+ return max(self.start_diff, self.end_diff)
193
+
194
+
195
+ TextAlignResult = Tuple[Optional[List[Supervision]], Optional[List[Supervision]], AlignQuality, TimestampQuality, int]
196
+
197
+
198
+ def align_supervisions_and_transcription(
199
+ caption: Caption,
200
+ max_duration: Optional[float] = None,
201
+ verbose: bool = False,
202
+ ) -> List[TextAlignResult]:
203
+ """Align caption.supervisions with caption.transcription.
204
+
205
+ Args:
206
+ caption: Caption object containing supervisions and transcription.
207
+
208
+ """
209
+ groups = group_alignments(caption.supervisions, caption.transcription, verbose=False)
210
+
211
+ if max_duration is None:
212
+ max_duration = max(caption.transcription[-1].end, caption.supervisions[-1].end)
213
+ else:
214
+ max_duration = min(
215
+ max_duration,
216
+ max(caption.transcription[-1].end, caption.supervisions[-1].end) + 10.0,
217
+ )
218
+
219
+ def next_start(alignments: List[Supervision], idx: int) -> float:
220
+ if idx < len(alignments):
221
+ return alignments[idx].start
222
+ return min(alignments[-1].end + 2.0, max_duration)
223
+
224
+ wer_filter = WERFilter()
225
+
226
+ matches = []
227
+ for idx, ((sub_start, sub_end), (asr_start, asr_end), chunk) in enumerate(groups):
228
+ sub_align = caption.supervisions[sub_start:sub_end]
229
+ asr_align = caption.transcription[asr_start:asr_end]
230
+
231
+ if not sub_align or not asr_align:
232
+ if sub_align:
233
+ if matches:
234
+ _asr_start = matches[-1][-2].end[1]
235
+ else:
236
+ _asr_start = 0.0
237
+ startends = [
238
+ (sub_align[0].start, sub_align[-1].end),
239
+ (_asr_start, next_start(caption.transcription, asr_end)),
240
+ ]
241
+ elif asr_align:
242
+ if matches:
243
+ _sub_start = matches[-1][-2].end[0]
244
+ else:
245
+ _sub_start = 0.0
246
+ startends = [
247
+ (_sub_start, next_start(caption.supervisions, sub_end)),
248
+ (asr_align[0].start, asr_align[-1].end),
249
+ ]
250
+ else:
251
+ raise ValueError(
252
+ f"Never Here! subtitles[{len(caption.supervisions)}] {sub_start}-{sub_end} asrs[{len(caption.transcription)}] {asr_start}-{asr_end}"
253
+ )
254
+
255
+ if verbose:
256
+ safe_print("oooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo")
257
+ safe_print(
258
+ f" Caption idx=[{sub_start:>4d}, {sub_end:>4d}) timestamp=[{startends[0][0]:>8.2f}, {startends[0][1]:>8.2f}]: {[sup.text for sup in sub_align]}"
259
+ )
260
+ safe_print(
261
+ f"Transcript idx=[{asr_start:>4d}, {asr_end:>4d}) timestamp=[{startends[1][0]:>8.2f}, {startends[1][1]:>8.2f}]: {[sup.text for sup in asr_align]}"
262
+ )
263
+ safe_print("oooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo\n\n")
264
+
265
+ aligned, timestamp = quality(chunk, startends[0], startends[1], wer_fn=wer_filter)
266
+ matches.append([sub_align, asr_align, aligned, timestamp, chunk])
267
+ continue
268
+ else:
269
+ aligned, timestamp = quality(
270
+ chunk,
271
+ [sub_align[0].start, sub_align[-1].end],
272
+ [asr_align[0].start, asr_align[-1].end],
273
+ wer_fn=wer_filter,
274
+ )
275
+ matches.append([sub_align, asr_align, aligned, timestamp, chunk])
276
+
277
+ if verbose and aligned.wer > 0.0:
278
+ safe_print(
279
+ f"===================================WER={aligned.wer:>4.2f}====================================="
280
+ )
281
+ safe_print(
282
+ f" Caption idx=[{sub_start:>4d}, {sub_end:>4d}) timestamp=[{sub_align[0].start:>8.2f}, {sub_align[-1].end:>8.2f}]: {[sup.text for sup in sub_align]}"
283
+ )
284
+ safe_print(
285
+ f"Transcript idx=[{asr_start:>4d}, {asr_end:>4d}) timestamp=[{asr_align[0].start:>8.2f}, {asr_align[-1].end:>8.2f}]: {[sup.text for sup in asr_align]}"
286
+ )
287
+ safe_print("========================================================================\n\n")
288
+
289
+ return matches
290
+
291
+
292
+ class AlignFilter(ABC):
293
+
294
+ def __init__(
295
+ self, PUNCTUATION: str = PUNCTUATION, IGNORE: str = "", SPACE=" ", EPSILON=EPSILON, SEPARATOR=JOIN_TOKEN
296
+ ):
297
+ super().__init__()
298
+ self._name = self.__class__.__name__
299
+ self.PUNCTUATION = PUNCTUATION
300
+
301
+ self.IGNORE = IGNORE
302
+ self.SPACE = SPACE
303
+ self.EPSILON = EPSILON
304
+ self.SEPARATOR = SEPARATOR
305
+
306
+ self.PUNCTUATION_SEPARATOR = PUNCTUATION + SEPARATOR
307
+ self.PUNCTUATION_SPACE = PUNCTUATION + SPACE
308
+ self.PUNCTUATION_SPACE_SEPARATOR = PUNCTUATION + SPACE + SEPARATOR
309
+
310
+ @abstractmethod
311
+ def __call__(self, chunk: List[Alignment]) -> str:
312
+ pass
313
+
314
+ @property
315
+ def name(self):
316
+ return self._name
317
+
318
+
319
+ class WERStats(namedtuple("AlignStats", ["WER", "ins_errs", "del_errs", "sub_errs", "ref_len"])):
320
+ def to_dict(self):
321
+ return {
322
+ "WER": self.WER,
323
+ "ins_errs": self.ins_errs,
324
+ "del_errs": self.del_errs,
325
+ "sub_errs": self.sub_errs,
326
+ "ref_len": self.ref_len,
327
+ }
328
+
329
+
330
+ # https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
331
+ def compute_align_stats(
332
+ ali: List[Tuple[str, str]],
333
+ ERR: str = "*",
334
+ IGNORE: str = "",
335
+ enable_log: bool = True,
336
+ ) -> WERStats:
337
+ subs: Dict[Tuple[str, str], int] = defaultdict(int)
338
+ ins: Dict[str, int] = defaultdict(int)
339
+ dels: Dict[str, int] = defaultdict(int)
340
+
341
+ # `words` stores counts per word, as follows:
342
+ # corr, ref_sub, hyp_sub, ins, dels
343
+ words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
344
+ num_corr = 0
345
+
346
+ ref_len = 0
347
+ skip = 0
348
+ for k, (ref_word, hyp_word) in enumerate(ali):
349
+ if skip > 0:
350
+ skip -= 1
351
+ continue
352
+
353
+ # compute_align_stats(ali, ERR=EPSILON, IGNORE=PUNCTUATION_SPACE_SEPARATOR)
354
+ if ali[k : k + 1] in [
355
+ [("is", "'s")], # what is -> what's
356
+ [("am", "'m")], # I am -> I'm
357
+ [("are", "'re")], # they are -> they're
358
+ [("would", "'d")], # they would -> they'd don't
359
+ [("had", "'d")], # I had -> I'd
360
+ # we will -> we'll
361
+ [("will", "'ll")],
362
+ # I have -> I've
363
+ [("have", "'ve")],
364
+ # ok -> okay
365
+ [("ok", "okay")],
366
+ # okay -> ok
367
+ [("okay", "ok")],
368
+ ]:
369
+ skip = 1
370
+ ref_len += 1
371
+ continue
372
+ elif ali[k : k + 2] in [
373
+ # let us -> let's
374
+ [("let", "let"), ("us", "'s")],
375
+ # do not -> don't
376
+ [("do", "do"), ("not", "n't")],
377
+ ]:
378
+ skip = 2
379
+ ref_len += 2
380
+ continue
381
+ elif (ref_word and ref_word in IGNORE) and (hyp_word and hyp_word in IGNORE):
382
+ continue
383
+ else:
384
+ ref_len += 1
385
+
386
+ if ref_word == ERR:
387
+ ins[hyp_word] += 1
388
+ words[hyp_word][3] += 1
389
+ elif hyp_word == ERR:
390
+ dels[ref_word] += 1
391
+ words[ref_word][4] += 1
392
+ elif hyp_word != ref_word:
393
+ subs[(ref_word, hyp_word)] += 1
394
+ words[ref_word][1] += 1
395
+ words[hyp_word][2] += 1
396
+ else:
397
+ words[ref_word][0] += 1
398
+ num_corr += 1
399
+
400
+ sub_errs = sum(subs.values())
401
+ ins_errs = sum(ins.values())
402
+ del_errs = sum(dels.values())
403
+ tot_errs = sub_errs + ins_errs + del_errs
404
+
405
+ stats = WERStats(
406
+ WER=round(tot_errs / max(ref_len, 1), ndigits=4),
407
+ ins_errs=ins_errs,
408
+ del_errs=del_errs,
409
+ sub_errs=sub_errs,
410
+ ref_len=ref_len,
411
+ )
412
+
413
+ if enable_log:
414
+ logging.info(
415
+ f"%WER {stats.WER:.4%} "
416
+ f"[{tot_errs} / {max(ref_len, 1)}, {ins_errs} ins, "
417
+ f"{del_errs} del, {sub_errs} sub ]"
418
+ )
419
+
420
+ return stats
421
+
422
+
423
+ class WERFilter(AlignFilter):
424
+ def __call__(self, chunk: List[Alignment]) -> WERStats:
425
+ ali = [(a.ref, a.hyp) for a in chunk]
426
+ stats = compute_align_stats(ali, ERR=JOIN_TOKEN, IGNORE=self.PUNCTUATION_SPACE_SEPARATOR, enable_log=False)
427
+ return stats
428
+
429
+
430
+ def quality(
431
+ chunk: List[Alignment], supervision: Tuple[float, float], transcript: Tuple[float, float], wer_fn: Callable
432
+ ) -> Tuple[AlignQuality, TimestampQuality]:
433
+ _quality = AlignQuality(
434
+ FW="FE" if chunk and chunk[0].op_type == OpType.MATCH else "FN",
435
+ LW="LE" if chunk and chunk[-1].op_type == OpType.MATCH else "LN",
436
+ PREFIX=equal_ratio(chunk[:4]),
437
+ SUFFIX=equal_ratio(chunk[-4:]),
438
+ WER=wer_fn(chunk),
439
+ )
440
+ timestamp = TimestampQuality(start=(supervision[0], transcript[0]), end=(supervision[1], transcript[1]))
441
+ return _quality, timestamp