lattifai 0.4.5__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. lattifai/__init__.py +61 -47
  2. lattifai/alignment/__init__.py +6 -0
  3. lattifai/alignment/lattice1_aligner.py +119 -0
  4. lattifai/alignment/lattice1_worker.py +185 -0
  5. lattifai/{tokenizer → alignment}/phonemizer.py +4 -4
  6. lattifai/alignment/segmenter.py +166 -0
  7. lattifai/{tokenizer → alignment}/tokenizer.py +244 -169
  8. lattifai/audio2.py +211 -0
  9. lattifai/caption/__init__.py +20 -0
  10. lattifai/caption/caption.py +1275 -0
  11. lattifai/{io → caption}/gemini_reader.py +30 -30
  12. lattifai/{io → caption}/gemini_writer.py +17 -17
  13. lattifai/{io → caption}/supervision.py +4 -3
  14. lattifai/caption/text_parser.py +145 -0
  15. lattifai/cli/__init__.py +17 -0
  16. lattifai/cli/alignment.py +153 -0
  17. lattifai/cli/caption.py +204 -0
  18. lattifai/cli/server.py +19 -0
  19. lattifai/cli/transcribe.py +197 -0
  20. lattifai/cli/youtube.py +128 -0
  21. lattifai/client.py +460 -251
  22. lattifai/config/__init__.py +20 -0
  23. lattifai/config/alignment.py +73 -0
  24. lattifai/config/caption.py +178 -0
  25. lattifai/config/client.py +46 -0
  26. lattifai/config/diarization.py +67 -0
  27. lattifai/config/media.py +335 -0
  28. lattifai/config/transcription.py +84 -0
  29. lattifai/diarization/__init__.py +5 -0
  30. lattifai/diarization/lattifai.py +89 -0
  31. lattifai/errors.py +98 -91
  32. lattifai/logging.py +116 -0
  33. lattifai/mixin.py +552 -0
  34. lattifai/server/app.py +420 -0
  35. lattifai/transcription/__init__.py +76 -0
  36. lattifai/transcription/base.py +108 -0
  37. lattifai/transcription/gemini.py +219 -0
  38. lattifai/transcription/lattifai.py +103 -0
  39. lattifai/{workflows → transcription}/prompts/__init__.py +4 -4
  40. lattifai/types.py +30 -0
  41. lattifai/utils.py +16 -44
  42. lattifai/workflow/__init__.py +22 -0
  43. lattifai/workflow/agents.py +6 -0
  44. lattifai/{workflows → workflow}/base.py +22 -22
  45. lattifai/{workflows → workflow}/file_manager.py +239 -215
  46. lattifai/workflow/youtube.py +564 -0
  47. lattifai-1.0.0.dist-info/METADATA +736 -0
  48. lattifai-1.0.0.dist-info/RECORD +52 -0
  49. {lattifai-0.4.5.dist-info → lattifai-1.0.0.dist-info}/WHEEL +1 -1
  50. lattifai-1.0.0.dist-info/entry_points.txt +13 -0
  51. {lattifai-0.4.5.dist-info → lattifai-1.0.0.dist-info}/licenses/LICENSE +1 -1
  52. lattifai/base_client.py +0 -126
  53. lattifai/bin/__init__.py +0 -3
  54. lattifai/bin/agent.py +0 -325
  55. lattifai/bin/align.py +0 -296
  56. lattifai/bin/cli_base.py +0 -25
  57. lattifai/bin/subtitle.py +0 -210
  58. lattifai/io/__init__.py +0 -42
  59. lattifai/io/reader.py +0 -85
  60. lattifai/io/text_parser.py +0 -75
  61. lattifai/io/utils.py +0 -15
  62. lattifai/io/writer.py +0 -90
  63. lattifai/tokenizer/__init__.py +0 -3
  64. lattifai/workers/__init__.py +0 -3
  65. lattifai/workers/lattice1_alpha.py +0 -284
  66. lattifai/workflows/__init__.py +0 -34
  67. lattifai/workflows/agents.py +0 -10
  68. lattifai/workflows/gemini.py +0 -167
  69. lattifai/workflows/prompts/README.md +0 -22
  70. lattifai/workflows/prompts/gemini/README.md +0 -24
  71. lattifai/workflows/prompts/gemini/transcription_gem.txt +0 -81
  72. lattifai/workflows/youtube.py +0 -931
  73. lattifai-0.4.5.dist-info/METADATA +0 -808
  74. lattifai-0.4.5.dist-info/RECORD +0 -39
  75. lattifai-0.4.5.dist-info/entry_points.txt +0 -3
  76. {lattifai-0.4.5.dist-info → lattifai-1.0.0.dist-info}/top_level.txt +0 -0
@@ -1,27 +1,118 @@
1
1
  import gzip
2
- import inspect
3
2
  import pickle
4
3
  import re
5
4
  from collections import defaultdict
6
- from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
5
+ from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar
7
6
 
8
7
  import torch
9
8
 
10
- from lattifai.errors import LATTICE_DECODING_FAILURE_HELP, LatticeDecodingError
11
- from lattifai.io import Supervision
12
- from lattifai.tokenizer.phonemizer import G2Phonemizer
9
+ from lattifai.alignment.phonemizer import G2Phonemizer
10
+ from lattifai.caption import Supervision
11
+ from lattifai.caption import normalize_text as normalize_html_text
12
+ from lattifai.errors import (
13
+ LATTICE_DECODING_FAILURE_HELP,
14
+ LatticeDecodingError,
15
+ ModelLoadError,
16
+ QuotaExceededError,
17
+ )
13
18
 
14
19
  PUNCTUATION = '!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~'
15
20
  END_PUNCTUATION = '.!?"]。!?”】'
16
- PUNCTUATION_SPACE = PUNCTUATION + ' '
17
- STAR_TOKEN = ''
21
+ PUNCTUATION_SPACE = PUNCTUATION + " "
22
+ STAR_TOKEN = ""
18
23
 
19
- GROUPING_SEPARATOR = ''
24
+ GROUPING_SEPARATOR = ""
20
25
 
21
26
  MAXIMUM_WORD_LENGTH = 40
22
27
 
23
28
 
24
- TokenizerT = TypeVar('TokenizerT', bound='LatticeTokenizer')
29
+ TokenizerT = TypeVar("TokenizerT", bound="LatticeTokenizer")
30
+
31
+
32
+ def _is_punctuation(char: str) -> bool:
33
+ """Check if a character is punctuation (not space, not alphanumeric, not CJK)."""
34
+ if len(char) != 1:
35
+ return False
36
+ if char.isspace():
37
+ return False
38
+ if char.isalnum():
39
+ return False
40
+ # Check if it's a CJK character
41
+ if "\u4e00" <= char <= "\u9fff":
42
+ return False
43
+ # Check if it's an accented Latin character
44
+ if "\u00c0" <= char <= "\u024f":
45
+ return False
46
+ return True
47
+
48
+
49
+ def tokenize_multilingual_text(text: str, keep_spaces: bool = True, attach_punctuation: bool = False) -> list[str]:
50
+ """
51
+ Tokenize a mixed Chinese-English string into individual units.
52
+
53
+ Tokenization rules:
54
+ - Chinese characters (CJK) are split individually
55
+ - Consecutive Latin letters (including accented characters) and digits are grouped as one unit
56
+ - English contractions ('s, 't, 'm, 'll, 're, 've) are kept with the preceding word
57
+ - Other characters (punctuation, spaces) are split individually by default
58
+ - If attach_punctuation=True, punctuation marks are attached to the preceding token
59
+
60
+ Args:
61
+ text: Input string containing mixed Chinese and English text
62
+ keep_spaces: If True, spaces are included in the output as separate tokens.
63
+ If False, spaces are excluded from the output. Default is True.
64
+ attach_punctuation: If True, punctuation marks are attached to the preceding token.
65
+ For example, "Hello, World!" becomes ["Hello,", " ", "World!"].
66
+ Default is False.
67
+
68
+ Returns:
69
+ List of tokenized units
70
+
71
+ Examples:
72
+ >>> tokenize_multilingual_text("Hello世界")
73
+ ['Hello', '世', '界']
74
+ >>> tokenize_multilingual_text("I'm fine")
75
+ ["I'm", ' ', 'fine']
76
+ >>> tokenize_multilingual_text("I'm fine", keep_spaces=False)
77
+ ["I'm", 'fine']
78
+ >>> tokenize_multilingual_text("Kühlschrank")
79
+ ['Kühlschrank']
80
+ >>> tokenize_multilingual_text("Hello, World!", attach_punctuation=True)
81
+ ['Hello,', ' ', 'World!']
82
+ """
83
+ # Regex pattern:
84
+ # - [a-zA-Z0-9\u00C0-\u024F]+ matches Latin letters (including accented chars like ü, ö, ä, ß, é, etc.)
85
+ # - (?:'[a-zA-Z]{1,2})? optionally matches contractions like 's, 't, 'm, 'll, 're, 've
86
+ # - [\u4e00-\u9fff] matches CJK characters
87
+ # - . matches any other single character
88
+ # Unicode ranges:
89
+ # - \u00C0-\u00FF: Latin-1 Supplement (À-ÿ)
90
+ # - \u0100-\u017F: Latin Extended-A
91
+ # - \u0180-\u024F: Latin Extended-B
92
+ pattern = re.compile(r"([a-zA-Z0-9\u00C0-\u024F]+(?:'[a-zA-Z]{1,2})?|[\u4e00-\u9fff]|.)")
93
+
94
+ # filter(None, ...) removes any empty strings from re.findall results
95
+ tokens = list(filter(None, pattern.findall(text)))
96
+
97
+ if attach_punctuation and len(tokens) > 1:
98
+ # Attach punctuation to the preceding token
99
+ # Punctuation characters (excluding spaces) are merged with the previous token
100
+ merged_tokens = []
101
+ i = 0
102
+ while i < len(tokens):
103
+ token = tokens[i]
104
+ # Look ahead to collect consecutive punctuation (non-space, non-alphanumeric, non-CJK)
105
+ if merged_tokens and _is_punctuation(token):
106
+ merged_tokens[-1] = merged_tokens[-1] + token
107
+ else:
108
+ merged_tokens.append(token)
109
+ i += 1
110
+ tokens = merged_tokens
111
+
112
+ if not keep_spaces:
113
+ tokens = [t for t in tokens if not t.isspace()]
114
+
115
+ return tokens
25
116
 
26
117
 
27
118
  class LatticeTokenizer:
@@ -29,12 +120,13 @@ class LatticeTokenizer:
29
120
 
30
121
  def __init__(self, client_wrapper: Any):
31
122
  self.client_wrapper = client_wrapper
123
+ self.model_name = ""
32
124
  self.words: List[str] = []
33
125
  self.g2p_model: Any = None # Placeholder for G2P model
34
126
  self.dictionaries = defaultdict(lambda: [])
35
- self.oov_word = '<unk>'
127
+ self.oov_word = "<unk>"
36
128
  self.sentence_splitter = None
37
- self.device = 'cpu'
129
+ self.device = "cpu"
38
130
 
39
131
  def init_sentence_splitter(self):
40
132
  if self.sentence_splitter is not None:
@@ -45,14 +137,14 @@ class LatticeTokenizer:
45
137
 
46
138
  providers = []
47
139
  device = self.device
48
- if device.startswith('cuda') and ort.get_all_providers().count('CUDAExecutionProvider') > 0:
49
- providers.append('CUDAExecutionProvider')
50
- elif device.startswith('mps') and ort.get_all_providers().count('MPSExecutionProvider') > 0:
51
- providers.append('MPSExecutionProvider')
140
+ if device.startswith("cuda") and ort.get_all_providers().count("CUDAExecutionProvider") > 0:
141
+ providers.append("CUDAExecutionProvider")
142
+ elif device.startswith("mps") and ort.get_all_providers().count("MPSExecutionProvider") > 0:
143
+ providers.append("MPSExecutionProvider")
52
144
 
53
145
  sat = SaT(
54
- 'sat-3l-sm',
55
- ort_providers=providers + ['CPUExecutionProvider'],
146
+ "sat-3l-sm",
147
+ ort_providers=providers + ["CPUExecutionProvider"],
56
148
  )
57
149
  self.sentence_splitter = sat
58
150
 
@@ -79,23 +171,23 @@ class LatticeTokenizer:
79
171
  # or other forms like [SOMETHING] SPEAKER:
80
172
 
81
173
  # Pattern 1: [mark] HTML-encoded separator speaker:
82
- pattern1 = r'^(\[[^\]]+\])\s+(&gt;&gt;|>>)\s+(.+)$'
174
+ pattern1 = r"^(\[[^\]]+\])\s+(&gt;&gt;|>>)\s+(.+)$"
83
175
  match1 = re.match(pattern1, sentence.strip())
84
176
  if match1:
85
177
  special_mark = match1.group(1)
86
178
  separator = match1.group(2)
87
179
  speaker_part = match1.group(3)
88
- return [special_mark, f'{separator} {speaker_part}']
180
+ return [special_mark, f"{separator} {speaker_part}"]
89
181
 
90
182
  # Pattern 2: [mark] speaker:
91
- pattern2 = r'^(\[[^\]]+\])\s+([^:]+:)(.*)$'
183
+ pattern2 = r"^(\[[^\]]+\])\s+([^:]+:)(.*)$"
92
184
  match2 = re.match(pattern2, sentence.strip())
93
185
  if match2:
94
186
  special_mark = match2.group(1)
95
187
  speaker_label = match2.group(2)
96
188
  remaining = match2.group(3).strip()
97
189
  if remaining:
98
- return [special_mark, f'{speaker_label} {remaining}']
190
+ return [special_mark, f"{speaker_label} {remaining}"]
99
191
  else:
100
192
  return [special_mark, speaker_label]
101
193
 
@@ -107,28 +199,45 @@ class LatticeTokenizer:
107
199
  cls: Type[TokenizerT],
108
200
  client_wrapper: Any,
109
201
  model_path: str,
110
- device: str = 'cpu',
202
+ model_name: str,
203
+ device: str = "cpu",
111
204
  compressed: bool = True,
112
205
  ) -> TokenizerT:
113
206
  """Load tokenizer from exported binary file"""
114
207
  from pathlib import Path
115
208
 
116
- words_model_path = f'{model_path}/words.bin'
117
- if compressed:
118
- with gzip.open(words_model_path, 'rb') as f:
119
- data = pickle.load(f)
120
- else:
121
- with open(words_model_path, 'rb') as f:
122
- data = pickle.load(f)
209
+ words_model_path = f"{model_path}/words.bin"
210
+ try:
211
+ if compressed:
212
+ with gzip.open(words_model_path, "rb") as f:
213
+ data = pickle.load(f)
214
+ else:
215
+ with open(words_model_path, "rb") as f:
216
+ data = pickle.load(f)
217
+ except pickle.UnpicklingError as e:
218
+ del e
219
+ import msgpack
220
+
221
+ if compressed:
222
+ with gzip.open(words_model_path, "rb") as f:
223
+ data = msgpack.unpack(f, raw=False, strict_map_key=False)
224
+ else:
225
+ with open(words_model_path, "rb") as f:
226
+ data = msgpack.unpack(f, raw=False, strict_map_key=False)
123
227
 
124
228
  tokenizer = cls(client_wrapper=client_wrapper)
125
- tokenizer.words = data['words']
126
- tokenizer.dictionaries = defaultdict(list, data['dictionaries'])
127
- tokenizer.oov_word = data['oov_word']
128
-
129
- g2p_model_path = f'{model_path}/g2p.bin' if Path(f'{model_path}/g2p.bin').exists() else None
130
- if g2p_model_path:
131
- tokenizer.g2p_model = G2Phonemizer(g2p_model_path, device=device)
229
+ tokenizer.model_name = model_name
230
+ tokenizer.words = data["words"]
231
+ tokenizer.dictionaries = defaultdict(list, data["dictionaries"])
232
+ tokenizer.oov_word = data["oov_word"]
233
+
234
+ g2pp_model_path = f"{model_path}/g2pp.bin" if Path(f"{model_path}/g2pp.bin").exists() else None
235
+ if g2pp_model_path:
236
+ tokenizer.g2p_model = G2Phonemizer(g2pp_model_path, device=device)
237
+ else:
238
+ g2p_model_path = f"{model_path}/g2p.bin" if Path(f"{model_path}/g2p.bin").exists() else None
239
+ if g2p_model_path:
240
+ tokenizer.g2p_model = G2Phonemizer(g2p_model_path, device=device)
132
241
 
133
242
  tokenizer.device = device
134
243
  tokenizer.add_special_tokens()
@@ -136,18 +245,22 @@ class LatticeTokenizer:
136
245
 
137
246
  def add_special_tokens(self):
138
247
  tokenizer = self
139
- for special_token in ['&gt;&gt;', '&gt;']:
248
+ for special_token in ["&gt;&gt;", "&gt;"]:
140
249
  if special_token not in tokenizer.dictionaries:
141
250
  tokenizer.dictionaries[special_token] = tokenizer.dictionaries[tokenizer.oov_word]
142
251
  return self
143
252
 
144
253
  def prenormalize(self, texts: List[str], language: Optional[str] = None) -> List[str]:
145
254
  if not self.g2p_model:
146
- raise ValueError('G2P model is not loaded, cannot prenormalize texts')
255
+ raise ValueError("G2P model is not loaded, cannot prenormalize texts")
147
256
 
148
257
  oov_words = []
149
258
  for text in texts:
150
- words = text.lower().replace('-', ' ').replace('—', ' ').replace('–', ' ').split()
259
+ text = normalize_html_text(text)
260
+ # support english, chinese and german tokenization
261
+ words = tokenize_multilingual_text(
262
+ text.lower().replace("-", " ").replace("—", " ").replace("–", " "), keep_spaces=False
263
+ )
151
264
  oovs = [w.strip(PUNCTUATION) for w in words if w not in self.words]
152
265
  if oovs:
153
266
  oov_words.extend([w for w in oovs if (w not in self.words and len(w) <= MAXIMUM_WORD_LENGTH)])
@@ -156,7 +269,7 @@ class LatticeTokenizer:
156
269
  if oov_words:
157
270
  indexs = []
158
271
  for k, _word in enumerate(oov_words):
159
- if any(_word.startswith(p) and _word.endswith(q) for (p, q) in [('(', ')'), ('[', ']')]):
272
+ if any(_word.startswith(p) and _word.endswith(q) for (p, q) in [("(", ")"), ("[", "]")]):
160
273
  self.dictionaries[_word] = self.dictionaries[self.oov_word]
161
274
  else:
162
275
  _word = _word.strip(PUNCTUATION_SPACE)
@@ -187,38 +300,49 @@ class LatticeTokenizer:
187
300
 
188
301
  Carefull about speaker changes.
189
302
  """
190
- texts, text_len, sidx = [], 0, 0
191
- speakers = []
303
+ texts, speakers = [], []
304
+ text_len, sidx = 0, 0
305
+
306
+ def flush_segment(end_idx: int, speaker: Optional[str] = None):
307
+ """Flush accumulated text from sidx to end_idx with given speaker."""
308
+ nonlocal text_len, sidx
309
+ if sidx <= end_idx:
310
+ if len(speakers) < len(texts) + 1:
311
+ speakers.append(speaker)
312
+ text = " ".join(sup.text for sup in supervisions[sidx : end_idx + 1])
313
+ texts.append(text)
314
+ sidx = end_idx + 1
315
+ text_len = 0
316
+
192
317
  for s, supervision in enumerate(supervisions):
193
318
  text_len += len(supervision.text)
319
+ is_last = s == len(supervisions) - 1
320
+
194
321
  if supervision.speaker:
322
+ # Flush previous segment without speaker (if any)
195
323
  if sidx < s:
196
- if len(speakers) < len(texts) + 1:
197
- speakers.append(None)
198
- text = ' '.join([sup.text for sup in supervisions[sidx:s]])
199
- texts.append(text)
200
- sidx = s
324
+ flush_segment(s - 1, None)
201
325
  text_len = len(supervision.text)
202
- speakers.append(supervision.speaker)
203
326
 
204
- else:
205
- if text_len >= 2000 or s == len(supervisions) - 1:
206
- if len(speakers) < len(texts) + 1:
207
- speakers.append(None)
208
- text = ' '.join([sup.text for sup in supervisions[sidx : s + 1]])
209
- texts.append(text)
210
- sidx = s + 1
211
- text_len = 0
212
-
213
- assert len(speakers) == len(texts), f'len(speakers)={len(speakers)} != len(texts)={len(texts)}'
327
+ # Check if we should flush this speaker's segment now
328
+ next_has_speaker = not is_last and supervisions[s + 1].speaker
329
+ if is_last or next_has_speaker:
330
+ flush_segment(s, supervision.speaker)
331
+ else:
332
+ speakers.append(supervision.speaker)
333
+
334
+ elif text_len >= 2000 or is_last:
335
+ flush_segment(s, None)
336
+
337
+ assert len(speakers) == len(texts), f"len(speakers)={len(speakers)} != len(texts)={len(texts)}"
214
338
  sentences = self.sentence_splitter.split(texts, threshold=0.15, strip_whitespace=strip_whitespace)
215
339
 
216
- supervisions, remainder = [], ''
340
+ supervisions, remainder = [], ""
217
341
  for k, (_speaker, _sentences) in enumerate(zip(speakers, sentences)):
218
342
  # Prepend remainder from previous iteration to the first sentence
219
343
  if _sentences and remainder:
220
344
  _sentences[0] = remainder + _sentences[0]
221
- remainder = ''
345
+ remainder = ""
222
346
 
223
347
  if not _sentences:
224
348
  continue
@@ -228,14 +352,14 @@ class LatticeTokenizer:
228
352
  for s, _sentence in enumerate(_sentences):
229
353
  if remainder:
230
354
  _sentence = remainder + _sentence
231
- remainder = ''
355
+ remainder = ""
232
356
  # Detect and split special sentence types: e.g., '[APPLAUSE] &gt;&gt; MIRA MURATI:' -> ['[APPLAUSE]', '&gt;&gt; MIRA MURATI:'] # noqa: E501
233
357
  resplit_parts = self._resplit_special_sentence_types(_sentence)
234
- if any(resplit_parts[-1].endswith(sp) for sp in [':', '']):
358
+ if any(resplit_parts[-1].endswith(sp) for sp in [":", ""]):
235
359
  if s < len(_sentences) - 1:
236
- _sentences[s + 1] = resplit_parts[-1] + ' ' + _sentences[s + 1]
360
+ _sentences[s + 1] = resplit_parts[-1] + " " + _sentences[s + 1]
237
361
  else: # last part
238
- remainder = resplit_parts[-1] + ' '
362
+ remainder = resplit_parts[-1] + " "
239
363
  processed_sentences.extend(resplit_parts[:-1])
240
364
  else:
241
365
  processed_sentences.extend(resplit_parts)
@@ -243,7 +367,7 @@ class LatticeTokenizer:
243
367
 
244
368
  if not _sentences:
245
369
  if remainder:
246
- _sentences, remainder = [remainder.strip()], ''
370
+ _sentences, remainder = [remainder.strip()], ""
247
371
  else:
248
372
  continue
249
373
 
@@ -257,12 +381,12 @@ class LatticeTokenizer:
257
381
  Supervision(text=text, speaker=(_speaker if s == 0 else None))
258
382
  for s, text in enumerate(_sentences[:-1])
259
383
  )
260
- remainder = _sentences[-1] + ' ' + remainder
384
+ remainder = _sentences[-1] + " " + remainder
261
385
  if k < len(speakers) - 1 and speakers[k + 1] is not None: # next speaker is set
262
386
  supervisions.append(
263
387
  Supervision(text=remainder.strip(), speaker=_speaker if len(_sentences) == 1 else None)
264
388
  )
265
- remainder = ''
389
+ remainder = ""
266
390
  elif len(_sentences) == 1:
267
391
  if k == len(speakers) - 1:
268
392
  pass # keep _speaker for the last supervision
@@ -285,20 +409,23 @@ class LatticeTokenizer:
285
409
 
286
410
  pronunciation_dictionaries = self.prenormalize([s.text for s in supervisions])
287
411
  response = self.client_wrapper.post(
288
- 'tokenize',
412
+ "tokenize",
289
413
  json={
290
- 'supervisions': [s.to_dict() for s in supervisions],
291
- 'pronunciation_dictionaries': pronunciation_dictionaries,
414
+ "model_name": self.model_name,
415
+ "supervisions": [s.to_dict() for s in supervisions],
416
+ "pronunciation_dictionaries": pronunciation_dictionaries,
292
417
  },
293
418
  )
419
+ if response.status_code == 402:
420
+ raise QuotaExceededError(response.json().get("detail", "Quota exceeded"))
294
421
  if response.status_code != 200:
295
- raise Exception(f'Failed to tokenize texts: {response.text}')
422
+ raise Exception(f"Failed to tokenize texts: {response.text}")
296
423
  result = response.json()
297
- lattice_id = result['id']
424
+ lattice_id = result["id"]
298
425
  return (
299
426
  supervisions,
300
427
  lattice_id,
301
- (result['lattice_graph'], result['final_state'], result.get('acoustic_scale', 1.0)),
428
+ (result["lattice_graph"], result["final_state"], result.get("acoustic_scale", 1.0)),
302
429
  )
303
430
 
304
431
  def detokenize(
@@ -310,16 +437,17 @@ class LatticeTokenizer:
310
437
  ) -> List[Supervision]:
311
438
  emission, results, labels, frame_shift, offset, channel = lattice_results # noqa: F841
312
439
  response = self.client_wrapper.post(
313
- 'detokenize',
440
+ "detokenize",
314
441
  json={
315
- 'lattice_id': lattice_id,
316
- 'frame_shift': frame_shift,
317
- 'results': [t.to_dict() for t in results[0]],
318
- 'labels': labels[0],
319
- 'offset': offset,
320
- 'channel': channel,
321
- 'return_details': return_details,
322
- 'destroy_lattice': True,
442
+ "model_name": self.model_name,
443
+ "lattice_id": lattice_id,
444
+ "frame_shift": frame_shift,
445
+ "results": [t.to_dict() for t in results[0]],
446
+ "labels": labels[0],
447
+ "offset": offset,
448
+ "channel": channel,
449
+ "return_details": False if return_details is None else return_details,
450
+ "destroy_lattice": True,
323
451
  },
324
452
  )
325
453
  if response.status_code == 422:
@@ -327,94 +455,20 @@ class LatticeTokenizer:
327
455
  lattice_id,
328
456
  original_error=Exception(LATTICE_DECODING_FAILURE_HELP),
329
457
  )
458
+ if response.status_code == 402:
459
+ raise QuotaExceededError(response.json().get("detail", "Quota exceeded"))
330
460
  if response.status_code != 200:
331
- raise Exception(f'Failed to detokenize lattice: {response.text}')
461
+ raise Exception(f"Failed to detokenize lattice: {response.text}")
332
462
 
333
463
  result = response.json()
334
- if not result.get('success'):
335
- raise Exception('Failed to detokenize the alignment results.')
464
+ if not result.get("success"):
465
+ raise Exception("Failed to detokenize the alignment results.")
336
466
 
337
- alignments = [Supervision.from_dict(s) for s in result['supervisions']]
467
+ alignments = [Supervision.from_dict(s) for s in result["supervisions"]]
338
468
 
339
469
  if return_details:
340
470
  # Add emission confidence scores for segments and word-level alignments
341
- _add_confidence_scores(alignments, emission, labels[0], frame_shift)
342
-
343
- alignments = _update_alignments_speaker(supervisions, alignments)
344
-
345
- return alignments
346
-
347
-
348
- class AsyncLatticeTokenizer(LatticeTokenizer):
349
- async def _post_async(self, endpoint: str, **kwargs):
350
- response = self.client_wrapper.post(endpoint, **kwargs)
351
- if inspect.isawaitable(response):
352
- return await response
353
- return response
354
-
355
- async def tokenize(
356
- self, supervisions: List[Supervision], split_sentence: bool = False
357
- ) -> Tuple[str, Dict[str, Any]]:
358
- if split_sentence:
359
- self.init_sentence_splitter()
360
- supervisions = self.split_sentences(supervisions)
361
-
362
- pronunciation_dictionaries = self.prenormalize([s.text for s in supervisions])
363
- response = await self._post_async(
364
- 'tokenize',
365
- json={
366
- 'supervisions': [s.to_dict() for s in supervisions],
367
- 'pronunciation_dictionaries': pronunciation_dictionaries,
368
- },
369
- )
370
- if response.status_code != 200:
371
- raise Exception(f'Failed to tokenize texts: {response.text}')
372
- result = response.json()
373
- lattice_id = result['id']
374
- return (
375
- supervisions,
376
- lattice_id,
377
- (result['lattice_graph'], result['final_state'], result.get('acoustic_scale', 1.0)),
378
- )
379
-
380
- async def detokenize(
381
- self,
382
- lattice_id: str,
383
- lattice_results: Tuple[torch.Tensor, Any, Any, float, float],
384
- supervisions: List[Supervision],
385
- return_details: bool = False,
386
- ) -> List[Supervision]:
387
- emission, results, labels, frame_shift, offset, channel = lattice_results # noqa: F841
388
- response = await self._post_async(
389
- 'detokenize',
390
- json={
391
- 'lattice_id': lattice_id,
392
- 'frame_shift': frame_shift,
393
- 'results': [t.to_dict() for t in results[0]],
394
- 'labels': labels[0],
395
- 'offset': offset,
396
- 'channel': channel,
397
- 'return_details': return_details,
398
- 'destroy_lattice': True,
399
- },
400
- )
401
- if response.status_code == 422:
402
- raise LatticeDecodingError(
403
- lattice_id,
404
- original_error=Exception(LATTICE_DECODING_FAILURE_HELP),
405
- )
406
- if response.status_code != 200:
407
- raise Exception(f'Failed to detokenize lattice: {response.text}')
408
-
409
- result = response.json()
410
- if not result.get('success'):
411
- return Exception('Failed to detokenize the alignment results.')
412
-
413
- alignments = [Supervision.from_dict(s) for s in result['supervisions']]
414
-
415
- if return_details:
416
- # Add emission confidence scores for segments and word-level alignments
417
- _add_confidence_scores(alignments, emission, labels[0], frame_shift)
471
+ _add_confidence_scores(alignments, emission, labels[0], frame_shift, offset)
418
472
 
419
473
  alignments = _update_alignments_speaker(supervisions, alignments)
420
474
 
@@ -426,6 +480,7 @@ def _add_confidence_scores(
426
480
  emission: torch.Tensor,
427
481
  labels: List[int],
428
482
  frame_shift: float,
483
+ offset: float = 0.0,
429
484
  ) -> None:
430
485
  """
431
486
  Add confidence scores to supervisions and their word-level alignments.
@@ -443,8 +498,8 @@ def _add_confidence_scores(
443
498
  tokens = torch.tensor(labels, dtype=torch.int64, device=emission.device)
444
499
 
445
500
  for supervision in supervisions:
446
- start_frame = int(supervision.start / frame_shift)
447
- end_frame = int(supervision.end / frame_shift)
501
+ start_frame = int((supervision.start - offset) / frame_shift)
502
+ end_frame = int((supervision.end - offset) / frame_shift)
448
503
 
449
504
  # Compute segment-level confidence
450
505
  probabilities = emission[0, start_frame:end_frame].softmax(dim=-1)
@@ -453,11 +508,11 @@ def _add_confidence_scores(
453
508
  supervision.score = round(1.0 - diffprobs.mean().item(), ndigits=4)
454
509
 
455
510
  # Compute word-level confidence if alignment exists
456
- if hasattr(supervision, 'alignment') and supervision.alignment:
457
- words = supervision.alignment.get('word', [])
511
+ if hasattr(supervision, "alignment") and supervision.alignment:
512
+ words = supervision.alignment.get("word", [])
458
513
  for w, item in enumerate(words):
459
- start = int(item.start / frame_shift) - start_frame
460
- end = int(item.end / frame_shift) - start_frame
514
+ start = int((item.start - offset) / frame_shift) - start_frame
515
+ end = int((item.end - offset) / frame_shift) - start_frame
461
516
  words[w] = item._replace(score=round(1.0 - diffprobs[start:end].mean().item(), ndigits=4))
462
517
 
463
518
 
@@ -472,3 +527,23 @@ def _update_alignments_speaker(supervisions: List[Supervision], alignments: List
472
527
  for supervision, alignment in zip(supervisions, alignments):
473
528
  alignment.speaker = supervision.speaker
474
529
  return alignments
530
+
531
+
532
+ def _load_tokenizer(
533
+ client_wrapper: Any,
534
+ model_path: str,
535
+ model_name: str,
536
+ device: str,
537
+ *,
538
+ tokenizer_cls: Type[LatticeTokenizer] = LatticeTokenizer,
539
+ ) -> LatticeTokenizer:
540
+ """Instantiate tokenizer with consistent error handling."""
541
+ try:
542
+ return tokenizer_cls.from_pretrained(
543
+ client_wrapper=client_wrapper,
544
+ model_path=model_path,
545
+ model_name=model_name,
546
+ device=device,
547
+ )
548
+ except Exception as e:
549
+ raise ModelLoadError(f"tokenizer from {model_path}", original_error=e)