lattifai 0.4.5__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lattifai/__init__.py +61 -47
- lattifai/alignment/__init__.py +6 -0
- lattifai/alignment/lattice1_aligner.py +119 -0
- lattifai/alignment/lattice1_worker.py +185 -0
- lattifai/{tokenizer → alignment}/phonemizer.py +4 -4
- lattifai/alignment/segmenter.py +166 -0
- lattifai/{tokenizer → alignment}/tokenizer.py +244 -169
- lattifai/audio2.py +211 -0
- lattifai/caption/__init__.py +20 -0
- lattifai/caption/caption.py +1275 -0
- lattifai/{io → caption}/gemini_reader.py +30 -30
- lattifai/{io → caption}/gemini_writer.py +17 -17
- lattifai/{io → caption}/supervision.py +4 -3
- lattifai/caption/text_parser.py +145 -0
- lattifai/cli/__init__.py +17 -0
- lattifai/cli/alignment.py +153 -0
- lattifai/cli/caption.py +204 -0
- lattifai/cli/server.py +19 -0
- lattifai/cli/transcribe.py +197 -0
- lattifai/cli/youtube.py +128 -0
- lattifai/client.py +460 -251
- lattifai/config/__init__.py +20 -0
- lattifai/config/alignment.py +73 -0
- lattifai/config/caption.py +178 -0
- lattifai/config/client.py +46 -0
- lattifai/config/diarization.py +67 -0
- lattifai/config/media.py +335 -0
- lattifai/config/transcription.py +84 -0
- lattifai/diarization/__init__.py +5 -0
- lattifai/diarization/lattifai.py +89 -0
- lattifai/errors.py +98 -91
- lattifai/logging.py +116 -0
- lattifai/mixin.py +552 -0
- lattifai/server/app.py +420 -0
- lattifai/transcription/__init__.py +76 -0
- lattifai/transcription/base.py +108 -0
- lattifai/transcription/gemini.py +219 -0
- lattifai/transcription/lattifai.py +103 -0
- lattifai/{workflows → transcription}/prompts/__init__.py +4 -4
- lattifai/types.py +30 -0
- lattifai/utils.py +16 -44
- lattifai/workflow/__init__.py +22 -0
- lattifai/workflow/agents.py +6 -0
- lattifai/{workflows → workflow}/base.py +22 -22
- lattifai/{workflows → workflow}/file_manager.py +239 -215
- lattifai/workflow/youtube.py +564 -0
- lattifai-1.0.0.dist-info/METADATA +736 -0
- lattifai-1.0.0.dist-info/RECORD +52 -0
- {lattifai-0.4.5.dist-info → lattifai-1.0.0.dist-info}/WHEEL +1 -1
- lattifai-1.0.0.dist-info/entry_points.txt +13 -0
- {lattifai-0.4.5.dist-info → lattifai-1.0.0.dist-info}/licenses/LICENSE +1 -1
- lattifai/base_client.py +0 -126
- lattifai/bin/__init__.py +0 -3
- lattifai/bin/agent.py +0 -325
- lattifai/bin/align.py +0 -296
- lattifai/bin/cli_base.py +0 -25
- lattifai/bin/subtitle.py +0 -210
- lattifai/io/__init__.py +0 -42
- lattifai/io/reader.py +0 -85
- lattifai/io/text_parser.py +0 -75
- lattifai/io/utils.py +0 -15
- lattifai/io/writer.py +0 -90
- lattifai/tokenizer/__init__.py +0 -3
- lattifai/workers/__init__.py +0 -3
- lattifai/workers/lattice1_alpha.py +0 -284
- lattifai/workflows/__init__.py +0 -34
- lattifai/workflows/agents.py +0 -10
- lattifai/workflows/gemini.py +0 -167
- lattifai/workflows/prompts/README.md +0 -22
- lattifai/workflows/prompts/gemini/README.md +0 -24
- lattifai/workflows/prompts/gemini/transcription_gem.txt +0 -81
- lattifai/workflows/youtube.py +0 -931
- lattifai-0.4.5.dist-info/METADATA +0 -808
- lattifai-0.4.5.dist-info/RECORD +0 -39
- lattifai-0.4.5.dist-info/entry_points.txt +0 -3
- {lattifai-0.4.5.dist-info → lattifai-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -1,27 +1,118 @@
|
|
|
1
1
|
import gzip
|
|
2
|
-
import inspect
|
|
3
2
|
import pickle
|
|
4
3
|
import re
|
|
5
4
|
from collections import defaultdict
|
|
6
|
-
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar
|
|
5
|
+
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar
|
|
7
6
|
|
|
8
7
|
import torch
|
|
9
8
|
|
|
10
|
-
from lattifai.
|
|
11
|
-
from lattifai.
|
|
12
|
-
from lattifai.
|
|
9
|
+
from lattifai.alignment.phonemizer import G2Phonemizer
|
|
10
|
+
from lattifai.caption import Supervision
|
|
11
|
+
from lattifai.caption import normalize_text as normalize_html_text
|
|
12
|
+
from lattifai.errors import (
|
|
13
|
+
LATTICE_DECODING_FAILURE_HELP,
|
|
14
|
+
LatticeDecodingError,
|
|
15
|
+
ModelLoadError,
|
|
16
|
+
QuotaExceededError,
|
|
17
|
+
)
|
|
13
18
|
|
|
14
19
|
PUNCTUATION = '!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~'
|
|
15
20
|
END_PUNCTUATION = '.!?"]。!?”】'
|
|
16
|
-
PUNCTUATION_SPACE = PUNCTUATION +
|
|
17
|
-
STAR_TOKEN =
|
|
21
|
+
PUNCTUATION_SPACE = PUNCTUATION + " "
|
|
22
|
+
STAR_TOKEN = "※"
|
|
18
23
|
|
|
19
|
-
GROUPING_SEPARATOR =
|
|
24
|
+
GROUPING_SEPARATOR = "✹"
|
|
20
25
|
|
|
21
26
|
MAXIMUM_WORD_LENGTH = 40
|
|
22
27
|
|
|
23
28
|
|
|
24
|
-
TokenizerT = TypeVar(
|
|
29
|
+
TokenizerT = TypeVar("TokenizerT", bound="LatticeTokenizer")
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _is_punctuation(char: str) -> bool:
|
|
33
|
+
"""Check if a character is punctuation (not space, not alphanumeric, not CJK)."""
|
|
34
|
+
if len(char) != 1:
|
|
35
|
+
return False
|
|
36
|
+
if char.isspace():
|
|
37
|
+
return False
|
|
38
|
+
if char.isalnum():
|
|
39
|
+
return False
|
|
40
|
+
# Check if it's a CJK character
|
|
41
|
+
if "\u4e00" <= char <= "\u9fff":
|
|
42
|
+
return False
|
|
43
|
+
# Check if it's an accented Latin character
|
|
44
|
+
if "\u00c0" <= char <= "\u024f":
|
|
45
|
+
return False
|
|
46
|
+
return True
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def tokenize_multilingual_text(text: str, keep_spaces: bool = True, attach_punctuation: bool = False) -> list[str]:
|
|
50
|
+
"""
|
|
51
|
+
Tokenize a mixed Chinese-English string into individual units.
|
|
52
|
+
|
|
53
|
+
Tokenization rules:
|
|
54
|
+
- Chinese characters (CJK) are split individually
|
|
55
|
+
- Consecutive Latin letters (including accented characters) and digits are grouped as one unit
|
|
56
|
+
- English contractions ('s, 't, 'm, 'll, 're, 've) are kept with the preceding word
|
|
57
|
+
- Other characters (punctuation, spaces) are split individually by default
|
|
58
|
+
- If attach_punctuation=True, punctuation marks are attached to the preceding token
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
text: Input string containing mixed Chinese and English text
|
|
62
|
+
keep_spaces: If True, spaces are included in the output as separate tokens.
|
|
63
|
+
If False, spaces are excluded from the output. Default is True.
|
|
64
|
+
attach_punctuation: If True, punctuation marks are attached to the preceding token.
|
|
65
|
+
For example, "Hello, World!" becomes ["Hello,", " ", "World!"].
|
|
66
|
+
Default is False.
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
List of tokenized units
|
|
70
|
+
|
|
71
|
+
Examples:
|
|
72
|
+
>>> tokenize_multilingual_text("Hello世界")
|
|
73
|
+
['Hello', '世', '界']
|
|
74
|
+
>>> tokenize_multilingual_text("I'm fine")
|
|
75
|
+
["I'm", ' ', 'fine']
|
|
76
|
+
>>> tokenize_multilingual_text("I'm fine", keep_spaces=False)
|
|
77
|
+
["I'm", 'fine']
|
|
78
|
+
>>> tokenize_multilingual_text("Kühlschrank")
|
|
79
|
+
['Kühlschrank']
|
|
80
|
+
>>> tokenize_multilingual_text("Hello, World!", attach_punctuation=True)
|
|
81
|
+
['Hello,', ' ', 'World!']
|
|
82
|
+
"""
|
|
83
|
+
# Regex pattern:
|
|
84
|
+
# - [a-zA-Z0-9\u00C0-\u024F]+ matches Latin letters (including accented chars like ü, ö, ä, ß, é, etc.)
|
|
85
|
+
# - (?:'[a-zA-Z]{1,2})? optionally matches contractions like 's, 't, 'm, 'll, 're, 've
|
|
86
|
+
# - [\u4e00-\u9fff] matches CJK characters
|
|
87
|
+
# - . matches any other single character
|
|
88
|
+
# Unicode ranges:
|
|
89
|
+
# - \u00C0-\u00FF: Latin-1 Supplement (À-ÿ)
|
|
90
|
+
# - \u0100-\u017F: Latin Extended-A
|
|
91
|
+
# - \u0180-\u024F: Latin Extended-B
|
|
92
|
+
pattern = re.compile(r"([a-zA-Z0-9\u00C0-\u024F]+(?:'[a-zA-Z]{1,2})?|[\u4e00-\u9fff]|.)")
|
|
93
|
+
|
|
94
|
+
# filter(None, ...) removes any empty strings from re.findall results
|
|
95
|
+
tokens = list(filter(None, pattern.findall(text)))
|
|
96
|
+
|
|
97
|
+
if attach_punctuation and len(tokens) > 1:
|
|
98
|
+
# Attach punctuation to the preceding token
|
|
99
|
+
# Punctuation characters (excluding spaces) are merged with the previous token
|
|
100
|
+
merged_tokens = []
|
|
101
|
+
i = 0
|
|
102
|
+
while i < len(tokens):
|
|
103
|
+
token = tokens[i]
|
|
104
|
+
# Look ahead to collect consecutive punctuation (non-space, non-alphanumeric, non-CJK)
|
|
105
|
+
if merged_tokens and _is_punctuation(token):
|
|
106
|
+
merged_tokens[-1] = merged_tokens[-1] + token
|
|
107
|
+
else:
|
|
108
|
+
merged_tokens.append(token)
|
|
109
|
+
i += 1
|
|
110
|
+
tokens = merged_tokens
|
|
111
|
+
|
|
112
|
+
if not keep_spaces:
|
|
113
|
+
tokens = [t for t in tokens if not t.isspace()]
|
|
114
|
+
|
|
115
|
+
return tokens
|
|
25
116
|
|
|
26
117
|
|
|
27
118
|
class LatticeTokenizer:
|
|
@@ -29,12 +120,13 @@ class LatticeTokenizer:
|
|
|
29
120
|
|
|
30
121
|
def __init__(self, client_wrapper: Any):
|
|
31
122
|
self.client_wrapper = client_wrapper
|
|
123
|
+
self.model_name = ""
|
|
32
124
|
self.words: List[str] = []
|
|
33
125
|
self.g2p_model: Any = None # Placeholder for G2P model
|
|
34
126
|
self.dictionaries = defaultdict(lambda: [])
|
|
35
|
-
self.oov_word =
|
|
127
|
+
self.oov_word = "<unk>"
|
|
36
128
|
self.sentence_splitter = None
|
|
37
|
-
self.device =
|
|
129
|
+
self.device = "cpu"
|
|
38
130
|
|
|
39
131
|
def init_sentence_splitter(self):
|
|
40
132
|
if self.sentence_splitter is not None:
|
|
@@ -45,14 +137,14 @@ class LatticeTokenizer:
|
|
|
45
137
|
|
|
46
138
|
providers = []
|
|
47
139
|
device = self.device
|
|
48
|
-
if device.startswith(
|
|
49
|
-
providers.append(
|
|
50
|
-
elif device.startswith(
|
|
51
|
-
providers.append(
|
|
140
|
+
if device.startswith("cuda") and ort.get_all_providers().count("CUDAExecutionProvider") > 0:
|
|
141
|
+
providers.append("CUDAExecutionProvider")
|
|
142
|
+
elif device.startswith("mps") and ort.get_all_providers().count("MPSExecutionProvider") > 0:
|
|
143
|
+
providers.append("MPSExecutionProvider")
|
|
52
144
|
|
|
53
145
|
sat = SaT(
|
|
54
|
-
|
|
55
|
-
ort_providers=providers + [
|
|
146
|
+
"sat-3l-sm",
|
|
147
|
+
ort_providers=providers + ["CPUExecutionProvider"],
|
|
56
148
|
)
|
|
57
149
|
self.sentence_splitter = sat
|
|
58
150
|
|
|
@@ -79,23 +171,23 @@ class LatticeTokenizer:
|
|
|
79
171
|
# or other forms like [SOMETHING] SPEAKER:
|
|
80
172
|
|
|
81
173
|
# Pattern 1: [mark] HTML-encoded separator speaker:
|
|
82
|
-
pattern1 = r
|
|
174
|
+
pattern1 = r"^(\[[^\]]+\])\s+(>>|>>)\s+(.+)$"
|
|
83
175
|
match1 = re.match(pattern1, sentence.strip())
|
|
84
176
|
if match1:
|
|
85
177
|
special_mark = match1.group(1)
|
|
86
178
|
separator = match1.group(2)
|
|
87
179
|
speaker_part = match1.group(3)
|
|
88
|
-
return [special_mark, f
|
|
180
|
+
return [special_mark, f"{separator} {speaker_part}"]
|
|
89
181
|
|
|
90
182
|
# Pattern 2: [mark] speaker:
|
|
91
|
-
pattern2 = r
|
|
183
|
+
pattern2 = r"^(\[[^\]]+\])\s+([^:]+:)(.*)$"
|
|
92
184
|
match2 = re.match(pattern2, sentence.strip())
|
|
93
185
|
if match2:
|
|
94
186
|
special_mark = match2.group(1)
|
|
95
187
|
speaker_label = match2.group(2)
|
|
96
188
|
remaining = match2.group(3).strip()
|
|
97
189
|
if remaining:
|
|
98
|
-
return [special_mark, f
|
|
190
|
+
return [special_mark, f"{speaker_label} {remaining}"]
|
|
99
191
|
else:
|
|
100
192
|
return [special_mark, speaker_label]
|
|
101
193
|
|
|
@@ -107,28 +199,45 @@ class LatticeTokenizer:
|
|
|
107
199
|
cls: Type[TokenizerT],
|
|
108
200
|
client_wrapper: Any,
|
|
109
201
|
model_path: str,
|
|
110
|
-
|
|
202
|
+
model_name: str,
|
|
203
|
+
device: str = "cpu",
|
|
111
204
|
compressed: bool = True,
|
|
112
205
|
) -> TokenizerT:
|
|
113
206
|
"""Load tokenizer from exported binary file"""
|
|
114
207
|
from pathlib import Path
|
|
115
208
|
|
|
116
|
-
words_model_path = f
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
209
|
+
words_model_path = f"{model_path}/words.bin"
|
|
210
|
+
try:
|
|
211
|
+
if compressed:
|
|
212
|
+
with gzip.open(words_model_path, "rb") as f:
|
|
213
|
+
data = pickle.load(f)
|
|
214
|
+
else:
|
|
215
|
+
with open(words_model_path, "rb") as f:
|
|
216
|
+
data = pickle.load(f)
|
|
217
|
+
except pickle.UnpicklingError as e:
|
|
218
|
+
del e
|
|
219
|
+
import msgpack
|
|
220
|
+
|
|
221
|
+
if compressed:
|
|
222
|
+
with gzip.open(words_model_path, "rb") as f:
|
|
223
|
+
data = msgpack.unpack(f, raw=False, strict_map_key=False)
|
|
224
|
+
else:
|
|
225
|
+
with open(words_model_path, "rb") as f:
|
|
226
|
+
data = msgpack.unpack(f, raw=False, strict_map_key=False)
|
|
123
227
|
|
|
124
228
|
tokenizer = cls(client_wrapper=client_wrapper)
|
|
125
|
-
tokenizer.
|
|
126
|
-
tokenizer.
|
|
127
|
-
tokenizer.
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
if
|
|
131
|
-
|
|
229
|
+
tokenizer.model_name = model_name
|
|
230
|
+
tokenizer.words = data["words"]
|
|
231
|
+
tokenizer.dictionaries = defaultdict(list, data["dictionaries"])
|
|
232
|
+
tokenizer.oov_word = data["oov_word"]
|
|
233
|
+
|
|
234
|
+
g2pp_model_path = f"{model_path}/g2pp.bin" if Path(f"{model_path}/g2pp.bin").exists() else None
|
|
235
|
+
if g2pp_model_path:
|
|
236
|
+
tokenizer.g2p_model = G2Phonemizer(g2pp_model_path, device=device)
|
|
237
|
+
else:
|
|
238
|
+
g2p_model_path = f"{model_path}/g2p.bin" if Path(f"{model_path}/g2p.bin").exists() else None
|
|
239
|
+
if g2p_model_path:
|
|
240
|
+
tokenizer.g2p_model = G2Phonemizer(g2p_model_path, device=device)
|
|
132
241
|
|
|
133
242
|
tokenizer.device = device
|
|
134
243
|
tokenizer.add_special_tokens()
|
|
@@ -136,18 +245,22 @@ class LatticeTokenizer:
|
|
|
136
245
|
|
|
137
246
|
def add_special_tokens(self):
|
|
138
247
|
tokenizer = self
|
|
139
|
-
for special_token in [
|
|
248
|
+
for special_token in [">>", ">"]:
|
|
140
249
|
if special_token not in tokenizer.dictionaries:
|
|
141
250
|
tokenizer.dictionaries[special_token] = tokenizer.dictionaries[tokenizer.oov_word]
|
|
142
251
|
return self
|
|
143
252
|
|
|
144
253
|
def prenormalize(self, texts: List[str], language: Optional[str] = None) -> List[str]:
|
|
145
254
|
if not self.g2p_model:
|
|
146
|
-
raise ValueError(
|
|
255
|
+
raise ValueError("G2P model is not loaded, cannot prenormalize texts")
|
|
147
256
|
|
|
148
257
|
oov_words = []
|
|
149
258
|
for text in texts:
|
|
150
|
-
|
|
259
|
+
text = normalize_html_text(text)
|
|
260
|
+
# support english, chinese and german tokenization
|
|
261
|
+
words = tokenize_multilingual_text(
|
|
262
|
+
text.lower().replace("-", " ").replace("—", " ").replace("–", " "), keep_spaces=False
|
|
263
|
+
)
|
|
151
264
|
oovs = [w.strip(PUNCTUATION) for w in words if w not in self.words]
|
|
152
265
|
if oovs:
|
|
153
266
|
oov_words.extend([w for w in oovs if (w not in self.words and len(w) <= MAXIMUM_WORD_LENGTH)])
|
|
@@ -156,7 +269,7 @@ class LatticeTokenizer:
|
|
|
156
269
|
if oov_words:
|
|
157
270
|
indexs = []
|
|
158
271
|
for k, _word in enumerate(oov_words):
|
|
159
|
-
if any(_word.startswith(p) and _word.endswith(q) for (p, q) in [(
|
|
272
|
+
if any(_word.startswith(p) and _word.endswith(q) for (p, q) in [("(", ")"), ("[", "]")]):
|
|
160
273
|
self.dictionaries[_word] = self.dictionaries[self.oov_word]
|
|
161
274
|
else:
|
|
162
275
|
_word = _word.strip(PUNCTUATION_SPACE)
|
|
@@ -187,38 +300,49 @@ class LatticeTokenizer:
|
|
|
187
300
|
|
|
188
301
|
Carefull about speaker changes.
|
|
189
302
|
"""
|
|
190
|
-
texts,
|
|
191
|
-
|
|
303
|
+
texts, speakers = [], []
|
|
304
|
+
text_len, sidx = 0, 0
|
|
305
|
+
|
|
306
|
+
def flush_segment(end_idx: int, speaker: Optional[str] = None):
|
|
307
|
+
"""Flush accumulated text from sidx to end_idx with given speaker."""
|
|
308
|
+
nonlocal text_len, sidx
|
|
309
|
+
if sidx <= end_idx:
|
|
310
|
+
if len(speakers) < len(texts) + 1:
|
|
311
|
+
speakers.append(speaker)
|
|
312
|
+
text = " ".join(sup.text for sup in supervisions[sidx : end_idx + 1])
|
|
313
|
+
texts.append(text)
|
|
314
|
+
sidx = end_idx + 1
|
|
315
|
+
text_len = 0
|
|
316
|
+
|
|
192
317
|
for s, supervision in enumerate(supervisions):
|
|
193
318
|
text_len += len(supervision.text)
|
|
319
|
+
is_last = s == len(supervisions) - 1
|
|
320
|
+
|
|
194
321
|
if supervision.speaker:
|
|
322
|
+
# Flush previous segment without speaker (if any)
|
|
195
323
|
if sidx < s:
|
|
196
|
-
|
|
197
|
-
speakers.append(None)
|
|
198
|
-
text = ' '.join([sup.text for sup in supervisions[sidx:s]])
|
|
199
|
-
texts.append(text)
|
|
200
|
-
sidx = s
|
|
324
|
+
flush_segment(s - 1, None)
|
|
201
325
|
text_len = len(supervision.text)
|
|
202
|
-
speakers.append(supervision.speaker)
|
|
203
326
|
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
327
|
+
# Check if we should flush this speaker's segment now
|
|
328
|
+
next_has_speaker = not is_last and supervisions[s + 1].speaker
|
|
329
|
+
if is_last or next_has_speaker:
|
|
330
|
+
flush_segment(s, supervision.speaker)
|
|
331
|
+
else:
|
|
332
|
+
speakers.append(supervision.speaker)
|
|
333
|
+
|
|
334
|
+
elif text_len >= 2000 or is_last:
|
|
335
|
+
flush_segment(s, None)
|
|
336
|
+
|
|
337
|
+
assert len(speakers) == len(texts), f"len(speakers)={len(speakers)} != len(texts)={len(texts)}"
|
|
214
338
|
sentences = self.sentence_splitter.split(texts, threshold=0.15, strip_whitespace=strip_whitespace)
|
|
215
339
|
|
|
216
|
-
supervisions, remainder = [],
|
|
340
|
+
supervisions, remainder = [], ""
|
|
217
341
|
for k, (_speaker, _sentences) in enumerate(zip(speakers, sentences)):
|
|
218
342
|
# Prepend remainder from previous iteration to the first sentence
|
|
219
343
|
if _sentences and remainder:
|
|
220
344
|
_sentences[0] = remainder + _sentences[0]
|
|
221
|
-
remainder =
|
|
345
|
+
remainder = ""
|
|
222
346
|
|
|
223
347
|
if not _sentences:
|
|
224
348
|
continue
|
|
@@ -228,14 +352,14 @@ class LatticeTokenizer:
|
|
|
228
352
|
for s, _sentence in enumerate(_sentences):
|
|
229
353
|
if remainder:
|
|
230
354
|
_sentence = remainder + _sentence
|
|
231
|
-
remainder =
|
|
355
|
+
remainder = ""
|
|
232
356
|
# Detect and split special sentence types: e.g., '[APPLAUSE] >> MIRA MURATI:' -> ['[APPLAUSE]', '>> MIRA MURATI:'] # noqa: E501
|
|
233
357
|
resplit_parts = self._resplit_special_sentence_types(_sentence)
|
|
234
|
-
if any(resplit_parts[-1].endswith(sp) for sp in [
|
|
358
|
+
if any(resplit_parts[-1].endswith(sp) for sp in [":", ":"]):
|
|
235
359
|
if s < len(_sentences) - 1:
|
|
236
|
-
_sentences[s + 1] = resplit_parts[-1] +
|
|
360
|
+
_sentences[s + 1] = resplit_parts[-1] + " " + _sentences[s + 1]
|
|
237
361
|
else: # last part
|
|
238
|
-
remainder = resplit_parts[-1] +
|
|
362
|
+
remainder = resplit_parts[-1] + " "
|
|
239
363
|
processed_sentences.extend(resplit_parts[:-1])
|
|
240
364
|
else:
|
|
241
365
|
processed_sentences.extend(resplit_parts)
|
|
@@ -243,7 +367,7 @@ class LatticeTokenizer:
|
|
|
243
367
|
|
|
244
368
|
if not _sentences:
|
|
245
369
|
if remainder:
|
|
246
|
-
_sentences, remainder = [remainder.strip()],
|
|
370
|
+
_sentences, remainder = [remainder.strip()], ""
|
|
247
371
|
else:
|
|
248
372
|
continue
|
|
249
373
|
|
|
@@ -257,12 +381,12 @@ class LatticeTokenizer:
|
|
|
257
381
|
Supervision(text=text, speaker=(_speaker if s == 0 else None))
|
|
258
382
|
for s, text in enumerate(_sentences[:-1])
|
|
259
383
|
)
|
|
260
|
-
remainder = _sentences[-1] +
|
|
384
|
+
remainder = _sentences[-1] + " " + remainder
|
|
261
385
|
if k < len(speakers) - 1 and speakers[k + 1] is not None: # next speaker is set
|
|
262
386
|
supervisions.append(
|
|
263
387
|
Supervision(text=remainder.strip(), speaker=_speaker if len(_sentences) == 1 else None)
|
|
264
388
|
)
|
|
265
|
-
remainder =
|
|
389
|
+
remainder = ""
|
|
266
390
|
elif len(_sentences) == 1:
|
|
267
391
|
if k == len(speakers) - 1:
|
|
268
392
|
pass # keep _speaker for the last supervision
|
|
@@ -285,20 +409,23 @@ class LatticeTokenizer:
|
|
|
285
409
|
|
|
286
410
|
pronunciation_dictionaries = self.prenormalize([s.text for s in supervisions])
|
|
287
411
|
response = self.client_wrapper.post(
|
|
288
|
-
|
|
412
|
+
"tokenize",
|
|
289
413
|
json={
|
|
290
|
-
|
|
291
|
-
|
|
414
|
+
"model_name": self.model_name,
|
|
415
|
+
"supervisions": [s.to_dict() for s in supervisions],
|
|
416
|
+
"pronunciation_dictionaries": pronunciation_dictionaries,
|
|
292
417
|
},
|
|
293
418
|
)
|
|
419
|
+
if response.status_code == 402:
|
|
420
|
+
raise QuotaExceededError(response.json().get("detail", "Quota exceeded"))
|
|
294
421
|
if response.status_code != 200:
|
|
295
|
-
raise Exception(f
|
|
422
|
+
raise Exception(f"Failed to tokenize texts: {response.text}")
|
|
296
423
|
result = response.json()
|
|
297
|
-
lattice_id = result[
|
|
424
|
+
lattice_id = result["id"]
|
|
298
425
|
return (
|
|
299
426
|
supervisions,
|
|
300
427
|
lattice_id,
|
|
301
|
-
(result[
|
|
428
|
+
(result["lattice_graph"], result["final_state"], result.get("acoustic_scale", 1.0)),
|
|
302
429
|
)
|
|
303
430
|
|
|
304
431
|
def detokenize(
|
|
@@ -310,16 +437,17 @@ class LatticeTokenizer:
|
|
|
310
437
|
) -> List[Supervision]:
|
|
311
438
|
emission, results, labels, frame_shift, offset, channel = lattice_results # noqa: F841
|
|
312
439
|
response = self.client_wrapper.post(
|
|
313
|
-
|
|
440
|
+
"detokenize",
|
|
314
441
|
json={
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
442
|
+
"model_name": self.model_name,
|
|
443
|
+
"lattice_id": lattice_id,
|
|
444
|
+
"frame_shift": frame_shift,
|
|
445
|
+
"results": [t.to_dict() for t in results[0]],
|
|
446
|
+
"labels": labels[0],
|
|
447
|
+
"offset": offset,
|
|
448
|
+
"channel": channel,
|
|
449
|
+
"return_details": False if return_details is None else return_details,
|
|
450
|
+
"destroy_lattice": True,
|
|
323
451
|
},
|
|
324
452
|
)
|
|
325
453
|
if response.status_code == 422:
|
|
@@ -327,94 +455,20 @@ class LatticeTokenizer:
|
|
|
327
455
|
lattice_id,
|
|
328
456
|
original_error=Exception(LATTICE_DECODING_FAILURE_HELP),
|
|
329
457
|
)
|
|
458
|
+
if response.status_code == 402:
|
|
459
|
+
raise QuotaExceededError(response.json().get("detail", "Quota exceeded"))
|
|
330
460
|
if response.status_code != 200:
|
|
331
|
-
raise Exception(f
|
|
461
|
+
raise Exception(f"Failed to detokenize lattice: {response.text}")
|
|
332
462
|
|
|
333
463
|
result = response.json()
|
|
334
|
-
if not result.get(
|
|
335
|
-
raise Exception(
|
|
464
|
+
if not result.get("success"):
|
|
465
|
+
raise Exception("Failed to detokenize the alignment results.")
|
|
336
466
|
|
|
337
|
-
alignments = [Supervision.from_dict(s) for s in result[
|
|
467
|
+
alignments = [Supervision.from_dict(s) for s in result["supervisions"]]
|
|
338
468
|
|
|
339
469
|
if return_details:
|
|
340
470
|
# Add emission confidence scores for segments and word-level alignments
|
|
341
|
-
_add_confidence_scores(alignments, emission, labels[0], frame_shift)
|
|
342
|
-
|
|
343
|
-
alignments = _update_alignments_speaker(supervisions, alignments)
|
|
344
|
-
|
|
345
|
-
return alignments
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
class AsyncLatticeTokenizer(LatticeTokenizer):
|
|
349
|
-
async def _post_async(self, endpoint: str, **kwargs):
|
|
350
|
-
response = self.client_wrapper.post(endpoint, **kwargs)
|
|
351
|
-
if inspect.isawaitable(response):
|
|
352
|
-
return await response
|
|
353
|
-
return response
|
|
354
|
-
|
|
355
|
-
async def tokenize(
|
|
356
|
-
self, supervisions: List[Supervision], split_sentence: bool = False
|
|
357
|
-
) -> Tuple[str, Dict[str, Any]]:
|
|
358
|
-
if split_sentence:
|
|
359
|
-
self.init_sentence_splitter()
|
|
360
|
-
supervisions = self.split_sentences(supervisions)
|
|
361
|
-
|
|
362
|
-
pronunciation_dictionaries = self.prenormalize([s.text for s in supervisions])
|
|
363
|
-
response = await self._post_async(
|
|
364
|
-
'tokenize',
|
|
365
|
-
json={
|
|
366
|
-
'supervisions': [s.to_dict() for s in supervisions],
|
|
367
|
-
'pronunciation_dictionaries': pronunciation_dictionaries,
|
|
368
|
-
},
|
|
369
|
-
)
|
|
370
|
-
if response.status_code != 200:
|
|
371
|
-
raise Exception(f'Failed to tokenize texts: {response.text}')
|
|
372
|
-
result = response.json()
|
|
373
|
-
lattice_id = result['id']
|
|
374
|
-
return (
|
|
375
|
-
supervisions,
|
|
376
|
-
lattice_id,
|
|
377
|
-
(result['lattice_graph'], result['final_state'], result.get('acoustic_scale', 1.0)),
|
|
378
|
-
)
|
|
379
|
-
|
|
380
|
-
async def detokenize(
|
|
381
|
-
self,
|
|
382
|
-
lattice_id: str,
|
|
383
|
-
lattice_results: Tuple[torch.Tensor, Any, Any, float, float],
|
|
384
|
-
supervisions: List[Supervision],
|
|
385
|
-
return_details: bool = False,
|
|
386
|
-
) -> List[Supervision]:
|
|
387
|
-
emission, results, labels, frame_shift, offset, channel = lattice_results # noqa: F841
|
|
388
|
-
response = await self._post_async(
|
|
389
|
-
'detokenize',
|
|
390
|
-
json={
|
|
391
|
-
'lattice_id': lattice_id,
|
|
392
|
-
'frame_shift': frame_shift,
|
|
393
|
-
'results': [t.to_dict() for t in results[0]],
|
|
394
|
-
'labels': labels[0],
|
|
395
|
-
'offset': offset,
|
|
396
|
-
'channel': channel,
|
|
397
|
-
'return_details': return_details,
|
|
398
|
-
'destroy_lattice': True,
|
|
399
|
-
},
|
|
400
|
-
)
|
|
401
|
-
if response.status_code == 422:
|
|
402
|
-
raise LatticeDecodingError(
|
|
403
|
-
lattice_id,
|
|
404
|
-
original_error=Exception(LATTICE_DECODING_FAILURE_HELP),
|
|
405
|
-
)
|
|
406
|
-
if response.status_code != 200:
|
|
407
|
-
raise Exception(f'Failed to detokenize lattice: {response.text}')
|
|
408
|
-
|
|
409
|
-
result = response.json()
|
|
410
|
-
if not result.get('success'):
|
|
411
|
-
return Exception('Failed to detokenize the alignment results.')
|
|
412
|
-
|
|
413
|
-
alignments = [Supervision.from_dict(s) for s in result['supervisions']]
|
|
414
|
-
|
|
415
|
-
if return_details:
|
|
416
|
-
# Add emission confidence scores for segments and word-level alignments
|
|
417
|
-
_add_confidence_scores(alignments, emission, labels[0], frame_shift)
|
|
471
|
+
_add_confidence_scores(alignments, emission, labels[0], frame_shift, offset)
|
|
418
472
|
|
|
419
473
|
alignments = _update_alignments_speaker(supervisions, alignments)
|
|
420
474
|
|
|
@@ -426,6 +480,7 @@ def _add_confidence_scores(
|
|
|
426
480
|
emission: torch.Tensor,
|
|
427
481
|
labels: List[int],
|
|
428
482
|
frame_shift: float,
|
|
483
|
+
offset: float = 0.0,
|
|
429
484
|
) -> None:
|
|
430
485
|
"""
|
|
431
486
|
Add confidence scores to supervisions and their word-level alignments.
|
|
@@ -443,8 +498,8 @@ def _add_confidence_scores(
|
|
|
443
498
|
tokens = torch.tensor(labels, dtype=torch.int64, device=emission.device)
|
|
444
499
|
|
|
445
500
|
for supervision in supervisions:
|
|
446
|
-
start_frame = int(supervision.start / frame_shift)
|
|
447
|
-
end_frame = int(supervision.end / frame_shift)
|
|
501
|
+
start_frame = int((supervision.start - offset) / frame_shift)
|
|
502
|
+
end_frame = int((supervision.end - offset) / frame_shift)
|
|
448
503
|
|
|
449
504
|
# Compute segment-level confidence
|
|
450
505
|
probabilities = emission[0, start_frame:end_frame].softmax(dim=-1)
|
|
@@ -453,11 +508,11 @@ def _add_confidence_scores(
|
|
|
453
508
|
supervision.score = round(1.0 - diffprobs.mean().item(), ndigits=4)
|
|
454
509
|
|
|
455
510
|
# Compute word-level confidence if alignment exists
|
|
456
|
-
if hasattr(supervision,
|
|
457
|
-
words = supervision.alignment.get(
|
|
511
|
+
if hasattr(supervision, "alignment") and supervision.alignment:
|
|
512
|
+
words = supervision.alignment.get("word", [])
|
|
458
513
|
for w, item in enumerate(words):
|
|
459
|
-
start = int(item.start / frame_shift) - start_frame
|
|
460
|
-
end = int(item.end / frame_shift) - start_frame
|
|
514
|
+
start = int((item.start - offset) / frame_shift) - start_frame
|
|
515
|
+
end = int((item.end - offset) / frame_shift) - start_frame
|
|
461
516
|
words[w] = item._replace(score=round(1.0 - diffprobs[start:end].mean().item(), ndigits=4))
|
|
462
517
|
|
|
463
518
|
|
|
@@ -472,3 +527,23 @@ def _update_alignments_speaker(supervisions: List[Supervision], alignments: List
|
|
|
472
527
|
for supervision, alignment in zip(supervisions, alignments):
|
|
473
528
|
alignment.speaker = supervision.speaker
|
|
474
529
|
return alignments
|
|
530
|
+
|
|
531
|
+
|
|
532
|
+
def _load_tokenizer(
|
|
533
|
+
client_wrapper: Any,
|
|
534
|
+
model_path: str,
|
|
535
|
+
model_name: str,
|
|
536
|
+
device: str,
|
|
537
|
+
*,
|
|
538
|
+
tokenizer_cls: Type[LatticeTokenizer] = LatticeTokenizer,
|
|
539
|
+
) -> LatticeTokenizer:
|
|
540
|
+
"""Instantiate tokenizer with consistent error handling."""
|
|
541
|
+
try:
|
|
542
|
+
return tokenizer_cls.from_pretrained(
|
|
543
|
+
client_wrapper=client_wrapper,
|
|
544
|
+
model_path=model_path,
|
|
545
|
+
model_name=model_name,
|
|
546
|
+
device=device,
|
|
547
|
+
)
|
|
548
|
+
except Exception as e:
|
|
549
|
+
raise ModelLoadError(f"tokenizer from {model_path}", original_error=e)
|