zero-agent 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentz/agent/base.py +262 -0
- agentz/artifacts/__init__.py +5 -0
- agentz/artifacts/artifact_writer.py +538 -0
- agentz/artifacts/reporter.py +235 -0
- agentz/artifacts/terminal_writer.py +100 -0
- agentz/context/__init__.py +6 -0
- agentz/context/context.py +91 -0
- agentz/context/conversation.py +205 -0
- agentz/context/data_store.py +208 -0
- agentz/llm/llm_setup.py +156 -0
- agentz/mcp/manager.py +142 -0
- agentz/mcp/patches.py +88 -0
- agentz/mcp/servers/chrome_devtools/server.py +14 -0
- agentz/profiles/base.py +108 -0
- agentz/profiles/data/data_analysis.py +38 -0
- agentz/profiles/data/data_loader.py +35 -0
- agentz/profiles/data/evaluation.py +43 -0
- agentz/profiles/data/model_training.py +47 -0
- agentz/profiles/data/preprocessing.py +47 -0
- agentz/profiles/data/visualization.py +47 -0
- agentz/profiles/manager/evaluate.py +51 -0
- agentz/profiles/manager/memory.py +62 -0
- agentz/profiles/manager/observe.py +48 -0
- agentz/profiles/manager/routing.py +66 -0
- agentz/profiles/manager/writer.py +51 -0
- agentz/profiles/mcp/browser.py +21 -0
- agentz/profiles/mcp/chrome.py +21 -0
- agentz/profiles/mcp/notion.py +21 -0
- agentz/runner/__init__.py +74 -0
- agentz/runner/base.py +28 -0
- agentz/runner/executor.py +320 -0
- agentz/runner/hooks.py +110 -0
- agentz/runner/iteration.py +142 -0
- agentz/runner/patterns.py +215 -0
- agentz/runner/tracker.py +188 -0
- agentz/runner/utils.py +45 -0
- agentz/runner/workflow.py +250 -0
- agentz/tools/__init__.py +20 -0
- agentz/tools/data_tools/__init__.py +17 -0
- agentz/tools/data_tools/data_analysis.py +152 -0
- agentz/tools/data_tools/data_loading.py +92 -0
- agentz/tools/data_tools/evaluation.py +175 -0
- agentz/tools/data_tools/helpers.py +120 -0
- agentz/tools/data_tools/model_training.py +192 -0
- agentz/tools/data_tools/preprocessing.py +229 -0
- agentz/tools/data_tools/visualization.py +281 -0
- agentz/utils/__init__.py +69 -0
- agentz/utils/config.py +708 -0
- agentz/utils/helpers.py +10 -0
- agentz/utils/parsers.py +142 -0
- agentz/utils/printer.py +539 -0
- pipelines/base.py +972 -0
- pipelines/data_scientist.py +97 -0
- pipelines/data_scientist_memory.py +151 -0
- pipelines/experience_learner.py +0 -0
- pipelines/prompt_generator.py +0 -0
- pipelines/simple.py +78 -0
- pipelines/simple_browser.py +145 -0
- pipelines/simple_chrome.py +75 -0
- pipelines/simple_notion.py +103 -0
- pipelines/tool_builder.py +0 -0
- zero_agent-0.1.0.dist-info/METADATA +269 -0
- zero_agent-0.1.0.dist-info/RECORD +66 -0
- zero_agent-0.1.0.dist-info/WHEEL +5 -0
- zero_agent-0.1.0.dist-info/licenses/LICENSE +21 -0
- zero_agent-0.1.0.dist-info/top_level.txt +2 -0
pipelines/base.py
ADDED
@@ -0,0 +1,972 @@
|
|
1
|
+
import asyncio
|
2
|
+
import time
|
3
|
+
from contextlib import contextmanager
|
4
|
+
from pathlib import Path
|
5
|
+
from typing import Any, Awaitable, Callable, Dict, List, Mapping, Optional, Tuple, Union
|
6
|
+
|
7
|
+
from loguru import logger
|
8
|
+
from rich.console import Console
|
9
|
+
|
10
|
+
from agents.tracing.create import function_span
|
11
|
+
from agentz.utils.config import BaseConfig, resolve_config
|
12
|
+
from agentz.runner import (
|
13
|
+
AgentExecutor,
|
14
|
+
RuntimeTracker,
|
15
|
+
HookRegistry,
|
16
|
+
execute_tool_plan,
|
17
|
+
execute_tools,
|
18
|
+
run_manager_tool_loop,
|
19
|
+
record_structured_payload,
|
20
|
+
serialize_output,
|
21
|
+
)
|
22
|
+
from agentz.artifacts import RunReporter
|
23
|
+
from agentz.utils import Printer, get_experiment_timestamp
|
24
|
+
from pydantic import BaseModel
|
25
|
+
|
26
|
+
|
27
|
+
class BasePipeline:
|
28
|
+
"""Base class for all pipelines with common configuration and setup."""
|
29
|
+
|
30
|
+
# Constants for iteration group IDs
|
31
|
+
ITERATION_GROUP_PREFIX = "iter"
|
32
|
+
FINAL_GROUP_ID = "iter-final"
|
33
|
+
|
34
|
+
def __init__(self, config: Union[str, Path, Mapping[str, Any], BaseConfig]):
|
35
|
+
"""Initialize the pipeline using a single configuration input.
|
36
|
+
|
37
|
+
Args:
|
38
|
+
spec: Configuration specification:
|
39
|
+
- str/Path: Load YAML/JSON file
|
40
|
+
- dict with 'config_path': Load file, then deep-merge dict on top (dict wins)
|
41
|
+
- dict without 'config_path': Use as-is
|
42
|
+
- BaseConfig: Use as-is
|
43
|
+
strict: Whether to strictly validate configuration (default: True).
|
44
|
+
|
45
|
+
Examples:
|
46
|
+
# Load from file
|
47
|
+
BasePipeline("pipelines/configs/data_science.yaml")
|
48
|
+
|
49
|
+
# Dict without config_path
|
50
|
+
BasePipeline({"provider": "openai", "data": {"path": "data.csv"}})
|
51
|
+
|
52
|
+
# Dict that patches a file (use 'config_path')
|
53
|
+
BasePipeline({
|
54
|
+
"config_path": "pipelines/configs/data_science.yaml",
|
55
|
+
"data": {"path": "data/banana_quality.csv"},
|
56
|
+
"user_prompt": "Custom prompt..."
|
57
|
+
})
|
58
|
+
|
59
|
+
# BaseConfig object
|
60
|
+
BasePipeline(BaseConfig(provider="openai", data={"path": "data.csv"}))
|
61
|
+
"""
|
62
|
+
self.console = Console()
|
63
|
+
self._printer: Optional[Printer] = None
|
64
|
+
self.reporter: Optional[RunReporter] = None
|
65
|
+
|
66
|
+
# Resolve configuration using the new unified API
|
67
|
+
self.config = resolve_config(config)
|
68
|
+
|
69
|
+
# Generic pipeline settings
|
70
|
+
self.experiment_id = get_experiment_timestamp()
|
71
|
+
|
72
|
+
pipeline_settings = self.config.pipeline
|
73
|
+
default_slug = self.__class__.__name__.replace("Pipeline", "").lower()
|
74
|
+
self.pipeline_slug = (
|
75
|
+
pipeline_settings.get("slug")
|
76
|
+
or pipeline_settings.get("name")
|
77
|
+
or default_slug
|
78
|
+
)
|
79
|
+
self.workflow_name = (
|
80
|
+
pipeline_settings.get("workflow_name")
|
81
|
+
or pipeline_settings.get("name")
|
82
|
+
)
|
83
|
+
if not self.workflow_name:
|
84
|
+
# Default pattern: use class name + experiment_id
|
85
|
+
pipeline_name = self.__class__.__name__.replace("Pipeline", "").lower()
|
86
|
+
self.workflow_name = f"{pipeline_name}_{self.experiment_id}"
|
87
|
+
|
88
|
+
self.verbose = pipeline_settings.get("verbose", True)
|
89
|
+
self.max_iterations = pipeline_settings.get("max_iterations", 5)
|
90
|
+
self.max_time_minutes = pipeline_settings.get("max_time_minutes", 10)
|
91
|
+
|
92
|
+
# Research workflow name (optional, for pipelines with research components)
|
93
|
+
self.research_workflow_name = pipeline_settings.get(
|
94
|
+
"research_workflow_name",
|
95
|
+
f"researcher_{self.experiment_id}",
|
96
|
+
)
|
97
|
+
|
98
|
+
# Iterative pipeline state
|
99
|
+
self.iteration = 0
|
100
|
+
self.start_time: Optional[float] = None
|
101
|
+
self.should_continue = True
|
102
|
+
self.constraint_reason = ""
|
103
|
+
|
104
|
+
# Setup tracing configuration and logging
|
105
|
+
self._setup_tracing()
|
106
|
+
|
107
|
+
# Initialize runtime tracker and executor
|
108
|
+
self._runtime_tracker: Optional[RuntimeTracker] = None
|
109
|
+
self._executor: Optional[AgentExecutor] = None
|
110
|
+
|
111
|
+
# Initialize hook registry
|
112
|
+
self._hook_registry = HookRegistry()
|
113
|
+
|
114
|
+
# ============================================
|
115
|
+
# Core Properties
|
116
|
+
# ============================================
|
117
|
+
|
118
|
+
@property
|
119
|
+
def enable_tracing(self) -> bool:
|
120
|
+
"""Get tracing enabled flag from config."""
|
121
|
+
return self.config.pipeline.get("enable_tracing", True)
|
122
|
+
|
123
|
+
@property
|
124
|
+
def trace_sensitive(self) -> bool:
|
125
|
+
"""Get trace sensitive data flag from config."""
|
126
|
+
return self.config.pipeline.get("trace_include_sensitive_data", False)
|
127
|
+
|
128
|
+
@property
|
129
|
+
def state(self) -> Optional[Any]:
|
130
|
+
"""Get pipeline state if available."""
|
131
|
+
if hasattr(self, 'context') and hasattr(self.context, 'state'):
|
132
|
+
return self.context.state
|
133
|
+
return None
|
134
|
+
|
135
|
+
@property
|
136
|
+
def printer(self) -> Optional[Printer]:
|
137
|
+
return self._printer
|
138
|
+
|
139
|
+
@property
|
140
|
+
def runtime_tracker(self) -> RuntimeTracker:
|
141
|
+
"""Get or create the runtime tracker."""
|
142
|
+
if self._runtime_tracker is None:
|
143
|
+
self._runtime_tracker = RuntimeTracker(
|
144
|
+
printer=self.printer,
|
145
|
+
enable_tracing=self.enable_tracing,
|
146
|
+
trace_sensitive=self.trace_sensitive,
|
147
|
+
iteration=self.iteration,
|
148
|
+
experiment_id=self.experiment_id,
|
149
|
+
reporter=self.reporter,
|
150
|
+
)
|
151
|
+
else:
|
152
|
+
# Update iteration in existing tracker
|
153
|
+
self._runtime_tracker.iteration = self.iteration
|
154
|
+
self._runtime_tracker.printer = self.printer
|
155
|
+
self._runtime_tracker.reporter = self.reporter
|
156
|
+
return self._runtime_tracker
|
157
|
+
|
158
|
+
@property
|
159
|
+
def executor(self) -> AgentExecutor:
|
160
|
+
"""Get or create the agent executor."""
|
161
|
+
# Refresh runtime tracker so iteration/printer stay in sync across loops
|
162
|
+
tracker = self.runtime_tracker
|
163
|
+
|
164
|
+
if self._executor is None:
|
165
|
+
self._executor = AgentExecutor(tracker)
|
166
|
+
else:
|
167
|
+
# Executor holds a reference to the tracker; update it in case it changed
|
168
|
+
self._executor.context = tracker
|
169
|
+
return self._executor
|
170
|
+
|
171
|
+
# ============================================
|
172
|
+
# Printer & Reporter Management
|
173
|
+
# ============================================
|
174
|
+
|
175
|
+
def start_printer(self) -> Printer:
|
176
|
+
if self._printer is None:
|
177
|
+
self._printer = Printer(self.console)
|
178
|
+
return self._printer
|
179
|
+
|
180
|
+
def stop_printer(self) -> None:
|
181
|
+
"""Stop the live printer and finalize reporter if active."""
|
182
|
+
if self._printer is not None:
|
183
|
+
self._printer.end()
|
184
|
+
self._printer = None
|
185
|
+
if self.reporter is not None:
|
186
|
+
self.reporter.finalize()
|
187
|
+
self.reporter.print_terminal_report()
|
188
|
+
|
189
|
+
def start_group(
|
190
|
+
self,
|
191
|
+
group_id: str,
|
192
|
+
*,
|
193
|
+
title: Optional[str] = None,
|
194
|
+
border_style: Optional[str] = None,
|
195
|
+
iteration: Optional[int] = None,
|
196
|
+
) -> None:
|
197
|
+
"""Start a printer group and notify the reporter."""
|
198
|
+
if self.reporter:
|
199
|
+
self.reporter.record_group_start(
|
200
|
+
group_id=group_id,
|
201
|
+
title=title,
|
202
|
+
border_style=border_style,
|
203
|
+
iteration=iteration,
|
204
|
+
)
|
205
|
+
if self.printer:
|
206
|
+
self.printer.start_group(
|
207
|
+
group_id,
|
208
|
+
title=title,
|
209
|
+
border_style=border_style,
|
210
|
+
)
|
211
|
+
|
212
|
+
def end_group(
|
213
|
+
self,
|
214
|
+
group_id: str,
|
215
|
+
*,
|
216
|
+
is_done: bool = True,
|
217
|
+
title: Optional[str] = None,
|
218
|
+
) -> None:
|
219
|
+
"""Mark a printer group complete and notify the reporter."""
|
220
|
+
if self.reporter:
|
221
|
+
self.reporter.record_group_end(
|
222
|
+
group_id=group_id,
|
223
|
+
is_done=is_done,
|
224
|
+
title=title,
|
225
|
+
)
|
226
|
+
if self.printer:
|
227
|
+
self.printer.end_group(
|
228
|
+
group_id,
|
229
|
+
is_done=is_done,
|
230
|
+
title=title,
|
231
|
+
)
|
232
|
+
|
233
|
+
# ============================================
|
234
|
+
# Initialization & Setup
|
235
|
+
# ============================================
|
236
|
+
|
237
|
+
def _initialize_run(self, additional_logging=None):
|
238
|
+
"""Initialize a pipeline run with logging, printer, and tracing.
|
239
|
+
|
240
|
+
Args:
|
241
|
+
additional_logging: Optional callable for pipeline-specific logging
|
242
|
+
|
243
|
+
Returns:
|
244
|
+
Trace context manager for the workflow
|
245
|
+
"""
|
246
|
+
# Basic logging
|
247
|
+
logger.info(
|
248
|
+
f"Running {self.__class__.__name__} with experiment_id: {self.experiment_id}"
|
249
|
+
)
|
250
|
+
|
251
|
+
# Pipeline-specific logging
|
252
|
+
if additional_logging:
|
253
|
+
additional_logging()
|
254
|
+
|
255
|
+
outputs_dir = Path(self.config.pipeline.get("outputs_dir", "outputs"))
|
256
|
+
if self.reporter is None:
|
257
|
+
self.reporter = RunReporter(
|
258
|
+
base_dir=outputs_dir,
|
259
|
+
pipeline_slug=self.pipeline_slug,
|
260
|
+
workflow_name=self.workflow_name,
|
261
|
+
experiment_id=self.experiment_id,
|
262
|
+
console=self.console,
|
263
|
+
)
|
264
|
+
self.reporter.start(self.config)
|
265
|
+
|
266
|
+
# Start printer and update workflow
|
267
|
+
self.start_printer()
|
268
|
+
if self.printer:
|
269
|
+
self.printer.update_item(
|
270
|
+
"workflow",
|
271
|
+
f"Workflow: {self.workflow_name}",
|
272
|
+
is_done=True,
|
273
|
+
hide_checkmark=True,
|
274
|
+
)
|
275
|
+
|
276
|
+
# Create trace context
|
277
|
+
trace_metadata = {
|
278
|
+
"experiment_id": self.experiment_id,
|
279
|
+
"includes_sensitive_data": "true" if self.trace_sensitive else "false",
|
280
|
+
}
|
281
|
+
return self.trace_context(self.workflow_name, metadata=trace_metadata)
|
282
|
+
|
283
|
+
def _setup_tracing(self) -> None:
|
284
|
+
"""Setup tracing configuration with user-friendly output.
|
285
|
+
|
286
|
+
Subclasses can override this method to add pipeline-specific information.
|
287
|
+
"""
|
288
|
+
if self.enable_tracing:
|
289
|
+
pipeline_name = self.__class__.__name__.replace("Pipeline", "")
|
290
|
+
self.console.print(f"🍌 Starting {pipeline_name} Pipeline with Tracing")
|
291
|
+
self.console.print(f"🔧 Provider: {self.config.provider}")
|
292
|
+
self.console.print(f"🤖 Model: {self.config.llm.model_name}")
|
293
|
+
self.console.print("🔍 Tracing: Enabled")
|
294
|
+
self.console.print(
|
295
|
+
f"🔒 Sensitive Data in Traces: {'Yes' if self.trace_sensitive else 'No'}"
|
296
|
+
)
|
297
|
+
self.console.print(f"🏷️ Workflow: {self.workflow_name}")
|
298
|
+
else:
|
299
|
+
pipeline_name = self.__class__.__name__.replace("Pipeline", "")
|
300
|
+
self.console.print(f"🍌 Starting {pipeline_name} Pipeline")
|
301
|
+
self.console.print(f"🔧 Provider: {self.config.provider}")
|
302
|
+
self.console.print(f"🤖 Model: {self.config.llm.model_name}")
|
303
|
+
|
304
|
+
def trace_context(self, name: str, metadata: Optional[Dict[str, Any]] = None):
|
305
|
+
"""Create a trace context - delegates to RuntimeTracker."""
|
306
|
+
return self.runtime_tracker.trace_context(name, metadata=metadata)
|
307
|
+
|
308
|
+
def span_context(self, span_factory, **kwargs):
|
309
|
+
"""Create a span context - delegates to RuntimeTracker."""
|
310
|
+
return self.runtime_tracker.span_context(span_factory, **kwargs)
|
311
|
+
|
312
|
+
async def agent_step(self, *args, **kwargs) -> Any:
|
313
|
+
"""Run an agent with span tracking and optional output parsing.
|
314
|
+
|
315
|
+
Delegates to AgentExecutor.agent_step(). See AgentExecutor.agent_step() for full documentation.
|
316
|
+
"""
|
317
|
+
return await self.executor.agent_step(*args, **kwargs)
|
318
|
+
|
319
|
+
def update_printer(self, *args, **kwargs) -> None:
|
320
|
+
"""Update printer status if printer is active.
|
321
|
+
|
322
|
+
Delegates to RuntimeTracker.update_printer(). See RuntimeTracker.update_printer() for full documentation.
|
323
|
+
"""
|
324
|
+
self.runtime_tracker.update_printer(*args, **kwargs)
|
325
|
+
|
326
|
+
# ============================================
|
327
|
+
# Context Managers & Utilities
|
328
|
+
# ============================================
|
329
|
+
|
330
|
+
@contextmanager
|
331
|
+
def run_context(self, additional_logging: Optional[Callable] = None):
|
332
|
+
"""Context manager for run lifecycle handling.
|
333
|
+
|
334
|
+
Manages trace context initialization, printer lifecycle, and cleanup.
|
335
|
+
Automatically starts the pipeline timer for constraint checking.
|
336
|
+
|
337
|
+
Args:
|
338
|
+
additional_logging: Optional callable for pipeline-specific logging
|
339
|
+
|
340
|
+
Yields:
|
341
|
+
Trace context for the workflow
|
342
|
+
"""
|
343
|
+
# Start pipeline timer for constraint checking
|
344
|
+
self.start_time = time.time()
|
345
|
+
|
346
|
+
trace_ctx = self._initialize_run(additional_logging)
|
347
|
+
try:
|
348
|
+
with trace_ctx:
|
349
|
+
yield trace_ctx
|
350
|
+
finally:
|
351
|
+
self.stop_printer()
|
352
|
+
|
353
|
+
async def run_span_step(self, *args, **kwargs) -> Any:
|
354
|
+
"""Execute a step with span context and printer updates.
|
355
|
+
|
356
|
+
Delegates to AgentExecutor.run_span_step(). See AgentExecutor.run_span_step() for full documentation.
|
357
|
+
"""
|
358
|
+
return await self.executor.run_span_step(*args, **kwargs)
|
359
|
+
|
360
|
+
# ============================================
|
361
|
+
# Iteration & Group Management
|
362
|
+
# ============================================
|
363
|
+
|
364
|
+
def begin_iteration(
|
365
|
+
self,
|
366
|
+
title: Optional[str] = None,
|
367
|
+
border_style: str = "white"
|
368
|
+
) -> Tuple[Any, str]:
|
369
|
+
"""Begin a new iteration with its associated group.
|
370
|
+
|
371
|
+
Combines context.begin_iteration() + start_group() into a single call.
|
372
|
+
|
373
|
+
Args:
|
374
|
+
title: Optional custom title (default: "Iteration {index}")
|
375
|
+
border_style: Border style for the group (default: "white")
|
376
|
+
|
377
|
+
Returns:
|
378
|
+
Tuple of (iteration_record, group_id)
|
379
|
+
"""
|
380
|
+
iteration, group_id = self.context.begin_iteration()
|
381
|
+
self.iteration = iteration.index
|
382
|
+
|
383
|
+
display_title = title or f"Iteration {iteration.index}"
|
384
|
+
self.start_group(
|
385
|
+
group_id,
|
386
|
+
title=display_title,
|
387
|
+
border_style=border_style,
|
388
|
+
iteration=iteration.index,
|
389
|
+
)
|
390
|
+
|
391
|
+
return iteration, group_id
|
392
|
+
|
393
|
+
def end_iteration(self, group_id: str, is_done: bool = True) -> None:
|
394
|
+
"""End the current iteration and its associated group.
|
395
|
+
|
396
|
+
Combines context.mark_iteration_complete() + end_group() into a single call.
|
397
|
+
|
398
|
+
Args:
|
399
|
+
group_id: The group ID to close
|
400
|
+
is_done: Whether the iteration completed successfully (default: True)
|
401
|
+
"""
|
402
|
+
self.context.mark_iteration_complete()
|
403
|
+
self.end_group(group_id, is_done=is_done)
|
404
|
+
|
405
|
+
def begin_final_report(
|
406
|
+
self,
|
407
|
+
title: str = "Final Report",
|
408
|
+
border_style: str = "white"
|
409
|
+
) -> str:
|
410
|
+
"""Begin the final report phase with its associated group.
|
411
|
+
|
412
|
+
Combines context.begin_final_report() + start_group() into a single call.
|
413
|
+
|
414
|
+
Args:
|
415
|
+
title: Title for the final report group (default: "Final Report")
|
416
|
+
border_style: Border style for the group (default: "white")
|
417
|
+
|
418
|
+
Returns:
|
419
|
+
The final report group_id
|
420
|
+
"""
|
421
|
+
_, group_id = self.context.begin_final_report()
|
422
|
+
self.start_group(group_id, title=title, border_style=border_style)
|
423
|
+
return group_id
|
424
|
+
|
425
|
+
def end_final_report(self, group_id: str, is_done: bool = True) -> None:
|
426
|
+
"""End the final report phase and its associated group.
|
427
|
+
|
428
|
+
Combines context.mark_final_complete() + end_group() into a single call.
|
429
|
+
|
430
|
+
Args:
|
431
|
+
group_id: The final report group ID to close
|
432
|
+
is_done: Whether the final report completed successfully (default: True)
|
433
|
+
"""
|
434
|
+
self.context.mark_final_complete()
|
435
|
+
self.end_group(group_id, is_done=is_done)
|
436
|
+
|
437
|
+
def prepare_query(
|
438
|
+
self,
|
439
|
+
content: str,
|
440
|
+
step_key: str = "prepare_query",
|
441
|
+
span_name: str = "prepare_research_query",
|
442
|
+
start_msg: str = "Preparing research query...",
|
443
|
+
done_msg: str = "Research query prepared"
|
444
|
+
) -> str:
|
445
|
+
"""Prepare query/content with span context and printer updates.
|
446
|
+
|
447
|
+
Args:
|
448
|
+
content: The query/content to prepare
|
449
|
+
step_key: Printer status key
|
450
|
+
span_name: Name for the span
|
451
|
+
start_msg: Start message for printer
|
452
|
+
done_msg: Completion message for printer
|
453
|
+
|
454
|
+
Returns:
|
455
|
+
The prepared content
|
456
|
+
"""
|
457
|
+
self.update_printer(step_key, start_msg)
|
458
|
+
|
459
|
+
with self.span_context(function_span, name=span_name) as span:
|
460
|
+
logger.debug(f"Prepared {span_name}: {content}")
|
461
|
+
|
462
|
+
if span and hasattr(span, "set_output"):
|
463
|
+
span.set_output({"output_preview": content[:200]})
|
464
|
+
|
465
|
+
self.update_printer(step_key, done_msg, is_done=True)
|
466
|
+
return content
|
467
|
+
|
468
|
+
def _log_message(self, message: str) -> None:
|
469
|
+
"""Log a message using the configured logger."""
|
470
|
+
logger.info(message)
|
471
|
+
|
472
|
+
# ============================================
|
473
|
+
# Execution Entry Points
|
474
|
+
# ============================================
|
475
|
+
|
476
|
+
def run_sync(self, *args, **kwargs):
|
477
|
+
"""Synchronous wrapper for the async run method."""
|
478
|
+
return asyncio.run(self.run(*args, **kwargs))
|
479
|
+
|
480
|
+
async def run(self, query: Any = None) -> Any:
|
481
|
+
"""Template method - DO NOT override in subclasses.
|
482
|
+
|
483
|
+
This method provides the fixed lifecycle structure:
|
484
|
+
1. Initialize pipeline
|
485
|
+
2. Before execution hooks
|
486
|
+
3. Main execution (delegated to execute())
|
487
|
+
4. After execution hooks
|
488
|
+
5. Finalization
|
489
|
+
|
490
|
+
Override execute() instead to implement pipeline logic.
|
491
|
+
|
492
|
+
Args:
|
493
|
+
query: Optional query input (can be None for pipelines without input)
|
494
|
+
|
495
|
+
Returns:
|
496
|
+
Final result from finalize()
|
497
|
+
"""
|
498
|
+
with self.run_context():
|
499
|
+
# Phase 1: Setup
|
500
|
+
await self.initialize_pipeline(query)
|
501
|
+
|
502
|
+
# Phase 2: Pre-execution hooks
|
503
|
+
await self.before_execution()
|
504
|
+
await self._hook_registry.trigger("before_execution", context=self)
|
505
|
+
|
506
|
+
# Phase 3: Main execution (delegated to subclass)
|
507
|
+
result = await self.execute()
|
508
|
+
|
509
|
+
# Phase 4: Post-execution hooks
|
510
|
+
await self._hook_registry.trigger("after_execution", context=self, result=result)
|
511
|
+
await self.after_execution(result)
|
512
|
+
|
513
|
+
# Phase 5: Finalization
|
514
|
+
final_result = await self.finalize(result)
|
515
|
+
|
516
|
+
return final_result
|
517
|
+
|
518
|
+
# ============================================
|
519
|
+
# Lifecycle Hook Methods (Override in Subclasses)
|
520
|
+
# ============================================
|
521
|
+
|
522
|
+
async def initialize_pipeline(self, query: Any) -> None:
|
523
|
+
"""Initialize pipeline state and format query.
|
524
|
+
|
525
|
+
Default implementation:
|
526
|
+
- Formats query via format_query()
|
527
|
+
- Sets state query
|
528
|
+
- Updates printer status
|
529
|
+
|
530
|
+
Override this for custom initialization logic.
|
531
|
+
|
532
|
+
Args:
|
533
|
+
query: Input query (can be None)
|
534
|
+
"""
|
535
|
+
if query is not None:
|
536
|
+
formatted_query = self.format_query(query)
|
537
|
+
if self.state:
|
538
|
+
self.state.set_query(formatted_query)
|
539
|
+
self.update_printer("initialization", "Pipeline initialized", is_done=True)
|
540
|
+
|
541
|
+
def format_query(self, query: Any) -> str:
|
542
|
+
"""Transform input query to formatted string.
|
543
|
+
|
544
|
+
Default behavior (in order of priority):
|
545
|
+
1. If query has a format() method, call it
|
546
|
+
2. If query is a BaseModel, return model_dump_json()
|
547
|
+
3. Otherwise, return str(query)
|
548
|
+
|
549
|
+
Override this to customize query formatting.
|
550
|
+
|
551
|
+
Args:
|
552
|
+
query: Input query
|
553
|
+
|
554
|
+
Returns:
|
555
|
+
Formatted query string
|
556
|
+
"""
|
557
|
+
if hasattr(query, 'format') and callable(getattr(query, 'format')):
|
558
|
+
return query.format()
|
559
|
+
if isinstance(query, BaseModel):
|
560
|
+
return query.model_dump_json(indent=2)
|
561
|
+
return str(query)
|
562
|
+
|
563
|
+
async def before_execution(self) -> None:
|
564
|
+
"""Hook called before execute().
|
565
|
+
|
566
|
+
Use for:
|
567
|
+
- Data loading/validation
|
568
|
+
- Resource initialization
|
569
|
+
- Pre-flight checks
|
570
|
+
|
571
|
+
Override this for custom pre-execution logic.
|
572
|
+
"""
|
573
|
+
pass
|
574
|
+
|
575
|
+
async def after_execution(self, result: Any) -> None: # noqa: ARG002
|
576
|
+
"""Hook called after execute() completes.
|
577
|
+
|
578
|
+
Use for:
|
579
|
+
- Result validation
|
580
|
+
- Cleanup operations
|
581
|
+
- State aggregation
|
582
|
+
|
583
|
+
Override this for custom post-execution logic.
|
584
|
+
|
585
|
+
Args:
|
586
|
+
result: The return value from execute()
|
587
|
+
"""
|
588
|
+
pass
|
589
|
+
|
590
|
+
async def finalize(self, result: Any) -> Any:
|
591
|
+
"""Finalization phase - prepare final return value.
|
592
|
+
|
593
|
+
Default implementation:
|
594
|
+
- Returns context.state.final_report if available
|
595
|
+
- Otherwise returns result as-is
|
596
|
+
|
597
|
+
Override this for custom finalization logic.
|
598
|
+
|
599
|
+
Args:
|
600
|
+
result: The return value from execute()
|
601
|
+
|
602
|
+
Returns:
|
603
|
+
Final result to return from run()
|
604
|
+
"""
|
605
|
+
if self.state:
|
606
|
+
return self.state.final_report
|
607
|
+
return result
|
608
|
+
|
609
|
+
# ============================================
|
610
|
+
# Abstract Execute Method (Must Implement in Subclasses)
|
611
|
+
# ============================================
|
612
|
+
|
613
|
+
async def execute(self) -> Any:
|
614
|
+
"""Main execution logic - implement in subclass.
|
615
|
+
|
616
|
+
This is where your pipeline logic goes. You have complete freedom:
|
617
|
+
- Iterative loops (use run_iterative_loop helper)
|
618
|
+
- Single-shot execution
|
619
|
+
- Multi-phase workflows
|
620
|
+
- Custom control flow (branching, conditional, parallel)
|
621
|
+
- Mix of patterns
|
622
|
+
|
623
|
+
Returns:
|
624
|
+
Any result value (passed to after_execution and finalize)
|
625
|
+
|
626
|
+
Examples:
|
627
|
+
# Iterative pattern
|
628
|
+
async def execute(self):
|
629
|
+
return await self.run_iterative_loop(
|
630
|
+
iteration_body=self._do_iteration,
|
631
|
+
final_body=self._write_report
|
632
|
+
)
|
633
|
+
|
634
|
+
# Single-shot pattern
|
635
|
+
async def execute(self):
|
636
|
+
data = await self.load_data()
|
637
|
+
analysis = await self.analyze(data)
|
638
|
+
return await self.generate_report(analysis)
|
639
|
+
|
640
|
+
# Multi-phase pattern
|
641
|
+
async def execute(self):
|
642
|
+
exploration = await self._explore_phase()
|
643
|
+
if exploration.needs_deep_dive:
|
644
|
+
deep_dive = await self._deep_dive_phase()
|
645
|
+
return await self._synthesize(exploration, deep_dive)
|
646
|
+
"""
|
647
|
+
raise NotImplementedError("Subclasses must implement execute()")
|
648
|
+
|
649
|
+
# ============================================
|
650
|
+
# Workflow Helper Methods
|
651
|
+
# ============================================
|
652
|
+
|
653
|
+
async def run_iterative_loop(
|
654
|
+
self,
|
655
|
+
iteration_body: Callable[[Any, str], Awaitable[Any]],
|
656
|
+
final_body: Optional[Callable[[str], Awaitable[Any]]] = None,
|
657
|
+
should_continue: Optional[Callable[[], bool]] = None,
|
658
|
+
) -> Any:
|
659
|
+
"""Execute standard iterative loop pattern.
|
660
|
+
|
661
|
+
Args:
|
662
|
+
iteration_body: Async function(iteration, group_id) -> result
|
663
|
+
final_body: Optional async function(final_group_id) -> result
|
664
|
+
should_continue: Optional custom condition (default: checks max iterations/time)
|
665
|
+
|
666
|
+
Returns:
|
667
|
+
Result from final_body if provided, else None
|
668
|
+
"""
|
669
|
+
should_continue_fn = should_continue or self._should_continue_iteration
|
670
|
+
|
671
|
+
while should_continue_fn():
|
672
|
+
# Begin iteration with its group
|
673
|
+
iteration, group_id = self.begin_iteration()
|
674
|
+
|
675
|
+
# Trigger before hooks
|
676
|
+
await self._hook_registry.trigger(
|
677
|
+
"before_iteration",
|
678
|
+
context=self.context,
|
679
|
+
iteration=iteration,
|
680
|
+
group_id=group_id
|
681
|
+
)
|
682
|
+
|
683
|
+
try:
|
684
|
+
await iteration_body(iteration, group_id)
|
685
|
+
finally:
|
686
|
+
# Trigger after hooks
|
687
|
+
await self._hook_registry.trigger(
|
688
|
+
"after_iteration",
|
689
|
+
context=self.context,
|
690
|
+
iteration=iteration,
|
691
|
+
group_id=group_id
|
692
|
+
)
|
693
|
+
# End iteration with its group
|
694
|
+
self.end_iteration(group_id)
|
695
|
+
|
696
|
+
# Check if state indicates completion
|
697
|
+
if self.state and self.state.complete:
|
698
|
+
break
|
699
|
+
|
700
|
+
# Execute final body if provided
|
701
|
+
result = None
|
702
|
+
if final_body:
|
703
|
+
final_group = self.begin_final_report()
|
704
|
+
result = await final_body(final_group)
|
705
|
+
self.end_final_report(final_group)
|
706
|
+
|
707
|
+
return result
|
708
|
+
|
709
|
+
def _should_continue_iteration(self) -> bool:
|
710
|
+
"""Check if iteration should continue based on constraints.
|
711
|
+
|
712
|
+
Returns:
|
713
|
+
True if should continue, False otherwise
|
714
|
+
"""
|
715
|
+
# Check state completion
|
716
|
+
if self.state and self.state.complete:
|
717
|
+
return False
|
718
|
+
|
719
|
+
# Check max iterations
|
720
|
+
if self.iteration >= self.max_iterations:
|
721
|
+
logger.info("\n=== Ending Iteration Loop ===")
|
722
|
+
logger.info(f"Reached maximum iterations ({self.max_iterations})")
|
723
|
+
return False
|
724
|
+
|
725
|
+
# Check max time
|
726
|
+
if self.start_time is not None:
|
727
|
+
elapsed_minutes = (time.time() - self.start_time) / 60
|
728
|
+
if elapsed_minutes >= self.max_time_minutes:
|
729
|
+
logger.info("\n=== Ending Iteration Loop ===")
|
730
|
+
logger.info(f"Reached maximum time ({self.max_time_minutes} minutes)")
|
731
|
+
return False
|
732
|
+
|
733
|
+
return True
|
734
|
+
|
735
|
+
async def run_custom_group(
|
736
|
+
self,
|
737
|
+
group_id: str,
|
738
|
+
title: str,
|
739
|
+
body: Callable[[], Awaitable[Any]],
|
740
|
+
border_style: str = "white",
|
741
|
+
) -> Any:
|
742
|
+
"""Execute code within a custom printer group.
|
743
|
+
|
744
|
+
Args:
|
745
|
+
group_id: Unique group identifier
|
746
|
+
title: Display title for the group
|
747
|
+
body: Async function to execute within group
|
748
|
+
border_style: Border color for printer
|
749
|
+
|
750
|
+
Returns:
|
751
|
+
Result from body()
|
752
|
+
"""
|
753
|
+
self.start_group(group_id, title=title, border_style=border_style)
|
754
|
+
try:
|
755
|
+
result = await body()
|
756
|
+
return result
|
757
|
+
finally:
|
758
|
+
self.end_group(group_id, is_done=True)
|
759
|
+
|
760
|
+
async def run_parallel_steps(
|
761
|
+
self,
|
762
|
+
steps: Dict[str, Callable[[], Awaitable[Any]]],
|
763
|
+
group_id: Optional[str] = None,
|
764
|
+
) -> Dict[str, Any]:
|
765
|
+
"""Execute multiple steps in parallel.
|
766
|
+
|
767
|
+
Args:
|
768
|
+
steps: Dict mapping step_name -> async callable
|
769
|
+
group_id: Optional group to nest steps in
|
770
|
+
|
771
|
+
Returns:
|
772
|
+
Dict mapping step_name -> result
|
773
|
+
"""
|
774
|
+
async def run_step(name: str, fn: Callable):
|
775
|
+
key = f"{group_id}:{name}" if group_id else name
|
776
|
+
self.update_printer(key, f"Running {name}...", group_id=group_id)
|
777
|
+
result = await fn()
|
778
|
+
self.update_printer(key, f"Completed {name}", is_done=True, group_id=group_id)
|
779
|
+
return name, result
|
780
|
+
|
781
|
+
tasks = [run_step(name, fn) for name, fn in steps.items()]
|
782
|
+
completed = await asyncio.gather(*tasks)
|
783
|
+
return dict(completed)
|
784
|
+
|
785
|
+
async def run_if(
|
786
|
+
self,
|
787
|
+
condition: Union[bool, Callable[[], bool]],
|
788
|
+
body: Callable[[], Awaitable[Any]],
|
789
|
+
else_body: Optional[Callable[[], Awaitable[Any]]] = None,
|
790
|
+
) -> Any:
|
791
|
+
"""Conditional execution helper.
|
792
|
+
|
793
|
+
Args:
|
794
|
+
condition: Boolean or callable returning bool
|
795
|
+
body: Execute if condition is True
|
796
|
+
else_body: Optional execute if condition is False
|
797
|
+
|
798
|
+
Returns:
|
799
|
+
Result from executed body
|
800
|
+
"""
|
801
|
+
cond_result = condition() if callable(condition) else condition
|
802
|
+
if cond_result:
|
803
|
+
return await body()
|
804
|
+
elif else_body:
|
805
|
+
return await else_body()
|
806
|
+
return None
|
807
|
+
|
808
|
+
async def run_until(
|
809
|
+
self,
|
810
|
+
condition: Callable[[], bool],
|
811
|
+
body: Callable[[int], Awaitable[Any]],
|
812
|
+
max_iterations: Optional[int] = None,
|
813
|
+
) -> List[Any]:
|
814
|
+
"""Execute body repeatedly until condition is met.
|
815
|
+
|
816
|
+
Args:
|
817
|
+
condition: Callable returning True to stop
|
818
|
+
body: Async function(iteration_number) -> result
|
819
|
+
max_iterations: Optional max iterations (default: unlimited)
|
820
|
+
|
821
|
+
Returns:
|
822
|
+
List of results from each iteration
|
823
|
+
"""
|
824
|
+
results = []
|
825
|
+
iteration = 0
|
826
|
+
|
827
|
+
while not condition():
|
828
|
+
if max_iterations and iteration >= max_iterations:
|
829
|
+
break
|
830
|
+
|
831
|
+
result = await body(iteration)
|
832
|
+
results.append(result)
|
833
|
+
iteration += 1
|
834
|
+
|
835
|
+
return results
|
836
|
+
|
837
|
+
# ============================================
|
838
|
+
# Integration with Runner Module
|
839
|
+
# ============================================
|
840
|
+
|
841
|
+
def _record_structured_payload(self, value: object, context_label: Optional[str] = None) -> None:
|
842
|
+
"""Record a structured payload to the current iteration state.
|
843
|
+
|
844
|
+
Delegates to runner.utils.record_structured_payload.
|
845
|
+
|
846
|
+
Args:
|
847
|
+
value: The payload to record (typically a BaseModel instance)
|
848
|
+
context_label: Optional label for debugging purposes
|
849
|
+
"""
|
850
|
+
record_structured_payload(self.state, value, context_label)
|
851
|
+
|
852
|
+
def _serialize_output(self, output: Any) -> str:
|
853
|
+
"""Serialize agent output to string for storage.
|
854
|
+
|
855
|
+
Delegates to runner.utils.serialize_output.
|
856
|
+
|
857
|
+
Args:
|
858
|
+
output: The output to serialize (BaseModel, str, or other)
|
859
|
+
|
860
|
+
Returns:
|
861
|
+
String representation of the output
|
862
|
+
"""
|
863
|
+
return serialize_output(output)
|
864
|
+
|
865
|
+
async def execute_tool_plan(
|
866
|
+
self,
|
867
|
+
plan: Any,
|
868
|
+
tool_agents: Dict[str, Any],
|
869
|
+
group_id: str,
|
870
|
+
) -> None:
|
871
|
+
"""Execute a routing plan with tool agents.
|
872
|
+
|
873
|
+
Delegates to runner.patterns.execute_tool_plan.
|
874
|
+
|
875
|
+
Args:
|
876
|
+
plan: AgentSelectionPlan with tasks to execute
|
877
|
+
tool_agents: Dict mapping agent names to agent instances
|
878
|
+
group_id: Group ID for printer updates
|
879
|
+
"""
|
880
|
+
await execute_tool_plan(
|
881
|
+
plan=plan,
|
882
|
+
tool_agents=tool_agents,
|
883
|
+
group_id=group_id,
|
884
|
+
context=self.context,
|
885
|
+
agent_step_fn=self.agent_step,
|
886
|
+
update_printer_fn=self.update_printer,
|
887
|
+
)
|
888
|
+
|
889
|
+
async def _execute_tools(
|
890
|
+
self,
|
891
|
+
route_plan: Any,
|
892
|
+
tool_agents: Dict[str, Any],
|
893
|
+
group_id: str,
|
894
|
+
) -> None:
|
895
|
+
"""Execute tool agents based on routing plan.
|
896
|
+
|
897
|
+
Delegates to runner.patterns.execute_tools.
|
898
|
+
|
899
|
+
Args:
|
900
|
+
route_plan: The routing plan (can be AgentSelectionPlan or other)
|
901
|
+
tool_agents: Dict mapping agent names to agent instances
|
902
|
+
group_id: Group ID for printer updates
|
903
|
+
"""
|
904
|
+
await execute_tools(
|
905
|
+
route_plan=route_plan,
|
906
|
+
tool_agents=tool_agents,
|
907
|
+
group_id=group_id,
|
908
|
+
context=self.context,
|
909
|
+
agent_step_fn=self.agent_step,
|
910
|
+
update_printer_fn=self.update_printer,
|
911
|
+
)
|
912
|
+
|
913
|
+
# ============================================
|
914
|
+
# High-Level Workflow Patterns
|
915
|
+
# ============================================
|
916
|
+
|
917
|
+
async def run_manager_tool_loop(
|
918
|
+
self,
|
919
|
+
manager_agents: Dict[str, Any],
|
920
|
+
tool_agents: Dict[str, Any],
|
921
|
+
workflow: List[str],
|
922
|
+
) -> Any:
|
923
|
+
"""Execute standard manager-tool iterative pattern.
|
924
|
+
|
925
|
+
Delegates to runner.patterns.run_manager_tool_loop.
|
926
|
+
|
927
|
+
This pattern implements: observe → evaluate → route → execute tools → repeat.
|
928
|
+
|
929
|
+
Args:
|
930
|
+
manager_agents: Dict of manager agents (observe, evaluate, routing, writer)
|
931
|
+
tool_agents: Dict of tool agents
|
932
|
+
workflow: List of manager agent names to execute in order (e.g., ["observe", "evaluate", "routing"])
|
933
|
+
|
934
|
+
Returns:
|
935
|
+
Result from final step
|
936
|
+
"""
|
937
|
+
return await run_manager_tool_loop(
|
938
|
+
manager_agents=manager_agents,
|
939
|
+
tool_agents=tool_agents,
|
940
|
+
workflow=workflow,
|
941
|
+
context=self.context,
|
942
|
+
agent_step_fn=self.agent_step,
|
943
|
+
run_iterative_loop_fn=self.run_iterative_loop,
|
944
|
+
update_printer_fn=self.update_printer,
|
945
|
+
)
|
946
|
+
|
947
|
+
# ============================================
|
948
|
+
# Event Hook System
|
949
|
+
# ============================================
|
950
|
+
|
951
|
+
def register_hook(
|
952
|
+
self,
|
953
|
+
event: str,
|
954
|
+
callback: Callable,
|
955
|
+
priority: int = 0
|
956
|
+
) -> None:
|
957
|
+
"""Register a hook callback for an event.
|
958
|
+
|
959
|
+
Delegates to HookRegistry.register.
|
960
|
+
|
961
|
+
Args:
|
962
|
+
event: Event name (before_execution, after_execution, before_iteration, after_iteration, etc.)
|
963
|
+
callback: Callable or async callable
|
964
|
+
priority: Execution priority (higher = earlier)
|
965
|
+
|
966
|
+
Example:
|
967
|
+
def log_iteration(pipeline, iteration, group_id):
|
968
|
+
logger.info(f"Starting iteration {iteration.index}")
|
969
|
+
|
970
|
+
pipeline.register_hook("before_iteration", log_iteration)
|
971
|
+
"""
|
972
|
+
self._hook_registry.register(event, callback, priority)
|