lattifai 0.4.5__py3-none-any.whl → 0.4.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lattifai/__init__.py +26 -27
- lattifai/base_client.py +7 -7
- lattifai/bin/agent.py +90 -91
- lattifai/bin/align.py +110 -111
- lattifai/bin/cli_base.py +3 -3
- lattifai/bin/subtitle.py +45 -45
- lattifai/client.py +56 -56
- lattifai/errors.py +73 -73
- lattifai/io/__init__.py +12 -11
- lattifai/io/gemini_reader.py +30 -30
- lattifai/io/gemini_writer.py +17 -17
- lattifai/io/reader.py +13 -12
- lattifai/io/supervision.py +3 -3
- lattifai/io/text_parser.py +43 -16
- lattifai/io/utils.py +4 -4
- lattifai/io/writer.py +31 -19
- lattifai/tokenizer/__init__.py +1 -1
- lattifai/tokenizer/phonemizer.py +3 -3
- lattifai/tokenizer/tokenizer.py +83 -82
- lattifai/utils.py +15 -15
- lattifai/workers/__init__.py +1 -1
- lattifai/workers/lattice1_alpha.py +46 -46
- lattifai/workflows/__init__.py +11 -11
- lattifai/workflows/agents.py +2 -0
- lattifai/workflows/base.py +22 -22
- lattifai/workflows/file_manager.py +182 -182
- lattifai/workflows/gemini.py +29 -29
- lattifai/workflows/prompts/__init__.py +4 -4
- lattifai/workflows/youtube.py +233 -233
- {lattifai-0.4.5.dist-info → lattifai-0.4.6.dist-info}/METADATA +7 -9
- lattifai-0.4.6.dist-info/RECORD +39 -0
- {lattifai-0.4.5.dist-info → lattifai-0.4.6.dist-info}/licenses/LICENSE +1 -1
- lattifai-0.4.5.dist-info/RECORD +0 -39
- {lattifai-0.4.5.dist-info → lattifai-0.4.6.dist-info}/WHEEL +0 -0
- {lattifai-0.4.5.dist-info → lattifai-0.4.6.dist-info}/entry_points.txt +0 -0
- {lattifai-0.4.5.dist-info → lattifai-0.4.6.dist-info}/top_level.txt +0 -0
|
@@ -21,7 +21,7 @@ def resample_audio(
|
|
|
21
21
|
audio_sr: Tuple[torch.Tensor, int],
|
|
22
22
|
sampling_rate: int,
|
|
23
23
|
device: Optional[str],
|
|
24
|
-
channel_selector: Optional[ChannelSelectorType] =
|
|
24
|
+
channel_selector: Optional[ChannelSelectorType] = "average",
|
|
25
25
|
) -> torch.Tensor:
|
|
26
26
|
"""
|
|
27
27
|
return:
|
|
@@ -33,21 +33,21 @@ def resample_audio(
|
|
|
33
33
|
# keep the original multi-channel signal
|
|
34
34
|
tensor = audio
|
|
35
35
|
elif isinstance(channel_selector, int):
|
|
36
|
-
assert audio.shape[0] >= channel_selector, f
|
|
36
|
+
assert audio.shape[0] >= channel_selector, f"Invalid channel: {channel_selector}"
|
|
37
37
|
tensor = audio[channel_selector : channel_selector + 1].clone()
|
|
38
38
|
del audio
|
|
39
39
|
elif isinstance(channel_selector, str):
|
|
40
|
-
assert channel_selector ==
|
|
40
|
+
assert channel_selector == "average"
|
|
41
41
|
tensor = torch.mean(audio.to(device), dim=0, keepdim=True)
|
|
42
42
|
del audio
|
|
43
43
|
else:
|
|
44
44
|
assert isinstance(channel_selector, Iterable)
|
|
45
45
|
num_channels = audio.shape[0]
|
|
46
|
-
print(f
|
|
46
|
+
print(f"Selecting channels {channel_selector} from the signal with {num_channels} channels.")
|
|
47
47
|
assert isinstance(channel_selector, Iterable)
|
|
48
48
|
if max(channel_selector) >= num_channels:
|
|
49
49
|
raise ValueError(
|
|
50
|
-
f
|
|
50
|
+
f"Cannot select channel subset {channel_selector} from a signal with {num_channels} channels."
|
|
51
51
|
)
|
|
52
52
|
tensor = audio[channel_selector]
|
|
53
53
|
|
|
@@ -70,41 +70,41 @@ def resample_audio(
|
|
|
70
70
|
class Lattice1AlphaWorker:
|
|
71
71
|
"""Worker for processing audio with LatticeGraph."""
|
|
72
72
|
|
|
73
|
-
def __init__(self, model_path: Pathlike, device: str =
|
|
73
|
+
def __init__(self, model_path: Pathlike, device: str = "cpu", num_threads: int = 8) -> None:
|
|
74
74
|
try:
|
|
75
|
-
self.config = json.load(open(f
|
|
75
|
+
self.config = json.load(open(f"{model_path}/config.json"))
|
|
76
76
|
except Exception as e:
|
|
77
|
-
raise ModelLoadError(f
|
|
77
|
+
raise ModelLoadError(f"config from {model_path}", original_error=e)
|
|
78
78
|
|
|
79
79
|
# SessionOptions
|
|
80
80
|
sess_options = ort.SessionOptions()
|
|
81
81
|
# sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
82
82
|
sess_options.intra_op_num_threads = num_threads # CPU cores
|
|
83
83
|
sess_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL
|
|
84
|
-
sess_options.add_session_config_entry(
|
|
84
|
+
sess_options.add_session_config_entry("session.intra_op.allow_spinning", "0")
|
|
85
85
|
|
|
86
86
|
providers = []
|
|
87
|
-
if device.startswith(
|
|
88
|
-
providers.append(
|
|
89
|
-
elif device.startswith(
|
|
90
|
-
providers.append(
|
|
87
|
+
if device.startswith("cuda") and ort.get_all_providers().count("CUDAExecutionProvider") > 0:
|
|
88
|
+
providers.append("CUDAExecutionProvider")
|
|
89
|
+
elif device.startswith("mps") and ort.get_all_providers().count("MPSExecutionProvider") > 0:
|
|
90
|
+
providers.append("MPSExecutionProvider")
|
|
91
91
|
|
|
92
92
|
try:
|
|
93
93
|
self.acoustic_ort = ort.InferenceSession(
|
|
94
|
-
f
|
|
94
|
+
f"{model_path}/acoustic_opt.onnx",
|
|
95
95
|
sess_options,
|
|
96
|
-
providers=providers + [
|
|
96
|
+
providers=providers + ["CPUExecutionProvider", "CoreMLExecutionProvider"],
|
|
97
97
|
)
|
|
98
98
|
except Exception as e:
|
|
99
|
-
raise ModelLoadError(f
|
|
99
|
+
raise ModelLoadError(f"acoustic model from {model_path}", original_error=e)
|
|
100
100
|
|
|
101
101
|
try:
|
|
102
102
|
config = FbankConfig(num_mel_bins=80, device=device, snip_edges=False)
|
|
103
103
|
config_dict = config.to_dict()
|
|
104
|
-
config_dict.pop(
|
|
104
|
+
config_dict.pop("device")
|
|
105
105
|
self.extractor = Wav2LogFilterBank(**config_dict).to(device).eval()
|
|
106
106
|
except Exception as e:
|
|
107
|
-
raise ModelLoadError(f
|
|
107
|
+
raise ModelLoadError(f"feature extractor for device {device}", original_error=e)
|
|
108
108
|
|
|
109
109
|
self.device = torch.device(device)
|
|
110
110
|
self.timings = defaultdict(lambda: 0.0)
|
|
@@ -119,8 +119,8 @@ class Lattice1AlphaWorker:
|
|
|
119
119
|
emissions = []
|
|
120
120
|
for features in features_list:
|
|
121
121
|
ort_inputs = {
|
|
122
|
-
|
|
123
|
-
|
|
122
|
+
"features": features.cpu().numpy(),
|
|
123
|
+
"feature_lengths": np.array([features.size(1)], dtype=np.int64),
|
|
124
124
|
}
|
|
125
125
|
emission = self.acoustic_ort.run(None, ort_inputs)[0] # (1, T, vocab_size) numpy
|
|
126
126
|
emissions.append(emission)
|
|
@@ -129,21 +129,21 @@ class Lattice1AlphaWorker:
|
|
|
129
129
|
) # (1, T, vocab_size)
|
|
130
130
|
else:
|
|
131
131
|
ort_inputs = {
|
|
132
|
-
|
|
133
|
-
|
|
132
|
+
"features": features.cpu().numpy(),
|
|
133
|
+
"feature_lengths": np.array([features.size(1)], dtype=np.int64),
|
|
134
134
|
}
|
|
135
135
|
emission = self.acoustic_ort.run(None, ort_inputs)[0] # (1, T, vocab_size) numpy
|
|
136
136
|
emission = torch.from_numpy(emission).to(self.device)
|
|
137
137
|
|
|
138
|
-
self.timings[
|
|
138
|
+
self.timings["emission"] += time.time() - _start
|
|
139
139
|
return emission # (1, T, vocab_size) torch
|
|
140
140
|
|
|
141
141
|
def load_audio(
|
|
142
|
-
self, audio: Union[Pathlike, BinaryIO], channel_selector: Optional[ChannelSelectorType] =
|
|
142
|
+
self, audio: Union[Pathlike, BinaryIO], channel_selector: Optional[ChannelSelectorType] = "average"
|
|
143
143
|
) -> Tuple[torch.Tensor, int]:
|
|
144
144
|
# load audio
|
|
145
145
|
try:
|
|
146
|
-
waveform, sample_rate = sf.read(audio, always_2d=True, dtype=
|
|
146
|
+
waveform, sample_rate = sf.read(audio, always_2d=True, dtype="float32") # numpy array
|
|
147
147
|
waveform = waveform.T # (channels, samples)
|
|
148
148
|
except Exception as primary_error:
|
|
149
149
|
# Fallback to PyAV for formats not supported by soundfile
|
|
@@ -151,18 +151,18 @@ class Lattice1AlphaWorker:
|
|
|
151
151
|
import av
|
|
152
152
|
except ImportError:
|
|
153
153
|
raise DependencyError(
|
|
154
|
-
|
|
154
|
+
"av (PyAV)", install_command="pip install av", context={"primary_error": str(primary_error)}
|
|
155
155
|
)
|
|
156
156
|
|
|
157
157
|
try:
|
|
158
158
|
container = av.open(audio)
|
|
159
|
-
audio_stream = next((s for s in container.streams if s.type ==
|
|
159
|
+
audio_stream = next((s for s in container.streams if s.type == "audio"), None)
|
|
160
160
|
|
|
161
161
|
if audio_stream is None:
|
|
162
|
-
raise AudioFormatError(str(audio),
|
|
162
|
+
raise AudioFormatError(str(audio), "No audio stream found in file")
|
|
163
163
|
|
|
164
164
|
# Resample to target sample rate during decoding
|
|
165
|
-
audio_stream.codec_context.format = av.AudioFormat(
|
|
165
|
+
audio_stream.codec_context.format = av.AudioFormat("flt") # 32-bit float
|
|
166
166
|
|
|
167
167
|
frames = []
|
|
168
168
|
for frame in container.decode(audio_stream):
|
|
@@ -178,7 +178,7 @@ class Lattice1AlphaWorker:
|
|
|
178
178
|
container.close()
|
|
179
179
|
|
|
180
180
|
if not frames:
|
|
181
|
-
raise AudioFormatError(str(audio),
|
|
181
|
+
raise AudioFormatError(str(audio), "No audio data found in file")
|
|
182
182
|
|
|
183
183
|
# Concatenate all frames
|
|
184
184
|
waveform = np.concatenate(frames, axis=1)
|
|
@@ -188,7 +188,7 @@ class Lattice1AlphaWorker:
|
|
|
188
188
|
|
|
189
189
|
return resample_audio(
|
|
190
190
|
(torch.from_numpy(waveform), sample_rate),
|
|
191
|
-
self.config.get(
|
|
191
|
+
self.config.get("sampling_rate", 16000),
|
|
192
192
|
device=self.device.type,
|
|
193
193
|
channel_selector=channel_selector,
|
|
194
194
|
)
|
|
@@ -221,21 +221,21 @@ class Lattice1AlphaWorker:
|
|
|
221
221
|
emission = self.emission(waveform.to(self.device)) # (1, T, vocab_size)
|
|
222
222
|
except Exception as e:
|
|
223
223
|
raise AlignmentError(
|
|
224
|
-
|
|
225
|
-
audio_path=str(audio) if not isinstance(audio, torch.Tensor) else
|
|
226
|
-
context={
|
|
224
|
+
"Failed to compute acoustic features from audio",
|
|
225
|
+
audio_path=str(audio) if not isinstance(audio, torch.Tensor) else "tensor",
|
|
226
|
+
context={"original_error": str(e)},
|
|
227
227
|
)
|
|
228
|
-
self.timings[
|
|
228
|
+
self.timings["emission"] += time.time() - _start
|
|
229
229
|
|
|
230
230
|
try:
|
|
231
231
|
import k2
|
|
232
232
|
except ImportError:
|
|
233
|
-
raise DependencyError(
|
|
233
|
+
raise DependencyError("k2", install_command="pip install install-k2 && python -m install_k2")
|
|
234
234
|
|
|
235
235
|
try:
|
|
236
236
|
from lattifai_core.lattice.decode import align_segments
|
|
237
237
|
except ImportError:
|
|
238
|
-
raise DependencyError(
|
|
238
|
+
raise DependencyError("lattifai_core", install_command="Contact support for lattifai_core installation")
|
|
239
239
|
|
|
240
240
|
lattice_graph_str, final_state, acoustic_scale = lattice_graph
|
|
241
241
|
|
|
@@ -249,14 +249,14 @@ class Lattice1AlphaWorker:
|
|
|
249
249
|
decoding_graph.return_id = int(final_state + 1)
|
|
250
250
|
except Exception as e:
|
|
251
251
|
raise AlignmentError(
|
|
252
|
-
|
|
253
|
-
context={
|
|
252
|
+
"Failed to create decoding graph from lattice",
|
|
253
|
+
context={"original_error": str(e), "lattice_graph_length": len(lattice_graph_str)},
|
|
254
254
|
)
|
|
255
|
-
self.timings[
|
|
255
|
+
self.timings["decoding_graph"] += time.time() - _start
|
|
256
256
|
|
|
257
257
|
_start = time.time()
|
|
258
|
-
if self.device.type ==
|
|
259
|
-
device =
|
|
258
|
+
if self.device.type == "mps":
|
|
259
|
+
device = "cpu" # k2 does not support mps yet
|
|
260
260
|
else:
|
|
261
261
|
device = self.device
|
|
262
262
|
|
|
@@ -274,11 +274,11 @@ class Lattice1AlphaWorker:
|
|
|
274
274
|
)
|
|
275
275
|
except Exception as e:
|
|
276
276
|
raise AlignmentError(
|
|
277
|
-
|
|
278
|
-
audio_path=str(audio) if not isinstance(audio, torch.Tensor) else
|
|
279
|
-
context={
|
|
277
|
+
"Failed to perform forced alignment",
|
|
278
|
+
audio_path=str(audio) if not isinstance(audio, torch.Tensor) else "tensor",
|
|
279
|
+
context={"original_error": str(e), "emission_shape": list(emission.shape), "device": str(device)},
|
|
280
280
|
)
|
|
281
|
-
self.timings[
|
|
281
|
+
self.timings["align_segments"] += time.time() - _start
|
|
282
282
|
|
|
283
283
|
channel = 0
|
|
284
284
|
return emission, results, labels, 0.02, 0.0, channel # frame_shift=20ms, offset=0.0s
|
lattifai/workflows/__init__.py
CHANGED
|
@@ -20,15 +20,15 @@ from .base import WorkflowAgent, WorkflowResult, WorkflowStep
|
|
|
20
20
|
from .file_manager import FileExistenceManager
|
|
21
21
|
|
|
22
22
|
__all__ = [
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
23
|
+
"WorkflowAgent",
|
|
24
|
+
"WorkflowStep",
|
|
25
|
+
"WorkflowResult",
|
|
26
|
+
"YouTubeSubtitleAgent",
|
|
27
|
+
"FileExistenceManager",
|
|
28
|
+
"GeminiReader",
|
|
29
|
+
"GeminiWriter",
|
|
30
|
+
"SUBTITLE_FORMATS",
|
|
31
|
+
"INPUT_SUBTITLE_FORMATS",
|
|
32
|
+
"OUTPUT_SUBTITLE_FORMATS",
|
|
33
|
+
"ALL_SUBTITLE_FORMATS",
|
|
34
34
|
]
|
lattifai/workflows/agents.py
CHANGED
lattifai/workflows/base.py
CHANGED
|
@@ -7,20 +7,20 @@ import logging
|
|
|
7
7
|
import time
|
|
8
8
|
from dataclasses import dataclass
|
|
9
9
|
from enum import Enum
|
|
10
|
-
from typing import Any, Dict, List, Optional
|
|
10
|
+
from typing import Any, Dict, List, Optional
|
|
11
11
|
|
|
12
12
|
import colorful
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
def setup_workflow_logger(name: str) -> logging.Logger:
|
|
16
16
|
"""Setup a logger with consistent formatting for workflow modules"""
|
|
17
|
-
logger = logging.getLogger(f
|
|
17
|
+
logger = logging.getLogger(f"workflows.{name}")
|
|
18
18
|
|
|
19
19
|
# Only add handler if it doesn't exist
|
|
20
20
|
if not logger.handlers:
|
|
21
21
|
handler = logging.StreamHandler()
|
|
22
22
|
formatter = logging.Formatter(
|
|
23
|
-
|
|
23
|
+
"%(asctime)s - %(name)+17s.py:%(lineno)-4d - %(levelname)-8s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
|
|
24
24
|
)
|
|
25
25
|
handler.setFormatter(formatter)
|
|
26
26
|
logger.addHandler(handler)
|
|
@@ -30,17 +30,17 @@ def setup_workflow_logger(name: str) -> logging.Logger:
|
|
|
30
30
|
return logger
|
|
31
31
|
|
|
32
32
|
|
|
33
|
-
logger = setup_workflow_logger(
|
|
33
|
+
logger = setup_workflow_logger("base")
|
|
34
34
|
|
|
35
35
|
|
|
36
36
|
class WorkflowStatus(Enum):
|
|
37
37
|
"""Workflow execution status"""
|
|
38
38
|
|
|
39
|
-
PENDING =
|
|
40
|
-
RUNNING =
|
|
41
|
-
COMPLETED =
|
|
42
|
-
FAILED =
|
|
43
|
-
RETRYING =
|
|
39
|
+
PENDING = "pending"
|
|
40
|
+
RUNNING = "running"
|
|
41
|
+
COMPLETED = "completed"
|
|
42
|
+
FAILED = "failed"
|
|
43
|
+
RETRYING = "retrying"
|
|
44
44
|
|
|
45
45
|
|
|
46
46
|
@dataclass
|
|
@@ -84,7 +84,7 @@ class WorkflowAgent(abc.ABC):
|
|
|
84
84
|
self.name = name
|
|
85
85
|
self.max_retries = max_retries
|
|
86
86
|
self.steps: List[WorkflowStep] = []
|
|
87
|
-
self.logger = setup_workflow_logger(
|
|
87
|
+
self.logger = setup_workflow_logger("agent")
|
|
88
88
|
|
|
89
89
|
@abc.abstractmethod
|
|
90
90
|
def define_steps(self) -> List[WorkflowStep]:
|
|
@@ -111,11 +111,11 @@ class WorkflowAgent(abc.ABC):
|
|
|
111
111
|
context = kwargs.copy()
|
|
112
112
|
step_results = []
|
|
113
113
|
|
|
114
|
-
self.logger.info(colorful.bold_white_on_green(f
|
|
114
|
+
self.logger.info(colorful.bold_white_on_green(f"🚀 Starting workflow: {self.name}"))
|
|
115
115
|
|
|
116
116
|
try:
|
|
117
117
|
for i, step in enumerate(self.steps):
|
|
118
|
-
step_info = f
|
|
118
|
+
step_info = f"📋 Step {i + 1}/{len(self.steps)}: {step.name}"
|
|
119
119
|
self.logger.info(colorful.bold_white_on_green(step_info))
|
|
120
120
|
|
|
121
121
|
step_start = time.time()
|
|
@@ -123,17 +123,17 @@ class WorkflowAgent(abc.ABC):
|
|
|
123
123
|
step_duration = time.time() - step_start
|
|
124
124
|
|
|
125
125
|
step_results.append(
|
|
126
|
-
{
|
|
126
|
+
{"step_name": step.name, "status": "completed", "duration": step_duration, "result": step_result}
|
|
127
127
|
)
|
|
128
128
|
|
|
129
129
|
# Update context with step result
|
|
130
|
-
context[f
|
|
130
|
+
context[f"step_{i}_result"] = step_result
|
|
131
131
|
context[f'{step.name.lower().replace(" ", "_")}_result'] = step_result
|
|
132
132
|
|
|
133
|
-
self.logger.info(f
|
|
133
|
+
self.logger.info(f"✅ Step {i + 1} completed in {step_duration:.2f}s")
|
|
134
134
|
|
|
135
135
|
execution_time = time.time() - start_time
|
|
136
|
-
self.logger.info(f
|
|
136
|
+
self.logger.info(f"🎉 Workflow completed in {execution_time:.2f}s")
|
|
137
137
|
|
|
138
138
|
return WorkflowResult(
|
|
139
139
|
status=WorkflowStatus.COMPLETED, data=context, execution_time=execution_time, step_results=step_results
|
|
@@ -145,9 +145,9 @@ class WorkflowAgent(abc.ABC):
|
|
|
145
145
|
from lattifai.errors import LattifAIError
|
|
146
146
|
|
|
147
147
|
if isinstance(e, LattifAIError):
|
|
148
|
-
self.logger.error(f
|
|
148
|
+
self.logger.error(f"❌ Workflow failed after {execution_time:.2f}s: [{e.error_code}] {e.message}")
|
|
149
149
|
else:
|
|
150
|
-
self.logger.error(f
|
|
150
|
+
self.logger.error(f"❌ Workflow failed after {execution_time:.2f}s: {str(e)}")
|
|
151
151
|
|
|
152
152
|
return WorkflowResult(
|
|
153
153
|
status=WorkflowStatus.FAILED,
|
|
@@ -164,7 +164,7 @@ class WorkflowAgent(abc.ABC):
|
|
|
164
164
|
for attempt in range(step.max_retries + 1):
|
|
165
165
|
try:
|
|
166
166
|
if attempt > 0:
|
|
167
|
-
self.logger.info(f
|
|
167
|
+
self.logger.info(f"🔄 Retrying step {step.name} (attempt {attempt + 1}/{step.max_retries + 1})")
|
|
168
168
|
|
|
169
169
|
result = await self.execute_step(step, context)
|
|
170
170
|
return result
|
|
@@ -176,14 +176,14 @@ class WorkflowAgent(abc.ABC):
|
|
|
176
176
|
# For LattifAI errors, show simplified message in logs
|
|
177
177
|
from lattifai.errors import LattifAIError
|
|
178
178
|
|
|
179
|
-
error_summary = f
|
|
179
|
+
error_summary = f"[{e.error_code}]" if isinstance(e, LattifAIError) else str(e)[:100]
|
|
180
180
|
|
|
181
181
|
if step.should_retry():
|
|
182
|
-
self.logger.warning(f
|
|
182
|
+
self.logger.warning(f"⚠️ Step {step.name} failed: {error_summary}. Retrying...")
|
|
183
183
|
continue
|
|
184
184
|
else:
|
|
185
185
|
self.logger.error(
|
|
186
|
-
f
|
|
186
|
+
f"❌ Step {step.name} failed after {step.max_retries + 1} attempts: {error_summary}"
|
|
187
187
|
)
|
|
188
188
|
raise e
|
|
189
189
|
|