lattifai 0.4.4__py3-none-any.whl → 0.4.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,59 +1,110 @@
1
1
  import json
2
2
  import time
3
3
  from collections import defaultdict
4
- from typing import Any, BinaryIO, Dict, Tuple, Union
4
+ from typing import Any, BinaryIO, Dict, Iterable, Optional, Tuple, Union
5
5
 
6
6
  import numpy as np
7
7
  import onnxruntime as ort
8
- import resampy
9
8
  import soundfile as sf
10
9
  import torch
11
10
  from lhotse import FbankConfig
12
- from lhotse.audio import read_audio
11
+ from lhotse.augmentation import get_or_create_resampler
13
12
  from lhotse.features.kaldi.layers import Wav2LogFilterBank
14
13
  from lhotse.utils import Pathlike
15
14
 
16
15
  from lattifai.errors import AlignmentError, AudioFormatError, AudioLoadError, DependencyError, ModelLoadError
17
16
 
17
+ ChannelSelectorType = Union[int, Iterable[int], str]
18
+
19
+
20
+ def resample_audio(
21
+ audio_sr: Tuple[torch.Tensor, int],
22
+ sampling_rate: int,
23
+ device: Optional[str],
24
+ channel_selector: Optional[ChannelSelectorType] = "average",
25
+ ) -> torch.Tensor:
26
+ """
27
+ return:
28
+ (1, T)
29
+ """
30
+ audio, sr = audio_sr
31
+
32
+ if channel_selector is None:
33
+ # keep the original multi-channel signal
34
+ tensor = audio
35
+ elif isinstance(channel_selector, int):
36
+ assert audio.shape[0] >= channel_selector, f"Invalid channel: {channel_selector}"
37
+ tensor = audio[channel_selector : channel_selector + 1].clone()
38
+ del audio
39
+ elif isinstance(channel_selector, str):
40
+ assert channel_selector == "average"
41
+ tensor = torch.mean(audio.to(device), dim=0, keepdim=True)
42
+ del audio
43
+ else:
44
+ assert isinstance(channel_selector, Iterable)
45
+ num_channels = audio.shape[0]
46
+ print(f"Selecting channels {channel_selector} from the signal with {num_channels} channels.")
47
+ assert isinstance(channel_selector, Iterable)
48
+ if max(channel_selector) >= num_channels:
49
+ raise ValueError(
50
+ f"Cannot select channel subset {channel_selector} from a signal with {num_channels} channels."
51
+ )
52
+ tensor = audio[channel_selector]
53
+
54
+ tensor = tensor.to(device)
55
+ if sr != sampling_rate:
56
+ resampler = get_or_create_resampler(sr, sampling_rate).to(device=device)
57
+ length = tensor.size(-1)
58
+ chunk_size = sampling_rate * 3600
59
+ if length > chunk_size:
60
+ resampled_chunks = []
61
+ for i in range(0, length, chunk_size):
62
+ resampled_chunks.append(resampler(tensor[..., i : i + chunk_size]))
63
+ tensor = torch.cat(resampled_chunks, dim=-1)
64
+ else:
65
+ tensor = resampler(tensor)
66
+
67
+ return tensor
68
+
18
69
 
19
70
  class Lattice1AlphaWorker:
20
71
  """Worker for processing audio with LatticeGraph."""
21
72
 
22
- def __init__(self, model_path: Pathlike, device: str = 'cpu', num_threads: int = 8) -> None:
73
+ def __init__(self, model_path: Pathlike, device: str = "cpu", num_threads: int = 8) -> None:
23
74
  try:
24
- self.config = json.load(open(f'{model_path}/config.json'))
75
+ self.config = json.load(open(f"{model_path}/config.json"))
25
76
  except Exception as e:
26
- raise ModelLoadError(f'config from {model_path}', original_error=e)
77
+ raise ModelLoadError(f"config from {model_path}", original_error=e)
27
78
 
28
79
  # SessionOptions
29
80
  sess_options = ort.SessionOptions()
30
81
  # sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
31
82
  sess_options.intra_op_num_threads = num_threads # CPU cores
32
83
  sess_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL
33
- sess_options.add_session_config_entry('session.intra_op.allow_spinning', '0')
84
+ sess_options.add_session_config_entry("session.intra_op.allow_spinning", "0")
34
85
 
35
86
  providers = []
36
- if device.startswith('cuda') and ort.get_all_providers().count('CUDAExecutionProvider') > 0:
37
- providers.append('CUDAExecutionProvider')
38
- elif device.startswith('mps') and ort.get_all_providers().count('MPSExecutionProvider') > 0:
39
- providers.append('MPSExecutionProvider')
87
+ if device.startswith("cuda") and ort.get_all_providers().count("CUDAExecutionProvider") > 0:
88
+ providers.append("CUDAExecutionProvider")
89
+ elif device.startswith("mps") and ort.get_all_providers().count("MPSExecutionProvider") > 0:
90
+ providers.append("MPSExecutionProvider")
40
91
 
41
92
  try:
42
93
  self.acoustic_ort = ort.InferenceSession(
43
- f'{model_path}/acoustic_opt.onnx',
94
+ f"{model_path}/acoustic_opt.onnx",
44
95
  sess_options,
45
- providers=providers + ['CoreMLExecutionProvider', 'CPUExecutionProvider'],
96
+ providers=providers + ["CPUExecutionProvider", "CoreMLExecutionProvider"],
46
97
  )
47
98
  except Exception as e:
48
- raise ModelLoadError(f'acoustic model from {model_path}', original_error=e)
99
+ raise ModelLoadError(f"acoustic model from {model_path}", original_error=e)
49
100
 
50
101
  try:
51
102
  config = FbankConfig(num_mel_bins=80, device=device, snip_edges=False)
52
103
  config_dict = config.to_dict()
53
- config_dict.pop('device')
104
+ config_dict.pop("device")
54
105
  self.extractor = Wav2LogFilterBank(**config_dict).to(device).eval()
55
106
  except Exception as e:
56
- raise ModelLoadError(f'feature extractor for device {device}', original_error=e)
107
+ raise ModelLoadError(f"feature extractor for device {device}", original_error=e)
57
108
 
58
109
  self.device = torch.device(device)
59
110
  self.timings = defaultdict(lambda: 0.0)
@@ -68,8 +119,8 @@ class Lattice1AlphaWorker:
68
119
  emissions = []
69
120
  for features in features_list:
70
121
  ort_inputs = {
71
- 'features': features.cpu().numpy(),
72
- 'feature_lengths': np.array([features.size(1)], dtype=np.int64),
122
+ "features": features.cpu().numpy(),
123
+ "feature_lengths": np.array([features.size(1)], dtype=np.int64),
73
124
  }
74
125
  emission = self.acoustic_ort.run(None, ort_inputs)[0] # (1, T, vocab_size) numpy
75
126
  emissions.append(emission)
@@ -78,44 +129,40 @@ class Lattice1AlphaWorker:
78
129
  ) # (1, T, vocab_size)
79
130
  else:
80
131
  ort_inputs = {
81
- 'features': features.cpu().numpy(),
82
- 'feature_lengths': np.array([features.size(1)], dtype=np.int64),
132
+ "features": features.cpu().numpy(),
133
+ "feature_lengths": np.array([features.size(1)], dtype=np.int64),
83
134
  }
84
135
  emission = self.acoustic_ort.run(None, ort_inputs)[0] # (1, T, vocab_size) numpy
85
136
  emission = torch.from_numpy(emission).to(self.device)
86
137
 
87
- self.timings['emission'] += time.time() - _start
138
+ self.timings["emission"] += time.time() - _start
88
139
  return emission # (1, T, vocab_size) torch
89
140
 
90
- def load_audio(self, audio: Union[Pathlike, BinaryIO]) -> Tuple[torch.Tensor, int]:
141
+ def load_audio(
142
+ self, audio: Union[Pathlike, BinaryIO], channel_selector: Optional[ChannelSelectorType] = "average"
143
+ ) -> Tuple[torch.Tensor, int]:
91
144
  # load audio
92
145
  try:
93
- waveform, sample_rate = read_audio(audio) # numpy array
94
- if len(waveform.shape) == 1:
95
- waveform = waveform.reshape([1, -1]) # (1, L)
96
- else: # make sure channel first
97
- if waveform.shape[0] > waveform.shape[1]:
98
- waveform = waveform.transpose(0, 1)
99
- # average multiple channels
100
- waveform = np.mean(waveform, axis=0, keepdims=True) # (1, L)
146
+ waveform, sample_rate = sf.read(audio, always_2d=True, dtype="float32") # numpy array
147
+ waveform = waveform.T # (channels, samples)
101
148
  except Exception as primary_error:
102
149
  # Fallback to PyAV for formats not supported by soundfile
103
150
  try:
104
151
  import av
105
152
  except ImportError:
106
153
  raise DependencyError(
107
- 'av (PyAV)', install_command='pip install av', context={'primary_error': str(primary_error)}
154
+ "av (PyAV)", install_command="pip install av", context={"primary_error": str(primary_error)}
108
155
  )
109
156
 
110
157
  try:
111
158
  container = av.open(audio)
112
- audio_stream = next((s for s in container.streams if s.type == 'audio'), None)
159
+ audio_stream = next((s for s in container.streams if s.type == "audio"), None)
113
160
 
114
161
  if audio_stream is None:
115
- raise AudioFormatError(str(audio), 'No audio stream found in file')
162
+ raise AudioFormatError(str(audio), "No audio stream found in file")
116
163
 
117
164
  # Resample to target sample rate during decoding
118
- audio_stream.codec_context.format = av.AudioFormat('flt') # 32-bit float
165
+ audio_stream.codec_context.format = av.AudioFormat("flt") # 32-bit float
119
166
 
120
167
  frames = []
121
168
  for frame in container.decode(audio_stream):
@@ -131,27 +178,20 @@ class Lattice1AlphaWorker:
131
178
  container.close()
132
179
 
133
180
  if not frames:
134
- raise AudioFormatError(str(audio), 'No audio data found in file')
181
+ raise AudioFormatError(str(audio), "No audio data found in file")
135
182
 
136
183
  # Concatenate all frames
137
184
  waveform = np.concatenate(frames, axis=1)
138
- # Average multiple channels to mono
139
- if waveform.shape[0] > 1:
140
- waveform = np.mean(waveform, axis=0, keepdims=True)
141
-
142
185
  sample_rate = audio_stream.codec_context.sample_rate
143
186
  except Exception as e:
144
187
  raise AudioLoadError(str(audio), original_error=e)
145
188
 
146
- try:
147
- if sample_rate != self.config['sample_rate']:
148
- waveform = resampy.resample(waveform, sample_rate, self.config['sample_rate'], axis=1)
149
- except Exception:
150
- raise AudioFormatError(
151
- str(audio), f'Failed to resample from {sample_rate}Hz to {self.config["sample_rate"]}Hz'
152
- )
153
-
154
- return torch.from_numpy(waveform).to(self.device) # (1, L)
189
+ return resample_audio(
190
+ (torch.from_numpy(waveform), sample_rate),
191
+ self.config.get("sampling_rate", 16000),
192
+ device=self.device.type,
193
+ channel_selector=channel_selector,
194
+ )
155
195
 
156
196
  def alignment(
157
197
  self, audio: Union[Union[Pathlike, BinaryIO], torch.tensor], lattice_graph: Tuple[str, int, float]
@@ -181,21 +221,21 @@ class Lattice1AlphaWorker:
181
221
  emission = self.emission(waveform.to(self.device)) # (1, T, vocab_size)
182
222
  except Exception as e:
183
223
  raise AlignmentError(
184
- 'Failed to compute acoustic features from audio',
185
- audio_path=str(audio) if not isinstance(audio, torch.Tensor) else 'tensor',
186
- context={'original_error': str(e)},
224
+ "Failed to compute acoustic features from audio",
225
+ audio_path=str(audio) if not isinstance(audio, torch.Tensor) else "tensor",
226
+ context={"original_error": str(e)},
187
227
  )
188
- self.timings['emission'] += time.time() - _start
228
+ self.timings["emission"] += time.time() - _start
189
229
 
190
230
  try:
191
231
  import k2
192
232
  except ImportError:
193
- raise DependencyError('k2', install_command='pip install install-k2 && python -m install_k2')
233
+ raise DependencyError("k2", install_command="pip install install-k2 && python -m install_k2")
194
234
 
195
235
  try:
196
236
  from lattifai_core.lattice.decode import align_segments
197
237
  except ImportError:
198
- raise DependencyError('lattifai_core', install_command='Contact support for lattifai_core installation')
238
+ raise DependencyError("lattifai_core", install_command="Contact support for lattifai_core installation")
199
239
 
200
240
  lattice_graph_str, final_state, acoustic_scale = lattice_graph
201
241
 
@@ -209,14 +249,14 @@ class Lattice1AlphaWorker:
209
249
  decoding_graph.return_id = int(final_state + 1)
210
250
  except Exception as e:
211
251
  raise AlignmentError(
212
- 'Failed to create decoding graph from lattice',
213
- context={'original_error': str(e), 'lattice_graph_length': len(lattice_graph_str)},
252
+ "Failed to create decoding graph from lattice",
253
+ context={"original_error": str(e), "lattice_graph_length": len(lattice_graph_str)},
214
254
  )
215
- self.timings['decoding_graph'] += time.time() - _start
255
+ self.timings["decoding_graph"] += time.time() - _start
216
256
 
217
257
  _start = time.time()
218
- if self.device.type == 'mps':
219
- device = 'cpu' # k2 does not support mps yet
258
+ if self.device.type == "mps":
259
+ device = "cpu" # k2 does not support mps yet
220
260
  else:
221
261
  device = self.device
222
262
 
@@ -234,11 +274,11 @@ class Lattice1AlphaWorker:
234
274
  )
235
275
  except Exception as e:
236
276
  raise AlignmentError(
237
- 'Failed to perform forced alignment',
238
- audio_path=str(audio) if not isinstance(audio, torch.Tensor) else 'tensor',
239
- context={'original_error': str(e), 'emission_shape': list(emission.shape), 'device': str(device)},
277
+ "Failed to perform forced alignment",
278
+ audio_path=str(audio) if not isinstance(audio, torch.Tensor) else "tensor",
279
+ context={"original_error": str(e), "emission_shape": list(emission.shape), "device": str(device)},
240
280
  )
241
- self.timings['align_segments'] += time.time() - _start
281
+ self.timings["align_segments"] += time.time() - _start
242
282
 
243
283
  channel = 0
244
284
  return emission, results, labels, 0.02, 0.0, channel # frame_shift=20ms, offset=0.0s
@@ -20,15 +20,15 @@ from .base import WorkflowAgent, WorkflowResult, WorkflowStep
20
20
  from .file_manager import FileExistenceManager
21
21
 
22
22
  __all__ = [
23
- 'WorkflowAgent',
24
- 'WorkflowStep',
25
- 'WorkflowResult',
26
- 'YouTubeSubtitleAgent',
27
- 'FileExistenceManager',
28
- 'GeminiReader',
29
- 'GeminiWriter',
30
- 'SUBTITLE_FORMATS',
31
- 'INPUT_SUBTITLE_FORMATS',
32
- 'OUTPUT_SUBTITLE_FORMATS',
33
- 'ALL_SUBTITLE_FORMATS',
23
+ "WorkflowAgent",
24
+ "WorkflowStep",
25
+ "WorkflowResult",
26
+ "YouTubeSubtitleAgent",
27
+ "FileExistenceManager",
28
+ "GeminiReader",
29
+ "GeminiWriter",
30
+ "SUBTITLE_FORMATS",
31
+ "INPUT_SUBTITLE_FORMATS",
32
+ "OUTPUT_SUBTITLE_FORMATS",
33
+ "ALL_SUBTITLE_FORMATS",
34
34
  ]
@@ -8,3 +8,5 @@ An agentic workflow for processing YouTube(or more) videos through:
8
8
  """
9
9
 
10
10
  from .youtube import YouTubeSubtitleAgent
11
+
12
+ __all__ = ["YouTubeSubtitleAgent"]
@@ -7,20 +7,20 @@ import logging
7
7
  import time
8
8
  from dataclasses import dataclass
9
9
  from enum import Enum
10
- from typing import Any, Dict, List, Optional, Union
10
+ from typing import Any, Dict, List, Optional
11
11
 
12
12
  import colorful
13
13
 
14
14
 
15
15
  def setup_workflow_logger(name: str) -> logging.Logger:
16
16
  """Setup a logger with consistent formatting for workflow modules"""
17
- logger = logging.getLogger(f'workflows.{name}')
17
+ logger = logging.getLogger(f"workflows.{name}")
18
18
 
19
19
  # Only add handler if it doesn't exist
20
20
  if not logger.handlers:
21
21
  handler = logging.StreamHandler()
22
22
  formatter = logging.Formatter(
23
- '%(asctime)s - %(name)+17s.py:%(lineno)-4d - %(levelname)-8s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S'
23
+ "%(asctime)s - %(name)+17s.py:%(lineno)-4d - %(levelname)-8s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
24
24
  )
25
25
  handler.setFormatter(formatter)
26
26
  logger.addHandler(handler)
@@ -30,17 +30,17 @@ def setup_workflow_logger(name: str) -> logging.Logger:
30
30
  return logger
31
31
 
32
32
 
33
- logger = setup_workflow_logger('base')
33
+ logger = setup_workflow_logger("base")
34
34
 
35
35
 
36
36
  class WorkflowStatus(Enum):
37
37
  """Workflow execution status"""
38
38
 
39
- PENDING = 'pending'
40
- RUNNING = 'running'
41
- COMPLETED = 'completed'
42
- FAILED = 'failed'
43
- RETRYING = 'retrying'
39
+ PENDING = "pending"
40
+ RUNNING = "running"
41
+ COMPLETED = "completed"
42
+ FAILED = "failed"
43
+ RETRYING = "retrying"
44
44
 
45
45
 
46
46
  @dataclass
@@ -84,7 +84,7 @@ class WorkflowAgent(abc.ABC):
84
84
  self.name = name
85
85
  self.max_retries = max_retries
86
86
  self.steps: List[WorkflowStep] = []
87
- self.logger = setup_workflow_logger('agent')
87
+ self.logger = setup_workflow_logger("agent")
88
88
 
89
89
  @abc.abstractmethod
90
90
  def define_steps(self) -> List[WorkflowStep]:
@@ -111,11 +111,11 @@ class WorkflowAgent(abc.ABC):
111
111
  context = kwargs.copy()
112
112
  step_results = []
113
113
 
114
- self.logger.info(colorful.bold_white_on_green(f'🚀 Starting workflow: {self.name}'))
114
+ self.logger.info(colorful.bold_white_on_green(f"🚀 Starting workflow: {self.name}"))
115
115
 
116
116
  try:
117
117
  for i, step in enumerate(self.steps):
118
- step_info = f'📋 Step {i + 1}/{len(self.steps)}: {step.name}'
118
+ step_info = f"📋 Step {i + 1}/{len(self.steps)}: {step.name}"
119
119
  self.logger.info(colorful.bold_white_on_green(step_info))
120
120
 
121
121
  step_start = time.time()
@@ -123,17 +123,17 @@ class WorkflowAgent(abc.ABC):
123
123
  step_duration = time.time() - step_start
124
124
 
125
125
  step_results.append(
126
- {'step_name': step.name, 'status': 'completed', 'duration': step_duration, 'result': step_result}
126
+ {"step_name": step.name, "status": "completed", "duration": step_duration, "result": step_result}
127
127
  )
128
128
 
129
129
  # Update context with step result
130
- context[f'step_{i}_result'] = step_result
130
+ context[f"step_{i}_result"] = step_result
131
131
  context[f'{step.name.lower().replace(" ", "_")}_result'] = step_result
132
132
 
133
- self.logger.info(f'✅ Step {i + 1} completed in {step_duration:.2f}s')
133
+ self.logger.info(f"✅ Step {i + 1} completed in {step_duration:.2f}s")
134
134
 
135
135
  execution_time = time.time() - start_time
136
- self.logger.info(f'🎉 Workflow completed in {execution_time:.2f}s')
136
+ self.logger.info(f"🎉 Workflow completed in {execution_time:.2f}s")
137
137
 
138
138
  return WorkflowResult(
139
139
  status=WorkflowStatus.COMPLETED, data=context, execution_time=execution_time, step_results=step_results
@@ -145,9 +145,9 @@ class WorkflowAgent(abc.ABC):
145
145
  from lattifai.errors import LattifAIError
146
146
 
147
147
  if isinstance(e, LattifAIError):
148
- self.logger.error(f'❌ Workflow failed after {execution_time:.2f}s: [{e.error_code}] {e.message}')
148
+ self.logger.error(f"❌ Workflow failed after {execution_time:.2f}s: [{e.error_code}] {e.message}")
149
149
  else:
150
- self.logger.error(f'❌ Workflow failed after {execution_time:.2f}s: {str(e)}')
150
+ self.logger.error(f"❌ Workflow failed after {execution_time:.2f}s: {str(e)}")
151
151
 
152
152
  return WorkflowResult(
153
153
  status=WorkflowStatus.FAILED,
@@ -164,7 +164,7 @@ class WorkflowAgent(abc.ABC):
164
164
  for attempt in range(step.max_retries + 1):
165
165
  try:
166
166
  if attempt > 0:
167
- self.logger.info(f'🔄 Retrying step {step.name} (attempt {attempt + 1}/{step.max_retries + 1})')
167
+ self.logger.info(f"🔄 Retrying step {step.name} (attempt {attempt + 1}/{step.max_retries + 1})")
168
168
 
169
169
  result = await self.execute_step(step, context)
170
170
  return result
@@ -176,14 +176,14 @@ class WorkflowAgent(abc.ABC):
176
176
  # For LattifAI errors, show simplified message in logs
177
177
  from lattifai.errors import LattifAIError
178
178
 
179
- error_summary = f'[{e.error_code}]' if isinstance(e, LattifAIError) else str(e)[:100]
179
+ error_summary = f"[{e.error_code}]" if isinstance(e, LattifAIError) else str(e)[:100]
180
180
 
181
181
  if step.should_retry():
182
- self.logger.warning(f'⚠️ Step {step.name} failed: {error_summary}. Retrying...')
182
+ self.logger.warning(f"⚠️ Step {step.name} failed: {error_summary}. Retrying...")
183
183
  continue
184
184
  else:
185
185
  self.logger.error(
186
- f'❌ Step {step.name} failed after {step.max_retries + 1} attempts: {error_summary}'
186
+ f"❌ Step {step.name} failed after {step.max_retries + 1} attempts: {error_summary}"
187
187
  )
188
188
  raise e
189
189