lattifai 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,143 @@
1
+ import gzip
2
+ import pickle
3
+ from collections import defaultdict
4
+ from itertools import chain
5
+ from typing import Any, Dict, List, Optional, Tuple
6
+
7
+ import torch
8
+
9
+ from lattifai.base_client import SyncAPIClient
10
+ from lattifai.io import Supervision
11
+ from lattifai.tokenizers.phonemizer import G2Phonemizer
12
+
13
+ PUNCTUATION = '!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~'
14
+ PUNCTUATION_SPACE = PUNCTUATION + ' '
15
+ STAR_TOKEN = '※'
16
+
17
+ GROUPING_SEPARATOR = '✹'
18
+
19
+ MAXIMUM_WORD_LENGTH = 40
20
+
21
+
22
+ class LatticeTokenizer:
23
+ """Tokenizer for converting Lhotse Cut to LatticeGraph."""
24
+
25
+ def __init__(self, client_wrapper: SyncAPIClient):
26
+ self.client_wrapper = client_wrapper
27
+ self.words: List[str] = []
28
+ self.g2p_model: Any = None # Placeholder for G2P model
29
+ self.dictionaries = defaultdict(lambda: [])
30
+ self.oov_word = '<unk>'
31
+
32
+ @staticmethod
33
+ def from_pretrained(
34
+ client_wrapper: SyncAPIClient,
35
+ model_path: str,
36
+ g2p_model_path: Optional[str] = None,
37
+ device: str = 'cpu',
38
+ compressed: bool = True,
39
+ ):
40
+ """Load tokenizer from exported binary file"""
41
+ if compressed:
42
+ with gzip.open(model_path, 'rb') as f:
43
+ data = pickle.load(f)
44
+ else:
45
+ with open(model_path, 'rb') as f:
46
+ data = pickle.load(f)
47
+
48
+ tokenizer = LatticeTokenizer(client_wrapper=client_wrapper)
49
+ tokenizer.words = data['words']
50
+ tokenizer.dictionaries = defaultdict(list, data['dictionaries'])
51
+ tokenizer.oov_word = data['oov_word']
52
+ if g2p_model_path:
53
+ tokenizer.g2p_model = G2Phonemizer(g2p_model_path, device=device)
54
+ return tokenizer
55
+
56
+ def prenormalize(self, texts: List[str], language: Optional[str] = None) -> List[str]:
57
+ if not self.g2p_model:
58
+ raise ValueError('G2P model is not loaded, cannot prenormalize texts')
59
+
60
+ oov_words = []
61
+ for text in texts:
62
+ words = text.lower().replace('-', ' ').replace('—', ' ').replace('–', ' ').split()
63
+ oovs = [w for w in words if w not in self.words]
64
+ if oovs:
65
+ oov_words.extend([w for w in oovs if (w not in self.words and len(w) <= MAXIMUM_WORD_LENGTH)])
66
+
67
+ oov_words = list(set(oov_words))
68
+ if oov_words:
69
+ indexs = []
70
+ for k, _word in enumerate(oov_words):
71
+ if any(_word.startswith(p) and _word.endswith(q) for (p, q) in [('(', ')'), ('[', ']')]):
72
+ self.dictionaries[_word] = self.dictionaries[self.oov_word]
73
+ else:
74
+ _word = _word.strip(PUNCTUATION_SPACE)
75
+ if not _word or _word in self.words:
76
+ indexs.append(k)
77
+ for idx in sorted(indexs, reverse=True):
78
+ del oov_words[idx]
79
+
80
+ g2p_words = [w for w in oov_words if w not in self.dictionaries]
81
+ if g2p_words:
82
+ predictions = self.g2p_model(words=g2p_words, lang=language, batch_size=len(g2p_words), num_prons=4)
83
+ for _word, _predictions in zip(g2p_words, predictions):
84
+ for pronuncation in _predictions:
85
+ if pronuncation and pronuncation not in self.dictionaries[_word]:
86
+ self.dictionaries[_word].append(pronuncation)
87
+
88
+ pronunciation_dictionaries: Dict[str, List[List[str]]] = {
89
+ w: self.dictionaries[w] for w in oov_words if self.dictionaries[w]
90
+ }
91
+ return pronunciation_dictionaries
92
+
93
+ return {}
94
+
95
+ def tokenize(self, supervisions: List[Supervision]) -> Tuple[str, Dict[str, Any]]:
96
+ pronunciation_dictionaries = self.prenormalize([s.text for s in supervisions])
97
+ response = self.client_wrapper.post(
98
+ 'tokenize',
99
+ json={
100
+ 'supervisions': [s.to_dict() for s in supervisions],
101
+ 'pronunciation_dictionaries': pronunciation_dictionaries,
102
+ },
103
+ )
104
+ if response.status_code != 200:
105
+ raise Exception(f'Failed to tokenize texts: {response.text}')
106
+ result = response.json()
107
+ lattice_id = result['id']
108
+ return lattice_id, (result['lattice_graph'], result['final_state'], result.get('acoustic_scale', 1.0))
109
+
110
+ def detokenize(
111
+ self,
112
+ lattice_id: str,
113
+ lattice_results: Tuple[torch.Tensor, Any, Any, float, float],
114
+ # return_supervisions: bool = True,
115
+ # return_details: bool = False,
116
+ ) -> List[Supervision]:
117
+ emission, results, labels, frame_shift, offset, channel = lattice_results # noqa: F841
118
+ response = self.client_wrapper.post(
119
+ 'detokenize',
120
+ json={
121
+ 'lattice_id': lattice_id,
122
+ 'frame_shift': frame_shift,
123
+ 'results': [t.to_dict() for t in results[0]],
124
+ 'labels': labels[0],
125
+ 'offset': offset,
126
+ 'channel': channel,
127
+ 'destroy_lattice': True,
128
+ },
129
+ )
130
+ if response.status_code != 200:
131
+ raise Exception(f'Failed to detokenize lattice: {response.text}')
132
+ result = response.json()
133
+ # if return_details:
134
+ # raise NotImplementedError("return_details is not implemented yet")
135
+ return [Supervision.from_dict(s) for s in result['supervisions']]
136
+
137
+
138
+ # Compute average score weighted by the span length
139
+ def _score(spans):
140
+ if not spans:
141
+ return 0.0
142
+ # TokenSpan(token=token, start=start, end=end, score=scores[start:end].mean().item())
143
+ return round(sum(s.score * len(s) for s in spans) / sum(len(s) for s in spans), ndigits=4)
@@ -0,0 +1,3 @@
1
+ from .lattice1_alpha import Lattice1AlphaWorker
2
+
3
+ __all__ = ['Lattice1AlphaWorker']
@@ -0,0 +1,119 @@
1
+ import json
2
+ import time
3
+ from collections import defaultdict
4
+ from typing import Any, BinaryIO, Dict, Tuple, Union
5
+
6
+ import numpy as np
7
+ import onnxruntime as ort
8
+ import torch
9
+ import torchaudio
10
+ from lhotse import FbankConfig
11
+ from lhotse.features.kaldi.layers import Wav2LogFilterBank
12
+ from lhotse.utils import Pathlike
13
+
14
+
15
+ class Lattice1AlphaWorker:
16
+ """Worker for processing audio with LatticeGraph."""
17
+
18
+ def __init__(self, model_path: Pathlike, device: str = 'cpu', num_threads: int = 8) -> None:
19
+ if device != 'cpu':
20
+ raise NotImplementedError(f'Only cpu is supported for now, got device={device}.')
21
+ self.config = json.load(open(f'{model_path}/config.json'))
22
+
23
+ # SessionOptions
24
+ sess_options = ort.SessionOptions()
25
+ # sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
26
+ sess_options.intra_op_num_threads = num_threads # CPU cores
27
+ sess_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL
28
+ sess_options.add_session_config_entry('session.intra_op.allow_spinning', '0')
29
+
30
+ providers = []
31
+ if device.startswith('cuda') or ort.get_all_providers().count('CUDAExecutionProvider') > 0:
32
+ providers.append('CUDAExecutionProvider')
33
+ self.acoustic_ort = ort.InferenceSession(
34
+ f'{model_path}/acoustic_opt.onnx',
35
+ sess_options,
36
+ providers=providers + ['CoreMLExecutionProvider', 'CPUExecutionProvider'],
37
+ )
38
+ config = FbankConfig(num_mel_bins=80, device=device, snip_edges=False)
39
+ config_dict = config.to_dict()
40
+ config_dict.pop('device')
41
+ self.extractor = Wav2LogFilterBank(**config_dict).to(device).eval()
42
+
43
+ self.device = torch.device(device)
44
+ self.timings = defaultdict(lambda: 0.0)
45
+
46
+ @torch.inference_mode()
47
+ def emission(self, audio: torch.Tensor) -> torch.Tensor:
48
+ _start = time.time()
49
+ # audio -> features -> emission
50
+ features = self.extractor(audio) # (1, T, D)
51
+ ort_inputs = {
52
+ 'features': features.cpu().numpy(),
53
+ 'feature_lengths': np.array([features.size(1)], dtype=np.int64),
54
+ }
55
+ emission = self.acoustic_ort.run(None, ort_inputs)[0] # (1, T, vocab_size) numpy
56
+ self.timings['emission'] += time.time() - _start
57
+ return torch.from_numpy(emission).to(self.device) # (1, T, vocab_size) torch
58
+
59
+ def load_audio(self, audio: Union[Pathlike, BinaryIO]) -> Tuple[torch.Tensor, int]:
60
+ # load audio
61
+ waveform, sample_rate = torchaudio.load(audio, channels_first=True)
62
+ if waveform.size(0) > 1: # TODO: support choose channel
63
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
64
+ if sample_rate != self.config['sample_rate']:
65
+ waveform = torchaudio.functional.resample(waveform, sample_rate, self.config['sample_rate'])
66
+ return waveform
67
+
68
+ def alignment(
69
+ self, audio: Union[Union[Pathlike, BinaryIO], torch.tensor], lattice_graph: Tuple[str, int, float]
70
+ ) -> Dict[str, Any]:
71
+ """Process audio with LatticeGraph.
72
+
73
+ Args:
74
+ audio: Audio file path or binary data
75
+ lattice_graph: LatticeGraph data
76
+
77
+ Returns:
78
+ Processed LatticeGraph
79
+ """
80
+ # load audio
81
+ if isinstance(audio, torch.Tensor):
82
+ waveform = audio
83
+ else:
84
+ waveform = self.load_audio(audio) # (1, L)
85
+
86
+ _start = time.time()
87
+ emission = self.emission(waveform.to(self.device)) # (1, T, vocab_size)
88
+ self.timings['emission'] += time.time() - _start
89
+
90
+ import k2
91
+ from lattifai_core.lattice.decode import align_segments
92
+
93
+ lattice_graph_str, final_state, acoustic_scale = lattice_graph
94
+
95
+ _start = time.time()
96
+ # graph
97
+ decoding_graph = k2.Fsa.from_str(lattice_graph_str, acceptor=False)
98
+ decoding_graph.requires_grad_(False)
99
+ decoding_graph = k2.arc_sort(decoding_graph)
100
+ decoding_graph.skip_id = int(final_state)
101
+ decoding_graph.return_id = int(final_state + 1)
102
+ self.timings['decoding_graph'] += time.time() - _start
103
+
104
+ _start = time.time()
105
+ results, labels = align_segments(
106
+ emission.to(self.device) * acoustic_scale,
107
+ decoding_graph.to(self.device),
108
+ torch.tensor([emission.shape[1]], dtype=torch.int32),
109
+ search_beam=100,
110
+ output_beam=40,
111
+ min_active_states=200,
112
+ max_active_states=10000,
113
+ subsampling_factor=1,
114
+ reject_low_confidence=False,
115
+ )
116
+ self.timings['align_segments'] += time.time() - _start
117
+
118
+ channel = 0
119
+ return emission, results, labels, 0.02, 0.0, channel # frame_shift=20ms, offset=0.0s