lattifai 1.2.1__py3-none-any.whl → 1.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. lattifai/alignment/__init__.py +10 -1
  2. lattifai/alignment/lattice1_aligner.py +66 -58
  3. lattifai/alignment/punctuation.py +38 -0
  4. lattifai/alignment/sentence_splitter.py +152 -21
  5. lattifai/alignment/text_align.py +440 -0
  6. lattifai/alignment/tokenizer.py +82 -40
  7. lattifai/caption/__init__.py +82 -6
  8. lattifai/caption/caption.py +335 -1141
  9. lattifai/caption/formats/__init__.py +199 -0
  10. lattifai/caption/formats/base.py +211 -0
  11. lattifai/caption/{gemini_reader.py → formats/gemini.py} +320 -60
  12. lattifai/caption/formats/json.py +194 -0
  13. lattifai/caption/formats/lrc.py +309 -0
  14. lattifai/caption/formats/nle/__init__.py +9 -0
  15. lattifai/caption/formats/nle/audition.py +561 -0
  16. lattifai/caption/formats/nle/avid.py +423 -0
  17. lattifai/caption/formats/nle/fcpxml.py +549 -0
  18. lattifai/caption/formats/nle/premiere.py +589 -0
  19. lattifai/caption/formats/pysubs2.py +642 -0
  20. lattifai/caption/formats/sbv.py +147 -0
  21. lattifai/caption/formats/tabular.py +338 -0
  22. lattifai/caption/formats/textgrid.py +193 -0
  23. lattifai/caption/formats/ttml.py +652 -0
  24. lattifai/caption/formats/vtt.py +469 -0
  25. lattifai/caption/parsers/__init__.py +9 -0
  26. lattifai/caption/{text_parser.py → parsers/text_parser.py} +4 -2
  27. lattifai/caption/standardize.py +636 -0
  28. lattifai/caption/utils.py +474 -0
  29. lattifai/cli/__init__.py +2 -1
  30. lattifai/cli/caption.py +108 -1
  31. lattifai/cli/transcribe.py +1 -1
  32. lattifai/cli/youtube.py +4 -1
  33. lattifai/client.py +33 -113
  34. lattifai/config/__init__.py +11 -1
  35. lattifai/config/alignment.py +7 -0
  36. lattifai/config/caption.py +267 -23
  37. lattifai/config/media.py +20 -0
  38. lattifai/diarization/__init__.py +41 -1
  39. lattifai/mixin.py +27 -15
  40. lattifai/transcription/base.py +6 -1
  41. lattifai/transcription/lattifai.py +19 -54
  42. lattifai/utils.py +7 -13
  43. lattifai/workflow/__init__.py +28 -4
  44. lattifai/workflow/file_manager.py +2 -5
  45. lattifai/youtube/__init__.py +43 -0
  46. lattifai/youtube/client.py +1170 -0
  47. lattifai/youtube/types.py +23 -0
  48. lattifai-1.2.2.dist-info/METADATA +615 -0
  49. lattifai-1.2.2.dist-info/RECORD +76 -0
  50. {lattifai-1.2.1.dist-info → lattifai-1.2.2.dist-info}/entry_points.txt +1 -2
  51. lattifai/caption/gemini_writer.py +0 -173
  52. lattifai/cli/app_installer.py +0 -142
  53. lattifai/cli/server.py +0 -44
  54. lattifai/server/app.py +0 -427
  55. lattifai/workflow/youtube.py +0 -577
  56. lattifai-1.2.1.dist-info/METADATA +0 -1134
  57. lattifai-1.2.1.dist-info/RECORD +0 -58
  58. {lattifai-1.2.1.dist-info → lattifai-1.2.2.dist-info}/WHEEL +0 -0
  59. {lattifai-1.2.1.dist-info → lattifai-1.2.2.dist-info}/licenses/LICENSE +0 -0
  60. {lattifai-1.2.1.dist-info → lattifai-1.2.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,440 @@
1
+ # Align caption.supervisions with transcription
2
+ import logging
3
+ import string # noqa: F401
4
+ from abc import ABC, abstractmethod
5
+ from collections import defaultdict, namedtuple
6
+ from typing import Callable, Dict, List, Optional, Tuple, TypeVar
7
+
8
+ import regex
9
+ from error_align import error_align
10
+ from error_align.utils import DELIMITERS, NUMERIC_TOKEN, STANDARD_TOKEN, Alignment, OpType
11
+
12
+ from lattifai.caption import Caption, Supervision
13
+ from lattifai.utils import safe_print
14
+
15
+ from .punctuation import PUNCTUATION
16
+
17
+ Symbol = TypeVar("Symbol")
18
+ EPSILON = "`"
19
+
20
+ JOIN_TOKEN = "❄"
21
+ if JOIN_TOKEN not in DELIMITERS:
22
+ DELIMITERS.add(JOIN_TOKEN)
23
+
24
+
25
+ def custom_tokenizer(text: str) -> list:
26
+ """Default tokenizer that splits text into words based on whitespace.
27
+
28
+ Args:
29
+ text (str): The input text to tokenize.
30
+
31
+ Returns:
32
+ list: A list of tokens (words).
33
+
34
+ """
35
+ # Escape JOIN_TOKEN for use in regex pattern
36
+ escaped_join_token = regex.escape(JOIN_TOKEN)
37
+ return list(
38
+ regex.finditer(
39
+ rf"({NUMERIC_TOKEN})|({STANDARD_TOKEN}|{escaped_join_token})",
40
+ text,
41
+ regex.UNICODE | regex.VERBOSE,
42
+ )
43
+ )
44
+
45
+
46
+ def equal_ratio(chunk: List[Alignment]):
47
+ return sum(a.op_type == OpType.MATCH for a in chunk) / max(len(chunk), 1)
48
+
49
+
50
+ def equal(chunk: List[Alignment]):
51
+ return all(a.op_type == OpType.MATCH for a in chunk)
52
+
53
+
54
+ def group_alignments(
55
+ supervisions: List[Supervision],
56
+ transcription: List[Supervision],
57
+ max_silence_gap: float = 10.0,
58
+ mini_num_supervisions: int = 1,
59
+ mini_num_transcription: int = 1,
60
+ equal_threshold: float = 0.5,
61
+ verbose: bool = False,
62
+ ) -> List[Tuple[Tuple[int, int], Tuple[int, int], List[Alignment]]]:
63
+ # TABLE = str.maketrans(dict.fromkeys(string.punctuation))
64
+ # sup.text.lower().translate(TABLE)
65
+ ref = "".join(sup.text.lower() + JOIN_TOKEN for sup in supervisions)
66
+ hyp = "".join(sup.text.lower() + JOIN_TOKEN for sup in transcription)
67
+ alignments = error_align(ref, hyp, tokenizer=custom_tokenizer)
68
+
69
+ matches = []
70
+ # segment start index
71
+ ss_start, ss_idx = 0, 0
72
+ tr_start, tr_idx = 0, 0
73
+
74
+ idx_start = 0
75
+ for idx, ali in enumerate(alignments):
76
+ if ali.ref == JOIN_TOKEN:
77
+ ss_idx += 1
78
+ if ali.hyp == JOIN_TOKEN:
79
+ tr_idx += 1
80
+
81
+ if ali.ref == JOIN_TOKEN and ali.hyp == JOIN_TOKEN:
82
+ chunk = alignments[idx_start:idx]
83
+
84
+ split_at_silence = False
85
+ greater_two_silence_gap = False
86
+ if tr_idx > 0 and tr_idx < len(transcription):
87
+ gap = transcription[tr_idx].start - transcription[tr_idx - 1].end
88
+ if gap > max_silence_gap:
89
+ split_at_silence = True
90
+ if gap > 2 * max_silence_gap:
91
+ greater_two_silence_gap = True
92
+
93
+ if (
94
+ (equal_ratio(chunk[:10]) > equal_threshold or equal_ratio(chunk[-10:]) > equal_threshold)
95
+ and (
96
+ (ss_idx - ss_start >= mini_num_supervisions and tr_idx - tr_start >= mini_num_transcription)
97
+ or split_at_silence
98
+ )
99
+ ) or greater_two_silence_gap:
100
+ matches.append(((ss_start, ss_idx), (tr_start, tr_idx), chunk))
101
+
102
+ if verbose:
103
+ sub_align = supervisions[ss_start:ss_idx]
104
+ asr_align = transcription[tr_start:tr_idx]
105
+ safe_print("========================================================================")
106
+ safe_print(f" Caption [{ss_start:>4d}, {ss_idx:>4d}): {[sup.text for sup in sub_align]}")
107
+ safe_print(f"Transcript [{tr_start:>4d}, {tr_idx:>4d}): {[sup.text for sup in asr_align]}")
108
+ safe_print("========================================================================\n\n")
109
+
110
+ ss_start = ss_idx
111
+ tr_start = tr_idx
112
+ idx_start = idx + 1
113
+
114
+ if ss_start == len(supervisions) and tr_start == len(transcription):
115
+ break
116
+
117
+ # remainder
118
+ if ss_idx == len(supervisions) or tr_idx == len(transcription):
119
+ chunk = alignments[idx_start:]
120
+ matches.append(((ss_start, len(supervisions)), (tr_start, len(transcription)), chunk))
121
+ break
122
+
123
+ return matches
124
+
125
+
126
+ class AlignQuality(namedtuple("AlignQuality", ["FW", "LW", "PREFIX", "SUFFIX", "WER"])):
127
+ def __repr__(self) -> str:
128
+ quality = f"WORD[{self.FW}][{self.LW}]_WER[{self.WER}]"
129
+ return quality
130
+
131
+ @property
132
+ def info(self) -> str:
133
+ info = f"WER {self.WER.WER:.4f} accuracy [{self.PREFIX:.2f}, {self.SUFFIX:.2f}] {self.WER}"
134
+ return info
135
+
136
+ @property
137
+ def first_word_equal(self) -> bool:
138
+ return self.FW == "FE"
139
+
140
+ @property
141
+ def last_word_equal(self) -> bool:
142
+ return self.LW == "LE"
143
+
144
+ @property
145
+ def wer(self) -> float:
146
+ return self.WER.WER
147
+
148
+ @property
149
+ def qwer(self):
150
+ wer = self.wer
151
+ # 考虑 ref_len
152
+ if wer == 0.0:
153
+ return "WZ" # zero
154
+ elif wer < 0.1:
155
+ return "WL" # low
156
+ elif wer < 0.32:
157
+ return "WM" # medium
158
+ else:
159
+ return "WH" # high
160
+
161
+ @property
162
+ def qprefix(self) -> str:
163
+ if self.PREFIX >= 0.7:
164
+ return "PH" # high
165
+ elif self.PREFIX >= 0.5:
166
+ return "PM" # medium
167
+ else:
168
+ return "PL" # low
169
+
170
+ @property
171
+ def qsuffix(self) -> str:
172
+ if self.SUFFIX > 0.7:
173
+ return "SH" # high
174
+ elif self.SUFFIX > 0.5:
175
+ return "SM" # medium
176
+ else:
177
+ return "SL" # low
178
+
179
+
180
+ class TimestampQuality(namedtuple("TimestampQuality", ["start", "end"])):
181
+ @property
182
+ def start_diff(self):
183
+ return abs(self.start[0] - self.start[1])
184
+
185
+ @property
186
+ def end_diff(self):
187
+ return abs(self.end[0] - self.end[1])
188
+
189
+ @property
190
+ def diff(self):
191
+ return max(self.start_diff, self.end_diff)
192
+
193
+
194
+ TextAlignResult = Tuple[Optional[List[Supervision]], Optional[List[Supervision]], AlignQuality, TimestampQuality, int]
195
+
196
+
197
+ def align_supervisions_and_transcription(
198
+ caption: Caption,
199
+ max_duration: Optional[float] = None,
200
+ verbose: bool = False,
201
+ ) -> List[TextAlignResult]:
202
+ """Align caption.supervisions with caption.transcription.
203
+
204
+ Args:
205
+ caption: Caption object containing supervisions and transcription.
206
+
207
+ """
208
+ groups = group_alignments(caption.supervisions, caption.transcription, verbose=False)
209
+
210
+ if max_duration is None:
211
+ max_duration = max(caption.transcription[-1].end, caption.supervisions[-1].end)
212
+ else:
213
+ max_duration = min(
214
+ max_duration,
215
+ max(caption.transcription[-1].end, caption.supervisions[-1].end) + 10.0,
216
+ )
217
+
218
+ def next_start(alignments: List[Supervision], idx: int) -> float:
219
+ if idx < len(alignments):
220
+ return alignments[idx].start
221
+ return min(alignments[-1].end + 2.0, max_duration)
222
+
223
+ wer_filter = WERFilter()
224
+
225
+ matches = []
226
+ for idx, ((sub_start, sub_end), (asr_start, asr_end), chunk) in enumerate(groups):
227
+ sub_align = caption.supervisions[sub_start:sub_end]
228
+ asr_align = caption.transcription[asr_start:asr_end]
229
+
230
+ if not sub_align or not asr_align:
231
+ if sub_align:
232
+ if matches:
233
+ _asr_start = matches[-1][-2].end[1]
234
+ else:
235
+ _asr_start = 0.0
236
+ startends = [
237
+ (sub_align[0].start, sub_align[-1].end),
238
+ (_asr_start, next_start(caption.transcription, asr_end)),
239
+ ]
240
+ elif asr_align:
241
+ if matches:
242
+ _sub_start = matches[-1][-2].end[0]
243
+ else:
244
+ _sub_start = 0.0
245
+ startends = [
246
+ (_sub_start, next_start(caption.supervisions, sub_end)),
247
+ (asr_align[0].start, asr_align[-1].end),
248
+ ]
249
+ else:
250
+ raise ValueError(
251
+ f"Never Here! subtitles[{len(caption.supervisions)}] {sub_start}-{sub_end} asrs[{len(caption.transcription)}] {asr_start}-{asr_end}"
252
+ )
253
+
254
+ if verbose:
255
+ safe_print("oooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo")
256
+ safe_print(
257
+ f" Caption idx=[{sub_start:>4d}, {sub_end:>4d}) timestamp=[{startends[0][0]:>8.2f}, {startends[0][1]:>8.2f}]: {[sup.text for sup in sub_align]}"
258
+ )
259
+ safe_print(
260
+ f"Transcript idx=[{asr_start:>4d}, {asr_end:>4d}) timestamp=[{startends[1][0]:>8.2f}, {startends[1][1]:>8.2f}]: {[sup.text for sup in asr_align]}"
261
+ )
262
+ safe_print("oooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo\n\n")
263
+
264
+ aligned, timestamp = quality(chunk, startends[0], startends[1], wer_fn=wer_filter)
265
+ matches.append([sub_align, asr_align, aligned, timestamp, chunk])
266
+ continue
267
+ else:
268
+ aligned, timestamp = quality(
269
+ chunk,
270
+ [sub_align[0].start, sub_align[-1].end],
271
+ [asr_align[0].start, asr_align[-1].end],
272
+ wer_fn=wer_filter,
273
+ )
274
+ matches.append([sub_align, asr_align, aligned, timestamp, chunk])
275
+
276
+ if verbose and aligned.wer > 0.0:
277
+ safe_print(
278
+ f"===================================WER={aligned.wer:>4.2f}====================================="
279
+ )
280
+ safe_print(
281
+ f" Caption idx=[{sub_start:>4d}, {sub_end:>4d}) timestamp=[{sub_align[0].start:>8.2f}, {sub_align[-1].end:>8.2f}]: {[sup.text for sup in sub_align]}"
282
+ )
283
+ safe_print(
284
+ f"Transcript idx=[{asr_start:>4d}, {asr_end:>4d}) timestamp=[{asr_align[0].start:>8.2f}, {asr_align[-1].end:>8.2f}]: {[sup.text for sup in asr_align]}"
285
+ )
286
+ safe_print("========================================================================\n\n")
287
+
288
+ return matches
289
+
290
+
291
+ class AlignFilter(ABC):
292
+
293
+ def __init__(
294
+ self, PUNCTUATION: str = PUNCTUATION, IGNORE: str = "", SPACE=" ", EPSILON=EPSILON, SEPARATOR=JOIN_TOKEN
295
+ ):
296
+ super().__init__()
297
+ self._name = self.__class__.__name__
298
+ self.PUNCTUATION = PUNCTUATION
299
+
300
+ self.IGNORE = IGNORE
301
+ self.SPACE = SPACE
302
+ self.EPSILON = EPSILON
303
+ self.SEPARATOR = SEPARATOR
304
+
305
+ self.PUNCTUATION_SEPARATOR = PUNCTUATION + SEPARATOR
306
+ self.PUNCTUATION_SPACE = PUNCTUATION + SPACE
307
+ self.PUNCTUATION_SPACE_SEPARATOR = PUNCTUATION + SPACE + SEPARATOR
308
+
309
+ @abstractmethod
310
+ def __call__(self, chunk: List[Alignment]) -> str:
311
+ pass
312
+
313
+ @property
314
+ def name(self):
315
+ return self._name
316
+
317
+
318
+ class WERStats(namedtuple("AlignStats", ["WER", "ins_errs", "del_errs", "sub_errs", "ref_len"])):
319
+ def to_dict(self):
320
+ return {
321
+ "WER": self.WER,
322
+ "ins_errs": self.ins_errs,
323
+ "del_errs": self.del_errs,
324
+ "sub_errs": self.sub_errs,
325
+ "ref_len": self.ref_len,
326
+ }
327
+
328
+
329
+ # https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
330
+ def compute_align_stats(
331
+ ali: List[Tuple[str, str]],
332
+ ERR: str = "*",
333
+ IGNORE: str = "",
334
+ enable_log: bool = True,
335
+ ) -> WERStats:
336
+ subs: Dict[Tuple[str, str], int] = defaultdict(int)
337
+ ins: Dict[str, int] = defaultdict(int)
338
+ dels: Dict[str, int] = defaultdict(int)
339
+
340
+ # `words` stores counts per word, as follows:
341
+ # corr, ref_sub, hyp_sub, ins, dels
342
+ words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
343
+ num_corr = 0
344
+
345
+ ref_len = 0
346
+ skip = 0
347
+ for k, (ref_word, hyp_word) in enumerate(ali):
348
+ if skip > 0:
349
+ skip -= 1
350
+ continue
351
+
352
+ # compute_align_stats(ali, ERR=EPSILON, IGNORE=PUNCTUATION_SPACE_SEPARATOR)
353
+ if ali[k : k + 1] in [
354
+ [("is", "'s")], # what is -> what's
355
+ [("am", "'m")], # I am -> I'm
356
+ [("are", "'re")], # they are -> they're
357
+ [("would", "'d")], # they would -> they'd don't
358
+ [("had", "'d")], # I had -> I'd
359
+ # we will -> we'll
360
+ [("will", "'ll")],
361
+ # I have -> I've
362
+ [("have", "'ve")],
363
+ # ok -> okay
364
+ [("ok", "okay")],
365
+ # okay -> ok
366
+ [("okay", "ok")],
367
+ ]:
368
+ skip = 1
369
+ ref_len += 1
370
+ continue
371
+ elif ali[k : k + 2] in [
372
+ # let us -> let's
373
+ [("let", "let"), ("us", "'s")],
374
+ # do not -> don't
375
+ [("do", "do"), ("not", "n't")],
376
+ ]:
377
+ skip = 2
378
+ ref_len += 2
379
+ continue
380
+ elif (ref_word and ref_word in IGNORE) and (hyp_word and hyp_word in IGNORE):
381
+ continue
382
+ else:
383
+ ref_len += 1
384
+
385
+ if ref_word == ERR:
386
+ ins[hyp_word] += 1
387
+ words[hyp_word][3] += 1
388
+ elif hyp_word == ERR:
389
+ dels[ref_word] += 1
390
+ words[ref_word][4] += 1
391
+ elif hyp_word != ref_word:
392
+ subs[(ref_word, hyp_word)] += 1
393
+ words[ref_word][1] += 1
394
+ words[hyp_word][2] += 1
395
+ else:
396
+ words[ref_word][0] += 1
397
+ num_corr += 1
398
+
399
+ sub_errs = sum(subs.values())
400
+ ins_errs = sum(ins.values())
401
+ del_errs = sum(dels.values())
402
+ tot_errs = sub_errs + ins_errs + del_errs
403
+
404
+ stats = WERStats(
405
+ WER=round(tot_errs / max(ref_len, 1), ndigits=4),
406
+ ins_errs=ins_errs,
407
+ del_errs=del_errs,
408
+ sub_errs=sub_errs,
409
+ ref_len=ref_len,
410
+ )
411
+
412
+ if enable_log:
413
+ logging.info(
414
+ f"%WER {stats.WER:.4%} "
415
+ f"[{tot_errs} / {max(ref_len, 1)}, {ins_errs} ins, "
416
+ f"{del_errs} del, {sub_errs} sub ]"
417
+ )
418
+
419
+ return stats
420
+
421
+
422
+ class WERFilter(AlignFilter):
423
+ def __call__(self, chunk: List[Alignment]) -> WERStats:
424
+ ali = [(a.ref, a.hyp) for a in chunk]
425
+ stats = compute_align_stats(ali, ERR=JOIN_TOKEN, IGNORE=self.PUNCTUATION_SPACE_SEPARATOR, enable_log=False)
426
+ return stats
427
+
428
+
429
+ def quality(
430
+ chunk: List[Alignment], supervision: Tuple[float, float], transcript: Tuple[float, float], wer_fn: Callable
431
+ ) -> Tuple[AlignQuality, TimestampQuality]:
432
+ _quality = AlignQuality(
433
+ FW="FE" if chunk and chunk[0].op_type == OpType.MATCH else "FN",
434
+ LW="LE" if chunk and chunk[-1].op_type == OpType.MATCH else "LN",
435
+ PREFIX=equal_ratio(chunk[:4]),
436
+ SUFFIX=equal_ratio(chunk[-4:]),
437
+ WER=wer_fn(chunk),
438
+ )
439
+ timestamp = TimestampQuality(start=(supervision[0], transcript[0]), end=(supervision[1], transcript[1]))
440
+ return _quality, timestamp
@@ -2,11 +2,13 @@ import gzip
2
2
  import pickle
3
3
  import re
4
4
  from collections import defaultdict
5
- from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar
5
+ from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
6
6
 
7
7
  import numpy as np
8
8
 
9
- from lattifai.caption import Supervision
9
+ # from lattifai.caption import Supervision
10
+ from lhotse.supervision import SupervisionSegment as Supervision # NOTE: Transcriber SupervisionSegment
11
+
10
12
  from lattifai.caption import normalize_text as normalize_html_text
11
13
  from lattifai.errors import (
12
14
  LATTICE_DECODING_FAILURE_HELP,
@@ -16,13 +18,9 @@ from lattifai.errors import (
16
18
  )
17
19
 
18
20
  from .phonemizer import G2Phonemizer
21
+ from .punctuation import PUNCTUATION, PUNCTUATION_SPACE
19
22
  from .sentence_splitter import SentenceSplitter
20
-
21
- PUNCTUATION = '!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~'
22
- PUNCTUATION_SPACE = PUNCTUATION + " "
23
- STAR_TOKEN = "※"
24
-
25
- GROUPING_SEPARATOR = "✹"
23
+ from .text_align import TextAlignResult
26
24
 
27
25
  MAXIMUM_WORD_LENGTH = 40
28
26
 
@@ -80,8 +78,11 @@ def tokenize_multilingual_text(text: str, keep_spaces: bool = True, attach_punct
80
78
  ['Kühlschrank']
81
79
  >>> tokenize_multilingual_text("Hello, World!", attach_punctuation=True)
82
80
  ['Hello,', ' ', 'World!']
81
+ >>> tokenize_multilingual_text("[AED], World!", keep_spaces=False, attach_punctuation=True)
82
+ ['[AED],', 'World!']
83
83
  """
84
84
  # Regex pattern:
85
+ # - \[[A-Z_]+\] matches bracketed annotations like [APPLAUSE], [MUSIC], [SPEAKER_01]
85
86
  # - [a-zA-Z0-9\u00C0-\u024F]+ matches Latin letters (including accented chars like ü, ö, ä, ß, é, etc.)
86
87
  # - (?:'[a-zA-Z]{1,2})? optionally matches contractions like 's, 't, 'm, 'll, 're, 've
87
88
  # - [\u4e00-\u9fff] matches CJK characters
@@ -90,7 +91,7 @@ def tokenize_multilingual_text(text: str, keep_spaces: bool = True, attach_punct
90
91
  # - \u00C0-\u00FF: Latin-1 Supplement (À-ÿ)
91
92
  # - \u0100-\u017F: Latin Extended-A
92
93
  # - \u0180-\u024F: Latin Extended-B
93
- pattern = re.compile(r"([a-zA-Z0-9\u00C0-\u024F]+(?:'[a-zA-Z]{1,2})?|[\u4e00-\u9fff]|.)")
94
+ pattern = re.compile(r"(\[[A-Z_]+\]|[a-zA-Z0-9\u00C0-\u024F]+(?:'[a-zA-Z]{1,2})?|[\u4e00-\u9fff]|.)")
94
95
 
95
96
  # filter(None, ...) removes any empty strings from re.findall results
96
97
  tokens = list(filter(None, pattern.findall(text)))
@@ -245,19 +246,37 @@ class LatticeTokenizer:
245
246
  self.init_sentence_splitter()
246
247
  return self.sentence_splitter.split_sentences(supervisions, strip_whitespace=strip_whitespace)
247
248
 
248
- def tokenize(self, supervisions: List[Supervision], split_sentence: bool = False) -> Tuple[str, Dict[str, Any]]:
249
- if split_sentence:
250
- supervisions = self.split_sentences(supervisions)
251
-
252
- pronunciation_dictionaries = self.prenormalize([s.text for s in supervisions])
253
- response = self.client_wrapper.post(
254
- "tokenize",
255
- json={
256
- "model_name": self.model_name,
257
- "supervisions": [s.to_dict() for s in supervisions],
258
- "pronunciation_dictionaries": pronunciation_dictionaries,
259
- },
260
- )
249
+ def tokenize(
250
+ self, supervisions: Union[List[Supervision], TextAlignResult], split_sentence: bool = False, boost: float = 0.0
251
+ ) -> Tuple[str, Dict[str, Any]]:
252
+ if isinstance(supervisions[0], Supervision):
253
+ if split_sentence:
254
+ supervisions = self.split_sentences(supervisions)
255
+
256
+ pronunciation_dictionaries = self.prenormalize([s.text for s in supervisions])
257
+ response = self.client_wrapper.post(
258
+ "tokenize",
259
+ json={
260
+ "model_name": self.model_name,
261
+ "supervisions": [s.to_dict() for s in supervisions],
262
+ "pronunciation_dictionaries": pronunciation_dictionaries,
263
+ },
264
+ )
265
+ else:
266
+ pronunciation_dictionaries = self.prenormalize([s.text for s in supervisions[0]])
267
+ pronunciation_dictionaries.update(self.prenormalize([s.text for s in supervisions[1]]))
268
+
269
+ response = self.client_wrapper.post(
270
+ "difftokenize",
271
+ json={
272
+ "model_name": self.model_name,
273
+ "supervisions": [s.to_dict() for s in supervisions[0]],
274
+ "transcription": [s.to_dict() for s in supervisions[1]],
275
+ "pronunciation_dictionaries": pronunciation_dictionaries,
276
+ "boost": boost,
277
+ },
278
+ )
279
+
261
280
  if response.status_code == 402:
262
281
  raise QuotaExceededError(response.json().get("detail", "Quota exceeded"))
263
282
  if response.status_code != 200:
@@ -274,28 +293,47 @@ class LatticeTokenizer:
274
293
  self,
275
294
  lattice_id: str,
276
295
  lattice_results: Tuple[np.ndarray, Any, Any, float, float],
277
- supervisions: List[Supervision],
296
+ supervisions: Union[List[Supervision], TextAlignResult],
278
297
  return_details: bool = False,
279
298
  start_margin: float = 0.08,
280
299
  end_margin: float = 0.20,
281
300
  ) -> List[Supervision]:
282
301
  emission, results, labels, frame_shift, offset, channel = lattice_results # noqa: F841
283
- response = self.client_wrapper.post(
284
- "detokenize",
285
- json={
286
- "model_name": self.model_name,
287
- "lattice_id": lattice_id,
288
- "frame_shift": frame_shift,
289
- "results": [t.to_dict() for t in results[0]],
290
- "labels": labels[0],
291
- "offset": offset,
292
- "channel": channel,
293
- "return_details": False if return_details is None else return_details,
294
- "destroy_lattice": True,
295
- "start_margin": start_margin,
296
- "end_margin": end_margin,
297
- },
298
- )
302
+ if isinstance(supervisions[0], Supervision):
303
+ response = self.client_wrapper.post(
304
+ "detokenize",
305
+ json={
306
+ "model_name": self.model_name,
307
+ "lattice_id": lattice_id,
308
+ "frame_shift": frame_shift,
309
+ "results": [t.to_dict() for t in results[0]],
310
+ "labels": labels[0],
311
+ "offset": offset,
312
+ "channel": channel,
313
+ "return_details": False if return_details is None else return_details,
314
+ "destroy_lattice": True,
315
+ "start_margin": start_margin,
316
+ "end_margin": end_margin,
317
+ },
318
+ )
319
+ else:
320
+ response = self.client_wrapper.post(
321
+ "diffdetokenize",
322
+ json={
323
+ "model_name": self.model_name,
324
+ "lattice_id": lattice_id,
325
+ "frame_shift": frame_shift,
326
+ "results": [t.to_dict() for t in results[0]],
327
+ "labels": labels[0],
328
+ "offset": offset,
329
+ "channel": channel,
330
+ "return_details": False if return_details is None else return_details,
331
+ "destroy_lattice": True,
332
+ "start_margin": start_margin,
333
+ "end_margin": end_margin,
334
+ },
335
+ )
336
+
299
337
  if response.status_code == 400:
300
338
  raise LatticeDecodingError(
301
339
  lattice_id,
@@ -316,7 +354,11 @@ class LatticeTokenizer:
316
354
  # Add emission confidence scores for segments and word-level alignments
317
355
  _add_confidence_scores(alignments, emission, labels[0], frame_shift, offset)
318
356
 
319
- alignments = _update_alignments_speaker(supervisions, alignments)
357
+ if isinstance(supervisions[0], Supervision):
358
+ alignments = _update_alignments_speaker(supervisions, alignments)
359
+ else:
360
+ # NOTE: Text Diff Alignment >> speaker has been handled in the backend service
361
+ pass
320
362
 
321
363
  return alignments
322
364