lattifai 0.4.6__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lattifai/__init__.py +42 -27
- lattifai/alignment/__init__.py +6 -0
- lattifai/alignment/lattice1_aligner.py +119 -0
- lattifai/{workers/lattice1_alpha.py → alignment/lattice1_worker.py} +33 -132
- lattifai/{tokenizer → alignment}/phonemizer.py +1 -1
- lattifai/alignment/segmenter.py +166 -0
- lattifai/{tokenizer → alignment}/tokenizer.py +186 -112
- lattifai/audio2.py +211 -0
- lattifai/caption/__init__.py +20 -0
- lattifai/caption/caption.py +1275 -0
- lattifai/{io → caption}/supervision.py +1 -0
- lattifai/{io → caption}/text_parser.py +53 -10
- lattifai/cli/__init__.py +17 -0
- lattifai/cli/alignment.py +153 -0
- lattifai/cli/caption.py +204 -0
- lattifai/cli/server.py +19 -0
- lattifai/cli/transcribe.py +197 -0
- lattifai/cli/youtube.py +128 -0
- lattifai/client.py +455 -246
- lattifai/config/__init__.py +20 -0
- lattifai/config/alignment.py +73 -0
- lattifai/config/caption.py +178 -0
- lattifai/config/client.py +46 -0
- lattifai/config/diarization.py +67 -0
- lattifai/config/media.py +335 -0
- lattifai/config/transcription.py +84 -0
- lattifai/diarization/__init__.py +5 -0
- lattifai/diarization/lattifai.py +89 -0
- lattifai/errors.py +41 -34
- lattifai/logging.py +116 -0
- lattifai/mixin.py +552 -0
- lattifai/server/app.py +420 -0
- lattifai/transcription/__init__.py +76 -0
- lattifai/transcription/base.py +108 -0
- lattifai/transcription/gemini.py +219 -0
- lattifai/transcription/lattifai.py +103 -0
- lattifai/types.py +30 -0
- lattifai/utils.py +3 -31
- lattifai/workflow/__init__.py +22 -0
- lattifai/workflow/agents.py +6 -0
- lattifai/{workflows → workflow}/file_manager.py +81 -57
- lattifai/workflow/youtube.py +564 -0
- lattifai-1.0.0.dist-info/METADATA +736 -0
- lattifai-1.0.0.dist-info/RECORD +52 -0
- {lattifai-0.4.6.dist-info → lattifai-1.0.0.dist-info}/WHEEL +1 -1
- lattifai-1.0.0.dist-info/entry_points.txt +13 -0
- lattifai/base_client.py +0 -126
- lattifai/bin/__init__.py +0 -3
- lattifai/bin/agent.py +0 -324
- lattifai/bin/align.py +0 -295
- lattifai/bin/cli_base.py +0 -25
- lattifai/bin/subtitle.py +0 -210
- lattifai/io/__init__.py +0 -43
- lattifai/io/reader.py +0 -86
- lattifai/io/utils.py +0 -15
- lattifai/io/writer.py +0 -102
- lattifai/tokenizer/__init__.py +0 -3
- lattifai/workers/__init__.py +0 -3
- lattifai/workflows/__init__.py +0 -34
- lattifai/workflows/agents.py +0 -12
- lattifai/workflows/gemini.py +0 -167
- lattifai/workflows/prompts/README.md +0 -22
- lattifai/workflows/prompts/gemini/README.md +0 -24
- lattifai/workflows/prompts/gemini/transcription_gem.txt +0 -81
- lattifai/workflows/youtube.py +0 -931
- lattifai-0.4.6.dist-info/METADATA +0 -806
- lattifai-0.4.6.dist-info/RECORD +0 -39
- lattifai-0.4.6.dist-info/entry_points.txt +0 -3
- /lattifai/{io → caption}/gemini_reader.py +0 -0
- /lattifai/{io → caption}/gemini_writer.py +0 -0
- /lattifai/{workflows → transcription}/prompts/__init__.py +0 -0
- /lattifai/{workflows → workflow}/base.py +0 -0
- {lattifai-0.4.6.dist-info → lattifai-1.0.0.dist-info}/licenses/LICENSE +0 -0
- {lattifai-0.4.6.dist-info → lattifai-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import gzip
|
|
2
|
-
import inspect
|
|
3
2
|
import pickle
|
|
4
3
|
import re
|
|
5
4
|
from collections import defaultdict
|
|
@@ -7,9 +6,15 @@ from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar
|
|
|
7
6
|
|
|
8
7
|
import torch
|
|
9
8
|
|
|
10
|
-
from lattifai.
|
|
11
|
-
from lattifai.
|
|
12
|
-
from lattifai.
|
|
9
|
+
from lattifai.alignment.phonemizer import G2Phonemizer
|
|
10
|
+
from lattifai.caption import Supervision
|
|
11
|
+
from lattifai.caption import normalize_text as normalize_html_text
|
|
12
|
+
from lattifai.errors import (
|
|
13
|
+
LATTICE_DECODING_FAILURE_HELP,
|
|
14
|
+
LatticeDecodingError,
|
|
15
|
+
ModelLoadError,
|
|
16
|
+
QuotaExceededError,
|
|
17
|
+
)
|
|
13
18
|
|
|
14
19
|
PUNCTUATION = '!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~'
|
|
15
20
|
END_PUNCTUATION = '.!?"]。!?”】'
|
|
@@ -24,11 +29,98 @@ MAXIMUM_WORD_LENGTH = 40
|
|
|
24
29
|
TokenizerT = TypeVar("TokenizerT", bound="LatticeTokenizer")
|
|
25
30
|
|
|
26
31
|
|
|
32
|
+
def _is_punctuation(char: str) -> bool:
|
|
33
|
+
"""Check if a character is punctuation (not space, not alphanumeric, not CJK)."""
|
|
34
|
+
if len(char) != 1:
|
|
35
|
+
return False
|
|
36
|
+
if char.isspace():
|
|
37
|
+
return False
|
|
38
|
+
if char.isalnum():
|
|
39
|
+
return False
|
|
40
|
+
# Check if it's a CJK character
|
|
41
|
+
if "\u4e00" <= char <= "\u9fff":
|
|
42
|
+
return False
|
|
43
|
+
# Check if it's an accented Latin character
|
|
44
|
+
if "\u00c0" <= char <= "\u024f":
|
|
45
|
+
return False
|
|
46
|
+
return True
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def tokenize_multilingual_text(text: str, keep_spaces: bool = True, attach_punctuation: bool = False) -> list[str]:
|
|
50
|
+
"""
|
|
51
|
+
Tokenize a mixed Chinese-English string into individual units.
|
|
52
|
+
|
|
53
|
+
Tokenization rules:
|
|
54
|
+
- Chinese characters (CJK) are split individually
|
|
55
|
+
- Consecutive Latin letters (including accented characters) and digits are grouped as one unit
|
|
56
|
+
- English contractions ('s, 't, 'm, 'll, 're, 've) are kept with the preceding word
|
|
57
|
+
- Other characters (punctuation, spaces) are split individually by default
|
|
58
|
+
- If attach_punctuation=True, punctuation marks are attached to the preceding token
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
text: Input string containing mixed Chinese and English text
|
|
62
|
+
keep_spaces: If True, spaces are included in the output as separate tokens.
|
|
63
|
+
If False, spaces are excluded from the output. Default is True.
|
|
64
|
+
attach_punctuation: If True, punctuation marks are attached to the preceding token.
|
|
65
|
+
For example, "Hello, World!" becomes ["Hello,", " ", "World!"].
|
|
66
|
+
Default is False.
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
List of tokenized units
|
|
70
|
+
|
|
71
|
+
Examples:
|
|
72
|
+
>>> tokenize_multilingual_text("Hello世界")
|
|
73
|
+
['Hello', '世', '界']
|
|
74
|
+
>>> tokenize_multilingual_text("I'm fine")
|
|
75
|
+
["I'm", ' ', 'fine']
|
|
76
|
+
>>> tokenize_multilingual_text("I'm fine", keep_spaces=False)
|
|
77
|
+
["I'm", 'fine']
|
|
78
|
+
>>> tokenize_multilingual_text("Kühlschrank")
|
|
79
|
+
['Kühlschrank']
|
|
80
|
+
>>> tokenize_multilingual_text("Hello, World!", attach_punctuation=True)
|
|
81
|
+
['Hello,', ' ', 'World!']
|
|
82
|
+
"""
|
|
83
|
+
# Regex pattern:
|
|
84
|
+
# - [a-zA-Z0-9\u00C0-\u024F]+ matches Latin letters (including accented chars like ü, ö, ä, ß, é, etc.)
|
|
85
|
+
# - (?:'[a-zA-Z]{1,2})? optionally matches contractions like 's, 't, 'm, 'll, 're, 've
|
|
86
|
+
# - [\u4e00-\u9fff] matches CJK characters
|
|
87
|
+
# - . matches any other single character
|
|
88
|
+
# Unicode ranges:
|
|
89
|
+
# - \u00C0-\u00FF: Latin-1 Supplement (À-ÿ)
|
|
90
|
+
# - \u0100-\u017F: Latin Extended-A
|
|
91
|
+
# - \u0180-\u024F: Latin Extended-B
|
|
92
|
+
pattern = re.compile(r"([a-zA-Z0-9\u00C0-\u024F]+(?:'[a-zA-Z]{1,2})?|[\u4e00-\u9fff]|.)")
|
|
93
|
+
|
|
94
|
+
# filter(None, ...) removes any empty strings from re.findall results
|
|
95
|
+
tokens = list(filter(None, pattern.findall(text)))
|
|
96
|
+
|
|
97
|
+
if attach_punctuation and len(tokens) > 1:
|
|
98
|
+
# Attach punctuation to the preceding token
|
|
99
|
+
# Punctuation characters (excluding spaces) are merged with the previous token
|
|
100
|
+
merged_tokens = []
|
|
101
|
+
i = 0
|
|
102
|
+
while i < len(tokens):
|
|
103
|
+
token = tokens[i]
|
|
104
|
+
# Look ahead to collect consecutive punctuation (non-space, non-alphanumeric, non-CJK)
|
|
105
|
+
if merged_tokens and _is_punctuation(token):
|
|
106
|
+
merged_tokens[-1] = merged_tokens[-1] + token
|
|
107
|
+
else:
|
|
108
|
+
merged_tokens.append(token)
|
|
109
|
+
i += 1
|
|
110
|
+
tokens = merged_tokens
|
|
111
|
+
|
|
112
|
+
if not keep_spaces:
|
|
113
|
+
tokens = [t for t in tokens if not t.isspace()]
|
|
114
|
+
|
|
115
|
+
return tokens
|
|
116
|
+
|
|
117
|
+
|
|
27
118
|
class LatticeTokenizer:
|
|
28
119
|
"""Tokenizer for converting Lhotse Cut to LatticeGraph."""
|
|
29
120
|
|
|
30
121
|
def __init__(self, client_wrapper: Any):
|
|
31
122
|
self.client_wrapper = client_wrapper
|
|
123
|
+
self.model_name = ""
|
|
32
124
|
self.words: List[str] = []
|
|
33
125
|
self.g2p_model: Any = None # Placeholder for G2P model
|
|
34
126
|
self.dictionaries = defaultdict(lambda: [])
|
|
@@ -107,6 +199,7 @@ class LatticeTokenizer:
|
|
|
107
199
|
cls: Type[TokenizerT],
|
|
108
200
|
client_wrapper: Any,
|
|
109
201
|
model_path: str,
|
|
202
|
+
model_name: str,
|
|
110
203
|
device: str = "cpu",
|
|
111
204
|
compressed: bool = True,
|
|
112
205
|
) -> TokenizerT:
|
|
@@ -114,21 +207,37 @@ class LatticeTokenizer:
|
|
|
114
207
|
from pathlib import Path
|
|
115
208
|
|
|
116
209
|
words_model_path = f"{model_path}/words.bin"
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
210
|
+
try:
|
|
211
|
+
if compressed:
|
|
212
|
+
with gzip.open(words_model_path, "rb") as f:
|
|
213
|
+
data = pickle.load(f)
|
|
214
|
+
else:
|
|
215
|
+
with open(words_model_path, "rb") as f:
|
|
216
|
+
data = pickle.load(f)
|
|
217
|
+
except pickle.UnpicklingError as e:
|
|
218
|
+
del e
|
|
219
|
+
import msgpack
|
|
220
|
+
|
|
221
|
+
if compressed:
|
|
222
|
+
with gzip.open(words_model_path, "rb") as f:
|
|
223
|
+
data = msgpack.unpack(f, raw=False, strict_map_key=False)
|
|
224
|
+
else:
|
|
225
|
+
with open(words_model_path, "rb") as f:
|
|
226
|
+
data = msgpack.unpack(f, raw=False, strict_map_key=False)
|
|
123
227
|
|
|
124
228
|
tokenizer = cls(client_wrapper=client_wrapper)
|
|
229
|
+
tokenizer.model_name = model_name
|
|
125
230
|
tokenizer.words = data["words"]
|
|
126
231
|
tokenizer.dictionaries = defaultdict(list, data["dictionaries"])
|
|
127
232
|
tokenizer.oov_word = data["oov_word"]
|
|
128
233
|
|
|
129
|
-
|
|
130
|
-
if
|
|
131
|
-
tokenizer.g2p_model = G2Phonemizer(
|
|
234
|
+
g2pp_model_path = f"{model_path}/g2pp.bin" if Path(f"{model_path}/g2pp.bin").exists() else None
|
|
235
|
+
if g2pp_model_path:
|
|
236
|
+
tokenizer.g2p_model = G2Phonemizer(g2pp_model_path, device=device)
|
|
237
|
+
else:
|
|
238
|
+
g2p_model_path = f"{model_path}/g2p.bin" if Path(f"{model_path}/g2p.bin").exists() else None
|
|
239
|
+
if g2p_model_path:
|
|
240
|
+
tokenizer.g2p_model = G2Phonemizer(g2p_model_path, device=device)
|
|
132
241
|
|
|
133
242
|
tokenizer.device = device
|
|
134
243
|
tokenizer.add_special_tokens()
|
|
@@ -148,7 +257,10 @@ class LatticeTokenizer:
|
|
|
148
257
|
oov_words = []
|
|
149
258
|
for text in texts:
|
|
150
259
|
text = normalize_html_text(text)
|
|
151
|
-
|
|
260
|
+
# support english, chinese and german tokenization
|
|
261
|
+
words = tokenize_multilingual_text(
|
|
262
|
+
text.lower().replace("-", " ").replace("—", " ").replace("–", " "), keep_spaces=False
|
|
263
|
+
)
|
|
152
264
|
oovs = [w.strip(PUNCTUATION) for w in words if w not in self.words]
|
|
153
265
|
if oovs:
|
|
154
266
|
oov_words.extend([w for w in oovs if (w not in self.words and len(w) <= MAXIMUM_WORD_LENGTH)])
|
|
@@ -188,28 +300,39 @@ class LatticeTokenizer:
|
|
|
188
300
|
|
|
189
301
|
Carefull about speaker changes.
|
|
190
302
|
"""
|
|
191
|
-
texts,
|
|
192
|
-
|
|
303
|
+
texts, speakers = [], []
|
|
304
|
+
text_len, sidx = 0, 0
|
|
305
|
+
|
|
306
|
+
def flush_segment(end_idx: int, speaker: Optional[str] = None):
|
|
307
|
+
"""Flush accumulated text from sidx to end_idx with given speaker."""
|
|
308
|
+
nonlocal text_len, sidx
|
|
309
|
+
if sidx <= end_idx:
|
|
310
|
+
if len(speakers) < len(texts) + 1:
|
|
311
|
+
speakers.append(speaker)
|
|
312
|
+
text = " ".join(sup.text for sup in supervisions[sidx : end_idx + 1])
|
|
313
|
+
texts.append(text)
|
|
314
|
+
sidx = end_idx + 1
|
|
315
|
+
text_len = 0
|
|
316
|
+
|
|
193
317
|
for s, supervision in enumerate(supervisions):
|
|
194
318
|
text_len += len(supervision.text)
|
|
319
|
+
is_last = s == len(supervisions) - 1
|
|
320
|
+
|
|
195
321
|
if supervision.speaker:
|
|
322
|
+
# Flush previous segment without speaker (if any)
|
|
196
323
|
if sidx < s:
|
|
197
|
-
|
|
198
|
-
speakers.append(None)
|
|
199
|
-
text = " ".join([sup.text for sup in supervisions[sidx:s]])
|
|
200
|
-
texts.append(text)
|
|
201
|
-
sidx = s
|
|
324
|
+
flush_segment(s - 1, None)
|
|
202
325
|
text_len = len(supervision.text)
|
|
203
|
-
speakers.append(supervision.speaker)
|
|
204
326
|
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
327
|
+
# Check if we should flush this speaker's segment now
|
|
328
|
+
next_has_speaker = not is_last and supervisions[s + 1].speaker
|
|
329
|
+
if is_last or next_has_speaker:
|
|
330
|
+
flush_segment(s, supervision.speaker)
|
|
331
|
+
else:
|
|
332
|
+
speakers.append(supervision.speaker)
|
|
333
|
+
|
|
334
|
+
elif text_len >= 2000 or is_last:
|
|
335
|
+
flush_segment(s, None)
|
|
213
336
|
|
|
214
337
|
assert len(speakers) == len(texts), f"len(speakers)={len(speakers)} != len(texts)={len(texts)}"
|
|
215
338
|
sentences = self.sentence_splitter.split(texts, threshold=0.15, strip_whitespace=strip_whitespace)
|
|
@@ -288,10 +411,13 @@ class LatticeTokenizer:
|
|
|
288
411
|
response = self.client_wrapper.post(
|
|
289
412
|
"tokenize",
|
|
290
413
|
json={
|
|
414
|
+
"model_name": self.model_name,
|
|
291
415
|
"supervisions": [s.to_dict() for s in supervisions],
|
|
292
416
|
"pronunciation_dictionaries": pronunciation_dictionaries,
|
|
293
417
|
},
|
|
294
418
|
)
|
|
419
|
+
if response.status_code == 402:
|
|
420
|
+
raise QuotaExceededError(response.json().get("detail", "Quota exceeded"))
|
|
295
421
|
if response.status_code != 200:
|
|
296
422
|
raise Exception(f"Failed to tokenize texts: {response.text}")
|
|
297
423
|
result = response.json()
|
|
@@ -313,13 +439,14 @@ class LatticeTokenizer:
|
|
|
313
439
|
response = self.client_wrapper.post(
|
|
314
440
|
"detokenize",
|
|
315
441
|
json={
|
|
442
|
+
"model_name": self.model_name,
|
|
316
443
|
"lattice_id": lattice_id,
|
|
317
444
|
"frame_shift": frame_shift,
|
|
318
445
|
"results": [t.to_dict() for t in results[0]],
|
|
319
446
|
"labels": labels[0],
|
|
320
447
|
"offset": offset,
|
|
321
448
|
"channel": channel,
|
|
322
|
-
"return_details": return_details,
|
|
449
|
+
"return_details": False if return_details is None else return_details,
|
|
323
450
|
"destroy_lattice": True,
|
|
324
451
|
},
|
|
325
452
|
)
|
|
@@ -328,6 +455,8 @@ class LatticeTokenizer:
|
|
|
328
455
|
lattice_id,
|
|
329
456
|
original_error=Exception(LATTICE_DECODING_FAILURE_HELP),
|
|
330
457
|
)
|
|
458
|
+
if response.status_code == 402:
|
|
459
|
+
raise QuotaExceededError(response.json().get("detail", "Quota exceeded"))
|
|
331
460
|
if response.status_code != 200:
|
|
332
461
|
raise Exception(f"Failed to detokenize lattice: {response.text}")
|
|
333
462
|
|
|
@@ -339,83 +468,7 @@ class LatticeTokenizer:
|
|
|
339
468
|
|
|
340
469
|
if return_details:
|
|
341
470
|
# Add emission confidence scores for segments and word-level alignments
|
|
342
|
-
_add_confidence_scores(alignments, emission, labels[0], frame_shift)
|
|
343
|
-
|
|
344
|
-
alignments = _update_alignments_speaker(supervisions, alignments)
|
|
345
|
-
|
|
346
|
-
return alignments
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
class AsyncLatticeTokenizer(LatticeTokenizer):
|
|
350
|
-
async def _post_async(self, endpoint: str, **kwargs):
|
|
351
|
-
response = self.client_wrapper.post(endpoint, **kwargs)
|
|
352
|
-
if inspect.isawaitable(response):
|
|
353
|
-
return await response
|
|
354
|
-
return response
|
|
355
|
-
|
|
356
|
-
async def tokenize(
|
|
357
|
-
self, supervisions: List[Supervision], split_sentence: bool = False
|
|
358
|
-
) -> Tuple[str, Dict[str, Any]]:
|
|
359
|
-
if split_sentence:
|
|
360
|
-
self.init_sentence_splitter()
|
|
361
|
-
supervisions = self.split_sentences(supervisions)
|
|
362
|
-
|
|
363
|
-
pronunciation_dictionaries = self.prenormalize([s.text for s in supervisions])
|
|
364
|
-
response = await self._post_async(
|
|
365
|
-
"tokenize",
|
|
366
|
-
json={
|
|
367
|
-
"supervisions": [s.to_dict() for s in supervisions],
|
|
368
|
-
"pronunciation_dictionaries": pronunciation_dictionaries,
|
|
369
|
-
},
|
|
370
|
-
)
|
|
371
|
-
if response.status_code != 200:
|
|
372
|
-
raise Exception(f"Failed to tokenize texts: {response.text}")
|
|
373
|
-
result = response.json()
|
|
374
|
-
lattice_id = result["id"]
|
|
375
|
-
return (
|
|
376
|
-
supervisions,
|
|
377
|
-
lattice_id,
|
|
378
|
-
(result["lattice_graph"], result["final_state"], result.get("acoustic_scale", 1.0)),
|
|
379
|
-
)
|
|
380
|
-
|
|
381
|
-
async def detokenize(
|
|
382
|
-
self,
|
|
383
|
-
lattice_id: str,
|
|
384
|
-
lattice_results: Tuple[torch.Tensor, Any, Any, float, float],
|
|
385
|
-
supervisions: List[Supervision],
|
|
386
|
-
return_details: bool = False,
|
|
387
|
-
) -> List[Supervision]:
|
|
388
|
-
emission, results, labels, frame_shift, offset, channel = lattice_results # noqa: F841
|
|
389
|
-
response = await self._post_async(
|
|
390
|
-
"detokenize",
|
|
391
|
-
json={
|
|
392
|
-
"lattice_id": lattice_id,
|
|
393
|
-
"frame_shift": frame_shift,
|
|
394
|
-
"results": [t.to_dict() for t in results[0]],
|
|
395
|
-
"labels": labels[0],
|
|
396
|
-
"offset": offset,
|
|
397
|
-
"channel": channel,
|
|
398
|
-
"return_details": return_details,
|
|
399
|
-
"destroy_lattice": True,
|
|
400
|
-
},
|
|
401
|
-
)
|
|
402
|
-
if response.status_code == 422:
|
|
403
|
-
raise LatticeDecodingError(
|
|
404
|
-
lattice_id,
|
|
405
|
-
original_error=Exception(LATTICE_DECODING_FAILURE_HELP),
|
|
406
|
-
)
|
|
407
|
-
if response.status_code != 200:
|
|
408
|
-
raise Exception(f"Failed to detokenize lattice: {response.text}")
|
|
409
|
-
|
|
410
|
-
result = response.json()
|
|
411
|
-
if not result.get("success"):
|
|
412
|
-
return Exception("Failed to detokenize the alignment results.")
|
|
413
|
-
|
|
414
|
-
alignments = [Supervision.from_dict(s) for s in result["supervisions"]]
|
|
415
|
-
|
|
416
|
-
if return_details:
|
|
417
|
-
# Add emission confidence scores for segments and word-level alignments
|
|
418
|
-
_add_confidence_scores(alignments, emission, labels[0], frame_shift)
|
|
471
|
+
_add_confidence_scores(alignments, emission, labels[0], frame_shift, offset)
|
|
419
472
|
|
|
420
473
|
alignments = _update_alignments_speaker(supervisions, alignments)
|
|
421
474
|
|
|
@@ -427,6 +480,7 @@ def _add_confidence_scores(
|
|
|
427
480
|
emission: torch.Tensor,
|
|
428
481
|
labels: List[int],
|
|
429
482
|
frame_shift: float,
|
|
483
|
+
offset: float = 0.0,
|
|
430
484
|
) -> None:
|
|
431
485
|
"""
|
|
432
486
|
Add confidence scores to supervisions and their word-level alignments.
|
|
@@ -444,8 +498,8 @@ def _add_confidence_scores(
|
|
|
444
498
|
tokens = torch.tensor(labels, dtype=torch.int64, device=emission.device)
|
|
445
499
|
|
|
446
500
|
for supervision in supervisions:
|
|
447
|
-
start_frame = int(supervision.start / frame_shift)
|
|
448
|
-
end_frame = int(supervision.end / frame_shift)
|
|
501
|
+
start_frame = int((supervision.start - offset) / frame_shift)
|
|
502
|
+
end_frame = int((supervision.end - offset) / frame_shift)
|
|
449
503
|
|
|
450
504
|
# Compute segment-level confidence
|
|
451
505
|
probabilities = emission[0, start_frame:end_frame].softmax(dim=-1)
|
|
@@ -457,8 +511,8 @@ def _add_confidence_scores(
|
|
|
457
511
|
if hasattr(supervision, "alignment") and supervision.alignment:
|
|
458
512
|
words = supervision.alignment.get("word", [])
|
|
459
513
|
for w, item in enumerate(words):
|
|
460
|
-
start = int(item.start / frame_shift) - start_frame
|
|
461
|
-
end = int(item.end / frame_shift) - start_frame
|
|
514
|
+
start = int((item.start - offset) / frame_shift) - start_frame
|
|
515
|
+
end = int((item.end - offset) / frame_shift) - start_frame
|
|
462
516
|
words[w] = item._replace(score=round(1.0 - diffprobs[start:end].mean().item(), ndigits=4))
|
|
463
517
|
|
|
464
518
|
|
|
@@ -473,3 +527,23 @@ def _update_alignments_speaker(supervisions: List[Supervision], alignments: List
|
|
|
473
527
|
for supervision, alignment in zip(supervisions, alignments):
|
|
474
528
|
alignment.speaker = supervision.speaker
|
|
475
529
|
return alignments
|
|
530
|
+
|
|
531
|
+
|
|
532
|
+
def _load_tokenizer(
|
|
533
|
+
client_wrapper: Any,
|
|
534
|
+
model_path: str,
|
|
535
|
+
model_name: str,
|
|
536
|
+
device: str,
|
|
537
|
+
*,
|
|
538
|
+
tokenizer_cls: Type[LatticeTokenizer] = LatticeTokenizer,
|
|
539
|
+
) -> LatticeTokenizer:
|
|
540
|
+
"""Instantiate tokenizer with consistent error handling."""
|
|
541
|
+
try:
|
|
542
|
+
return tokenizer_cls.from_pretrained(
|
|
543
|
+
client_wrapper=client_wrapper,
|
|
544
|
+
model_path=model_path,
|
|
545
|
+
model_name=model_name,
|
|
546
|
+
device=device,
|
|
547
|
+
)
|
|
548
|
+
except Exception as e:
|
|
549
|
+
raise ModelLoadError(f"tokenizer from {model_path}", original_error=e)
|
lattifai/audio2.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
1
|
+
"""Audio loading and resampling utilities."""
|
|
2
|
+
|
|
3
|
+
from collections import namedtuple
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import BinaryIO, Iterable, Optional, Tuple, Union
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import soundfile as sf
|
|
9
|
+
import torch
|
|
10
|
+
from lhotse.augmentation import get_or_create_resampler
|
|
11
|
+
from lhotse.utils import Pathlike
|
|
12
|
+
|
|
13
|
+
from lattifai.errors import AudioLoadError
|
|
14
|
+
|
|
15
|
+
# ChannelSelectorType = Union[int, Iterable[int], str]
|
|
16
|
+
ChannelSelectorType = Union[int, str]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class AudioData(namedtuple("AudioData", ["sampling_rate", "ndarray", "tensor", "device", "path"])):
|
|
20
|
+
"""Audio data container with sampling rate, numpy array, tensor, and device information."""
|
|
21
|
+
|
|
22
|
+
def __str__(self) -> str:
|
|
23
|
+
return self.path
|
|
24
|
+
|
|
25
|
+
@property
|
|
26
|
+
def duration(self) -> float:
|
|
27
|
+
"""Duration of the audio in seconds."""
|
|
28
|
+
return self.ndarray.shape[-1] / self.sampling_rate
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class AudioLoader:
|
|
32
|
+
"""Load and preprocess audio files into AudioData format."""
|
|
33
|
+
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
device: str = "cpu",
|
|
37
|
+
):
|
|
38
|
+
"""Initialize AudioLoader.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
device: Device to load audio tensors on (default: "cpu").
|
|
42
|
+
"""
|
|
43
|
+
self.device = device
|
|
44
|
+
self._resampler_cache = {}
|
|
45
|
+
|
|
46
|
+
def _resample_audio(
|
|
47
|
+
self,
|
|
48
|
+
audio_sr: Tuple[torch.Tensor, int],
|
|
49
|
+
sampling_rate: int,
|
|
50
|
+
device: Optional[str],
|
|
51
|
+
channel_selector: Optional[ChannelSelectorType],
|
|
52
|
+
) -> torch.Tensor:
|
|
53
|
+
"""Resample audio to target sampling rate with channel selection.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
audio_sr: Tuple of (audio_tensor, original_sample_rate).
|
|
57
|
+
sampling_rate: Target sampling rate.
|
|
58
|
+
device: Device to perform resampling on.
|
|
59
|
+
channel_selector: How to select channels.
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
Resampled audio tensor of shape (1, T) or (C, T).
|
|
63
|
+
"""
|
|
64
|
+
audio, sr = audio_sr
|
|
65
|
+
|
|
66
|
+
if channel_selector is None:
|
|
67
|
+
# keep the original multi-channel signal
|
|
68
|
+
tensor = audio
|
|
69
|
+
elif isinstance(channel_selector, int):
|
|
70
|
+
assert audio.shape[0] >= channel_selector, f"Invalid channel: {channel_selector}"
|
|
71
|
+
tensor = audio[channel_selector : channel_selector + 1].clone()
|
|
72
|
+
del audio
|
|
73
|
+
elif isinstance(channel_selector, str):
|
|
74
|
+
assert channel_selector == "average"
|
|
75
|
+
tensor = torch.mean(audio.to(device), dim=0, keepdim=True)
|
|
76
|
+
del audio
|
|
77
|
+
else:
|
|
78
|
+
raise ValueError(f"Unsupported channel_selector: {channel_selector}")
|
|
79
|
+
# assert isinstance(channel_selector, Iterable)
|
|
80
|
+
# num_channels = audio.shape[0]
|
|
81
|
+
# print(f"Selecting channels {channel_selector} from the signal with {num_channels} channels.")
|
|
82
|
+
# if max(channel_selector) >= num_channels:
|
|
83
|
+
# raise ValueError(
|
|
84
|
+
# f"Cannot select channel subset {channel_selector} from a signal with {num_channels} channels."
|
|
85
|
+
# )
|
|
86
|
+
# tensor = audio[channel_selector]
|
|
87
|
+
|
|
88
|
+
tensor = tensor.to(device)
|
|
89
|
+
if sr != sampling_rate:
|
|
90
|
+
cache_key = (sr, sampling_rate, device)
|
|
91
|
+
if cache_key not in self._resampler_cache:
|
|
92
|
+
self._resampler_cache[cache_key] = get_or_create_resampler(sr, sampling_rate).to(device=device)
|
|
93
|
+
resampler = self._resampler_cache[cache_key]
|
|
94
|
+
|
|
95
|
+
length = tensor.size(-1)
|
|
96
|
+
chunk_size = sampling_rate * 3600
|
|
97
|
+
if length > chunk_size:
|
|
98
|
+
resampled_chunks = []
|
|
99
|
+
for i in range(0, length, chunk_size):
|
|
100
|
+
resampled_chunks.append(resampler(tensor[..., i : i + chunk_size]))
|
|
101
|
+
tensor = torch.cat(resampled_chunks, dim=-1)
|
|
102
|
+
else:
|
|
103
|
+
tensor = resampler(tensor)
|
|
104
|
+
|
|
105
|
+
return tensor
|
|
106
|
+
|
|
107
|
+
def _load_audio(
|
|
108
|
+
self,
|
|
109
|
+
audio: Union[Pathlike, BinaryIO],
|
|
110
|
+
sampling_rate: int,
|
|
111
|
+
channel_selector: Optional[ChannelSelectorType],
|
|
112
|
+
) -> torch.Tensor:
|
|
113
|
+
"""Load audio from file or binary stream and resample to target rate.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
audio: Path to audio file or binary stream.
|
|
117
|
+
sampling_rate: Target sampling rate.
|
|
118
|
+
channel_selector: How to select channels.
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
Resampled audio tensor.
|
|
122
|
+
|
|
123
|
+
Raises:
|
|
124
|
+
ImportError: If PyAV is needed but not installed.
|
|
125
|
+
ValueError: If no audio stream found.
|
|
126
|
+
RuntimeError: If audio loading fails.
|
|
127
|
+
"""
|
|
128
|
+
if isinstance(audio, Pathlike):
|
|
129
|
+
audio = str(Path(str(audio)).expanduser())
|
|
130
|
+
|
|
131
|
+
# load audio
|
|
132
|
+
try:
|
|
133
|
+
waveform, sample_rate = sf.read(audio, always_2d=True, dtype="float32") # numpy array
|
|
134
|
+
waveform = waveform.T # (channels, samples)
|
|
135
|
+
except Exception as primary_error:
|
|
136
|
+
# Fallback to PyAV for formats not supported by soundfile
|
|
137
|
+
try:
|
|
138
|
+
import av
|
|
139
|
+
except ImportError:
|
|
140
|
+
raise AudioLoadError(
|
|
141
|
+
"PyAV (av) is required for loading certain audio formats. "
|
|
142
|
+
f"Install it with: pip install av\n"
|
|
143
|
+
f"Primary error was: {primary_error}"
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
try:
|
|
147
|
+
container = av.open(audio)
|
|
148
|
+
audio_stream = next((s for s in container.streams if s.type == "audio"), None)
|
|
149
|
+
|
|
150
|
+
if audio_stream is None:
|
|
151
|
+
raise ValueError(f"No audio stream found in file: {audio}")
|
|
152
|
+
|
|
153
|
+
# Resample to target sample rate during decoding
|
|
154
|
+
audio_stream.codec_context.format = av.AudioFormat("flt") # 32-bit float
|
|
155
|
+
|
|
156
|
+
frames = []
|
|
157
|
+
for frame in container.decode(audio_stream):
|
|
158
|
+
# Convert frame to numpy array
|
|
159
|
+
array = frame.to_ndarray()
|
|
160
|
+
# Ensure shape is (channels, samples)
|
|
161
|
+
if array.ndim == 1:
|
|
162
|
+
array = array.reshape(1, -1)
|
|
163
|
+
elif array.ndim == 2 and array.shape[0] > array.shape[1]:
|
|
164
|
+
array = array.T
|
|
165
|
+
frames.append(array)
|
|
166
|
+
|
|
167
|
+
container.close()
|
|
168
|
+
|
|
169
|
+
if not frames:
|
|
170
|
+
raise ValueError(f"No audio data found in file: {audio}")
|
|
171
|
+
|
|
172
|
+
# Concatenate all frames
|
|
173
|
+
waveform = np.concatenate(frames, axis=1).astype(np.float32) # (channels, samples)
|
|
174
|
+
sample_rate = audio_stream.codec_context.sample_rate
|
|
175
|
+
except Exception as e:
|
|
176
|
+
raise RuntimeError(f"Failed to load audio file {audio}: {e}")
|
|
177
|
+
|
|
178
|
+
return self._resample_audio(
|
|
179
|
+
(torch.from_numpy(waveform), sample_rate),
|
|
180
|
+
sampling_rate,
|
|
181
|
+
device=self.device,
|
|
182
|
+
channel_selector=channel_selector,
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
def __call__(
|
|
186
|
+
self,
|
|
187
|
+
audio: Union[Pathlike, BinaryIO],
|
|
188
|
+
sampling_rate: int = 16000,
|
|
189
|
+
channel_selector: Optional[ChannelSelectorType] = "average",
|
|
190
|
+
) -> AudioData:
|
|
191
|
+
"""
|
|
192
|
+
Args:
|
|
193
|
+
audio: Path to audio file or binary stream.
|
|
194
|
+
channel_selector: How to select channels (default: "average").
|
|
195
|
+
sampling_rate: Target sampling rate (default: use instance sampling_rate).
|
|
196
|
+
|
|
197
|
+
Returns:
|
|
198
|
+
AudioData namedtuple with sampling_rate, ndarray, and tensor fields.
|
|
199
|
+
"""
|
|
200
|
+
tensor = self._load_audio(audio, sampling_rate, channel_selector)
|
|
201
|
+
|
|
202
|
+
# tensor is (1, T) or (C, T)
|
|
203
|
+
ndarray = tensor.cpu().numpy()
|
|
204
|
+
|
|
205
|
+
return AudioData(
|
|
206
|
+
sampling_rate=sampling_rate,
|
|
207
|
+
ndarray=ndarray,
|
|
208
|
+
tensor=tensor,
|
|
209
|
+
device=self.device,
|
|
210
|
+
path=str(audio) if isinstance(audio, Pathlike) else "<BinaryIO>",
|
|
211
|
+
)
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from typing import List, Optional
|
|
2
|
+
|
|
3
|
+
from lhotse.utils import Pathlike
|
|
4
|
+
|
|
5
|
+
from ..config.caption import InputCaptionFormat
|
|
6
|
+
from .caption import Caption
|
|
7
|
+
from .gemini_reader import GeminiReader, GeminiSegment
|
|
8
|
+
from .gemini_writer import GeminiWriter
|
|
9
|
+
from .supervision import Supervision
|
|
10
|
+
from .text_parser import normalize_text
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"Caption",
|
|
14
|
+
"Supervision",
|
|
15
|
+
"GeminiReader",
|
|
16
|
+
"GeminiWriter",
|
|
17
|
+
"GeminiSegment",
|
|
18
|
+
"normalize_text",
|
|
19
|
+
"InputCaptionFormat",
|
|
20
|
+
]
|