lattifai 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lattifai/__init__.py +30 -0
- lattifai/base_client.py +118 -0
- lattifai/bin/__init__.py +2 -0
- lattifai/bin/align.py +42 -0
- lattifai/bin/cli_base.py +14 -0
- lattifai/bin/subtitle.py +32 -0
- lattifai/client.py +131 -0
- lattifai/io/__init__.py +22 -0
- lattifai/io/reader.py +71 -0
- lattifai/io/supervision.py +17 -0
- lattifai/io/writer.py +49 -0
- lattifai/tokenizers/__init__.py +3 -0
- lattifai/tokenizers/phonemizer.py +50 -0
- lattifai/tokenizers/tokenizer.py +143 -0
- lattifai/workers/__init__.py +3 -0
- lattifai/workers/lattice1_alpha.py +119 -0
- lattifai-0.1.4.dist-info/METADATA +467 -0
- lattifai-0.1.4.dist-info/RECORD +22 -0
- lattifai-0.1.4.dist-info/WHEEL +5 -0
- lattifai-0.1.4.dist-info/entry_points.txt +3 -0
- lattifai-0.1.4.dist-info/licenses/LICENSE +21 -0
- lattifai-0.1.4.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
import gzip
|
|
2
|
+
import pickle
|
|
3
|
+
from collections import defaultdict
|
|
4
|
+
from itertools import chain
|
|
5
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from lattifai.base_client import SyncAPIClient
|
|
10
|
+
from lattifai.io import Supervision
|
|
11
|
+
from lattifai.tokenizers.phonemizer import G2Phonemizer
|
|
12
|
+
|
|
13
|
+
PUNCTUATION = '!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~'
|
|
14
|
+
PUNCTUATION_SPACE = PUNCTUATION + ' '
|
|
15
|
+
STAR_TOKEN = '※'
|
|
16
|
+
|
|
17
|
+
GROUPING_SEPARATOR = '✹'
|
|
18
|
+
|
|
19
|
+
MAXIMUM_WORD_LENGTH = 40
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class LatticeTokenizer:
|
|
23
|
+
"""Tokenizer for converting Lhotse Cut to LatticeGraph."""
|
|
24
|
+
|
|
25
|
+
def __init__(self, client_wrapper: SyncAPIClient):
|
|
26
|
+
self.client_wrapper = client_wrapper
|
|
27
|
+
self.words: List[str] = []
|
|
28
|
+
self.g2p_model: Any = None # Placeholder for G2P model
|
|
29
|
+
self.dictionaries = defaultdict(lambda: [])
|
|
30
|
+
self.oov_word = '<unk>'
|
|
31
|
+
|
|
32
|
+
@staticmethod
|
|
33
|
+
def from_pretrained(
|
|
34
|
+
client_wrapper: SyncAPIClient,
|
|
35
|
+
model_path: str,
|
|
36
|
+
g2p_model_path: Optional[str] = None,
|
|
37
|
+
device: str = 'cpu',
|
|
38
|
+
compressed: bool = True,
|
|
39
|
+
):
|
|
40
|
+
"""Load tokenizer from exported binary file"""
|
|
41
|
+
if compressed:
|
|
42
|
+
with gzip.open(model_path, 'rb') as f:
|
|
43
|
+
data = pickle.load(f)
|
|
44
|
+
else:
|
|
45
|
+
with open(model_path, 'rb') as f:
|
|
46
|
+
data = pickle.load(f)
|
|
47
|
+
|
|
48
|
+
tokenizer = LatticeTokenizer(client_wrapper=client_wrapper)
|
|
49
|
+
tokenizer.words = data['words']
|
|
50
|
+
tokenizer.dictionaries = defaultdict(list, data['dictionaries'])
|
|
51
|
+
tokenizer.oov_word = data['oov_word']
|
|
52
|
+
if g2p_model_path:
|
|
53
|
+
tokenizer.g2p_model = G2Phonemizer(g2p_model_path, device=device)
|
|
54
|
+
return tokenizer
|
|
55
|
+
|
|
56
|
+
def prenormalize(self, texts: List[str], language: Optional[str] = None) -> List[str]:
|
|
57
|
+
if not self.g2p_model:
|
|
58
|
+
raise ValueError('G2P model is not loaded, cannot prenormalize texts')
|
|
59
|
+
|
|
60
|
+
oov_words = []
|
|
61
|
+
for text in texts:
|
|
62
|
+
words = text.lower().replace('-', ' ').replace('—', ' ').replace('–', ' ').split()
|
|
63
|
+
oovs = [w for w in words if w not in self.words]
|
|
64
|
+
if oovs:
|
|
65
|
+
oov_words.extend([w for w in oovs if (w not in self.words and len(w) <= MAXIMUM_WORD_LENGTH)])
|
|
66
|
+
|
|
67
|
+
oov_words = list(set(oov_words))
|
|
68
|
+
if oov_words:
|
|
69
|
+
indexs = []
|
|
70
|
+
for k, _word in enumerate(oov_words):
|
|
71
|
+
if any(_word.startswith(p) and _word.endswith(q) for (p, q) in [('(', ')'), ('[', ']')]):
|
|
72
|
+
self.dictionaries[_word] = self.dictionaries[self.oov_word]
|
|
73
|
+
else:
|
|
74
|
+
_word = _word.strip(PUNCTUATION_SPACE)
|
|
75
|
+
if not _word or _word in self.words:
|
|
76
|
+
indexs.append(k)
|
|
77
|
+
for idx in sorted(indexs, reverse=True):
|
|
78
|
+
del oov_words[idx]
|
|
79
|
+
|
|
80
|
+
g2p_words = [w for w in oov_words if w not in self.dictionaries]
|
|
81
|
+
if g2p_words:
|
|
82
|
+
predictions = self.g2p_model(words=g2p_words, lang=language, batch_size=len(g2p_words), num_prons=4)
|
|
83
|
+
for _word, _predictions in zip(g2p_words, predictions):
|
|
84
|
+
for pronuncation in _predictions:
|
|
85
|
+
if pronuncation and pronuncation not in self.dictionaries[_word]:
|
|
86
|
+
self.dictionaries[_word].append(pronuncation)
|
|
87
|
+
|
|
88
|
+
pronunciation_dictionaries: Dict[str, List[List[str]]] = {
|
|
89
|
+
w: self.dictionaries[w] for w in oov_words if self.dictionaries[w]
|
|
90
|
+
}
|
|
91
|
+
return pronunciation_dictionaries
|
|
92
|
+
|
|
93
|
+
return {}
|
|
94
|
+
|
|
95
|
+
def tokenize(self, supervisions: List[Supervision]) -> Tuple[str, Dict[str, Any]]:
|
|
96
|
+
pronunciation_dictionaries = self.prenormalize([s.text for s in supervisions])
|
|
97
|
+
response = self.client_wrapper.post(
|
|
98
|
+
'tokenize',
|
|
99
|
+
json={
|
|
100
|
+
'supervisions': [s.to_dict() for s in supervisions],
|
|
101
|
+
'pronunciation_dictionaries': pronunciation_dictionaries,
|
|
102
|
+
},
|
|
103
|
+
)
|
|
104
|
+
if response.status_code != 200:
|
|
105
|
+
raise Exception(f'Failed to tokenize texts: {response.text}')
|
|
106
|
+
result = response.json()
|
|
107
|
+
lattice_id = result['id']
|
|
108
|
+
return lattice_id, (result['lattice_graph'], result['final_state'], result.get('acoustic_scale', 1.0))
|
|
109
|
+
|
|
110
|
+
def detokenize(
|
|
111
|
+
self,
|
|
112
|
+
lattice_id: str,
|
|
113
|
+
lattice_results: Tuple[torch.Tensor, Any, Any, float, float],
|
|
114
|
+
# return_supervisions: bool = True,
|
|
115
|
+
# return_details: bool = False,
|
|
116
|
+
) -> List[Supervision]:
|
|
117
|
+
emission, results, labels, frame_shift, offset, channel = lattice_results # noqa: F841
|
|
118
|
+
response = self.client_wrapper.post(
|
|
119
|
+
'detokenize',
|
|
120
|
+
json={
|
|
121
|
+
'lattice_id': lattice_id,
|
|
122
|
+
'frame_shift': frame_shift,
|
|
123
|
+
'results': [t.to_dict() for t in results[0]],
|
|
124
|
+
'labels': labels[0],
|
|
125
|
+
'offset': offset,
|
|
126
|
+
'channel': channel,
|
|
127
|
+
'destroy_lattice': True,
|
|
128
|
+
},
|
|
129
|
+
)
|
|
130
|
+
if response.status_code != 200:
|
|
131
|
+
raise Exception(f'Failed to detokenize lattice: {response.text}')
|
|
132
|
+
result = response.json()
|
|
133
|
+
# if return_details:
|
|
134
|
+
# raise NotImplementedError("return_details is not implemented yet")
|
|
135
|
+
return [Supervision.from_dict(s) for s in result['supervisions']]
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
# Compute average score weighted by the span length
|
|
139
|
+
def _score(spans):
|
|
140
|
+
if not spans:
|
|
141
|
+
return 0.0
|
|
142
|
+
# TokenSpan(token=token, start=start, end=end, score=scores[start:end].mean().item())
|
|
143
|
+
return round(sum(s.score * len(s) for s in spans) / sum(len(s) for s in spans), ndigits=4)
|
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import time
|
|
3
|
+
from collections import defaultdict
|
|
4
|
+
from typing import Any, BinaryIO, Dict, Tuple, Union
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import onnxruntime as ort
|
|
8
|
+
import torch
|
|
9
|
+
import torchaudio
|
|
10
|
+
from lhotse import FbankConfig
|
|
11
|
+
from lhotse.features.kaldi.layers import Wav2LogFilterBank
|
|
12
|
+
from lhotse.utils import Pathlike
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class Lattice1AlphaWorker:
|
|
16
|
+
"""Worker for processing audio with LatticeGraph."""
|
|
17
|
+
|
|
18
|
+
def __init__(self, model_path: Pathlike, device: str = 'cpu', num_threads: int = 8) -> None:
|
|
19
|
+
if device != 'cpu':
|
|
20
|
+
raise NotImplementedError(f'Only cpu is supported for now, got device={device}.')
|
|
21
|
+
self.config = json.load(open(f'{model_path}/config.json'))
|
|
22
|
+
|
|
23
|
+
# SessionOptions
|
|
24
|
+
sess_options = ort.SessionOptions()
|
|
25
|
+
# sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
26
|
+
sess_options.intra_op_num_threads = num_threads # CPU cores
|
|
27
|
+
sess_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL
|
|
28
|
+
sess_options.add_session_config_entry('session.intra_op.allow_spinning', '0')
|
|
29
|
+
|
|
30
|
+
providers = []
|
|
31
|
+
if device.startswith('cuda') or ort.get_all_providers().count('CUDAExecutionProvider') > 0:
|
|
32
|
+
providers.append('CUDAExecutionProvider')
|
|
33
|
+
self.acoustic_ort = ort.InferenceSession(
|
|
34
|
+
f'{model_path}/acoustic_opt.onnx',
|
|
35
|
+
sess_options,
|
|
36
|
+
providers=providers + ['CoreMLExecutionProvider', 'CPUExecutionProvider'],
|
|
37
|
+
)
|
|
38
|
+
config = FbankConfig(num_mel_bins=80, device=device, snip_edges=False)
|
|
39
|
+
config_dict = config.to_dict()
|
|
40
|
+
config_dict.pop('device')
|
|
41
|
+
self.extractor = Wav2LogFilterBank(**config_dict).to(device).eval()
|
|
42
|
+
|
|
43
|
+
self.device = torch.device(device)
|
|
44
|
+
self.timings = defaultdict(lambda: 0.0)
|
|
45
|
+
|
|
46
|
+
@torch.inference_mode()
|
|
47
|
+
def emission(self, audio: torch.Tensor) -> torch.Tensor:
|
|
48
|
+
_start = time.time()
|
|
49
|
+
# audio -> features -> emission
|
|
50
|
+
features = self.extractor(audio) # (1, T, D)
|
|
51
|
+
ort_inputs = {
|
|
52
|
+
'features': features.cpu().numpy(),
|
|
53
|
+
'feature_lengths': np.array([features.size(1)], dtype=np.int64),
|
|
54
|
+
}
|
|
55
|
+
emission = self.acoustic_ort.run(None, ort_inputs)[0] # (1, T, vocab_size) numpy
|
|
56
|
+
self.timings['emission'] += time.time() - _start
|
|
57
|
+
return torch.from_numpy(emission).to(self.device) # (1, T, vocab_size) torch
|
|
58
|
+
|
|
59
|
+
def load_audio(self, audio: Union[Pathlike, BinaryIO]) -> Tuple[torch.Tensor, int]:
|
|
60
|
+
# load audio
|
|
61
|
+
waveform, sample_rate = torchaudio.load(audio, channels_first=True)
|
|
62
|
+
if waveform.size(0) > 1: # TODO: support choose channel
|
|
63
|
+
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
|
64
|
+
if sample_rate != self.config['sample_rate']:
|
|
65
|
+
waveform = torchaudio.functional.resample(waveform, sample_rate, self.config['sample_rate'])
|
|
66
|
+
return waveform
|
|
67
|
+
|
|
68
|
+
def alignment(
|
|
69
|
+
self, audio: Union[Union[Pathlike, BinaryIO], torch.tensor], lattice_graph: Tuple[str, int, float]
|
|
70
|
+
) -> Dict[str, Any]:
|
|
71
|
+
"""Process audio with LatticeGraph.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
audio: Audio file path or binary data
|
|
75
|
+
lattice_graph: LatticeGraph data
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
Processed LatticeGraph
|
|
79
|
+
"""
|
|
80
|
+
# load audio
|
|
81
|
+
if isinstance(audio, torch.Tensor):
|
|
82
|
+
waveform = audio
|
|
83
|
+
else:
|
|
84
|
+
waveform = self.load_audio(audio) # (1, L)
|
|
85
|
+
|
|
86
|
+
_start = time.time()
|
|
87
|
+
emission = self.emission(waveform.to(self.device)) # (1, T, vocab_size)
|
|
88
|
+
self.timings['emission'] += time.time() - _start
|
|
89
|
+
|
|
90
|
+
import k2
|
|
91
|
+
from lattifai_core.lattice.decode import align_segments
|
|
92
|
+
|
|
93
|
+
lattice_graph_str, final_state, acoustic_scale = lattice_graph
|
|
94
|
+
|
|
95
|
+
_start = time.time()
|
|
96
|
+
# graph
|
|
97
|
+
decoding_graph = k2.Fsa.from_str(lattice_graph_str, acceptor=False)
|
|
98
|
+
decoding_graph.requires_grad_(False)
|
|
99
|
+
decoding_graph = k2.arc_sort(decoding_graph)
|
|
100
|
+
decoding_graph.skip_id = int(final_state)
|
|
101
|
+
decoding_graph.return_id = int(final_state + 1)
|
|
102
|
+
self.timings['decoding_graph'] += time.time() - _start
|
|
103
|
+
|
|
104
|
+
_start = time.time()
|
|
105
|
+
results, labels = align_segments(
|
|
106
|
+
emission.to(self.device) * acoustic_scale,
|
|
107
|
+
decoding_graph.to(self.device),
|
|
108
|
+
torch.tensor([emission.shape[1]], dtype=torch.int32),
|
|
109
|
+
search_beam=100,
|
|
110
|
+
output_beam=40,
|
|
111
|
+
min_active_states=200,
|
|
112
|
+
max_active_states=10000,
|
|
113
|
+
subsampling_factor=1,
|
|
114
|
+
reject_low_confidence=False,
|
|
115
|
+
)
|
|
116
|
+
self.timings['align_segments'] += time.time() - _start
|
|
117
|
+
|
|
118
|
+
channel = 0
|
|
119
|
+
return emission, results, labels, 0.02, 0.0, channel # frame_shift=20ms, offset=0.0s
|