lattifai 0.2.4__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -13,12 +13,23 @@ from lhotse.audio import read_audio
13
13
  from lhotse.features.kaldi.layers import Wav2LogFilterBank
14
14
  from lhotse.utils import Pathlike
15
15
 
16
+ from lattifai.errors import (
17
+ AlignmentError,
18
+ AudioFormatError,
19
+ AudioLoadError,
20
+ DependencyError,
21
+ ModelLoadError,
22
+ )
23
+
16
24
 
17
25
  class Lattice1AlphaWorker:
18
26
  """Worker for processing audio with LatticeGraph."""
19
27
 
20
28
  def __init__(self, model_path: Pathlike, device: str = 'cpu', num_threads: int = 8) -> None:
21
- self.config = json.load(open(f'{model_path}/config.json'))
29
+ try:
30
+ self.config = json.load(open(f'{model_path}/config.json'))
31
+ except Exception as e:
32
+ raise ModelLoadError(f'config from {model_path}', original_error=e)
22
33
 
23
34
  # SessionOptions
24
35
  sess_options = ort.SessionOptions()
@@ -33,15 +44,22 @@ class Lattice1AlphaWorker:
33
44
  elif device.startswith('mps') and ort.get_all_providers().count('MPSExecutionProvider') > 0:
34
45
  providers.append('MPSExecutionProvider')
35
46
 
36
- self.acoustic_ort = ort.InferenceSession(
37
- f'{model_path}/acoustic_opt.onnx',
38
- sess_options,
39
- providers=providers + ['CoreMLExecutionProvider', 'CPUExecutionProvider'],
40
- )
41
- config = FbankConfig(num_mel_bins=80, device=device, snip_edges=False)
42
- config_dict = config.to_dict()
43
- config_dict.pop('device')
44
- self.extractor = Wav2LogFilterBank(**config_dict).to(device).eval()
47
+ try:
48
+ self.acoustic_ort = ort.InferenceSession(
49
+ f'{model_path}/acoustic_opt.onnx',
50
+ sess_options,
51
+ providers=providers + ['CoreMLExecutionProvider', 'CPUExecutionProvider'],
52
+ )
53
+ except Exception as e:
54
+ raise ModelLoadError(f'acoustic model from {model_path}', original_error=e)
55
+
56
+ try:
57
+ config = FbankConfig(num_mel_bins=80, device=device, snip_edges=False)
58
+ config_dict = config.to_dict()
59
+ config_dict.pop('device')
60
+ self.extractor = Wav2LogFilterBank(**config_dict).to(device).eval()
61
+ except Exception as e:
62
+ raise ModelLoadError(f'feature extractor for device {device}', original_error=e)
45
63
 
46
64
  self.device = torch.device(device)
47
65
  self.timings = defaultdict(lambda: 0.0)
@@ -86,45 +104,59 @@ class Lattice1AlphaWorker:
86
104
  waveform = waveform.transpose(0, 1)
87
105
  # average multiple channels
88
106
  waveform = np.mean(waveform, axis=0, keepdims=True) # (1, L)
89
- except Exception:
107
+ except Exception as primary_error:
90
108
  # Fallback to PyAV for formats not supported by soundfile
91
- import av
92
-
93
- container = av.open(audio)
94
- audio_stream = next((s for s in container.streams if s.type == 'audio'), None)
95
-
96
- if audio_stream is None:
97
- raise ValueError(f'No audio stream found in {audio}')
98
-
99
- # Resample to target sample rate during decoding
100
- audio_stream.codec_context.format = av.AudioFormat('flt') # 32-bit float
101
-
102
- frames = []
103
- for frame in container.decode(audio_stream):
104
- # Convert frame to numpy array
105
- array = frame.to_ndarray()
106
- # Ensure shape is (channels, samples)
107
- if array.ndim == 1:
108
- array = array.reshape(1, -1)
109
- elif array.ndim == 2 and array.shape[0] > array.shape[1]:
110
- array = array.T
111
- frames.append(array)
112
-
113
- container.close()
109
+ try:
110
+ import av
111
+ except ImportError:
112
+ raise DependencyError(
113
+ 'av (PyAV)', install_command='pip install av', context={'primary_error': str(primary_error)}
114
+ )
115
+
116
+ try:
117
+ container = av.open(audio)
118
+ audio_stream = next((s for s in container.streams if s.type == 'audio'), None)
119
+
120
+ if audio_stream is None:
121
+ raise AudioFormatError(str(audio), 'No audio stream found in file')
122
+
123
+ # Resample to target sample rate during decoding
124
+ audio_stream.codec_context.format = av.AudioFormat('flt') # 32-bit float
125
+
126
+ frames = []
127
+ for frame in container.decode(audio_stream):
128
+ # Convert frame to numpy array
129
+ array = frame.to_ndarray()
130
+ # Ensure shape is (channels, samples)
131
+ if array.ndim == 1:
132
+ array = array.reshape(1, -1)
133
+ elif array.ndim == 2 and array.shape[0] > array.shape[1]:
134
+ array = array.T
135
+ frames.append(array)
136
+
137
+ container.close()
138
+
139
+ if not frames:
140
+ raise AudioFormatError(str(audio), 'No audio data found in file')
141
+
142
+ # Concatenate all frames
143
+ waveform = np.concatenate(frames, axis=1)
144
+ # Average multiple channels to mono
145
+ if waveform.shape[0] > 1:
146
+ waveform = np.mean(waveform, axis=0, keepdims=True)
147
+
148
+ sample_rate = audio_stream.codec_context.sample_rate
149
+ except Exception as e:
150
+ raise AudioLoadError(str(audio), original_error=e)
114
151
 
115
- if not frames:
116
- raise ValueError(f'No audio data found in {audio}')
117
-
118
- # Concatenate all frames
119
- waveform = np.concatenate(frames, axis=1)
120
- # Average multiple channels to mono
121
- if waveform.shape[0] > 1:
122
- waveform = np.mean(waveform, axis=0, keepdims=True)
123
-
124
- sample_rate = audio_stream.codec_context.sample_rate
152
+ try:
153
+ if sample_rate != self.config['sample_rate']:
154
+ waveform = resampy.resample(waveform, sample_rate, self.config['sample_rate'], axis=1)
155
+ except Exception:
156
+ raise AudioFormatError(
157
+ str(audio), f'Failed to resample from {sample_rate}Hz to {self.config["sample_rate"]}Hz'
158
+ )
125
159
 
126
- if sample_rate != self.config['sample_rate']:
127
- waveform = resampy.resample(waveform, sample_rate, self.config['sample_rate'], axis=1)
128
160
  return torch.from_numpy(waveform).to(self.device) # (1, L)
129
161
 
130
162
  def alignment(
@@ -138,6 +170,11 @@ class Lattice1AlphaWorker:
138
170
 
139
171
  Returns:
140
172
  Processed LatticeGraph
173
+
174
+ Raises:
175
+ AudioLoadError: If audio cannot be loaded
176
+ DependencyError: If required dependencies are missing
177
+ AlignmentError: If alignment process fails
141
178
  """
142
179
  # load audio
143
180
  if isinstance(audio, torch.Tensor):
@@ -146,21 +183,41 @@ class Lattice1AlphaWorker:
146
183
  waveform = self.load_audio(audio) # (1, L)
147
184
 
148
185
  _start = time.time()
149
- emission = self.emission(waveform.to(self.device)) # (1, T, vocab_size)
186
+ try:
187
+ emission = self.emission(waveform.to(self.device)) # (1, T, vocab_size)
188
+ except Exception as e:
189
+ raise AlignmentError(
190
+ 'Failed to compute acoustic features from audio',
191
+ audio_path=str(audio) if not isinstance(audio, torch.Tensor) else 'tensor',
192
+ context={'original_error': str(e)},
193
+ )
150
194
  self.timings['emission'] += time.time() - _start
151
195
 
152
- import k2
153
- from lattifai_core.lattice.decode import align_segments
196
+ try:
197
+ import k2
198
+ except ImportError:
199
+ raise DependencyError('k2', install_command='pip install install-k2 && python -m install_k2')
200
+
201
+ try:
202
+ from lattifai_core.lattice.decode import align_segments
203
+ except ImportError:
204
+ raise DependencyError('lattifai_core', install_command='Contact support for lattifai_core installation')
154
205
 
155
206
  lattice_graph_str, final_state, acoustic_scale = lattice_graph
156
207
 
157
208
  _start = time.time()
158
- # graph
159
- decoding_graph = k2.Fsa.from_str(lattice_graph_str, acceptor=False)
160
- decoding_graph.requires_grad_(False)
161
- decoding_graph = k2.arc_sort(decoding_graph)
162
- decoding_graph.skip_id = int(final_state)
163
- decoding_graph.return_id = int(final_state + 1)
209
+ try:
210
+ # graph
211
+ decoding_graph = k2.Fsa.from_str(lattice_graph_str, acceptor=False)
212
+ decoding_graph.requires_grad_(False)
213
+ decoding_graph = k2.arc_sort(decoding_graph)
214
+ decoding_graph.skip_id = int(final_state)
215
+ decoding_graph.return_id = int(final_state + 1)
216
+ except Exception as e:
217
+ raise AlignmentError(
218
+ 'Failed to create decoding graph from lattice',
219
+ context={'original_error': str(e), 'lattice_graph_length': len(lattice_graph_str)},
220
+ )
164
221
  self.timings['decoding_graph'] += time.time() - _start
165
222
 
166
223
  _start = time.time()
@@ -169,17 +226,24 @@ class Lattice1AlphaWorker:
169
226
  else:
170
227
  device = self.device
171
228
 
172
- results, labels = align_segments(
173
- emission.to(device) * acoustic_scale,
174
- decoding_graph.to(device),
175
- torch.tensor([emission.shape[1]], dtype=torch.int32),
176
- search_beam=100,
177
- output_beam=40,
178
- min_active_states=200,
179
- max_active_states=10000,
180
- subsampling_factor=1,
181
- reject_low_confidence=False,
182
- )
229
+ try:
230
+ results, labels = align_segments(
231
+ emission.to(device) * acoustic_scale,
232
+ decoding_graph.to(device),
233
+ torch.tensor([emission.shape[1]], dtype=torch.int32),
234
+ search_beam=100,
235
+ output_beam=40,
236
+ min_active_states=200,
237
+ max_active_states=10000,
238
+ subsampling_factor=1,
239
+ reject_low_confidence=False,
240
+ )
241
+ except Exception as e:
242
+ raise AlignmentError(
243
+ 'Failed to perform forced alignment',
244
+ audio_path=str(audio) if not isinstance(audio, torch.Tensor) else 'tensor',
245
+ context={'original_error': str(e), 'emission_shape': list(emission.shape), 'device': str(device)},
246
+ )
183
247
  self.timings['align_segments'] += time.time() - _start
184
248
 
185
249
  channel = 0