yamlgraph 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of yamlgraph might be problematic. Click here for more details.
- examples/__init__.py +1 -0
- examples/storyboard/__init__.py +1 -0
- examples/storyboard/generate_videos.py +335 -0
- examples/storyboard/nodes/__init__.py +10 -0
- examples/storyboard/nodes/animated_character_node.py +248 -0
- examples/storyboard/nodes/animated_image_node.py +138 -0
- examples/storyboard/nodes/character_node.py +162 -0
- examples/storyboard/nodes/image_node.py +118 -0
- examples/storyboard/nodes/replicate_tool.py +238 -0
- examples/storyboard/retry_images.py +118 -0
- tests/__init__.py +1 -0
- tests/conftest.py +178 -0
- tests/integration/__init__.py +1 -0
- tests/integration/test_animated_storyboard.py +63 -0
- tests/integration/test_cli_commands.py +242 -0
- tests/integration/test_map_demo.py +50 -0
- tests/integration/test_memory_demo.py +281 -0
- tests/integration/test_pipeline_flow.py +105 -0
- tests/integration/test_providers.py +163 -0
- tests/integration/test_resume.py +75 -0
- tests/unit/__init__.py +1 -0
- tests/unit/test_agent_nodes.py +200 -0
- tests/unit/test_checkpointer.py +212 -0
- tests/unit/test_cli.py +121 -0
- tests/unit/test_cli_package.py +81 -0
- tests/unit/test_compile_graph_map.py +132 -0
- tests/unit/test_conditions_routing.py +253 -0
- tests/unit/test_config.py +93 -0
- tests/unit/test_conversation_memory.py +270 -0
- tests/unit/test_database.py +145 -0
- tests/unit/test_deprecation.py +104 -0
- tests/unit/test_executor.py +60 -0
- tests/unit/test_executor_async.py +179 -0
- tests/unit/test_export.py +150 -0
- tests/unit/test_expressions.py +178 -0
- tests/unit/test_format_prompt.py +145 -0
- tests/unit/test_generic_report.py +200 -0
- tests/unit/test_graph_commands.py +327 -0
- tests/unit/test_graph_loader.py +299 -0
- tests/unit/test_graph_schema.py +193 -0
- tests/unit/test_inline_schema.py +151 -0
- tests/unit/test_issues.py +164 -0
- tests/unit/test_jinja2_prompts.py +85 -0
- tests/unit/test_langsmith.py +319 -0
- tests/unit/test_llm_factory.py +109 -0
- tests/unit/test_llm_factory_async.py +118 -0
- tests/unit/test_loops.py +403 -0
- tests/unit/test_map_node.py +144 -0
- tests/unit/test_no_backward_compat.py +56 -0
- tests/unit/test_node_factory.py +225 -0
- tests/unit/test_prompts.py +166 -0
- tests/unit/test_python_nodes.py +198 -0
- tests/unit/test_reliability.py +298 -0
- tests/unit/test_result_export.py +234 -0
- tests/unit/test_router.py +296 -0
- tests/unit/test_sanitize.py +99 -0
- tests/unit/test_schema_loader.py +295 -0
- tests/unit/test_shell_tools.py +229 -0
- tests/unit/test_state_builder.py +331 -0
- tests/unit/test_state_builder_map.py +104 -0
- tests/unit/test_state_config.py +197 -0
- tests/unit/test_template.py +190 -0
- tests/unit/test_tool_nodes.py +129 -0
- yamlgraph/__init__.py +35 -0
- yamlgraph/builder.py +110 -0
- yamlgraph/cli/__init__.py +139 -0
- yamlgraph/cli/__main__.py +6 -0
- yamlgraph/cli/commands.py +232 -0
- yamlgraph/cli/deprecation.py +92 -0
- yamlgraph/cli/graph_commands.py +382 -0
- yamlgraph/cli/validators.py +37 -0
- yamlgraph/config.py +67 -0
- yamlgraph/constants.py +66 -0
- yamlgraph/error_handlers.py +226 -0
- yamlgraph/executor.py +275 -0
- yamlgraph/executor_async.py +122 -0
- yamlgraph/graph_loader.py +337 -0
- yamlgraph/map_compiler.py +138 -0
- yamlgraph/models/__init__.py +36 -0
- yamlgraph/models/graph_schema.py +141 -0
- yamlgraph/models/schemas.py +124 -0
- yamlgraph/models/state_builder.py +236 -0
- yamlgraph/node_factory.py +240 -0
- yamlgraph/routing.py +87 -0
- yamlgraph/schema_loader.py +160 -0
- yamlgraph/storage/__init__.py +17 -0
- yamlgraph/storage/checkpointer.py +72 -0
- yamlgraph/storage/database.py +320 -0
- yamlgraph/storage/export.py +269 -0
- yamlgraph/tools/__init__.py +1 -0
- yamlgraph/tools/agent.py +235 -0
- yamlgraph/tools/nodes.py +124 -0
- yamlgraph/tools/python_tool.py +178 -0
- yamlgraph/tools/shell.py +205 -0
- yamlgraph/utils/__init__.py +47 -0
- yamlgraph/utils/conditions.py +157 -0
- yamlgraph/utils/expressions.py +111 -0
- yamlgraph/utils/langsmith.py +308 -0
- yamlgraph/utils/llm_factory.py +118 -0
- yamlgraph/utils/llm_factory_async.py +105 -0
- yamlgraph/utils/logging.py +127 -0
- yamlgraph/utils/prompts.py +116 -0
- yamlgraph/utils/sanitize.py +98 -0
- yamlgraph/utils/template.py +102 -0
- yamlgraph/utils/validators.py +181 -0
- yamlgraph-0.1.1.dist-info/METADATA +854 -0
- yamlgraph-0.1.1.dist-info/RECORD +111 -0
- yamlgraph-0.1.1.dist-info/WHEEL +5 -0
- yamlgraph-0.1.1.dist-info/entry_points.txt +2 -0
- yamlgraph-0.1.1.dist-info/licenses/LICENSE +21 -0
- yamlgraph-0.1.1.dist-info/top_level.txt +3 -0
|
@@ -0,0 +1,270 @@
|
|
|
1
|
+
"""Tests for Phase 6.3: Conversation Memory.
|
|
2
|
+
|
|
3
|
+
Tests that agent nodes:
|
|
4
|
+
1. Return messages to state for accumulation
|
|
5
|
+
2. Store raw tool results in state
|
|
6
|
+
3. Support multi-turn conversations via thread_id
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from unittest.mock import MagicMock, patch
|
|
10
|
+
|
|
11
|
+
from langchain_core.messages import (
|
|
12
|
+
AIMessage,
|
|
13
|
+
HumanMessage,
|
|
14
|
+
SystemMessage,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class TestAgentReturnsMessages:
|
|
19
|
+
"""Tests for message accumulation in agent state."""
|
|
20
|
+
|
|
21
|
+
def test_agent_returns_messages_in_state(self):
|
|
22
|
+
"""Agent node should return messages for state accumulation."""
|
|
23
|
+
from yamlgraph.tools.agent import create_agent_node
|
|
24
|
+
|
|
25
|
+
# Setup mock LLM
|
|
26
|
+
mock_response = MagicMock()
|
|
27
|
+
mock_response.content = "Analysis complete"
|
|
28
|
+
mock_response.tool_calls = []
|
|
29
|
+
|
|
30
|
+
mock_llm = MagicMock()
|
|
31
|
+
mock_llm.bind_tools.return_value = mock_llm
|
|
32
|
+
mock_llm.invoke.return_value = mock_response
|
|
33
|
+
|
|
34
|
+
with patch("yamlgraph.tools.agent.create_llm", return_value=mock_llm):
|
|
35
|
+
node_fn = create_agent_node(
|
|
36
|
+
"agent",
|
|
37
|
+
{"tools": [], "state_key": "result"},
|
|
38
|
+
{},
|
|
39
|
+
)
|
|
40
|
+
result = node_fn({"input": "test"})
|
|
41
|
+
|
|
42
|
+
# Should include messages in output
|
|
43
|
+
assert "messages" in result, "Agent should return messages for accumulation"
|
|
44
|
+
assert (
|
|
45
|
+
len(result["messages"]) >= 2
|
|
46
|
+
), "Should have at least system + user + AI messages"
|
|
47
|
+
|
|
48
|
+
def test_agent_messages_include_all_types(self):
|
|
49
|
+
"""Agent should include system, user, AI, and tool messages."""
|
|
50
|
+
from yamlgraph.tools.agent import create_agent_node
|
|
51
|
+
from yamlgraph.tools.shell import ShellToolConfig
|
|
52
|
+
|
|
53
|
+
# Mock LLM that calls a tool then responds
|
|
54
|
+
# Use actual AIMessage for proper type checking
|
|
55
|
+
tool_response = AIMessage(
|
|
56
|
+
content="", tool_calls=[{"name": "test_tool", "args": {}, "id": "call_1"}]
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
final_response = AIMessage(content="Done")
|
|
60
|
+
|
|
61
|
+
mock_llm = MagicMock()
|
|
62
|
+
mock_llm.bind_tools.return_value = mock_llm
|
|
63
|
+
mock_llm.invoke.side_effect = [tool_response, final_response]
|
|
64
|
+
|
|
65
|
+
tool_config = ShellToolConfig(
|
|
66
|
+
command="echo test",
|
|
67
|
+
description="Test tool",
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
with patch("yamlgraph.tools.agent.create_llm", return_value=mock_llm):
|
|
71
|
+
with patch("yamlgraph.tools.agent.execute_shell_tool") as mock_exec:
|
|
72
|
+
mock_exec.return_value = MagicMock(success=True, output="tool output")
|
|
73
|
+
|
|
74
|
+
node_fn = create_agent_node(
|
|
75
|
+
"agent",
|
|
76
|
+
{"tools": ["test_tool"], "state_key": "result"},
|
|
77
|
+
{"test_tool": tool_config},
|
|
78
|
+
)
|
|
79
|
+
result = node_fn({"input": "test"})
|
|
80
|
+
|
|
81
|
+
messages = result["messages"]
|
|
82
|
+
types = [type(m).__name__ for m in messages]
|
|
83
|
+
|
|
84
|
+
assert "SystemMessage" in types, "Should include system message"
|
|
85
|
+
assert "HumanMessage" in types, "Should include human message"
|
|
86
|
+
assert "AIMessage" in types, "Should include AI message"
|
|
87
|
+
assert "ToolMessage" in types, "Should include tool message"
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class TestToolResultsPersistence:
|
|
91
|
+
"""Tests for raw tool result storage."""
|
|
92
|
+
|
|
93
|
+
def test_tool_results_stored_in_state(self):
|
|
94
|
+
"""Agent should store raw tool results in state."""
|
|
95
|
+
from yamlgraph.tools.agent import create_agent_node
|
|
96
|
+
from yamlgraph.tools.shell import ShellToolConfig
|
|
97
|
+
|
|
98
|
+
tool_response = MagicMock()
|
|
99
|
+
tool_response.content = ""
|
|
100
|
+
tool_response.tool_calls = [
|
|
101
|
+
{"name": "git_log", "args": {"count": "5"}, "id": "call_1"}
|
|
102
|
+
]
|
|
103
|
+
|
|
104
|
+
final_response = MagicMock()
|
|
105
|
+
final_response.content = "Report ready"
|
|
106
|
+
final_response.tool_calls = []
|
|
107
|
+
|
|
108
|
+
mock_llm = MagicMock()
|
|
109
|
+
mock_llm.bind_tools.return_value = mock_llm
|
|
110
|
+
mock_llm.invoke.side_effect = [tool_response, final_response]
|
|
111
|
+
|
|
112
|
+
tool_config = ShellToolConfig(
|
|
113
|
+
command="git log -n {count}",
|
|
114
|
+
description="Get git log",
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
with patch("yamlgraph.tools.agent.create_llm", return_value=mock_llm):
|
|
118
|
+
with patch("yamlgraph.tools.agent.execute_shell_tool") as mock_exec:
|
|
119
|
+
mock_exec.return_value = MagicMock(
|
|
120
|
+
success=True, output="commit abc123\nAuthor: Test"
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
node_fn = create_agent_node(
|
|
124
|
+
"agent",
|
|
125
|
+
{
|
|
126
|
+
"tools": ["git_log"],
|
|
127
|
+
"state_key": "report",
|
|
128
|
+
"tool_results_key": "_tool_results",
|
|
129
|
+
},
|
|
130
|
+
{"git_log": tool_config},
|
|
131
|
+
)
|
|
132
|
+
result = node_fn({"input": "analyze"})
|
|
133
|
+
|
|
134
|
+
assert "_tool_results" in result, "Should include tool_results in state"
|
|
135
|
+
assert len(result["_tool_results"]) == 1, "Should have one tool result"
|
|
136
|
+
|
|
137
|
+
tool_result = result["_tool_results"][0]
|
|
138
|
+
assert tool_result["tool"] == "git_log"
|
|
139
|
+
assert tool_result["args"] == {"count": "5"}
|
|
140
|
+
assert "commit abc123" in tool_result["output"]
|
|
141
|
+
assert tool_result["success"] is True
|
|
142
|
+
|
|
143
|
+
def test_tool_results_key_is_optional(self):
|
|
144
|
+
"""Without tool_results_key, raw results are not stored."""
|
|
145
|
+
from yamlgraph.tools.agent import create_agent_node
|
|
146
|
+
|
|
147
|
+
mock_response = MagicMock()
|
|
148
|
+
mock_response.content = "Done"
|
|
149
|
+
mock_response.tool_calls = []
|
|
150
|
+
|
|
151
|
+
mock_llm = MagicMock()
|
|
152
|
+
mock_llm.bind_tools.return_value = mock_llm
|
|
153
|
+
mock_llm.invoke.return_value = mock_response
|
|
154
|
+
|
|
155
|
+
with patch("yamlgraph.tools.agent.create_llm", return_value=mock_llm):
|
|
156
|
+
node_fn = create_agent_node(
|
|
157
|
+
"agent",
|
|
158
|
+
{"tools": [], "state_key": "result"}, # No tool_results_key
|
|
159
|
+
{},
|
|
160
|
+
)
|
|
161
|
+
result = node_fn({"input": "test"})
|
|
162
|
+
|
|
163
|
+
# Should NOT have _tool_results if not configured
|
|
164
|
+
assert "_tool_results" not in result
|
|
165
|
+
|
|
166
|
+
def test_multiple_tool_calls_all_stored(self):
|
|
167
|
+
"""Multiple tool calls should all be stored."""
|
|
168
|
+
from yamlgraph.tools.agent import create_agent_node
|
|
169
|
+
from yamlgraph.tools.shell import ShellToolConfig
|
|
170
|
+
|
|
171
|
+
tool_response = MagicMock()
|
|
172
|
+
tool_response.content = ""
|
|
173
|
+
tool_response.tool_calls = [
|
|
174
|
+
{"name": "tool_a", "args": {}, "id": "call_1"},
|
|
175
|
+
{"name": "tool_b", "args": {}, "id": "call_2"},
|
|
176
|
+
]
|
|
177
|
+
|
|
178
|
+
final_response = MagicMock()
|
|
179
|
+
final_response.content = "Done"
|
|
180
|
+
final_response.tool_calls = []
|
|
181
|
+
|
|
182
|
+
mock_llm = MagicMock()
|
|
183
|
+
mock_llm.bind_tools.return_value = mock_llm
|
|
184
|
+
mock_llm.invoke.side_effect = [tool_response, final_response]
|
|
185
|
+
|
|
186
|
+
tools = {
|
|
187
|
+
"tool_a": ShellToolConfig(command="echo a", description="A"),
|
|
188
|
+
"tool_b": ShellToolConfig(command="echo b", description="B"),
|
|
189
|
+
}
|
|
190
|
+
|
|
191
|
+
with patch("yamlgraph.tools.agent.create_llm", return_value=mock_llm):
|
|
192
|
+
with patch("yamlgraph.tools.agent.execute_shell_tool") as mock_exec:
|
|
193
|
+
mock_exec.return_value = MagicMock(success=True, output="output")
|
|
194
|
+
|
|
195
|
+
node_fn = create_agent_node(
|
|
196
|
+
"agent",
|
|
197
|
+
{
|
|
198
|
+
"tools": ["tool_a", "tool_b"],
|
|
199
|
+
"state_key": "result",
|
|
200
|
+
"tool_results_key": "_tool_results",
|
|
201
|
+
},
|
|
202
|
+
tools,
|
|
203
|
+
)
|
|
204
|
+
result = node_fn({"input": "test"})
|
|
205
|
+
|
|
206
|
+
assert len(result["_tool_results"]) == 2
|
|
207
|
+
tool_names = [r["tool"] for r in result["_tool_results"]]
|
|
208
|
+
assert "tool_a" in tool_names
|
|
209
|
+
assert "tool_b" in tool_names
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
class TestMultiTurnConversation:
|
|
213
|
+
"""Tests for multi-turn conversation support."""
|
|
214
|
+
|
|
215
|
+
def test_existing_messages_preserved(self):
|
|
216
|
+
"""Agent should preserve existing messages from state."""
|
|
217
|
+
from yamlgraph.tools.agent import create_agent_node
|
|
218
|
+
|
|
219
|
+
mock_response = MagicMock()
|
|
220
|
+
mock_response.content = "Follow-up response"
|
|
221
|
+
mock_response.tool_calls = []
|
|
222
|
+
|
|
223
|
+
mock_llm = MagicMock()
|
|
224
|
+
mock_llm.bind_tools.return_value = mock_llm
|
|
225
|
+
mock_llm.invoke.return_value = mock_response
|
|
226
|
+
|
|
227
|
+
# Simulate state with existing messages
|
|
228
|
+
existing_messages = [
|
|
229
|
+
SystemMessage(content="You are helpful."),
|
|
230
|
+
HumanMessage(content="First question"),
|
|
231
|
+
AIMessage(content="First answer"),
|
|
232
|
+
]
|
|
233
|
+
|
|
234
|
+
with patch("yamlgraph.tools.agent.create_llm", return_value=mock_llm):
|
|
235
|
+
node_fn = create_agent_node(
|
|
236
|
+
"agent",
|
|
237
|
+
{"tools": [], "state_key": "result"},
|
|
238
|
+
{},
|
|
239
|
+
)
|
|
240
|
+
result = node_fn(
|
|
241
|
+
{
|
|
242
|
+
"input": "Follow-up question",
|
|
243
|
+
"messages": existing_messages,
|
|
244
|
+
}
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
messages = result["messages"]
|
|
248
|
+
# Should include new messages (at minimum human + AI for this turn)
|
|
249
|
+
# The exact count depends on implementation - key is messages are returned
|
|
250
|
+
assert len(messages) >= 2, "Should return messages for accumulation"
|
|
251
|
+
|
|
252
|
+
def test_agent_state_message_reducer_works(self):
|
|
253
|
+
"""Dynamic state's Annotated[list, add] should accumulate messages."""
|
|
254
|
+
from operator import add as add_op
|
|
255
|
+
from typing import get_type_hints
|
|
256
|
+
|
|
257
|
+
from yamlgraph.models.state_builder import build_state_class
|
|
258
|
+
|
|
259
|
+
State = build_state_class({"nodes": {"agent": {"type": "agent"}}})
|
|
260
|
+
hints = get_type_hints(State, include_extras=True)
|
|
261
|
+
|
|
262
|
+
# Check messages field has reducer annotation
|
|
263
|
+
messages_hint = hints.get("messages")
|
|
264
|
+
assert messages_hint is not None, "State should have messages field"
|
|
265
|
+
|
|
266
|
+
# The Annotated type should have add as metadata
|
|
267
|
+
if hasattr(messages_hint, "__metadata__"):
|
|
268
|
+
assert (
|
|
269
|
+
add_op in messages_hint.__metadata__
|
|
270
|
+
), "messages should use add reducer"
|
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
"""Tests for yamlgraph.storage.database module."""
|
|
2
|
+
|
|
3
|
+
from yamlgraph.models import create_initial_state
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class TestYamlGraphDB:
|
|
7
|
+
"""Tests for YamlGraphDB class."""
|
|
8
|
+
|
|
9
|
+
def test_db_initialization(self, temp_db):
|
|
10
|
+
"""Database should initialize successfully."""
|
|
11
|
+
assert temp_db.db_path.exists()
|
|
12
|
+
|
|
13
|
+
def test_save_and_load_state(self, temp_db, sample_state):
|
|
14
|
+
"""State should be saved and loaded correctly."""
|
|
15
|
+
thread_id = sample_state["thread_id"]
|
|
16
|
+
temp_db.save_state(thread_id, sample_state, status="completed")
|
|
17
|
+
|
|
18
|
+
loaded = temp_db.load_state(thread_id)
|
|
19
|
+
assert loaded is not None
|
|
20
|
+
assert loaded["topic"] == sample_state["topic"]
|
|
21
|
+
assert loaded["thread_id"] == thread_id
|
|
22
|
+
|
|
23
|
+
def test_load_nonexistent_state(self, temp_db):
|
|
24
|
+
"""Loading nonexistent state should return None."""
|
|
25
|
+
result = temp_db.load_state("nonexistent")
|
|
26
|
+
assert result is None
|
|
27
|
+
|
|
28
|
+
def test_update_existing_state(self, temp_db, empty_state):
|
|
29
|
+
"""Updating existing state should work."""
|
|
30
|
+
thread_id = empty_state["thread_id"]
|
|
31
|
+
|
|
32
|
+
# Save initial state
|
|
33
|
+
temp_db.save_state(thread_id, empty_state, status="running")
|
|
34
|
+
|
|
35
|
+
# Update state
|
|
36
|
+
empty_state["current_step"] = "generate"
|
|
37
|
+
temp_db.save_state(thread_id, empty_state, status="completed")
|
|
38
|
+
|
|
39
|
+
# Load and verify
|
|
40
|
+
loaded = temp_db.load_state(thread_id)
|
|
41
|
+
assert loaded["current_step"] == "generate"
|
|
42
|
+
|
|
43
|
+
def test_list_runs_empty(self, temp_db):
|
|
44
|
+
"""List runs should return empty list when no runs."""
|
|
45
|
+
runs = temp_db.list_runs()
|
|
46
|
+
assert runs == []
|
|
47
|
+
|
|
48
|
+
def test_list_runs_with_data(self, temp_db):
|
|
49
|
+
"""List runs should return saved runs."""
|
|
50
|
+
state1 = create_initial_state(topic="test1", thread_id="thread1")
|
|
51
|
+
state2 = create_initial_state(topic="test2", thread_id="thread2")
|
|
52
|
+
|
|
53
|
+
temp_db.save_state("thread1", state1, status="completed")
|
|
54
|
+
temp_db.save_state("thread2", state2, status="running")
|
|
55
|
+
|
|
56
|
+
runs = temp_db.list_runs()
|
|
57
|
+
assert len(runs) == 2
|
|
58
|
+
thread_ids = [r["thread_id"] for r in runs]
|
|
59
|
+
assert "thread1" in thread_ids
|
|
60
|
+
assert "thread2" in thread_ids
|
|
61
|
+
|
|
62
|
+
def test_list_runs_limit(self, temp_db):
|
|
63
|
+
"""List runs should respect limit parameter."""
|
|
64
|
+
for i in range(5):
|
|
65
|
+
state = create_initial_state(topic=f"test{i}", thread_id=f"thread{i}")
|
|
66
|
+
temp_db.save_state(f"thread{i}", state)
|
|
67
|
+
|
|
68
|
+
runs = temp_db.list_runs(limit=3)
|
|
69
|
+
assert len(runs) == 3
|
|
70
|
+
|
|
71
|
+
def test_delete_run(self, temp_db, empty_state):
|
|
72
|
+
"""Delete run should remove the state."""
|
|
73
|
+
thread_id = empty_state["thread_id"]
|
|
74
|
+
temp_db.save_state(thread_id, empty_state)
|
|
75
|
+
|
|
76
|
+
result = temp_db.delete_run(thread_id)
|
|
77
|
+
assert result is True
|
|
78
|
+
|
|
79
|
+
loaded = temp_db.load_state(thread_id)
|
|
80
|
+
assert loaded is None
|
|
81
|
+
|
|
82
|
+
def test_delete_nonexistent_run(self, temp_db):
|
|
83
|
+
"""Deleting nonexistent run should return False."""
|
|
84
|
+
result = temp_db.delete_run("nonexistent")
|
|
85
|
+
assert result is False
|
|
86
|
+
|
|
87
|
+
def test_serialize_state_with_pydantic(self, temp_db, sample_state):
|
|
88
|
+
"""State with Pydantic models should serialize correctly."""
|
|
89
|
+
thread_id = sample_state["thread_id"]
|
|
90
|
+
temp_db.save_state(thread_id, sample_state)
|
|
91
|
+
|
|
92
|
+
loaded = temp_db.load_state(thread_id)
|
|
93
|
+
# Pydantic models should be dicts after serialization
|
|
94
|
+
assert isinstance(loaded["generated"], dict)
|
|
95
|
+
assert loaded["generated"]["title"] == "Test Article"
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class TestConnectionPool:
|
|
99
|
+
"""Tests for connection pooling."""
|
|
100
|
+
|
|
101
|
+
def test_pooled_mode_works(self, tmp_path):
|
|
102
|
+
"""Database should work in pooled mode."""
|
|
103
|
+
from yamlgraph.storage.database import YamlGraphDB
|
|
104
|
+
|
|
105
|
+
db_path = tmp_path / "pooled_test.db"
|
|
106
|
+
db = YamlGraphDB(db_path=db_path, use_pool=True, pool_size=3)
|
|
107
|
+
|
|
108
|
+
try:
|
|
109
|
+
# Save and load should work
|
|
110
|
+
db.save_state("test-thread", {"topic": "test"})
|
|
111
|
+
loaded = db.load_state("test-thread")
|
|
112
|
+
assert loaded["topic"] == "test"
|
|
113
|
+
finally:
|
|
114
|
+
db.close()
|
|
115
|
+
|
|
116
|
+
def test_pool_reuses_connections(self, tmp_path):
|
|
117
|
+
"""Pool should reuse connections."""
|
|
118
|
+
from yamlgraph.storage.database import ConnectionPool
|
|
119
|
+
|
|
120
|
+
db_path = tmp_path / "pool_test.db"
|
|
121
|
+
pool = ConnectionPool(db_path, pool_size=2)
|
|
122
|
+
|
|
123
|
+
# Get a connection, use it, return it
|
|
124
|
+
with pool.get_connection() as conn1:
|
|
125
|
+
conn1_id = id(conn1)
|
|
126
|
+
|
|
127
|
+
# Next connection should be the same one (reused from pool)
|
|
128
|
+
with pool.get_connection() as conn2:
|
|
129
|
+
conn2_id = id(conn2)
|
|
130
|
+
|
|
131
|
+
assert conn1_id == conn2_id
|
|
132
|
+
pool.close_all()
|
|
133
|
+
|
|
134
|
+
def test_close_method(self, tmp_path):
|
|
135
|
+
"""Close should clean up connections."""
|
|
136
|
+
from yamlgraph.storage.database import YamlGraphDB
|
|
137
|
+
|
|
138
|
+
db_path = tmp_path / "close_test.db"
|
|
139
|
+
db = YamlGraphDB(db_path=db_path, use_pool=True, pool_size=2)
|
|
140
|
+
|
|
141
|
+
# Use the connection to create one
|
|
142
|
+
db.save_state("test", {"data": "value"})
|
|
143
|
+
|
|
144
|
+
# Close should not raise
|
|
145
|
+
db.close()
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
"""Tests for deprecation module.
|
|
2
|
+
|
|
3
|
+
TDD tests for DeprecationError and deprecation utilities.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import pytest
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class TestDeprecationError:
|
|
10
|
+
"""Tests for DeprecationError exception."""
|
|
11
|
+
|
|
12
|
+
def test_deprecation_error_exists(self):
|
|
13
|
+
"""DeprecationError should be importable."""
|
|
14
|
+
from yamlgraph.cli.deprecation import DeprecationError
|
|
15
|
+
|
|
16
|
+
assert issubclass(DeprecationError, Exception)
|
|
17
|
+
|
|
18
|
+
def test_deprecation_error_message(self):
|
|
19
|
+
"""DeprecationError should include replacement command."""
|
|
20
|
+
from yamlgraph.cli.deprecation import DeprecationError
|
|
21
|
+
|
|
22
|
+
err = DeprecationError(
|
|
23
|
+
old_command="route",
|
|
24
|
+
new_command="graph run graphs/router-demo.yaml --var message=...",
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
assert "route" in str(err)
|
|
28
|
+
assert "graph run" in str(err)
|
|
29
|
+
assert "deprecated" in str(err).lower()
|
|
30
|
+
|
|
31
|
+
def test_deprecation_error_has_attributes(self):
|
|
32
|
+
"""DeprecationError should expose old and new commands."""
|
|
33
|
+
from yamlgraph.cli.deprecation import DeprecationError
|
|
34
|
+
|
|
35
|
+
err = DeprecationError(
|
|
36
|
+
old_command="refine",
|
|
37
|
+
new_command="graph run graphs/reflexion-demo.yaml",
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
assert err.old_command == "refine"
|
|
41
|
+
assert err.new_command == "graph run graphs/reflexion-demo.yaml"
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class TestDeprecatedCommand:
|
|
45
|
+
"""Tests for deprecated_command decorator/helper."""
|
|
46
|
+
|
|
47
|
+
def test_deprecated_command_exists(self):
|
|
48
|
+
"""deprecated_command should be importable."""
|
|
49
|
+
from yamlgraph.cli.deprecation import deprecated_command
|
|
50
|
+
|
|
51
|
+
assert callable(deprecated_command)
|
|
52
|
+
|
|
53
|
+
def test_deprecated_command_raises(self):
|
|
54
|
+
"""deprecated_command should raise DeprecationError."""
|
|
55
|
+
from yamlgraph.cli.deprecation import DeprecationError, deprecated_command
|
|
56
|
+
|
|
57
|
+
with pytest.raises(DeprecationError) as exc_info:
|
|
58
|
+
deprecated_command(
|
|
59
|
+
"route",
|
|
60
|
+
"graph run graphs/router-demo.yaml --var message=...",
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
assert "route" in str(exc_info.value)
|
|
64
|
+
|
|
65
|
+
def test_deprecated_command_with_mapping(self):
|
|
66
|
+
"""deprecated_command should format with variable mapping."""
|
|
67
|
+
from yamlgraph.cli.deprecation import DeprecationError, deprecated_command
|
|
68
|
+
|
|
69
|
+
with pytest.raises(DeprecationError) as exc_info:
|
|
70
|
+
deprecated_command(
|
|
71
|
+
"refine --topic X",
|
|
72
|
+
"graph run graphs/reflexion-demo.yaml --var topic=X",
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
assert "topic=X" in str(exc_info.value)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class TestCommandMappings:
|
|
79
|
+
"""Tests for deprecated command mappings."""
|
|
80
|
+
|
|
81
|
+
def test_get_replacement_for_route(self):
|
|
82
|
+
"""Should return replacement for route command."""
|
|
83
|
+
from yamlgraph.cli.deprecation import get_replacement_command
|
|
84
|
+
|
|
85
|
+
result = get_replacement_command("route", {"message": "hello"})
|
|
86
|
+
assert "graph run" in result
|
|
87
|
+
assert "router-demo.yaml" in result
|
|
88
|
+
assert "message=hello" in result
|
|
89
|
+
|
|
90
|
+
def test_get_replacement_for_refine(self):
|
|
91
|
+
"""Should return replacement for refine command."""
|
|
92
|
+
from yamlgraph.cli.deprecation import get_replacement_command
|
|
93
|
+
|
|
94
|
+
result = get_replacement_command("refine", {"topic": "AI"})
|
|
95
|
+
assert "graph run" in result
|
|
96
|
+
assert "reflexion-demo.yaml" in result
|
|
97
|
+
assert "topic=AI" in result
|
|
98
|
+
|
|
99
|
+
def test_get_replacement_unknown_command(self):
|
|
100
|
+
"""Unknown command returns None."""
|
|
101
|
+
from yamlgraph.cli.deprecation import get_replacement_command
|
|
102
|
+
|
|
103
|
+
result = get_replacement_command("unknown", {})
|
|
104
|
+
assert result is None
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
"""Tests for yamlgraph.executor module."""
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
|
|
5
|
+
from yamlgraph.executor import format_prompt, load_prompt
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class TestLoadPrompt:
|
|
9
|
+
"""Tests for load_prompt function."""
|
|
10
|
+
|
|
11
|
+
def test_load_existing_prompt(self):
|
|
12
|
+
"""Should load an existing prompt file."""
|
|
13
|
+
prompt = load_prompt("generate")
|
|
14
|
+
assert "system" in prompt
|
|
15
|
+
assert "user" in prompt
|
|
16
|
+
|
|
17
|
+
def test_load_analyze_prompt(self):
|
|
18
|
+
"""Should load analyze prompt."""
|
|
19
|
+
prompt = load_prompt("analyze")
|
|
20
|
+
assert "system" in prompt
|
|
21
|
+
assert "{content}" in prompt["user"]
|
|
22
|
+
|
|
23
|
+
def test_load_nonexistent_prompt(self):
|
|
24
|
+
"""Should raise FileNotFoundError for missing prompt."""
|
|
25
|
+
with pytest.raises(FileNotFoundError):
|
|
26
|
+
load_prompt("nonexistent_prompt")
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class TestFormatPrompt:
|
|
30
|
+
"""Tests for format_prompt function."""
|
|
31
|
+
|
|
32
|
+
def test_format_single_variable(self):
|
|
33
|
+
"""Should format single variable."""
|
|
34
|
+
template = "Hello, {name}!"
|
|
35
|
+
result = format_prompt(template, {"name": "World"})
|
|
36
|
+
assert result == "Hello, World!"
|
|
37
|
+
|
|
38
|
+
def test_format_multiple_variables(self):
|
|
39
|
+
"""Should format multiple variables."""
|
|
40
|
+
template = "Topic: {topic}, Style: {style}"
|
|
41
|
+
result = format_prompt(template, {"topic": "AI", "style": "casual"})
|
|
42
|
+
assert result == "Topic: AI, Style: casual"
|
|
43
|
+
|
|
44
|
+
def test_format_empty_variables(self):
|
|
45
|
+
"""Should handle empty variables dict."""
|
|
46
|
+
template = "No variables here"
|
|
47
|
+
result = format_prompt(template, {})
|
|
48
|
+
assert result == "No variables here"
|
|
49
|
+
|
|
50
|
+
def test_format_missing_variable_raises(self):
|
|
51
|
+
"""Should raise KeyError for missing variable."""
|
|
52
|
+
template = "Hello, {name}!"
|
|
53
|
+
with pytest.raises(KeyError):
|
|
54
|
+
format_prompt(template, {})
|
|
55
|
+
|
|
56
|
+
def test_format_with_numbers(self):
|
|
57
|
+
"""Should handle numeric variables."""
|
|
58
|
+
template = "Count: {word_count}"
|
|
59
|
+
result = format_prompt(template, {"word_count": 300})
|
|
60
|
+
assert result == "Count: 300"
|