yamlgraph 0.3.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- examples/__init__.py +1 -0
- examples/codegen/__init__.py +5 -0
- examples/codegen/models/__init__.py +13 -0
- examples/codegen/models/schemas.py +76 -0
- examples/codegen/tests/__init__.py +1 -0
- examples/codegen/tests/test_ai_helpers.py +235 -0
- examples/codegen/tests/test_ast_analysis.py +174 -0
- examples/codegen/tests/test_code_analysis.py +134 -0
- examples/codegen/tests/test_code_context.py +301 -0
- examples/codegen/tests/test_code_nav.py +89 -0
- examples/codegen/tests/test_dependency_tools.py +119 -0
- examples/codegen/tests/test_example_tools.py +185 -0
- examples/codegen/tests/test_git_tools.py +112 -0
- examples/codegen/tests/test_impl_agent_schemas.py +193 -0
- examples/codegen/tests/test_impl_agent_v4_graph.py +94 -0
- examples/codegen/tests/test_jedi_analysis.py +226 -0
- examples/codegen/tests/test_meta_tools.py +250 -0
- examples/codegen/tests/test_plan_discovery_prompt.py +98 -0
- examples/codegen/tests/test_syntax_tools.py +85 -0
- examples/codegen/tests/test_synthesize_prompt.py +94 -0
- examples/codegen/tests/test_template_tools.py +244 -0
- examples/codegen/tools/__init__.py +80 -0
- examples/codegen/tools/ai_helpers.py +420 -0
- examples/codegen/tools/ast_analysis.py +92 -0
- examples/codegen/tools/code_context.py +180 -0
- examples/codegen/tools/code_nav.py +52 -0
- examples/codegen/tools/dependency_tools.py +120 -0
- examples/codegen/tools/example_tools.py +188 -0
- examples/codegen/tools/git_tools.py +151 -0
- examples/codegen/tools/impl_executor.py +614 -0
- examples/codegen/tools/jedi_analysis.py +311 -0
- examples/codegen/tools/meta_tools.py +202 -0
- examples/codegen/tools/syntax_tools.py +26 -0
- examples/codegen/tools/template_tools.py +356 -0
- examples/fastapi_interview.py +167 -0
- examples/npc/api/__init__.py +1 -0
- examples/npc/api/app.py +100 -0
- examples/npc/api/routes/__init__.py +5 -0
- examples/npc/api/routes/encounter.py +182 -0
- examples/npc/api/session.py +330 -0
- examples/npc/demo.py +387 -0
- examples/npc/nodes/__init__.py +5 -0
- examples/npc/nodes/image_node.py +92 -0
- examples/npc/run_encounter.py +230 -0
- examples/shared/__init__.py +0 -0
- examples/shared/replicate_tool.py +238 -0
- examples/storyboard/__init__.py +1 -0
- examples/storyboard/generate_videos.py +335 -0
- examples/storyboard/nodes/__init__.py +12 -0
- examples/storyboard/nodes/animated_character_node.py +248 -0
- examples/storyboard/nodes/animated_image_node.py +138 -0
- examples/storyboard/nodes/character_node.py +162 -0
- examples/storyboard/nodes/image_node.py +118 -0
- examples/storyboard/nodes/replicate_tool.py +49 -0
- examples/storyboard/retry_images.py +118 -0
- scripts/demo_async_executor.py +212 -0
- scripts/demo_interview_e2e.py +200 -0
- scripts/demo_streaming.py +140 -0
- scripts/run_interview_demo.py +94 -0
- scripts/test_interrupt_fix.py +26 -0
- tests/__init__.py +1 -0
- tests/conftest.py +178 -0
- tests/integration/__init__.py +1 -0
- tests/integration/test_animated_storyboard.py +63 -0
- tests/integration/test_cli_commands.py +242 -0
- tests/integration/test_colocated_prompts.py +139 -0
- tests/integration/test_map_demo.py +50 -0
- tests/integration/test_memory_demo.py +283 -0
- tests/integration/test_npc_api/__init__.py +1 -0
- tests/integration/test_npc_api/test_routes.py +357 -0
- tests/integration/test_npc_api/test_session.py +216 -0
- tests/integration/test_pipeline_flow.py +105 -0
- tests/integration/test_providers.py +163 -0
- tests/integration/test_resume.py +75 -0
- tests/integration/test_subgraph_integration.py +295 -0
- tests/integration/test_subgraph_interrupt.py +106 -0
- tests/unit/__init__.py +1 -0
- tests/unit/test_agent_nodes.py +355 -0
- tests/unit/test_async_executor.py +346 -0
- tests/unit/test_checkpointer.py +212 -0
- tests/unit/test_checkpointer_factory.py +212 -0
- tests/unit/test_cli.py +121 -0
- tests/unit/test_cli_package.py +81 -0
- tests/unit/test_compile_graph_map.py +132 -0
- tests/unit/test_conditions_routing.py +253 -0
- tests/unit/test_config.py +93 -0
- tests/unit/test_conversation_memory.py +276 -0
- tests/unit/test_database.py +145 -0
- tests/unit/test_deprecation.py +104 -0
- tests/unit/test_executor.py +172 -0
- tests/unit/test_executor_async.py +179 -0
- tests/unit/test_export.py +149 -0
- tests/unit/test_expressions.py +178 -0
- tests/unit/test_feature_brainstorm.py +194 -0
- tests/unit/test_format_prompt.py +145 -0
- tests/unit/test_generic_report.py +200 -0
- tests/unit/test_graph_commands.py +327 -0
- tests/unit/test_graph_linter.py +627 -0
- tests/unit/test_graph_loader.py +357 -0
- tests/unit/test_graph_schema.py +193 -0
- tests/unit/test_inline_schema.py +151 -0
- tests/unit/test_interrupt_node.py +182 -0
- tests/unit/test_issues.py +164 -0
- tests/unit/test_jinja2_prompts.py +85 -0
- tests/unit/test_json_extract.py +134 -0
- tests/unit/test_langsmith.py +600 -0
- tests/unit/test_langsmith_tools.py +204 -0
- tests/unit/test_llm_factory.py +109 -0
- tests/unit/test_llm_factory_async.py +118 -0
- tests/unit/test_loops.py +403 -0
- tests/unit/test_map_node.py +144 -0
- tests/unit/test_no_backward_compat.py +56 -0
- tests/unit/test_node_factory.py +348 -0
- tests/unit/test_passthrough_node.py +126 -0
- tests/unit/test_prompts.py +324 -0
- tests/unit/test_python_nodes.py +198 -0
- tests/unit/test_reliability.py +298 -0
- tests/unit/test_result_export.py +234 -0
- tests/unit/test_router.py +296 -0
- tests/unit/test_sanitize.py +99 -0
- tests/unit/test_schema_loader.py +295 -0
- tests/unit/test_shell_tools.py +229 -0
- tests/unit/test_state_builder.py +331 -0
- tests/unit/test_state_builder_map.py +104 -0
- tests/unit/test_state_config.py +197 -0
- tests/unit/test_streaming.py +307 -0
- tests/unit/test_subgraph.py +596 -0
- tests/unit/test_template.py +190 -0
- tests/unit/test_tool_call_integration.py +164 -0
- tests/unit/test_tool_call_node.py +178 -0
- tests/unit/test_tool_nodes.py +129 -0
- tests/unit/test_websearch.py +234 -0
- yamlgraph/__init__.py +35 -0
- yamlgraph/builder.py +110 -0
- yamlgraph/cli/__init__.py +159 -0
- yamlgraph/cli/__main__.py +6 -0
- yamlgraph/cli/commands.py +231 -0
- yamlgraph/cli/deprecation.py +92 -0
- yamlgraph/cli/graph_commands.py +541 -0
- yamlgraph/cli/validators.py +37 -0
- yamlgraph/config.py +67 -0
- yamlgraph/constants.py +70 -0
- yamlgraph/error_handlers.py +227 -0
- yamlgraph/executor.py +290 -0
- yamlgraph/executor_async.py +288 -0
- yamlgraph/graph_loader.py +451 -0
- yamlgraph/map_compiler.py +150 -0
- yamlgraph/models/__init__.py +36 -0
- yamlgraph/models/graph_schema.py +181 -0
- yamlgraph/models/schemas.py +124 -0
- yamlgraph/models/state_builder.py +236 -0
- yamlgraph/node_factory.py +768 -0
- yamlgraph/routing.py +87 -0
- yamlgraph/schema_loader.py +240 -0
- yamlgraph/storage/__init__.py +20 -0
- yamlgraph/storage/checkpointer.py +72 -0
- yamlgraph/storage/checkpointer_factory.py +123 -0
- yamlgraph/storage/database.py +320 -0
- yamlgraph/storage/export.py +269 -0
- yamlgraph/tools/__init__.py +1 -0
- yamlgraph/tools/agent.py +320 -0
- yamlgraph/tools/graph_linter.py +388 -0
- yamlgraph/tools/langsmith_tools.py +125 -0
- yamlgraph/tools/nodes.py +126 -0
- yamlgraph/tools/python_tool.py +179 -0
- yamlgraph/tools/shell.py +205 -0
- yamlgraph/tools/websearch.py +242 -0
- yamlgraph/utils/__init__.py +48 -0
- yamlgraph/utils/conditions.py +157 -0
- yamlgraph/utils/expressions.py +245 -0
- yamlgraph/utils/json_extract.py +104 -0
- yamlgraph/utils/langsmith.py +416 -0
- yamlgraph/utils/llm_factory.py +118 -0
- yamlgraph/utils/llm_factory_async.py +105 -0
- yamlgraph/utils/logging.py +104 -0
- yamlgraph/utils/prompts.py +171 -0
- yamlgraph/utils/sanitize.py +98 -0
- yamlgraph/utils/template.py +102 -0
- yamlgraph/utils/validators.py +181 -0
- yamlgraph-0.3.9.dist-info/METADATA +1105 -0
- yamlgraph-0.3.9.dist-info/RECORD +185 -0
- yamlgraph-0.3.9.dist-info/WHEEL +5 -0
- yamlgraph-0.3.9.dist-info/entry_points.txt +2 -0
- yamlgraph-0.3.9.dist-info/licenses/LICENSE +33 -0
- yamlgraph-0.3.9.dist-info/top_level.txt +4 -0
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
"""Tests for yamlgraph.utils.template module - Variable extraction and validation."""
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class TestExtractVariables:
|
|
7
|
+
"""Tests for extract_variables function."""
|
|
8
|
+
|
|
9
|
+
def test_extract_simple_variables(self):
|
|
10
|
+
"""Should extract {var} placeholders."""
|
|
11
|
+
from yamlgraph.utils.template import extract_variables
|
|
12
|
+
|
|
13
|
+
template = "Hello {name}, your style is {style}."
|
|
14
|
+
variables = extract_variables(template)
|
|
15
|
+
assert variables == {"name", "style"}
|
|
16
|
+
|
|
17
|
+
def test_extract_single_variable(self):
|
|
18
|
+
"""Should extract a single variable."""
|
|
19
|
+
from yamlgraph.utils.template import extract_variables
|
|
20
|
+
|
|
21
|
+
template = "Welcome {user}!"
|
|
22
|
+
variables = extract_variables(template)
|
|
23
|
+
assert variables == {"user"}
|
|
24
|
+
|
|
25
|
+
def test_extract_no_variables(self):
|
|
26
|
+
"""Should return empty set when no variables."""
|
|
27
|
+
from yamlgraph.utils.template import extract_variables
|
|
28
|
+
|
|
29
|
+
template = "No variables here"
|
|
30
|
+
variables = extract_variables(template)
|
|
31
|
+
assert variables == set()
|
|
32
|
+
|
|
33
|
+
def test_extract_duplicate_variables(self):
|
|
34
|
+
"""Should deduplicate variables."""
|
|
35
|
+
from yamlgraph.utils.template import extract_variables
|
|
36
|
+
|
|
37
|
+
template = "{name} and {name} again"
|
|
38
|
+
variables = extract_variables(template)
|
|
39
|
+
assert variables == {"name"}
|
|
40
|
+
|
|
41
|
+
def test_extract_jinja2_variable(self):
|
|
42
|
+
"""Should extract {{ var }} Jinja2 variables."""
|
|
43
|
+
from yamlgraph.utils.template import extract_variables
|
|
44
|
+
|
|
45
|
+
template = "Hello {{ name }}!"
|
|
46
|
+
variables = extract_variables(template)
|
|
47
|
+
assert "name" in variables
|
|
48
|
+
|
|
49
|
+
def test_extract_jinja2_variable_with_field_access(self):
|
|
50
|
+
"""Should extract base variable from {{ var.field }}."""
|
|
51
|
+
from yamlgraph.utils.template import extract_variables
|
|
52
|
+
|
|
53
|
+
template = "User: {{ user.name }}"
|
|
54
|
+
variables = extract_variables(template)
|
|
55
|
+
assert "user" in variables
|
|
56
|
+
|
|
57
|
+
def test_extract_jinja2_loop_variable(self):
|
|
58
|
+
"""Should extract iterable from {% for x in items %}."""
|
|
59
|
+
from yamlgraph.utils.template import extract_variables
|
|
60
|
+
|
|
61
|
+
template = "{% for item in items %}{{ item.name }}{% endfor %}"
|
|
62
|
+
variables = extract_variables(template)
|
|
63
|
+
assert "items" in variables
|
|
64
|
+
# 'item' is a loop variable, not a required input
|
|
65
|
+
assert "item" not in variables
|
|
66
|
+
|
|
67
|
+
def test_extract_jinja2_if_variable(self):
|
|
68
|
+
"""Should extract variable from {% if condition %}."""
|
|
69
|
+
from yamlgraph.utils.template import extract_variables
|
|
70
|
+
|
|
71
|
+
template = "{% if show_details %}Details here{% endif %}"
|
|
72
|
+
variables = extract_variables(template)
|
|
73
|
+
assert "show_details" in variables
|
|
74
|
+
|
|
75
|
+
def test_exclude_state_variable(self):
|
|
76
|
+
"""State is injected by framework, not a required input."""
|
|
77
|
+
from yamlgraph.utils.template import extract_variables
|
|
78
|
+
|
|
79
|
+
template = "{{ state.topic }}"
|
|
80
|
+
variables = extract_variables(template)
|
|
81
|
+
# state is excluded - it's injected by node_factory
|
|
82
|
+
assert "state" not in variables
|
|
83
|
+
|
|
84
|
+
def test_exclude_jinja2_builtins(self):
|
|
85
|
+
"""Should exclude Jinja2 builtins like loop, range."""
|
|
86
|
+
from yamlgraph.utils.template import extract_variables
|
|
87
|
+
|
|
88
|
+
template = "{% for i in range(10) %}{{ loop.index }}{% endfor %}"
|
|
89
|
+
variables = extract_variables(template)
|
|
90
|
+
assert "range" not in variables
|
|
91
|
+
assert "loop" not in variables
|
|
92
|
+
|
|
93
|
+
def test_mixed_simple_and_jinja2(self):
|
|
94
|
+
"""Should handle templates mixing {var} and {{ var }}."""
|
|
95
|
+
from yamlgraph.utils.template import extract_variables
|
|
96
|
+
|
|
97
|
+
template = "Simple {name} and Jinja2 {{ topic }}"
|
|
98
|
+
variables = extract_variables(template)
|
|
99
|
+
assert "name" in variables
|
|
100
|
+
assert "topic" in variables
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class TestValidateVariables:
|
|
104
|
+
"""Tests for validate_variables function."""
|
|
105
|
+
|
|
106
|
+
def test_validate_all_provided(self):
|
|
107
|
+
"""Should not raise when all variables provided."""
|
|
108
|
+
from yamlgraph.utils.template import validate_variables
|
|
109
|
+
|
|
110
|
+
template = "Hello {name}, style: {style}"
|
|
111
|
+
# Should not raise
|
|
112
|
+
validate_variables(template, {"name": "World", "style": "formal"}, "greet")
|
|
113
|
+
|
|
114
|
+
def test_validate_missing_single_variable(self):
|
|
115
|
+
"""Should raise ValueError for single missing variable."""
|
|
116
|
+
from yamlgraph.utils.template import validate_variables
|
|
117
|
+
|
|
118
|
+
template = "Hello {name}, style: {style}"
|
|
119
|
+
with pytest.raises(ValueError, match="Missing required variable.*name"):
|
|
120
|
+
validate_variables(template, {"style": "formal"}, "greet")
|
|
121
|
+
|
|
122
|
+
def test_validate_missing_multiple_variables(self):
|
|
123
|
+
"""Should list ALL missing variables in error."""
|
|
124
|
+
from yamlgraph.utils.template import validate_variables
|
|
125
|
+
|
|
126
|
+
template = "Hello {name}, style: {style}"
|
|
127
|
+
with pytest.raises(ValueError) as exc_info:
|
|
128
|
+
validate_variables(template, {}, "greet")
|
|
129
|
+
error_msg = str(exc_info.value)
|
|
130
|
+
assert "name" in error_msg
|
|
131
|
+
assert "style" in error_msg
|
|
132
|
+
|
|
133
|
+
def test_validate_extra_variables_ok(self):
|
|
134
|
+
"""Should not raise when extra variables provided."""
|
|
135
|
+
from yamlgraph.utils.template import validate_variables
|
|
136
|
+
|
|
137
|
+
template = "Hello {name}"
|
|
138
|
+
# Should not raise - extra vars are fine
|
|
139
|
+
validate_variables(template, {"name": "World", "extra": "ignored"}, "greet")
|
|
140
|
+
|
|
141
|
+
def test_validate_prompt_name_in_error(self):
|
|
142
|
+
"""Error message should include prompt name."""
|
|
143
|
+
from yamlgraph.utils.template import validate_variables
|
|
144
|
+
|
|
145
|
+
template = "Hello {name}"
|
|
146
|
+
with pytest.raises(ValueError, match="greet"):
|
|
147
|
+
validate_variables(template, {}, "greet")
|
|
148
|
+
|
|
149
|
+
def test_validate_empty_template(self):
|
|
150
|
+
"""Should not raise for template without variables."""
|
|
151
|
+
from yamlgraph.utils.template import validate_variables
|
|
152
|
+
|
|
153
|
+
template = "No variables here"
|
|
154
|
+
# Should not raise
|
|
155
|
+
validate_variables(template, {}, "static")
|
|
156
|
+
|
|
157
|
+
def test_validate_jinja2_template(self):
|
|
158
|
+
"""Should validate Jinja2 templates correctly."""
|
|
159
|
+
from yamlgraph.utils.template import validate_variables
|
|
160
|
+
|
|
161
|
+
template = "{% for item in items %}{{ item }}{% endfor %}"
|
|
162
|
+
with pytest.raises(ValueError, match="items"):
|
|
163
|
+
validate_variables(template, {}, "list_template")
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
class TestExecutePromptValidation:
|
|
167
|
+
"""Integration tests for validation in execute_prompt."""
|
|
168
|
+
|
|
169
|
+
def test_execute_prompt_raises_on_missing_variable(self):
|
|
170
|
+
"""Should raise clear error when required variable is missing."""
|
|
171
|
+
from yamlgraph.executor import execute_prompt
|
|
172
|
+
|
|
173
|
+
with pytest.raises(ValueError, match="Missing required variable.*name"):
|
|
174
|
+
execute_prompt(
|
|
175
|
+
prompt_name="greet",
|
|
176
|
+
variables={"style": "formal"}, # Missing 'name'
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
def test_execute_prompt_lists_all_missing_variables(self):
|
|
180
|
+
"""Error should list ALL missing variables, not just first."""
|
|
181
|
+
from yamlgraph.executor import execute_prompt
|
|
182
|
+
|
|
183
|
+
with pytest.raises(ValueError) as exc_info:
|
|
184
|
+
execute_prompt(
|
|
185
|
+
prompt_name="greet",
|
|
186
|
+
variables={}, # Missing both 'name' and 'style'
|
|
187
|
+
)
|
|
188
|
+
error_msg = str(exc_info.value)
|
|
189
|
+
assert "name" in error_msg
|
|
190
|
+
assert "style" in error_msg
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
"""Tests for type: tool_call integration in graph_loader.
|
|
2
|
+
|
|
3
|
+
TDD Phase 3b: Wire tool_call node into graph compilation.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import pytest
|
|
7
|
+
from langgraph.graph import StateGraph
|
|
8
|
+
|
|
9
|
+
from yamlgraph.graph_loader import GraphConfig, _compile_node
|
|
10
|
+
from yamlgraph.map_compiler import compile_map_node
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
# Sample tools for testing
|
|
14
|
+
def sample_search(path: str) -> dict:
|
|
15
|
+
"""Sample search tool."""
|
|
16
|
+
return {"path": path, "matches": ["line1", "line2"]}
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def sample_read(file: str, start: int = 1, end: int = 10) -> dict:
|
|
20
|
+
"""Sample read tool."""
|
|
21
|
+
return {"file": file, "lines": list(range(start, end + 1))}
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@pytest.fixture
|
|
25
|
+
def tools_registry() -> dict:
|
|
26
|
+
"""Combined tools registry for tool_call nodes."""
|
|
27
|
+
return {
|
|
28
|
+
"search_file": sample_search,
|
|
29
|
+
"read_lines": sample_read,
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@pytest.fixture
|
|
34
|
+
def minimal_config() -> GraphConfig:
|
|
35
|
+
"""Minimal graph config with tools."""
|
|
36
|
+
config_dict = {
|
|
37
|
+
"version": "1.0",
|
|
38
|
+
"name": "test",
|
|
39
|
+
"nodes": {
|
|
40
|
+
"dummy": {
|
|
41
|
+
"prompt": "test",
|
|
42
|
+
"state_key": "result",
|
|
43
|
+
}
|
|
44
|
+
},
|
|
45
|
+
"edges": [
|
|
46
|
+
{"from": "START", "to": "dummy"},
|
|
47
|
+
{"from": "dummy", "to": "END"},
|
|
48
|
+
],
|
|
49
|
+
"tools": {
|
|
50
|
+
"search_file": {
|
|
51
|
+
"type": "python",
|
|
52
|
+
"module": "tests.unit.test_tool_call_integration",
|
|
53
|
+
"function": "sample_search",
|
|
54
|
+
},
|
|
55
|
+
"read_lines": {
|
|
56
|
+
"type": "python",
|
|
57
|
+
"module": "tests.unit.test_tool_call_integration",
|
|
58
|
+
"function": "sample_read",
|
|
59
|
+
},
|
|
60
|
+
},
|
|
61
|
+
}
|
|
62
|
+
return GraphConfig(config_dict)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class TestCompileToolCallNode:
|
|
66
|
+
"""Test _compile_node with type: tool_call."""
|
|
67
|
+
|
|
68
|
+
def test_compiles_tool_call_node(self, minimal_config, tools_registry):
|
|
69
|
+
"""Should compile tool_call node and add to graph."""
|
|
70
|
+
from operator import add
|
|
71
|
+
from typing import Annotated
|
|
72
|
+
|
|
73
|
+
# Create state class with reducer for discovery_findings
|
|
74
|
+
class TestState:
|
|
75
|
+
discovery_findings: Annotated[list, add] = []
|
|
76
|
+
|
|
77
|
+
graph = StateGraph(TestState)
|
|
78
|
+
|
|
79
|
+
node_config = {
|
|
80
|
+
"type": "tool_call",
|
|
81
|
+
"tool": "{state.task.tool}",
|
|
82
|
+
"args": "{state.task.args}",
|
|
83
|
+
"state_key": "result",
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
result = _compile_node(
|
|
87
|
+
"test_tool_call",
|
|
88
|
+
node_config,
|
|
89
|
+
graph,
|
|
90
|
+
minimal_config,
|
|
91
|
+
tools={},
|
|
92
|
+
python_tools={},
|
|
93
|
+
websearch_tools={},
|
|
94
|
+
callable_registry=tools_registry, # tool_call uses callable_registry
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
# Should not return map info
|
|
98
|
+
assert result is None
|
|
99
|
+
# Node should be in graph
|
|
100
|
+
assert "test_tool_call" in graph.nodes
|
|
101
|
+
|
|
102
|
+
def test_tool_call_node_executes(self, tools_registry):
|
|
103
|
+
"""Tool call node should execute tool from state."""
|
|
104
|
+
from yamlgraph.node_factory import create_tool_call_node
|
|
105
|
+
|
|
106
|
+
node_config = {
|
|
107
|
+
"tool": "{state.task.tool}",
|
|
108
|
+
"args": "{state.task.args}",
|
|
109
|
+
"state_key": "result",
|
|
110
|
+
}
|
|
111
|
+
node_fn = create_tool_call_node("exec_tool", node_config, tools_registry)
|
|
112
|
+
|
|
113
|
+
state = {
|
|
114
|
+
"task": {
|
|
115
|
+
"id": 1,
|
|
116
|
+
"tool": "search_file",
|
|
117
|
+
"args": {"path": "foo.py"},
|
|
118
|
+
}
|
|
119
|
+
}
|
|
120
|
+
result = node_fn(state)
|
|
121
|
+
|
|
122
|
+
assert result["result"]["success"] is True
|
|
123
|
+
assert result["result"]["result"]["path"] == "foo.py"
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
class TestMapWithToolCall:
|
|
127
|
+
"""Test map node with tool_call sub-node."""
|
|
128
|
+
|
|
129
|
+
def test_map_with_tool_call_sub_node(self, tools_registry):
|
|
130
|
+
"""Map node should support type: tool_call in sub-node."""
|
|
131
|
+
from operator import add
|
|
132
|
+
from typing import Annotated
|
|
133
|
+
|
|
134
|
+
class TestState:
|
|
135
|
+
discovery_plan: dict = {}
|
|
136
|
+
discovery_findings: Annotated[list, add] = []
|
|
137
|
+
|
|
138
|
+
graph = StateGraph(TestState)
|
|
139
|
+
|
|
140
|
+
map_config = {
|
|
141
|
+
"type": "map",
|
|
142
|
+
"over": "{state.discovery_plan.tasks}",
|
|
143
|
+
"as": "task",
|
|
144
|
+
"node": {
|
|
145
|
+
"type": "tool_call",
|
|
146
|
+
"tool": "{state.task.tool}",
|
|
147
|
+
"args": "{state.task.args}",
|
|
148
|
+
"state_key": "discovery_result",
|
|
149
|
+
},
|
|
150
|
+
"collect": "discovery_findings",
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
# This should work - map_compiler needs to handle tool_call sub-nodes
|
|
154
|
+
map_edge_fn, sub_node_name = compile_map_node(
|
|
155
|
+
"execute_discovery",
|
|
156
|
+
map_config,
|
|
157
|
+
graph,
|
|
158
|
+
defaults={},
|
|
159
|
+
tools_registry=tools_registry, # New parameter for tool_call
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
assert callable(map_edge_fn)
|
|
163
|
+
assert sub_node_name == "_map_execute_discovery_sub"
|
|
164
|
+
assert sub_node_name in graph.nodes
|
|
@@ -0,0 +1,178 @@
|
|
|
1
|
+
"""Tests for type: tool_call node in node_factory.
|
|
2
|
+
|
|
3
|
+
TDD Phase 3: Dynamic tool execution from state.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import pytest
|
|
7
|
+
|
|
8
|
+
from yamlgraph.constants import NodeType
|
|
9
|
+
from yamlgraph.node_factory import create_tool_call_node
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
# Sample tools for testing
|
|
13
|
+
def sample_tool(path: str, pattern: str = ".*") -> dict:
|
|
14
|
+
"""Sample tool that returns its args."""
|
|
15
|
+
return {"path": path, "pattern": pattern, "found": ["line1", "line2"]}
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def failing_tool(path: str) -> dict:
|
|
19
|
+
"""Tool that always raises."""
|
|
20
|
+
raise ValueError(f"Cannot process: {path}")
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def simple_tool() -> str:
|
|
24
|
+
"""Tool with no args."""
|
|
25
|
+
return "simple result"
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@pytest.fixture
|
|
29
|
+
def tools_registry() -> dict:
|
|
30
|
+
"""Sample tools registry."""
|
|
31
|
+
return {
|
|
32
|
+
"search_file": sample_tool,
|
|
33
|
+
"failing_tool": failing_tool,
|
|
34
|
+
"simple_tool": simple_tool,
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class TestNodeTypeConstant:
|
|
39
|
+
"""Verify TOOL_CALL is added to NodeType enum."""
|
|
40
|
+
|
|
41
|
+
def test_tool_call_in_node_type(self):
|
|
42
|
+
"""TOOL_CALL should be a valid node type."""
|
|
43
|
+
assert NodeType.TOOL_CALL == "tool_call"
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class TestCreateToolCallNode:
|
|
47
|
+
"""Test create_tool_call_node factory function."""
|
|
48
|
+
|
|
49
|
+
def test_creates_callable(self, tools_registry):
|
|
50
|
+
"""Should return a callable node function."""
|
|
51
|
+
config = {
|
|
52
|
+
"tool": "{state.task.tool}",
|
|
53
|
+
"args": "{state.task.args}",
|
|
54
|
+
"state_key": "result",
|
|
55
|
+
}
|
|
56
|
+
node_fn = create_tool_call_node("test_node", config, tools_registry)
|
|
57
|
+
assert callable(node_fn)
|
|
58
|
+
|
|
59
|
+
def test_resolves_tool_from_state(self, tools_registry):
|
|
60
|
+
"""Should resolve tool name dynamically from state."""
|
|
61
|
+
config = {
|
|
62
|
+
"tool": "{state.task.tool}",
|
|
63
|
+
"args": "{state.task.args}",
|
|
64
|
+
"state_key": "result",
|
|
65
|
+
}
|
|
66
|
+
node_fn = create_tool_call_node("test_node", config, tools_registry)
|
|
67
|
+
|
|
68
|
+
state = {
|
|
69
|
+
"task": {
|
|
70
|
+
"id": 1,
|
|
71
|
+
"tool": "search_file",
|
|
72
|
+
"args": {"path": "foo.py", "pattern": "def"},
|
|
73
|
+
}
|
|
74
|
+
}
|
|
75
|
+
result = node_fn(state)
|
|
76
|
+
|
|
77
|
+
assert result["result"]["success"] is True
|
|
78
|
+
assert result["result"]["tool"] == "search_file"
|
|
79
|
+
|
|
80
|
+
def test_resolves_args_from_state(self, tools_registry):
|
|
81
|
+
"""Should resolve args dynamically and pass to tool."""
|
|
82
|
+
config = {
|
|
83
|
+
"tool": "{state.task.tool}",
|
|
84
|
+
"args": "{state.task.args}",
|
|
85
|
+
"state_key": "result",
|
|
86
|
+
}
|
|
87
|
+
node_fn = create_tool_call_node("test_node", config, tools_registry)
|
|
88
|
+
|
|
89
|
+
state = {
|
|
90
|
+
"task": {
|
|
91
|
+
"id": 1,
|
|
92
|
+
"tool": "search_file",
|
|
93
|
+
"args": {"path": "bar.py", "pattern": "class"},
|
|
94
|
+
}
|
|
95
|
+
}
|
|
96
|
+
result = node_fn(state)
|
|
97
|
+
|
|
98
|
+
# Verify args were passed correctly
|
|
99
|
+
assert result["result"]["result"]["path"] == "bar.py"
|
|
100
|
+
assert result["result"]["result"]["pattern"] == "class"
|
|
101
|
+
|
|
102
|
+
def test_successful_execution(self, tools_registry):
|
|
103
|
+
"""Should return success=True with result on success."""
|
|
104
|
+
config = {
|
|
105
|
+
"tool": "{state.task.tool}",
|
|
106
|
+
"args": "{state.task.args}",
|
|
107
|
+
"state_key": "result",
|
|
108
|
+
}
|
|
109
|
+
node_fn = create_tool_call_node("test_node", config, tools_registry)
|
|
110
|
+
|
|
111
|
+
state = {"task": {"id": 42, "tool": "search_file", "args": {"path": "x.py"}}}
|
|
112
|
+
result = node_fn(state)
|
|
113
|
+
|
|
114
|
+
assert result["result"]["success"] is True
|
|
115
|
+
assert result["result"]["task_id"] == 42
|
|
116
|
+
assert result["result"]["error"] is None
|
|
117
|
+
assert "found" in result["result"]["result"]
|
|
118
|
+
|
|
119
|
+
def test_unknown_tool_handling(self, tools_registry):
|
|
120
|
+
"""Should return success=False for unknown tool."""
|
|
121
|
+
config = {
|
|
122
|
+
"tool": "{state.task.tool}",
|
|
123
|
+
"args": "{state.task.args}",
|
|
124
|
+
"state_key": "result",
|
|
125
|
+
}
|
|
126
|
+
node_fn = create_tool_call_node("test_node", config, tools_registry)
|
|
127
|
+
|
|
128
|
+
state = {"task": {"id": 1, "tool": "nonexistent_tool", "args": {}}}
|
|
129
|
+
result = node_fn(state)
|
|
130
|
+
|
|
131
|
+
assert result["result"]["success"] is False
|
|
132
|
+
assert "Unknown tool" in result["result"]["error"]
|
|
133
|
+
assert result["result"]["tool"] == "nonexistent_tool"
|
|
134
|
+
|
|
135
|
+
def test_tool_exception_handling(self, tools_registry):
|
|
136
|
+
"""Should catch exceptions and return success=False."""
|
|
137
|
+
config = {
|
|
138
|
+
"tool": "{state.task.tool}",
|
|
139
|
+
"args": "{state.task.args}",
|
|
140
|
+
"state_key": "result",
|
|
141
|
+
}
|
|
142
|
+
node_fn = create_tool_call_node("test_node", config, tools_registry)
|
|
143
|
+
|
|
144
|
+
state = {"task": {"id": 5, "tool": "failing_tool", "args": {"path": "bad.py"}}}
|
|
145
|
+
result = node_fn(state)
|
|
146
|
+
|
|
147
|
+
assert result["result"]["success"] is False
|
|
148
|
+
assert "Cannot process: bad.py" in result["result"]["error"]
|
|
149
|
+
assert result["result"]["task_id"] == 5
|
|
150
|
+
|
|
151
|
+
def test_includes_current_step(self, tools_registry):
|
|
152
|
+
"""Should include current_step in output for state tracking."""
|
|
153
|
+
config = {
|
|
154
|
+
"tool": "{state.task.tool}",
|
|
155
|
+
"args": "{state.task.args}",
|
|
156
|
+
"state_key": "result",
|
|
157
|
+
}
|
|
158
|
+
node_fn = create_tool_call_node("my_tool_call", config, tools_registry)
|
|
159
|
+
|
|
160
|
+
state = {"task": {"id": 1, "tool": "simple_tool", "args": {}}}
|
|
161
|
+
result = node_fn(state)
|
|
162
|
+
|
|
163
|
+
assert result["current_step"] == "my_tool_call"
|
|
164
|
+
|
|
165
|
+
def test_empty_args(self, tools_registry):
|
|
166
|
+
"""Should work with tools that take no arguments."""
|
|
167
|
+
config = {
|
|
168
|
+
"tool": "{state.task.tool}",
|
|
169
|
+
"args": "{state.task.args}",
|
|
170
|
+
"state_key": "result",
|
|
171
|
+
}
|
|
172
|
+
node_fn = create_tool_call_node("test_node", config, tools_registry)
|
|
173
|
+
|
|
174
|
+
state = {"task": {"id": 1, "tool": "simple_tool", "args": {}}}
|
|
175
|
+
result = node_fn(state)
|
|
176
|
+
|
|
177
|
+
assert result["result"]["success"] is True
|
|
178
|
+
assert result["result"]["result"] == "simple result"
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
"""Tests for tool nodes (type: tool)."""
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
|
|
5
|
+
from yamlgraph.tools.nodes import create_tool_node
|
|
6
|
+
from yamlgraph.tools.shell import ShellToolConfig
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class TestCreateToolNode:
|
|
10
|
+
"""Tests for create_tool_node function."""
|
|
11
|
+
|
|
12
|
+
def test_executes_named_tool(self):
|
|
13
|
+
"""Node runs correct tool from registry."""
|
|
14
|
+
tools = {
|
|
15
|
+
"echo_tool": ShellToolConfig(command="echo hello"),
|
|
16
|
+
"other_tool": ShellToolConfig(command="echo other"),
|
|
17
|
+
}
|
|
18
|
+
node_config = {"tool": "echo_tool"}
|
|
19
|
+
|
|
20
|
+
node_fn = create_tool_node("test_node", node_config, tools)
|
|
21
|
+
result = node_fn({})
|
|
22
|
+
|
|
23
|
+
assert result["test_node"].strip() == "hello"
|
|
24
|
+
assert result["current_step"] == "test_node"
|
|
25
|
+
|
|
26
|
+
def test_resolves_variables_from_state(self):
|
|
27
|
+
"""State values passed to tool."""
|
|
28
|
+
tools = {
|
|
29
|
+
"greet": ShellToolConfig(command="echo Hello {name}"),
|
|
30
|
+
}
|
|
31
|
+
node_config = {
|
|
32
|
+
"tool": "greet",
|
|
33
|
+
"variables": {"name": "{state.user_name}"},
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
node_fn = create_tool_node("greet_node", node_config, tools)
|
|
37
|
+
result = node_fn({"user_name": "Alice"})
|
|
38
|
+
|
|
39
|
+
assert "Alice" in result["greet_node"]
|
|
40
|
+
|
|
41
|
+
def test_stores_result_in_state_key(self):
|
|
42
|
+
"""Tool output saved to custom state_key."""
|
|
43
|
+
tools = {
|
|
44
|
+
"data_tool": ShellToolConfig(command="echo data_value"),
|
|
45
|
+
}
|
|
46
|
+
node_config = {
|
|
47
|
+
"tool": "data_tool",
|
|
48
|
+
"state_key": "my_data",
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
node_fn = create_tool_node("fetch_node", node_config, tools)
|
|
52
|
+
result = node_fn({})
|
|
53
|
+
|
|
54
|
+
assert "my_data" in result
|
|
55
|
+
assert result["my_data"].strip() == "data_value"
|
|
56
|
+
|
|
57
|
+
def test_on_error_skip(self):
|
|
58
|
+
"""Failed tool skipped when on_error: skip."""
|
|
59
|
+
tools = {
|
|
60
|
+
"fail_tool": ShellToolConfig(command="exit 1"),
|
|
61
|
+
}
|
|
62
|
+
node_config = {
|
|
63
|
+
"tool": "fail_tool",
|
|
64
|
+
"on_error": "skip",
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
node_fn = create_tool_node("fail_node", node_config, tools)
|
|
68
|
+
result = node_fn({})
|
|
69
|
+
|
|
70
|
+
# Should not raise, should return with error info
|
|
71
|
+
assert result["current_step"] == "fail_node"
|
|
72
|
+
assert "errors" in result or result.get("fail_node") is None
|
|
73
|
+
|
|
74
|
+
def test_on_error_fail_raises(self):
|
|
75
|
+
"""Failed tool raises when on_error: fail."""
|
|
76
|
+
tools = {
|
|
77
|
+
"fail_tool": ShellToolConfig(command="exit 1"),
|
|
78
|
+
}
|
|
79
|
+
node_config = {
|
|
80
|
+
"tool": "fail_tool",
|
|
81
|
+
"on_error": "fail",
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
node_fn = create_tool_node("fail_node", node_config, tools)
|
|
85
|
+
|
|
86
|
+
with pytest.raises(RuntimeError):
|
|
87
|
+
node_fn({})
|
|
88
|
+
|
|
89
|
+
def test_nested_state_variable(self):
|
|
90
|
+
"""Nested state values like {state.location.lat} resolved."""
|
|
91
|
+
tools = {
|
|
92
|
+
"geo": ShellToolConfig(command="echo {lat},{lon}"),
|
|
93
|
+
}
|
|
94
|
+
node_config = {
|
|
95
|
+
"tool": "geo",
|
|
96
|
+
"variables": {
|
|
97
|
+
"lat": "{state.location.lat}",
|
|
98
|
+
"lon": "{state.location.lon}",
|
|
99
|
+
},
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
node_fn = create_tool_node("geo_node", node_config, tools)
|
|
103
|
+
result = node_fn({"location": {"lat": "37.7749", "lon": "-122.4194"}})
|
|
104
|
+
|
|
105
|
+
assert "37.7749" in result["geo_node"]
|
|
106
|
+
assert "-122.4194" in result["geo_node"]
|
|
107
|
+
|
|
108
|
+
def test_missing_tool_raises(self):
|
|
109
|
+
"""Unknown tool name raises error."""
|
|
110
|
+
tools = {}
|
|
111
|
+
node_config = {"tool": "nonexistent"}
|
|
112
|
+
|
|
113
|
+
with pytest.raises(KeyError):
|
|
114
|
+
create_tool_node("bad_node", node_config, tools)
|
|
115
|
+
|
|
116
|
+
def test_json_parse_tool(self):
|
|
117
|
+
"""Tool with parse: json returns dict."""
|
|
118
|
+
tools = {
|
|
119
|
+
"json_tool": ShellToolConfig(
|
|
120
|
+
command="echo '{{\"count\": 42}}'",
|
|
121
|
+
parse="json",
|
|
122
|
+
),
|
|
123
|
+
}
|
|
124
|
+
node_config = {"tool": "json_tool"}
|
|
125
|
+
|
|
126
|
+
node_fn = create_tool_node("json_node", node_config, tools)
|
|
127
|
+
result = node_fn({})
|
|
128
|
+
|
|
129
|
+
assert result["json_node"] == {"count": 42}
|