yamlgraph 0.3.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- examples/__init__.py +1 -0
- examples/codegen/__init__.py +5 -0
- examples/codegen/models/__init__.py +13 -0
- examples/codegen/models/schemas.py +76 -0
- examples/codegen/tests/__init__.py +1 -0
- examples/codegen/tests/test_ai_helpers.py +235 -0
- examples/codegen/tests/test_ast_analysis.py +174 -0
- examples/codegen/tests/test_code_analysis.py +134 -0
- examples/codegen/tests/test_code_context.py +301 -0
- examples/codegen/tests/test_code_nav.py +89 -0
- examples/codegen/tests/test_dependency_tools.py +119 -0
- examples/codegen/tests/test_example_tools.py +185 -0
- examples/codegen/tests/test_git_tools.py +112 -0
- examples/codegen/tests/test_impl_agent_schemas.py +193 -0
- examples/codegen/tests/test_impl_agent_v4_graph.py +94 -0
- examples/codegen/tests/test_jedi_analysis.py +226 -0
- examples/codegen/tests/test_meta_tools.py +250 -0
- examples/codegen/tests/test_plan_discovery_prompt.py +98 -0
- examples/codegen/tests/test_syntax_tools.py +85 -0
- examples/codegen/tests/test_synthesize_prompt.py +94 -0
- examples/codegen/tests/test_template_tools.py +244 -0
- examples/codegen/tools/__init__.py +80 -0
- examples/codegen/tools/ai_helpers.py +420 -0
- examples/codegen/tools/ast_analysis.py +92 -0
- examples/codegen/tools/code_context.py +180 -0
- examples/codegen/tools/code_nav.py +52 -0
- examples/codegen/tools/dependency_tools.py +120 -0
- examples/codegen/tools/example_tools.py +188 -0
- examples/codegen/tools/git_tools.py +151 -0
- examples/codegen/tools/impl_executor.py +614 -0
- examples/codegen/tools/jedi_analysis.py +311 -0
- examples/codegen/tools/meta_tools.py +202 -0
- examples/codegen/tools/syntax_tools.py +26 -0
- examples/codegen/tools/template_tools.py +356 -0
- examples/fastapi_interview.py +167 -0
- examples/npc/api/__init__.py +1 -0
- examples/npc/api/app.py +100 -0
- examples/npc/api/routes/__init__.py +5 -0
- examples/npc/api/routes/encounter.py +182 -0
- examples/npc/api/session.py +330 -0
- examples/npc/demo.py +387 -0
- examples/npc/nodes/__init__.py +5 -0
- examples/npc/nodes/image_node.py +92 -0
- examples/npc/run_encounter.py +230 -0
- examples/shared/__init__.py +0 -0
- examples/shared/replicate_tool.py +238 -0
- examples/storyboard/__init__.py +1 -0
- examples/storyboard/generate_videos.py +335 -0
- examples/storyboard/nodes/__init__.py +12 -0
- examples/storyboard/nodes/animated_character_node.py +248 -0
- examples/storyboard/nodes/animated_image_node.py +138 -0
- examples/storyboard/nodes/character_node.py +162 -0
- examples/storyboard/nodes/image_node.py +118 -0
- examples/storyboard/nodes/replicate_tool.py +49 -0
- examples/storyboard/retry_images.py +118 -0
- scripts/demo_async_executor.py +212 -0
- scripts/demo_interview_e2e.py +200 -0
- scripts/demo_streaming.py +140 -0
- scripts/run_interview_demo.py +94 -0
- scripts/test_interrupt_fix.py +26 -0
- tests/__init__.py +1 -0
- tests/conftest.py +178 -0
- tests/integration/__init__.py +1 -0
- tests/integration/test_animated_storyboard.py +63 -0
- tests/integration/test_cli_commands.py +242 -0
- tests/integration/test_colocated_prompts.py +139 -0
- tests/integration/test_map_demo.py +50 -0
- tests/integration/test_memory_demo.py +283 -0
- tests/integration/test_npc_api/__init__.py +1 -0
- tests/integration/test_npc_api/test_routes.py +357 -0
- tests/integration/test_npc_api/test_session.py +216 -0
- tests/integration/test_pipeline_flow.py +105 -0
- tests/integration/test_providers.py +163 -0
- tests/integration/test_resume.py +75 -0
- tests/integration/test_subgraph_integration.py +295 -0
- tests/integration/test_subgraph_interrupt.py +106 -0
- tests/unit/__init__.py +1 -0
- tests/unit/test_agent_nodes.py +355 -0
- tests/unit/test_async_executor.py +346 -0
- tests/unit/test_checkpointer.py +212 -0
- tests/unit/test_checkpointer_factory.py +212 -0
- tests/unit/test_cli.py +121 -0
- tests/unit/test_cli_package.py +81 -0
- tests/unit/test_compile_graph_map.py +132 -0
- tests/unit/test_conditions_routing.py +253 -0
- tests/unit/test_config.py +93 -0
- tests/unit/test_conversation_memory.py +276 -0
- tests/unit/test_database.py +145 -0
- tests/unit/test_deprecation.py +104 -0
- tests/unit/test_executor.py +172 -0
- tests/unit/test_executor_async.py +179 -0
- tests/unit/test_export.py +149 -0
- tests/unit/test_expressions.py +178 -0
- tests/unit/test_feature_brainstorm.py +194 -0
- tests/unit/test_format_prompt.py +145 -0
- tests/unit/test_generic_report.py +200 -0
- tests/unit/test_graph_commands.py +327 -0
- tests/unit/test_graph_linter.py +627 -0
- tests/unit/test_graph_loader.py +357 -0
- tests/unit/test_graph_schema.py +193 -0
- tests/unit/test_inline_schema.py +151 -0
- tests/unit/test_interrupt_node.py +182 -0
- tests/unit/test_issues.py +164 -0
- tests/unit/test_jinja2_prompts.py +85 -0
- tests/unit/test_json_extract.py +134 -0
- tests/unit/test_langsmith.py +600 -0
- tests/unit/test_langsmith_tools.py +204 -0
- tests/unit/test_llm_factory.py +109 -0
- tests/unit/test_llm_factory_async.py +118 -0
- tests/unit/test_loops.py +403 -0
- tests/unit/test_map_node.py +144 -0
- tests/unit/test_no_backward_compat.py +56 -0
- tests/unit/test_node_factory.py +348 -0
- tests/unit/test_passthrough_node.py +126 -0
- tests/unit/test_prompts.py +324 -0
- tests/unit/test_python_nodes.py +198 -0
- tests/unit/test_reliability.py +298 -0
- tests/unit/test_result_export.py +234 -0
- tests/unit/test_router.py +296 -0
- tests/unit/test_sanitize.py +99 -0
- tests/unit/test_schema_loader.py +295 -0
- tests/unit/test_shell_tools.py +229 -0
- tests/unit/test_state_builder.py +331 -0
- tests/unit/test_state_builder_map.py +104 -0
- tests/unit/test_state_config.py +197 -0
- tests/unit/test_streaming.py +307 -0
- tests/unit/test_subgraph.py +596 -0
- tests/unit/test_template.py +190 -0
- tests/unit/test_tool_call_integration.py +164 -0
- tests/unit/test_tool_call_node.py +178 -0
- tests/unit/test_tool_nodes.py +129 -0
- tests/unit/test_websearch.py +234 -0
- yamlgraph/__init__.py +35 -0
- yamlgraph/builder.py +110 -0
- yamlgraph/cli/__init__.py +159 -0
- yamlgraph/cli/__main__.py +6 -0
- yamlgraph/cli/commands.py +231 -0
- yamlgraph/cli/deprecation.py +92 -0
- yamlgraph/cli/graph_commands.py +541 -0
- yamlgraph/cli/validators.py +37 -0
- yamlgraph/config.py +67 -0
- yamlgraph/constants.py +70 -0
- yamlgraph/error_handlers.py +227 -0
- yamlgraph/executor.py +290 -0
- yamlgraph/executor_async.py +288 -0
- yamlgraph/graph_loader.py +451 -0
- yamlgraph/map_compiler.py +150 -0
- yamlgraph/models/__init__.py +36 -0
- yamlgraph/models/graph_schema.py +181 -0
- yamlgraph/models/schemas.py +124 -0
- yamlgraph/models/state_builder.py +236 -0
- yamlgraph/node_factory.py +768 -0
- yamlgraph/routing.py +87 -0
- yamlgraph/schema_loader.py +240 -0
- yamlgraph/storage/__init__.py +20 -0
- yamlgraph/storage/checkpointer.py +72 -0
- yamlgraph/storage/checkpointer_factory.py +123 -0
- yamlgraph/storage/database.py +320 -0
- yamlgraph/storage/export.py +269 -0
- yamlgraph/tools/__init__.py +1 -0
- yamlgraph/tools/agent.py +320 -0
- yamlgraph/tools/graph_linter.py +388 -0
- yamlgraph/tools/langsmith_tools.py +125 -0
- yamlgraph/tools/nodes.py +126 -0
- yamlgraph/tools/python_tool.py +179 -0
- yamlgraph/tools/shell.py +205 -0
- yamlgraph/tools/websearch.py +242 -0
- yamlgraph/utils/__init__.py +48 -0
- yamlgraph/utils/conditions.py +157 -0
- yamlgraph/utils/expressions.py +245 -0
- yamlgraph/utils/json_extract.py +104 -0
- yamlgraph/utils/langsmith.py +416 -0
- yamlgraph/utils/llm_factory.py +118 -0
- yamlgraph/utils/llm_factory_async.py +105 -0
- yamlgraph/utils/logging.py +104 -0
- yamlgraph/utils/prompts.py +171 -0
- yamlgraph/utils/sanitize.py +98 -0
- yamlgraph/utils/template.py +102 -0
- yamlgraph/utils/validators.py +181 -0
- yamlgraph-0.3.9.dist-info/METADATA +1105 -0
- yamlgraph-0.3.9.dist-info/RECORD +185 -0
- yamlgraph-0.3.9.dist-info/WHEEL +5 -0
- yamlgraph-0.3.9.dist-info/entry_points.txt +2 -0
- yamlgraph-0.3.9.dist-info/licenses/LICENSE +33 -0
- yamlgraph-0.3.9.dist-info/top_level.txt +4 -0
yamlgraph/tools/agent.py
ADDED
|
@@ -0,0 +1,320 @@
|
|
|
1
|
+
"""Agent node factory for LLM-driven tool loops.
|
|
2
|
+
|
|
3
|
+
This module provides the agent node type that allows the LLM to
|
|
4
|
+
autonomously decide which tools to call until it has enough
|
|
5
|
+
information to provide a final answer.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import inspect
|
|
11
|
+
import logging
|
|
12
|
+
from collections.abc import Callable
|
|
13
|
+
from typing import Any
|
|
14
|
+
|
|
15
|
+
from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage
|
|
16
|
+
|
|
17
|
+
from yamlgraph.tools.python_tool import PythonToolConfig, load_python_function
|
|
18
|
+
from yamlgraph.tools.shell import ShellToolConfig, execute_shell_tool
|
|
19
|
+
from yamlgraph.utils.llm_factory import create_llm
|
|
20
|
+
from yamlgraph.utils.prompts import load_prompt
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def build_langchain_tool(name: str, config: ShellToolConfig) -> Callable:
|
|
26
|
+
"""Convert shell config to LangChain Tool.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
name: Tool name for LLM to reference
|
|
30
|
+
config: Shell tool configuration
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
LangChain-compatible tool function
|
|
34
|
+
"""
|
|
35
|
+
import re
|
|
36
|
+
|
|
37
|
+
from langchain_core.tools import StructuredTool
|
|
38
|
+
from pydantic import Field, create_model
|
|
39
|
+
|
|
40
|
+
# Extract variable names from command template
|
|
41
|
+
var_names = re.findall(r"\{(\w+)\}", config.command)
|
|
42
|
+
|
|
43
|
+
# Create dynamic Pydantic model for tool args
|
|
44
|
+
if var_names:
|
|
45
|
+
fields = {
|
|
46
|
+
var: (str, Field(description=f"Value for {var}")) for var in var_names
|
|
47
|
+
}
|
|
48
|
+
ArgsModel = create_model(f"{name}_args", **fields)
|
|
49
|
+
else:
|
|
50
|
+
ArgsModel = None
|
|
51
|
+
|
|
52
|
+
def execute_tool_with_dict(**kwargs) -> str:
|
|
53
|
+
"""Execute shell command with provided arguments."""
|
|
54
|
+
result = execute_shell_tool(config, kwargs)
|
|
55
|
+
if result.success:
|
|
56
|
+
return (
|
|
57
|
+
str(result.output).strip() if result.output is not None else "Success"
|
|
58
|
+
)
|
|
59
|
+
else:
|
|
60
|
+
return f"Error: {result.error}"
|
|
61
|
+
|
|
62
|
+
return StructuredTool.from_function(
|
|
63
|
+
func=execute_tool_with_dict,
|
|
64
|
+
name=name,
|
|
65
|
+
description=config.description,
|
|
66
|
+
args_schema=ArgsModel,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def build_python_tool(name: str, config: PythonToolConfig) -> Any:
|
|
71
|
+
"""Convert Python tool config to LangChain StructuredTool.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
name: Tool name for LLM to reference
|
|
75
|
+
config: Python tool configuration
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
LangChain StructuredTool
|
|
79
|
+
"""
|
|
80
|
+
from langchain_core.tools import StructuredTool
|
|
81
|
+
from pydantic import Field, create_model
|
|
82
|
+
|
|
83
|
+
# Load the Python function
|
|
84
|
+
func = load_python_function(config)
|
|
85
|
+
|
|
86
|
+
# Build args schema from function signature
|
|
87
|
+
sig = inspect.signature(func)
|
|
88
|
+
fields = {}
|
|
89
|
+
for param_name, param in sig.parameters.items():
|
|
90
|
+
# Skip *args, **kwargs
|
|
91
|
+
if param.kind in (
|
|
92
|
+
inspect.Parameter.VAR_POSITIONAL,
|
|
93
|
+
inspect.Parameter.VAR_KEYWORD,
|
|
94
|
+
):
|
|
95
|
+
continue
|
|
96
|
+
|
|
97
|
+
# Get type annotation or default to str
|
|
98
|
+
param_type = (
|
|
99
|
+
param.annotation if param.annotation != inspect.Parameter.empty else str
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
# Create field with description
|
|
103
|
+
fields[param_name] = (param_type, Field(description=f"Parameter: {param_name}"))
|
|
104
|
+
|
|
105
|
+
# Create dynamic Pydantic model
|
|
106
|
+
ArgsModel = create_model(f"{name}_args", **fields) if fields else None
|
|
107
|
+
|
|
108
|
+
def execute_python(**kwargs) -> str:
|
|
109
|
+
"""Execute the Python function and return result as string."""
|
|
110
|
+
try:
|
|
111
|
+
result = func(**kwargs)
|
|
112
|
+
return str(result) if result is not None else "Success"
|
|
113
|
+
except Exception as e:
|
|
114
|
+
return f"Error: {e}"
|
|
115
|
+
|
|
116
|
+
return StructuredTool.from_function(
|
|
117
|
+
func=execute_python,
|
|
118
|
+
name=name,
|
|
119
|
+
description=config.description,
|
|
120
|
+
args_schema=ArgsModel,
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def create_agent_node(
|
|
125
|
+
node_name: str,
|
|
126
|
+
node_config: dict[str, Any],
|
|
127
|
+
tools: dict[str, ShellToolConfig],
|
|
128
|
+
websearch_tools: dict[str, Any] | None = None,
|
|
129
|
+
python_tools: dict[str, PythonToolConfig] | None = None,
|
|
130
|
+
) -> Callable[[dict], dict]:
|
|
131
|
+
"""Create an agent node that loops with tool calls.
|
|
132
|
+
|
|
133
|
+
The agent will:
|
|
134
|
+
1. Send the prompt to the LLM with available tools
|
|
135
|
+
2. If LLM returns tool calls, execute them and feed results back
|
|
136
|
+
3. Repeat until LLM returns without tool calls or max_iterations reached
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
node_name: Name of the node in the graph
|
|
140
|
+
node_config: Node configuration from YAML
|
|
141
|
+
tools: Registry of available shell tools
|
|
142
|
+
websearch_tools: Registry of web search tools (LangChain StructuredTool)
|
|
143
|
+
python_tools: Registry of Python tools (PythonToolConfig)
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
Node function that runs the agent loop
|
|
147
|
+
|
|
148
|
+
Config options:
|
|
149
|
+
- tools: List of tool names to make available
|
|
150
|
+
- max_iterations: Max tool-call loops (default: 5)
|
|
151
|
+
- state_key: Key to store final answer (default: node_name)
|
|
152
|
+
- prompt: Prompt file name (default: "agent")
|
|
153
|
+
- tool_results_key: Optional key to store raw tool outputs
|
|
154
|
+
"""
|
|
155
|
+
if websearch_tools is None:
|
|
156
|
+
websearch_tools = {}
|
|
157
|
+
if python_tools is None:
|
|
158
|
+
python_tools = {}
|
|
159
|
+
|
|
160
|
+
tool_names = node_config.get("tools", [])
|
|
161
|
+
max_iterations = node_config.get("max_iterations", 5)
|
|
162
|
+
state_key = node_config.get("state_key", node_name)
|
|
163
|
+
prompt_name = node_config.get("prompt", "agent")
|
|
164
|
+
tool_results_key = node_config.get("tool_results_key")
|
|
165
|
+
|
|
166
|
+
# Build LangChain tools from configs
|
|
167
|
+
lc_tools = []
|
|
168
|
+
tool_lookup = {}
|
|
169
|
+
|
|
170
|
+
for name in tool_names:
|
|
171
|
+
if name in tools:
|
|
172
|
+
# Shell tool - need to wrap
|
|
173
|
+
lc_tools.append(build_langchain_tool(name, tools[name]))
|
|
174
|
+
tool_lookup[name] = tools[name]
|
|
175
|
+
elif name in websearch_tools:
|
|
176
|
+
# Websearch tool - already a LangChain tool
|
|
177
|
+
lc_tools.append(websearch_tools[name])
|
|
178
|
+
tool_lookup[name] = websearch_tools[name]
|
|
179
|
+
elif name in python_tools:
|
|
180
|
+
# Python tool - wrap as LangChain tool
|
|
181
|
+
lc_tools.append(build_python_tool(name, python_tools[name]))
|
|
182
|
+
tool_lookup[name] = python_tools[name]
|
|
183
|
+
else:
|
|
184
|
+
logger.warning(
|
|
185
|
+
f"Tool '{name}' not found in shell, websearch, or python registries"
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
def node_fn(state: dict) -> dict:
|
|
189
|
+
"""Execute the agent loop."""
|
|
190
|
+
# Load prompts - fail fast if missing
|
|
191
|
+
prompt_config = load_prompt(prompt_name)
|
|
192
|
+
system_prompt = prompt_config.get("system", "")
|
|
193
|
+
user_template = prompt_config.get("user", "{input}")
|
|
194
|
+
|
|
195
|
+
# Format user prompt with state - handle missing keys
|
|
196
|
+
import re
|
|
197
|
+
|
|
198
|
+
def replace_var(match):
|
|
199
|
+
key = match.group(1)
|
|
200
|
+
return str(state.get(key, f"{{{key}}}"))
|
|
201
|
+
|
|
202
|
+
user_prompt = re.sub(r"\{(\w+)\}", replace_var, user_template)
|
|
203
|
+
|
|
204
|
+
# Initialize messages - preserve existing if multi-turn
|
|
205
|
+
existing_messages = list(state.get("messages", []))
|
|
206
|
+
if existing_messages:
|
|
207
|
+
# Multi-turn: add new user message to existing conversation
|
|
208
|
+
messages = existing_messages + [HumanMessage(content=user_prompt)]
|
|
209
|
+
else:
|
|
210
|
+
# New conversation: start with system + user
|
|
211
|
+
messages = [
|
|
212
|
+
SystemMessage(content=system_prompt),
|
|
213
|
+
HumanMessage(content=user_prompt),
|
|
214
|
+
]
|
|
215
|
+
|
|
216
|
+
# Track raw tool outputs for persistence
|
|
217
|
+
tool_results: list[dict] = []
|
|
218
|
+
|
|
219
|
+
# Get LLM with tools bound
|
|
220
|
+
llm = create_llm().bind_tools(lc_tools)
|
|
221
|
+
|
|
222
|
+
logger.info(
|
|
223
|
+
f"🤖 Starting agent loop: {node_name} (max {max_iterations} iterations)"
|
|
224
|
+
)
|
|
225
|
+
logger.debug(f"Tools available: {[t.name for t in lc_tools]}")
|
|
226
|
+
logger.debug(f"User prompt: {user_prompt[:100]}...")
|
|
227
|
+
|
|
228
|
+
for iteration in range(max_iterations):
|
|
229
|
+
logger.debug(f"Agent iteration {iteration + 1}/{max_iterations}")
|
|
230
|
+
|
|
231
|
+
# Get LLM response
|
|
232
|
+
response = llm.invoke(messages)
|
|
233
|
+
messages.append(response)
|
|
234
|
+
|
|
235
|
+
logger.debug(f"Response tool_calls: {response.tool_calls}")
|
|
236
|
+
|
|
237
|
+
# Check if LLM wants to call tools
|
|
238
|
+
if not response.tool_calls:
|
|
239
|
+
# Done - LLM finished reasoning
|
|
240
|
+
logger.info(f"✓ Agent completed after {iteration + 1} iterations")
|
|
241
|
+
result = {
|
|
242
|
+
state_key: response.content,
|
|
243
|
+
"current_step": node_name,
|
|
244
|
+
"_agent_iterations": iteration + 1,
|
|
245
|
+
"messages": messages, # Return for accumulation
|
|
246
|
+
}
|
|
247
|
+
if tool_results_key and tool_results:
|
|
248
|
+
result[tool_results_key] = tool_results
|
|
249
|
+
return result
|
|
250
|
+
|
|
251
|
+
# Execute tool calls
|
|
252
|
+
for tool_call in response.tool_calls:
|
|
253
|
+
tool_name = tool_call["name"]
|
|
254
|
+
tool_args = tool_call["args"]
|
|
255
|
+
tool_id = tool_call.get("id", f"call_{iteration}")
|
|
256
|
+
|
|
257
|
+
logger.info(f"🔧 Calling tool: {tool_name}({tool_args})")
|
|
258
|
+
|
|
259
|
+
# Execute the tool
|
|
260
|
+
tool_config = tool_lookup.get(tool_name)
|
|
261
|
+
if tool_config:
|
|
262
|
+
# Check the type of tool config
|
|
263
|
+
if isinstance(tool_config, ShellToolConfig):
|
|
264
|
+
# Shell tool - use execute_shell_tool
|
|
265
|
+
result = execute_shell_tool(tool_config, tool_args)
|
|
266
|
+
output = (
|
|
267
|
+
str(result.output)
|
|
268
|
+
if result.success
|
|
269
|
+
else f"Error: {result.error}"
|
|
270
|
+
)
|
|
271
|
+
success = result.success
|
|
272
|
+
elif isinstance(tool_config, PythonToolConfig):
|
|
273
|
+
# Python tool - load and execute function
|
|
274
|
+
try:
|
|
275
|
+
func = load_python_function(tool_config)
|
|
276
|
+
output = str(func(**tool_args))
|
|
277
|
+
success = True
|
|
278
|
+
except Exception as e:
|
|
279
|
+
output = f"Error: {e}"
|
|
280
|
+
success = False
|
|
281
|
+
else:
|
|
282
|
+
# LangChain tool (websearch, etc) - invoke directly
|
|
283
|
+
try:
|
|
284
|
+
output = tool_config.invoke(tool_args)
|
|
285
|
+
success = True
|
|
286
|
+
except Exception as e:
|
|
287
|
+
output = f"Error: {e}"
|
|
288
|
+
success = False
|
|
289
|
+
else:
|
|
290
|
+
output = f"Error: Unknown tool '{tool_name}'"
|
|
291
|
+
success = False
|
|
292
|
+
|
|
293
|
+
# Store raw tool result for persistence
|
|
294
|
+
tool_results.append(
|
|
295
|
+
{
|
|
296
|
+
"tool": tool_name,
|
|
297
|
+
"args": tool_args,
|
|
298
|
+
"output": output,
|
|
299
|
+
"success": success,
|
|
300
|
+
}
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
# Add tool result to messages
|
|
304
|
+
messages.append(ToolMessage(content=output, tool_call_id=tool_id))
|
|
305
|
+
|
|
306
|
+
# Hit max iterations
|
|
307
|
+
logger.warning(f"Agent hit max iterations ({max_iterations})")
|
|
308
|
+
last_content = messages[-1].content if hasattr(messages[-1], "content") else ""
|
|
309
|
+
result = {
|
|
310
|
+
state_key: last_content,
|
|
311
|
+
"current_step": node_name,
|
|
312
|
+
"_agent_iterations": max_iterations,
|
|
313
|
+
"_agent_limit_reached": True,
|
|
314
|
+
"messages": messages, # Return for accumulation
|
|
315
|
+
}
|
|
316
|
+
if tool_results_key and tool_results:
|
|
317
|
+
result[tool_results_key] = tool_results
|
|
318
|
+
return result
|
|
319
|
+
|
|
320
|
+
return node_fn
|
|
@@ -0,0 +1,388 @@
|
|
|
1
|
+
"""Graph linter for validating YAML graph files.
|
|
2
|
+
|
|
3
|
+
Checks for common issues:
|
|
4
|
+
- Missing state declarations
|
|
5
|
+
- Undefined tool references
|
|
6
|
+
- Missing prompt files
|
|
7
|
+
- Unreachable nodes
|
|
8
|
+
- Invalid node types
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import logging
|
|
14
|
+
import re
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
from typing import Any
|
|
17
|
+
|
|
18
|
+
import yaml
|
|
19
|
+
from pydantic import BaseModel
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
# Valid node types
|
|
24
|
+
VALID_NODE_TYPES = {"llm", "router", "agent", "map", "python"}
|
|
25
|
+
|
|
26
|
+
# Built-in state fields that don't need declaration
|
|
27
|
+
BUILTIN_STATE_FIELDS = {
|
|
28
|
+
"thread_id",
|
|
29
|
+
"current_step",
|
|
30
|
+
"error",
|
|
31
|
+
"errors",
|
|
32
|
+
"messages",
|
|
33
|
+
"_loop_counts",
|
|
34
|
+
"_loop_limit_reached",
|
|
35
|
+
"_agent_iterations",
|
|
36
|
+
"_agent_limit_reached",
|
|
37
|
+
"started_at",
|
|
38
|
+
"completed_at",
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class LintIssue(BaseModel):
|
|
43
|
+
"""A single lint issue found in the graph."""
|
|
44
|
+
|
|
45
|
+
severity: str # "error", "warning", "info"
|
|
46
|
+
code: str # e.g., "E001", "W002"
|
|
47
|
+
message: str
|
|
48
|
+
line: int | None = None
|
|
49
|
+
fix: str | None = None
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class LintResult(BaseModel):
|
|
53
|
+
"""Result of linting a graph file."""
|
|
54
|
+
|
|
55
|
+
file: str
|
|
56
|
+
issues: list[LintIssue]
|
|
57
|
+
valid: bool
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def _load_graph(graph_path: Path) -> dict[str, Any]:
|
|
61
|
+
"""Load and parse a YAML graph file."""
|
|
62
|
+
with open(graph_path) as f:
|
|
63
|
+
return yaml.safe_load(f) or {}
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _extract_variables(text: str) -> set[str]:
|
|
67
|
+
"""Extract {variable} placeholders from text.
|
|
68
|
+
|
|
69
|
+
Ignores escaped {{variable}} (doubled braces).
|
|
70
|
+
"""
|
|
71
|
+
# Find all {word} patterns but not {{word}}
|
|
72
|
+
# First, temporarily replace {{ and }} to protect them
|
|
73
|
+
protected = text.replace("{{", "\x00").replace("}}", "\x01")
|
|
74
|
+
matches = re.findall(r"\{(\w+)\}", protected)
|
|
75
|
+
return set(matches)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _get_prompt_path(prompt_name: str, prompts_dir: Path) -> Path:
|
|
79
|
+
"""Get the full path to a prompt file."""
|
|
80
|
+
return prompts_dir / f"{prompt_name}.yaml"
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def check_state_declarations(
|
|
84
|
+
graph_path: Path, project_root: Path | None = None
|
|
85
|
+
) -> list[LintIssue]:
|
|
86
|
+
"""Check if variables used in prompts/tools are declared in state.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
graph_path: Path to the graph YAML file
|
|
90
|
+
project_root: Root directory containing prompts/ folder
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
List of lint issues for missing state declarations
|
|
94
|
+
"""
|
|
95
|
+
issues = []
|
|
96
|
+
graph = _load_graph(graph_path)
|
|
97
|
+
|
|
98
|
+
if project_root is None:
|
|
99
|
+
project_root = graph_path.parent
|
|
100
|
+
|
|
101
|
+
prompts_dir = project_root / "prompts"
|
|
102
|
+
|
|
103
|
+
# Get declared state variables
|
|
104
|
+
declared_state = set(graph.get("state", {}).keys())
|
|
105
|
+
declared_state.update(BUILTIN_STATE_FIELDS)
|
|
106
|
+
|
|
107
|
+
# Also include state_keys from nodes as they become available at runtime
|
|
108
|
+
for node_config in graph.get("nodes", {}).values():
|
|
109
|
+
if "state_key" in node_config:
|
|
110
|
+
declared_state.add(node_config["state_key"])
|
|
111
|
+
|
|
112
|
+
# Find tools used by agent nodes (their variables come from LLM, not state)
|
|
113
|
+
agent_tools: set[str] = set()
|
|
114
|
+
for node_config in graph.get("nodes", {}).values():
|
|
115
|
+
if node_config.get("type") == "agent":
|
|
116
|
+
agent_tools.update(node_config.get("tools", []))
|
|
117
|
+
|
|
118
|
+
# Check shell tool commands for variables (skip agent tools)
|
|
119
|
+
for tool_name, tool_config in graph.get("tools", {}).items():
|
|
120
|
+
if tool_config.get("type") == "shell":
|
|
121
|
+
# Skip tools used by agent nodes - their args come from LLM
|
|
122
|
+
if tool_name in agent_tools:
|
|
123
|
+
continue
|
|
124
|
+
|
|
125
|
+
command = tool_config.get("command", "")
|
|
126
|
+
variables = _extract_variables(command)
|
|
127
|
+
for var in variables:
|
|
128
|
+
if var not in declared_state:
|
|
129
|
+
issues.append(
|
|
130
|
+
LintIssue(
|
|
131
|
+
severity="error",
|
|
132
|
+
code="E001",
|
|
133
|
+
message=f"Variable '{var}' used in tool '{tool_name}' "
|
|
134
|
+
f"but not declared in state",
|
|
135
|
+
fix=f"Add '{var}: str' to the state section",
|
|
136
|
+
)
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
# Check prompt files for variables
|
|
140
|
+
for _node_name, node_config in graph.get("nodes", {}).items():
|
|
141
|
+
prompt_name = node_config.get("prompt")
|
|
142
|
+
if prompt_name:
|
|
143
|
+
prompt_path = _get_prompt_path(prompt_name, prompts_dir)
|
|
144
|
+
if prompt_path.exists():
|
|
145
|
+
with open(prompt_path) as f:
|
|
146
|
+
prompt_content = f.read()
|
|
147
|
+
variables = _extract_variables(prompt_content)
|
|
148
|
+
|
|
149
|
+
# Node-level variables provide values for prompt placeholders
|
|
150
|
+
node_variables = set(node_config.get("variables", {}).keys())
|
|
151
|
+
|
|
152
|
+
for var in variables:
|
|
153
|
+
# Variable is valid if it's in state OR defined in node variables
|
|
154
|
+
if var not in declared_state and var not in node_variables:
|
|
155
|
+
issues.append(
|
|
156
|
+
LintIssue(
|
|
157
|
+
severity="error",
|
|
158
|
+
code="E002",
|
|
159
|
+
message=f"Variable '{var}' used in prompt "
|
|
160
|
+
f"'{prompt_name}' but not declared in state",
|
|
161
|
+
fix=f"Add '{var}: str' to the state section",
|
|
162
|
+
)
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
return issues
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def check_tool_references(graph_path: Path) -> list[LintIssue]:
|
|
169
|
+
"""Check that all tool references in nodes are defined.
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
graph_path: Path to the graph YAML file
|
|
173
|
+
|
|
174
|
+
Returns:
|
|
175
|
+
List of lint issues for undefined/unused tools
|
|
176
|
+
"""
|
|
177
|
+
issues = []
|
|
178
|
+
graph = _load_graph(graph_path)
|
|
179
|
+
|
|
180
|
+
defined_tools = set(graph.get("tools", {}).keys())
|
|
181
|
+
used_tools: set[str] = set()
|
|
182
|
+
|
|
183
|
+
# Find all tool references in nodes
|
|
184
|
+
for node_name, node_config in graph.get("nodes", {}).items():
|
|
185
|
+
node_tools = node_config.get("tools", [])
|
|
186
|
+
for tool in node_tools:
|
|
187
|
+
used_tools.add(tool)
|
|
188
|
+
if tool not in defined_tools:
|
|
189
|
+
issues.append(
|
|
190
|
+
LintIssue(
|
|
191
|
+
severity="error",
|
|
192
|
+
code="E003",
|
|
193
|
+
message=f"Tool '{tool}' referenced in node '{node_name}' "
|
|
194
|
+
f"but not defined in tools section",
|
|
195
|
+
fix=f"Add tool '{tool}' to the tools section or remove reference",
|
|
196
|
+
)
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
# Check for unused tools
|
|
200
|
+
for tool in defined_tools - used_tools:
|
|
201
|
+
issues.append(
|
|
202
|
+
LintIssue(
|
|
203
|
+
severity="warning",
|
|
204
|
+
code="W001",
|
|
205
|
+
message=f"Tool '{tool}' is defined but never used",
|
|
206
|
+
fix=f"Remove unused tool '{tool}' from tools section",
|
|
207
|
+
)
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
return issues
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def check_prompt_files(
|
|
214
|
+
graph_path: Path, project_root: Path | None = None
|
|
215
|
+
) -> list[LintIssue]:
|
|
216
|
+
"""Check that all prompt files referenced by nodes exist.
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
graph_path: Path to the graph YAML file
|
|
220
|
+
project_root: Root directory containing prompts/ folder
|
|
221
|
+
|
|
222
|
+
Returns:
|
|
223
|
+
List of lint issues for missing prompt files
|
|
224
|
+
"""
|
|
225
|
+
issues = []
|
|
226
|
+
graph = _load_graph(graph_path)
|
|
227
|
+
|
|
228
|
+
if project_root is None:
|
|
229
|
+
project_root = graph_path.parent
|
|
230
|
+
|
|
231
|
+
prompts_dir = project_root / "prompts"
|
|
232
|
+
|
|
233
|
+
for node_name, node_config in graph.get("nodes", {}).items():
|
|
234
|
+
prompt_name = node_config.get("prompt")
|
|
235
|
+
if prompt_name:
|
|
236
|
+
prompt_path = _get_prompt_path(prompt_name, prompts_dir)
|
|
237
|
+
if not prompt_path.exists():
|
|
238
|
+
issues.append(
|
|
239
|
+
LintIssue(
|
|
240
|
+
severity="error",
|
|
241
|
+
code="E004",
|
|
242
|
+
message=f"Prompt file '{prompt_name}.yaml' not found "
|
|
243
|
+
f"for node '{node_name}'",
|
|
244
|
+
fix=f"Create file: prompts/{prompt_name}.yaml",
|
|
245
|
+
)
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
return issues
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
def check_edge_coverage(graph_path: Path) -> list[LintIssue]:
|
|
252
|
+
"""Check that all nodes are reachable and have paths to END.
|
|
253
|
+
|
|
254
|
+
Args:
|
|
255
|
+
graph_path: Path to the graph YAML file
|
|
256
|
+
|
|
257
|
+
Returns:
|
|
258
|
+
List of lint issues for unreachable/dead-end nodes
|
|
259
|
+
"""
|
|
260
|
+
issues = []
|
|
261
|
+
graph = _load_graph(graph_path)
|
|
262
|
+
|
|
263
|
+
nodes = set(graph.get("nodes", {}).keys())
|
|
264
|
+
edges = graph.get("edges", [])
|
|
265
|
+
|
|
266
|
+
# Build adjacency lists
|
|
267
|
+
reachable_from_start: set[str] = set()
|
|
268
|
+
can_reach_end: set[str] = set()
|
|
269
|
+
nodes_in_edges: set[str] = set()
|
|
270
|
+
|
|
271
|
+
def normalize_targets(target) -> list[str]:
|
|
272
|
+
"""Handle both single target and list of targets."""
|
|
273
|
+
if isinstance(target, list):
|
|
274
|
+
return target
|
|
275
|
+
return [target] if target else []
|
|
276
|
+
|
|
277
|
+
# Forward traversal from START
|
|
278
|
+
frontier = {"START"}
|
|
279
|
+
while frontier:
|
|
280
|
+
current = frontier.pop()
|
|
281
|
+
for edge in edges:
|
|
282
|
+
if edge.get("from") == current:
|
|
283
|
+
targets = normalize_targets(edge.get("to"))
|
|
284
|
+
for target in targets:
|
|
285
|
+
nodes_in_edges.add(target)
|
|
286
|
+
if target not in reachable_from_start and target != "END":
|
|
287
|
+
reachable_from_start.add(target)
|
|
288
|
+
frontier.add(target)
|
|
289
|
+
|
|
290
|
+
# Backward traversal from END
|
|
291
|
+
frontier = {"END"}
|
|
292
|
+
visited_backward: set[str] = set()
|
|
293
|
+
while frontier:
|
|
294
|
+
current = frontier.pop()
|
|
295
|
+
visited_backward.add(current)
|
|
296
|
+
for edge in edges:
|
|
297
|
+
targets = normalize_targets(edge.get("to"))
|
|
298
|
+
if current in targets:
|
|
299
|
+
source = edge.get("from")
|
|
300
|
+
nodes_in_edges.add(source)
|
|
301
|
+
if source not in can_reach_end and source != "START":
|
|
302
|
+
can_reach_end.add(source)
|
|
303
|
+
frontier.add(source)
|
|
304
|
+
|
|
305
|
+
# Check for orphaned nodes (not in any edge)
|
|
306
|
+
for node in nodes:
|
|
307
|
+
if node not in reachable_from_start:
|
|
308
|
+
issues.append(
|
|
309
|
+
LintIssue(
|
|
310
|
+
severity="warning",
|
|
311
|
+
code="W002",
|
|
312
|
+
message=f"Node '{node}' is not reachable from START",
|
|
313
|
+
fix=f"Add edge from START or another node to '{node}'",
|
|
314
|
+
)
|
|
315
|
+
)
|
|
316
|
+
elif node not in can_reach_end:
|
|
317
|
+
issues.append(
|
|
318
|
+
LintIssue(
|
|
319
|
+
severity="warning",
|
|
320
|
+
code="W003",
|
|
321
|
+
message=f"Node '{node}' has no path to END",
|
|
322
|
+
fix=f"Add edge from '{node}' to END or another node",
|
|
323
|
+
)
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
return issues
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
def check_node_types(graph_path: Path) -> list[LintIssue]:
|
|
330
|
+
"""Check that all node types are valid.
|
|
331
|
+
|
|
332
|
+
Args:
|
|
333
|
+
graph_path: Path to the graph YAML file
|
|
334
|
+
|
|
335
|
+
Returns:
|
|
336
|
+
List of lint issues for invalid node types
|
|
337
|
+
"""
|
|
338
|
+
issues = []
|
|
339
|
+
graph = _load_graph(graph_path)
|
|
340
|
+
|
|
341
|
+
for node_name, node_config in graph.get("nodes", {}).items():
|
|
342
|
+
node_type = node_config.get("type")
|
|
343
|
+
if node_type and node_type not in VALID_NODE_TYPES:
|
|
344
|
+
issues.append(
|
|
345
|
+
LintIssue(
|
|
346
|
+
severity="error",
|
|
347
|
+
code="E005",
|
|
348
|
+
message=f"Invalid node type '{node_type}' in node '{node_name}'",
|
|
349
|
+
fix=f"Use one of: {', '.join(sorted(VALID_NODE_TYPES))}",
|
|
350
|
+
)
|
|
351
|
+
)
|
|
352
|
+
|
|
353
|
+
return issues
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
def lint_graph(
|
|
357
|
+
graph_path: Path | str, project_root: Path | str | None = None
|
|
358
|
+
) -> LintResult:
|
|
359
|
+
"""Lint a YAML graph file for issues.
|
|
360
|
+
|
|
361
|
+
Args:
|
|
362
|
+
graph_path: Path to the graph YAML file
|
|
363
|
+
project_root: Root directory containing prompts/ folder
|
|
364
|
+
|
|
365
|
+
Returns:
|
|
366
|
+
LintResult with all issues found
|
|
367
|
+
"""
|
|
368
|
+
graph_path = Path(graph_path)
|
|
369
|
+
if project_root:
|
|
370
|
+
project_root = Path(project_root)
|
|
371
|
+
|
|
372
|
+
all_issues: list[LintIssue] = []
|
|
373
|
+
|
|
374
|
+
# Run all checks
|
|
375
|
+
all_issues.extend(check_state_declarations(graph_path, project_root))
|
|
376
|
+
all_issues.extend(check_tool_references(graph_path))
|
|
377
|
+
all_issues.extend(check_prompt_files(graph_path, project_root))
|
|
378
|
+
all_issues.extend(check_edge_coverage(graph_path))
|
|
379
|
+
all_issues.extend(check_node_types(graph_path))
|
|
380
|
+
|
|
381
|
+
# Determine validity (no errors)
|
|
382
|
+
has_errors = any(issue.severity == "error" for issue in all_issues)
|
|
383
|
+
|
|
384
|
+
return LintResult(
|
|
385
|
+
file=str(graph_path),
|
|
386
|
+
issues=all_issues,
|
|
387
|
+
valid=not has_errors,
|
|
388
|
+
)
|