adk-graph-workflow 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- adk_graph_workflow-0.1.0.dist-info/METADATA +13 -0
- adk_graph_workflow-0.1.0.dist-info/RECORD +13 -0
- adk_graph_workflow-0.1.0.dist-info/WHEEL +4 -0
- graph_workflow/__init__.py +104 -0
- graph_workflow/compiler.py +405 -0
- graph_workflow/data/schema.json +140 -0
- graph_workflow/errors.py +38 -0
- graph_workflow/evaluator.py +137 -0
- graph_workflow/models.py +112 -0
- graph_workflow/resolver.py +87 -0
- graph_workflow/runner.py +136 -0
- graph_workflow/schema.py +33 -0
- graph_workflow/validator.py +141 -0
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: adk-graph-workflow
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: YAML-defined graph workflows for Google ADK agents
|
|
5
|
+
Requires-Python: >=3.10
|
|
6
|
+
Requires-Dist: google-adk>=1.0.0
|
|
7
|
+
Requires-Dist: jsonschema>=4.0
|
|
8
|
+
Requires-Dist: pydantic>=2.0
|
|
9
|
+
Requires-Dist: pyyaml>=6.0
|
|
10
|
+
Provides-Extra: dev
|
|
11
|
+
Requires-Dist: pytest-asyncio>=0.21; extra == 'dev'
|
|
12
|
+
Requires-Dist: pytest-cov>=4.0; extra == 'dev'
|
|
13
|
+
Requires-Dist: pytest>=7.0; extra == 'dev'
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
graph_workflow/__init__.py,sha256=4VUu4k-oOWSAXUrZLm7VCVDW9ptYug-2GY_Y4ebHatc,3132
|
|
2
|
+
graph_workflow/compiler.py,sha256=PMHcqUM0nDeuQRT8h6AYdzm0ujFyGWQof2b-h4gfVDI,14661
|
|
3
|
+
graph_workflow/errors.py,sha256=jVSuPAJbTdIOO5enhNcuUAwVUH1JsuR-bWfNR8zmnZI,1166
|
|
4
|
+
graph_workflow/evaluator.py,sha256=Lg3T3Lmvf_uqV6qUJ9guJpSkn-2Pl5YcFgolN13VNVA,4562
|
|
5
|
+
graph_workflow/models.py,sha256=1FC6LBqViWD4-g7asohhNdrB1QOIYd_t1vsvHLUVsvA,2819
|
|
6
|
+
graph_workflow/resolver.py,sha256=_Nv9DA-ORbiJxzdzRjdoGU33k4nYyUiQfsG2UNgBJKw,3030
|
|
7
|
+
graph_workflow/runner.py,sha256=wzsvIfgl9lRSmYCP45bjxa9lDWQnx3X9t5hpENBTWUs,4935
|
|
8
|
+
graph_workflow/schema.py,sha256=DwsKUBq1IYNhVg9NNWEVJGrVhSa3n-EcX4g0w5fuKps,922
|
|
9
|
+
graph_workflow/validator.py,sha256=Rg2cZM4rYct8jLJP7pnGRzGJtWLrv12TSetXC2Jkpjc,5866
|
|
10
|
+
graph_workflow/data/schema.json,sha256=6Ex4IYzDW9DKRBFXOEw6Jy48OczBD4TrMkUTsCkljCQ,3840
|
|
11
|
+
adk_graph_workflow-0.1.0.dist-info/METADATA,sha256=IV5sW40bH1tA4CpiJxWvguqx3ZYLkPSllP-w6UTdNaQ,428
|
|
12
|
+
adk_graph_workflow-0.1.0.dist-info/WHEEL,sha256=QccIxa26bgl1E6uMy58deGWi-0aeIkkangHcxk2kWfw,87
|
|
13
|
+
adk_graph_workflow-0.1.0.dist-info/RECORD,,
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
"""ADK Graph Workflow — YAML-defined graph workflows for Google ADK agents."""
|
|
2
|
+
|
|
3
|
+
import importlib
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
import yaml
|
|
7
|
+
|
|
8
|
+
from graph_workflow.compiler import CompiledEdge, CompiledGraph, CompiledNode, GraphCompiler
|
|
9
|
+
from graph_workflow.errors import (
|
|
10
|
+
ConditionEvalError,
|
|
11
|
+
FunctionResolutionError,
|
|
12
|
+
GraphCompilationError,
|
|
13
|
+
GraphExecutionError,
|
|
14
|
+
GraphValidationError,
|
|
15
|
+
GraphWorkflowError,
|
|
16
|
+
)
|
|
17
|
+
from graph_workflow.evaluator import ConditionEvaluator
|
|
18
|
+
from graph_workflow.models import (
|
|
19
|
+
EdgeDef,
|
|
20
|
+
FnRef,
|
|
21
|
+
FunctionDef,
|
|
22
|
+
FunctionNodeDef,
|
|
23
|
+
GraphWorkflowDef,
|
|
24
|
+
LanggraphAgentNodeDef,
|
|
25
|
+
LoopAgentNodeDef,
|
|
26
|
+
LlmAgentNodeDef,
|
|
27
|
+
NodeDef,
|
|
28
|
+
ParallelAgentNodeDef,
|
|
29
|
+
SequentialAgentNodeDef,
|
|
30
|
+
ToolRef,
|
|
31
|
+
)
|
|
32
|
+
from graph_workflow.resolver import FunctionResolver
|
|
33
|
+
from graph_workflow.runner import GraphRunnerAgent
|
|
34
|
+
from graph_workflow.schema import SchemaValidator
|
|
35
|
+
from graph_workflow.validator import GraphValidator
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _auto_resolve_langgraph_refs(workflow: GraphWorkflowDef, registry: dict) -> None:
|
|
39
|
+
"""Resolve langgraph_agent graph references not already in the registry."""
|
|
40
|
+
for _key, node_def in workflow.nodes.items():
|
|
41
|
+
if not isinstance(node_def, LanggraphAgentNodeDef):
|
|
42
|
+
continue
|
|
43
|
+
graph_ref = node_def.graph
|
|
44
|
+
if graph_ref in registry:
|
|
45
|
+
continue
|
|
46
|
+
module_path, _, attr_name = graph_ref.rpartition(".")
|
|
47
|
+
if not module_path:
|
|
48
|
+
continue
|
|
49
|
+
try:
|
|
50
|
+
module = importlib.import_module(module_path)
|
|
51
|
+
registry[graph_ref] = getattr(module, attr_name)
|
|
52
|
+
except (ImportError, AttributeError):
|
|
53
|
+
pass
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def from_config(config_path: str) -> GraphRunnerAgent:
|
|
57
|
+
"""Load a graph workflow from a YAML config file and return a ready-to-run agent."""
|
|
58
|
+
path = Path(config_path)
|
|
59
|
+
if not path.exists():
|
|
60
|
+
raise FileNotFoundError(f"Config file not found: {config_path}")
|
|
61
|
+
raw = yaml.safe_load(path.read_text(encoding="utf-8"))
|
|
62
|
+
SchemaValidator().validate(raw)
|
|
63
|
+
workflow = GraphWorkflowDef.model_validate(raw)
|
|
64
|
+
GraphValidator().validate(workflow)
|
|
65
|
+
resolver = FunctionResolver(workflow.functions)
|
|
66
|
+
registry = {name: resolver.resolve(name) for name in workflow.functions}
|
|
67
|
+
_auto_resolve_langgraph_refs(workflow, registry)
|
|
68
|
+
compiled = GraphCompiler(registry).compile(workflow)
|
|
69
|
+
return GraphRunnerAgent(name=workflow.name, graph=compiled)
|
|
70
|
+
|
|
71
|
+
__all__ = [
|
|
72
|
+
# Convenience
|
|
73
|
+
"from_config",
|
|
74
|
+
# Models
|
|
75
|
+
"EdgeDef",
|
|
76
|
+
"FnRef",
|
|
77
|
+
"FunctionDef",
|
|
78
|
+
"FunctionNodeDef",
|
|
79
|
+
"GraphWorkflowDef",
|
|
80
|
+
"LanggraphAgentNodeDef",
|
|
81
|
+
"LoopAgentNodeDef",
|
|
82
|
+
"LlmAgentNodeDef",
|
|
83
|
+
"NodeDef",
|
|
84
|
+
"ParallelAgentNodeDef",
|
|
85
|
+
"SequentialAgentNodeDef",
|
|
86
|
+
"ToolRef",
|
|
87
|
+
# Errors
|
|
88
|
+
"ConditionEvalError",
|
|
89
|
+
"FunctionResolutionError",
|
|
90
|
+
"GraphCompilationError",
|
|
91
|
+
"GraphExecutionError",
|
|
92
|
+
"GraphValidationError",
|
|
93
|
+
"GraphWorkflowError",
|
|
94
|
+
# Core
|
|
95
|
+
"CompiledEdge",
|
|
96
|
+
"CompiledGraph",
|
|
97
|
+
"CompiledNode",
|
|
98
|
+
"ConditionEvaluator",
|
|
99
|
+
"FunctionResolver",
|
|
100
|
+
"GraphCompiler",
|
|
101
|
+
"GraphRunnerAgent",
|
|
102
|
+
"GraphValidator",
|
|
103
|
+
"SchemaValidator",
|
|
104
|
+
]
|
|
@@ -0,0 +1,405 @@
|
|
|
1
|
+
"""Graph compiler — compiles workflow definitions into ADK agents."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import importlib
|
|
5
|
+
import inspect
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from typing import Any, AsyncGenerator, Callable
|
|
8
|
+
|
|
9
|
+
from google.adk.agents import (
|
|
10
|
+
BaseAgent,
|
|
11
|
+
LlmAgent,
|
|
12
|
+
LoopAgent,
|
|
13
|
+
ParallelAgent,
|
|
14
|
+
SequentialAgent,
|
|
15
|
+
)
|
|
16
|
+
from google.adk.events import Event
|
|
17
|
+
from google.adk.tools.base_tool import BaseTool
|
|
18
|
+
from google.adk.tools.base_toolset import BaseToolset
|
|
19
|
+
|
|
20
|
+
try:
|
|
21
|
+
from google.adk.agents import LanggraphAgent
|
|
22
|
+
except ImportError:
|
|
23
|
+
LanggraphAgent = None # type: ignore[misc, assignment]
|
|
24
|
+
|
|
25
|
+
from graph_workflow.errors import GraphCompilationError
|
|
26
|
+
from graph_workflow.models import (
|
|
27
|
+
EdgeDef,
|
|
28
|
+
FunctionNodeDef,
|
|
29
|
+
GraphWorkflowDef,
|
|
30
|
+
LanggraphAgentNodeDef,
|
|
31
|
+
LlmAgentNodeDef,
|
|
32
|
+
LoopAgentNodeDef,
|
|
33
|
+
ParallelAgentNodeDef,
|
|
34
|
+
SequentialAgentNodeDef,
|
|
35
|
+
ToolRef,
|
|
36
|
+
)
|
|
37
|
+
from graph_workflow.resolver import interpolate_value
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@dataclass
|
|
41
|
+
class CompiledNode:
|
|
42
|
+
"""A node that has been compiled and is ready for execution."""
|
|
43
|
+
|
|
44
|
+
key: str
|
|
45
|
+
callable: Callable[..., Any] | None = None
|
|
46
|
+
agent: BaseAgent | None = None
|
|
47
|
+
output_key: str | None = None
|
|
48
|
+
|
|
49
|
+
@property
|
|
50
|
+
def is_function(self) -> bool:
|
|
51
|
+
return self.callable is not None
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
def is_agent(self) -> bool:
|
|
55
|
+
return self.agent is not None
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@dataclass
|
|
59
|
+
class CompiledEdge:
|
|
60
|
+
"""A compiled edge with resolved target and condition."""
|
|
61
|
+
|
|
62
|
+
to_key: str
|
|
63
|
+
condition: str | None = None
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@dataclass
|
|
67
|
+
class CompiledGraph:
|
|
68
|
+
"""A fully compiled graph ready for execution."""
|
|
69
|
+
|
|
70
|
+
entry: str
|
|
71
|
+
exit: str | None
|
|
72
|
+
nodes: dict[str, CompiledNode]
|
|
73
|
+
adjacency: dict[str, list[CompiledEdge]]
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class _InlineFunctionAgent(BaseAgent):
|
|
77
|
+
"""Wraps a function node so it can act as a sub-agent in containers."""
|
|
78
|
+
|
|
79
|
+
def __init__(
|
|
80
|
+
self,
|
|
81
|
+
name: str,
|
|
82
|
+
fn: Callable[..., Any],
|
|
83
|
+
output_key: str | None = None,
|
|
84
|
+
):
|
|
85
|
+
super().__init__(name=name)
|
|
86
|
+
self._fn = fn
|
|
87
|
+
self._output_key = output_key
|
|
88
|
+
|
|
89
|
+
async def _run_async_impl(
|
|
90
|
+
self, ctx
|
|
91
|
+
) -> AsyncGenerator[Event, None]:
|
|
92
|
+
state = dict(ctx.session.state)
|
|
93
|
+
result = self._fn(state)
|
|
94
|
+
if inspect.isawaitable(result):
|
|
95
|
+
result = await result
|
|
96
|
+
if self._output_key and result is not None:
|
|
97
|
+
ctx.session.state[self._output_key] = result
|
|
98
|
+
yield Event(author=self.name)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def _interpolate_tool_args(
|
|
102
|
+
args: dict[str, Any] | None,
|
|
103
|
+
) -> dict[str, Any] | None:
|
|
104
|
+
"""Interpolate ${ENV_VAR} references in tool args."""
|
|
105
|
+
if args is None:
|
|
106
|
+
return None
|
|
107
|
+
return {k: interpolate_value(v) for k, v in args.items()}
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class GraphCompiler:
|
|
111
|
+
"""Compiles a GraphWorkflowDef into a CompiledGraph.
|
|
112
|
+
|
|
113
|
+
Uses two-pass compilation:
|
|
114
|
+
1. Non-container nodes (FunctionNodeDef, LlmAgentNodeDef, LanggraphAgentNodeDef)
|
|
115
|
+
2. Container nodes (SequentialAgentNodeDef, ParallelAgentNodeDef, LoopAgentNodeDef)
|
|
116
|
+
|
|
117
|
+
Container nodes are compiled second so that their sub_agents have
|
|
118
|
+
already been compiled and can be resolved from ``_compiled``.
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
def __init__(self, func_registry: dict[str, Callable[..., Any]]):
|
|
122
|
+
self._func_registry = func_registry
|
|
123
|
+
self._compiled: dict[str, CompiledNode] = {}
|
|
124
|
+
self._workflow: GraphWorkflowDef | None = None
|
|
125
|
+
self._compiling: set[str] = set()
|
|
126
|
+
|
|
127
|
+
def compile(self, workflow: GraphWorkflowDef) -> CompiledGraph:
|
|
128
|
+
"""Compile *workflow* into a :class:`CompiledGraph`."""
|
|
129
|
+
self._compiled = {}
|
|
130
|
+
self._compiling = set()
|
|
131
|
+
self._workflow = workflow
|
|
132
|
+
|
|
133
|
+
# First pass: non-container nodes
|
|
134
|
+
for key, node_def in workflow.nodes.items():
|
|
135
|
+
if not isinstance(
|
|
136
|
+
node_def,
|
|
137
|
+
(SequentialAgentNodeDef, ParallelAgentNodeDef, LoopAgentNodeDef),
|
|
138
|
+
):
|
|
139
|
+
self._compiled[key] = self._compile_node(key, node_def)
|
|
140
|
+
|
|
141
|
+
# Second pass: container nodes (sub_agents now available)
|
|
142
|
+
for key, node_def in workflow.nodes.items():
|
|
143
|
+
if isinstance(
|
|
144
|
+
node_def,
|
|
145
|
+
(SequentialAgentNodeDef, ParallelAgentNodeDef, LoopAgentNodeDef),
|
|
146
|
+
):
|
|
147
|
+
if key not in self._compiled:
|
|
148
|
+
self._compiled[key] = self._compile_node(key, node_def)
|
|
149
|
+
|
|
150
|
+
adjacency = self._build_adjacency(workflow.edges)
|
|
151
|
+
return CompiledGraph(
|
|
152
|
+
entry=workflow.entry,
|
|
153
|
+
exit=workflow.exit,
|
|
154
|
+
nodes=dict(self._compiled),
|
|
155
|
+
adjacency=adjacency,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
# -- node compilation ----------------------------------------------------
|
|
159
|
+
|
|
160
|
+
def _compile_node(self, key: str, node_def: Any) -> CompiledNode:
|
|
161
|
+
if isinstance(node_def, FunctionNodeDef):
|
|
162
|
+
return self._compile_function_node(key, node_def)
|
|
163
|
+
if isinstance(node_def, LlmAgentNodeDef):
|
|
164
|
+
return self._compile_llm_agent(key, node_def)
|
|
165
|
+
if isinstance(node_def, SequentialAgentNodeDef):
|
|
166
|
+
return self._compile_sequential_agent(key, node_def)
|
|
167
|
+
if isinstance(node_def, ParallelAgentNodeDef):
|
|
168
|
+
return self._compile_parallel_agent(key, node_def)
|
|
169
|
+
if isinstance(node_def, LoopAgentNodeDef):
|
|
170
|
+
return self._compile_loop_agent(key, node_def)
|
|
171
|
+
if isinstance(node_def, LanggraphAgentNodeDef):
|
|
172
|
+
return self._compile_langgraph_agent(key, node_def)
|
|
173
|
+
node_type = getattr(node_def, "type", type(node_def).__name__)
|
|
174
|
+
raise GraphCompilationError(
|
|
175
|
+
f"Unknown node type: {node_type}",
|
|
176
|
+
node_id=key,
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
def _compile_function_node(
|
|
180
|
+
self, key: str, node_def: FunctionNodeDef
|
|
181
|
+
) -> CompiledNode:
|
|
182
|
+
func = self._func_registry.get(node_def.function)
|
|
183
|
+
if func is None:
|
|
184
|
+
raise GraphCompilationError(
|
|
185
|
+
f"Function '{node_def.function}' not found in registry",
|
|
186
|
+
node_id=key,
|
|
187
|
+
)
|
|
188
|
+
return CompiledNode(key=key, callable=func, output_key=node_def.output_key)
|
|
189
|
+
|
|
190
|
+
def _compile_llm_agent(
|
|
191
|
+
self, key: str, node_def: LlmAgentNodeDef
|
|
192
|
+
) -> CompiledNode:
|
|
193
|
+
resolved_tools = self._resolve_tools(node_def.tools, key)
|
|
194
|
+
kwargs: dict[str, Any] = {
|
|
195
|
+
"name": key,
|
|
196
|
+
"model": node_def.model,
|
|
197
|
+
"instruction": node_def.instruction,
|
|
198
|
+
"tools": resolved_tools,
|
|
199
|
+
"output_key": node_def.output_key,
|
|
200
|
+
}
|
|
201
|
+
if node_def.output_schema is not None:
|
|
202
|
+
kwargs["output_schema"] = node_def.output_schema
|
|
203
|
+
try:
|
|
204
|
+
agent = LlmAgent(**kwargs)
|
|
205
|
+
except Exception as exc:
|
|
206
|
+
raise GraphCompilationError(
|
|
207
|
+
f"Failed to create LlmAgent: {exc}",
|
|
208
|
+
node_id=key,
|
|
209
|
+
) from exc
|
|
210
|
+
return CompiledNode(key=key, agent=agent, output_key=node_def.output_key)
|
|
211
|
+
|
|
212
|
+
def _resolve_tools(self, tools: list[ToolRef], key: str) -> list[Any]:
|
|
213
|
+
"""Resolve tool references to actual tool objects.
|
|
214
|
+
|
|
215
|
+
Follows ADK's native _resolve_tools pattern:
|
|
216
|
+
1. No dot in name - built-in tool from google.adk.tools
|
|
217
|
+
2. Dot in name - dotted path import
|
|
218
|
+
3. Dispatch by type: instance, class, factory function, or plain callable
|
|
219
|
+
"""
|
|
220
|
+
resolved: list[Any] = []
|
|
221
|
+
for tool_ref in tools:
|
|
222
|
+
obj = self._import_tool(tool_ref.name, key)
|
|
223
|
+
args = _interpolate_tool_args(tool_ref.args)
|
|
224
|
+
|
|
225
|
+
if isinstance(obj, BaseTool | BaseToolset):
|
|
226
|
+
resolved.append(obj)
|
|
227
|
+
elif inspect.isclass(obj) and issubclass(obj, BaseTool | BaseToolset):
|
|
228
|
+
try:
|
|
229
|
+
resolved.append(obj(**(args or {})))
|
|
230
|
+
except Exception as exc:
|
|
231
|
+
raise GraphCompilationError(
|
|
232
|
+
f"Failed to instantiate tool class '{tool_ref.name}': {exc}",
|
|
233
|
+
node_id=key,
|
|
234
|
+
) from exc
|
|
235
|
+
elif callable(obj):
|
|
236
|
+
if args is not None:
|
|
237
|
+
resolved.append(obj(args))
|
|
238
|
+
else:
|
|
239
|
+
resolved.append(obj)
|
|
240
|
+
else:
|
|
241
|
+
raise GraphCompilationError(
|
|
242
|
+
f"Tool '{tool_ref.name}' is not a valid tool "
|
|
243
|
+
f"(expected Callable, BaseTool, or BaseToolset, "
|
|
244
|
+
f"got {type(obj).__name__})",
|
|
245
|
+
node_id=key,
|
|
246
|
+
)
|
|
247
|
+
return resolved
|
|
248
|
+
|
|
249
|
+
def _import_tool(self, name: str, key: str) -> Any:
|
|
250
|
+
"""Import a tool by name — built-in or dotted path."""
|
|
251
|
+
if "." not in name:
|
|
252
|
+
try:
|
|
253
|
+
module = importlib.import_module("google.adk.tools")
|
|
254
|
+
except ImportError as exc:
|
|
255
|
+
raise GraphCompilationError(
|
|
256
|
+
f"Cannot import google.adk.tools for built-in tool '{name}': {exc}",
|
|
257
|
+
node_id=key,
|
|
258
|
+
) from exc
|
|
259
|
+
try:
|
|
260
|
+
return getattr(module, name)
|
|
261
|
+
except AttributeError as exc:
|
|
262
|
+
raise GraphCompilationError(
|
|
263
|
+
f"Built-in tool '{name}' not found in google.adk.tools",
|
|
264
|
+
node_id=key,
|
|
265
|
+
) from exc
|
|
266
|
+
|
|
267
|
+
module_path, _, obj_name = name.rpartition(".")
|
|
268
|
+
try:
|
|
269
|
+
module = importlib.import_module(module_path)
|
|
270
|
+
except ImportError as exc:
|
|
271
|
+
raise GraphCompilationError(
|
|
272
|
+
f"Cannot import module '{module_path}' for tool '{name}'",
|
|
273
|
+
node_id=key,
|
|
274
|
+
) from exc
|
|
275
|
+
try:
|
|
276
|
+
return getattr(module, obj_name)
|
|
277
|
+
except AttributeError as exc:
|
|
278
|
+
raise GraphCompilationError(
|
|
279
|
+
f"Tool '{obj_name}' not found in module '{module_path}'",
|
|
280
|
+
node_id=key,
|
|
281
|
+
) from exc
|
|
282
|
+
|
|
283
|
+
def _compile_sequential_agent(
|
|
284
|
+
self, key: str, node_def: SequentialAgentNodeDef
|
|
285
|
+
) -> CompiledNode:
|
|
286
|
+
sub_agents = self._resolve_sub_agents(node_def.sub_agents, key)
|
|
287
|
+
try:
|
|
288
|
+
agent = SequentialAgent(name=key, sub_agents=sub_agents)
|
|
289
|
+
except Exception as exc:
|
|
290
|
+
raise GraphCompilationError(
|
|
291
|
+
f"Failed to create SequentialAgent: {exc}",
|
|
292
|
+
node_id=key,
|
|
293
|
+
) from exc
|
|
294
|
+
return CompiledNode(key=key, agent=agent)
|
|
295
|
+
|
|
296
|
+
def _compile_parallel_agent(
|
|
297
|
+
self, key: str, node_def: ParallelAgentNodeDef
|
|
298
|
+
) -> CompiledNode:
|
|
299
|
+
sub_agents = self._resolve_sub_agents(node_def.sub_agents, key)
|
|
300
|
+
try:
|
|
301
|
+
agent = ParallelAgent(name=key, sub_agents=sub_agents)
|
|
302
|
+
except Exception as exc:
|
|
303
|
+
raise GraphCompilationError(
|
|
304
|
+
f"Failed to create ParallelAgent: {exc}",
|
|
305
|
+
node_id=key,
|
|
306
|
+
) from exc
|
|
307
|
+
return CompiledNode(key=key, agent=agent)
|
|
308
|
+
|
|
309
|
+
def _compile_loop_agent(
|
|
310
|
+
self, key: str, node_def: LoopAgentNodeDef
|
|
311
|
+
) -> CompiledNode:
|
|
312
|
+
sub_agents = self._resolve_sub_agents(node_def.sub_agents, key)
|
|
313
|
+
try:
|
|
314
|
+
agent = LoopAgent(
|
|
315
|
+
name=key,
|
|
316
|
+
sub_agents=sub_agents,
|
|
317
|
+
max_iterations=node_def.max_iterations,
|
|
318
|
+
)
|
|
319
|
+
except Exception as exc:
|
|
320
|
+
raise GraphCompilationError(
|
|
321
|
+
f"Failed to create LoopAgent: {exc}",
|
|
322
|
+
node_id=key,
|
|
323
|
+
) from exc
|
|
324
|
+
return CompiledNode(key=key, agent=agent)
|
|
325
|
+
|
|
326
|
+
def _compile_langgraph_agent(
|
|
327
|
+
self, key: str, node_def: LanggraphAgentNodeDef
|
|
328
|
+
) -> CompiledNode:
|
|
329
|
+
graph = self._func_registry.get(node_def.graph)
|
|
330
|
+
if graph is None:
|
|
331
|
+
raise GraphCompilationError(
|
|
332
|
+
f"Graph function '{node_def.graph}' not found in registry",
|
|
333
|
+
node_id=key,
|
|
334
|
+
)
|
|
335
|
+
if LanggraphAgent is None:
|
|
336
|
+
raise GraphCompilationError(
|
|
337
|
+
"LanggraphAgent is not available — install langchain-core and langgraph packages",
|
|
338
|
+
node_id=key,
|
|
339
|
+
)
|
|
340
|
+
try:
|
|
341
|
+
agent = LanggraphAgent(
|
|
342
|
+
name=key, graph=graph, instruction=node_def.instruction
|
|
343
|
+
)
|
|
344
|
+
except Exception as exc:
|
|
345
|
+
raise GraphCompilationError(
|
|
346
|
+
f"Failed to create LanggraphAgent: {exc}",
|
|
347
|
+
node_id=key,
|
|
348
|
+
) from exc
|
|
349
|
+
return CompiledNode(key=key, agent=agent)
|
|
350
|
+
|
|
351
|
+
# -- helpers -------------------------------------------------------------
|
|
352
|
+
|
|
353
|
+
def _resolve_sub_agents(
|
|
354
|
+
self, sub_agent_keys: list[str], parent_key: str
|
|
355
|
+
) -> list[BaseAgent]:
|
|
356
|
+
agents: list[BaseAgent] = []
|
|
357
|
+
for k in sub_agent_keys:
|
|
358
|
+
compiled = self._compiled.get(k)
|
|
359
|
+
# If not compiled yet (nested container), try recursive compilation
|
|
360
|
+
if compiled is None and self._workflow is not None:
|
|
361
|
+
if k in self._compiling:
|
|
362
|
+
raise GraphCompilationError(
|
|
363
|
+
f"Circular sub-agent reference detected: '{k}'",
|
|
364
|
+
node_id=parent_key,
|
|
365
|
+
)
|
|
366
|
+
node_def = self._workflow.nodes.get(k)
|
|
367
|
+
if node_def is not None and k not in self._compiled:
|
|
368
|
+
self._compiling.add(k)
|
|
369
|
+
try:
|
|
370
|
+
self._compiled[k] = self._compile_node(k, node_def)
|
|
371
|
+
finally:
|
|
372
|
+
self._compiling.discard(k)
|
|
373
|
+
compiled = self._compiled[k]
|
|
374
|
+
if compiled is None:
|
|
375
|
+
raise GraphCompilationError(
|
|
376
|
+
f"Sub-agent '{k}' has not been compiled",
|
|
377
|
+
node_id=parent_key,
|
|
378
|
+
detail=f"Compiled keys: {sorted(self._compiled.keys())}",
|
|
379
|
+
)
|
|
380
|
+
if compiled.agent is not None:
|
|
381
|
+
agents.append(compiled.agent)
|
|
382
|
+
elif compiled.callable is not None:
|
|
383
|
+
# Wrap function nodes so they can be used as sub-agents
|
|
384
|
+
agents.append(
|
|
385
|
+
_InlineFunctionAgent(
|
|
386
|
+
name=k,
|
|
387
|
+
fn=compiled.callable,
|
|
388
|
+
output_key=compiled.output_key,
|
|
389
|
+
)
|
|
390
|
+
)
|
|
391
|
+
else:
|
|
392
|
+
raise GraphCompilationError(
|
|
393
|
+
f"Sub-agent '{k}' is not an agent or function",
|
|
394
|
+
node_id=parent_key,
|
|
395
|
+
)
|
|
396
|
+
return agents
|
|
397
|
+
|
|
398
|
+
@staticmethod
|
|
399
|
+
def _build_adjacency(edges: list[EdgeDef]) -> dict[str, list[CompiledEdge]]:
|
|
400
|
+
adj: dict[str, list[CompiledEdge]] = {}
|
|
401
|
+
for edge in edges:
|
|
402
|
+
adj.setdefault(edge.from_, []).append(
|
|
403
|
+
CompiledEdge(to_key=edge.to, condition=edge.condition)
|
|
404
|
+
)
|
|
405
|
+
return adj
|
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
{
|
|
2
|
+
"$schema": "http://json-schema.org/draft-07/schema#",
|
|
3
|
+
"title": "GraphWorkflow",
|
|
4
|
+
"type": "object",
|
|
5
|
+
"required": ["version", "name", "entry", "nodes"],
|
|
6
|
+
"additionalProperties": false,
|
|
7
|
+
"properties": {
|
|
8
|
+
"version": {
|
|
9
|
+
"type": "string",
|
|
10
|
+
"const": "2.0"
|
|
11
|
+
},
|
|
12
|
+
"name": {
|
|
13
|
+
"type": "string",
|
|
14
|
+
"minLength": 1
|
|
15
|
+
},
|
|
16
|
+
"entry": {
|
|
17
|
+
"type": "string",
|
|
18
|
+
"minLength": 1
|
|
19
|
+
},
|
|
20
|
+
"exit": {
|
|
21
|
+
"type": "string"
|
|
22
|
+
},
|
|
23
|
+
"nodes": {
|
|
24
|
+
"type": "object",
|
|
25
|
+
"minProperties": 1,
|
|
26
|
+
"additionalProperties": {
|
|
27
|
+
"$ref": "#/definitions/NodeDef"
|
|
28
|
+
}
|
|
29
|
+
},
|
|
30
|
+
"edges": {
|
|
31
|
+
"type": "array",
|
|
32
|
+
"items": {
|
|
33
|
+
"$ref": "#/definitions/EdgeDef"
|
|
34
|
+
}
|
|
35
|
+
},
|
|
36
|
+
"functions": {
|
|
37
|
+
"type": "object",
|
|
38
|
+
"additionalProperties": {
|
|
39
|
+
"$ref": "#/definitions/FunctionDef"
|
|
40
|
+
}
|
|
41
|
+
}
|
|
42
|
+
},
|
|
43
|
+
"definitions": {
|
|
44
|
+
"EdgeDef": {
|
|
45
|
+
"type": "object",
|
|
46
|
+
"required": ["from", "to"],
|
|
47
|
+
"additionalProperties": false,
|
|
48
|
+
"properties": {
|
|
49
|
+
"from": {"type": "string", "minLength": 1},
|
|
50
|
+
"to": {"type": "string", "minLength": 1},
|
|
51
|
+
"condition": {"type": "string"}
|
|
52
|
+
}
|
|
53
|
+
},
|
|
54
|
+
"FunctionDef": {
|
|
55
|
+
"type": "object",
|
|
56
|
+
"required": ["callable"],
|
|
57
|
+
"additionalProperties": false,
|
|
58
|
+
"properties": {
|
|
59
|
+
"callable": {"type": "string", "minLength": 1},
|
|
60
|
+
"args": {"type": "object"}
|
|
61
|
+
}
|
|
62
|
+
},
|
|
63
|
+
"NodeDef": {
|
|
64
|
+
"oneOf": [
|
|
65
|
+
{"$ref": "#/definitions/FunctionNodeDef"},
|
|
66
|
+
{"$ref": "#/definitions/LlmAgentNodeDef"},
|
|
67
|
+
{"$ref": "#/definitions/SequentialAgentNodeDef"},
|
|
68
|
+
{"$ref": "#/definitions/ParallelAgentNodeDef"},
|
|
69
|
+
{"$ref": "#/definitions/LoopAgentNodeDef"},
|
|
70
|
+
{"$ref": "#/definitions/LanggraphAgentNodeDef"}
|
|
71
|
+
]
|
|
72
|
+
},
|
|
73
|
+
"FunctionNodeDef": {
|
|
74
|
+
"type": "object",
|
|
75
|
+
"required": ["type", "function"],
|
|
76
|
+
"properties": {
|
|
77
|
+
"type": {"type": "string", "const": "function"},
|
|
78
|
+
"function": {"type": "string", "minLength": 1},
|
|
79
|
+
"output_key": {"type": "string"}
|
|
80
|
+
}
|
|
81
|
+
},
|
|
82
|
+
"LlmAgentNodeDef": {
|
|
83
|
+
"type": "object",
|
|
84
|
+
"required": ["type"],
|
|
85
|
+
"properties": {
|
|
86
|
+
"type": {"type": "string", "const": "llm_agent"},
|
|
87
|
+
"model": {"type": "string"},
|
|
88
|
+
"instruction": {"type": "string"},
|
|
89
|
+
"tools": {
|
|
90
|
+
"type": "array",
|
|
91
|
+
"items": {
|
|
92
|
+
"type": "object",
|
|
93
|
+
"required": ["name"],
|
|
94
|
+
"additionalProperties": false,
|
|
95
|
+
"properties": {
|
|
96
|
+
"name": {"type": "string", "minLength": 1},
|
|
97
|
+
"args": {"type": "object"}
|
|
98
|
+
}
|
|
99
|
+
}
|
|
100
|
+
},
|
|
101
|
+
"output_key": {"type": "string"},
|
|
102
|
+
"output_schema": {"type": "object"}
|
|
103
|
+
}
|
|
104
|
+
},
|
|
105
|
+
"SequentialAgentNodeDef": {
|
|
106
|
+
"type": "object",
|
|
107
|
+
"required": ["type"],
|
|
108
|
+
"properties": {
|
|
109
|
+
"type": {"type": "string", "const": "sequential_agent"},
|
|
110
|
+
"sub_agents": {"type": "array", "items": {"type": "string"}}
|
|
111
|
+
}
|
|
112
|
+
},
|
|
113
|
+
"ParallelAgentNodeDef": {
|
|
114
|
+
"type": "object",
|
|
115
|
+
"required": ["type"],
|
|
116
|
+
"properties": {
|
|
117
|
+
"type": {"type": "string", "const": "parallel_agent"},
|
|
118
|
+
"sub_agents": {"type": "array", "items": {"type": "string"}}
|
|
119
|
+
}
|
|
120
|
+
},
|
|
121
|
+
"LoopAgentNodeDef": {
|
|
122
|
+
"type": "object",
|
|
123
|
+
"required": ["type"],
|
|
124
|
+
"properties": {
|
|
125
|
+
"type": {"type": "string", "const": "loop_agent"},
|
|
126
|
+
"sub_agents": {"type": "array", "items": {"type": "string"}},
|
|
127
|
+
"max_iterations": {"type": "integer", "minimum": 1}
|
|
128
|
+
}
|
|
129
|
+
},
|
|
130
|
+
"LanggraphAgentNodeDef": {
|
|
131
|
+
"type": "object",
|
|
132
|
+
"required": ["type", "graph"],
|
|
133
|
+
"properties": {
|
|
134
|
+
"type": {"type": "string", "const": "langgraph_agent"},
|
|
135
|
+
"graph": {"type": "string", "minLength": 1},
|
|
136
|
+
"instruction": {"type": "string"}
|
|
137
|
+
}
|
|
138
|
+
}
|
|
139
|
+
}
|
|
140
|
+
}
|
graph_workflow/errors.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
"""Error types for graph workflow."""
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class GraphWorkflowError(Exception):
|
|
5
|
+
"""Base exception for all graph workflow errors."""
|
|
6
|
+
|
|
7
|
+
def __init__(self, message: str, *, node_id: str | None = None, detail: str | None = None):
|
|
8
|
+
super().__init__(message)
|
|
9
|
+
self.node_id = node_id
|
|
10
|
+
self.detail = detail
|
|
11
|
+
|
|
12
|
+
def __str__(self) -> str:
|
|
13
|
+
parts = [super().__str__()]
|
|
14
|
+
if self.node_id:
|
|
15
|
+
parts.append(f"node_id={self.node_id}")
|
|
16
|
+
if self.detail:
|
|
17
|
+
parts.append(f"detail={self.detail}")
|
|
18
|
+
return " | ".join(parts)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class GraphValidationError(GraphWorkflowError):
|
|
22
|
+
"""Raised when a graph definition fails structural validation."""
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class GraphCompilationError(GraphWorkflowError):
|
|
26
|
+
"""Raised when a graph definition cannot be compiled into agents."""
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class GraphExecutionError(GraphWorkflowError):
|
|
30
|
+
"""Raised when a graph execution encounters a runtime error."""
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class ConditionEvalError(GraphWorkflowError):
|
|
34
|
+
"""Raised when a condition expression cannot be evaluated."""
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class FunctionResolutionError(GraphWorkflowError):
|
|
38
|
+
"""Raised when a function reference cannot be resolved."""
|
|
@@ -0,0 +1,137 @@
|
|
|
1
|
+
"""AST-based safe condition evaluation for graph edges."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import ast
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from graph_workflow.errors import ConditionEvalError
|
|
8
|
+
|
|
9
|
+
_ALLOWED_NODES = (
|
|
10
|
+
ast.Expression,
|
|
11
|
+
ast.Constant,
|
|
12
|
+
ast.Name,
|
|
13
|
+
ast.Attribute,
|
|
14
|
+
ast.Compare,
|
|
15
|
+
ast.BoolOp,
|
|
16
|
+
ast.UnaryOp,
|
|
17
|
+
ast.And,
|
|
18
|
+
ast.Or,
|
|
19
|
+
ast.Not,
|
|
20
|
+
ast.Eq,
|
|
21
|
+
ast.NotEq,
|
|
22
|
+
ast.Lt,
|
|
23
|
+
ast.LtE,
|
|
24
|
+
ast.Gt,
|
|
25
|
+
ast.GtE,
|
|
26
|
+
ast.In,
|
|
27
|
+
ast.NotIn,
|
|
28
|
+
ast.Is,
|
|
29
|
+
ast.IsNot,
|
|
30
|
+
ast.Load,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class ConditionEvaluator:
|
|
35
|
+
"""Safely evaluates condition expressions against a state dict using an AST whitelist."""
|
|
36
|
+
|
|
37
|
+
def evaluate(self, expression: str, state: dict[str, Any]) -> bool:
|
|
38
|
+
try:
|
|
39
|
+
tree = ast.parse(expression, mode="eval")
|
|
40
|
+
except SyntaxError as exc:
|
|
41
|
+
raise ConditionEvalError(
|
|
42
|
+
f"Syntax error in condition: {expression}",
|
|
43
|
+
detail=str(exc),
|
|
44
|
+
) from exc
|
|
45
|
+
|
|
46
|
+
self._validate_tree(tree)
|
|
47
|
+
try:
|
|
48
|
+
result = self._eval_node(tree.body, state)
|
|
49
|
+
except Exception as exc:
|
|
50
|
+
if isinstance(exc, ConditionEvalError):
|
|
51
|
+
raise
|
|
52
|
+
raise ConditionEvalError(
|
|
53
|
+
f"Error evaluating condition: {expression}",
|
|
54
|
+
detail=str(exc),
|
|
55
|
+
) from exc
|
|
56
|
+
|
|
57
|
+
return bool(result)
|
|
58
|
+
|
|
59
|
+
def _validate_tree(self, tree: ast.Expression) -> None:
|
|
60
|
+
for node in ast.walk(tree):
|
|
61
|
+
if not isinstance(node, _ALLOWED_NODES):
|
|
62
|
+
raise ConditionEvalError(
|
|
63
|
+
f"Disallowed AST node in condition: {type(node).__name__}",
|
|
64
|
+
detail=f"Node {type(node).__name__} is not in the allowed whitelist",
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
def _eval_node(self, node: ast.AST, state: dict[str, Any]) -> Any:
|
|
68
|
+
if isinstance(node, ast.Constant):
|
|
69
|
+
return node.value
|
|
70
|
+
if isinstance(node, ast.Name):
|
|
71
|
+
if node.id not in state:
|
|
72
|
+
raise ConditionEvalError(
|
|
73
|
+
f"Undefined variable: {node.id}",
|
|
74
|
+
detail=f"Available keys: {sorted(state.keys())}",
|
|
75
|
+
)
|
|
76
|
+
return state[node.id]
|
|
77
|
+
if isinstance(node, ast.Attribute):
|
|
78
|
+
value = self._eval_node(node.value, state)
|
|
79
|
+
try:
|
|
80
|
+
return getattr(value, node.attr)
|
|
81
|
+
except AttributeError:
|
|
82
|
+
if isinstance(value, dict):
|
|
83
|
+
return value[node.attr]
|
|
84
|
+
raise
|
|
85
|
+
if isinstance(node, ast.Compare):
|
|
86
|
+
left = self._eval_node(node.left, state)
|
|
87
|
+
for op, comparator in zip(node.ops, node.comparators):
|
|
88
|
+
right = self._eval_node(comparator, state)
|
|
89
|
+
if not self._eval_compare(op, left, right):
|
|
90
|
+
return False
|
|
91
|
+
left = right
|
|
92
|
+
return True
|
|
93
|
+
if isinstance(node, ast.BoolOp):
|
|
94
|
+
if isinstance(node.op, ast.And):
|
|
95
|
+
result = True
|
|
96
|
+
for value in node.values:
|
|
97
|
+
result = self._eval_node(value, state)
|
|
98
|
+
if not result:
|
|
99
|
+
return result
|
|
100
|
+
return result
|
|
101
|
+
else: # ast.Or
|
|
102
|
+
result = False
|
|
103
|
+
for value in node.values:
|
|
104
|
+
result = self._eval_node(value, state)
|
|
105
|
+
if result:
|
|
106
|
+
return result
|
|
107
|
+
return result
|
|
108
|
+
if isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.Not):
|
|
109
|
+
return not self._eval_node(node.operand, state)
|
|
110
|
+
|
|
111
|
+
raise ConditionEvalError(
|
|
112
|
+
f"Unsupported node type during evaluation: {type(node).__name__}"
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
@staticmethod
|
|
116
|
+
def _eval_compare(op: ast.cmpop, left: Any, right: Any) -> bool:
|
|
117
|
+
if isinstance(op, ast.Eq):
|
|
118
|
+
return left == right
|
|
119
|
+
if isinstance(op, ast.NotEq):
|
|
120
|
+
return left != right
|
|
121
|
+
if isinstance(op, ast.Lt):
|
|
122
|
+
return left < right
|
|
123
|
+
if isinstance(op, ast.LtE):
|
|
124
|
+
return left <= right
|
|
125
|
+
if isinstance(op, ast.Gt):
|
|
126
|
+
return left > right
|
|
127
|
+
if isinstance(op, ast.GtE):
|
|
128
|
+
return left >= right
|
|
129
|
+
if isinstance(op, ast.In):
|
|
130
|
+
return left in right
|
|
131
|
+
if isinstance(op, ast.NotIn):
|
|
132
|
+
return left not in right
|
|
133
|
+
if isinstance(op, ast.Is):
|
|
134
|
+
return left is right
|
|
135
|
+
if isinstance(op, ast.IsNot):
|
|
136
|
+
return left is not right
|
|
137
|
+
raise ConditionEvalError(f"Unsupported comparison operator: {type(op).__name__}")
|
graph_workflow/models.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
"""Pydantic data models for graph workflow definitions."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from typing import Annotated, Any, Literal, Union
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel, Field
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class FnRef(BaseModel):
|
|
10
|
+
"""Reference to a function in the functions registry via $fn key."""
|
|
11
|
+
|
|
12
|
+
fn_name: str = Field(alias="$fn")
|
|
13
|
+
|
|
14
|
+
model_config = {"populate_by_name": True}
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class EdgeDef(BaseModel):
|
|
18
|
+
"""An edge connecting two nodes, optionally with a condition."""
|
|
19
|
+
|
|
20
|
+
from_: str = Field(alias="from")
|
|
21
|
+
to: str
|
|
22
|
+
condition: str | None = None
|
|
23
|
+
|
|
24
|
+
model_config = {"populate_by_name": True}
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class FunctionDef(BaseModel):
|
|
28
|
+
"""Definition of a Python function to be resolved at compile time."""
|
|
29
|
+
|
|
30
|
+
callable: str
|
|
31
|
+
args: dict[str, Any] = Field(default_factory=dict)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class FunctionNodeDef(BaseModel):
|
|
35
|
+
"""A node that executes a Python function directly."""
|
|
36
|
+
|
|
37
|
+
type: Literal["function"] = "function"
|
|
38
|
+
function: str
|
|
39
|
+
output_key: str | None = None
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class ToolRef(BaseModel):
|
|
43
|
+
"""Reference to a tool — mirrors ADK's ToolConfig."""
|
|
44
|
+
|
|
45
|
+
name: str
|
|
46
|
+
args: dict[str, Any] | None = None
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class LlmAgentNodeDef(BaseModel):
|
|
50
|
+
"""A node that wraps an LLM agent."""
|
|
51
|
+
|
|
52
|
+
type: Literal["llm_agent"] = "llm_agent"
|
|
53
|
+
model: str = "gemini-2.0-flash"
|
|
54
|
+
instruction: str = ""
|
|
55
|
+
tools: list[ToolRef] = Field(default_factory=list)
|
|
56
|
+
output_key: str | None = None
|
|
57
|
+
output_schema: dict[str, Any] | None = None
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class SequentialAgentNodeDef(BaseModel):
|
|
61
|
+
"""A node that runs sub-agents sequentially."""
|
|
62
|
+
|
|
63
|
+
type: Literal["sequential_agent"] = "sequential_agent"
|
|
64
|
+
sub_agents: list[str] = Field(default_factory=list)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class ParallelAgentNodeDef(BaseModel):
|
|
68
|
+
"""A node that runs sub-agents in parallel."""
|
|
69
|
+
|
|
70
|
+
type: Literal["parallel_agent"] = "parallel_agent"
|
|
71
|
+
sub_agents: list[str] = Field(default_factory=list)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class LoopAgentNodeDef(BaseModel):
|
|
75
|
+
"""A node that runs sub-agents in a loop."""
|
|
76
|
+
|
|
77
|
+
type: Literal["loop_agent"] = "loop_agent"
|
|
78
|
+
sub_agents: list[str] = Field(default_factory=list)
|
|
79
|
+
max_iterations: int = 10
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class LanggraphAgentNodeDef(BaseModel):
|
|
83
|
+
"""A node that wraps a LangGraph graph."""
|
|
84
|
+
|
|
85
|
+
type: Literal["langgraph_agent"] = "langgraph_agent"
|
|
86
|
+
graph: str
|
|
87
|
+
instruction: str = ""
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
NodeDef = Annotated[
|
|
91
|
+
Union[
|
|
92
|
+
FunctionNodeDef,
|
|
93
|
+
LlmAgentNodeDef,
|
|
94
|
+
SequentialAgentNodeDef,
|
|
95
|
+
ParallelAgentNodeDef,
|
|
96
|
+
LoopAgentNodeDef,
|
|
97
|
+
LanggraphAgentNodeDef,
|
|
98
|
+
],
|
|
99
|
+
Field(discriminator="type"),
|
|
100
|
+
]
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class GraphWorkflowDef(BaseModel):
|
|
104
|
+
"""Top-level graph workflow definition."""
|
|
105
|
+
|
|
106
|
+
version: str = "2.0"
|
|
107
|
+
name: str
|
|
108
|
+
entry: str
|
|
109
|
+
exit: str | None = None
|
|
110
|
+
nodes: dict[str, NodeDef]
|
|
111
|
+
edges: list[EdgeDef] = Field(default_factory=list)
|
|
112
|
+
functions: dict[str, FunctionDef] = Field(default_factory=dict)
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
"""Dynamic function resolution for graph workflow nodes."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import importlib
|
|
5
|
+
import os
|
|
6
|
+
import re
|
|
7
|
+
from typing import Any, Callable
|
|
8
|
+
|
|
9
|
+
from graph_workflow.errors import FunctionResolutionError
|
|
10
|
+
from graph_workflow.models import FunctionDef
|
|
11
|
+
|
|
12
|
+
_VAR_PATTERN = re.compile(r"\$\{(\w+)\}")
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def interpolate_value(value: Any) -> Any:
|
|
16
|
+
"""Recursively interpolate ${ENV_VAR} references in config values."""
|
|
17
|
+
if isinstance(value, str):
|
|
18
|
+
return _VAR_PATTERN.sub(
|
|
19
|
+
lambda m: os.environ.get(m.group(1), m.group(0)), value
|
|
20
|
+
)
|
|
21
|
+
if isinstance(value, dict):
|
|
22
|
+
return {k: interpolate_value(v) for k, v in value.items()}
|
|
23
|
+
if isinstance(value, list):
|
|
24
|
+
return [interpolate_value(v) for v in value]
|
|
25
|
+
return value
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class FunctionResolver:
|
|
29
|
+
"""Resolves function references to actual Python callables."""
|
|
30
|
+
|
|
31
|
+
def __init__(self, functions: dict[str, FunctionDef] | None = None):
|
|
32
|
+
self._functions = functions or {}
|
|
33
|
+
|
|
34
|
+
def resolve(self, name: str) -> Callable[..., Any]:
|
|
35
|
+
if name not in self._functions:
|
|
36
|
+
raise FunctionResolutionError(
|
|
37
|
+
f"Function not found in registry: {name}",
|
|
38
|
+
detail=f"Available functions: {sorted(self._functions.keys())}",
|
|
39
|
+
)
|
|
40
|
+
func_def = self._functions[name]
|
|
41
|
+
return self._resolve_callable(func_def.callable, func_def.args)
|
|
42
|
+
|
|
43
|
+
def _resolve_callable(self, dotted_path: str, args: dict[str, Any]) -> Callable[..., Any]:
|
|
44
|
+
module_path, _, attr_name = dotted_path.rpartition(".")
|
|
45
|
+
if not module_path:
|
|
46
|
+
raise FunctionResolutionError(
|
|
47
|
+
f"Invalid callable path (missing module): {dotted_path}",
|
|
48
|
+
detail="Expected format: module.path.callable",
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
try:
|
|
52
|
+
module = importlib.import_module(module_path)
|
|
53
|
+
except ImportError as exc:
|
|
54
|
+
raise FunctionResolutionError(
|
|
55
|
+
f"Cannot import module: {module_path}",
|
|
56
|
+
detail=str(exc),
|
|
57
|
+
) from exc
|
|
58
|
+
|
|
59
|
+
try:
|
|
60
|
+
callable_obj = getattr(module, attr_name)
|
|
61
|
+
except AttributeError as exc:
|
|
62
|
+
raise FunctionResolutionError(
|
|
63
|
+
f"Attribute '{attr_name}' not found in module '{module_path}'",
|
|
64
|
+
detail=str(exc),
|
|
65
|
+
) from exc
|
|
66
|
+
|
|
67
|
+
if not callable(callable_obj):
|
|
68
|
+
raise FunctionResolutionError(
|
|
69
|
+
f"'{dotted_path}' is not callable",
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
resolved_args = self._interpolate_args(args)
|
|
73
|
+
if resolved_args:
|
|
74
|
+
|
|
75
|
+
def _bound_callable(*fn_args: Any, **fn_kwargs: Any) -> Any:
|
|
76
|
+
merged = {**resolved_args, **fn_kwargs}
|
|
77
|
+
return callable_obj(*fn_args, **merged)
|
|
78
|
+
|
|
79
|
+
return _bound_callable
|
|
80
|
+
|
|
81
|
+
return callable_obj
|
|
82
|
+
|
|
83
|
+
def _interpolate_args(self, args: dict[str, Any]) -> dict[str, Any]:
|
|
84
|
+
resolved: dict[str, Any] = {}
|
|
85
|
+
for key, value in args.items():
|
|
86
|
+
resolved[key] = interpolate_value(value)
|
|
87
|
+
return resolved
|
graph_workflow/runner.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
"""GraphRunnerAgent — runtime graph executor for ADK."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from typing import AsyncGenerator
|
|
5
|
+
|
|
6
|
+
from pydantic import PrivateAttr
|
|
7
|
+
from google.adk.agents import BaseAgent, InvocationContext
|
|
8
|
+
from google.adk.events import Event
|
|
9
|
+
|
|
10
|
+
from graph_workflow.compiler import CompiledGraph, CompiledNode
|
|
11
|
+
from graph_workflow.errors import GraphExecutionError
|
|
12
|
+
from graph_workflow.evaluator import ConditionEvaluator
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class GraphRunnerAgent(BaseAgent):
|
|
16
|
+
"""An ADK BaseAgent that executes a compiled graph workflow at runtime."""
|
|
17
|
+
|
|
18
|
+
_graph: CompiledGraph = PrivateAttr()
|
|
19
|
+
_evaluator: ConditionEvaluator = PrivateAttr()
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
name: str,
|
|
24
|
+
graph: CompiledGraph,
|
|
25
|
+
description: str = "",
|
|
26
|
+
):
|
|
27
|
+
super().__init__(name=name, description=description)
|
|
28
|
+
self._graph = graph
|
|
29
|
+
self._evaluator = ConditionEvaluator()
|
|
30
|
+
|
|
31
|
+
async def _run_async_impl(
|
|
32
|
+
self, ctx: InvocationContext
|
|
33
|
+
) -> AsyncGenerator[Event, None]:
|
|
34
|
+
# Store user message in session state so LLM agents can access it
|
|
35
|
+
user_content = getattr(ctx, "user_content", None)
|
|
36
|
+
if user_content and user_content.parts:
|
|
37
|
+
user_text = "".join(
|
|
38
|
+
p.text for p in user_content.parts if p.text
|
|
39
|
+
).strip()
|
|
40
|
+
if user_text and "user_message" not in ctx.session.state:
|
|
41
|
+
ctx.session.state["user_message"] = user_text
|
|
42
|
+
# Also add as a user event so LlmAgent sees it in conversation history
|
|
43
|
+
from google.adk.events import Event
|
|
44
|
+
user_event = Event(
|
|
45
|
+
author="user",
|
|
46
|
+
content=user_content,
|
|
47
|
+
)
|
|
48
|
+
ctx.session.events.append(user_event)
|
|
49
|
+
|
|
50
|
+
current_key = self._graph.entry
|
|
51
|
+
visited: set[str] = set()
|
|
52
|
+
max_steps = len(self._graph.nodes) * 2 # Safety bound
|
|
53
|
+
steps = 0
|
|
54
|
+
|
|
55
|
+
while current_key is not None:
|
|
56
|
+
if steps >= max_steps:
|
|
57
|
+
raise GraphExecutionError(
|
|
58
|
+
f"Graph execution exceeded max steps ({max_steps})",
|
|
59
|
+
node_id=current_key,
|
|
60
|
+
detail="Possible infinite loop in graph",
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
if current_key in visited and not self._has_condition_from(current_key):
|
|
64
|
+
raise GraphExecutionError(
|
|
65
|
+
f"Revisiting node '{current_key}' without conditional edges",
|
|
66
|
+
node_id=current_key,
|
|
67
|
+
detail="Possible cycle in graph",
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
steps += 1
|
|
71
|
+
visited.add(current_key)
|
|
72
|
+
|
|
73
|
+
node = self._graph.nodes.get(current_key)
|
|
74
|
+
if node is None:
|
|
75
|
+
raise GraphExecutionError(
|
|
76
|
+
f"Node '{current_key}' not found in compiled graph",
|
|
77
|
+
node_id=current_key,
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
# Execute the node
|
|
81
|
+
if node.is_function:
|
|
82
|
+
await self._execute_function(node, ctx)
|
|
83
|
+
elif node.is_agent:
|
|
84
|
+
async for event in node.agent.run_async(ctx):
|
|
85
|
+
# Merge state deltas from agent events into session state
|
|
86
|
+
if event.actions and event.actions.state_delta:
|
|
87
|
+
ctx.session.state.update(event.actions.state_delta)
|
|
88
|
+
yield event
|
|
89
|
+
|
|
90
|
+
# Check if we've reached the exit
|
|
91
|
+
if current_key == self._graph.exit:
|
|
92
|
+
break
|
|
93
|
+
|
|
94
|
+
# Find next node via edges
|
|
95
|
+
current_key = self._find_next(current_key, ctx)
|
|
96
|
+
|
|
97
|
+
async def _execute_function(
|
|
98
|
+
self, node: CompiledNode, ctx: InvocationContext
|
|
99
|
+
) -> None:
|
|
100
|
+
try:
|
|
101
|
+
import inspect
|
|
102
|
+
state = dict(ctx.session.state)
|
|
103
|
+
result = node.callable(state)
|
|
104
|
+
if inspect.isawaitable(result):
|
|
105
|
+
result = await result
|
|
106
|
+
if node.output_key and result is not None:
|
|
107
|
+
ctx.session.state[node.output_key] = result
|
|
108
|
+
except Exception as exc:
|
|
109
|
+
raise GraphExecutionError(
|
|
110
|
+
f"Function execution failed for node '{node.key}'",
|
|
111
|
+
node_id=node.key,
|
|
112
|
+
detail=str(exc),
|
|
113
|
+
) from exc
|
|
114
|
+
|
|
115
|
+
def _find_next(
|
|
116
|
+
self, current_key: str, ctx: InvocationContext
|
|
117
|
+
) -> str | None:
|
|
118
|
+
edges = self._graph.adjacency.get(current_key, [])
|
|
119
|
+
if not edges:
|
|
120
|
+
return None
|
|
121
|
+
|
|
122
|
+
state = {
|
|
123
|
+
k: v.strip() if isinstance(v, str) else v
|
|
124
|
+
for k, v in ctx.session.state.items()
|
|
125
|
+
}
|
|
126
|
+
for edge in edges:
|
|
127
|
+
if edge.condition is None:
|
|
128
|
+
return edge.to_key
|
|
129
|
+
if self._evaluator.evaluate(edge.condition, state):
|
|
130
|
+
return edge.to_key
|
|
131
|
+
|
|
132
|
+
return None
|
|
133
|
+
|
|
134
|
+
def _has_condition_from(self, key: str) -> bool:
|
|
135
|
+
edges = self._graph.adjacency.get(key, [])
|
|
136
|
+
return any(e.condition is not None for e in edges)
|
graph_workflow/schema.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
"""JSON Schema validation for graph workflow definitions."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import json
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
import jsonschema
|
|
8
|
+
|
|
9
|
+
from graph_workflow.errors import GraphValidationError
|
|
10
|
+
|
|
11
|
+
_SCHEMA_DIR = Path(__file__).parent / "data"
|
|
12
|
+
_SCHEMA_PATH = _SCHEMA_DIR / "schema.json"
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _load_schema() -> dict:
|
|
16
|
+
with open(_SCHEMA_PATH) as f:
|
|
17
|
+
return json.load(f)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class SchemaValidator:
|
|
21
|
+
"""Validates raw workflow dicts against the JSON Schema."""
|
|
22
|
+
|
|
23
|
+
def __init__(self, schema: dict | None = None):
|
|
24
|
+
self._schema = schema or _load_schema()
|
|
25
|
+
|
|
26
|
+
def validate(self, data: dict) -> None:
|
|
27
|
+
try:
|
|
28
|
+
jsonschema.validate(instance=data, schema=self._schema)
|
|
29
|
+
except jsonschema.ValidationError as exc:
|
|
30
|
+
raise GraphValidationError(
|
|
31
|
+
f"JSON Schema validation failed: {exc.message}",
|
|
32
|
+
detail=exc.json_path,
|
|
33
|
+
) from exc
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
"""Structural validation for graph workflow definitions."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from graph_workflow.errors import GraphValidationError
|
|
5
|
+
from graph_workflow.models import (
|
|
6
|
+
FunctionNodeDef,
|
|
7
|
+
GraphWorkflowDef,
|
|
8
|
+
LoopAgentNodeDef,
|
|
9
|
+
ParallelAgentNodeDef,
|
|
10
|
+
SequentialAgentNodeDef,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class GraphValidator:
|
|
15
|
+
"""Validates structural properties of a GraphWorkflowDef beyond JSON Schema."""
|
|
16
|
+
|
|
17
|
+
def validate(self, workflow: GraphWorkflowDef) -> list[str]:
|
|
18
|
+
warnings: list[str] = []
|
|
19
|
+
self._check_entry_node(workflow)
|
|
20
|
+
self._check_exit_node(workflow, warnings)
|
|
21
|
+
self._check_edge_nodes_exist(workflow)
|
|
22
|
+
self._check_self_loops(workflow)
|
|
23
|
+
self._check_exit_no_out_edges(workflow)
|
|
24
|
+
self._check_entry_no_in_edges(workflow)
|
|
25
|
+
self._check_single_default_edge(workflow)
|
|
26
|
+
self._check_container_sub_agents(workflow)
|
|
27
|
+
self._check_function_references(workflow)
|
|
28
|
+
self._check_reachable_nodes(workflow, warnings)
|
|
29
|
+
return warnings
|
|
30
|
+
|
|
31
|
+
def _check_entry_node(self, workflow: GraphWorkflowDef) -> None:
|
|
32
|
+
if workflow.entry not in workflow.nodes:
|
|
33
|
+
raise GraphValidationError(
|
|
34
|
+
f"Entry node '{workflow.entry}' not found in nodes",
|
|
35
|
+
node_id=workflow.entry,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
def _check_exit_node(self, workflow: GraphWorkflowDef, warnings: list[str]) -> None:
|
|
39
|
+
if workflow.exit is None:
|
|
40
|
+
warnings.append("No exit node defined — graph may run indefinitely")
|
|
41
|
+
return
|
|
42
|
+
if workflow.exit not in workflow.nodes:
|
|
43
|
+
raise GraphValidationError(
|
|
44
|
+
f"Exit node '{workflow.exit}' not found in nodes",
|
|
45
|
+
node_id=workflow.exit,
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
def _check_edge_nodes_exist(self, workflow: GraphWorkflowDef) -> None:
|
|
49
|
+
node_ids = set(workflow.nodes.keys())
|
|
50
|
+
for edge in workflow.edges:
|
|
51
|
+
if edge.from_ not in node_ids:
|
|
52
|
+
raise GraphValidationError(
|
|
53
|
+
f"Edge source '{edge.from_}' not found in nodes",
|
|
54
|
+
node_id=edge.from_,
|
|
55
|
+
)
|
|
56
|
+
if edge.to not in node_ids:
|
|
57
|
+
raise GraphValidationError(
|
|
58
|
+
f"Edge target '{edge.to}' not found in nodes",
|
|
59
|
+
node_id=edge.to,
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
def _check_container_sub_agents(self, workflow: GraphWorkflowDef) -> None:
|
|
63
|
+
node_ids = set(workflow.nodes.keys())
|
|
64
|
+
for nid, node in workflow.nodes.items():
|
|
65
|
+
sub_agents: list[str] = []
|
|
66
|
+
if isinstance(node, (SequentialAgentNodeDef, ParallelAgentNodeDef, LoopAgentNodeDef)):
|
|
67
|
+
sub_agents = node.sub_agents
|
|
68
|
+
for child_id in sub_agents:
|
|
69
|
+
if child_id not in node_ids:
|
|
70
|
+
raise GraphValidationError(
|
|
71
|
+
f"Sub-agent '{child_id}' of container node '{nid}' not found in nodes",
|
|
72
|
+
node_id=nid,
|
|
73
|
+
detail=f"Missing sub-agent: {child_id}",
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
def _check_self_loops(self, workflow: GraphWorkflowDef) -> None:
|
|
77
|
+
for edge in workflow.edges:
|
|
78
|
+
if edge.from_ == edge.to:
|
|
79
|
+
raise GraphValidationError(
|
|
80
|
+
f"Self-loop detected on node '{edge.from_}'",
|
|
81
|
+
node_id=edge.from_,
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
def _check_exit_no_out_edges(self, workflow: GraphWorkflowDef) -> None:
|
|
85
|
+
if workflow.exit is None:
|
|
86
|
+
return
|
|
87
|
+
for edge in workflow.edges:
|
|
88
|
+
if edge.from_ == workflow.exit:
|
|
89
|
+
raise GraphValidationError(
|
|
90
|
+
f"Exit node '{workflow.exit}' has outgoing edge",
|
|
91
|
+
node_id=workflow.exit,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
def _check_entry_no_in_edges(self, workflow: GraphWorkflowDef) -> None:
|
|
95
|
+
for edge in workflow.edges:
|
|
96
|
+
if edge.to == workflow.entry:
|
|
97
|
+
raise GraphValidationError(
|
|
98
|
+
f"Entry node '{workflow.entry}' has incoming edge",
|
|
99
|
+
node_id=workflow.entry,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
def _check_single_default_edge(self, workflow: GraphWorkflowDef) -> None:
|
|
103
|
+
from_groups: dict[str, int] = {}
|
|
104
|
+
for edge in workflow.edges:
|
|
105
|
+
if edge.condition is None:
|
|
106
|
+
count = from_groups.get(edge.from_, 0) + 1
|
|
107
|
+
from_groups[edge.from_] = count
|
|
108
|
+
if count > 1:
|
|
109
|
+
raise GraphValidationError(
|
|
110
|
+
f"Multiple default (unconditional) edges from node '{edge.from_}'",
|
|
111
|
+
node_id=edge.from_,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
def _check_function_references(self, workflow: GraphWorkflowDef) -> None:
|
|
115
|
+
func_names = set(workflow.functions.keys())
|
|
116
|
+
for nid, node in workflow.nodes.items():
|
|
117
|
+
if isinstance(node, FunctionNodeDef):
|
|
118
|
+
if node.function not in func_names:
|
|
119
|
+
raise GraphValidationError(
|
|
120
|
+
f"Function node '{nid}' references undefined function '{node.function}'",
|
|
121
|
+
node_id=nid,
|
|
122
|
+
detail=f"Available functions: {sorted(func_names)}",
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
def _check_reachable_nodes(self, workflow: GraphWorkflowDef, warnings: list[str]) -> None:
|
|
126
|
+
reachable: set[str] = {workflow.entry}
|
|
127
|
+
edge_map: dict[str, list[str]] = {}
|
|
128
|
+
for edge in workflow.edges:
|
|
129
|
+
edge_map.setdefault(edge.from_, []).append(edge.to)
|
|
130
|
+
|
|
131
|
+
queue = [workflow.entry]
|
|
132
|
+
while queue:
|
|
133
|
+
current = queue.pop(0)
|
|
134
|
+
for target in edge_map.get(current, []):
|
|
135
|
+
if target not in reachable:
|
|
136
|
+
reachable.add(target)
|
|
137
|
+
queue.append(target)
|
|
138
|
+
|
|
139
|
+
for nid in workflow.nodes:
|
|
140
|
+
if nid not in reachable:
|
|
141
|
+
warnings.append(f"Node '{nid}' is not reachable from entry")
|