yamlgraph 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of yamlgraph might be problematic. Click here for more details.

Files changed (111) hide show
  1. examples/__init__.py +1 -0
  2. examples/storyboard/__init__.py +1 -0
  3. examples/storyboard/generate_videos.py +335 -0
  4. examples/storyboard/nodes/__init__.py +10 -0
  5. examples/storyboard/nodes/animated_character_node.py +248 -0
  6. examples/storyboard/nodes/animated_image_node.py +138 -0
  7. examples/storyboard/nodes/character_node.py +162 -0
  8. examples/storyboard/nodes/image_node.py +118 -0
  9. examples/storyboard/nodes/replicate_tool.py +238 -0
  10. examples/storyboard/retry_images.py +118 -0
  11. tests/__init__.py +1 -0
  12. tests/conftest.py +178 -0
  13. tests/integration/__init__.py +1 -0
  14. tests/integration/test_animated_storyboard.py +63 -0
  15. tests/integration/test_cli_commands.py +242 -0
  16. tests/integration/test_map_demo.py +50 -0
  17. tests/integration/test_memory_demo.py +281 -0
  18. tests/integration/test_pipeline_flow.py +105 -0
  19. tests/integration/test_providers.py +163 -0
  20. tests/integration/test_resume.py +75 -0
  21. tests/unit/__init__.py +1 -0
  22. tests/unit/test_agent_nodes.py +200 -0
  23. tests/unit/test_checkpointer.py +212 -0
  24. tests/unit/test_cli.py +121 -0
  25. tests/unit/test_cli_package.py +81 -0
  26. tests/unit/test_compile_graph_map.py +132 -0
  27. tests/unit/test_conditions_routing.py +253 -0
  28. tests/unit/test_config.py +93 -0
  29. tests/unit/test_conversation_memory.py +270 -0
  30. tests/unit/test_database.py +145 -0
  31. tests/unit/test_deprecation.py +104 -0
  32. tests/unit/test_executor.py +60 -0
  33. tests/unit/test_executor_async.py +179 -0
  34. tests/unit/test_export.py +150 -0
  35. tests/unit/test_expressions.py +178 -0
  36. tests/unit/test_format_prompt.py +145 -0
  37. tests/unit/test_generic_report.py +200 -0
  38. tests/unit/test_graph_commands.py +327 -0
  39. tests/unit/test_graph_loader.py +299 -0
  40. tests/unit/test_graph_schema.py +193 -0
  41. tests/unit/test_inline_schema.py +151 -0
  42. tests/unit/test_issues.py +164 -0
  43. tests/unit/test_jinja2_prompts.py +85 -0
  44. tests/unit/test_langsmith.py +319 -0
  45. tests/unit/test_llm_factory.py +109 -0
  46. tests/unit/test_llm_factory_async.py +118 -0
  47. tests/unit/test_loops.py +403 -0
  48. tests/unit/test_map_node.py +144 -0
  49. tests/unit/test_no_backward_compat.py +56 -0
  50. tests/unit/test_node_factory.py +225 -0
  51. tests/unit/test_prompts.py +166 -0
  52. tests/unit/test_python_nodes.py +198 -0
  53. tests/unit/test_reliability.py +298 -0
  54. tests/unit/test_result_export.py +234 -0
  55. tests/unit/test_router.py +296 -0
  56. tests/unit/test_sanitize.py +99 -0
  57. tests/unit/test_schema_loader.py +295 -0
  58. tests/unit/test_shell_tools.py +229 -0
  59. tests/unit/test_state_builder.py +331 -0
  60. tests/unit/test_state_builder_map.py +104 -0
  61. tests/unit/test_state_config.py +197 -0
  62. tests/unit/test_template.py +190 -0
  63. tests/unit/test_tool_nodes.py +129 -0
  64. yamlgraph/__init__.py +35 -0
  65. yamlgraph/builder.py +110 -0
  66. yamlgraph/cli/__init__.py +139 -0
  67. yamlgraph/cli/__main__.py +6 -0
  68. yamlgraph/cli/commands.py +232 -0
  69. yamlgraph/cli/deprecation.py +92 -0
  70. yamlgraph/cli/graph_commands.py +382 -0
  71. yamlgraph/cli/validators.py +37 -0
  72. yamlgraph/config.py +67 -0
  73. yamlgraph/constants.py +66 -0
  74. yamlgraph/error_handlers.py +226 -0
  75. yamlgraph/executor.py +275 -0
  76. yamlgraph/executor_async.py +122 -0
  77. yamlgraph/graph_loader.py +337 -0
  78. yamlgraph/map_compiler.py +138 -0
  79. yamlgraph/models/__init__.py +36 -0
  80. yamlgraph/models/graph_schema.py +141 -0
  81. yamlgraph/models/schemas.py +124 -0
  82. yamlgraph/models/state_builder.py +236 -0
  83. yamlgraph/node_factory.py +240 -0
  84. yamlgraph/routing.py +87 -0
  85. yamlgraph/schema_loader.py +160 -0
  86. yamlgraph/storage/__init__.py +17 -0
  87. yamlgraph/storage/checkpointer.py +72 -0
  88. yamlgraph/storage/database.py +320 -0
  89. yamlgraph/storage/export.py +269 -0
  90. yamlgraph/tools/__init__.py +1 -0
  91. yamlgraph/tools/agent.py +235 -0
  92. yamlgraph/tools/nodes.py +124 -0
  93. yamlgraph/tools/python_tool.py +178 -0
  94. yamlgraph/tools/shell.py +205 -0
  95. yamlgraph/utils/__init__.py +47 -0
  96. yamlgraph/utils/conditions.py +157 -0
  97. yamlgraph/utils/expressions.py +111 -0
  98. yamlgraph/utils/langsmith.py +308 -0
  99. yamlgraph/utils/llm_factory.py +118 -0
  100. yamlgraph/utils/llm_factory_async.py +105 -0
  101. yamlgraph/utils/logging.py +127 -0
  102. yamlgraph/utils/prompts.py +116 -0
  103. yamlgraph/utils/sanitize.py +98 -0
  104. yamlgraph/utils/template.py +102 -0
  105. yamlgraph/utils/validators.py +181 -0
  106. yamlgraph-0.1.1.dist-info/METADATA +854 -0
  107. yamlgraph-0.1.1.dist-info/RECORD +111 -0
  108. yamlgraph-0.1.1.dist-info/WHEEL +5 -0
  109. yamlgraph-0.1.1.dist-info/entry_points.txt +2 -0
  110. yamlgraph-0.1.1.dist-info/licenses/LICENSE +21 -0
  111. yamlgraph-0.1.1.dist-info/top_level.txt +3 -0
@@ -0,0 +1,118 @@
1
+ """Unit tests for async LLM factory module."""
2
+
3
+ from unittest.mock import MagicMock, patch
4
+
5
+ import pytest
6
+
7
+ from yamlgraph.utils.llm_factory_async import (
8
+ create_llm_async,
9
+ get_executor,
10
+ invoke_async,
11
+ shutdown_executor,
12
+ )
13
+
14
+
15
+ class TestGetExecutor:
16
+ """Tests for get_executor function."""
17
+
18
+ def teardown_method(self):
19
+ """Clean up executor after each test."""
20
+ shutdown_executor()
21
+
22
+ def test_creates_executor(self):
23
+ """Should create a ThreadPoolExecutor."""
24
+ executor = get_executor()
25
+ assert executor is not None
26
+
27
+ def test_returns_same_executor(self):
28
+ """Should return the same executor on subsequent calls."""
29
+ executor1 = get_executor()
30
+ executor2 = get_executor()
31
+ assert executor1 is executor2
32
+
33
+
34
+ class TestShutdownExecutor:
35
+ """Tests for shutdown_executor function."""
36
+
37
+ def test_shutdown_cleans_up(self):
38
+ """Shutdown should clean up executor."""
39
+ # Create an executor
40
+ executor1 = get_executor()
41
+ assert executor1 is not None
42
+
43
+ # Shutdown
44
+ shutdown_executor()
45
+
46
+ # Next call should create a new executor
47
+ executor2 = get_executor()
48
+ assert executor2 is not executor1
49
+
50
+ def test_shutdown_when_none(self):
51
+ """Shutdown when no executor should not raise."""
52
+ shutdown_executor() # Ensure clean state
53
+ shutdown_executor() # Should not raise
54
+
55
+
56
+ class TestCreateLLMAsync:
57
+ """Tests for create_llm_async function."""
58
+
59
+ def teardown_method(self):
60
+ """Clean up executor after each test."""
61
+ shutdown_executor()
62
+
63
+ @pytest.mark.asyncio
64
+ async def test_creates_llm(self):
65
+ """Should create an LLM instance."""
66
+ llm = await create_llm_async(provider="anthropic", temperature=0.5)
67
+ assert llm is not None
68
+ assert llm.temperature == 0.5
69
+
70
+ @pytest.mark.asyncio
71
+ async def test_uses_default_provider(self):
72
+ """Should use default provider when not specified."""
73
+ with patch.dict("os.environ", {"PROVIDER": ""}, clear=False):
74
+ llm = await create_llm_async(temperature=0.7)
75
+ # Default is anthropic
76
+ assert "anthropic" in llm.__class__.__name__.lower()
77
+
78
+
79
+ class TestInvokeAsync:
80
+ """Tests for invoke_async function."""
81
+
82
+ def teardown_method(self):
83
+ """Clean up executor after each test."""
84
+ shutdown_executor()
85
+
86
+ @pytest.mark.asyncio
87
+ async def test_invoke_returns_string(self):
88
+ """Should return string content when no output model."""
89
+ mock_llm = MagicMock()
90
+ mock_response = MagicMock()
91
+ mock_response.content = "Hello, world!"
92
+ mock_llm.invoke.return_value = mock_response
93
+
94
+ messages = [MagicMock()]
95
+ result = await invoke_async(mock_llm, messages)
96
+
97
+ assert result == "Hello, world!"
98
+ mock_llm.invoke.assert_called_once_with(messages)
99
+
100
+ @pytest.mark.asyncio
101
+ async def test_invoke_with_output_model(self):
102
+ """Should use structured output when model provided."""
103
+ from pydantic import BaseModel
104
+
105
+ class TestOutput(BaseModel):
106
+ value: str
107
+
108
+ mock_llm = MagicMock()
109
+ mock_structured_llm = MagicMock()
110
+ mock_llm.with_structured_output.return_value = mock_structured_llm
111
+ mock_structured_llm.invoke.return_value = TestOutput(value="test")
112
+
113
+ messages = [MagicMock()]
114
+ result = await invoke_async(mock_llm, messages, output_model=TestOutput)
115
+
116
+ assert isinstance(result, TestOutput)
117
+ assert result.value == "test"
118
+ mock_llm.with_structured_output.assert_called_once_with(TestOutput)
@@ -0,0 +1,403 @@
1
+ """Tests for Section 3: Self-Correction Loops (Reflexion).
2
+
3
+ TDD tests for expression conditions, loop tracking, and cyclic graphs.
4
+ """
5
+
6
+ from unittest.mock import MagicMock, patch
7
+
8
+ import pytest
9
+
10
+ # =============================================================================
11
+ # Test: Expression Condition Parsing
12
+ # =============================================================================
13
+
14
+
15
+ class TestExpressionConditions:
16
+ """Tests for condition expression evaluation."""
17
+
18
+ def test_evaluate_condition_exists(self):
19
+ """evaluate_condition function should exist."""
20
+ from yamlgraph.utils.conditions import evaluate_condition
21
+
22
+ assert callable(evaluate_condition)
23
+
24
+ def test_less_than_comparison(self):
25
+ """Evaluates 'score < 0.8' correctly."""
26
+ from yamlgraph.utils.conditions import evaluate_condition
27
+
28
+ state = {"score": 0.5}
29
+ assert evaluate_condition("score < 0.8", state) is True
30
+
31
+ state = {"score": 0.9}
32
+ assert evaluate_condition("score < 0.8", state) is False
33
+
34
+ def test_greater_than_comparison(self):
35
+ """Evaluates 'score > 0.5' correctly."""
36
+ from yamlgraph.utils.conditions import evaluate_condition
37
+
38
+ state = {"score": 0.7}
39
+ assert evaluate_condition("score > 0.5", state) is True
40
+
41
+ state = {"score": 0.3}
42
+ assert evaluate_condition("score > 0.5", state) is False
43
+
44
+ def test_less_than_or_equal(self):
45
+ """Evaluates 'score <= 0.8' correctly."""
46
+ from yamlgraph.utils.conditions import evaluate_condition
47
+
48
+ state = {"score": 0.8}
49
+ assert evaluate_condition("score <= 0.8", state) is True
50
+
51
+ state = {"score": 0.9}
52
+ assert evaluate_condition("score <= 0.8", state) is False
53
+
54
+ def test_greater_than_or_equal(self):
55
+ """Evaluates 'score >= 0.8' correctly."""
56
+ from yamlgraph.utils.conditions import evaluate_condition
57
+
58
+ state = {"score": 0.8}
59
+ assert evaluate_condition("score >= 0.8", state) is True
60
+
61
+ state = {"score": 0.7}
62
+ assert evaluate_condition("score >= 0.8", state) is False
63
+
64
+ def test_equality_comparison(self):
65
+ """Evaluates 'status == \"approved\"' correctly."""
66
+ from yamlgraph.utils.conditions import evaluate_condition
67
+
68
+ state = {"status": "approved"}
69
+ assert evaluate_condition('status == "approved"', state) is True
70
+
71
+ state = {"status": "pending"}
72
+ assert evaluate_condition('status == "approved"', state) is False
73
+
74
+ def test_inequality_comparison(self):
75
+ """Evaluates 'error != null' correctly."""
76
+ from yamlgraph.utils.conditions import evaluate_condition
77
+
78
+ state = {"error": "something"}
79
+ assert evaluate_condition("error != null", state) is True
80
+
81
+ state = {"error": None}
82
+ assert evaluate_condition("error != null", state) is False
83
+
84
+ def test_nested_attribute_access(self):
85
+ """Evaluates 'critique.score >= 0.8' from state."""
86
+ from yamlgraph.utils.conditions import evaluate_condition
87
+
88
+ # Using object with attribute
89
+ critique = MagicMock()
90
+ critique.score = 0.85
91
+ state = {"critique": critique}
92
+ assert evaluate_condition("critique.score >= 0.8", state) is True
93
+
94
+ critique.score = 0.7
95
+ assert evaluate_condition("critique.score >= 0.8", state) is False
96
+
97
+ def test_compound_and_condition(self):
98
+ """Evaluates 'score < 0.8 and iteration < 3'."""
99
+ from yamlgraph.utils.conditions import evaluate_condition
100
+
101
+ state = {"score": 0.5, "iteration": 2}
102
+ assert evaluate_condition("score < 0.8 and iteration < 3", state) is True
103
+
104
+ state = {"score": 0.9, "iteration": 2}
105
+ assert evaluate_condition("score < 0.8 and iteration < 3", state) is False
106
+
107
+ state = {"score": 0.5, "iteration": 5}
108
+ assert evaluate_condition("score < 0.8 and iteration < 3", state) is False
109
+
110
+ def test_compound_or_condition(self):
111
+ """Evaluates 'approved == true or override == true'."""
112
+ from yamlgraph.utils.conditions import evaluate_condition
113
+
114
+ state = {"approved": True, "override": False}
115
+ assert evaluate_condition("approved == true or override == true", state) is True
116
+
117
+ state = {"approved": False, "override": True}
118
+ assert evaluate_condition("approved == true or override == true", state) is True
119
+
120
+ state = {"approved": False, "override": False}
121
+ assert (
122
+ evaluate_condition("approved == true or override == true", state) is False
123
+ )
124
+
125
+ def test_invalid_expression_raises(self):
126
+ """Malformed expression raises ValueError."""
127
+ from yamlgraph.utils.conditions import evaluate_condition
128
+
129
+ with pytest.raises(ValueError):
130
+ evaluate_condition("score <<< 0.8", {})
131
+
132
+ def test_missing_attribute_returns_false(self):
133
+ """Missing attribute in state returns False gracefully."""
134
+ from yamlgraph.utils.conditions import evaluate_condition
135
+
136
+ state = {}
137
+ # Should not raise, should return False for missing attribute
138
+ assert evaluate_condition("score < 0.8", state) is False
139
+
140
+
141
+ # =============================================================================
142
+ # Test: Loop Tracking
143
+ # =============================================================================
144
+
145
+
146
+ class TestLoopTracking:
147
+ """Tests for loop iteration tracking."""
148
+
149
+ def test_state_has_loop_counts_field(self):
150
+ """Dynamic state should have _loop_counts field."""
151
+ from yamlgraph.models.state_builder import build_state_class
152
+
153
+ State = build_state_class({"nodes": {}})
154
+ # Should have _loop_counts in annotations
155
+ assert "_loop_counts" in State.__annotations__
156
+
157
+ # And work at runtime
158
+ state = {"_loop_counts": {"critique": 2}}
159
+ assert state["_loop_counts"]["critique"] == 2
160
+
161
+ def test_node_increments_loop_counter(self):
162
+ """Each node execution increments its counter in _loop_counts."""
163
+ from yamlgraph.node_factory import create_node_function
164
+
165
+ node_config = {
166
+ "prompt": "test_prompt",
167
+ "state_key": "result",
168
+ }
169
+
170
+ with patch("yamlgraph.node_factory.execute_prompt") as mock_execute:
171
+ mock_execute.return_value = "test result"
172
+
173
+ node_fn = create_node_function("critique", node_config, {})
174
+
175
+ # First call - should initialize counter
176
+ state = {"message": "test"}
177
+ result = node_fn(state)
178
+ assert result.get("_loop_counts", {}).get("critique") == 1
179
+
180
+ # Second call - should increment
181
+ state = {"message": "test", "_loop_counts": {"critique": 1}}
182
+ result = node_fn(state)
183
+ assert result.get("_loop_counts", {}).get("critique") == 2
184
+
185
+
186
+ # =============================================================================
187
+ # Test: Loop Limits Configuration
188
+ # =============================================================================
189
+
190
+
191
+ class TestLoopLimits:
192
+ """Tests for loop_limits configuration."""
193
+
194
+ def test_parses_loop_limits_from_yaml(self):
195
+ """GraphConfig parses loop_limits section."""
196
+ from yamlgraph.graph_loader import GraphConfig
197
+
198
+ config_dict = {
199
+ "version": "1.0",
200
+ "name": "test",
201
+ "nodes": {
202
+ "draft": {"prompt": "draft"},
203
+ "critique": {"prompt": "critique"},
204
+ },
205
+ "edges": [
206
+ {"from": "START", "to": "draft"},
207
+ {"from": "draft", "to": "critique"},
208
+ {"from": "critique", "to": "END"},
209
+ ],
210
+ "loop_limits": {
211
+ "critique": 3,
212
+ },
213
+ }
214
+ config = GraphConfig(config_dict)
215
+ assert config.loop_limits == {"critique": 3}
216
+
217
+ def test_loop_limits_defaults_to_empty(self):
218
+ """Missing loop_limits defaults to empty dict."""
219
+ from yamlgraph.graph_loader import GraphConfig
220
+
221
+ config_dict = {
222
+ "version": "1.0",
223
+ "name": "test",
224
+ "nodes": {"node1": {"prompt": "p1"}},
225
+ "edges": [{"from": "START", "to": "node1"}, {"from": "node1", "to": "END"}],
226
+ }
227
+ config = GraphConfig(config_dict)
228
+ assert config.loop_limits == {}
229
+
230
+ def test_node_checks_loop_limit(self):
231
+ """Node execution checks loop limit before running."""
232
+ from yamlgraph.node_factory import create_node_function
233
+
234
+ node_config = {
235
+ "prompt": "test_prompt",
236
+ "state_key": "result",
237
+ "loop_limit": 3, # Node-level limit
238
+ }
239
+
240
+ with patch("yamlgraph.node_factory.execute_prompt") as mock_execute:
241
+ mock_execute.return_value = "test result"
242
+
243
+ node_fn = create_node_function("critique", node_config, {})
244
+
245
+ # Under limit - should execute
246
+ state = {"_loop_counts": {"critique": 2}}
247
+ result = node_fn(state)
248
+ assert "result" in result
249
+
250
+ # At limit - should skip/terminate
251
+ state = {"_loop_counts": {"critique": 3}}
252
+ result = node_fn(state)
253
+ assert result.get("_loop_limit_reached") is True
254
+
255
+
256
+ # =============================================================================
257
+ # Test: Cyclic Edges
258
+ # =============================================================================
259
+
260
+
261
+ class TestCyclicEdges:
262
+ """Tests for cyclic graph support."""
263
+
264
+ def test_allows_backward_edges(self):
265
+ """Graph config allows edges pointing to earlier nodes."""
266
+ from yamlgraph.graph_loader import GraphConfig
267
+
268
+ config_dict = {
269
+ "version": "1.0",
270
+ "name": "test",
271
+ "nodes": {
272
+ "draft": {"prompt": "draft"},
273
+ "critique": {"prompt": "critique"},
274
+ "refine": {"prompt": "refine"},
275
+ },
276
+ "edges": [
277
+ {"from": "START", "to": "draft"},
278
+ {"from": "draft", "to": "critique"},
279
+ {
280
+ "from": "critique",
281
+ "to": "refine",
282
+ "condition": "critique.score < 0.8",
283
+ },
284
+ {"from": "critique", "to": "END", "condition": "critique.score >= 0.8"},
285
+ {"from": "refine", "to": "critique"}, # Backward edge (cycle)
286
+ ],
287
+ "loop_limits": {"critique": 3},
288
+ }
289
+ # Should not raise
290
+ config = GraphConfig(config_dict)
291
+ assert config is not None
292
+
293
+ def test_compiles_cyclic_graph(self):
294
+ """Cyclic graph compiles to StateGraph."""
295
+ from yamlgraph.graph_loader import GraphConfig, compile_graph
296
+
297
+ config_dict = {
298
+ "version": "1.0",
299
+ "name": "test",
300
+ "nodes": {
301
+ "draft": {"prompt": "draft", "state_key": "current_draft"},
302
+ "critique": {"prompt": "critique", "state_key": "critique"},
303
+ "refine": {"prompt": "refine", "state_key": "current_draft"},
304
+ },
305
+ "edges": [
306
+ {"from": "START", "to": "draft"},
307
+ {"from": "draft", "to": "critique"},
308
+ {
309
+ "from": "critique",
310
+ "to": "refine",
311
+ "condition": "critique.score < 0.8",
312
+ },
313
+ {"from": "critique", "to": "END", "condition": "critique.score >= 0.8"},
314
+ {"from": "refine", "to": "critique"}, # Cycle
315
+ ],
316
+ "loop_limits": {"critique": 3},
317
+ }
318
+ config = GraphConfig(config_dict)
319
+ graph = compile_graph(config)
320
+ assert graph is not None
321
+
322
+
323
+ # =============================================================================
324
+ # Test: Pydantic Models
325
+ # =============================================================================
326
+
327
+
328
+ class TestReflexionModels:
329
+ """Tests for DraftContent and Critique-like fixture models.
330
+
331
+ Note: Demo models were removed from yamlgraph.models in Section 10.
332
+ These tests use fixture models to prove the pattern still works.
333
+ """
334
+
335
+ def test_draft_content_model_exists(self):
336
+ """DraftContent-like fixture model can be created."""
337
+ from tests.conftest import FixtureDraftContent
338
+
339
+ assert FixtureDraftContent is not None
340
+
341
+ def test_draft_content_fields(self):
342
+ """DraftContent-like model has content and version fields."""
343
+ from tests.conftest import FixtureDraftContent
344
+
345
+ draft = FixtureDraftContent(content="Test essay", version=1)
346
+ assert draft.content == "Test essay"
347
+ assert draft.version == 1
348
+
349
+ def test_critique_model_exists(self):
350
+ """Critique-like fixture model can be created."""
351
+ from tests.conftest import FixtureCritique
352
+
353
+ assert FixtureCritique is not None
354
+
355
+ def test_critique_fields(self):
356
+ """Critique-like model has score, feedback, issues, should_refine fields."""
357
+ from tests.conftest import FixtureCritique
358
+
359
+ critique = FixtureCritique(
360
+ score=0.75,
361
+ feedback="Improve transitions",
362
+ issues=["Weak intro", "No conclusion"],
363
+ should_refine=True,
364
+ )
365
+ assert critique.score == 0.75
366
+ assert critique.feedback == "Improve transitions"
367
+ assert len(critique.issues) == 2
368
+ assert critique.should_refine is True
369
+
370
+
371
+ # =============================================================================
372
+ # Test: Reflexion Demo Graph
373
+ # =============================================================================
374
+
375
+
376
+ class TestReflexionDemoGraph:
377
+ """Tests for the reflexion-demo.yaml graph."""
378
+
379
+ def test_demo_graph_loads(self):
380
+ """reflexion-demo.yaml loads without error."""
381
+ from yamlgraph.graph_loader import load_graph_config
382
+
383
+ config = load_graph_config("graphs/reflexion-demo.yaml")
384
+ assert config.name == "reflexion-demo"
385
+ assert "draft" in config.nodes
386
+ assert "critique" in config.nodes
387
+ assert "refine" in config.nodes
388
+
389
+ def test_demo_graph_has_loop_limits(self):
390
+ """reflexion-demo.yaml has loop_limits configured."""
391
+ from yamlgraph.graph_loader import load_graph_config
392
+
393
+ config = load_graph_config("graphs/reflexion-demo.yaml")
394
+ assert "critique" in config.loop_limits
395
+ assert config.loop_limits["critique"] >= 3
396
+
397
+ def test_demo_graph_compiles(self):
398
+ """reflexion-demo.yaml compiles to StateGraph."""
399
+ from yamlgraph.graph_loader import compile_graph, load_graph_config
400
+
401
+ config = load_graph_config("graphs/reflexion-demo.yaml")
402
+ graph = compile_graph(config)
403
+ assert graph is not None
@@ -0,0 +1,144 @@
1
+ """Tests for type: map node functionality."""
2
+
3
+ from unittest.mock import MagicMock
4
+
5
+ import pytest
6
+
7
+ from yamlgraph.map_compiler import compile_map_node, wrap_for_reducer
8
+
9
+
10
+ class TestWrapForReducer:
11
+ """Tests for wrap_for_reducer helper."""
12
+
13
+ def test_wraps_result_in_list(self):
14
+ """Wrap node output for reducer aggregation."""
15
+
16
+ def simple_node(state: dict) -> dict:
17
+ return {"result": state["item"] * 2}
18
+
19
+ wrapped = wrap_for_reducer(simple_node, "collected", "result")
20
+ result = wrapped({"item": 5})
21
+
22
+ assert result == {"collected": [10]}
23
+
24
+ def test_preserves_map_index(self):
25
+ """Preserve _map_index in wrapped output."""
26
+
27
+ def node_fn(state: dict) -> dict:
28
+ return {"data": state["value"]}
29
+
30
+ wrapped = wrap_for_reducer(node_fn, "results", "data")
31
+ result = wrapped({"value": "test", "_map_index": 2})
32
+
33
+ assert result == {"results": [{"_map_index": 2, "value": "test"}]}
34
+
35
+ def test_extracts_state_key(self):
36
+ """Extract specific state_key from node result."""
37
+
38
+ def node_fn(state: dict) -> dict:
39
+ return {"frame_data": {"before": "a", "after": "b"}, "other": "ignore"}
40
+
41
+ wrapped = wrap_for_reducer(node_fn, "frames", "frame_data")
42
+ result = wrapped({})
43
+
44
+ assert result == {"frames": [{"before": "a", "after": "b"}]}
45
+
46
+
47
+ class TestCompileMapNode:
48
+ """Tests for compile_map_node function."""
49
+
50
+ def test_creates_map_edge_function(self):
51
+ """compile_map_node returns a map edge function."""
52
+ config = {
53
+ "over": "{items}",
54
+ "as": "item",
55
+ "collect": "results",
56
+ "node": {"type": "llm", "prompt": "test", "state_key": "result"},
57
+ }
58
+ builder = MagicMock()
59
+ defaults = {}
60
+
61
+ map_edge, sub_node_name = compile_map_node("expand", config, builder, defaults)
62
+
63
+ # Should return callable and sub-node name
64
+ assert callable(map_edge)
65
+ assert sub_node_name == "_map_expand_sub"
66
+
67
+ def test_map_edge_returns_send_list(self):
68
+ """Map edge function returns list of Send objects."""
69
+ from langgraph.types import Send
70
+
71
+ config = {
72
+ "over": "{items}",
73
+ "as": "item",
74
+ "collect": "results",
75
+ "node": {"type": "llm", "prompt": "test", "state_key": "result"},
76
+ }
77
+ builder = MagicMock()
78
+ defaults = {}
79
+
80
+ map_edge, sub_node_name = compile_map_node("expand", config, builder, defaults)
81
+
82
+ state = {"items": ["a", "b", "c"]}
83
+ sends = map_edge(state)
84
+
85
+ assert len(sends) == 3
86
+ assert all(isinstance(s, Send) for s in sends)
87
+ assert sends[0].node == sub_node_name
88
+ assert sends[0].arg["item"] == "a"
89
+ assert sends[0].arg["_map_index"] == 0
90
+ assert sends[1].arg["item"] == "b"
91
+ assert sends[1].arg["_map_index"] == 1
92
+
93
+ def test_map_edge_empty_list(self):
94
+ """Empty list returns empty Send list."""
95
+ config = {
96
+ "over": "{items}",
97
+ "as": "item",
98
+ "collect": "results",
99
+ "node": {"type": "llm", "prompt": "test", "state_key": "result"},
100
+ }
101
+ builder = MagicMock()
102
+ defaults = {}
103
+
104
+ map_edge, _ = compile_map_node("expand", config, builder, defaults)
105
+
106
+ state = {"items": []}
107
+ sends = map_edge(state)
108
+
109
+ assert sends == []
110
+
111
+ def test_adds_wrapped_sub_node_to_builder(self):
112
+ """compile_map_node adds wrapped sub-node to builder."""
113
+ config = {
114
+ "over": "{items}",
115
+ "as": "item",
116
+ "collect": "results",
117
+ "node": {"type": "llm", "prompt": "test", "state_key": "result"},
118
+ }
119
+ builder = MagicMock()
120
+ defaults = {}
121
+
122
+ compile_map_node("expand", config, builder, defaults)
123
+
124
+ # Should call builder.add_node
125
+ builder.add_node.assert_called_once()
126
+ call_args = builder.add_node.call_args
127
+ assert call_args[0][0] == "_map_expand_sub"
128
+
129
+ def test_validates_over_is_list(self):
130
+ """Map edge validates that 'over' resolves to a list."""
131
+ config = {
132
+ "over": "{not_a_list}",
133
+ "as": "item",
134
+ "collect": "results",
135
+ "node": {"type": "llm", "prompt": "test", "state_key": "result"},
136
+ }
137
+ builder = MagicMock()
138
+ defaults = {}
139
+
140
+ map_edge, _ = compile_map_node("expand", config, builder, defaults)
141
+
142
+ state = {"not_a_list": "string"}
143
+ with pytest.raises(TypeError, match="must resolve to list"):
144
+ map_edge(state)