yamlgraph 0.3.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. examples/__init__.py +1 -0
  2. examples/codegen/__init__.py +5 -0
  3. examples/codegen/models/__init__.py +13 -0
  4. examples/codegen/models/schemas.py +76 -0
  5. examples/codegen/tests/__init__.py +1 -0
  6. examples/codegen/tests/test_ai_helpers.py +235 -0
  7. examples/codegen/tests/test_ast_analysis.py +174 -0
  8. examples/codegen/tests/test_code_analysis.py +134 -0
  9. examples/codegen/tests/test_code_context.py +301 -0
  10. examples/codegen/tests/test_code_nav.py +89 -0
  11. examples/codegen/tests/test_dependency_tools.py +119 -0
  12. examples/codegen/tests/test_example_tools.py +185 -0
  13. examples/codegen/tests/test_git_tools.py +112 -0
  14. examples/codegen/tests/test_impl_agent_schemas.py +193 -0
  15. examples/codegen/tests/test_impl_agent_v4_graph.py +94 -0
  16. examples/codegen/tests/test_jedi_analysis.py +226 -0
  17. examples/codegen/tests/test_meta_tools.py +250 -0
  18. examples/codegen/tests/test_plan_discovery_prompt.py +98 -0
  19. examples/codegen/tests/test_syntax_tools.py +85 -0
  20. examples/codegen/tests/test_synthesize_prompt.py +94 -0
  21. examples/codegen/tests/test_template_tools.py +244 -0
  22. examples/codegen/tools/__init__.py +80 -0
  23. examples/codegen/tools/ai_helpers.py +420 -0
  24. examples/codegen/tools/ast_analysis.py +92 -0
  25. examples/codegen/tools/code_context.py +180 -0
  26. examples/codegen/tools/code_nav.py +52 -0
  27. examples/codegen/tools/dependency_tools.py +120 -0
  28. examples/codegen/tools/example_tools.py +188 -0
  29. examples/codegen/tools/git_tools.py +151 -0
  30. examples/codegen/tools/impl_executor.py +614 -0
  31. examples/codegen/tools/jedi_analysis.py +311 -0
  32. examples/codegen/tools/meta_tools.py +202 -0
  33. examples/codegen/tools/syntax_tools.py +26 -0
  34. examples/codegen/tools/template_tools.py +356 -0
  35. examples/fastapi_interview.py +167 -0
  36. examples/npc/api/__init__.py +1 -0
  37. examples/npc/api/app.py +100 -0
  38. examples/npc/api/routes/__init__.py +5 -0
  39. examples/npc/api/routes/encounter.py +182 -0
  40. examples/npc/api/session.py +330 -0
  41. examples/npc/demo.py +387 -0
  42. examples/npc/nodes/__init__.py +5 -0
  43. examples/npc/nodes/image_node.py +92 -0
  44. examples/npc/run_encounter.py +230 -0
  45. examples/shared/__init__.py +0 -0
  46. examples/shared/replicate_tool.py +238 -0
  47. examples/storyboard/__init__.py +1 -0
  48. examples/storyboard/generate_videos.py +335 -0
  49. examples/storyboard/nodes/__init__.py +12 -0
  50. examples/storyboard/nodes/animated_character_node.py +248 -0
  51. examples/storyboard/nodes/animated_image_node.py +138 -0
  52. examples/storyboard/nodes/character_node.py +162 -0
  53. examples/storyboard/nodes/image_node.py +118 -0
  54. examples/storyboard/nodes/replicate_tool.py +49 -0
  55. examples/storyboard/retry_images.py +118 -0
  56. scripts/demo_async_executor.py +212 -0
  57. scripts/demo_interview_e2e.py +200 -0
  58. scripts/demo_streaming.py +140 -0
  59. scripts/run_interview_demo.py +94 -0
  60. scripts/test_interrupt_fix.py +26 -0
  61. tests/__init__.py +1 -0
  62. tests/conftest.py +178 -0
  63. tests/integration/__init__.py +1 -0
  64. tests/integration/test_animated_storyboard.py +63 -0
  65. tests/integration/test_cli_commands.py +242 -0
  66. tests/integration/test_colocated_prompts.py +139 -0
  67. tests/integration/test_map_demo.py +50 -0
  68. tests/integration/test_memory_demo.py +283 -0
  69. tests/integration/test_npc_api/__init__.py +1 -0
  70. tests/integration/test_npc_api/test_routes.py +357 -0
  71. tests/integration/test_npc_api/test_session.py +216 -0
  72. tests/integration/test_pipeline_flow.py +105 -0
  73. tests/integration/test_providers.py +163 -0
  74. tests/integration/test_resume.py +75 -0
  75. tests/integration/test_subgraph_integration.py +295 -0
  76. tests/integration/test_subgraph_interrupt.py +106 -0
  77. tests/unit/__init__.py +1 -0
  78. tests/unit/test_agent_nodes.py +355 -0
  79. tests/unit/test_async_executor.py +346 -0
  80. tests/unit/test_checkpointer.py +212 -0
  81. tests/unit/test_checkpointer_factory.py +212 -0
  82. tests/unit/test_cli.py +121 -0
  83. tests/unit/test_cli_package.py +81 -0
  84. tests/unit/test_compile_graph_map.py +132 -0
  85. tests/unit/test_conditions_routing.py +253 -0
  86. tests/unit/test_config.py +93 -0
  87. tests/unit/test_conversation_memory.py +276 -0
  88. tests/unit/test_database.py +145 -0
  89. tests/unit/test_deprecation.py +104 -0
  90. tests/unit/test_executor.py +172 -0
  91. tests/unit/test_executor_async.py +179 -0
  92. tests/unit/test_export.py +149 -0
  93. tests/unit/test_expressions.py +178 -0
  94. tests/unit/test_feature_brainstorm.py +194 -0
  95. tests/unit/test_format_prompt.py +145 -0
  96. tests/unit/test_generic_report.py +200 -0
  97. tests/unit/test_graph_commands.py +327 -0
  98. tests/unit/test_graph_linter.py +627 -0
  99. tests/unit/test_graph_loader.py +357 -0
  100. tests/unit/test_graph_schema.py +193 -0
  101. tests/unit/test_inline_schema.py +151 -0
  102. tests/unit/test_interrupt_node.py +182 -0
  103. tests/unit/test_issues.py +164 -0
  104. tests/unit/test_jinja2_prompts.py +85 -0
  105. tests/unit/test_json_extract.py +134 -0
  106. tests/unit/test_langsmith.py +600 -0
  107. tests/unit/test_langsmith_tools.py +204 -0
  108. tests/unit/test_llm_factory.py +109 -0
  109. tests/unit/test_llm_factory_async.py +118 -0
  110. tests/unit/test_loops.py +403 -0
  111. tests/unit/test_map_node.py +144 -0
  112. tests/unit/test_no_backward_compat.py +56 -0
  113. tests/unit/test_node_factory.py +348 -0
  114. tests/unit/test_passthrough_node.py +126 -0
  115. tests/unit/test_prompts.py +324 -0
  116. tests/unit/test_python_nodes.py +198 -0
  117. tests/unit/test_reliability.py +298 -0
  118. tests/unit/test_result_export.py +234 -0
  119. tests/unit/test_router.py +296 -0
  120. tests/unit/test_sanitize.py +99 -0
  121. tests/unit/test_schema_loader.py +295 -0
  122. tests/unit/test_shell_tools.py +229 -0
  123. tests/unit/test_state_builder.py +331 -0
  124. tests/unit/test_state_builder_map.py +104 -0
  125. tests/unit/test_state_config.py +197 -0
  126. tests/unit/test_streaming.py +307 -0
  127. tests/unit/test_subgraph.py +596 -0
  128. tests/unit/test_template.py +190 -0
  129. tests/unit/test_tool_call_integration.py +164 -0
  130. tests/unit/test_tool_call_node.py +178 -0
  131. tests/unit/test_tool_nodes.py +129 -0
  132. tests/unit/test_websearch.py +234 -0
  133. yamlgraph/__init__.py +35 -0
  134. yamlgraph/builder.py +110 -0
  135. yamlgraph/cli/__init__.py +159 -0
  136. yamlgraph/cli/__main__.py +6 -0
  137. yamlgraph/cli/commands.py +231 -0
  138. yamlgraph/cli/deprecation.py +92 -0
  139. yamlgraph/cli/graph_commands.py +541 -0
  140. yamlgraph/cli/validators.py +37 -0
  141. yamlgraph/config.py +67 -0
  142. yamlgraph/constants.py +70 -0
  143. yamlgraph/error_handlers.py +227 -0
  144. yamlgraph/executor.py +290 -0
  145. yamlgraph/executor_async.py +288 -0
  146. yamlgraph/graph_loader.py +451 -0
  147. yamlgraph/map_compiler.py +150 -0
  148. yamlgraph/models/__init__.py +36 -0
  149. yamlgraph/models/graph_schema.py +181 -0
  150. yamlgraph/models/schemas.py +124 -0
  151. yamlgraph/models/state_builder.py +236 -0
  152. yamlgraph/node_factory.py +768 -0
  153. yamlgraph/routing.py +87 -0
  154. yamlgraph/schema_loader.py +240 -0
  155. yamlgraph/storage/__init__.py +20 -0
  156. yamlgraph/storage/checkpointer.py +72 -0
  157. yamlgraph/storage/checkpointer_factory.py +123 -0
  158. yamlgraph/storage/database.py +320 -0
  159. yamlgraph/storage/export.py +269 -0
  160. yamlgraph/tools/__init__.py +1 -0
  161. yamlgraph/tools/agent.py +320 -0
  162. yamlgraph/tools/graph_linter.py +388 -0
  163. yamlgraph/tools/langsmith_tools.py +125 -0
  164. yamlgraph/tools/nodes.py +126 -0
  165. yamlgraph/tools/python_tool.py +179 -0
  166. yamlgraph/tools/shell.py +205 -0
  167. yamlgraph/tools/websearch.py +242 -0
  168. yamlgraph/utils/__init__.py +48 -0
  169. yamlgraph/utils/conditions.py +157 -0
  170. yamlgraph/utils/expressions.py +245 -0
  171. yamlgraph/utils/json_extract.py +104 -0
  172. yamlgraph/utils/langsmith.py +416 -0
  173. yamlgraph/utils/llm_factory.py +118 -0
  174. yamlgraph/utils/llm_factory_async.py +105 -0
  175. yamlgraph/utils/logging.py +104 -0
  176. yamlgraph/utils/prompts.py +171 -0
  177. yamlgraph/utils/sanitize.py +98 -0
  178. yamlgraph/utils/template.py +102 -0
  179. yamlgraph/utils/validators.py +181 -0
  180. yamlgraph-0.3.9.dist-info/METADATA +1105 -0
  181. yamlgraph-0.3.9.dist-info/RECORD +185 -0
  182. yamlgraph-0.3.9.dist-info/WHEEL +5 -0
  183. yamlgraph-0.3.9.dist-info/entry_points.txt +2 -0
  184. yamlgraph-0.3.9.dist-info/licenses/LICENSE +33 -0
  185. yamlgraph-0.3.9.dist-info/top_level.txt +4 -0
@@ -0,0 +1,190 @@
1
+ """Tests for yamlgraph.utils.template module - Variable extraction and validation."""
2
+
3
+ import pytest
4
+
5
+
6
+ class TestExtractVariables:
7
+ """Tests for extract_variables function."""
8
+
9
+ def test_extract_simple_variables(self):
10
+ """Should extract {var} placeholders."""
11
+ from yamlgraph.utils.template import extract_variables
12
+
13
+ template = "Hello {name}, your style is {style}."
14
+ variables = extract_variables(template)
15
+ assert variables == {"name", "style"}
16
+
17
+ def test_extract_single_variable(self):
18
+ """Should extract a single variable."""
19
+ from yamlgraph.utils.template import extract_variables
20
+
21
+ template = "Welcome {user}!"
22
+ variables = extract_variables(template)
23
+ assert variables == {"user"}
24
+
25
+ def test_extract_no_variables(self):
26
+ """Should return empty set when no variables."""
27
+ from yamlgraph.utils.template import extract_variables
28
+
29
+ template = "No variables here"
30
+ variables = extract_variables(template)
31
+ assert variables == set()
32
+
33
+ def test_extract_duplicate_variables(self):
34
+ """Should deduplicate variables."""
35
+ from yamlgraph.utils.template import extract_variables
36
+
37
+ template = "{name} and {name} again"
38
+ variables = extract_variables(template)
39
+ assert variables == {"name"}
40
+
41
+ def test_extract_jinja2_variable(self):
42
+ """Should extract {{ var }} Jinja2 variables."""
43
+ from yamlgraph.utils.template import extract_variables
44
+
45
+ template = "Hello {{ name }}!"
46
+ variables = extract_variables(template)
47
+ assert "name" in variables
48
+
49
+ def test_extract_jinja2_variable_with_field_access(self):
50
+ """Should extract base variable from {{ var.field }}."""
51
+ from yamlgraph.utils.template import extract_variables
52
+
53
+ template = "User: {{ user.name }}"
54
+ variables = extract_variables(template)
55
+ assert "user" in variables
56
+
57
+ def test_extract_jinja2_loop_variable(self):
58
+ """Should extract iterable from {% for x in items %}."""
59
+ from yamlgraph.utils.template import extract_variables
60
+
61
+ template = "{% for item in items %}{{ item.name }}{% endfor %}"
62
+ variables = extract_variables(template)
63
+ assert "items" in variables
64
+ # 'item' is a loop variable, not a required input
65
+ assert "item" not in variables
66
+
67
+ def test_extract_jinja2_if_variable(self):
68
+ """Should extract variable from {% if condition %}."""
69
+ from yamlgraph.utils.template import extract_variables
70
+
71
+ template = "{% if show_details %}Details here{% endif %}"
72
+ variables = extract_variables(template)
73
+ assert "show_details" in variables
74
+
75
+ def test_exclude_state_variable(self):
76
+ """State is injected by framework, not a required input."""
77
+ from yamlgraph.utils.template import extract_variables
78
+
79
+ template = "{{ state.topic }}"
80
+ variables = extract_variables(template)
81
+ # state is excluded - it's injected by node_factory
82
+ assert "state" not in variables
83
+
84
+ def test_exclude_jinja2_builtins(self):
85
+ """Should exclude Jinja2 builtins like loop, range."""
86
+ from yamlgraph.utils.template import extract_variables
87
+
88
+ template = "{% for i in range(10) %}{{ loop.index }}{% endfor %}"
89
+ variables = extract_variables(template)
90
+ assert "range" not in variables
91
+ assert "loop" not in variables
92
+
93
+ def test_mixed_simple_and_jinja2(self):
94
+ """Should handle templates mixing {var} and {{ var }}."""
95
+ from yamlgraph.utils.template import extract_variables
96
+
97
+ template = "Simple {name} and Jinja2 {{ topic }}"
98
+ variables = extract_variables(template)
99
+ assert "name" in variables
100
+ assert "topic" in variables
101
+
102
+
103
+ class TestValidateVariables:
104
+ """Tests for validate_variables function."""
105
+
106
+ def test_validate_all_provided(self):
107
+ """Should not raise when all variables provided."""
108
+ from yamlgraph.utils.template import validate_variables
109
+
110
+ template = "Hello {name}, style: {style}"
111
+ # Should not raise
112
+ validate_variables(template, {"name": "World", "style": "formal"}, "greet")
113
+
114
+ def test_validate_missing_single_variable(self):
115
+ """Should raise ValueError for single missing variable."""
116
+ from yamlgraph.utils.template import validate_variables
117
+
118
+ template = "Hello {name}, style: {style}"
119
+ with pytest.raises(ValueError, match="Missing required variable.*name"):
120
+ validate_variables(template, {"style": "formal"}, "greet")
121
+
122
+ def test_validate_missing_multiple_variables(self):
123
+ """Should list ALL missing variables in error."""
124
+ from yamlgraph.utils.template import validate_variables
125
+
126
+ template = "Hello {name}, style: {style}"
127
+ with pytest.raises(ValueError) as exc_info:
128
+ validate_variables(template, {}, "greet")
129
+ error_msg = str(exc_info.value)
130
+ assert "name" in error_msg
131
+ assert "style" in error_msg
132
+
133
+ def test_validate_extra_variables_ok(self):
134
+ """Should not raise when extra variables provided."""
135
+ from yamlgraph.utils.template import validate_variables
136
+
137
+ template = "Hello {name}"
138
+ # Should not raise - extra vars are fine
139
+ validate_variables(template, {"name": "World", "extra": "ignored"}, "greet")
140
+
141
+ def test_validate_prompt_name_in_error(self):
142
+ """Error message should include prompt name."""
143
+ from yamlgraph.utils.template import validate_variables
144
+
145
+ template = "Hello {name}"
146
+ with pytest.raises(ValueError, match="greet"):
147
+ validate_variables(template, {}, "greet")
148
+
149
+ def test_validate_empty_template(self):
150
+ """Should not raise for template without variables."""
151
+ from yamlgraph.utils.template import validate_variables
152
+
153
+ template = "No variables here"
154
+ # Should not raise
155
+ validate_variables(template, {}, "static")
156
+
157
+ def test_validate_jinja2_template(self):
158
+ """Should validate Jinja2 templates correctly."""
159
+ from yamlgraph.utils.template import validate_variables
160
+
161
+ template = "{% for item in items %}{{ item }}{% endfor %}"
162
+ with pytest.raises(ValueError, match="items"):
163
+ validate_variables(template, {}, "list_template")
164
+
165
+
166
+ class TestExecutePromptValidation:
167
+ """Integration tests for validation in execute_prompt."""
168
+
169
+ def test_execute_prompt_raises_on_missing_variable(self):
170
+ """Should raise clear error when required variable is missing."""
171
+ from yamlgraph.executor import execute_prompt
172
+
173
+ with pytest.raises(ValueError, match="Missing required variable.*name"):
174
+ execute_prompt(
175
+ prompt_name="greet",
176
+ variables={"style": "formal"}, # Missing 'name'
177
+ )
178
+
179
+ def test_execute_prompt_lists_all_missing_variables(self):
180
+ """Error should list ALL missing variables, not just first."""
181
+ from yamlgraph.executor import execute_prompt
182
+
183
+ with pytest.raises(ValueError) as exc_info:
184
+ execute_prompt(
185
+ prompt_name="greet",
186
+ variables={}, # Missing both 'name' and 'style'
187
+ )
188
+ error_msg = str(exc_info.value)
189
+ assert "name" in error_msg
190
+ assert "style" in error_msg
@@ -0,0 +1,164 @@
1
+ """Tests for type: tool_call integration in graph_loader.
2
+
3
+ TDD Phase 3b: Wire tool_call node into graph compilation.
4
+ """
5
+
6
+ import pytest
7
+ from langgraph.graph import StateGraph
8
+
9
+ from yamlgraph.graph_loader import GraphConfig, _compile_node
10
+ from yamlgraph.map_compiler import compile_map_node
11
+
12
+
13
+ # Sample tools for testing
14
+ def sample_search(path: str) -> dict:
15
+ """Sample search tool."""
16
+ return {"path": path, "matches": ["line1", "line2"]}
17
+
18
+
19
+ def sample_read(file: str, start: int = 1, end: int = 10) -> dict:
20
+ """Sample read tool."""
21
+ return {"file": file, "lines": list(range(start, end + 1))}
22
+
23
+
24
+ @pytest.fixture
25
+ def tools_registry() -> dict:
26
+ """Combined tools registry for tool_call nodes."""
27
+ return {
28
+ "search_file": sample_search,
29
+ "read_lines": sample_read,
30
+ }
31
+
32
+
33
+ @pytest.fixture
34
+ def minimal_config() -> GraphConfig:
35
+ """Minimal graph config with tools."""
36
+ config_dict = {
37
+ "version": "1.0",
38
+ "name": "test",
39
+ "nodes": {
40
+ "dummy": {
41
+ "prompt": "test",
42
+ "state_key": "result",
43
+ }
44
+ },
45
+ "edges": [
46
+ {"from": "START", "to": "dummy"},
47
+ {"from": "dummy", "to": "END"},
48
+ ],
49
+ "tools": {
50
+ "search_file": {
51
+ "type": "python",
52
+ "module": "tests.unit.test_tool_call_integration",
53
+ "function": "sample_search",
54
+ },
55
+ "read_lines": {
56
+ "type": "python",
57
+ "module": "tests.unit.test_tool_call_integration",
58
+ "function": "sample_read",
59
+ },
60
+ },
61
+ }
62
+ return GraphConfig(config_dict)
63
+
64
+
65
+ class TestCompileToolCallNode:
66
+ """Test _compile_node with type: tool_call."""
67
+
68
+ def test_compiles_tool_call_node(self, minimal_config, tools_registry):
69
+ """Should compile tool_call node and add to graph."""
70
+ from operator import add
71
+ from typing import Annotated
72
+
73
+ # Create state class with reducer for discovery_findings
74
+ class TestState:
75
+ discovery_findings: Annotated[list, add] = []
76
+
77
+ graph = StateGraph(TestState)
78
+
79
+ node_config = {
80
+ "type": "tool_call",
81
+ "tool": "{state.task.tool}",
82
+ "args": "{state.task.args}",
83
+ "state_key": "result",
84
+ }
85
+
86
+ result = _compile_node(
87
+ "test_tool_call",
88
+ node_config,
89
+ graph,
90
+ minimal_config,
91
+ tools={},
92
+ python_tools={},
93
+ websearch_tools={},
94
+ callable_registry=tools_registry, # tool_call uses callable_registry
95
+ )
96
+
97
+ # Should not return map info
98
+ assert result is None
99
+ # Node should be in graph
100
+ assert "test_tool_call" in graph.nodes
101
+
102
+ def test_tool_call_node_executes(self, tools_registry):
103
+ """Tool call node should execute tool from state."""
104
+ from yamlgraph.node_factory import create_tool_call_node
105
+
106
+ node_config = {
107
+ "tool": "{state.task.tool}",
108
+ "args": "{state.task.args}",
109
+ "state_key": "result",
110
+ }
111
+ node_fn = create_tool_call_node("exec_tool", node_config, tools_registry)
112
+
113
+ state = {
114
+ "task": {
115
+ "id": 1,
116
+ "tool": "search_file",
117
+ "args": {"path": "foo.py"},
118
+ }
119
+ }
120
+ result = node_fn(state)
121
+
122
+ assert result["result"]["success"] is True
123
+ assert result["result"]["result"]["path"] == "foo.py"
124
+
125
+
126
+ class TestMapWithToolCall:
127
+ """Test map node with tool_call sub-node."""
128
+
129
+ def test_map_with_tool_call_sub_node(self, tools_registry):
130
+ """Map node should support type: tool_call in sub-node."""
131
+ from operator import add
132
+ from typing import Annotated
133
+
134
+ class TestState:
135
+ discovery_plan: dict = {}
136
+ discovery_findings: Annotated[list, add] = []
137
+
138
+ graph = StateGraph(TestState)
139
+
140
+ map_config = {
141
+ "type": "map",
142
+ "over": "{state.discovery_plan.tasks}",
143
+ "as": "task",
144
+ "node": {
145
+ "type": "tool_call",
146
+ "tool": "{state.task.tool}",
147
+ "args": "{state.task.args}",
148
+ "state_key": "discovery_result",
149
+ },
150
+ "collect": "discovery_findings",
151
+ }
152
+
153
+ # This should work - map_compiler needs to handle tool_call sub-nodes
154
+ map_edge_fn, sub_node_name = compile_map_node(
155
+ "execute_discovery",
156
+ map_config,
157
+ graph,
158
+ defaults={},
159
+ tools_registry=tools_registry, # New parameter for tool_call
160
+ )
161
+
162
+ assert callable(map_edge_fn)
163
+ assert sub_node_name == "_map_execute_discovery_sub"
164
+ assert sub_node_name in graph.nodes
@@ -0,0 +1,178 @@
1
+ """Tests for type: tool_call node in node_factory.
2
+
3
+ TDD Phase 3: Dynamic tool execution from state.
4
+ """
5
+
6
+ import pytest
7
+
8
+ from yamlgraph.constants import NodeType
9
+ from yamlgraph.node_factory import create_tool_call_node
10
+
11
+
12
+ # Sample tools for testing
13
+ def sample_tool(path: str, pattern: str = ".*") -> dict:
14
+ """Sample tool that returns its args."""
15
+ return {"path": path, "pattern": pattern, "found": ["line1", "line2"]}
16
+
17
+
18
+ def failing_tool(path: str) -> dict:
19
+ """Tool that always raises."""
20
+ raise ValueError(f"Cannot process: {path}")
21
+
22
+
23
+ def simple_tool() -> str:
24
+ """Tool with no args."""
25
+ return "simple result"
26
+
27
+
28
+ @pytest.fixture
29
+ def tools_registry() -> dict:
30
+ """Sample tools registry."""
31
+ return {
32
+ "search_file": sample_tool,
33
+ "failing_tool": failing_tool,
34
+ "simple_tool": simple_tool,
35
+ }
36
+
37
+
38
+ class TestNodeTypeConstant:
39
+ """Verify TOOL_CALL is added to NodeType enum."""
40
+
41
+ def test_tool_call_in_node_type(self):
42
+ """TOOL_CALL should be a valid node type."""
43
+ assert NodeType.TOOL_CALL == "tool_call"
44
+
45
+
46
+ class TestCreateToolCallNode:
47
+ """Test create_tool_call_node factory function."""
48
+
49
+ def test_creates_callable(self, tools_registry):
50
+ """Should return a callable node function."""
51
+ config = {
52
+ "tool": "{state.task.tool}",
53
+ "args": "{state.task.args}",
54
+ "state_key": "result",
55
+ }
56
+ node_fn = create_tool_call_node("test_node", config, tools_registry)
57
+ assert callable(node_fn)
58
+
59
+ def test_resolves_tool_from_state(self, tools_registry):
60
+ """Should resolve tool name dynamically from state."""
61
+ config = {
62
+ "tool": "{state.task.tool}",
63
+ "args": "{state.task.args}",
64
+ "state_key": "result",
65
+ }
66
+ node_fn = create_tool_call_node("test_node", config, tools_registry)
67
+
68
+ state = {
69
+ "task": {
70
+ "id": 1,
71
+ "tool": "search_file",
72
+ "args": {"path": "foo.py", "pattern": "def"},
73
+ }
74
+ }
75
+ result = node_fn(state)
76
+
77
+ assert result["result"]["success"] is True
78
+ assert result["result"]["tool"] == "search_file"
79
+
80
+ def test_resolves_args_from_state(self, tools_registry):
81
+ """Should resolve args dynamically and pass to tool."""
82
+ config = {
83
+ "tool": "{state.task.tool}",
84
+ "args": "{state.task.args}",
85
+ "state_key": "result",
86
+ }
87
+ node_fn = create_tool_call_node("test_node", config, tools_registry)
88
+
89
+ state = {
90
+ "task": {
91
+ "id": 1,
92
+ "tool": "search_file",
93
+ "args": {"path": "bar.py", "pattern": "class"},
94
+ }
95
+ }
96
+ result = node_fn(state)
97
+
98
+ # Verify args were passed correctly
99
+ assert result["result"]["result"]["path"] == "bar.py"
100
+ assert result["result"]["result"]["pattern"] == "class"
101
+
102
+ def test_successful_execution(self, tools_registry):
103
+ """Should return success=True with result on success."""
104
+ config = {
105
+ "tool": "{state.task.tool}",
106
+ "args": "{state.task.args}",
107
+ "state_key": "result",
108
+ }
109
+ node_fn = create_tool_call_node("test_node", config, tools_registry)
110
+
111
+ state = {"task": {"id": 42, "tool": "search_file", "args": {"path": "x.py"}}}
112
+ result = node_fn(state)
113
+
114
+ assert result["result"]["success"] is True
115
+ assert result["result"]["task_id"] == 42
116
+ assert result["result"]["error"] is None
117
+ assert "found" in result["result"]["result"]
118
+
119
+ def test_unknown_tool_handling(self, tools_registry):
120
+ """Should return success=False for unknown tool."""
121
+ config = {
122
+ "tool": "{state.task.tool}",
123
+ "args": "{state.task.args}",
124
+ "state_key": "result",
125
+ }
126
+ node_fn = create_tool_call_node("test_node", config, tools_registry)
127
+
128
+ state = {"task": {"id": 1, "tool": "nonexistent_tool", "args": {}}}
129
+ result = node_fn(state)
130
+
131
+ assert result["result"]["success"] is False
132
+ assert "Unknown tool" in result["result"]["error"]
133
+ assert result["result"]["tool"] == "nonexistent_tool"
134
+
135
+ def test_tool_exception_handling(self, tools_registry):
136
+ """Should catch exceptions and return success=False."""
137
+ config = {
138
+ "tool": "{state.task.tool}",
139
+ "args": "{state.task.args}",
140
+ "state_key": "result",
141
+ }
142
+ node_fn = create_tool_call_node("test_node", config, tools_registry)
143
+
144
+ state = {"task": {"id": 5, "tool": "failing_tool", "args": {"path": "bad.py"}}}
145
+ result = node_fn(state)
146
+
147
+ assert result["result"]["success"] is False
148
+ assert "Cannot process: bad.py" in result["result"]["error"]
149
+ assert result["result"]["task_id"] == 5
150
+
151
+ def test_includes_current_step(self, tools_registry):
152
+ """Should include current_step in output for state tracking."""
153
+ config = {
154
+ "tool": "{state.task.tool}",
155
+ "args": "{state.task.args}",
156
+ "state_key": "result",
157
+ }
158
+ node_fn = create_tool_call_node("my_tool_call", config, tools_registry)
159
+
160
+ state = {"task": {"id": 1, "tool": "simple_tool", "args": {}}}
161
+ result = node_fn(state)
162
+
163
+ assert result["current_step"] == "my_tool_call"
164
+
165
+ def test_empty_args(self, tools_registry):
166
+ """Should work with tools that take no arguments."""
167
+ config = {
168
+ "tool": "{state.task.tool}",
169
+ "args": "{state.task.args}",
170
+ "state_key": "result",
171
+ }
172
+ node_fn = create_tool_call_node("test_node", config, tools_registry)
173
+
174
+ state = {"task": {"id": 1, "tool": "simple_tool", "args": {}}}
175
+ result = node_fn(state)
176
+
177
+ assert result["result"]["success"] is True
178
+ assert result["result"]["result"] == "simple result"
@@ -0,0 +1,129 @@
1
+ """Tests for tool nodes (type: tool)."""
2
+
3
+ import pytest
4
+
5
+ from yamlgraph.tools.nodes import create_tool_node
6
+ from yamlgraph.tools.shell import ShellToolConfig
7
+
8
+
9
+ class TestCreateToolNode:
10
+ """Tests for create_tool_node function."""
11
+
12
+ def test_executes_named_tool(self):
13
+ """Node runs correct tool from registry."""
14
+ tools = {
15
+ "echo_tool": ShellToolConfig(command="echo hello"),
16
+ "other_tool": ShellToolConfig(command="echo other"),
17
+ }
18
+ node_config = {"tool": "echo_tool"}
19
+
20
+ node_fn = create_tool_node("test_node", node_config, tools)
21
+ result = node_fn({})
22
+
23
+ assert result["test_node"].strip() == "hello"
24
+ assert result["current_step"] == "test_node"
25
+
26
+ def test_resolves_variables_from_state(self):
27
+ """State values passed to tool."""
28
+ tools = {
29
+ "greet": ShellToolConfig(command="echo Hello {name}"),
30
+ }
31
+ node_config = {
32
+ "tool": "greet",
33
+ "variables": {"name": "{state.user_name}"},
34
+ }
35
+
36
+ node_fn = create_tool_node("greet_node", node_config, tools)
37
+ result = node_fn({"user_name": "Alice"})
38
+
39
+ assert "Alice" in result["greet_node"]
40
+
41
+ def test_stores_result_in_state_key(self):
42
+ """Tool output saved to custom state_key."""
43
+ tools = {
44
+ "data_tool": ShellToolConfig(command="echo data_value"),
45
+ }
46
+ node_config = {
47
+ "tool": "data_tool",
48
+ "state_key": "my_data",
49
+ }
50
+
51
+ node_fn = create_tool_node("fetch_node", node_config, tools)
52
+ result = node_fn({})
53
+
54
+ assert "my_data" in result
55
+ assert result["my_data"].strip() == "data_value"
56
+
57
+ def test_on_error_skip(self):
58
+ """Failed tool skipped when on_error: skip."""
59
+ tools = {
60
+ "fail_tool": ShellToolConfig(command="exit 1"),
61
+ }
62
+ node_config = {
63
+ "tool": "fail_tool",
64
+ "on_error": "skip",
65
+ }
66
+
67
+ node_fn = create_tool_node("fail_node", node_config, tools)
68
+ result = node_fn({})
69
+
70
+ # Should not raise, should return with error info
71
+ assert result["current_step"] == "fail_node"
72
+ assert "errors" in result or result.get("fail_node") is None
73
+
74
+ def test_on_error_fail_raises(self):
75
+ """Failed tool raises when on_error: fail."""
76
+ tools = {
77
+ "fail_tool": ShellToolConfig(command="exit 1"),
78
+ }
79
+ node_config = {
80
+ "tool": "fail_tool",
81
+ "on_error": "fail",
82
+ }
83
+
84
+ node_fn = create_tool_node("fail_node", node_config, tools)
85
+
86
+ with pytest.raises(RuntimeError):
87
+ node_fn({})
88
+
89
+ def test_nested_state_variable(self):
90
+ """Nested state values like {state.location.lat} resolved."""
91
+ tools = {
92
+ "geo": ShellToolConfig(command="echo {lat},{lon}"),
93
+ }
94
+ node_config = {
95
+ "tool": "geo",
96
+ "variables": {
97
+ "lat": "{state.location.lat}",
98
+ "lon": "{state.location.lon}",
99
+ },
100
+ }
101
+
102
+ node_fn = create_tool_node("geo_node", node_config, tools)
103
+ result = node_fn({"location": {"lat": "37.7749", "lon": "-122.4194"}})
104
+
105
+ assert "37.7749" in result["geo_node"]
106
+ assert "-122.4194" in result["geo_node"]
107
+
108
+ def test_missing_tool_raises(self):
109
+ """Unknown tool name raises error."""
110
+ tools = {}
111
+ node_config = {"tool": "nonexistent"}
112
+
113
+ with pytest.raises(KeyError):
114
+ create_tool_node("bad_node", node_config, tools)
115
+
116
+ def test_json_parse_tool(self):
117
+ """Tool with parse: json returns dict."""
118
+ tools = {
119
+ "json_tool": ShellToolConfig(
120
+ command="echo '{{\"count\": 42}}'",
121
+ parse="json",
122
+ ),
123
+ }
124
+ node_config = {"tool": "json_tool"}
125
+
126
+ node_fn = create_tool_node("json_node", node_config, tools)
127
+ result = node_fn({})
128
+
129
+ assert result["json_node"] == {"count": 42}