yamlgraph 0.3.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- examples/__init__.py +1 -0
- examples/codegen/__init__.py +5 -0
- examples/codegen/models/__init__.py +13 -0
- examples/codegen/models/schemas.py +76 -0
- examples/codegen/tests/__init__.py +1 -0
- examples/codegen/tests/test_ai_helpers.py +235 -0
- examples/codegen/tests/test_ast_analysis.py +174 -0
- examples/codegen/tests/test_code_analysis.py +134 -0
- examples/codegen/tests/test_code_context.py +301 -0
- examples/codegen/tests/test_code_nav.py +89 -0
- examples/codegen/tests/test_dependency_tools.py +119 -0
- examples/codegen/tests/test_example_tools.py +185 -0
- examples/codegen/tests/test_git_tools.py +112 -0
- examples/codegen/tests/test_impl_agent_schemas.py +193 -0
- examples/codegen/tests/test_impl_agent_v4_graph.py +94 -0
- examples/codegen/tests/test_jedi_analysis.py +226 -0
- examples/codegen/tests/test_meta_tools.py +250 -0
- examples/codegen/tests/test_plan_discovery_prompt.py +98 -0
- examples/codegen/tests/test_syntax_tools.py +85 -0
- examples/codegen/tests/test_synthesize_prompt.py +94 -0
- examples/codegen/tests/test_template_tools.py +244 -0
- examples/codegen/tools/__init__.py +80 -0
- examples/codegen/tools/ai_helpers.py +420 -0
- examples/codegen/tools/ast_analysis.py +92 -0
- examples/codegen/tools/code_context.py +180 -0
- examples/codegen/tools/code_nav.py +52 -0
- examples/codegen/tools/dependency_tools.py +120 -0
- examples/codegen/tools/example_tools.py +188 -0
- examples/codegen/tools/git_tools.py +151 -0
- examples/codegen/tools/impl_executor.py +614 -0
- examples/codegen/tools/jedi_analysis.py +311 -0
- examples/codegen/tools/meta_tools.py +202 -0
- examples/codegen/tools/syntax_tools.py +26 -0
- examples/codegen/tools/template_tools.py +356 -0
- examples/fastapi_interview.py +167 -0
- examples/npc/api/__init__.py +1 -0
- examples/npc/api/app.py +100 -0
- examples/npc/api/routes/__init__.py +5 -0
- examples/npc/api/routes/encounter.py +182 -0
- examples/npc/api/session.py +330 -0
- examples/npc/demo.py +387 -0
- examples/npc/nodes/__init__.py +5 -0
- examples/npc/nodes/image_node.py +92 -0
- examples/npc/run_encounter.py +230 -0
- examples/shared/__init__.py +0 -0
- examples/shared/replicate_tool.py +238 -0
- examples/storyboard/__init__.py +1 -0
- examples/storyboard/generate_videos.py +335 -0
- examples/storyboard/nodes/__init__.py +12 -0
- examples/storyboard/nodes/animated_character_node.py +248 -0
- examples/storyboard/nodes/animated_image_node.py +138 -0
- examples/storyboard/nodes/character_node.py +162 -0
- examples/storyboard/nodes/image_node.py +118 -0
- examples/storyboard/nodes/replicate_tool.py +49 -0
- examples/storyboard/retry_images.py +118 -0
- scripts/demo_async_executor.py +212 -0
- scripts/demo_interview_e2e.py +200 -0
- scripts/demo_streaming.py +140 -0
- scripts/run_interview_demo.py +94 -0
- scripts/test_interrupt_fix.py +26 -0
- tests/__init__.py +1 -0
- tests/conftest.py +178 -0
- tests/integration/__init__.py +1 -0
- tests/integration/test_animated_storyboard.py +63 -0
- tests/integration/test_cli_commands.py +242 -0
- tests/integration/test_colocated_prompts.py +139 -0
- tests/integration/test_map_demo.py +50 -0
- tests/integration/test_memory_demo.py +283 -0
- tests/integration/test_npc_api/__init__.py +1 -0
- tests/integration/test_npc_api/test_routes.py +357 -0
- tests/integration/test_npc_api/test_session.py +216 -0
- tests/integration/test_pipeline_flow.py +105 -0
- tests/integration/test_providers.py +163 -0
- tests/integration/test_resume.py +75 -0
- tests/integration/test_subgraph_integration.py +295 -0
- tests/integration/test_subgraph_interrupt.py +106 -0
- tests/unit/__init__.py +1 -0
- tests/unit/test_agent_nodes.py +355 -0
- tests/unit/test_async_executor.py +346 -0
- tests/unit/test_checkpointer.py +212 -0
- tests/unit/test_checkpointer_factory.py +212 -0
- tests/unit/test_cli.py +121 -0
- tests/unit/test_cli_package.py +81 -0
- tests/unit/test_compile_graph_map.py +132 -0
- tests/unit/test_conditions_routing.py +253 -0
- tests/unit/test_config.py +93 -0
- tests/unit/test_conversation_memory.py +276 -0
- tests/unit/test_database.py +145 -0
- tests/unit/test_deprecation.py +104 -0
- tests/unit/test_executor.py +172 -0
- tests/unit/test_executor_async.py +179 -0
- tests/unit/test_export.py +149 -0
- tests/unit/test_expressions.py +178 -0
- tests/unit/test_feature_brainstorm.py +194 -0
- tests/unit/test_format_prompt.py +145 -0
- tests/unit/test_generic_report.py +200 -0
- tests/unit/test_graph_commands.py +327 -0
- tests/unit/test_graph_linter.py +627 -0
- tests/unit/test_graph_loader.py +357 -0
- tests/unit/test_graph_schema.py +193 -0
- tests/unit/test_inline_schema.py +151 -0
- tests/unit/test_interrupt_node.py +182 -0
- tests/unit/test_issues.py +164 -0
- tests/unit/test_jinja2_prompts.py +85 -0
- tests/unit/test_json_extract.py +134 -0
- tests/unit/test_langsmith.py +600 -0
- tests/unit/test_langsmith_tools.py +204 -0
- tests/unit/test_llm_factory.py +109 -0
- tests/unit/test_llm_factory_async.py +118 -0
- tests/unit/test_loops.py +403 -0
- tests/unit/test_map_node.py +144 -0
- tests/unit/test_no_backward_compat.py +56 -0
- tests/unit/test_node_factory.py +348 -0
- tests/unit/test_passthrough_node.py +126 -0
- tests/unit/test_prompts.py +324 -0
- tests/unit/test_python_nodes.py +198 -0
- tests/unit/test_reliability.py +298 -0
- tests/unit/test_result_export.py +234 -0
- tests/unit/test_router.py +296 -0
- tests/unit/test_sanitize.py +99 -0
- tests/unit/test_schema_loader.py +295 -0
- tests/unit/test_shell_tools.py +229 -0
- tests/unit/test_state_builder.py +331 -0
- tests/unit/test_state_builder_map.py +104 -0
- tests/unit/test_state_config.py +197 -0
- tests/unit/test_streaming.py +307 -0
- tests/unit/test_subgraph.py +596 -0
- tests/unit/test_template.py +190 -0
- tests/unit/test_tool_call_integration.py +164 -0
- tests/unit/test_tool_call_node.py +178 -0
- tests/unit/test_tool_nodes.py +129 -0
- tests/unit/test_websearch.py +234 -0
- yamlgraph/__init__.py +35 -0
- yamlgraph/builder.py +110 -0
- yamlgraph/cli/__init__.py +159 -0
- yamlgraph/cli/__main__.py +6 -0
- yamlgraph/cli/commands.py +231 -0
- yamlgraph/cli/deprecation.py +92 -0
- yamlgraph/cli/graph_commands.py +541 -0
- yamlgraph/cli/validators.py +37 -0
- yamlgraph/config.py +67 -0
- yamlgraph/constants.py +70 -0
- yamlgraph/error_handlers.py +227 -0
- yamlgraph/executor.py +290 -0
- yamlgraph/executor_async.py +288 -0
- yamlgraph/graph_loader.py +451 -0
- yamlgraph/map_compiler.py +150 -0
- yamlgraph/models/__init__.py +36 -0
- yamlgraph/models/graph_schema.py +181 -0
- yamlgraph/models/schemas.py +124 -0
- yamlgraph/models/state_builder.py +236 -0
- yamlgraph/node_factory.py +768 -0
- yamlgraph/routing.py +87 -0
- yamlgraph/schema_loader.py +240 -0
- yamlgraph/storage/__init__.py +20 -0
- yamlgraph/storage/checkpointer.py +72 -0
- yamlgraph/storage/checkpointer_factory.py +123 -0
- yamlgraph/storage/database.py +320 -0
- yamlgraph/storage/export.py +269 -0
- yamlgraph/tools/__init__.py +1 -0
- yamlgraph/tools/agent.py +320 -0
- yamlgraph/tools/graph_linter.py +388 -0
- yamlgraph/tools/langsmith_tools.py +125 -0
- yamlgraph/tools/nodes.py +126 -0
- yamlgraph/tools/python_tool.py +179 -0
- yamlgraph/tools/shell.py +205 -0
- yamlgraph/tools/websearch.py +242 -0
- yamlgraph/utils/__init__.py +48 -0
- yamlgraph/utils/conditions.py +157 -0
- yamlgraph/utils/expressions.py +245 -0
- yamlgraph/utils/json_extract.py +104 -0
- yamlgraph/utils/langsmith.py +416 -0
- yamlgraph/utils/llm_factory.py +118 -0
- yamlgraph/utils/llm_factory_async.py +105 -0
- yamlgraph/utils/logging.py +104 -0
- yamlgraph/utils/prompts.py +171 -0
- yamlgraph/utils/sanitize.py +98 -0
- yamlgraph/utils/template.py +102 -0
- yamlgraph/utils/validators.py +181 -0
- yamlgraph-0.3.9.dist-info/METADATA +1105 -0
- yamlgraph-0.3.9.dist-info/RECORD +185 -0
- yamlgraph-0.3.9.dist-info/WHEEL +5 -0
- yamlgraph-0.3.9.dist-info/entry_points.txt +2 -0
- yamlgraph-0.3.9.dist-info/licenses/LICENSE +33 -0
- yamlgraph-0.3.9.dist-info/top_level.txt +4 -0
|
@@ -0,0 +1,229 @@
|
|
|
1
|
+
"""Tests for shell tool executor."""
|
|
2
|
+
|
|
3
|
+
from yamlgraph.tools.shell import (
|
|
4
|
+
ShellToolConfig,
|
|
5
|
+
execute_shell_tool,
|
|
6
|
+
parse_tools,
|
|
7
|
+
sanitize_variables,
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class TestShellToolConfig:
|
|
12
|
+
"""Tests for ShellToolConfig dataclass."""
|
|
13
|
+
|
|
14
|
+
def test_default_values(self):
|
|
15
|
+
"""Config has sensible defaults."""
|
|
16
|
+
config = ShellToolConfig(command="echo hello")
|
|
17
|
+
assert config.command == "echo hello"
|
|
18
|
+
assert config.description == ""
|
|
19
|
+
assert config.parse == "text"
|
|
20
|
+
assert config.timeout == 30
|
|
21
|
+
assert config.working_dir == "."
|
|
22
|
+
assert config.env == {}
|
|
23
|
+
assert config.success_codes == [0]
|
|
24
|
+
|
|
25
|
+
def test_custom_values(self):
|
|
26
|
+
"""Config accepts custom values."""
|
|
27
|
+
config = ShellToolConfig(
|
|
28
|
+
command="curl http://api.example.com",
|
|
29
|
+
description="Fetch API data",
|
|
30
|
+
parse="json",
|
|
31
|
+
timeout=60,
|
|
32
|
+
working_dir="/tmp",
|
|
33
|
+
env={"API_KEY": "secret"},
|
|
34
|
+
success_codes=[0, 1],
|
|
35
|
+
)
|
|
36
|
+
assert config.parse == "json"
|
|
37
|
+
assert config.timeout == 60
|
|
38
|
+
assert config.env == {"API_KEY": "secret"}
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class TestSanitizeVariables:
|
|
42
|
+
"""Tests for variable sanitization."""
|
|
43
|
+
|
|
44
|
+
def test_sanitizes_simple_string(self):
|
|
45
|
+
"""Simple strings are quoted."""
|
|
46
|
+
result = sanitize_variables({"name": "Alice"})
|
|
47
|
+
# shlex.quote adds quotes around strings with no special chars
|
|
48
|
+
assert result["name"] in ("Alice", "'Alice'")
|
|
49
|
+
|
|
50
|
+
def test_sanitizes_shell_injection(self):
|
|
51
|
+
"""Shell injection attempts are safely quoted."""
|
|
52
|
+
# Command substitution attempt
|
|
53
|
+
result = sanitize_variables({"name": "$(rm -rf /)"})
|
|
54
|
+
assert "$" not in result["name"] or result["name"].startswith("'")
|
|
55
|
+
# The result should be a quoted string
|
|
56
|
+
assert result["name"] == "'$(rm -rf /)'"
|
|
57
|
+
|
|
58
|
+
def test_sanitizes_semicolon_injection(self):
|
|
59
|
+
"""Semicolon command chaining is prevented."""
|
|
60
|
+
result = sanitize_variables({"name": "test; rm -rf /"})
|
|
61
|
+
assert "'" in result["name"] # Must be quoted
|
|
62
|
+
|
|
63
|
+
def test_sanitizes_pipe_injection(self):
|
|
64
|
+
"""Pipe injection is prevented."""
|
|
65
|
+
result = sanitize_variables({"name": "test | cat /etc/passwd"})
|
|
66
|
+
assert "'" in result["name"] # Must be quoted
|
|
67
|
+
|
|
68
|
+
def test_handles_none_values(self):
|
|
69
|
+
"""None values become empty strings."""
|
|
70
|
+
result = sanitize_variables({"name": None})
|
|
71
|
+
assert result["name"] == ""
|
|
72
|
+
|
|
73
|
+
def test_handles_list_values(self):
|
|
74
|
+
"""List values are JSON encoded and quoted."""
|
|
75
|
+
result = sanitize_variables({"items": [1, 2, 3]})
|
|
76
|
+
assert "[1, 2, 3]" in result["items"] or result["items"] == "'[1, 2, 3]'"
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class TestExecuteShellTool:
|
|
80
|
+
"""Tests for execute_shell_tool function."""
|
|
81
|
+
|
|
82
|
+
def test_executes_command(self):
|
|
83
|
+
"""Simple command executes successfully."""
|
|
84
|
+
config = ShellToolConfig(command="echo hello")
|
|
85
|
+
result = execute_shell_tool(config, {})
|
|
86
|
+
assert result.success is True
|
|
87
|
+
assert result.output.strip() == "hello"
|
|
88
|
+
assert result.error is None
|
|
89
|
+
|
|
90
|
+
def test_substitutes_variables(self):
|
|
91
|
+
"""Placeholders replaced with values."""
|
|
92
|
+
config = ShellToolConfig(command="echo {message}")
|
|
93
|
+
result = execute_shell_tool(config, {"message": "world"})
|
|
94
|
+
assert result.success is True
|
|
95
|
+
assert result.output.strip() == "world"
|
|
96
|
+
|
|
97
|
+
def test_multiple_variables(self):
|
|
98
|
+
"""Multiple placeholders all substituted."""
|
|
99
|
+
config = ShellToolConfig(command="echo {a} {b} {c}")
|
|
100
|
+
result = execute_shell_tool(config, {"a": "1", "b": "2", "c": "3"})
|
|
101
|
+
assert result.output.strip() == "1 2 3"
|
|
102
|
+
|
|
103
|
+
def test_parses_json_output(self):
|
|
104
|
+
"""JSON stdout parsed to dict."""
|
|
105
|
+
# Double braces escape them from .format()
|
|
106
|
+
config = ShellToolConfig(
|
|
107
|
+
command='echo \'{{"name": "test", "value": 42}}\'',
|
|
108
|
+
parse="json",
|
|
109
|
+
)
|
|
110
|
+
result = execute_shell_tool(config, {})
|
|
111
|
+
assert result.success is True
|
|
112
|
+
assert result.output == {"name": "test", "value": 42}
|
|
113
|
+
|
|
114
|
+
def test_parse_none_returns_none(self):
|
|
115
|
+
"""parse=none returns None for side-effect commands."""
|
|
116
|
+
config = ShellToolConfig(command="echo ignored", parse="none")
|
|
117
|
+
result = execute_shell_tool(config, {})
|
|
118
|
+
assert result.success is True
|
|
119
|
+
assert result.output is None
|
|
120
|
+
|
|
121
|
+
def test_handles_timeout(self):
|
|
122
|
+
"""Long-running command times out."""
|
|
123
|
+
config = ShellToolConfig(command="sleep 10", timeout=1)
|
|
124
|
+
result = execute_shell_tool(config, {})
|
|
125
|
+
assert result.success is False
|
|
126
|
+
assert "timed out" in result.error.lower()
|
|
127
|
+
|
|
128
|
+
def test_captures_stderr_on_error(self):
|
|
129
|
+
"""Non-zero exit captures stderr."""
|
|
130
|
+
config = ShellToolConfig(command="ls /nonexistent_path_xyz")
|
|
131
|
+
result = execute_shell_tool(config, {})
|
|
132
|
+
assert result.success is False
|
|
133
|
+
assert result.error is not None
|
|
134
|
+
assert "No such file" in result.error or "nonexistent" in result.error.lower()
|
|
135
|
+
|
|
136
|
+
def test_custom_success_codes(self):
|
|
137
|
+
"""Custom success codes treated as success."""
|
|
138
|
+
# grep returns 1 when no match found
|
|
139
|
+
config = ShellToolConfig(
|
|
140
|
+
command="grep nonexistent /dev/null",
|
|
141
|
+
success_codes=[0, 1],
|
|
142
|
+
)
|
|
143
|
+
result = execute_shell_tool(config, {})
|
|
144
|
+
assert result.success is True
|
|
145
|
+
|
|
146
|
+
def test_working_dir(self):
|
|
147
|
+
"""Command runs in specified directory."""
|
|
148
|
+
config = ShellToolConfig(command="pwd", working_dir="/tmp")
|
|
149
|
+
result = execute_shell_tool(config, {})
|
|
150
|
+
assert result.success is True
|
|
151
|
+
assert "/tmp" in result.output
|
|
152
|
+
|
|
153
|
+
def test_env_variables(self):
|
|
154
|
+
"""Environment variables passed to command."""
|
|
155
|
+
config = ShellToolConfig(
|
|
156
|
+
command="echo $TEST_VAR",
|
|
157
|
+
env={"TEST_VAR": "secret_value"},
|
|
158
|
+
)
|
|
159
|
+
result = execute_shell_tool(config, {})
|
|
160
|
+
assert result.success is True
|
|
161
|
+
assert "secret_value" in result.output
|
|
162
|
+
|
|
163
|
+
def test_invalid_json_parse_fails(self):
|
|
164
|
+
"""Invalid JSON returns error."""
|
|
165
|
+
config = ShellToolConfig(command="echo 'not json'", parse="json")
|
|
166
|
+
result = execute_shell_tool(config, {})
|
|
167
|
+
assert result.success is False
|
|
168
|
+
assert "json" in result.error.lower()
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
class TestParseTools:
|
|
172
|
+
"""Tests for parse_tools function."""
|
|
173
|
+
|
|
174
|
+
def test_empty_config(self):
|
|
175
|
+
"""Empty config returns empty registry."""
|
|
176
|
+
registry = parse_tools({})
|
|
177
|
+
assert registry == {}
|
|
178
|
+
|
|
179
|
+
def test_parses_single_tool(self):
|
|
180
|
+
"""Single tool parsed correctly."""
|
|
181
|
+
config = {
|
|
182
|
+
"search": {
|
|
183
|
+
"command": "curl -s {url}",
|
|
184
|
+
"description": "Search the web",
|
|
185
|
+
"parse": "json",
|
|
186
|
+
"timeout": 60,
|
|
187
|
+
}
|
|
188
|
+
}
|
|
189
|
+
registry = parse_tools(config)
|
|
190
|
+
assert "search" in registry
|
|
191
|
+
assert registry["search"].command == "curl -s {url}"
|
|
192
|
+
assert registry["search"].description == "Search the web"
|
|
193
|
+
assert registry["search"].parse == "json"
|
|
194
|
+
assert registry["search"].timeout == 60
|
|
195
|
+
|
|
196
|
+
def test_parses_multiple_tools(self):
|
|
197
|
+
"""Multiple tools all parsed."""
|
|
198
|
+
config = {
|
|
199
|
+
"tool1": {"command": "echo 1"},
|
|
200
|
+
"tool2": {"command": "echo 2"},
|
|
201
|
+
"tool3": {"command": "echo 3"},
|
|
202
|
+
}
|
|
203
|
+
registry = parse_tools(config)
|
|
204
|
+
assert len(registry) == 3
|
|
205
|
+
assert all(name in registry for name in ["tool1", "tool2", "tool3"])
|
|
206
|
+
|
|
207
|
+
def test_default_values_applied(self):
|
|
208
|
+
"""Missing optional fields get defaults."""
|
|
209
|
+
config = {"minimal": {"command": "echo hello"}}
|
|
210
|
+
registry = parse_tools(config)
|
|
211
|
+
tool = registry["minimal"]
|
|
212
|
+
assert tool.description == ""
|
|
213
|
+
assert tool.parse == "text"
|
|
214
|
+
assert tool.timeout == 30
|
|
215
|
+
assert tool.working_dir == "."
|
|
216
|
+
assert tool.env == {}
|
|
217
|
+
|
|
218
|
+
def test_parses_env_and_working_dir(self):
|
|
219
|
+
"""env and working_dir parsed correctly."""
|
|
220
|
+
config = {
|
|
221
|
+
"script": {
|
|
222
|
+
"command": "node index.js",
|
|
223
|
+
"working_dir": "./scripts",
|
|
224
|
+
"env": {"NODE_ENV": "production"},
|
|
225
|
+
}
|
|
226
|
+
}
|
|
227
|
+
registry = parse_tools(config)
|
|
228
|
+
assert registry["script"].working_dir == "./scripts"
|
|
229
|
+
assert registry["script"].env == {"NODE_ENV": "production"}
|
|
@@ -0,0 +1,331 @@
|
|
|
1
|
+
"""Unit tests for dynamic state builder.
|
|
2
|
+
|
|
3
|
+
TDD: Red phase - these tests define the expected behavior.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from operator import add
|
|
7
|
+
from typing import Annotated, get_args, get_origin
|
|
8
|
+
|
|
9
|
+
from yamlgraph.models.state_builder import sorted_add
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class TestSortedAdd:
|
|
13
|
+
"""Test the sorted_add reducer for map node fan-in."""
|
|
14
|
+
|
|
15
|
+
def test_concatenates_lists(self):
|
|
16
|
+
"""Should concatenate two lists."""
|
|
17
|
+
result = sorted_add([1, 2], [3, 4])
|
|
18
|
+
assert result == [1, 2, 3, 4]
|
|
19
|
+
|
|
20
|
+
def test_handles_empty_existing(self):
|
|
21
|
+
"""Should handle empty existing list."""
|
|
22
|
+
result = sorted_add([], [1, 2])
|
|
23
|
+
assert result == [1, 2]
|
|
24
|
+
|
|
25
|
+
def test_handles_none_existing(self):
|
|
26
|
+
"""Should handle None as existing list."""
|
|
27
|
+
result = sorted_add(None, [1, 2])
|
|
28
|
+
assert result == [1, 2]
|
|
29
|
+
|
|
30
|
+
def test_handles_empty_new(self):
|
|
31
|
+
"""Should handle empty new list."""
|
|
32
|
+
result = sorted_add([1, 2], [])
|
|
33
|
+
assert result == [1, 2]
|
|
34
|
+
|
|
35
|
+
def test_handles_none_new(self):
|
|
36
|
+
"""Should handle None as new list."""
|
|
37
|
+
result = sorted_add([1, 2], None)
|
|
38
|
+
assert result == [1, 2]
|
|
39
|
+
|
|
40
|
+
def test_sorts_by_map_index(self):
|
|
41
|
+
"""Should sort results by _map_index for map fan-in."""
|
|
42
|
+
# Simulate out-of-order parallel results
|
|
43
|
+
existing = [{"_map_index": 2, "value": "third"}]
|
|
44
|
+
new = [{"_map_index": 0, "value": "first"}]
|
|
45
|
+
result = sorted_add(existing, new)
|
|
46
|
+
|
|
47
|
+
assert result[0]["_map_index"] == 0
|
|
48
|
+
assert result[0]["value"] == "first"
|
|
49
|
+
assert result[1]["_map_index"] == 2
|
|
50
|
+
assert result[1]["value"] == "third"
|
|
51
|
+
|
|
52
|
+
def test_sorts_multiple_out_of_order(self):
|
|
53
|
+
"""Should sort many out-of-order items correctly."""
|
|
54
|
+
# Simulate 5 items arriving in random order
|
|
55
|
+
items = [
|
|
56
|
+
{"_map_index": 3, "data": "d"},
|
|
57
|
+
{"_map_index": 0, "data": "a"},
|
|
58
|
+
{"_map_index": 4, "data": "e"},
|
|
59
|
+
{"_map_index": 1, "data": "b"},
|
|
60
|
+
{"_map_index": 2, "data": "c"},
|
|
61
|
+
]
|
|
62
|
+
result = sorted_add([], items)
|
|
63
|
+
|
|
64
|
+
assert [r["data"] for r in result] == ["a", "b", "c", "d", "e"]
|
|
65
|
+
|
|
66
|
+
def test_no_sort_for_non_dict_items(self):
|
|
67
|
+
"""Should not sort if items are not dicts."""
|
|
68
|
+
result = sorted_add([3, 1], [2])
|
|
69
|
+
assert result == [3, 1, 2] # Preserved insertion order
|
|
70
|
+
|
|
71
|
+
def test_no_sort_for_dicts_without_map_index(self):
|
|
72
|
+
"""Should not sort if dicts lack _map_index."""
|
|
73
|
+
result = sorted_add([{"a": 1}], [{"b": 2}])
|
|
74
|
+
assert result == [{"a": 1}, {"b": 2}]
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class TestBuildStateClass:
|
|
78
|
+
"""Test dynamic TypedDict generation from graph config."""
|
|
79
|
+
|
|
80
|
+
def test_includes_base_infrastructure_fields(self):
|
|
81
|
+
"""State always has infrastructure fields."""
|
|
82
|
+
from yamlgraph.models.state_builder import build_state_class
|
|
83
|
+
|
|
84
|
+
config = {"nodes": {}, "edges": []}
|
|
85
|
+
State = build_state_class(config)
|
|
86
|
+
|
|
87
|
+
annotations = State.__annotations__
|
|
88
|
+
assert "thread_id" in annotations
|
|
89
|
+
assert "current_step" in annotations
|
|
90
|
+
assert "errors" in annotations
|
|
91
|
+
assert "messages" in annotations
|
|
92
|
+
|
|
93
|
+
def test_errors_has_reducer(self):
|
|
94
|
+
"""errors field uses Annotated[list, add] reducer."""
|
|
95
|
+
from yamlgraph.models.state_builder import build_state_class
|
|
96
|
+
|
|
97
|
+
config = {"nodes": {}, "edges": []}
|
|
98
|
+
State = build_state_class(config)
|
|
99
|
+
|
|
100
|
+
errors_type = State.__annotations__["errors"]
|
|
101
|
+
assert get_origin(errors_type) is Annotated
|
|
102
|
+
args = get_args(errors_type)
|
|
103
|
+
assert args[0] is list
|
|
104
|
+
assert args[1] is add
|
|
105
|
+
|
|
106
|
+
def test_messages_has_reducer(self):
|
|
107
|
+
"""messages field uses Annotated[list, add] reducer."""
|
|
108
|
+
from yamlgraph.models.state_builder import build_state_class
|
|
109
|
+
|
|
110
|
+
config = {"nodes": {}, "edges": []}
|
|
111
|
+
State = build_state_class(config)
|
|
112
|
+
|
|
113
|
+
messages_type = State.__annotations__["messages"]
|
|
114
|
+
assert get_origin(messages_type) is Annotated
|
|
115
|
+
args = get_args(messages_type)
|
|
116
|
+
assert args[0] is list
|
|
117
|
+
assert args[1] is add
|
|
118
|
+
|
|
119
|
+
def test_extracts_state_key_from_nodes(self):
|
|
120
|
+
"""state_key in node config becomes state field."""
|
|
121
|
+
from yamlgraph.models.state_builder import build_state_class
|
|
122
|
+
|
|
123
|
+
config = {
|
|
124
|
+
"nodes": {
|
|
125
|
+
"generate": {"prompt": "generate", "state_key": "generated"},
|
|
126
|
+
"analyze": {"prompt": "analyze", "state_key": "analysis"},
|
|
127
|
+
},
|
|
128
|
+
"edges": [],
|
|
129
|
+
}
|
|
130
|
+
State = build_state_class(config)
|
|
131
|
+
|
|
132
|
+
assert "generated" in State.__annotations__
|
|
133
|
+
assert "analysis" in State.__annotations__
|
|
134
|
+
|
|
135
|
+
def test_agent_node_adds_input_field(self):
|
|
136
|
+
"""Agent nodes automatically add 'input' field."""
|
|
137
|
+
from yamlgraph.models.state_builder import build_state_class
|
|
138
|
+
|
|
139
|
+
config = {
|
|
140
|
+
"nodes": {
|
|
141
|
+
"agent": {"type": "agent", "prompt": "agent"},
|
|
142
|
+
},
|
|
143
|
+
"edges": [],
|
|
144
|
+
}
|
|
145
|
+
State = build_state_class(config)
|
|
146
|
+
|
|
147
|
+
assert "input" in State.__annotations__
|
|
148
|
+
|
|
149
|
+
def test_agent_node_adds_tool_results_field(self):
|
|
150
|
+
"""Agent nodes add _tool_results field."""
|
|
151
|
+
from yamlgraph.models.state_builder import build_state_class
|
|
152
|
+
|
|
153
|
+
config = {
|
|
154
|
+
"nodes": {
|
|
155
|
+
"agent": {"type": "agent", "prompt": "agent"},
|
|
156
|
+
},
|
|
157
|
+
"edges": [],
|
|
158
|
+
}
|
|
159
|
+
State = build_state_class(config)
|
|
160
|
+
|
|
161
|
+
assert "_tool_results" in State.__annotations__
|
|
162
|
+
|
|
163
|
+
def test_router_node_adds_route_field(self):
|
|
164
|
+
"""Router nodes add _route field."""
|
|
165
|
+
from yamlgraph.models.state_builder import build_state_class
|
|
166
|
+
|
|
167
|
+
config = {
|
|
168
|
+
"nodes": {
|
|
169
|
+
"router": {
|
|
170
|
+
"type": "router",
|
|
171
|
+
"prompt": "router",
|
|
172
|
+
"routes": {"a": "node_a", "b": "node_b"},
|
|
173
|
+
},
|
|
174
|
+
},
|
|
175
|
+
"edges": [],
|
|
176
|
+
}
|
|
177
|
+
State = build_state_class(config)
|
|
178
|
+
|
|
179
|
+
assert "_route" in State.__annotations__
|
|
180
|
+
|
|
181
|
+
def test_loop_tracking_fields_included(self):
|
|
182
|
+
"""Loop tracking fields are always included."""
|
|
183
|
+
from yamlgraph.models.state_builder import build_state_class
|
|
184
|
+
|
|
185
|
+
config = {"nodes": {}, "edges": []}
|
|
186
|
+
State = build_state_class(config)
|
|
187
|
+
|
|
188
|
+
assert "_loop_counts" in State.__annotations__
|
|
189
|
+
assert "_loop_limit_reached" in State.__annotations__
|
|
190
|
+
assert "_agent_iterations" in State.__annotations__
|
|
191
|
+
assert "_agent_limit_reached" in State.__annotations__
|
|
192
|
+
|
|
193
|
+
def test_state_is_typeddict_total_false(self):
|
|
194
|
+
"""Generated state is TypedDict with total=False (all optional)."""
|
|
195
|
+
from yamlgraph.models.state_builder import build_state_class
|
|
196
|
+
|
|
197
|
+
config = {"nodes": {}, "edges": []}
|
|
198
|
+
State = build_state_class(config)
|
|
199
|
+
|
|
200
|
+
# TypedDict with total=False has __total__ = False
|
|
201
|
+
assert State.__total__ is False
|
|
202
|
+
|
|
203
|
+
def test_state_works_with_langgraph(self):
|
|
204
|
+
"""Generated state class works with LangGraph StateGraph."""
|
|
205
|
+
from langgraph.graph import StateGraph
|
|
206
|
+
|
|
207
|
+
from yamlgraph.models.state_builder import build_state_class
|
|
208
|
+
|
|
209
|
+
config = {
|
|
210
|
+
"nodes": {
|
|
211
|
+
"test": {"prompt": "test", "state_key": "result"},
|
|
212
|
+
},
|
|
213
|
+
"edges": [],
|
|
214
|
+
}
|
|
215
|
+
State = build_state_class(config)
|
|
216
|
+
|
|
217
|
+
# Should not raise
|
|
218
|
+
graph = StateGraph(State)
|
|
219
|
+
graph.add_node("test", lambda s: {"result": "done"})
|
|
220
|
+
graph.set_entry_point("test")
|
|
221
|
+
graph.set_finish_point("test")
|
|
222
|
+
compiled = graph.compile()
|
|
223
|
+
|
|
224
|
+
# Verify fields are preserved
|
|
225
|
+
result = compiled.invoke({"input": "hello"})
|
|
226
|
+
assert "result" in result
|
|
227
|
+
|
|
228
|
+
def test_reducer_accumulates_messages(self):
|
|
229
|
+
"""Messages reducer accumulates across nodes."""
|
|
230
|
+
from langgraph.graph import StateGraph
|
|
231
|
+
|
|
232
|
+
from yamlgraph.models.state_builder import build_state_class
|
|
233
|
+
|
|
234
|
+
config = {"nodes": {}, "edges": []}
|
|
235
|
+
State = build_state_class(config)
|
|
236
|
+
|
|
237
|
+
graph = StateGraph(State)
|
|
238
|
+
graph.add_node("n1", lambda s: {"messages": [{"content": "a"}]})
|
|
239
|
+
graph.add_node("n2", lambda s: {"messages": [{"content": "b"}]})
|
|
240
|
+
graph.add_edge("n1", "n2")
|
|
241
|
+
graph.set_entry_point("n1")
|
|
242
|
+
graph.set_finish_point("n2")
|
|
243
|
+
compiled = graph.compile()
|
|
244
|
+
|
|
245
|
+
result = compiled.invoke({})
|
|
246
|
+
assert len(result["messages"]) == 2
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
class TestExtractNodeFields:
|
|
250
|
+
"""Test field extraction from node configurations."""
|
|
251
|
+
|
|
252
|
+
def test_extracts_state_key(self):
|
|
253
|
+
"""Extracts state_key from nodes."""
|
|
254
|
+
from yamlgraph.models.state_builder import extract_node_fields
|
|
255
|
+
|
|
256
|
+
nodes = {
|
|
257
|
+
"gen": {"state_key": "generated"},
|
|
258
|
+
"analyze": {"state_key": "analysis"},
|
|
259
|
+
}
|
|
260
|
+
fields = extract_node_fields(nodes)
|
|
261
|
+
|
|
262
|
+
assert "generated" in fields
|
|
263
|
+
assert "analysis" in fields
|
|
264
|
+
|
|
265
|
+
def test_agent_adds_special_fields(self):
|
|
266
|
+
"""Agent nodes add input and _tool_results."""
|
|
267
|
+
from yamlgraph.models.state_builder import extract_node_fields
|
|
268
|
+
|
|
269
|
+
nodes = {"agent": {"type": "agent"}}
|
|
270
|
+
fields = extract_node_fields(nodes)
|
|
271
|
+
|
|
272
|
+
assert "input" in fields
|
|
273
|
+
assert "_tool_results" in fields
|
|
274
|
+
|
|
275
|
+
def test_router_adds_route_field(self):
|
|
276
|
+
"""Router nodes add _route."""
|
|
277
|
+
from yamlgraph.models.state_builder import extract_node_fields
|
|
278
|
+
|
|
279
|
+
nodes = {"router": {"type": "router", "routes": {}}}
|
|
280
|
+
fields = extract_node_fields(nodes)
|
|
281
|
+
|
|
282
|
+
assert "_route" in fields
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
class TestCommonInputFields:
|
|
286
|
+
"""Test that common input fields are included."""
|
|
287
|
+
|
|
288
|
+
def test_includes_topic_field(self):
|
|
289
|
+
"""topic field included for content generation."""
|
|
290
|
+
from yamlgraph.models.state_builder import build_state_class
|
|
291
|
+
|
|
292
|
+
config = {"nodes": {}, "edges": []}
|
|
293
|
+
State = build_state_class(config)
|
|
294
|
+
|
|
295
|
+
assert "topic" in State.__annotations__
|
|
296
|
+
|
|
297
|
+
def test_includes_style_field(self):
|
|
298
|
+
"""style field included for content generation."""
|
|
299
|
+
from yamlgraph.models.state_builder import build_state_class
|
|
300
|
+
|
|
301
|
+
config = {"nodes": {}, "edges": []}
|
|
302
|
+
State = build_state_class(config)
|
|
303
|
+
|
|
304
|
+
assert "style" in State.__annotations__
|
|
305
|
+
|
|
306
|
+
def test_includes_word_count_field(self):
|
|
307
|
+
"""word_count field included for content generation."""
|
|
308
|
+
from yamlgraph.models.state_builder import build_state_class
|
|
309
|
+
|
|
310
|
+
config = {"nodes": {}, "edges": []}
|
|
311
|
+
State = build_state_class(config)
|
|
312
|
+
|
|
313
|
+
assert "word_count" in State.__annotations__
|
|
314
|
+
|
|
315
|
+
def test_includes_message_field(self):
|
|
316
|
+
"""message field included for router."""
|
|
317
|
+
from yamlgraph.models.state_builder import build_state_class
|
|
318
|
+
|
|
319
|
+
config = {"nodes": {}, "edges": []}
|
|
320
|
+
State = build_state_class(config)
|
|
321
|
+
|
|
322
|
+
assert "message" in State.__annotations__
|
|
323
|
+
|
|
324
|
+
def test_includes_input_field(self):
|
|
325
|
+
"""input field included for agents."""
|
|
326
|
+
from yamlgraph.models.state_builder import build_state_class
|
|
327
|
+
|
|
328
|
+
config = {"nodes": {}, "edges": []}
|
|
329
|
+
State = build_state_class(config)
|
|
330
|
+
|
|
331
|
+
assert "input" in State.__annotations__
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
"""Tests for state_builder map node reducer support."""
|
|
2
|
+
|
|
3
|
+
from typing import Annotated, get_args, get_origin
|
|
4
|
+
|
|
5
|
+
from yamlgraph.models.state_builder import (
|
|
6
|
+
build_state_class,
|
|
7
|
+
extract_node_fields,
|
|
8
|
+
sorted_add,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class TestExtractNodeFieldsMap:
|
|
13
|
+
"""Tests for map node collect field extraction."""
|
|
14
|
+
|
|
15
|
+
def test_map_node_collect_field_added(self) -> None:
|
|
16
|
+
"""Map node adds collect field to extracted fields."""
|
|
17
|
+
nodes = {
|
|
18
|
+
"expand_frames": {
|
|
19
|
+
"type": "map",
|
|
20
|
+
"over": "{state.frames}",
|
|
21
|
+
"sub_node": "expand_frame",
|
|
22
|
+
"collect": "expanded_frames",
|
|
23
|
+
}
|
|
24
|
+
}
|
|
25
|
+
fields = extract_node_fields(nodes)
|
|
26
|
+
assert "expanded_frames" in fields
|
|
27
|
+
|
|
28
|
+
def test_map_node_collect_has_sorted_reducer(self) -> None:
|
|
29
|
+
"""Map node collect field has Annotated[list, sorted_add] type."""
|
|
30
|
+
nodes = {
|
|
31
|
+
"expand_frames": {
|
|
32
|
+
"type": "map",
|
|
33
|
+
"over": "{state.frames}",
|
|
34
|
+
"sub_node": "expand_frame",
|
|
35
|
+
"collect": "expanded_frames",
|
|
36
|
+
}
|
|
37
|
+
}
|
|
38
|
+
fields = extract_node_fields(nodes)
|
|
39
|
+
|
|
40
|
+
field_type = fields["expanded_frames"]
|
|
41
|
+
|
|
42
|
+
# Check it's Annotated
|
|
43
|
+
assert get_origin(field_type) is Annotated
|
|
44
|
+
# Check args: list and sorted_add for ordered fan-in
|
|
45
|
+
args = get_args(field_type)
|
|
46
|
+
assert args[0] is list
|
|
47
|
+
assert args[1] is sorted_add
|
|
48
|
+
|
|
49
|
+
def test_map_node_without_collect_no_field(self) -> None:
|
|
50
|
+
"""Map node without collect key doesn't add field."""
|
|
51
|
+
nodes = {
|
|
52
|
+
"expand_frames": {
|
|
53
|
+
"type": "map",
|
|
54
|
+
"over": "{state.frames}",
|
|
55
|
+
"sub_node": "expand_frame",
|
|
56
|
+
# No collect key
|
|
57
|
+
}
|
|
58
|
+
}
|
|
59
|
+
fields = extract_node_fields(nodes)
|
|
60
|
+
# Should not have any fields from this node
|
|
61
|
+
# (other fields may come from sub_node if it had output_key)
|
|
62
|
+
assert "expanded_frames" not in fields
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class TestBuildStateClassMap:
|
|
66
|
+
"""Tests for build_state_class with map nodes."""
|
|
67
|
+
|
|
68
|
+
def test_build_state_includes_collect_field(self) -> None:
|
|
69
|
+
"""Built state class includes map node collect field."""
|
|
70
|
+
config = {
|
|
71
|
+
"nodes": {
|
|
72
|
+
"expand_frames": {
|
|
73
|
+
"type": "map",
|
|
74
|
+
"over": "{state.frames}",
|
|
75
|
+
"sub_node": "expand_frame",
|
|
76
|
+
"collect": "expanded_frames",
|
|
77
|
+
}
|
|
78
|
+
}
|
|
79
|
+
}
|
|
80
|
+
state_class = build_state_class(config)
|
|
81
|
+
annotations = state_class.__annotations__
|
|
82
|
+
|
|
83
|
+
assert "expanded_frames" in annotations
|
|
84
|
+
|
|
85
|
+
def test_build_state_collect_has_sorted_reducer(self) -> None:
|
|
86
|
+
"""Built state class has sorted_add reducer for collect field."""
|
|
87
|
+
config = {
|
|
88
|
+
"nodes": {
|
|
89
|
+
"expand_frames": {
|
|
90
|
+
"type": "map",
|
|
91
|
+
"over": "{state.frames}",
|
|
92
|
+
"sub_node": "expand_frame",
|
|
93
|
+
"collect": "expanded_frames",
|
|
94
|
+
}
|
|
95
|
+
}
|
|
96
|
+
}
|
|
97
|
+
state_class = build_state_class(config)
|
|
98
|
+
annotations = state_class.__annotations__
|
|
99
|
+
|
|
100
|
+
field_type = annotations["expanded_frames"]
|
|
101
|
+
assert get_origin(field_type) is Annotated
|
|
102
|
+
args = get_args(field_type)
|
|
103
|
+
assert args[0] is list
|
|
104
|
+
assert args[1] is sorted_add
|