yamlgraph 0.3.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. examples/__init__.py +1 -0
  2. examples/codegen/__init__.py +5 -0
  3. examples/codegen/models/__init__.py +13 -0
  4. examples/codegen/models/schemas.py +76 -0
  5. examples/codegen/tests/__init__.py +1 -0
  6. examples/codegen/tests/test_ai_helpers.py +235 -0
  7. examples/codegen/tests/test_ast_analysis.py +174 -0
  8. examples/codegen/tests/test_code_analysis.py +134 -0
  9. examples/codegen/tests/test_code_context.py +301 -0
  10. examples/codegen/tests/test_code_nav.py +89 -0
  11. examples/codegen/tests/test_dependency_tools.py +119 -0
  12. examples/codegen/tests/test_example_tools.py +185 -0
  13. examples/codegen/tests/test_git_tools.py +112 -0
  14. examples/codegen/tests/test_impl_agent_schemas.py +193 -0
  15. examples/codegen/tests/test_impl_agent_v4_graph.py +94 -0
  16. examples/codegen/tests/test_jedi_analysis.py +226 -0
  17. examples/codegen/tests/test_meta_tools.py +250 -0
  18. examples/codegen/tests/test_plan_discovery_prompt.py +98 -0
  19. examples/codegen/tests/test_syntax_tools.py +85 -0
  20. examples/codegen/tests/test_synthesize_prompt.py +94 -0
  21. examples/codegen/tests/test_template_tools.py +244 -0
  22. examples/codegen/tools/__init__.py +80 -0
  23. examples/codegen/tools/ai_helpers.py +420 -0
  24. examples/codegen/tools/ast_analysis.py +92 -0
  25. examples/codegen/tools/code_context.py +180 -0
  26. examples/codegen/tools/code_nav.py +52 -0
  27. examples/codegen/tools/dependency_tools.py +120 -0
  28. examples/codegen/tools/example_tools.py +188 -0
  29. examples/codegen/tools/git_tools.py +151 -0
  30. examples/codegen/tools/impl_executor.py +614 -0
  31. examples/codegen/tools/jedi_analysis.py +311 -0
  32. examples/codegen/tools/meta_tools.py +202 -0
  33. examples/codegen/tools/syntax_tools.py +26 -0
  34. examples/codegen/tools/template_tools.py +356 -0
  35. examples/fastapi_interview.py +167 -0
  36. examples/npc/api/__init__.py +1 -0
  37. examples/npc/api/app.py +100 -0
  38. examples/npc/api/routes/__init__.py +5 -0
  39. examples/npc/api/routes/encounter.py +182 -0
  40. examples/npc/api/session.py +330 -0
  41. examples/npc/demo.py +387 -0
  42. examples/npc/nodes/__init__.py +5 -0
  43. examples/npc/nodes/image_node.py +92 -0
  44. examples/npc/run_encounter.py +230 -0
  45. examples/shared/__init__.py +0 -0
  46. examples/shared/replicate_tool.py +238 -0
  47. examples/storyboard/__init__.py +1 -0
  48. examples/storyboard/generate_videos.py +335 -0
  49. examples/storyboard/nodes/__init__.py +12 -0
  50. examples/storyboard/nodes/animated_character_node.py +248 -0
  51. examples/storyboard/nodes/animated_image_node.py +138 -0
  52. examples/storyboard/nodes/character_node.py +162 -0
  53. examples/storyboard/nodes/image_node.py +118 -0
  54. examples/storyboard/nodes/replicate_tool.py +49 -0
  55. examples/storyboard/retry_images.py +118 -0
  56. scripts/demo_async_executor.py +212 -0
  57. scripts/demo_interview_e2e.py +200 -0
  58. scripts/demo_streaming.py +140 -0
  59. scripts/run_interview_demo.py +94 -0
  60. scripts/test_interrupt_fix.py +26 -0
  61. tests/__init__.py +1 -0
  62. tests/conftest.py +178 -0
  63. tests/integration/__init__.py +1 -0
  64. tests/integration/test_animated_storyboard.py +63 -0
  65. tests/integration/test_cli_commands.py +242 -0
  66. tests/integration/test_colocated_prompts.py +139 -0
  67. tests/integration/test_map_demo.py +50 -0
  68. tests/integration/test_memory_demo.py +283 -0
  69. tests/integration/test_npc_api/__init__.py +1 -0
  70. tests/integration/test_npc_api/test_routes.py +357 -0
  71. tests/integration/test_npc_api/test_session.py +216 -0
  72. tests/integration/test_pipeline_flow.py +105 -0
  73. tests/integration/test_providers.py +163 -0
  74. tests/integration/test_resume.py +75 -0
  75. tests/integration/test_subgraph_integration.py +295 -0
  76. tests/integration/test_subgraph_interrupt.py +106 -0
  77. tests/unit/__init__.py +1 -0
  78. tests/unit/test_agent_nodes.py +355 -0
  79. tests/unit/test_async_executor.py +346 -0
  80. tests/unit/test_checkpointer.py +212 -0
  81. tests/unit/test_checkpointer_factory.py +212 -0
  82. tests/unit/test_cli.py +121 -0
  83. tests/unit/test_cli_package.py +81 -0
  84. tests/unit/test_compile_graph_map.py +132 -0
  85. tests/unit/test_conditions_routing.py +253 -0
  86. tests/unit/test_config.py +93 -0
  87. tests/unit/test_conversation_memory.py +276 -0
  88. tests/unit/test_database.py +145 -0
  89. tests/unit/test_deprecation.py +104 -0
  90. tests/unit/test_executor.py +172 -0
  91. tests/unit/test_executor_async.py +179 -0
  92. tests/unit/test_export.py +149 -0
  93. tests/unit/test_expressions.py +178 -0
  94. tests/unit/test_feature_brainstorm.py +194 -0
  95. tests/unit/test_format_prompt.py +145 -0
  96. tests/unit/test_generic_report.py +200 -0
  97. tests/unit/test_graph_commands.py +327 -0
  98. tests/unit/test_graph_linter.py +627 -0
  99. tests/unit/test_graph_loader.py +357 -0
  100. tests/unit/test_graph_schema.py +193 -0
  101. tests/unit/test_inline_schema.py +151 -0
  102. tests/unit/test_interrupt_node.py +182 -0
  103. tests/unit/test_issues.py +164 -0
  104. tests/unit/test_jinja2_prompts.py +85 -0
  105. tests/unit/test_json_extract.py +134 -0
  106. tests/unit/test_langsmith.py +600 -0
  107. tests/unit/test_langsmith_tools.py +204 -0
  108. tests/unit/test_llm_factory.py +109 -0
  109. tests/unit/test_llm_factory_async.py +118 -0
  110. tests/unit/test_loops.py +403 -0
  111. tests/unit/test_map_node.py +144 -0
  112. tests/unit/test_no_backward_compat.py +56 -0
  113. tests/unit/test_node_factory.py +348 -0
  114. tests/unit/test_passthrough_node.py +126 -0
  115. tests/unit/test_prompts.py +324 -0
  116. tests/unit/test_python_nodes.py +198 -0
  117. tests/unit/test_reliability.py +298 -0
  118. tests/unit/test_result_export.py +234 -0
  119. tests/unit/test_router.py +296 -0
  120. tests/unit/test_sanitize.py +99 -0
  121. tests/unit/test_schema_loader.py +295 -0
  122. tests/unit/test_shell_tools.py +229 -0
  123. tests/unit/test_state_builder.py +331 -0
  124. tests/unit/test_state_builder_map.py +104 -0
  125. tests/unit/test_state_config.py +197 -0
  126. tests/unit/test_streaming.py +307 -0
  127. tests/unit/test_subgraph.py +596 -0
  128. tests/unit/test_template.py +190 -0
  129. tests/unit/test_tool_call_integration.py +164 -0
  130. tests/unit/test_tool_call_node.py +178 -0
  131. tests/unit/test_tool_nodes.py +129 -0
  132. tests/unit/test_websearch.py +234 -0
  133. yamlgraph/__init__.py +35 -0
  134. yamlgraph/builder.py +110 -0
  135. yamlgraph/cli/__init__.py +159 -0
  136. yamlgraph/cli/__main__.py +6 -0
  137. yamlgraph/cli/commands.py +231 -0
  138. yamlgraph/cli/deprecation.py +92 -0
  139. yamlgraph/cli/graph_commands.py +541 -0
  140. yamlgraph/cli/validators.py +37 -0
  141. yamlgraph/config.py +67 -0
  142. yamlgraph/constants.py +70 -0
  143. yamlgraph/error_handlers.py +227 -0
  144. yamlgraph/executor.py +290 -0
  145. yamlgraph/executor_async.py +288 -0
  146. yamlgraph/graph_loader.py +451 -0
  147. yamlgraph/map_compiler.py +150 -0
  148. yamlgraph/models/__init__.py +36 -0
  149. yamlgraph/models/graph_schema.py +181 -0
  150. yamlgraph/models/schemas.py +124 -0
  151. yamlgraph/models/state_builder.py +236 -0
  152. yamlgraph/node_factory.py +768 -0
  153. yamlgraph/routing.py +87 -0
  154. yamlgraph/schema_loader.py +240 -0
  155. yamlgraph/storage/__init__.py +20 -0
  156. yamlgraph/storage/checkpointer.py +72 -0
  157. yamlgraph/storage/checkpointer_factory.py +123 -0
  158. yamlgraph/storage/database.py +320 -0
  159. yamlgraph/storage/export.py +269 -0
  160. yamlgraph/tools/__init__.py +1 -0
  161. yamlgraph/tools/agent.py +320 -0
  162. yamlgraph/tools/graph_linter.py +388 -0
  163. yamlgraph/tools/langsmith_tools.py +125 -0
  164. yamlgraph/tools/nodes.py +126 -0
  165. yamlgraph/tools/python_tool.py +179 -0
  166. yamlgraph/tools/shell.py +205 -0
  167. yamlgraph/tools/websearch.py +242 -0
  168. yamlgraph/utils/__init__.py +48 -0
  169. yamlgraph/utils/conditions.py +157 -0
  170. yamlgraph/utils/expressions.py +245 -0
  171. yamlgraph/utils/json_extract.py +104 -0
  172. yamlgraph/utils/langsmith.py +416 -0
  173. yamlgraph/utils/llm_factory.py +118 -0
  174. yamlgraph/utils/llm_factory_async.py +105 -0
  175. yamlgraph/utils/logging.py +104 -0
  176. yamlgraph/utils/prompts.py +171 -0
  177. yamlgraph/utils/sanitize.py +98 -0
  178. yamlgraph/utils/template.py +102 -0
  179. yamlgraph/utils/validators.py +181 -0
  180. yamlgraph-0.3.9.dist-info/METADATA +1105 -0
  181. yamlgraph-0.3.9.dist-info/RECORD +185 -0
  182. yamlgraph-0.3.9.dist-info/WHEEL +5 -0
  183. yamlgraph-0.3.9.dist-info/entry_points.txt +2 -0
  184. yamlgraph-0.3.9.dist-info/licenses/LICENSE +33 -0
  185. yamlgraph-0.3.9.dist-info/top_level.txt +4 -0
@@ -0,0 +1,276 @@
1
+ """Tests for Phase 6.3: Conversation Memory.
2
+
3
+ Tests that agent nodes:
4
+ 1. Return messages to state for accumulation
5
+ 2. Store raw tool results in state
6
+ 3. Support multi-turn conversations via thread_id
7
+ """
8
+
9
+ from unittest.mock import MagicMock, patch
10
+
11
+ from langchain_core.messages import (
12
+ AIMessage,
13
+ HumanMessage,
14
+ SystemMessage,
15
+ )
16
+
17
+
18
+ class TestAgentReturnsMessages:
19
+ """Tests for message accumulation in agent state."""
20
+
21
+ def test_agent_returns_messages_in_state(self):
22
+ """Agent node should return messages for state accumulation."""
23
+ from yamlgraph.tools.agent import create_agent_node
24
+
25
+ # Setup mock LLM
26
+ mock_response = MagicMock()
27
+ mock_response.content = "Analysis complete"
28
+ mock_response.tool_calls = []
29
+
30
+ mock_llm = MagicMock()
31
+ mock_llm.bind_tools.return_value = mock_llm
32
+ mock_llm.invoke.return_value = mock_response
33
+
34
+ with patch("yamlgraph.tools.agent.create_llm", return_value=mock_llm):
35
+ node_fn = create_agent_node(
36
+ "agent",
37
+ {"tools": [], "state_key": "result"},
38
+ {},
39
+ )
40
+ result = node_fn({"input": "test"})
41
+
42
+ # Should include messages in output
43
+ assert "messages" in result, "Agent should return messages for accumulation"
44
+ assert (
45
+ len(result["messages"]) >= 2
46
+ ), "Should have at least system + user + AI messages"
47
+
48
+ def test_agent_messages_include_all_types(self):
49
+ """Agent should include system, user, AI, and tool messages."""
50
+ from yamlgraph.tools.agent import create_agent_node
51
+ from yamlgraph.tools.shell import ShellToolConfig
52
+
53
+ # Mock LLM that calls a tool then responds
54
+ # Use actual AIMessage for proper type checking
55
+ tool_response = AIMessage(
56
+ content="", tool_calls=[{"name": "test_tool", "args": {}, "id": "call_1"}]
57
+ )
58
+
59
+ final_response = AIMessage(content="Done")
60
+
61
+ mock_llm = MagicMock()
62
+ mock_llm.bind_tools.return_value = mock_llm
63
+ mock_llm.invoke.side_effect = [tool_response, final_response]
64
+
65
+ tool_config = ShellToolConfig(
66
+ command="echo test",
67
+ description="Test tool",
68
+ )
69
+
70
+ with (
71
+ patch("yamlgraph.tools.agent.create_llm", return_value=mock_llm),
72
+ patch("yamlgraph.tools.agent.execute_shell_tool") as mock_exec,
73
+ ):
74
+ mock_exec.return_value = MagicMock(success=True, output="tool output")
75
+
76
+ node_fn = create_agent_node(
77
+ "agent",
78
+ {"tools": ["test_tool"], "state_key": "result"},
79
+ {"test_tool": tool_config},
80
+ )
81
+ result = node_fn({"input": "test"})
82
+
83
+ messages = result["messages"]
84
+ types = [type(m).__name__ for m in messages]
85
+
86
+ assert "SystemMessage" in types, "Should include system message"
87
+ assert "HumanMessage" in types, "Should include human message"
88
+ assert "AIMessage" in types, "Should include AI message"
89
+ assert "ToolMessage" in types, "Should include tool message"
90
+
91
+
92
+ class TestToolResultsPersistence:
93
+ """Tests for raw tool result storage."""
94
+
95
+ def test_tool_results_stored_in_state(self):
96
+ """Agent should store raw tool results in state."""
97
+ from yamlgraph.tools.agent import create_agent_node
98
+ from yamlgraph.tools.shell import ShellToolConfig
99
+
100
+ tool_response = MagicMock()
101
+ tool_response.content = ""
102
+ tool_response.tool_calls = [
103
+ {"name": "git_log", "args": {"count": "5"}, "id": "call_1"}
104
+ ]
105
+
106
+ final_response = MagicMock()
107
+ final_response.content = "Report ready"
108
+ final_response.tool_calls = []
109
+
110
+ mock_llm = MagicMock()
111
+ mock_llm.bind_tools.return_value = mock_llm
112
+ mock_llm.invoke.side_effect = [tool_response, final_response]
113
+
114
+ tool_config = ShellToolConfig(
115
+ command="git log -n {count}",
116
+ description="Get git log",
117
+ )
118
+
119
+ with (
120
+ patch("yamlgraph.tools.agent.create_llm", return_value=mock_llm),
121
+ patch("yamlgraph.tools.agent.execute_shell_tool") as mock_exec,
122
+ ):
123
+ mock_exec.return_value = MagicMock(
124
+ success=True, output="commit abc123\nAuthor: Test"
125
+ )
126
+
127
+ node_fn = create_agent_node(
128
+ "agent",
129
+ {
130
+ "tools": ["git_log"],
131
+ "state_key": "report",
132
+ "tool_results_key": "_tool_results",
133
+ },
134
+ {"git_log": tool_config},
135
+ )
136
+ result = node_fn({"input": "analyze"})
137
+
138
+ assert "_tool_results" in result, "Should include tool_results in state"
139
+ assert len(result["_tool_results"]) == 1, "Should have one tool result"
140
+
141
+ tool_result = result["_tool_results"][0]
142
+ assert tool_result["tool"] == "git_log"
143
+ assert tool_result["args"] == {"count": "5"}
144
+ assert "commit abc123" in tool_result["output"]
145
+ assert tool_result["success"] is True
146
+
147
+ def test_tool_results_key_is_optional(self):
148
+ """Without tool_results_key, raw results are not stored."""
149
+ from yamlgraph.tools.agent import create_agent_node
150
+
151
+ mock_response = MagicMock()
152
+ mock_response.content = "Done"
153
+ mock_response.tool_calls = []
154
+
155
+ mock_llm = MagicMock()
156
+ mock_llm.bind_tools.return_value = mock_llm
157
+ mock_llm.invoke.return_value = mock_response
158
+
159
+ with patch("yamlgraph.tools.agent.create_llm", return_value=mock_llm):
160
+ node_fn = create_agent_node(
161
+ "agent",
162
+ {"tools": [], "state_key": "result"}, # No tool_results_key
163
+ {},
164
+ )
165
+ result = node_fn({"input": "test"})
166
+
167
+ # Should NOT have _tool_results if not configured
168
+ assert "_tool_results" not in result
169
+
170
+ def test_multiple_tool_calls_all_stored(self):
171
+ """Multiple tool calls should all be stored."""
172
+ from yamlgraph.tools.agent import create_agent_node
173
+ from yamlgraph.tools.shell import ShellToolConfig
174
+
175
+ tool_response = MagicMock()
176
+ tool_response.content = ""
177
+ tool_response.tool_calls = [
178
+ {"name": "tool_a", "args": {}, "id": "call_1"},
179
+ {"name": "tool_b", "args": {}, "id": "call_2"},
180
+ ]
181
+
182
+ final_response = MagicMock()
183
+ final_response.content = "Done"
184
+ final_response.tool_calls = []
185
+
186
+ mock_llm = MagicMock()
187
+ mock_llm.bind_tools.return_value = mock_llm
188
+ mock_llm.invoke.side_effect = [tool_response, final_response]
189
+
190
+ tools = {
191
+ "tool_a": ShellToolConfig(command="echo a", description="A"),
192
+ "tool_b": ShellToolConfig(command="echo b", description="B"),
193
+ }
194
+
195
+ with (
196
+ patch("yamlgraph.tools.agent.create_llm", return_value=mock_llm),
197
+ patch("yamlgraph.tools.agent.execute_shell_tool") as mock_exec,
198
+ ):
199
+ mock_exec.return_value = MagicMock(success=True, output="output")
200
+
201
+ node_fn = create_agent_node(
202
+ "agent",
203
+ {
204
+ "tools": ["tool_a", "tool_b"],
205
+ "state_key": "result",
206
+ "tool_results_key": "_tool_results",
207
+ },
208
+ tools,
209
+ )
210
+ result = node_fn({"input": "test"})
211
+
212
+ assert len(result["_tool_results"]) == 2
213
+ tool_names = [r["tool"] for r in result["_tool_results"]]
214
+ assert "tool_a" in tool_names
215
+ assert "tool_b" in tool_names
216
+
217
+
218
+ class TestMultiTurnConversation:
219
+ """Tests for multi-turn conversation support."""
220
+
221
+ def test_existing_messages_preserved(self):
222
+ """Agent should preserve existing messages from state."""
223
+ from yamlgraph.tools.agent import create_agent_node
224
+
225
+ mock_response = MagicMock()
226
+ mock_response.content = "Follow-up response"
227
+ mock_response.tool_calls = []
228
+
229
+ mock_llm = MagicMock()
230
+ mock_llm.bind_tools.return_value = mock_llm
231
+ mock_llm.invoke.return_value = mock_response
232
+
233
+ # Simulate state with existing messages
234
+ existing_messages = [
235
+ SystemMessage(content="You are helpful."),
236
+ HumanMessage(content="First question"),
237
+ AIMessage(content="First answer"),
238
+ ]
239
+
240
+ with patch("yamlgraph.tools.agent.create_llm", return_value=mock_llm):
241
+ node_fn = create_agent_node(
242
+ "agent",
243
+ {"tools": [], "state_key": "result"},
244
+ {},
245
+ )
246
+ result = node_fn(
247
+ {
248
+ "input": "Follow-up question",
249
+ "messages": existing_messages,
250
+ }
251
+ )
252
+
253
+ messages = result["messages"]
254
+ # Should include new messages (at minimum human + AI for this turn)
255
+ # The exact count depends on implementation - key is messages are returned
256
+ assert len(messages) >= 2, "Should return messages for accumulation"
257
+
258
+ def test_agent_state_message_reducer_works(self):
259
+ """Dynamic state's Annotated[list, add] should accumulate messages."""
260
+ from operator import add as add_op
261
+ from typing import get_type_hints
262
+
263
+ from yamlgraph.models.state_builder import build_state_class
264
+
265
+ State = build_state_class({"nodes": {"agent": {"type": "agent"}}})
266
+ hints = get_type_hints(State, include_extras=True)
267
+
268
+ # Check messages field has reducer annotation
269
+ messages_hint = hints.get("messages")
270
+ assert messages_hint is not None, "State should have messages field"
271
+
272
+ # The Annotated type should have add as metadata
273
+ if hasattr(messages_hint, "__metadata__"):
274
+ assert (
275
+ add_op in messages_hint.__metadata__
276
+ ), "messages should use add reducer"
@@ -0,0 +1,145 @@
1
+ """Tests for yamlgraph.storage.database module."""
2
+
3
+ from yamlgraph.models import create_initial_state
4
+
5
+
6
+ class TestYamlGraphDB:
7
+ """Tests for YamlGraphDB class."""
8
+
9
+ def test_db_initialization(self, temp_db):
10
+ """Database should initialize successfully."""
11
+ assert temp_db.db_path.exists()
12
+
13
+ def test_save_and_load_state(self, temp_db, sample_state):
14
+ """State should be saved and loaded correctly."""
15
+ thread_id = sample_state["thread_id"]
16
+ temp_db.save_state(thread_id, sample_state, status="completed")
17
+
18
+ loaded = temp_db.load_state(thread_id)
19
+ assert loaded is not None
20
+ assert loaded["topic"] == sample_state["topic"]
21
+ assert loaded["thread_id"] == thread_id
22
+
23
+ def test_load_nonexistent_state(self, temp_db):
24
+ """Loading nonexistent state should return None."""
25
+ result = temp_db.load_state("nonexistent")
26
+ assert result is None
27
+
28
+ def test_update_existing_state(self, temp_db, empty_state):
29
+ """Updating existing state should work."""
30
+ thread_id = empty_state["thread_id"]
31
+
32
+ # Save initial state
33
+ temp_db.save_state(thread_id, empty_state, status="running")
34
+
35
+ # Update state
36
+ empty_state["current_step"] = "generate"
37
+ temp_db.save_state(thread_id, empty_state, status="completed")
38
+
39
+ # Load and verify
40
+ loaded = temp_db.load_state(thread_id)
41
+ assert loaded["current_step"] == "generate"
42
+
43
+ def test_list_runs_empty(self, temp_db):
44
+ """List runs should return empty list when no runs."""
45
+ runs = temp_db.list_runs()
46
+ assert runs == []
47
+
48
+ def test_list_runs_with_data(self, temp_db):
49
+ """List runs should return saved runs."""
50
+ state1 = create_initial_state(topic="test1", thread_id="thread1")
51
+ state2 = create_initial_state(topic="test2", thread_id="thread2")
52
+
53
+ temp_db.save_state("thread1", state1, status="completed")
54
+ temp_db.save_state("thread2", state2, status="running")
55
+
56
+ runs = temp_db.list_runs()
57
+ assert len(runs) == 2
58
+ thread_ids = [r["thread_id"] for r in runs]
59
+ assert "thread1" in thread_ids
60
+ assert "thread2" in thread_ids
61
+
62
+ def test_list_runs_limit(self, temp_db):
63
+ """List runs should respect limit parameter."""
64
+ for i in range(5):
65
+ state = create_initial_state(topic=f"test{i}", thread_id=f"thread{i}")
66
+ temp_db.save_state(f"thread{i}", state)
67
+
68
+ runs = temp_db.list_runs(limit=3)
69
+ assert len(runs) == 3
70
+
71
+ def test_delete_run(self, temp_db, empty_state):
72
+ """Delete run should remove the state."""
73
+ thread_id = empty_state["thread_id"]
74
+ temp_db.save_state(thread_id, empty_state)
75
+
76
+ result = temp_db.delete_run(thread_id)
77
+ assert result is True
78
+
79
+ loaded = temp_db.load_state(thread_id)
80
+ assert loaded is None
81
+
82
+ def test_delete_nonexistent_run(self, temp_db):
83
+ """Deleting nonexistent run should return False."""
84
+ result = temp_db.delete_run("nonexistent")
85
+ assert result is False
86
+
87
+ def test_serialize_state_with_pydantic(self, temp_db, sample_state):
88
+ """State with Pydantic models should serialize correctly."""
89
+ thread_id = sample_state["thread_id"]
90
+ temp_db.save_state(thread_id, sample_state)
91
+
92
+ loaded = temp_db.load_state(thread_id)
93
+ # Pydantic models should be dicts after serialization
94
+ assert isinstance(loaded["generated"], dict)
95
+ assert loaded["generated"]["title"] == "Test Article"
96
+
97
+
98
+ class TestConnectionPool:
99
+ """Tests for connection pooling."""
100
+
101
+ def test_pooled_mode_works(self, tmp_path):
102
+ """Database should work in pooled mode."""
103
+ from yamlgraph.storage.database import YamlGraphDB
104
+
105
+ db_path = tmp_path / "pooled_test.db"
106
+ db = YamlGraphDB(db_path=db_path, use_pool=True, pool_size=3)
107
+
108
+ try:
109
+ # Save and load should work
110
+ db.save_state("test-thread", {"topic": "test"})
111
+ loaded = db.load_state("test-thread")
112
+ assert loaded["topic"] == "test"
113
+ finally:
114
+ db.close()
115
+
116
+ def test_pool_reuses_connections(self, tmp_path):
117
+ """Pool should reuse connections."""
118
+ from yamlgraph.storage.database import ConnectionPool
119
+
120
+ db_path = tmp_path / "pool_test.db"
121
+ pool = ConnectionPool(db_path, pool_size=2)
122
+
123
+ # Get a connection, use it, return it
124
+ with pool.get_connection() as conn1:
125
+ conn1_id = id(conn1)
126
+
127
+ # Next connection should be the same one (reused from pool)
128
+ with pool.get_connection() as conn2:
129
+ conn2_id = id(conn2)
130
+
131
+ assert conn1_id == conn2_id
132
+ pool.close_all()
133
+
134
+ def test_close_method(self, tmp_path):
135
+ """Close should clean up connections."""
136
+ from yamlgraph.storage.database import YamlGraphDB
137
+
138
+ db_path = tmp_path / "close_test.db"
139
+ db = YamlGraphDB(db_path=db_path, use_pool=True, pool_size=2)
140
+
141
+ # Use the connection to create one
142
+ db.save_state("test", {"data": "value"})
143
+
144
+ # Close should not raise
145
+ db.close()
@@ -0,0 +1,104 @@
1
+ """Tests for deprecation module.
2
+
3
+ TDD tests for DeprecationError and deprecation utilities.
4
+ """
5
+
6
+ import pytest
7
+
8
+
9
+ class TestDeprecationError:
10
+ """Tests for DeprecationError exception."""
11
+
12
+ def test_deprecation_error_exists(self):
13
+ """DeprecationError should be importable."""
14
+ from yamlgraph.cli.deprecation import DeprecationError
15
+
16
+ assert issubclass(DeprecationError, Exception)
17
+
18
+ def test_deprecation_error_message(self):
19
+ """DeprecationError should include replacement command."""
20
+ from yamlgraph.cli.deprecation import DeprecationError
21
+
22
+ err = DeprecationError(
23
+ old_command="route",
24
+ new_command="graph run graphs/router-demo.yaml --var message=...",
25
+ )
26
+
27
+ assert "route" in str(err)
28
+ assert "graph run" in str(err)
29
+ assert "deprecated" in str(err).lower()
30
+
31
+ def test_deprecation_error_has_attributes(self):
32
+ """DeprecationError should expose old and new commands."""
33
+ from yamlgraph.cli.deprecation import DeprecationError
34
+
35
+ err = DeprecationError(
36
+ old_command="refine",
37
+ new_command="graph run graphs/reflexion-demo.yaml",
38
+ )
39
+
40
+ assert err.old_command == "refine"
41
+ assert err.new_command == "graph run graphs/reflexion-demo.yaml"
42
+
43
+
44
+ class TestDeprecatedCommand:
45
+ """Tests for deprecated_command decorator/helper."""
46
+
47
+ def test_deprecated_command_exists(self):
48
+ """deprecated_command should be importable."""
49
+ from yamlgraph.cli.deprecation import deprecated_command
50
+
51
+ assert callable(deprecated_command)
52
+
53
+ def test_deprecated_command_raises(self):
54
+ """deprecated_command should raise DeprecationError."""
55
+ from yamlgraph.cli.deprecation import DeprecationError, deprecated_command
56
+
57
+ with pytest.raises(DeprecationError) as exc_info:
58
+ deprecated_command(
59
+ "route",
60
+ "graph run graphs/router-demo.yaml --var message=...",
61
+ )
62
+
63
+ assert "route" in str(exc_info.value)
64
+
65
+ def test_deprecated_command_with_mapping(self):
66
+ """deprecated_command should format with variable mapping."""
67
+ from yamlgraph.cli.deprecation import DeprecationError, deprecated_command
68
+
69
+ with pytest.raises(DeprecationError) as exc_info:
70
+ deprecated_command(
71
+ "refine --topic X",
72
+ "graph run graphs/reflexion-demo.yaml --var topic=X",
73
+ )
74
+
75
+ assert "topic=X" in str(exc_info.value)
76
+
77
+
78
+ class TestCommandMappings:
79
+ """Tests for deprecated command mappings."""
80
+
81
+ def test_get_replacement_for_route(self):
82
+ """Should return replacement for route command."""
83
+ from yamlgraph.cli.deprecation import get_replacement_command
84
+
85
+ result = get_replacement_command("route", {"message": "hello"})
86
+ assert "graph run" in result
87
+ assert "router-demo.yaml" in result
88
+ assert "message=hello" in result
89
+
90
+ def test_get_replacement_for_refine(self):
91
+ """Should return replacement for refine command."""
92
+ from yamlgraph.cli.deprecation import get_replacement_command
93
+
94
+ result = get_replacement_command("refine", {"topic": "AI"})
95
+ assert "graph run" in result
96
+ assert "reflexion-demo.yaml" in result
97
+ assert "topic=AI" in result
98
+
99
+ def test_get_replacement_unknown_command(self):
100
+ """Unknown command returns None."""
101
+ from yamlgraph.cli.deprecation import get_replacement_command
102
+
103
+ result = get_replacement_command("unknown", {})
104
+ assert result is None
@@ -0,0 +1,172 @@
1
+ """Tests for yamlgraph.executor module."""
2
+
3
+ import pytest
4
+
5
+ from yamlgraph.executor import format_prompt, load_prompt
6
+
7
+
8
+ class TestLoadPrompt:
9
+ """Tests for load_prompt function."""
10
+
11
+ def test_load_existing_prompt(self):
12
+ """Should load an existing prompt file."""
13
+ prompt = load_prompt("generate")
14
+ assert "system" in prompt
15
+ assert "user" in prompt
16
+
17
+ def test_load_analyze_prompt(self):
18
+ """Should load analyze prompt."""
19
+ prompt = load_prompt("analyze")
20
+ assert "system" in prompt
21
+ assert "{content}" in prompt["user"]
22
+
23
+ def test_load_nonexistent_prompt(self):
24
+ """Should raise FileNotFoundError for missing prompt."""
25
+ with pytest.raises(FileNotFoundError):
26
+ load_prompt("nonexistent_prompt")
27
+
28
+
29
+ class TestFormatPrompt:
30
+ """Tests for format_prompt function."""
31
+
32
+ def test_format_single_variable(self):
33
+ """Should format single variable."""
34
+ template = "Hello, {name}!"
35
+ result = format_prompt(template, {"name": "World"})
36
+ assert result == "Hello, World!"
37
+
38
+ def test_format_multiple_variables(self):
39
+ """Should format multiple variables."""
40
+ template = "Topic: {topic}, Style: {style}"
41
+ result = format_prompt(template, {"topic": "AI", "style": "casual"})
42
+ assert result == "Topic: AI, Style: casual"
43
+
44
+ def test_format_empty_variables(self):
45
+ """Should handle empty variables dict."""
46
+ template = "No variables here"
47
+ result = format_prompt(template, {})
48
+ assert result == "No variables here"
49
+
50
+ def test_format_missing_variable_raises(self):
51
+ """Should raise KeyError for missing variable."""
52
+ template = "Hello, {name}!"
53
+ with pytest.raises(KeyError):
54
+ format_prompt(template, {})
55
+
56
+ def test_format_with_numbers(self):
57
+ """Should handle numeric variables."""
58
+ template = "Count: {word_count}"
59
+ result = format_prompt(template, {"word_count": 300})
60
+ assert result == "Count: 300"
61
+
62
+
63
+ class TestPromptExecutorGraphRelative:
64
+ """Tests for PromptExecutor with graph-relative prompts."""
65
+
66
+ def test_execute_with_graph_path_and_prompts_relative(self, tmp_path):
67
+ """Executor should resolve prompts relative to graph when configured."""
68
+ from unittest.mock import MagicMock, patch
69
+
70
+ from yamlgraph.executor import PromptExecutor
71
+
72
+ # Create graph-relative prompt structure
73
+ graph_dir = tmp_path / "questionnaires" / "audit"
74
+ prompts_dir = graph_dir / "prompts"
75
+ prompts_dir.mkdir(parents=True)
76
+
77
+ # Create colocated prompt
78
+ prompt_file = prompts_dir / "opening.yaml"
79
+ prompt_file.write_text(
80
+ """
81
+ system: You are an audit assistant.
82
+ user: Generate opening for {questionnaire_name}.
83
+ """
84
+ )
85
+
86
+ graph_path = graph_dir / "graph.yaml"
87
+ graph_path.touch() # Just needs to exist for path resolution
88
+
89
+ # Mock LLM to avoid actual API calls
90
+ mock_llm = MagicMock()
91
+ mock_llm.invoke.return_value = MagicMock(content="Welcome to the audit.")
92
+
93
+ executor = PromptExecutor()
94
+
95
+ with patch.object(executor, "_get_llm", return_value=mock_llm):
96
+ # Should find prompts/opening.yaml relative to graph_path
97
+ result = executor.execute(
98
+ prompt_name="prompts/opening",
99
+ variables={"questionnaire_name": "Financial Audit"},
100
+ graph_path=graph_path,
101
+ prompts_relative=True,
102
+ )
103
+
104
+ assert result == "Welcome to the audit."
105
+ mock_llm.invoke.assert_called_once()
106
+
107
+ def test_execute_with_prompts_dir_override(self, tmp_path):
108
+ """Executor should use explicit prompts_dir when provided."""
109
+ from unittest.mock import MagicMock, patch
110
+
111
+ from yamlgraph.executor import PromptExecutor
112
+
113
+ # Create prompts in explicit directory
114
+ prompts_dir = tmp_path / "my_prompts"
115
+ prompts_dir.mkdir()
116
+
117
+ prompt_file = prompts_dir / "greeting.yaml"
118
+ prompt_file.write_text(
119
+ """
120
+ system: You are helpful.
121
+ user: Say hello to {name}.
122
+ """
123
+ )
124
+
125
+ mock_llm = MagicMock()
126
+ mock_llm.invoke.return_value = MagicMock(content="Hello!")
127
+
128
+ executor = PromptExecutor()
129
+
130
+ with patch.object(executor, "_get_llm", return_value=mock_llm):
131
+ result = executor.execute(
132
+ prompt_name="greeting",
133
+ variables={"name": "World"},
134
+ prompts_dir=prompts_dir,
135
+ )
136
+
137
+ assert result == "Hello!"
138
+
139
+ def test_execute_prompt_function_passes_path_params(self, tmp_path):
140
+ """execute_prompt() should accept and forward path params."""
141
+ from unittest.mock import MagicMock, patch
142
+
143
+ from yamlgraph.executor import execute_prompt
144
+
145
+ # Create test prompt
146
+ prompts_dir = tmp_path / "prompts"
147
+ prompts_dir.mkdir()
148
+ (prompts_dir / "test.yaml").write_text(
149
+ """
150
+ system: Test system.
151
+ user: Test {msg}.
152
+ """
153
+ )
154
+
155
+ mock_llm = MagicMock()
156
+ mock_llm.invoke.return_value = MagicMock(content="OK")
157
+
158
+ with patch("yamlgraph.executor.get_executor") as mock_get:
159
+ mock_executor = MagicMock()
160
+ mock_executor.execute.return_value = "OK"
161
+ mock_get.return_value = mock_executor
162
+
163
+ execute_prompt(
164
+ prompt_name="test",
165
+ variables={"msg": "hello"},
166
+ prompts_dir=prompts_dir,
167
+ )
168
+
169
+ # Verify path params were forwarded
170
+ mock_executor.execute.assert_called_once()
171
+ call_kwargs = mock_executor.execute.call_args.kwargs
172
+ assert call_kwargs["prompts_dir"] == prompts_dir