lattifai 0.2.4__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lattifai/__init__.py +32 -1
- lattifai/base_client.py +14 -6
- lattifai/bin/__init__.py +1 -0
- lattifai/bin/agent.py +325 -0
- lattifai/bin/align.py +253 -21
- lattifai/bin/cli_base.py +5 -0
- lattifai/bin/subtitle.py +182 -4
- lattifai/client.py +236 -63
- lattifai/errors.py +257 -0
- lattifai/io/__init__.py +21 -1
- lattifai/io/gemini_reader.py +371 -0
- lattifai/io/gemini_writer.py +173 -0
- lattifai/io/reader.py +21 -9
- lattifai/io/supervision.py +16 -0
- lattifai/io/utils.py +15 -0
- lattifai/io/writer.py +58 -17
- lattifai/tokenizer/__init__.py +2 -2
- lattifai/tokenizer/tokenizer.py +221 -40
- lattifai/utils.py +133 -0
- lattifai/workers/lattice1_alpha.py +130 -66
- lattifai-0.4.0.dist-info/METADATA +811 -0
- lattifai-0.4.0.dist-info/RECORD +28 -0
- lattifai-0.4.0.dist-info/entry_points.txt +3 -0
- lattifai-0.2.4.dist-info/METADATA +0 -334
- lattifai-0.2.4.dist-info/RECORD +0 -22
- lattifai-0.2.4.dist-info/entry_points.txt +0 -4
- {lattifai-0.2.4.dist-info → lattifai-0.4.0.dist-info}/WHEEL +0 -0
- {lattifai-0.2.4.dist-info → lattifai-0.4.0.dist-info}/licenses/LICENSE +0 -0
- {lattifai-0.2.4.dist-info → lattifai-0.4.0.dist-info}/top_level.txt +0 -0
|
@@ -13,12 +13,23 @@ from lhotse.audio import read_audio
|
|
|
13
13
|
from lhotse.features.kaldi.layers import Wav2LogFilterBank
|
|
14
14
|
from lhotse.utils import Pathlike
|
|
15
15
|
|
|
16
|
+
from lattifai.errors import (
|
|
17
|
+
AlignmentError,
|
|
18
|
+
AudioFormatError,
|
|
19
|
+
AudioLoadError,
|
|
20
|
+
DependencyError,
|
|
21
|
+
ModelLoadError,
|
|
22
|
+
)
|
|
23
|
+
|
|
16
24
|
|
|
17
25
|
class Lattice1AlphaWorker:
|
|
18
26
|
"""Worker for processing audio with LatticeGraph."""
|
|
19
27
|
|
|
20
28
|
def __init__(self, model_path: Pathlike, device: str = 'cpu', num_threads: int = 8) -> None:
|
|
21
|
-
|
|
29
|
+
try:
|
|
30
|
+
self.config = json.load(open(f'{model_path}/config.json'))
|
|
31
|
+
except Exception as e:
|
|
32
|
+
raise ModelLoadError(f'config from {model_path}', original_error=e)
|
|
22
33
|
|
|
23
34
|
# SessionOptions
|
|
24
35
|
sess_options = ort.SessionOptions()
|
|
@@ -33,15 +44,22 @@ class Lattice1AlphaWorker:
|
|
|
33
44
|
elif device.startswith('mps') and ort.get_all_providers().count('MPSExecutionProvider') > 0:
|
|
34
45
|
providers.append('MPSExecutionProvider')
|
|
35
46
|
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
47
|
+
try:
|
|
48
|
+
self.acoustic_ort = ort.InferenceSession(
|
|
49
|
+
f'{model_path}/acoustic_opt.onnx',
|
|
50
|
+
sess_options,
|
|
51
|
+
providers=providers + ['CoreMLExecutionProvider', 'CPUExecutionProvider'],
|
|
52
|
+
)
|
|
53
|
+
except Exception as e:
|
|
54
|
+
raise ModelLoadError(f'acoustic model from {model_path}', original_error=e)
|
|
55
|
+
|
|
56
|
+
try:
|
|
57
|
+
config = FbankConfig(num_mel_bins=80, device=device, snip_edges=False)
|
|
58
|
+
config_dict = config.to_dict()
|
|
59
|
+
config_dict.pop('device')
|
|
60
|
+
self.extractor = Wav2LogFilterBank(**config_dict).to(device).eval()
|
|
61
|
+
except Exception as e:
|
|
62
|
+
raise ModelLoadError(f'feature extractor for device {device}', original_error=e)
|
|
45
63
|
|
|
46
64
|
self.device = torch.device(device)
|
|
47
65
|
self.timings = defaultdict(lambda: 0.0)
|
|
@@ -86,45 +104,59 @@ class Lattice1AlphaWorker:
|
|
|
86
104
|
waveform = waveform.transpose(0, 1)
|
|
87
105
|
# average multiple channels
|
|
88
106
|
waveform = np.mean(waveform, axis=0, keepdims=True) # (1, L)
|
|
89
|
-
except Exception:
|
|
107
|
+
except Exception as primary_error:
|
|
90
108
|
# Fallback to PyAV for formats not supported by soundfile
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
109
|
+
try:
|
|
110
|
+
import av
|
|
111
|
+
except ImportError:
|
|
112
|
+
raise DependencyError(
|
|
113
|
+
'av (PyAV)', install_command='pip install av', context={'primary_error': str(primary_error)}
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
try:
|
|
117
|
+
container = av.open(audio)
|
|
118
|
+
audio_stream = next((s for s in container.streams if s.type == 'audio'), None)
|
|
119
|
+
|
|
120
|
+
if audio_stream is None:
|
|
121
|
+
raise AudioFormatError(str(audio), 'No audio stream found in file')
|
|
122
|
+
|
|
123
|
+
# Resample to target sample rate during decoding
|
|
124
|
+
audio_stream.codec_context.format = av.AudioFormat('flt') # 32-bit float
|
|
125
|
+
|
|
126
|
+
frames = []
|
|
127
|
+
for frame in container.decode(audio_stream):
|
|
128
|
+
# Convert frame to numpy array
|
|
129
|
+
array = frame.to_ndarray()
|
|
130
|
+
# Ensure shape is (channels, samples)
|
|
131
|
+
if array.ndim == 1:
|
|
132
|
+
array = array.reshape(1, -1)
|
|
133
|
+
elif array.ndim == 2 and array.shape[0] > array.shape[1]:
|
|
134
|
+
array = array.T
|
|
135
|
+
frames.append(array)
|
|
136
|
+
|
|
137
|
+
container.close()
|
|
138
|
+
|
|
139
|
+
if not frames:
|
|
140
|
+
raise AudioFormatError(str(audio), 'No audio data found in file')
|
|
141
|
+
|
|
142
|
+
# Concatenate all frames
|
|
143
|
+
waveform = np.concatenate(frames, axis=1)
|
|
144
|
+
# Average multiple channels to mono
|
|
145
|
+
if waveform.shape[0] > 1:
|
|
146
|
+
waveform = np.mean(waveform, axis=0, keepdims=True)
|
|
147
|
+
|
|
148
|
+
sample_rate = audio_stream.codec_context.sample_rate
|
|
149
|
+
except Exception as e:
|
|
150
|
+
raise AudioLoadError(str(audio), original_error=e)
|
|
114
151
|
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
waveform = np.mean(waveform, axis=0, keepdims=True)
|
|
123
|
-
|
|
124
|
-
sample_rate = audio_stream.codec_context.sample_rate
|
|
152
|
+
try:
|
|
153
|
+
if sample_rate != self.config['sample_rate']:
|
|
154
|
+
waveform = resampy.resample(waveform, sample_rate, self.config['sample_rate'], axis=1)
|
|
155
|
+
except Exception:
|
|
156
|
+
raise AudioFormatError(
|
|
157
|
+
str(audio), f'Failed to resample from {sample_rate}Hz to {self.config["sample_rate"]}Hz'
|
|
158
|
+
)
|
|
125
159
|
|
|
126
|
-
if sample_rate != self.config['sample_rate']:
|
|
127
|
-
waveform = resampy.resample(waveform, sample_rate, self.config['sample_rate'], axis=1)
|
|
128
160
|
return torch.from_numpy(waveform).to(self.device) # (1, L)
|
|
129
161
|
|
|
130
162
|
def alignment(
|
|
@@ -138,6 +170,11 @@ class Lattice1AlphaWorker:
|
|
|
138
170
|
|
|
139
171
|
Returns:
|
|
140
172
|
Processed LatticeGraph
|
|
173
|
+
|
|
174
|
+
Raises:
|
|
175
|
+
AudioLoadError: If audio cannot be loaded
|
|
176
|
+
DependencyError: If required dependencies are missing
|
|
177
|
+
AlignmentError: If alignment process fails
|
|
141
178
|
"""
|
|
142
179
|
# load audio
|
|
143
180
|
if isinstance(audio, torch.Tensor):
|
|
@@ -146,21 +183,41 @@ class Lattice1AlphaWorker:
|
|
|
146
183
|
waveform = self.load_audio(audio) # (1, L)
|
|
147
184
|
|
|
148
185
|
_start = time.time()
|
|
149
|
-
|
|
186
|
+
try:
|
|
187
|
+
emission = self.emission(waveform.to(self.device)) # (1, T, vocab_size)
|
|
188
|
+
except Exception as e:
|
|
189
|
+
raise AlignmentError(
|
|
190
|
+
'Failed to compute acoustic features from audio',
|
|
191
|
+
audio_path=str(audio) if not isinstance(audio, torch.Tensor) else 'tensor',
|
|
192
|
+
context={'original_error': str(e)},
|
|
193
|
+
)
|
|
150
194
|
self.timings['emission'] += time.time() - _start
|
|
151
195
|
|
|
152
|
-
|
|
153
|
-
|
|
196
|
+
try:
|
|
197
|
+
import k2
|
|
198
|
+
except ImportError:
|
|
199
|
+
raise DependencyError('k2', install_command='pip install install-k2 && python -m install_k2')
|
|
200
|
+
|
|
201
|
+
try:
|
|
202
|
+
from lattifai_core.lattice.decode import align_segments
|
|
203
|
+
except ImportError:
|
|
204
|
+
raise DependencyError('lattifai_core', install_command='Contact support for lattifai_core installation')
|
|
154
205
|
|
|
155
206
|
lattice_graph_str, final_state, acoustic_scale = lattice_graph
|
|
156
207
|
|
|
157
208
|
_start = time.time()
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
209
|
+
try:
|
|
210
|
+
# graph
|
|
211
|
+
decoding_graph = k2.Fsa.from_str(lattice_graph_str, acceptor=False)
|
|
212
|
+
decoding_graph.requires_grad_(False)
|
|
213
|
+
decoding_graph = k2.arc_sort(decoding_graph)
|
|
214
|
+
decoding_graph.skip_id = int(final_state)
|
|
215
|
+
decoding_graph.return_id = int(final_state + 1)
|
|
216
|
+
except Exception as e:
|
|
217
|
+
raise AlignmentError(
|
|
218
|
+
'Failed to create decoding graph from lattice',
|
|
219
|
+
context={'original_error': str(e), 'lattice_graph_length': len(lattice_graph_str)},
|
|
220
|
+
)
|
|
164
221
|
self.timings['decoding_graph'] += time.time() - _start
|
|
165
222
|
|
|
166
223
|
_start = time.time()
|
|
@@ -169,17 +226,24 @@ class Lattice1AlphaWorker:
|
|
|
169
226
|
else:
|
|
170
227
|
device = self.device
|
|
171
228
|
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
229
|
+
try:
|
|
230
|
+
results, labels = align_segments(
|
|
231
|
+
emission.to(device) * acoustic_scale,
|
|
232
|
+
decoding_graph.to(device),
|
|
233
|
+
torch.tensor([emission.shape[1]], dtype=torch.int32),
|
|
234
|
+
search_beam=100,
|
|
235
|
+
output_beam=40,
|
|
236
|
+
min_active_states=200,
|
|
237
|
+
max_active_states=10000,
|
|
238
|
+
subsampling_factor=1,
|
|
239
|
+
reject_low_confidence=False,
|
|
240
|
+
)
|
|
241
|
+
except Exception as e:
|
|
242
|
+
raise AlignmentError(
|
|
243
|
+
'Failed to perform forced alignment',
|
|
244
|
+
audio_path=str(audio) if not isinstance(audio, torch.Tensor) else 'tensor',
|
|
245
|
+
context={'original_error': str(e), 'emission_shape': list(emission.shape), 'device': str(device)},
|
|
246
|
+
)
|
|
183
247
|
self.timings['align_segments'] += time.time() - _start
|
|
184
248
|
|
|
185
249
|
channel = 0
|