yamlgraph 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of yamlgraph might be problematic. Click here for more details.
- examples/__init__.py +1 -0
- examples/storyboard/__init__.py +1 -0
- examples/storyboard/generate_videos.py +335 -0
- examples/storyboard/nodes/__init__.py +10 -0
- examples/storyboard/nodes/animated_character_node.py +248 -0
- examples/storyboard/nodes/animated_image_node.py +138 -0
- examples/storyboard/nodes/character_node.py +162 -0
- examples/storyboard/nodes/image_node.py +118 -0
- examples/storyboard/nodes/replicate_tool.py +238 -0
- examples/storyboard/retry_images.py +118 -0
- tests/__init__.py +1 -0
- tests/conftest.py +178 -0
- tests/integration/__init__.py +1 -0
- tests/integration/test_animated_storyboard.py +63 -0
- tests/integration/test_cli_commands.py +242 -0
- tests/integration/test_map_demo.py +50 -0
- tests/integration/test_memory_demo.py +281 -0
- tests/integration/test_pipeline_flow.py +105 -0
- tests/integration/test_providers.py +163 -0
- tests/integration/test_resume.py +75 -0
- tests/unit/__init__.py +1 -0
- tests/unit/test_agent_nodes.py +200 -0
- tests/unit/test_checkpointer.py +212 -0
- tests/unit/test_cli.py +121 -0
- tests/unit/test_cli_package.py +81 -0
- tests/unit/test_compile_graph_map.py +132 -0
- tests/unit/test_conditions_routing.py +253 -0
- tests/unit/test_config.py +93 -0
- tests/unit/test_conversation_memory.py +270 -0
- tests/unit/test_database.py +145 -0
- tests/unit/test_deprecation.py +104 -0
- tests/unit/test_executor.py +60 -0
- tests/unit/test_executor_async.py +179 -0
- tests/unit/test_export.py +150 -0
- tests/unit/test_expressions.py +178 -0
- tests/unit/test_format_prompt.py +145 -0
- tests/unit/test_generic_report.py +200 -0
- tests/unit/test_graph_commands.py +327 -0
- tests/unit/test_graph_loader.py +299 -0
- tests/unit/test_graph_schema.py +193 -0
- tests/unit/test_inline_schema.py +151 -0
- tests/unit/test_issues.py +164 -0
- tests/unit/test_jinja2_prompts.py +85 -0
- tests/unit/test_langsmith.py +319 -0
- tests/unit/test_llm_factory.py +109 -0
- tests/unit/test_llm_factory_async.py +118 -0
- tests/unit/test_loops.py +403 -0
- tests/unit/test_map_node.py +144 -0
- tests/unit/test_no_backward_compat.py +56 -0
- tests/unit/test_node_factory.py +225 -0
- tests/unit/test_prompts.py +166 -0
- tests/unit/test_python_nodes.py +198 -0
- tests/unit/test_reliability.py +298 -0
- tests/unit/test_result_export.py +234 -0
- tests/unit/test_router.py +296 -0
- tests/unit/test_sanitize.py +99 -0
- tests/unit/test_schema_loader.py +295 -0
- tests/unit/test_shell_tools.py +229 -0
- tests/unit/test_state_builder.py +331 -0
- tests/unit/test_state_builder_map.py +104 -0
- tests/unit/test_state_config.py +197 -0
- tests/unit/test_template.py +190 -0
- tests/unit/test_tool_nodes.py +129 -0
- yamlgraph/__init__.py +35 -0
- yamlgraph/builder.py +110 -0
- yamlgraph/cli/__init__.py +139 -0
- yamlgraph/cli/__main__.py +6 -0
- yamlgraph/cli/commands.py +232 -0
- yamlgraph/cli/deprecation.py +92 -0
- yamlgraph/cli/graph_commands.py +382 -0
- yamlgraph/cli/validators.py +37 -0
- yamlgraph/config.py +67 -0
- yamlgraph/constants.py +66 -0
- yamlgraph/error_handlers.py +226 -0
- yamlgraph/executor.py +275 -0
- yamlgraph/executor_async.py +122 -0
- yamlgraph/graph_loader.py +337 -0
- yamlgraph/map_compiler.py +138 -0
- yamlgraph/models/__init__.py +36 -0
- yamlgraph/models/graph_schema.py +141 -0
- yamlgraph/models/schemas.py +124 -0
- yamlgraph/models/state_builder.py +236 -0
- yamlgraph/node_factory.py +240 -0
- yamlgraph/routing.py +87 -0
- yamlgraph/schema_loader.py +160 -0
- yamlgraph/storage/__init__.py +17 -0
- yamlgraph/storage/checkpointer.py +72 -0
- yamlgraph/storage/database.py +320 -0
- yamlgraph/storage/export.py +269 -0
- yamlgraph/tools/__init__.py +1 -0
- yamlgraph/tools/agent.py +235 -0
- yamlgraph/tools/nodes.py +124 -0
- yamlgraph/tools/python_tool.py +178 -0
- yamlgraph/tools/shell.py +205 -0
- yamlgraph/utils/__init__.py +47 -0
- yamlgraph/utils/conditions.py +157 -0
- yamlgraph/utils/expressions.py +111 -0
- yamlgraph/utils/langsmith.py +308 -0
- yamlgraph/utils/llm_factory.py +118 -0
- yamlgraph/utils/llm_factory_async.py +105 -0
- yamlgraph/utils/logging.py +127 -0
- yamlgraph/utils/prompts.py +116 -0
- yamlgraph/utils/sanitize.py +98 -0
- yamlgraph/utils/template.py +102 -0
- yamlgraph/utils/validators.py +181 -0
- yamlgraph-0.1.1.dist-info/METADATA +854 -0
- yamlgraph-0.1.1.dist-info/RECORD +111 -0
- yamlgraph-0.1.1.dist-info/WHEEL +5 -0
- yamlgraph-0.1.1.dist-info/entry_points.txt +2 -0
- yamlgraph-0.1.1.dist-info/licenses/LICENSE +21 -0
- yamlgraph-0.1.1.dist-info/top_level.txt +3 -0
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
"""Unit tests for async LLM factory module."""
|
|
2
|
+
|
|
3
|
+
from unittest.mock import MagicMock, patch
|
|
4
|
+
|
|
5
|
+
import pytest
|
|
6
|
+
|
|
7
|
+
from yamlgraph.utils.llm_factory_async import (
|
|
8
|
+
create_llm_async,
|
|
9
|
+
get_executor,
|
|
10
|
+
invoke_async,
|
|
11
|
+
shutdown_executor,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class TestGetExecutor:
|
|
16
|
+
"""Tests for get_executor function."""
|
|
17
|
+
|
|
18
|
+
def teardown_method(self):
|
|
19
|
+
"""Clean up executor after each test."""
|
|
20
|
+
shutdown_executor()
|
|
21
|
+
|
|
22
|
+
def test_creates_executor(self):
|
|
23
|
+
"""Should create a ThreadPoolExecutor."""
|
|
24
|
+
executor = get_executor()
|
|
25
|
+
assert executor is not None
|
|
26
|
+
|
|
27
|
+
def test_returns_same_executor(self):
|
|
28
|
+
"""Should return the same executor on subsequent calls."""
|
|
29
|
+
executor1 = get_executor()
|
|
30
|
+
executor2 = get_executor()
|
|
31
|
+
assert executor1 is executor2
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class TestShutdownExecutor:
|
|
35
|
+
"""Tests for shutdown_executor function."""
|
|
36
|
+
|
|
37
|
+
def test_shutdown_cleans_up(self):
|
|
38
|
+
"""Shutdown should clean up executor."""
|
|
39
|
+
# Create an executor
|
|
40
|
+
executor1 = get_executor()
|
|
41
|
+
assert executor1 is not None
|
|
42
|
+
|
|
43
|
+
# Shutdown
|
|
44
|
+
shutdown_executor()
|
|
45
|
+
|
|
46
|
+
# Next call should create a new executor
|
|
47
|
+
executor2 = get_executor()
|
|
48
|
+
assert executor2 is not executor1
|
|
49
|
+
|
|
50
|
+
def test_shutdown_when_none(self):
|
|
51
|
+
"""Shutdown when no executor should not raise."""
|
|
52
|
+
shutdown_executor() # Ensure clean state
|
|
53
|
+
shutdown_executor() # Should not raise
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class TestCreateLLMAsync:
|
|
57
|
+
"""Tests for create_llm_async function."""
|
|
58
|
+
|
|
59
|
+
def teardown_method(self):
|
|
60
|
+
"""Clean up executor after each test."""
|
|
61
|
+
shutdown_executor()
|
|
62
|
+
|
|
63
|
+
@pytest.mark.asyncio
|
|
64
|
+
async def test_creates_llm(self):
|
|
65
|
+
"""Should create an LLM instance."""
|
|
66
|
+
llm = await create_llm_async(provider="anthropic", temperature=0.5)
|
|
67
|
+
assert llm is not None
|
|
68
|
+
assert llm.temperature == 0.5
|
|
69
|
+
|
|
70
|
+
@pytest.mark.asyncio
|
|
71
|
+
async def test_uses_default_provider(self):
|
|
72
|
+
"""Should use default provider when not specified."""
|
|
73
|
+
with patch.dict("os.environ", {"PROVIDER": ""}, clear=False):
|
|
74
|
+
llm = await create_llm_async(temperature=0.7)
|
|
75
|
+
# Default is anthropic
|
|
76
|
+
assert "anthropic" in llm.__class__.__name__.lower()
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class TestInvokeAsync:
|
|
80
|
+
"""Tests for invoke_async function."""
|
|
81
|
+
|
|
82
|
+
def teardown_method(self):
|
|
83
|
+
"""Clean up executor after each test."""
|
|
84
|
+
shutdown_executor()
|
|
85
|
+
|
|
86
|
+
@pytest.mark.asyncio
|
|
87
|
+
async def test_invoke_returns_string(self):
|
|
88
|
+
"""Should return string content when no output model."""
|
|
89
|
+
mock_llm = MagicMock()
|
|
90
|
+
mock_response = MagicMock()
|
|
91
|
+
mock_response.content = "Hello, world!"
|
|
92
|
+
mock_llm.invoke.return_value = mock_response
|
|
93
|
+
|
|
94
|
+
messages = [MagicMock()]
|
|
95
|
+
result = await invoke_async(mock_llm, messages)
|
|
96
|
+
|
|
97
|
+
assert result == "Hello, world!"
|
|
98
|
+
mock_llm.invoke.assert_called_once_with(messages)
|
|
99
|
+
|
|
100
|
+
@pytest.mark.asyncio
|
|
101
|
+
async def test_invoke_with_output_model(self):
|
|
102
|
+
"""Should use structured output when model provided."""
|
|
103
|
+
from pydantic import BaseModel
|
|
104
|
+
|
|
105
|
+
class TestOutput(BaseModel):
|
|
106
|
+
value: str
|
|
107
|
+
|
|
108
|
+
mock_llm = MagicMock()
|
|
109
|
+
mock_structured_llm = MagicMock()
|
|
110
|
+
mock_llm.with_structured_output.return_value = mock_structured_llm
|
|
111
|
+
mock_structured_llm.invoke.return_value = TestOutput(value="test")
|
|
112
|
+
|
|
113
|
+
messages = [MagicMock()]
|
|
114
|
+
result = await invoke_async(mock_llm, messages, output_model=TestOutput)
|
|
115
|
+
|
|
116
|
+
assert isinstance(result, TestOutput)
|
|
117
|
+
assert result.value == "test"
|
|
118
|
+
mock_llm.with_structured_output.assert_called_once_with(TestOutput)
|
tests/unit/test_loops.py
ADDED
|
@@ -0,0 +1,403 @@
|
|
|
1
|
+
"""Tests for Section 3: Self-Correction Loops (Reflexion).
|
|
2
|
+
|
|
3
|
+
TDD tests for expression conditions, loop tracking, and cyclic graphs.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from unittest.mock import MagicMock, patch
|
|
7
|
+
|
|
8
|
+
import pytest
|
|
9
|
+
|
|
10
|
+
# =============================================================================
|
|
11
|
+
# Test: Expression Condition Parsing
|
|
12
|
+
# =============================================================================
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class TestExpressionConditions:
|
|
16
|
+
"""Tests for condition expression evaluation."""
|
|
17
|
+
|
|
18
|
+
def test_evaluate_condition_exists(self):
|
|
19
|
+
"""evaluate_condition function should exist."""
|
|
20
|
+
from yamlgraph.utils.conditions import evaluate_condition
|
|
21
|
+
|
|
22
|
+
assert callable(evaluate_condition)
|
|
23
|
+
|
|
24
|
+
def test_less_than_comparison(self):
|
|
25
|
+
"""Evaluates 'score < 0.8' correctly."""
|
|
26
|
+
from yamlgraph.utils.conditions import evaluate_condition
|
|
27
|
+
|
|
28
|
+
state = {"score": 0.5}
|
|
29
|
+
assert evaluate_condition("score < 0.8", state) is True
|
|
30
|
+
|
|
31
|
+
state = {"score": 0.9}
|
|
32
|
+
assert evaluate_condition("score < 0.8", state) is False
|
|
33
|
+
|
|
34
|
+
def test_greater_than_comparison(self):
|
|
35
|
+
"""Evaluates 'score > 0.5' correctly."""
|
|
36
|
+
from yamlgraph.utils.conditions import evaluate_condition
|
|
37
|
+
|
|
38
|
+
state = {"score": 0.7}
|
|
39
|
+
assert evaluate_condition("score > 0.5", state) is True
|
|
40
|
+
|
|
41
|
+
state = {"score": 0.3}
|
|
42
|
+
assert evaluate_condition("score > 0.5", state) is False
|
|
43
|
+
|
|
44
|
+
def test_less_than_or_equal(self):
|
|
45
|
+
"""Evaluates 'score <= 0.8' correctly."""
|
|
46
|
+
from yamlgraph.utils.conditions import evaluate_condition
|
|
47
|
+
|
|
48
|
+
state = {"score": 0.8}
|
|
49
|
+
assert evaluate_condition("score <= 0.8", state) is True
|
|
50
|
+
|
|
51
|
+
state = {"score": 0.9}
|
|
52
|
+
assert evaluate_condition("score <= 0.8", state) is False
|
|
53
|
+
|
|
54
|
+
def test_greater_than_or_equal(self):
|
|
55
|
+
"""Evaluates 'score >= 0.8' correctly."""
|
|
56
|
+
from yamlgraph.utils.conditions import evaluate_condition
|
|
57
|
+
|
|
58
|
+
state = {"score": 0.8}
|
|
59
|
+
assert evaluate_condition("score >= 0.8", state) is True
|
|
60
|
+
|
|
61
|
+
state = {"score": 0.7}
|
|
62
|
+
assert evaluate_condition("score >= 0.8", state) is False
|
|
63
|
+
|
|
64
|
+
def test_equality_comparison(self):
|
|
65
|
+
"""Evaluates 'status == \"approved\"' correctly."""
|
|
66
|
+
from yamlgraph.utils.conditions import evaluate_condition
|
|
67
|
+
|
|
68
|
+
state = {"status": "approved"}
|
|
69
|
+
assert evaluate_condition('status == "approved"', state) is True
|
|
70
|
+
|
|
71
|
+
state = {"status": "pending"}
|
|
72
|
+
assert evaluate_condition('status == "approved"', state) is False
|
|
73
|
+
|
|
74
|
+
def test_inequality_comparison(self):
|
|
75
|
+
"""Evaluates 'error != null' correctly."""
|
|
76
|
+
from yamlgraph.utils.conditions import evaluate_condition
|
|
77
|
+
|
|
78
|
+
state = {"error": "something"}
|
|
79
|
+
assert evaluate_condition("error != null", state) is True
|
|
80
|
+
|
|
81
|
+
state = {"error": None}
|
|
82
|
+
assert evaluate_condition("error != null", state) is False
|
|
83
|
+
|
|
84
|
+
def test_nested_attribute_access(self):
|
|
85
|
+
"""Evaluates 'critique.score >= 0.8' from state."""
|
|
86
|
+
from yamlgraph.utils.conditions import evaluate_condition
|
|
87
|
+
|
|
88
|
+
# Using object with attribute
|
|
89
|
+
critique = MagicMock()
|
|
90
|
+
critique.score = 0.85
|
|
91
|
+
state = {"critique": critique}
|
|
92
|
+
assert evaluate_condition("critique.score >= 0.8", state) is True
|
|
93
|
+
|
|
94
|
+
critique.score = 0.7
|
|
95
|
+
assert evaluate_condition("critique.score >= 0.8", state) is False
|
|
96
|
+
|
|
97
|
+
def test_compound_and_condition(self):
|
|
98
|
+
"""Evaluates 'score < 0.8 and iteration < 3'."""
|
|
99
|
+
from yamlgraph.utils.conditions import evaluate_condition
|
|
100
|
+
|
|
101
|
+
state = {"score": 0.5, "iteration": 2}
|
|
102
|
+
assert evaluate_condition("score < 0.8 and iteration < 3", state) is True
|
|
103
|
+
|
|
104
|
+
state = {"score": 0.9, "iteration": 2}
|
|
105
|
+
assert evaluate_condition("score < 0.8 and iteration < 3", state) is False
|
|
106
|
+
|
|
107
|
+
state = {"score": 0.5, "iteration": 5}
|
|
108
|
+
assert evaluate_condition("score < 0.8 and iteration < 3", state) is False
|
|
109
|
+
|
|
110
|
+
def test_compound_or_condition(self):
|
|
111
|
+
"""Evaluates 'approved == true or override == true'."""
|
|
112
|
+
from yamlgraph.utils.conditions import evaluate_condition
|
|
113
|
+
|
|
114
|
+
state = {"approved": True, "override": False}
|
|
115
|
+
assert evaluate_condition("approved == true or override == true", state) is True
|
|
116
|
+
|
|
117
|
+
state = {"approved": False, "override": True}
|
|
118
|
+
assert evaluate_condition("approved == true or override == true", state) is True
|
|
119
|
+
|
|
120
|
+
state = {"approved": False, "override": False}
|
|
121
|
+
assert (
|
|
122
|
+
evaluate_condition("approved == true or override == true", state) is False
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
def test_invalid_expression_raises(self):
|
|
126
|
+
"""Malformed expression raises ValueError."""
|
|
127
|
+
from yamlgraph.utils.conditions import evaluate_condition
|
|
128
|
+
|
|
129
|
+
with pytest.raises(ValueError):
|
|
130
|
+
evaluate_condition("score <<< 0.8", {})
|
|
131
|
+
|
|
132
|
+
def test_missing_attribute_returns_false(self):
|
|
133
|
+
"""Missing attribute in state returns False gracefully."""
|
|
134
|
+
from yamlgraph.utils.conditions import evaluate_condition
|
|
135
|
+
|
|
136
|
+
state = {}
|
|
137
|
+
# Should not raise, should return False for missing attribute
|
|
138
|
+
assert evaluate_condition("score < 0.8", state) is False
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
# =============================================================================
|
|
142
|
+
# Test: Loop Tracking
|
|
143
|
+
# =============================================================================
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
class TestLoopTracking:
|
|
147
|
+
"""Tests for loop iteration tracking."""
|
|
148
|
+
|
|
149
|
+
def test_state_has_loop_counts_field(self):
|
|
150
|
+
"""Dynamic state should have _loop_counts field."""
|
|
151
|
+
from yamlgraph.models.state_builder import build_state_class
|
|
152
|
+
|
|
153
|
+
State = build_state_class({"nodes": {}})
|
|
154
|
+
# Should have _loop_counts in annotations
|
|
155
|
+
assert "_loop_counts" in State.__annotations__
|
|
156
|
+
|
|
157
|
+
# And work at runtime
|
|
158
|
+
state = {"_loop_counts": {"critique": 2}}
|
|
159
|
+
assert state["_loop_counts"]["critique"] == 2
|
|
160
|
+
|
|
161
|
+
def test_node_increments_loop_counter(self):
|
|
162
|
+
"""Each node execution increments its counter in _loop_counts."""
|
|
163
|
+
from yamlgraph.node_factory import create_node_function
|
|
164
|
+
|
|
165
|
+
node_config = {
|
|
166
|
+
"prompt": "test_prompt",
|
|
167
|
+
"state_key": "result",
|
|
168
|
+
}
|
|
169
|
+
|
|
170
|
+
with patch("yamlgraph.node_factory.execute_prompt") as mock_execute:
|
|
171
|
+
mock_execute.return_value = "test result"
|
|
172
|
+
|
|
173
|
+
node_fn = create_node_function("critique", node_config, {})
|
|
174
|
+
|
|
175
|
+
# First call - should initialize counter
|
|
176
|
+
state = {"message": "test"}
|
|
177
|
+
result = node_fn(state)
|
|
178
|
+
assert result.get("_loop_counts", {}).get("critique") == 1
|
|
179
|
+
|
|
180
|
+
# Second call - should increment
|
|
181
|
+
state = {"message": "test", "_loop_counts": {"critique": 1}}
|
|
182
|
+
result = node_fn(state)
|
|
183
|
+
assert result.get("_loop_counts", {}).get("critique") == 2
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
# =============================================================================
|
|
187
|
+
# Test: Loop Limits Configuration
|
|
188
|
+
# =============================================================================
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
class TestLoopLimits:
|
|
192
|
+
"""Tests for loop_limits configuration."""
|
|
193
|
+
|
|
194
|
+
def test_parses_loop_limits_from_yaml(self):
|
|
195
|
+
"""GraphConfig parses loop_limits section."""
|
|
196
|
+
from yamlgraph.graph_loader import GraphConfig
|
|
197
|
+
|
|
198
|
+
config_dict = {
|
|
199
|
+
"version": "1.0",
|
|
200
|
+
"name": "test",
|
|
201
|
+
"nodes": {
|
|
202
|
+
"draft": {"prompt": "draft"},
|
|
203
|
+
"critique": {"prompt": "critique"},
|
|
204
|
+
},
|
|
205
|
+
"edges": [
|
|
206
|
+
{"from": "START", "to": "draft"},
|
|
207
|
+
{"from": "draft", "to": "critique"},
|
|
208
|
+
{"from": "critique", "to": "END"},
|
|
209
|
+
],
|
|
210
|
+
"loop_limits": {
|
|
211
|
+
"critique": 3,
|
|
212
|
+
},
|
|
213
|
+
}
|
|
214
|
+
config = GraphConfig(config_dict)
|
|
215
|
+
assert config.loop_limits == {"critique": 3}
|
|
216
|
+
|
|
217
|
+
def test_loop_limits_defaults_to_empty(self):
|
|
218
|
+
"""Missing loop_limits defaults to empty dict."""
|
|
219
|
+
from yamlgraph.graph_loader import GraphConfig
|
|
220
|
+
|
|
221
|
+
config_dict = {
|
|
222
|
+
"version": "1.0",
|
|
223
|
+
"name": "test",
|
|
224
|
+
"nodes": {"node1": {"prompt": "p1"}},
|
|
225
|
+
"edges": [{"from": "START", "to": "node1"}, {"from": "node1", "to": "END"}],
|
|
226
|
+
}
|
|
227
|
+
config = GraphConfig(config_dict)
|
|
228
|
+
assert config.loop_limits == {}
|
|
229
|
+
|
|
230
|
+
def test_node_checks_loop_limit(self):
|
|
231
|
+
"""Node execution checks loop limit before running."""
|
|
232
|
+
from yamlgraph.node_factory import create_node_function
|
|
233
|
+
|
|
234
|
+
node_config = {
|
|
235
|
+
"prompt": "test_prompt",
|
|
236
|
+
"state_key": "result",
|
|
237
|
+
"loop_limit": 3, # Node-level limit
|
|
238
|
+
}
|
|
239
|
+
|
|
240
|
+
with patch("yamlgraph.node_factory.execute_prompt") as mock_execute:
|
|
241
|
+
mock_execute.return_value = "test result"
|
|
242
|
+
|
|
243
|
+
node_fn = create_node_function("critique", node_config, {})
|
|
244
|
+
|
|
245
|
+
# Under limit - should execute
|
|
246
|
+
state = {"_loop_counts": {"critique": 2}}
|
|
247
|
+
result = node_fn(state)
|
|
248
|
+
assert "result" in result
|
|
249
|
+
|
|
250
|
+
# At limit - should skip/terminate
|
|
251
|
+
state = {"_loop_counts": {"critique": 3}}
|
|
252
|
+
result = node_fn(state)
|
|
253
|
+
assert result.get("_loop_limit_reached") is True
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
# =============================================================================
|
|
257
|
+
# Test: Cyclic Edges
|
|
258
|
+
# =============================================================================
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
class TestCyclicEdges:
|
|
262
|
+
"""Tests for cyclic graph support."""
|
|
263
|
+
|
|
264
|
+
def test_allows_backward_edges(self):
|
|
265
|
+
"""Graph config allows edges pointing to earlier nodes."""
|
|
266
|
+
from yamlgraph.graph_loader import GraphConfig
|
|
267
|
+
|
|
268
|
+
config_dict = {
|
|
269
|
+
"version": "1.0",
|
|
270
|
+
"name": "test",
|
|
271
|
+
"nodes": {
|
|
272
|
+
"draft": {"prompt": "draft"},
|
|
273
|
+
"critique": {"prompt": "critique"},
|
|
274
|
+
"refine": {"prompt": "refine"},
|
|
275
|
+
},
|
|
276
|
+
"edges": [
|
|
277
|
+
{"from": "START", "to": "draft"},
|
|
278
|
+
{"from": "draft", "to": "critique"},
|
|
279
|
+
{
|
|
280
|
+
"from": "critique",
|
|
281
|
+
"to": "refine",
|
|
282
|
+
"condition": "critique.score < 0.8",
|
|
283
|
+
},
|
|
284
|
+
{"from": "critique", "to": "END", "condition": "critique.score >= 0.8"},
|
|
285
|
+
{"from": "refine", "to": "critique"}, # Backward edge (cycle)
|
|
286
|
+
],
|
|
287
|
+
"loop_limits": {"critique": 3},
|
|
288
|
+
}
|
|
289
|
+
# Should not raise
|
|
290
|
+
config = GraphConfig(config_dict)
|
|
291
|
+
assert config is not None
|
|
292
|
+
|
|
293
|
+
def test_compiles_cyclic_graph(self):
|
|
294
|
+
"""Cyclic graph compiles to StateGraph."""
|
|
295
|
+
from yamlgraph.graph_loader import GraphConfig, compile_graph
|
|
296
|
+
|
|
297
|
+
config_dict = {
|
|
298
|
+
"version": "1.0",
|
|
299
|
+
"name": "test",
|
|
300
|
+
"nodes": {
|
|
301
|
+
"draft": {"prompt": "draft", "state_key": "current_draft"},
|
|
302
|
+
"critique": {"prompt": "critique", "state_key": "critique"},
|
|
303
|
+
"refine": {"prompt": "refine", "state_key": "current_draft"},
|
|
304
|
+
},
|
|
305
|
+
"edges": [
|
|
306
|
+
{"from": "START", "to": "draft"},
|
|
307
|
+
{"from": "draft", "to": "critique"},
|
|
308
|
+
{
|
|
309
|
+
"from": "critique",
|
|
310
|
+
"to": "refine",
|
|
311
|
+
"condition": "critique.score < 0.8",
|
|
312
|
+
},
|
|
313
|
+
{"from": "critique", "to": "END", "condition": "critique.score >= 0.8"},
|
|
314
|
+
{"from": "refine", "to": "critique"}, # Cycle
|
|
315
|
+
],
|
|
316
|
+
"loop_limits": {"critique": 3},
|
|
317
|
+
}
|
|
318
|
+
config = GraphConfig(config_dict)
|
|
319
|
+
graph = compile_graph(config)
|
|
320
|
+
assert graph is not None
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
# =============================================================================
|
|
324
|
+
# Test: Pydantic Models
|
|
325
|
+
# =============================================================================
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
class TestReflexionModels:
|
|
329
|
+
"""Tests for DraftContent and Critique-like fixture models.
|
|
330
|
+
|
|
331
|
+
Note: Demo models were removed from yamlgraph.models in Section 10.
|
|
332
|
+
These tests use fixture models to prove the pattern still works.
|
|
333
|
+
"""
|
|
334
|
+
|
|
335
|
+
def test_draft_content_model_exists(self):
|
|
336
|
+
"""DraftContent-like fixture model can be created."""
|
|
337
|
+
from tests.conftest import FixtureDraftContent
|
|
338
|
+
|
|
339
|
+
assert FixtureDraftContent is not None
|
|
340
|
+
|
|
341
|
+
def test_draft_content_fields(self):
|
|
342
|
+
"""DraftContent-like model has content and version fields."""
|
|
343
|
+
from tests.conftest import FixtureDraftContent
|
|
344
|
+
|
|
345
|
+
draft = FixtureDraftContent(content="Test essay", version=1)
|
|
346
|
+
assert draft.content == "Test essay"
|
|
347
|
+
assert draft.version == 1
|
|
348
|
+
|
|
349
|
+
def test_critique_model_exists(self):
|
|
350
|
+
"""Critique-like fixture model can be created."""
|
|
351
|
+
from tests.conftest import FixtureCritique
|
|
352
|
+
|
|
353
|
+
assert FixtureCritique is not None
|
|
354
|
+
|
|
355
|
+
def test_critique_fields(self):
|
|
356
|
+
"""Critique-like model has score, feedback, issues, should_refine fields."""
|
|
357
|
+
from tests.conftest import FixtureCritique
|
|
358
|
+
|
|
359
|
+
critique = FixtureCritique(
|
|
360
|
+
score=0.75,
|
|
361
|
+
feedback="Improve transitions",
|
|
362
|
+
issues=["Weak intro", "No conclusion"],
|
|
363
|
+
should_refine=True,
|
|
364
|
+
)
|
|
365
|
+
assert critique.score == 0.75
|
|
366
|
+
assert critique.feedback == "Improve transitions"
|
|
367
|
+
assert len(critique.issues) == 2
|
|
368
|
+
assert critique.should_refine is True
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
# =============================================================================
|
|
372
|
+
# Test: Reflexion Demo Graph
|
|
373
|
+
# =============================================================================
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
class TestReflexionDemoGraph:
|
|
377
|
+
"""Tests for the reflexion-demo.yaml graph."""
|
|
378
|
+
|
|
379
|
+
def test_demo_graph_loads(self):
|
|
380
|
+
"""reflexion-demo.yaml loads without error."""
|
|
381
|
+
from yamlgraph.graph_loader import load_graph_config
|
|
382
|
+
|
|
383
|
+
config = load_graph_config("graphs/reflexion-demo.yaml")
|
|
384
|
+
assert config.name == "reflexion-demo"
|
|
385
|
+
assert "draft" in config.nodes
|
|
386
|
+
assert "critique" in config.nodes
|
|
387
|
+
assert "refine" in config.nodes
|
|
388
|
+
|
|
389
|
+
def test_demo_graph_has_loop_limits(self):
|
|
390
|
+
"""reflexion-demo.yaml has loop_limits configured."""
|
|
391
|
+
from yamlgraph.graph_loader import load_graph_config
|
|
392
|
+
|
|
393
|
+
config = load_graph_config("graphs/reflexion-demo.yaml")
|
|
394
|
+
assert "critique" in config.loop_limits
|
|
395
|
+
assert config.loop_limits["critique"] >= 3
|
|
396
|
+
|
|
397
|
+
def test_demo_graph_compiles(self):
|
|
398
|
+
"""reflexion-demo.yaml compiles to StateGraph."""
|
|
399
|
+
from yamlgraph.graph_loader import compile_graph, load_graph_config
|
|
400
|
+
|
|
401
|
+
config = load_graph_config("graphs/reflexion-demo.yaml")
|
|
402
|
+
graph = compile_graph(config)
|
|
403
|
+
assert graph is not None
|
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
"""Tests for type: map node functionality."""
|
|
2
|
+
|
|
3
|
+
from unittest.mock import MagicMock
|
|
4
|
+
|
|
5
|
+
import pytest
|
|
6
|
+
|
|
7
|
+
from yamlgraph.map_compiler import compile_map_node, wrap_for_reducer
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class TestWrapForReducer:
|
|
11
|
+
"""Tests for wrap_for_reducer helper."""
|
|
12
|
+
|
|
13
|
+
def test_wraps_result_in_list(self):
|
|
14
|
+
"""Wrap node output for reducer aggregation."""
|
|
15
|
+
|
|
16
|
+
def simple_node(state: dict) -> dict:
|
|
17
|
+
return {"result": state["item"] * 2}
|
|
18
|
+
|
|
19
|
+
wrapped = wrap_for_reducer(simple_node, "collected", "result")
|
|
20
|
+
result = wrapped({"item": 5})
|
|
21
|
+
|
|
22
|
+
assert result == {"collected": [10]}
|
|
23
|
+
|
|
24
|
+
def test_preserves_map_index(self):
|
|
25
|
+
"""Preserve _map_index in wrapped output."""
|
|
26
|
+
|
|
27
|
+
def node_fn(state: dict) -> dict:
|
|
28
|
+
return {"data": state["value"]}
|
|
29
|
+
|
|
30
|
+
wrapped = wrap_for_reducer(node_fn, "results", "data")
|
|
31
|
+
result = wrapped({"value": "test", "_map_index": 2})
|
|
32
|
+
|
|
33
|
+
assert result == {"results": [{"_map_index": 2, "value": "test"}]}
|
|
34
|
+
|
|
35
|
+
def test_extracts_state_key(self):
|
|
36
|
+
"""Extract specific state_key from node result."""
|
|
37
|
+
|
|
38
|
+
def node_fn(state: dict) -> dict:
|
|
39
|
+
return {"frame_data": {"before": "a", "after": "b"}, "other": "ignore"}
|
|
40
|
+
|
|
41
|
+
wrapped = wrap_for_reducer(node_fn, "frames", "frame_data")
|
|
42
|
+
result = wrapped({})
|
|
43
|
+
|
|
44
|
+
assert result == {"frames": [{"before": "a", "after": "b"}]}
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class TestCompileMapNode:
|
|
48
|
+
"""Tests for compile_map_node function."""
|
|
49
|
+
|
|
50
|
+
def test_creates_map_edge_function(self):
|
|
51
|
+
"""compile_map_node returns a map edge function."""
|
|
52
|
+
config = {
|
|
53
|
+
"over": "{items}",
|
|
54
|
+
"as": "item",
|
|
55
|
+
"collect": "results",
|
|
56
|
+
"node": {"type": "llm", "prompt": "test", "state_key": "result"},
|
|
57
|
+
}
|
|
58
|
+
builder = MagicMock()
|
|
59
|
+
defaults = {}
|
|
60
|
+
|
|
61
|
+
map_edge, sub_node_name = compile_map_node("expand", config, builder, defaults)
|
|
62
|
+
|
|
63
|
+
# Should return callable and sub-node name
|
|
64
|
+
assert callable(map_edge)
|
|
65
|
+
assert sub_node_name == "_map_expand_sub"
|
|
66
|
+
|
|
67
|
+
def test_map_edge_returns_send_list(self):
|
|
68
|
+
"""Map edge function returns list of Send objects."""
|
|
69
|
+
from langgraph.types import Send
|
|
70
|
+
|
|
71
|
+
config = {
|
|
72
|
+
"over": "{items}",
|
|
73
|
+
"as": "item",
|
|
74
|
+
"collect": "results",
|
|
75
|
+
"node": {"type": "llm", "prompt": "test", "state_key": "result"},
|
|
76
|
+
}
|
|
77
|
+
builder = MagicMock()
|
|
78
|
+
defaults = {}
|
|
79
|
+
|
|
80
|
+
map_edge, sub_node_name = compile_map_node("expand", config, builder, defaults)
|
|
81
|
+
|
|
82
|
+
state = {"items": ["a", "b", "c"]}
|
|
83
|
+
sends = map_edge(state)
|
|
84
|
+
|
|
85
|
+
assert len(sends) == 3
|
|
86
|
+
assert all(isinstance(s, Send) for s in sends)
|
|
87
|
+
assert sends[0].node == sub_node_name
|
|
88
|
+
assert sends[0].arg["item"] == "a"
|
|
89
|
+
assert sends[0].arg["_map_index"] == 0
|
|
90
|
+
assert sends[1].arg["item"] == "b"
|
|
91
|
+
assert sends[1].arg["_map_index"] == 1
|
|
92
|
+
|
|
93
|
+
def test_map_edge_empty_list(self):
|
|
94
|
+
"""Empty list returns empty Send list."""
|
|
95
|
+
config = {
|
|
96
|
+
"over": "{items}",
|
|
97
|
+
"as": "item",
|
|
98
|
+
"collect": "results",
|
|
99
|
+
"node": {"type": "llm", "prompt": "test", "state_key": "result"},
|
|
100
|
+
}
|
|
101
|
+
builder = MagicMock()
|
|
102
|
+
defaults = {}
|
|
103
|
+
|
|
104
|
+
map_edge, _ = compile_map_node("expand", config, builder, defaults)
|
|
105
|
+
|
|
106
|
+
state = {"items": []}
|
|
107
|
+
sends = map_edge(state)
|
|
108
|
+
|
|
109
|
+
assert sends == []
|
|
110
|
+
|
|
111
|
+
def test_adds_wrapped_sub_node_to_builder(self):
|
|
112
|
+
"""compile_map_node adds wrapped sub-node to builder."""
|
|
113
|
+
config = {
|
|
114
|
+
"over": "{items}",
|
|
115
|
+
"as": "item",
|
|
116
|
+
"collect": "results",
|
|
117
|
+
"node": {"type": "llm", "prompt": "test", "state_key": "result"},
|
|
118
|
+
}
|
|
119
|
+
builder = MagicMock()
|
|
120
|
+
defaults = {}
|
|
121
|
+
|
|
122
|
+
compile_map_node("expand", config, builder, defaults)
|
|
123
|
+
|
|
124
|
+
# Should call builder.add_node
|
|
125
|
+
builder.add_node.assert_called_once()
|
|
126
|
+
call_args = builder.add_node.call_args
|
|
127
|
+
assert call_args[0][0] == "_map_expand_sub"
|
|
128
|
+
|
|
129
|
+
def test_validates_over_is_list(self):
|
|
130
|
+
"""Map edge validates that 'over' resolves to a list."""
|
|
131
|
+
config = {
|
|
132
|
+
"over": "{not_a_list}",
|
|
133
|
+
"as": "item",
|
|
134
|
+
"collect": "results",
|
|
135
|
+
"node": {"type": "llm", "prompt": "test", "state_key": "result"},
|
|
136
|
+
}
|
|
137
|
+
builder = MagicMock()
|
|
138
|
+
defaults = {}
|
|
139
|
+
|
|
140
|
+
map_edge, _ = compile_map_node("expand", config, builder, defaults)
|
|
141
|
+
|
|
142
|
+
state = {"not_a_list": "string"}
|
|
143
|
+
with pytest.raises(TypeError, match="must resolve to list"):
|
|
144
|
+
map_edge(state)
|