lattifai 0.2.5__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lattifai/__init__.py +5 -0
- lattifai/base_client.py +11 -0
- lattifai/bin/__init__.py +1 -0
- lattifai/bin/agent.py +326 -0
- lattifai/bin/align.py +253 -21
- lattifai/bin/cli_base.py +5 -0
- lattifai/bin/subtitle.py +182 -4
- lattifai/client.py +166 -66
- lattifai/errors.py +45 -7
- lattifai/io/__init__.py +21 -1
- lattifai/io/gemini_reader.py +371 -0
- lattifai/io/gemini_writer.py +173 -0
- lattifai/io/parser.py +75 -0
- lattifai/io/reader.py +25 -10
- lattifai/io/supervision.py +16 -0
- lattifai/io/utils.py +15 -0
- lattifai/io/writer.py +58 -17
- lattifai/tokenizer/__init__.py +2 -2
- lattifai/tokenizer/tokenizer.py +229 -41
- lattifai/utils.py +133 -0
- lattifai-0.4.1.dist-info/METADATA +810 -0
- lattifai-0.4.1.dist-info/RECORD +29 -0
- lattifai-0.4.1.dist-info/entry_points.txt +3 -0
- lattifai-0.2.5.dist-info/METADATA +0 -334
- lattifai-0.2.5.dist-info/RECORD +0 -23
- lattifai-0.2.5.dist-info/entry_points.txt +0 -4
- {lattifai-0.2.5.dist-info → lattifai-0.4.1.dist-info}/WHEEL +0 -0
- {lattifai-0.2.5.dist-info → lattifai-0.4.1.dist-info}/licenses/LICENSE +0 -0
- {lattifai-0.2.5.dist-info → lattifai-0.4.1.dist-info}/top_level.txt +0 -0
lattifai/io/supervision.py
CHANGED
|
@@ -7,6 +7,22 @@ from lhotse.utils import Seconds
|
|
|
7
7
|
|
|
8
8
|
@dataclass
|
|
9
9
|
class Supervision(SupervisionSegment):
|
|
10
|
+
"""
|
|
11
|
+
Extended SupervisionSegment with simplified initialization.
|
|
12
|
+
|
|
13
|
+
Note: The `alignment` field is inherited from SupervisionSegment:
|
|
14
|
+
alignment: Optional[Dict[str, List[AlignmentItem]]] = None
|
|
15
|
+
|
|
16
|
+
Structure of alignment when return_details=True:
|
|
17
|
+
{
|
|
18
|
+
'word': [
|
|
19
|
+
AlignmentItem(symbol='hello', start=0.0, duration=0.5, score=0.95),
|
|
20
|
+
AlignmentItem(symbol='world', start=0.6, duration=0.4, score=0.92),
|
|
21
|
+
...
|
|
22
|
+
]
|
|
23
|
+
}
|
|
24
|
+
"""
|
|
25
|
+
|
|
10
26
|
text: Optional[str] = None
|
|
11
27
|
id: str = ''
|
|
12
28
|
recording_id: str = ''
|
lattifai/io/utils.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Utility constants and helper functions for subtitle I/O operations
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
# Supported subtitle formats for reading/writing
|
|
6
|
+
SUBTITLE_FORMATS = ['srt', 'vtt', 'ass', 'ssa', 'sub', 'sbv', 'txt', 'md']
|
|
7
|
+
|
|
8
|
+
# Input subtitle formats (includes special formats like 'auto' and 'gemini')
|
|
9
|
+
INPUT_SUBTITLE_FORMATS = ['srt', 'vtt', 'ass', 'ssa', 'sub', 'sbv', 'txt', 'auto', 'gemini']
|
|
10
|
+
|
|
11
|
+
# Output subtitle formats (includes special formats like 'TextGrid' and 'json')
|
|
12
|
+
OUTPUT_SUBTITLE_FORMATS = ['srt', 'vtt', 'ass', 'ssa', 'sub', 'sbv', 'txt', 'TextGrid', 'json']
|
|
13
|
+
|
|
14
|
+
# All subtitle formats combined (for file detection)
|
|
15
|
+
ALL_SUBTITLE_FORMATS = list(set(SUBTITLE_FORMATS + ['TextGrid', 'json', 'gemini']))
|
lattifai/io/writer.py
CHANGED
|
@@ -1,49 +1,90 @@
|
|
|
1
|
+
import json
|
|
1
2
|
from abc import ABCMeta
|
|
2
|
-
from typing import List
|
|
3
|
+
from typing import Any, List, Optional
|
|
3
4
|
|
|
5
|
+
import pysubs2
|
|
6
|
+
from lhotse.supervision import AlignmentItem
|
|
4
7
|
from lhotse.utils import Pathlike
|
|
5
8
|
|
|
6
|
-
from .reader import
|
|
9
|
+
from .reader import Supervision
|
|
7
10
|
|
|
8
11
|
|
|
9
12
|
class SubtitleWriter(ABCMeta):
|
|
10
|
-
"""Class for writing subtitle files."""
|
|
13
|
+
"""Class for writing subtitle files with optional word-level alignment."""
|
|
11
14
|
|
|
12
15
|
@classmethod
|
|
13
16
|
def write(cls, alignments: List[Supervision], output_path: Pathlike) -> Pathlike:
|
|
14
17
|
if str(output_path)[-4:].lower() == '.txt':
|
|
15
18
|
with open(output_path, 'w', encoding='utf-8') as f:
|
|
16
19
|
for sup in alignments:
|
|
17
|
-
|
|
20
|
+
word_items = parse_alignment_from_supervision(sup)
|
|
21
|
+
if word_items:
|
|
22
|
+
for item in word_items:
|
|
23
|
+
f.write(f'[{item.start:.2f}-{item.end:.2f}] {item.symbol}\n')
|
|
24
|
+
else:
|
|
25
|
+
text = f'{sup.speaker} {sup.text}' if sup.speaker is not None else sup.text
|
|
26
|
+
f.write(f'[{sup.start:.2f}-{sup.end:.2f}] {text}\n')
|
|
27
|
+
|
|
18
28
|
elif str(output_path)[-5:].lower() == '.json':
|
|
19
29
|
with open(output_path, 'w', encoding='utf-8') as f:
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
30
|
+
# Enhanced JSON export with word-level alignment
|
|
31
|
+
json_data = []
|
|
32
|
+
for sup in alignments:
|
|
33
|
+
sup_dict = sup.to_dict()
|
|
34
|
+
json_data.append(sup_dict)
|
|
35
|
+
json.dump(json_data, f, ensure_ascii=False, indent=4)
|
|
23
36
|
elif str(output_path).endswith('.TextGrid') or str(output_path).endswith('.textgrid'):
|
|
24
37
|
from tgt import Interval, IntervalTier, TextGrid, write_to_file
|
|
25
38
|
|
|
26
39
|
tg = TextGrid()
|
|
27
40
|
supervisions, words = [], []
|
|
28
41
|
for supervision in sorted(alignments, key=lambda x: x.start):
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
42
|
+
text = (
|
|
43
|
+
f'{supervision.speaker} {supervision.text}' if supervision.speaker is not None else supervision.text
|
|
44
|
+
)
|
|
45
|
+
supervisions.append(Interval(supervision.start, supervision.end, text or ''))
|
|
46
|
+
# Extract word-level alignment using helper function
|
|
47
|
+
word_items = parse_alignment_from_supervision(supervision)
|
|
48
|
+
if word_items:
|
|
49
|
+
for item in word_items:
|
|
50
|
+
words.append(Interval(item.start, item.end, item.symbol))
|
|
33
51
|
|
|
34
52
|
tg.add_tier(IntervalTier(name='utterances', objects=supervisions))
|
|
35
53
|
if words:
|
|
36
54
|
tg.add_tier(IntervalTier(name='words', objects=words))
|
|
37
55
|
write_to_file(tg, output_path, format='long')
|
|
38
56
|
else:
|
|
39
|
-
import pysubs2
|
|
40
|
-
|
|
41
57
|
subs = pysubs2.SSAFile()
|
|
42
58
|
for sup in alignments:
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
59
|
+
# Add word-level timing as metadata in the subtitle text
|
|
60
|
+
word_items = parse_alignment_from_supervision(sup)
|
|
61
|
+
if word_items:
|
|
62
|
+
for word in word_items:
|
|
63
|
+
subs.append(
|
|
64
|
+
pysubs2.SSAEvent(start=int(word.start * 1000), end=int(word.end * 1000), text=word.symbol)
|
|
65
|
+
)
|
|
66
|
+
else:
|
|
67
|
+
text = f'{sup.speaker} {sup.text}' if sup.speaker is not None else sup.text
|
|
68
|
+
subs.append(pysubs2.SSAEvent(start=int(sup.start * 1000), end=int(sup.end * 1000), text=text or ''))
|
|
47
69
|
subs.save(output_path)
|
|
48
70
|
|
|
49
71
|
return output_path
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def parse_alignment_from_supervision(supervision: Any) -> Optional[List[AlignmentItem]]:
|
|
75
|
+
"""
|
|
76
|
+
Extract word-level alignment items from Supervision object.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
supervision: Supervision object with potential alignment data
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
List of AlignmentItem objects, or None if no alignment data present
|
|
83
|
+
"""
|
|
84
|
+
if not hasattr(supervision, 'alignment') or not supervision.alignment:
|
|
85
|
+
return None
|
|
86
|
+
|
|
87
|
+
if 'word' not in supervision.alignment:
|
|
88
|
+
return None
|
|
89
|
+
|
|
90
|
+
return supervision.alignment['word']
|
lattifai/tokenizer/__init__.py
CHANGED
|
@@ -1,3 +1,3 @@
|
|
|
1
|
-
from .tokenizer import LatticeTokenizer
|
|
1
|
+
from .tokenizer import AsyncLatticeTokenizer, LatticeTokenizer
|
|
2
2
|
|
|
3
|
-
__all__ = ['LatticeTokenizer']
|
|
3
|
+
__all__ = ['LatticeTokenizer', 'AsyncLatticeTokenizer']
|
lattifai/tokenizer/tokenizer.py
CHANGED
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
import gzip
|
|
2
|
+
import inspect
|
|
2
3
|
import pickle
|
|
3
4
|
import re
|
|
4
5
|
from collections import defaultdict
|
|
5
|
-
from
|
|
6
|
-
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
6
|
+
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
|
7
7
|
|
|
8
8
|
import torch
|
|
9
9
|
|
|
10
|
-
from lattifai.
|
|
10
|
+
from lattifai.errors import LATTICE_DECODING_FAILURE_HELP, LatticeDecodingError
|
|
11
11
|
from lattifai.io import Supervision
|
|
12
12
|
from lattifai.tokenizer.phonemizer import G2Phonemizer
|
|
13
13
|
|
|
@@ -21,10 +21,13 @@ GROUPING_SEPARATOR = '✹'
|
|
|
21
21
|
MAXIMUM_WORD_LENGTH = 40
|
|
22
22
|
|
|
23
23
|
|
|
24
|
+
TokenizerT = TypeVar('TokenizerT', bound='LatticeTokenizer')
|
|
25
|
+
|
|
26
|
+
|
|
24
27
|
class LatticeTokenizer:
|
|
25
28
|
"""Tokenizer for converting Lhotse Cut to LatticeGraph."""
|
|
26
29
|
|
|
27
|
-
def __init__(self, client_wrapper:
|
|
30
|
+
def __init__(self, client_wrapper: Any):
|
|
28
31
|
self.client_wrapper = client_wrapper
|
|
29
32
|
self.words: List[str] = []
|
|
30
33
|
self.g2p_model: Any = None # Placeholder for G2P model
|
|
@@ -99,13 +102,14 @@ class LatticeTokenizer:
|
|
|
99
102
|
# If no special pattern matches, return the original sentence
|
|
100
103
|
return [sentence]
|
|
101
104
|
|
|
102
|
-
@
|
|
105
|
+
@classmethod
|
|
103
106
|
def from_pretrained(
|
|
104
|
-
|
|
107
|
+
cls: Type[TokenizerT],
|
|
108
|
+
client_wrapper: Any,
|
|
105
109
|
model_path: str,
|
|
106
110
|
device: str = 'cpu',
|
|
107
111
|
compressed: bool = True,
|
|
108
|
-
):
|
|
112
|
+
) -> TokenizerT:
|
|
109
113
|
"""Load tokenizer from exported binary file"""
|
|
110
114
|
from pathlib import Path
|
|
111
115
|
|
|
@@ -117,7 +121,7 @@ class LatticeTokenizer:
|
|
|
117
121
|
with open(words_model_path, 'rb') as f:
|
|
118
122
|
data = pickle.load(f)
|
|
119
123
|
|
|
120
|
-
tokenizer =
|
|
124
|
+
tokenizer = cls(client_wrapper=client_wrapper)
|
|
121
125
|
tokenizer.words = data['words']
|
|
122
126
|
tokenizer.dictionaries = defaultdict(list, data['dictionaries'])
|
|
123
127
|
tokenizer.oov_word = data['oov_word']
|
|
@@ -179,53 +183,98 @@ class LatticeTokenizer:
|
|
|
179
183
|
return {}
|
|
180
184
|
|
|
181
185
|
def split_sentences(self, supervisions: List[Supervision], strip_whitespace=True) -> List[str]:
|
|
186
|
+
"""Split supervisions into sentences using the sentence splitter.
|
|
187
|
+
|
|
188
|
+
Carefull about speaker changes.
|
|
189
|
+
"""
|
|
182
190
|
texts, text_len, sidx = [], 0, 0
|
|
191
|
+
speakers = []
|
|
183
192
|
for s, supervision in enumerate(supervisions):
|
|
184
193
|
text_len += len(supervision.text)
|
|
185
|
-
if
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
194
|
+
if supervision.speaker:
|
|
195
|
+
if sidx < s:
|
|
196
|
+
if len(speakers) < len(texts) + 1:
|
|
197
|
+
speakers.append(None)
|
|
198
|
+
text = ' '.join([sup.text for sup in supervisions[sidx:s]])
|
|
199
|
+
texts.append(text)
|
|
200
|
+
sidx = s
|
|
201
|
+
text_len = len(supervision.text)
|
|
202
|
+
speakers.append(supervision.speaker)
|
|
203
|
+
|
|
204
|
+
else:
|
|
205
|
+
if text_len >= 2000 or s == len(supervisions) - 1:
|
|
206
|
+
if len(speakers) < len(texts) + 1:
|
|
207
|
+
speakers.append(None)
|
|
208
|
+
text = ' '.join([sup.text for sup in supervisions[sidx : s + 1]])
|
|
209
|
+
texts.append(text)
|
|
210
|
+
sidx = s + 1
|
|
211
|
+
text_len = 0
|
|
212
|
+
|
|
213
|
+
assert len(speakers) == len(texts), f'len(speakers)={len(speakers)} != len(texts)={len(texts)}'
|
|
193
214
|
sentences = self.sentence_splitter.split(texts, threshold=0.15, strip_whitespace=strip_whitespace)
|
|
194
215
|
|
|
195
216
|
supervisions, remainder = [], ''
|
|
196
|
-
for _sentences in sentences:
|
|
217
|
+
for k, (_speaker, _sentences) in enumerate(zip(speakers, sentences)):
|
|
218
|
+
# Prepend remainder from previous iteration to the first sentence
|
|
219
|
+
if _sentences and remainder:
|
|
220
|
+
_sentences[0] = remainder + _sentences[0]
|
|
221
|
+
remainder = ''
|
|
222
|
+
|
|
223
|
+
if not _sentences:
|
|
224
|
+
continue
|
|
225
|
+
|
|
197
226
|
# Process and re-split special sentence types
|
|
198
227
|
processed_sentences = []
|
|
199
228
|
for s, _sentence in enumerate(_sentences):
|
|
200
229
|
if remainder:
|
|
201
230
|
_sentence = remainder + _sentence
|
|
202
231
|
remainder = ''
|
|
203
|
-
|
|
204
232
|
# Detect and split special sentence types: e.g., '[APPLAUSE] >> MIRA MURATI:' -> ['[APPLAUSE]', '>> MIRA MURATI:'] # noqa: E501
|
|
205
233
|
resplit_parts = self._resplit_special_sentence_types(_sentence)
|
|
206
|
-
if any(resplit_parts[-1].endswith(sp) for sp in [':', ':']):
|
|
234
|
+
if any(resplit_parts[-1].endswith(sp) for sp in [':', ':', ']']):
|
|
207
235
|
if s < len(_sentences) - 1:
|
|
208
236
|
_sentences[s + 1] = resplit_parts[-1] + ' ' + _sentences[s + 1]
|
|
209
237
|
else: # last part
|
|
210
|
-
remainder = resplit_parts[-1] + ' '
|
|
238
|
+
remainder = resplit_parts[-1] + ' '
|
|
211
239
|
processed_sentences.extend(resplit_parts[:-1])
|
|
212
240
|
else:
|
|
213
241
|
processed_sentences.extend(resplit_parts)
|
|
214
|
-
|
|
215
242
|
_sentences = processed_sentences
|
|
216
243
|
|
|
217
|
-
if
|
|
218
|
-
|
|
219
|
-
|
|
244
|
+
if not _sentences:
|
|
245
|
+
if remainder:
|
|
246
|
+
_sentences, remainder = [remainder.strip()], ''
|
|
247
|
+
else:
|
|
248
|
+
continue
|
|
220
249
|
|
|
221
250
|
if any(_sentences[-1].endswith(ep) for ep in END_PUNCTUATION):
|
|
222
|
-
supervisions.extend(
|
|
251
|
+
supervisions.extend(
|
|
252
|
+
Supervision(text=text, speaker=(_speaker if s == 0 else None)) for s, text in enumerate(_sentences)
|
|
253
|
+
)
|
|
254
|
+
_speaker = None # reset speaker after use
|
|
223
255
|
else:
|
|
224
|
-
supervisions.extend(
|
|
225
|
-
|
|
256
|
+
supervisions.extend(
|
|
257
|
+
Supervision(text=text, speaker=(_speaker if s == 0 else None))
|
|
258
|
+
for s, text in enumerate(_sentences[:-1])
|
|
259
|
+
)
|
|
260
|
+
remainder = _sentences[-1] + ' ' + remainder
|
|
261
|
+
if k < len(speakers) - 1 and speakers[k + 1] is not None: # next speaker is set
|
|
262
|
+
supervisions.append(
|
|
263
|
+
Supervision(text=remainder.strip(), speaker=_speaker if len(_sentences) == 1 else None)
|
|
264
|
+
)
|
|
265
|
+
remainder = ''
|
|
266
|
+
elif len(_sentences) == 1:
|
|
267
|
+
if k == len(speakers) - 1:
|
|
268
|
+
pass # keep _speaker for the last supervision
|
|
269
|
+
else:
|
|
270
|
+
assert speakers[k + 1] is None
|
|
271
|
+
speakers[k + 1] = _speaker
|
|
272
|
+
else:
|
|
273
|
+
assert len(_sentences) > 1
|
|
274
|
+
_speaker = None # reset speaker if sentence not ended
|
|
226
275
|
|
|
227
276
|
if remainder.strip():
|
|
228
|
-
supervisions.append(Supervision(text=remainder.strip()))
|
|
277
|
+
supervisions.append(Supervision(text=remainder.strip(), speaker=_speaker))
|
|
229
278
|
|
|
230
279
|
return supervisions
|
|
231
280
|
|
|
@@ -246,14 +295,18 @@ class LatticeTokenizer:
|
|
|
246
295
|
raise Exception(f'Failed to tokenize texts: {response.text}')
|
|
247
296
|
result = response.json()
|
|
248
297
|
lattice_id = result['id']
|
|
249
|
-
return
|
|
298
|
+
return (
|
|
299
|
+
supervisions,
|
|
300
|
+
lattice_id,
|
|
301
|
+
(result['lattice_graph'], result['final_state'], result.get('acoustic_scale', 1.0)),
|
|
302
|
+
)
|
|
250
303
|
|
|
251
304
|
def detokenize(
|
|
252
305
|
self,
|
|
253
306
|
lattice_id: str,
|
|
254
307
|
lattice_results: Tuple[torch.Tensor, Any, Any, float, float],
|
|
255
|
-
|
|
256
|
-
|
|
308
|
+
supervisions: List[Supervision],
|
|
309
|
+
return_details: bool = False,
|
|
257
310
|
) -> List[Supervision]:
|
|
258
311
|
emission, results, labels, frame_shift, offset, channel = lattice_results # noqa: F841
|
|
259
312
|
response = self.client_wrapper.post(
|
|
@@ -265,22 +318,157 @@ class LatticeTokenizer:
|
|
|
265
318
|
'labels': labels[0],
|
|
266
319
|
'offset': offset,
|
|
267
320
|
'channel': channel,
|
|
321
|
+
'return_details': return_details,
|
|
268
322
|
'destroy_lattice': True,
|
|
269
323
|
},
|
|
270
324
|
)
|
|
325
|
+
if response.status_code == 422:
|
|
326
|
+
raise LatticeDecodingError(
|
|
327
|
+
lattice_id,
|
|
328
|
+
original_error=Exception(LATTICE_DECODING_FAILURE_HELP),
|
|
329
|
+
)
|
|
271
330
|
if response.status_code != 200:
|
|
272
331
|
raise Exception(f'Failed to detokenize lattice: {response.text}')
|
|
332
|
+
|
|
273
333
|
result = response.json()
|
|
274
334
|
if not result.get('success'):
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
335
|
+
raise Exception('Failed to detokenize the alignment results.')
|
|
336
|
+
|
|
337
|
+
alignments = [Supervision.from_dict(s) for s in result['supervisions']]
|
|
338
|
+
|
|
339
|
+
if return_details:
|
|
340
|
+
# Add emission confidence scores for segments and word-level alignments
|
|
341
|
+
_add_confidence_scores(alignments, emission, labels[0], frame_shift)
|
|
342
|
+
|
|
343
|
+
alignments = _update_alignments_speaker(supervisions, alignments)
|
|
344
|
+
|
|
345
|
+
return alignments
|
|
279
346
|
|
|
280
347
|
|
|
281
|
-
|
|
282
|
-
def
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
348
|
+
class AsyncLatticeTokenizer(LatticeTokenizer):
|
|
349
|
+
async def _post_async(self, endpoint: str, **kwargs):
|
|
350
|
+
response = self.client_wrapper.post(endpoint, **kwargs)
|
|
351
|
+
if inspect.isawaitable(response):
|
|
352
|
+
return await response
|
|
353
|
+
return response
|
|
354
|
+
|
|
355
|
+
async def tokenize(
|
|
356
|
+
self, supervisions: List[Supervision], split_sentence: bool = False
|
|
357
|
+
) -> Tuple[str, Dict[str, Any]]:
|
|
358
|
+
if split_sentence:
|
|
359
|
+
self.init_sentence_splitter()
|
|
360
|
+
supervisions = self.split_sentences(supervisions)
|
|
361
|
+
|
|
362
|
+
pronunciation_dictionaries = self.prenormalize([s.text for s in supervisions])
|
|
363
|
+
response = await self._post_async(
|
|
364
|
+
'tokenize',
|
|
365
|
+
json={
|
|
366
|
+
'supervisions': [s.to_dict() for s in supervisions],
|
|
367
|
+
'pronunciation_dictionaries': pronunciation_dictionaries,
|
|
368
|
+
},
|
|
369
|
+
)
|
|
370
|
+
if response.status_code != 200:
|
|
371
|
+
raise Exception(f'Failed to tokenize texts: {response.text}')
|
|
372
|
+
result = response.json()
|
|
373
|
+
lattice_id = result['id']
|
|
374
|
+
return (
|
|
375
|
+
supervisions,
|
|
376
|
+
lattice_id,
|
|
377
|
+
(result['lattice_graph'], result['final_state'], result.get('acoustic_scale', 1.0)),
|
|
378
|
+
)
|
|
379
|
+
|
|
380
|
+
async def detokenize(
|
|
381
|
+
self,
|
|
382
|
+
lattice_id: str,
|
|
383
|
+
lattice_results: Tuple[torch.Tensor, Any, Any, float, float],
|
|
384
|
+
supervisions: List[Supervision],
|
|
385
|
+
return_details: bool = False,
|
|
386
|
+
) -> List[Supervision]:
|
|
387
|
+
emission, results, labels, frame_shift, offset, channel = lattice_results # noqa: F841
|
|
388
|
+
response = await self._post_async(
|
|
389
|
+
'detokenize',
|
|
390
|
+
json={
|
|
391
|
+
'lattice_id': lattice_id,
|
|
392
|
+
'frame_shift': frame_shift,
|
|
393
|
+
'results': [t.to_dict() for t in results[0]],
|
|
394
|
+
'labels': labels[0],
|
|
395
|
+
'offset': offset,
|
|
396
|
+
'channel': channel,
|
|
397
|
+
'return_details': return_details,
|
|
398
|
+
'destroy_lattice': True,
|
|
399
|
+
},
|
|
400
|
+
)
|
|
401
|
+
if response.status_code == 422:
|
|
402
|
+
raise LatticeDecodingError(
|
|
403
|
+
lattice_id,
|
|
404
|
+
original_error=Exception(LATTICE_DECODING_FAILURE_HELP),
|
|
405
|
+
)
|
|
406
|
+
if response.status_code != 200:
|
|
407
|
+
raise Exception(f'Failed to detokenize lattice: {response.text}')
|
|
408
|
+
|
|
409
|
+
result = response.json()
|
|
410
|
+
if not result.get('success'):
|
|
411
|
+
return Exception('Failed to detokenize the alignment results.')
|
|
412
|
+
|
|
413
|
+
alignments = [Supervision.from_dict(s) for s in result['supervisions']]
|
|
414
|
+
|
|
415
|
+
if return_details:
|
|
416
|
+
# Add emission confidence scores for segments and word-level alignments
|
|
417
|
+
_add_confidence_scores(alignments, emission, labels[0], frame_shift)
|
|
418
|
+
|
|
419
|
+
alignments = _update_alignments_speaker(supervisions, alignments)
|
|
420
|
+
|
|
421
|
+
return alignments
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
def _add_confidence_scores(
|
|
425
|
+
supervisions: List[Supervision],
|
|
426
|
+
emission: torch.Tensor,
|
|
427
|
+
labels: List[int],
|
|
428
|
+
frame_shift: float,
|
|
429
|
+
) -> None:
|
|
430
|
+
"""
|
|
431
|
+
Add confidence scores to supervisions and their word-level alignments.
|
|
432
|
+
|
|
433
|
+
This function modifies supervisions in-place by:
|
|
434
|
+
1. Computing segment-level confidence scores based on emission probabilities
|
|
435
|
+
2. Computing word-level confidence scores for each aligned word
|
|
436
|
+
|
|
437
|
+
Args:
|
|
438
|
+
supervisions: List of Supervision objects to add scores to (modified in-place)
|
|
439
|
+
emission: Emission tensor with shape [batch, time, vocab_size]
|
|
440
|
+
labels: Token labels corresponding to aligned tokens
|
|
441
|
+
frame_shift: Frame shift in seconds for converting frames to time
|
|
442
|
+
"""
|
|
443
|
+
tokens = torch.tensor(labels, dtype=torch.int64, device=emission.device)
|
|
444
|
+
|
|
445
|
+
for supervision in supervisions:
|
|
446
|
+
start_frame = int(supervision.start / frame_shift)
|
|
447
|
+
end_frame = int(supervision.end / frame_shift)
|
|
448
|
+
|
|
449
|
+
# Compute segment-level confidence
|
|
450
|
+
probabilities = emission[0, start_frame:end_frame].softmax(dim=-1)
|
|
451
|
+
aligned = probabilities[range(0, end_frame - start_frame), tokens[start_frame:end_frame]]
|
|
452
|
+
diffprobs = (probabilities.max(dim=-1).values - aligned).cpu()
|
|
453
|
+
supervision.score = round(1.0 - diffprobs.mean().item(), ndigits=4)
|
|
454
|
+
|
|
455
|
+
# Compute word-level confidence if alignment exists
|
|
456
|
+
if hasattr(supervision, 'alignment') and supervision.alignment:
|
|
457
|
+
words = supervision.alignment.get('word', [])
|
|
458
|
+
for w, item in enumerate(words):
|
|
459
|
+
start = int(item.start / frame_shift) - start_frame
|
|
460
|
+
end = int(item.end / frame_shift) - start_frame
|
|
461
|
+
words[w] = item._replace(score=round(1.0 - diffprobs[start:end].mean().item(), ndigits=4))
|
|
462
|
+
|
|
463
|
+
|
|
464
|
+
def _update_alignments_speaker(supervisions: List[Supervision], alignments: List[Supervision]) -> List[Supervision]:
|
|
465
|
+
"""
|
|
466
|
+
Update the speaker attribute for a list of supervisions.
|
|
467
|
+
|
|
468
|
+
Args:
|
|
469
|
+
supervisions: List of Supervision objects to get speaker info from
|
|
470
|
+
alignments: List of aligned Supervision objects to update speaker info to
|
|
471
|
+
"""
|
|
472
|
+
for supervision, alignment in zip(supervisions, alignments):
|
|
473
|
+
alignment.speaker = supervision.speaker
|
|
474
|
+
return alignments
|
lattifai/utils.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
"""Shared utility helpers for the LattifAI SDK."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from datetime import datetime, timedelta
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Optional, Type
|
|
7
|
+
|
|
8
|
+
from lattifai.errors import ModelLoadError
|
|
9
|
+
from lattifai.tokenizer import LatticeTokenizer
|
|
10
|
+
from lattifai.workers import Lattice1AlphaWorker
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _get_cache_marker_path(cache_dir: Path) -> Path:
|
|
14
|
+
"""Get the path for the cache marker file with current date."""
|
|
15
|
+
today = datetime.now().strftime('%Y%m%d')
|
|
16
|
+
return cache_dir / f'.done{today}'
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _is_cache_valid(cache_dir: Path) -> bool:
|
|
20
|
+
"""Check if cached model is valid (exists and not older than 1 days)."""
|
|
21
|
+
if not cache_dir.exists():
|
|
22
|
+
return False
|
|
23
|
+
|
|
24
|
+
# Find any .done* marker files
|
|
25
|
+
marker_files = list(cache_dir.glob('.done*'))
|
|
26
|
+
if not marker_files:
|
|
27
|
+
return False
|
|
28
|
+
|
|
29
|
+
# Get the most recent marker file
|
|
30
|
+
latest_marker = max(marker_files, key=lambda p: p.stat().st_mtime)
|
|
31
|
+
|
|
32
|
+
# Extract date from marker filename (format: .doneYYYYMMDD)
|
|
33
|
+
try:
|
|
34
|
+
date_str = latest_marker.name.replace('.done', '')
|
|
35
|
+
marker_date = datetime.strptime(date_str, '%Y%m%d')
|
|
36
|
+
# Check if marker is older than 1 days
|
|
37
|
+
if datetime.now() - marker_date > timedelta(days=1):
|
|
38
|
+
return False
|
|
39
|
+
return True
|
|
40
|
+
except (ValueError, IndexError):
|
|
41
|
+
# Invalid marker file format, treat as invalid cache
|
|
42
|
+
return False
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _create_cache_marker(cache_dir: Path) -> None:
|
|
46
|
+
"""Create a cache marker file with current date and clean old markers."""
|
|
47
|
+
# Remove old marker files
|
|
48
|
+
for old_marker in cache_dir.glob('.done*'):
|
|
49
|
+
old_marker.unlink(missing_ok=True)
|
|
50
|
+
|
|
51
|
+
# Create new marker file
|
|
52
|
+
marker_path = _get_cache_marker_path(cache_dir)
|
|
53
|
+
marker_path.touch()
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _resolve_model_path(model_name_or_path: str) -> str:
|
|
57
|
+
"""Resolve model path, downloading from Hugging Face when necessary."""
|
|
58
|
+
if Path(model_name_or_path).exists():
|
|
59
|
+
return model_name_or_path
|
|
60
|
+
|
|
61
|
+
from huggingface_hub import snapshot_download
|
|
62
|
+
from huggingface_hub.constants import HF_HUB_CACHE
|
|
63
|
+
from huggingface_hub.errors import LocalEntryNotFoundError
|
|
64
|
+
|
|
65
|
+
# Determine cache directory for this model
|
|
66
|
+
cache_dir = Path(HF_HUB_CACHE) / f'models--{model_name_or_path.replace("/", "--")}'
|
|
67
|
+
|
|
68
|
+
# Check if we have a valid cached version
|
|
69
|
+
if _is_cache_valid(cache_dir):
|
|
70
|
+
# Return the snapshot path (latest version)
|
|
71
|
+
snapshots_dir = cache_dir / 'snapshots'
|
|
72
|
+
if snapshots_dir.exists():
|
|
73
|
+
snapshot_dirs = [d for d in snapshots_dir.iterdir() if d.is_dir()]
|
|
74
|
+
if snapshot_dirs:
|
|
75
|
+
# Return the most recent snapshot
|
|
76
|
+
latest_snapshot = max(snapshot_dirs, key=lambda p: p.stat().st_mtime)
|
|
77
|
+
return str(latest_snapshot)
|
|
78
|
+
|
|
79
|
+
try:
|
|
80
|
+
downloaded_path = snapshot_download(repo_id=model_name_or_path, repo_type='model')
|
|
81
|
+
_create_cache_marker(cache_dir)
|
|
82
|
+
return downloaded_path
|
|
83
|
+
except LocalEntryNotFoundError:
|
|
84
|
+
try:
|
|
85
|
+
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
|
|
86
|
+
downloaded_path = snapshot_download(repo_id=model_name_or_path, repo_type='model')
|
|
87
|
+
_create_cache_marker(cache_dir)
|
|
88
|
+
return downloaded_path
|
|
89
|
+
except Exception as e: # pragma: no cover - bubble up for caller context
|
|
90
|
+
raise ModelLoadError(model_name_or_path, original_error=e)
|
|
91
|
+
except Exception as e: # pragma: no cover - unexpected download issue
|
|
92
|
+
raise ModelLoadError(model_name_or_path, original_error=e)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def _select_device(device: Optional[str]) -> str:
|
|
96
|
+
"""Select best available torch device when not explicitly provided."""
|
|
97
|
+
if device:
|
|
98
|
+
return device
|
|
99
|
+
|
|
100
|
+
import torch
|
|
101
|
+
|
|
102
|
+
detected = 'cpu'
|
|
103
|
+
if torch.backends.mps.is_available():
|
|
104
|
+
detected = 'mps'
|
|
105
|
+
elif torch.cuda.is_available():
|
|
106
|
+
detected = 'cuda'
|
|
107
|
+
return detected
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def _load_tokenizer(
|
|
111
|
+
client_wrapper: Any,
|
|
112
|
+
model_path: str,
|
|
113
|
+
device: str,
|
|
114
|
+
*,
|
|
115
|
+
tokenizer_cls: Type[LatticeTokenizer] = LatticeTokenizer,
|
|
116
|
+
) -> LatticeTokenizer:
|
|
117
|
+
"""Instantiate tokenizer with consistent error handling."""
|
|
118
|
+
try:
|
|
119
|
+
return tokenizer_cls.from_pretrained(
|
|
120
|
+
client_wrapper=client_wrapper,
|
|
121
|
+
model_path=model_path,
|
|
122
|
+
device=device,
|
|
123
|
+
)
|
|
124
|
+
except Exception as e:
|
|
125
|
+
raise ModelLoadError(f'tokenizer from {model_path}', original_error=e)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def _load_worker(model_path: str, device: str) -> Lattice1AlphaWorker:
|
|
129
|
+
"""Instantiate lattice worker with consistent error handling."""
|
|
130
|
+
try:
|
|
131
|
+
return Lattice1AlphaWorker(model_path, device=device, num_threads=8)
|
|
132
|
+
except Exception as e:
|
|
133
|
+
raise ModelLoadError(f'worker from {model_path}', original_error=e)
|