lattifai 0.2.4__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lattifai/__init__.py +32 -1
- lattifai/base_client.py +14 -6
- lattifai/bin/__init__.py +1 -0
- lattifai/bin/agent.py +325 -0
- lattifai/bin/align.py +253 -21
- lattifai/bin/cli_base.py +5 -0
- lattifai/bin/subtitle.py +182 -4
- lattifai/client.py +236 -63
- lattifai/errors.py +257 -0
- lattifai/io/__init__.py +21 -1
- lattifai/io/gemini_reader.py +371 -0
- lattifai/io/gemini_writer.py +173 -0
- lattifai/io/reader.py +21 -9
- lattifai/io/supervision.py +16 -0
- lattifai/io/utils.py +15 -0
- lattifai/io/writer.py +58 -17
- lattifai/tokenizer/__init__.py +2 -2
- lattifai/tokenizer/tokenizer.py +221 -40
- lattifai/utils.py +133 -0
- lattifai/workers/lattice1_alpha.py +130 -66
- lattifai-0.4.0.dist-info/METADATA +811 -0
- lattifai-0.4.0.dist-info/RECORD +28 -0
- lattifai-0.4.0.dist-info/entry_points.txt +3 -0
- lattifai-0.2.4.dist-info/METADATA +0 -334
- lattifai-0.2.4.dist-info/RECORD +0 -22
- lattifai-0.2.4.dist-info/entry_points.txt +0 -4
- {lattifai-0.2.4.dist-info → lattifai-0.4.0.dist-info}/WHEEL +0 -0
- {lattifai-0.2.4.dist-info → lattifai-0.4.0.dist-info}/licenses/LICENSE +0 -0
- {lattifai-0.2.4.dist-info → lattifai-0.4.0.dist-info}/top_level.txt +0 -0
lattifai/io/writer.py
CHANGED
|
@@ -1,49 +1,90 @@
|
|
|
1
|
+
import json
|
|
1
2
|
from abc import ABCMeta
|
|
2
|
-
from typing import List
|
|
3
|
+
from typing import Any, List, Optional
|
|
3
4
|
|
|
5
|
+
import pysubs2
|
|
6
|
+
from lhotse.supervision import AlignmentItem
|
|
4
7
|
from lhotse.utils import Pathlike
|
|
5
8
|
|
|
6
|
-
from .reader import
|
|
9
|
+
from .reader import Supervision
|
|
7
10
|
|
|
8
11
|
|
|
9
12
|
class SubtitleWriter(ABCMeta):
|
|
10
|
-
"""Class for writing subtitle files."""
|
|
13
|
+
"""Class for writing subtitle files with optional word-level alignment."""
|
|
11
14
|
|
|
12
15
|
@classmethod
|
|
13
16
|
def write(cls, alignments: List[Supervision], output_path: Pathlike) -> Pathlike:
|
|
14
17
|
if str(output_path)[-4:].lower() == '.txt':
|
|
15
18
|
with open(output_path, 'w', encoding='utf-8') as f:
|
|
16
19
|
for sup in alignments:
|
|
17
|
-
|
|
20
|
+
word_items = parse_alignment_from_supervision(sup)
|
|
21
|
+
if word_items:
|
|
22
|
+
for item in word_items:
|
|
23
|
+
f.write(f'[{item.start:.2f}-{item.end:.2f}] {item.symbol}\n')
|
|
24
|
+
else:
|
|
25
|
+
text = f'{sup.speaker} {sup.text}' if sup.speaker is not None else sup.text
|
|
26
|
+
f.write(f'[{sup.start:.2f}-{sup.end:.2f}] {text}\n')
|
|
27
|
+
|
|
18
28
|
elif str(output_path)[-5:].lower() == '.json':
|
|
19
29
|
with open(output_path, 'w', encoding='utf-8') as f:
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
30
|
+
# Enhanced JSON export with word-level alignment
|
|
31
|
+
json_data = []
|
|
32
|
+
for sup in alignments:
|
|
33
|
+
sup_dict = sup.to_dict()
|
|
34
|
+
json_data.append(sup_dict)
|
|
35
|
+
json.dump(json_data, f, ensure_ascii=False, indent=4)
|
|
23
36
|
elif str(output_path).endswith('.TextGrid') or str(output_path).endswith('.textgrid'):
|
|
24
37
|
from tgt import Interval, IntervalTier, TextGrid, write_to_file
|
|
25
38
|
|
|
26
39
|
tg = TextGrid()
|
|
27
40
|
supervisions, words = [], []
|
|
28
41
|
for supervision in sorted(alignments, key=lambda x: x.start):
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
42
|
+
text = (
|
|
43
|
+
f'{supervision.speaker} {supervision.text}' if supervision.speaker is not None else supervision.text
|
|
44
|
+
)
|
|
45
|
+
supervisions.append(Interval(supervision.start, supervision.end, text or ''))
|
|
46
|
+
# Extract word-level alignment using helper function
|
|
47
|
+
word_items = parse_alignment_from_supervision(supervision)
|
|
48
|
+
if word_items:
|
|
49
|
+
for item in word_items:
|
|
50
|
+
words.append(Interval(item.start, item.end, item.symbol))
|
|
33
51
|
|
|
34
52
|
tg.add_tier(IntervalTier(name='utterances', objects=supervisions))
|
|
35
53
|
if words:
|
|
36
54
|
tg.add_tier(IntervalTier(name='words', objects=words))
|
|
37
55
|
write_to_file(tg, output_path, format='long')
|
|
38
56
|
else:
|
|
39
|
-
import pysubs2
|
|
40
|
-
|
|
41
57
|
subs = pysubs2.SSAFile()
|
|
42
58
|
for sup in alignments:
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
59
|
+
# Add word-level timing as metadata in the subtitle text
|
|
60
|
+
word_items = parse_alignment_from_supervision(sup)
|
|
61
|
+
if word_items:
|
|
62
|
+
for word in word_items:
|
|
63
|
+
subs.append(
|
|
64
|
+
pysubs2.SSAEvent(start=int(word.start * 1000), end=int(word.end * 1000), text=word.symbol)
|
|
65
|
+
)
|
|
66
|
+
else:
|
|
67
|
+
text = f'{sup.speaker} {sup.text}' if sup.speaker is not None else sup.text
|
|
68
|
+
subs.append(pysubs2.SSAEvent(start=int(sup.start * 1000), end=int(sup.end * 1000), text=text or ''))
|
|
47
69
|
subs.save(output_path)
|
|
48
70
|
|
|
49
71
|
return output_path
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def parse_alignment_from_supervision(supervision: Any) -> Optional[List[AlignmentItem]]:
|
|
75
|
+
"""
|
|
76
|
+
Extract word-level alignment items from Supervision object.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
supervision: Supervision object with potential alignment data
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
List of AlignmentItem objects, or None if no alignment data present
|
|
83
|
+
"""
|
|
84
|
+
if not hasattr(supervision, 'alignment') or not supervision.alignment:
|
|
85
|
+
return None
|
|
86
|
+
|
|
87
|
+
if 'word' not in supervision.alignment:
|
|
88
|
+
return None
|
|
89
|
+
|
|
90
|
+
return supervision.alignment['word']
|
lattifai/tokenizer/__init__.py
CHANGED
|
@@ -1,3 +1,3 @@
|
|
|
1
|
-
from .tokenizer import LatticeTokenizer
|
|
1
|
+
from .tokenizer import AsyncLatticeTokenizer, LatticeTokenizer
|
|
2
2
|
|
|
3
|
-
__all__ = ['LatticeTokenizer']
|
|
3
|
+
__all__ = ['LatticeTokenizer', 'AsyncLatticeTokenizer']
|
lattifai/tokenizer/tokenizer.py
CHANGED
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
import gzip
|
|
2
|
+
import inspect
|
|
2
3
|
import pickle
|
|
3
4
|
import re
|
|
4
5
|
from collections import defaultdict
|
|
5
|
-
from
|
|
6
|
-
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
6
|
+
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
|
7
7
|
|
|
8
8
|
import torch
|
|
9
9
|
|
|
10
|
-
from lattifai.
|
|
10
|
+
from lattifai.errors import LATTICE_DECODING_FAILURE_HELP, LatticeDecodingError
|
|
11
11
|
from lattifai.io import Supervision
|
|
12
12
|
from lattifai.tokenizer.phonemizer import G2Phonemizer
|
|
13
13
|
|
|
@@ -21,10 +21,13 @@ GROUPING_SEPARATOR = '✹'
|
|
|
21
21
|
MAXIMUM_WORD_LENGTH = 40
|
|
22
22
|
|
|
23
23
|
|
|
24
|
+
TokenizerT = TypeVar('TokenizerT', bound='LatticeTokenizer')
|
|
25
|
+
|
|
26
|
+
|
|
24
27
|
class LatticeTokenizer:
|
|
25
28
|
"""Tokenizer for converting Lhotse Cut to LatticeGraph."""
|
|
26
29
|
|
|
27
|
-
def __init__(self, client_wrapper:
|
|
30
|
+
def __init__(self, client_wrapper: Any):
|
|
28
31
|
self.client_wrapper = client_wrapper
|
|
29
32
|
self.words: List[str] = []
|
|
30
33
|
self.g2p_model: Any = None # Placeholder for G2P model
|
|
@@ -99,13 +102,14 @@ class LatticeTokenizer:
|
|
|
99
102
|
# If no special pattern matches, return the original sentence
|
|
100
103
|
return [sentence]
|
|
101
104
|
|
|
102
|
-
@
|
|
105
|
+
@classmethod
|
|
103
106
|
def from_pretrained(
|
|
104
|
-
|
|
107
|
+
cls: Type[TokenizerT],
|
|
108
|
+
client_wrapper: Any,
|
|
105
109
|
model_path: str,
|
|
106
110
|
device: str = 'cpu',
|
|
107
111
|
compressed: bool = True,
|
|
108
|
-
):
|
|
112
|
+
) -> TokenizerT:
|
|
109
113
|
"""Load tokenizer from exported binary file"""
|
|
110
114
|
from pathlib import Path
|
|
111
115
|
|
|
@@ -117,7 +121,7 @@ class LatticeTokenizer:
|
|
|
117
121
|
with open(words_model_path, 'rb') as f:
|
|
118
122
|
data = pickle.load(f)
|
|
119
123
|
|
|
120
|
-
tokenizer =
|
|
124
|
+
tokenizer = cls(client_wrapper=client_wrapper)
|
|
121
125
|
tokenizer.words = data['words']
|
|
122
126
|
tokenizer.dictionaries = defaultdict(list, data['dictionaries'])
|
|
123
127
|
tokenizer.oov_word = data['oov_word']
|
|
@@ -179,53 +183,89 @@ class LatticeTokenizer:
|
|
|
179
183
|
return {}
|
|
180
184
|
|
|
181
185
|
def split_sentences(self, supervisions: List[Supervision], strip_whitespace=True) -> List[str]:
|
|
186
|
+
"""Split supervisions into sentences using the sentence splitter.
|
|
187
|
+
|
|
188
|
+
Carefull about speaker changes.
|
|
189
|
+
"""
|
|
182
190
|
texts, text_len, sidx = [], 0, 0
|
|
191
|
+
speakers = []
|
|
183
192
|
for s, supervision in enumerate(supervisions):
|
|
184
193
|
text_len += len(supervision.text)
|
|
185
|
-
if
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
194
|
+
if supervision.speaker:
|
|
195
|
+
speakers.append(supervision.speaker)
|
|
196
|
+
if sidx < s:
|
|
197
|
+
text = ' '.join([sup.text for sup in supervisions[sidx:s]])
|
|
198
|
+
texts.append(text)
|
|
199
|
+
sidx = s
|
|
200
|
+
text_len = len(supervision.text)
|
|
201
|
+
else:
|
|
202
|
+
if text_len >= 2000 or s == len(supervisions) - 1:
|
|
203
|
+
if len(speakers) < len(texts) + 1:
|
|
204
|
+
speakers.append(None)
|
|
205
|
+
text = ' '.join([sup.text for sup in supervisions[sidx : s + 1]])
|
|
206
|
+
texts.append(text)
|
|
207
|
+
sidx = s + 1
|
|
208
|
+
text_len = 0
|
|
209
|
+
|
|
210
|
+
assert len(speakers) == len(texts), f'len(speakers)={len(speakers)} != len(texts)={len(texts)}'
|
|
193
211
|
sentences = self.sentence_splitter.split(texts, threshold=0.15, strip_whitespace=strip_whitespace)
|
|
194
212
|
|
|
195
213
|
supervisions, remainder = [], ''
|
|
196
|
-
for _sentences in sentences:
|
|
214
|
+
for k, (_speaker, _sentences) in enumerate(zip(speakers, sentences)):
|
|
215
|
+
# Prepend remainder from previous iteration to the first sentence
|
|
216
|
+
if _sentences and remainder:
|
|
217
|
+
_sentences[0] = remainder + _sentences[0]
|
|
218
|
+
remainder = ''
|
|
219
|
+
|
|
220
|
+
if not _sentences:
|
|
221
|
+
continue
|
|
222
|
+
|
|
197
223
|
# Process and re-split special sentence types
|
|
198
224
|
processed_sentences = []
|
|
199
225
|
for s, _sentence in enumerate(_sentences):
|
|
200
226
|
if remainder:
|
|
201
227
|
_sentence = remainder + _sentence
|
|
202
228
|
remainder = ''
|
|
203
|
-
|
|
204
229
|
# Detect and split special sentence types: e.g., '[APPLAUSE] >> MIRA MURATI:' -> ['[APPLAUSE]', '>> MIRA MURATI:'] # noqa: E501
|
|
205
230
|
resplit_parts = self._resplit_special_sentence_types(_sentence)
|
|
206
231
|
if any(resplit_parts[-1].endswith(sp) for sp in [':', ':']):
|
|
207
232
|
if s < len(_sentences) - 1:
|
|
208
233
|
_sentences[s + 1] = resplit_parts[-1] + ' ' + _sentences[s + 1]
|
|
209
234
|
else: # last part
|
|
210
|
-
remainder = resplit_parts[-1] + ' '
|
|
235
|
+
remainder = resplit_parts[-1] + ' '
|
|
211
236
|
processed_sentences.extend(resplit_parts[:-1])
|
|
212
237
|
else:
|
|
213
238
|
processed_sentences.extend(resplit_parts)
|
|
214
|
-
|
|
215
239
|
_sentences = processed_sentences
|
|
216
240
|
|
|
217
|
-
if remainder:
|
|
218
|
-
_sentences[0] = remainder + _sentences[0]
|
|
219
|
-
remainder = ''
|
|
220
|
-
|
|
221
241
|
if any(_sentences[-1].endswith(ep) for ep in END_PUNCTUATION):
|
|
222
|
-
supervisions.extend(
|
|
242
|
+
supervisions.extend(
|
|
243
|
+
Supervision(text=text, speaker=(_speaker if s == 0 else None)) for s, text in enumerate(_sentences)
|
|
244
|
+
)
|
|
245
|
+
_speaker = None # reset speaker after use
|
|
223
246
|
else:
|
|
224
|
-
supervisions.extend(
|
|
225
|
-
|
|
247
|
+
supervisions.extend(
|
|
248
|
+
Supervision(text=text, speaker=(_speaker if s == 0 else None))
|
|
249
|
+
for s, text in enumerate(_sentences[:-1])
|
|
250
|
+
)
|
|
251
|
+
remainder = _sentences[-1] + ' ' + remainder
|
|
252
|
+
if k < len(speakers) - 1 and speakers[k + 1] is not None: # next speaker is set
|
|
253
|
+
supervisions.append(
|
|
254
|
+
Supervision(text=remainder.strip(), speaker=_speaker if len(_sentences) == 1 else None)
|
|
255
|
+
)
|
|
256
|
+
remainder = ''
|
|
257
|
+
elif len(_sentences) == 1:
|
|
258
|
+
if k == len(speakers) - 1:
|
|
259
|
+
pass # keep _speaker for the last supervision
|
|
260
|
+
else:
|
|
261
|
+
assert speakers[k + 1] is None
|
|
262
|
+
speakers[k + 1] = _speaker
|
|
263
|
+
else:
|
|
264
|
+
assert len(_sentences) > 1
|
|
265
|
+
_speaker = None # reset speaker if sentence not ended
|
|
226
266
|
|
|
227
267
|
if remainder.strip():
|
|
228
|
-
supervisions.append(Supervision(text=remainder.strip()))
|
|
268
|
+
supervisions.append(Supervision(text=remainder.strip(), speaker=_speaker))
|
|
229
269
|
|
|
230
270
|
return supervisions
|
|
231
271
|
|
|
@@ -246,14 +286,18 @@ class LatticeTokenizer:
|
|
|
246
286
|
raise Exception(f'Failed to tokenize texts: {response.text}')
|
|
247
287
|
result = response.json()
|
|
248
288
|
lattice_id = result['id']
|
|
249
|
-
return
|
|
289
|
+
return (
|
|
290
|
+
supervisions,
|
|
291
|
+
lattice_id,
|
|
292
|
+
(result['lattice_graph'], result['final_state'], result.get('acoustic_scale', 1.0)),
|
|
293
|
+
)
|
|
250
294
|
|
|
251
295
|
def detokenize(
|
|
252
296
|
self,
|
|
253
297
|
lattice_id: str,
|
|
254
298
|
lattice_results: Tuple[torch.Tensor, Any, Any, float, float],
|
|
255
|
-
|
|
256
|
-
|
|
299
|
+
supervisions: List[Supervision],
|
|
300
|
+
return_details: bool = False,
|
|
257
301
|
) -> List[Supervision]:
|
|
258
302
|
emission, results, labels, frame_shift, offset, channel = lattice_results # noqa: F841
|
|
259
303
|
response = self.client_wrapper.post(
|
|
@@ -265,20 +309,157 @@ class LatticeTokenizer:
|
|
|
265
309
|
'labels': labels[0],
|
|
266
310
|
'offset': offset,
|
|
267
311
|
'channel': channel,
|
|
312
|
+
'return_details': return_details,
|
|
268
313
|
'destroy_lattice': True,
|
|
269
314
|
},
|
|
270
315
|
)
|
|
316
|
+
if response.status_code == 422:
|
|
317
|
+
raise LatticeDecodingError(
|
|
318
|
+
lattice_id,
|
|
319
|
+
original_error=Exception(LATTICE_DECODING_FAILURE_HELP),
|
|
320
|
+
)
|
|
271
321
|
if response.status_code != 200:
|
|
272
322
|
raise Exception(f'Failed to detokenize lattice: {response.text}')
|
|
323
|
+
|
|
273
324
|
result = response.json()
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
325
|
+
if not result.get('success'):
|
|
326
|
+
raise Exception('Failed to detokenize the alignment results.')
|
|
327
|
+
|
|
328
|
+
alignments = [Supervision.from_dict(s) for s in result['supervisions']]
|
|
329
|
+
|
|
330
|
+
if return_details:
|
|
331
|
+
# Add emission confidence scores for segments and word-level alignments
|
|
332
|
+
_add_confidence_scores(alignments, emission, labels[0], frame_shift)
|
|
333
|
+
|
|
334
|
+
alignments = _update_alignments_speaker(supervisions, alignments)
|
|
335
|
+
|
|
336
|
+
return alignments
|
|
337
|
+
|
|
277
338
|
|
|
339
|
+
class AsyncLatticeTokenizer(LatticeTokenizer):
|
|
340
|
+
async def _post_async(self, endpoint: str, **kwargs):
|
|
341
|
+
response = self.client_wrapper.post(endpoint, **kwargs)
|
|
342
|
+
if inspect.isawaitable(response):
|
|
343
|
+
return await response
|
|
344
|
+
return response
|
|
278
345
|
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
346
|
+
async def tokenize(
|
|
347
|
+
self, supervisions: List[Supervision], split_sentence: bool = False
|
|
348
|
+
) -> Tuple[str, Dict[str, Any]]:
|
|
349
|
+
if split_sentence:
|
|
350
|
+
self.init_sentence_splitter()
|
|
351
|
+
supervisions = self.split_sentences(supervisions)
|
|
352
|
+
|
|
353
|
+
pronunciation_dictionaries = self.prenormalize([s.text for s in supervisions])
|
|
354
|
+
response = await self._post_async(
|
|
355
|
+
'tokenize',
|
|
356
|
+
json={
|
|
357
|
+
'supervisions': [s.to_dict() for s in supervisions],
|
|
358
|
+
'pronunciation_dictionaries': pronunciation_dictionaries,
|
|
359
|
+
},
|
|
360
|
+
)
|
|
361
|
+
if response.status_code != 200:
|
|
362
|
+
raise Exception(f'Failed to tokenize texts: {response.text}')
|
|
363
|
+
result = response.json()
|
|
364
|
+
lattice_id = result['id']
|
|
365
|
+
return (
|
|
366
|
+
supervisions,
|
|
367
|
+
lattice_id,
|
|
368
|
+
(result['lattice_graph'], result['final_state'], result.get('acoustic_scale', 1.0)),
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
async def detokenize(
|
|
372
|
+
self,
|
|
373
|
+
lattice_id: str,
|
|
374
|
+
lattice_results: Tuple[torch.Tensor, Any, Any, float, float],
|
|
375
|
+
supervisions: List[Supervision],
|
|
376
|
+
return_details: bool = False,
|
|
377
|
+
) -> List[Supervision]:
|
|
378
|
+
emission, results, labels, frame_shift, offset, channel = lattice_results # noqa: F841
|
|
379
|
+
response = await self._post_async(
|
|
380
|
+
'detokenize',
|
|
381
|
+
json={
|
|
382
|
+
'lattice_id': lattice_id,
|
|
383
|
+
'frame_shift': frame_shift,
|
|
384
|
+
'results': [t.to_dict() for t in results[0]],
|
|
385
|
+
'labels': labels[0],
|
|
386
|
+
'offset': offset,
|
|
387
|
+
'channel': channel,
|
|
388
|
+
'return_details': return_details,
|
|
389
|
+
'destroy_lattice': True,
|
|
390
|
+
},
|
|
391
|
+
)
|
|
392
|
+
if response.status_code == 422:
|
|
393
|
+
raise LatticeDecodingError(
|
|
394
|
+
lattice_id,
|
|
395
|
+
original_error=Exception(LATTICE_DECODING_FAILURE_HELP),
|
|
396
|
+
)
|
|
397
|
+
if response.status_code != 200:
|
|
398
|
+
raise Exception(f'Failed to detokenize lattice: {response.text}')
|
|
399
|
+
|
|
400
|
+
result = response.json()
|
|
401
|
+
if not result.get('success'):
|
|
402
|
+
return Exception('Failed to detokenize the alignment results.')
|
|
403
|
+
|
|
404
|
+
alignments = [Supervision.from_dict(s) for s in result['supervisions']]
|
|
405
|
+
|
|
406
|
+
if return_details:
|
|
407
|
+
# Add emission confidence scores for segments and word-level alignments
|
|
408
|
+
_add_confidence_scores(alignments, emission, labels[0], frame_shift)
|
|
409
|
+
|
|
410
|
+
alignments = _update_alignments_speaker(supervisions, alignments)
|
|
411
|
+
|
|
412
|
+
return alignments
|
|
413
|
+
|
|
414
|
+
|
|
415
|
+
def _add_confidence_scores(
|
|
416
|
+
supervisions: List[Supervision],
|
|
417
|
+
emission: torch.Tensor,
|
|
418
|
+
labels: List[int],
|
|
419
|
+
frame_shift: float,
|
|
420
|
+
) -> None:
|
|
421
|
+
"""
|
|
422
|
+
Add confidence scores to supervisions and their word-level alignments.
|
|
423
|
+
|
|
424
|
+
This function modifies supervisions in-place by:
|
|
425
|
+
1. Computing segment-level confidence scores based on emission probabilities
|
|
426
|
+
2. Computing word-level confidence scores for each aligned word
|
|
427
|
+
|
|
428
|
+
Args:
|
|
429
|
+
supervisions: List of Supervision objects to add scores to (modified in-place)
|
|
430
|
+
emission: Emission tensor with shape [batch, time, vocab_size]
|
|
431
|
+
labels: Token labels corresponding to aligned tokens
|
|
432
|
+
frame_shift: Frame shift in seconds for converting frames to time
|
|
433
|
+
"""
|
|
434
|
+
tokens = torch.tensor(labels, dtype=torch.int64, device=emission.device)
|
|
435
|
+
|
|
436
|
+
for supervision in supervisions:
|
|
437
|
+
start_frame = int(supervision.start / frame_shift)
|
|
438
|
+
end_frame = int(supervision.end / frame_shift)
|
|
439
|
+
|
|
440
|
+
# Compute segment-level confidence
|
|
441
|
+
probabilities = emission[0, start_frame:end_frame].softmax(dim=-1)
|
|
442
|
+
aligned = probabilities[range(0, end_frame - start_frame), tokens[start_frame:end_frame]]
|
|
443
|
+
diffprobs = (probabilities.max(dim=-1).values - aligned).cpu()
|
|
444
|
+
supervision.score = round(1.0 - diffprobs.mean().item(), ndigits=4)
|
|
445
|
+
|
|
446
|
+
# Compute word-level confidence if alignment exists
|
|
447
|
+
if hasattr(supervision, 'alignment') and supervision.alignment:
|
|
448
|
+
words = supervision.alignment.get('word', [])
|
|
449
|
+
for w, item in enumerate(words):
|
|
450
|
+
start = int(item.start / frame_shift) - start_frame
|
|
451
|
+
end = int(item.end / frame_shift) - start_frame
|
|
452
|
+
words[w] = item._replace(score=round(1.0 - diffprobs[start:end].mean().item(), ndigits=4))
|
|
453
|
+
|
|
454
|
+
|
|
455
|
+
def _update_alignments_speaker(supervisions: List[Supervision], alignments: List[Supervision]) -> List[Supervision]:
|
|
456
|
+
"""
|
|
457
|
+
Update the speaker attribute for a list of supervisions.
|
|
458
|
+
|
|
459
|
+
Args:
|
|
460
|
+
supervisions: List of Supervision objects to get speaker info from
|
|
461
|
+
alignments: List of aligned Supervision objects to update speaker info to
|
|
462
|
+
"""
|
|
463
|
+
for supervision, alignment in zip(supervisions, alignments):
|
|
464
|
+
alignment.speaker = supervision.speaker
|
|
465
|
+
return alignments
|
lattifai/utils.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
"""Shared utility helpers for the LattifAI SDK."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from datetime import datetime, timedelta
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Optional, Type
|
|
7
|
+
|
|
8
|
+
from lattifai.errors import ModelLoadError
|
|
9
|
+
from lattifai.tokenizer import LatticeTokenizer
|
|
10
|
+
from lattifai.workers import Lattice1AlphaWorker
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _get_cache_marker_path(cache_dir: Path) -> Path:
|
|
14
|
+
"""Get the path for the cache marker file with current date."""
|
|
15
|
+
today = datetime.now().strftime('%Y%m%d')
|
|
16
|
+
return cache_dir / f'.done{today}'
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _is_cache_valid(cache_dir: Path) -> bool:
|
|
20
|
+
"""Check if cached model is valid (exists and not older than 1 days)."""
|
|
21
|
+
if not cache_dir.exists():
|
|
22
|
+
return False
|
|
23
|
+
|
|
24
|
+
# Find any .done* marker files
|
|
25
|
+
marker_files = list(cache_dir.glob('.done*'))
|
|
26
|
+
if not marker_files:
|
|
27
|
+
return False
|
|
28
|
+
|
|
29
|
+
# Get the most recent marker file
|
|
30
|
+
latest_marker = max(marker_files, key=lambda p: p.stat().st_mtime)
|
|
31
|
+
|
|
32
|
+
# Extract date from marker filename (format: .doneYYYYMMDD)
|
|
33
|
+
try:
|
|
34
|
+
date_str = latest_marker.name.replace('.done', '')
|
|
35
|
+
marker_date = datetime.strptime(date_str, '%Y%m%d')
|
|
36
|
+
# Check if marker is older than 1 days
|
|
37
|
+
if datetime.now() - marker_date > timedelta(days=1):
|
|
38
|
+
return False
|
|
39
|
+
return True
|
|
40
|
+
except (ValueError, IndexError):
|
|
41
|
+
# Invalid marker file format, treat as invalid cache
|
|
42
|
+
return False
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _create_cache_marker(cache_dir: Path) -> None:
|
|
46
|
+
"""Create a cache marker file with current date and clean old markers."""
|
|
47
|
+
# Remove old marker files
|
|
48
|
+
for old_marker in cache_dir.glob('.done*'):
|
|
49
|
+
old_marker.unlink(missing_ok=True)
|
|
50
|
+
|
|
51
|
+
# Create new marker file
|
|
52
|
+
marker_path = _get_cache_marker_path(cache_dir)
|
|
53
|
+
marker_path.touch()
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _resolve_model_path(model_name_or_path: str) -> str:
|
|
57
|
+
"""Resolve model path, downloading from Hugging Face when necessary."""
|
|
58
|
+
if Path(model_name_or_path).exists():
|
|
59
|
+
return model_name_or_path
|
|
60
|
+
|
|
61
|
+
from huggingface_hub import snapshot_download
|
|
62
|
+
from huggingface_hub.constants import HF_HUB_CACHE
|
|
63
|
+
from huggingface_hub.errors import LocalEntryNotFoundError
|
|
64
|
+
|
|
65
|
+
# Determine cache directory for this model
|
|
66
|
+
cache_dir = Path(HF_HUB_CACHE) / f'models--{model_name_or_path.replace("/", "--")}'
|
|
67
|
+
|
|
68
|
+
# Check if we have a valid cached version
|
|
69
|
+
if _is_cache_valid(cache_dir):
|
|
70
|
+
# Return the snapshot path (latest version)
|
|
71
|
+
snapshots_dir = cache_dir / 'snapshots'
|
|
72
|
+
if snapshots_dir.exists():
|
|
73
|
+
snapshot_dirs = [d for d in snapshots_dir.iterdir() if d.is_dir()]
|
|
74
|
+
if snapshot_dirs:
|
|
75
|
+
# Return the most recent snapshot
|
|
76
|
+
latest_snapshot = max(snapshot_dirs, key=lambda p: p.stat().st_mtime)
|
|
77
|
+
return str(latest_snapshot)
|
|
78
|
+
|
|
79
|
+
try:
|
|
80
|
+
downloaded_path = snapshot_download(repo_id=model_name_or_path, repo_type='model')
|
|
81
|
+
_create_cache_marker(cache_dir)
|
|
82
|
+
return downloaded_path
|
|
83
|
+
except LocalEntryNotFoundError:
|
|
84
|
+
try:
|
|
85
|
+
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
|
|
86
|
+
downloaded_path = snapshot_download(repo_id=model_name_or_path, repo_type='model')
|
|
87
|
+
_create_cache_marker(cache_dir)
|
|
88
|
+
return downloaded_path
|
|
89
|
+
except Exception as e: # pragma: no cover - bubble up for caller context
|
|
90
|
+
raise ModelLoadError(model_name_or_path, original_error=e)
|
|
91
|
+
except Exception as e: # pragma: no cover - unexpected download issue
|
|
92
|
+
raise ModelLoadError(model_name_or_path, original_error=e)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def _select_device(device: Optional[str]) -> str:
|
|
96
|
+
"""Select best available torch device when not explicitly provided."""
|
|
97
|
+
if device:
|
|
98
|
+
return device
|
|
99
|
+
|
|
100
|
+
import torch
|
|
101
|
+
|
|
102
|
+
detected = 'cpu'
|
|
103
|
+
if torch.backends.mps.is_available():
|
|
104
|
+
detected = 'mps'
|
|
105
|
+
elif torch.cuda.is_available():
|
|
106
|
+
detected = 'cuda'
|
|
107
|
+
return detected
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def _load_tokenizer(
|
|
111
|
+
client_wrapper: Any,
|
|
112
|
+
model_path: str,
|
|
113
|
+
device: str,
|
|
114
|
+
*,
|
|
115
|
+
tokenizer_cls: Type[LatticeTokenizer] = LatticeTokenizer,
|
|
116
|
+
) -> LatticeTokenizer:
|
|
117
|
+
"""Instantiate tokenizer with consistent error handling."""
|
|
118
|
+
try:
|
|
119
|
+
return tokenizer_cls.from_pretrained(
|
|
120
|
+
client_wrapper=client_wrapper,
|
|
121
|
+
model_path=model_path,
|
|
122
|
+
device=device,
|
|
123
|
+
)
|
|
124
|
+
except Exception as e:
|
|
125
|
+
raise ModelLoadError(f'tokenizer from {model_path}', original_error=e)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def _load_worker(model_path: str, device: str) -> Lattice1AlphaWorker:
|
|
129
|
+
"""Instantiate lattice worker with consistent error handling."""
|
|
130
|
+
try:
|
|
131
|
+
return Lattice1AlphaWorker(model_path, device=device, num_threads=8)
|
|
132
|
+
except Exception as e:
|
|
133
|
+
raise ModelLoadError(f'worker from {model_path}', original_error=e)
|