lattifai 0.4.5__py3-none-any.whl → 0.4.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -21,7 +21,7 @@ def resample_audio(
21
21
  audio_sr: Tuple[torch.Tensor, int],
22
22
  sampling_rate: int,
23
23
  device: Optional[str],
24
- channel_selector: Optional[ChannelSelectorType] = 'average',
24
+ channel_selector: Optional[ChannelSelectorType] = "average",
25
25
  ) -> torch.Tensor:
26
26
  """
27
27
  return:
@@ -33,21 +33,21 @@ def resample_audio(
33
33
  # keep the original multi-channel signal
34
34
  tensor = audio
35
35
  elif isinstance(channel_selector, int):
36
- assert audio.shape[0] >= channel_selector, f'Invalid channel: {channel_selector}'
36
+ assert audio.shape[0] >= channel_selector, f"Invalid channel: {channel_selector}"
37
37
  tensor = audio[channel_selector : channel_selector + 1].clone()
38
38
  del audio
39
39
  elif isinstance(channel_selector, str):
40
- assert channel_selector == 'average'
40
+ assert channel_selector == "average"
41
41
  tensor = torch.mean(audio.to(device), dim=0, keepdim=True)
42
42
  del audio
43
43
  else:
44
44
  assert isinstance(channel_selector, Iterable)
45
45
  num_channels = audio.shape[0]
46
- print(f'Selecting channels {channel_selector} from the signal with {num_channels} channels.')
46
+ print(f"Selecting channels {channel_selector} from the signal with {num_channels} channels.")
47
47
  assert isinstance(channel_selector, Iterable)
48
48
  if max(channel_selector) >= num_channels:
49
49
  raise ValueError(
50
- f'Cannot select channel subset {channel_selector} from a signal with {num_channels} channels.'
50
+ f"Cannot select channel subset {channel_selector} from a signal with {num_channels} channels."
51
51
  )
52
52
  tensor = audio[channel_selector]
53
53
 
@@ -70,41 +70,41 @@ def resample_audio(
70
70
  class Lattice1AlphaWorker:
71
71
  """Worker for processing audio with LatticeGraph."""
72
72
 
73
- def __init__(self, model_path: Pathlike, device: str = 'cpu', num_threads: int = 8) -> None:
73
+ def __init__(self, model_path: Pathlike, device: str = "cpu", num_threads: int = 8) -> None:
74
74
  try:
75
- self.config = json.load(open(f'{model_path}/config.json'))
75
+ self.config = json.load(open(f"{model_path}/config.json"))
76
76
  except Exception as e:
77
- raise ModelLoadError(f'config from {model_path}', original_error=e)
77
+ raise ModelLoadError(f"config from {model_path}", original_error=e)
78
78
 
79
79
  # SessionOptions
80
80
  sess_options = ort.SessionOptions()
81
81
  # sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
82
82
  sess_options.intra_op_num_threads = num_threads # CPU cores
83
83
  sess_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL
84
- sess_options.add_session_config_entry('session.intra_op.allow_spinning', '0')
84
+ sess_options.add_session_config_entry("session.intra_op.allow_spinning", "0")
85
85
 
86
86
  providers = []
87
- if device.startswith('cuda') and ort.get_all_providers().count('CUDAExecutionProvider') > 0:
88
- providers.append('CUDAExecutionProvider')
89
- elif device.startswith('mps') and ort.get_all_providers().count('MPSExecutionProvider') > 0:
90
- providers.append('MPSExecutionProvider')
87
+ if device.startswith("cuda") and ort.get_all_providers().count("CUDAExecutionProvider") > 0:
88
+ providers.append("CUDAExecutionProvider")
89
+ elif device.startswith("mps") and ort.get_all_providers().count("MPSExecutionProvider") > 0:
90
+ providers.append("MPSExecutionProvider")
91
91
 
92
92
  try:
93
93
  self.acoustic_ort = ort.InferenceSession(
94
- f'{model_path}/acoustic_opt.onnx',
94
+ f"{model_path}/acoustic_opt.onnx",
95
95
  sess_options,
96
- providers=providers + ['CPUExecutionProvider', 'CoreMLExecutionProvider'],
96
+ providers=providers + ["CPUExecutionProvider", "CoreMLExecutionProvider"],
97
97
  )
98
98
  except Exception as e:
99
- raise ModelLoadError(f'acoustic model from {model_path}', original_error=e)
99
+ raise ModelLoadError(f"acoustic model from {model_path}", original_error=e)
100
100
 
101
101
  try:
102
102
  config = FbankConfig(num_mel_bins=80, device=device, snip_edges=False)
103
103
  config_dict = config.to_dict()
104
- config_dict.pop('device')
104
+ config_dict.pop("device")
105
105
  self.extractor = Wav2LogFilterBank(**config_dict).to(device).eval()
106
106
  except Exception as e:
107
- raise ModelLoadError(f'feature extractor for device {device}', original_error=e)
107
+ raise ModelLoadError(f"feature extractor for device {device}", original_error=e)
108
108
 
109
109
  self.device = torch.device(device)
110
110
  self.timings = defaultdict(lambda: 0.0)
@@ -119,8 +119,8 @@ class Lattice1AlphaWorker:
119
119
  emissions = []
120
120
  for features in features_list:
121
121
  ort_inputs = {
122
- 'features': features.cpu().numpy(),
123
- 'feature_lengths': np.array([features.size(1)], dtype=np.int64),
122
+ "features": features.cpu().numpy(),
123
+ "feature_lengths": np.array([features.size(1)], dtype=np.int64),
124
124
  }
125
125
  emission = self.acoustic_ort.run(None, ort_inputs)[0] # (1, T, vocab_size) numpy
126
126
  emissions.append(emission)
@@ -129,21 +129,21 @@ class Lattice1AlphaWorker:
129
129
  ) # (1, T, vocab_size)
130
130
  else:
131
131
  ort_inputs = {
132
- 'features': features.cpu().numpy(),
133
- 'feature_lengths': np.array([features.size(1)], dtype=np.int64),
132
+ "features": features.cpu().numpy(),
133
+ "feature_lengths": np.array([features.size(1)], dtype=np.int64),
134
134
  }
135
135
  emission = self.acoustic_ort.run(None, ort_inputs)[0] # (1, T, vocab_size) numpy
136
136
  emission = torch.from_numpy(emission).to(self.device)
137
137
 
138
- self.timings['emission'] += time.time() - _start
138
+ self.timings["emission"] += time.time() - _start
139
139
  return emission # (1, T, vocab_size) torch
140
140
 
141
141
  def load_audio(
142
- self, audio: Union[Pathlike, BinaryIO], channel_selector: Optional[ChannelSelectorType] = 'average'
142
+ self, audio: Union[Pathlike, BinaryIO], channel_selector: Optional[ChannelSelectorType] = "average"
143
143
  ) -> Tuple[torch.Tensor, int]:
144
144
  # load audio
145
145
  try:
146
- waveform, sample_rate = sf.read(audio, always_2d=True, dtype='float32') # numpy array
146
+ waveform, sample_rate = sf.read(audio, always_2d=True, dtype="float32") # numpy array
147
147
  waveform = waveform.T # (channels, samples)
148
148
  except Exception as primary_error:
149
149
  # Fallback to PyAV for formats not supported by soundfile
@@ -151,18 +151,18 @@ class Lattice1AlphaWorker:
151
151
  import av
152
152
  except ImportError:
153
153
  raise DependencyError(
154
- 'av (PyAV)', install_command='pip install av', context={'primary_error': str(primary_error)}
154
+ "av (PyAV)", install_command="pip install av", context={"primary_error": str(primary_error)}
155
155
  )
156
156
 
157
157
  try:
158
158
  container = av.open(audio)
159
- audio_stream = next((s for s in container.streams if s.type == 'audio'), None)
159
+ audio_stream = next((s for s in container.streams if s.type == "audio"), None)
160
160
 
161
161
  if audio_stream is None:
162
- raise AudioFormatError(str(audio), 'No audio stream found in file')
162
+ raise AudioFormatError(str(audio), "No audio stream found in file")
163
163
 
164
164
  # Resample to target sample rate during decoding
165
- audio_stream.codec_context.format = av.AudioFormat('flt') # 32-bit float
165
+ audio_stream.codec_context.format = av.AudioFormat("flt") # 32-bit float
166
166
 
167
167
  frames = []
168
168
  for frame in container.decode(audio_stream):
@@ -178,7 +178,7 @@ class Lattice1AlphaWorker:
178
178
  container.close()
179
179
 
180
180
  if not frames:
181
- raise AudioFormatError(str(audio), 'No audio data found in file')
181
+ raise AudioFormatError(str(audio), "No audio data found in file")
182
182
 
183
183
  # Concatenate all frames
184
184
  waveform = np.concatenate(frames, axis=1)
@@ -188,7 +188,7 @@ class Lattice1AlphaWorker:
188
188
 
189
189
  return resample_audio(
190
190
  (torch.from_numpy(waveform), sample_rate),
191
- self.config.get('sampling_rate', 16000),
191
+ self.config.get("sampling_rate", 16000),
192
192
  device=self.device.type,
193
193
  channel_selector=channel_selector,
194
194
  )
@@ -221,21 +221,21 @@ class Lattice1AlphaWorker:
221
221
  emission = self.emission(waveform.to(self.device)) # (1, T, vocab_size)
222
222
  except Exception as e:
223
223
  raise AlignmentError(
224
- 'Failed to compute acoustic features from audio',
225
- audio_path=str(audio) if not isinstance(audio, torch.Tensor) else 'tensor',
226
- context={'original_error': str(e)},
224
+ "Failed to compute acoustic features from audio",
225
+ audio_path=str(audio) if not isinstance(audio, torch.Tensor) else "tensor",
226
+ context={"original_error": str(e)},
227
227
  )
228
- self.timings['emission'] += time.time() - _start
228
+ self.timings["emission"] += time.time() - _start
229
229
 
230
230
  try:
231
231
  import k2
232
232
  except ImportError:
233
- raise DependencyError('k2', install_command='pip install install-k2 && python -m install_k2')
233
+ raise DependencyError("k2", install_command="pip install install-k2 && python -m install_k2")
234
234
 
235
235
  try:
236
236
  from lattifai_core.lattice.decode import align_segments
237
237
  except ImportError:
238
- raise DependencyError('lattifai_core', install_command='Contact support for lattifai_core installation')
238
+ raise DependencyError("lattifai_core", install_command="Contact support for lattifai_core installation")
239
239
 
240
240
  lattice_graph_str, final_state, acoustic_scale = lattice_graph
241
241
 
@@ -249,14 +249,14 @@ class Lattice1AlphaWorker:
249
249
  decoding_graph.return_id = int(final_state + 1)
250
250
  except Exception as e:
251
251
  raise AlignmentError(
252
- 'Failed to create decoding graph from lattice',
253
- context={'original_error': str(e), 'lattice_graph_length': len(lattice_graph_str)},
252
+ "Failed to create decoding graph from lattice",
253
+ context={"original_error": str(e), "lattice_graph_length": len(lattice_graph_str)},
254
254
  )
255
- self.timings['decoding_graph'] += time.time() - _start
255
+ self.timings["decoding_graph"] += time.time() - _start
256
256
 
257
257
  _start = time.time()
258
- if self.device.type == 'mps':
259
- device = 'cpu' # k2 does not support mps yet
258
+ if self.device.type == "mps":
259
+ device = "cpu" # k2 does not support mps yet
260
260
  else:
261
261
  device = self.device
262
262
 
@@ -274,11 +274,11 @@ class Lattice1AlphaWorker:
274
274
  )
275
275
  except Exception as e:
276
276
  raise AlignmentError(
277
- 'Failed to perform forced alignment',
278
- audio_path=str(audio) if not isinstance(audio, torch.Tensor) else 'tensor',
279
- context={'original_error': str(e), 'emission_shape': list(emission.shape), 'device': str(device)},
277
+ "Failed to perform forced alignment",
278
+ audio_path=str(audio) if not isinstance(audio, torch.Tensor) else "tensor",
279
+ context={"original_error": str(e), "emission_shape": list(emission.shape), "device": str(device)},
280
280
  )
281
- self.timings['align_segments'] += time.time() - _start
281
+ self.timings["align_segments"] += time.time() - _start
282
282
 
283
283
  channel = 0
284
284
  return emission, results, labels, 0.02, 0.0, channel # frame_shift=20ms, offset=0.0s
@@ -20,15 +20,15 @@ from .base import WorkflowAgent, WorkflowResult, WorkflowStep
20
20
  from .file_manager import FileExistenceManager
21
21
 
22
22
  __all__ = [
23
- 'WorkflowAgent',
24
- 'WorkflowStep',
25
- 'WorkflowResult',
26
- 'YouTubeSubtitleAgent',
27
- 'FileExistenceManager',
28
- 'GeminiReader',
29
- 'GeminiWriter',
30
- 'SUBTITLE_FORMATS',
31
- 'INPUT_SUBTITLE_FORMATS',
32
- 'OUTPUT_SUBTITLE_FORMATS',
33
- 'ALL_SUBTITLE_FORMATS',
23
+ "WorkflowAgent",
24
+ "WorkflowStep",
25
+ "WorkflowResult",
26
+ "YouTubeSubtitleAgent",
27
+ "FileExistenceManager",
28
+ "GeminiReader",
29
+ "GeminiWriter",
30
+ "SUBTITLE_FORMATS",
31
+ "INPUT_SUBTITLE_FORMATS",
32
+ "OUTPUT_SUBTITLE_FORMATS",
33
+ "ALL_SUBTITLE_FORMATS",
34
34
  ]
@@ -8,3 +8,5 @@ An agentic workflow for processing YouTube(or more) videos through:
8
8
  """
9
9
 
10
10
  from .youtube import YouTubeSubtitleAgent
11
+
12
+ __all__ = ["YouTubeSubtitleAgent"]
@@ -7,20 +7,20 @@ import logging
7
7
  import time
8
8
  from dataclasses import dataclass
9
9
  from enum import Enum
10
- from typing import Any, Dict, List, Optional, Union
10
+ from typing import Any, Dict, List, Optional
11
11
 
12
12
  import colorful
13
13
 
14
14
 
15
15
  def setup_workflow_logger(name: str) -> logging.Logger:
16
16
  """Setup a logger with consistent formatting for workflow modules"""
17
- logger = logging.getLogger(f'workflows.{name}')
17
+ logger = logging.getLogger(f"workflows.{name}")
18
18
 
19
19
  # Only add handler if it doesn't exist
20
20
  if not logger.handlers:
21
21
  handler = logging.StreamHandler()
22
22
  formatter = logging.Formatter(
23
- '%(asctime)s - %(name)+17s.py:%(lineno)-4d - %(levelname)-8s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S'
23
+ "%(asctime)s - %(name)+17s.py:%(lineno)-4d - %(levelname)-8s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
24
24
  )
25
25
  handler.setFormatter(formatter)
26
26
  logger.addHandler(handler)
@@ -30,17 +30,17 @@ def setup_workflow_logger(name: str) -> logging.Logger:
30
30
  return logger
31
31
 
32
32
 
33
- logger = setup_workflow_logger('base')
33
+ logger = setup_workflow_logger("base")
34
34
 
35
35
 
36
36
  class WorkflowStatus(Enum):
37
37
  """Workflow execution status"""
38
38
 
39
- PENDING = 'pending'
40
- RUNNING = 'running'
41
- COMPLETED = 'completed'
42
- FAILED = 'failed'
43
- RETRYING = 'retrying'
39
+ PENDING = "pending"
40
+ RUNNING = "running"
41
+ COMPLETED = "completed"
42
+ FAILED = "failed"
43
+ RETRYING = "retrying"
44
44
 
45
45
 
46
46
  @dataclass
@@ -84,7 +84,7 @@ class WorkflowAgent(abc.ABC):
84
84
  self.name = name
85
85
  self.max_retries = max_retries
86
86
  self.steps: List[WorkflowStep] = []
87
- self.logger = setup_workflow_logger('agent')
87
+ self.logger = setup_workflow_logger("agent")
88
88
 
89
89
  @abc.abstractmethod
90
90
  def define_steps(self) -> List[WorkflowStep]:
@@ -111,11 +111,11 @@ class WorkflowAgent(abc.ABC):
111
111
  context = kwargs.copy()
112
112
  step_results = []
113
113
 
114
- self.logger.info(colorful.bold_white_on_green(f'🚀 Starting workflow: {self.name}'))
114
+ self.logger.info(colorful.bold_white_on_green(f"🚀 Starting workflow: {self.name}"))
115
115
 
116
116
  try:
117
117
  for i, step in enumerate(self.steps):
118
- step_info = f'📋 Step {i + 1}/{len(self.steps)}: {step.name}'
118
+ step_info = f"📋 Step {i + 1}/{len(self.steps)}: {step.name}"
119
119
  self.logger.info(colorful.bold_white_on_green(step_info))
120
120
 
121
121
  step_start = time.time()
@@ -123,17 +123,17 @@ class WorkflowAgent(abc.ABC):
123
123
  step_duration = time.time() - step_start
124
124
 
125
125
  step_results.append(
126
- {'step_name': step.name, 'status': 'completed', 'duration': step_duration, 'result': step_result}
126
+ {"step_name": step.name, "status": "completed", "duration": step_duration, "result": step_result}
127
127
  )
128
128
 
129
129
  # Update context with step result
130
- context[f'step_{i}_result'] = step_result
130
+ context[f"step_{i}_result"] = step_result
131
131
  context[f'{step.name.lower().replace(" ", "_")}_result'] = step_result
132
132
 
133
- self.logger.info(f'✅ Step {i + 1} completed in {step_duration:.2f}s')
133
+ self.logger.info(f"✅ Step {i + 1} completed in {step_duration:.2f}s")
134
134
 
135
135
  execution_time = time.time() - start_time
136
- self.logger.info(f'🎉 Workflow completed in {execution_time:.2f}s')
136
+ self.logger.info(f"🎉 Workflow completed in {execution_time:.2f}s")
137
137
 
138
138
  return WorkflowResult(
139
139
  status=WorkflowStatus.COMPLETED, data=context, execution_time=execution_time, step_results=step_results
@@ -145,9 +145,9 @@ class WorkflowAgent(abc.ABC):
145
145
  from lattifai.errors import LattifAIError
146
146
 
147
147
  if isinstance(e, LattifAIError):
148
- self.logger.error(f'❌ Workflow failed after {execution_time:.2f}s: [{e.error_code}] {e.message}')
148
+ self.logger.error(f"❌ Workflow failed after {execution_time:.2f}s: [{e.error_code}] {e.message}")
149
149
  else:
150
- self.logger.error(f'❌ Workflow failed after {execution_time:.2f}s: {str(e)}')
150
+ self.logger.error(f"❌ Workflow failed after {execution_time:.2f}s: {str(e)}")
151
151
 
152
152
  return WorkflowResult(
153
153
  status=WorkflowStatus.FAILED,
@@ -164,7 +164,7 @@ class WorkflowAgent(abc.ABC):
164
164
  for attempt in range(step.max_retries + 1):
165
165
  try:
166
166
  if attempt > 0:
167
- self.logger.info(f'🔄 Retrying step {step.name} (attempt {attempt + 1}/{step.max_retries + 1})')
167
+ self.logger.info(f"🔄 Retrying step {step.name} (attempt {attempt + 1}/{step.max_retries + 1})")
168
168
 
169
169
  result = await self.execute_step(step, context)
170
170
  return result
@@ -176,14 +176,14 @@ class WorkflowAgent(abc.ABC):
176
176
  # For LattifAI errors, show simplified message in logs
177
177
  from lattifai.errors import LattifAIError
178
178
 
179
- error_summary = f'[{e.error_code}]' if isinstance(e, LattifAIError) else str(e)[:100]
179
+ error_summary = f"[{e.error_code}]" if isinstance(e, LattifAIError) else str(e)[:100]
180
180
 
181
181
  if step.should_retry():
182
- self.logger.warning(f'⚠️ Step {step.name} failed: {error_summary}. Retrying...')
182
+ self.logger.warning(f"⚠️ Step {step.name} failed: {error_summary}. Retrying...")
183
183
  continue
184
184
  else:
185
185
  self.logger.error(
186
- f'❌ Step {step.name} failed after {step.max_retries + 1} attempts: {error_summary}'
186
+ f"❌ Step {step.name} failed after {step.max_retries + 1} attempts: {error_summary}"
187
187
  )
188
188
  raise e
189
189