yamlgraph 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of yamlgraph might be problematic. Click here for more details.
- examples/__init__.py +1 -0
- examples/storyboard/__init__.py +1 -0
- examples/storyboard/generate_videos.py +335 -0
- examples/storyboard/nodes/__init__.py +10 -0
- examples/storyboard/nodes/animated_character_node.py +248 -0
- examples/storyboard/nodes/animated_image_node.py +138 -0
- examples/storyboard/nodes/character_node.py +162 -0
- examples/storyboard/nodes/image_node.py +118 -0
- examples/storyboard/nodes/replicate_tool.py +238 -0
- examples/storyboard/retry_images.py +118 -0
- tests/__init__.py +1 -0
- tests/conftest.py +178 -0
- tests/integration/__init__.py +1 -0
- tests/integration/test_animated_storyboard.py +63 -0
- tests/integration/test_cli_commands.py +242 -0
- tests/integration/test_map_demo.py +50 -0
- tests/integration/test_memory_demo.py +281 -0
- tests/integration/test_pipeline_flow.py +105 -0
- tests/integration/test_providers.py +163 -0
- tests/integration/test_resume.py +75 -0
- tests/unit/__init__.py +1 -0
- tests/unit/test_agent_nodes.py +200 -0
- tests/unit/test_checkpointer.py +212 -0
- tests/unit/test_cli.py +121 -0
- tests/unit/test_cli_package.py +81 -0
- tests/unit/test_compile_graph_map.py +132 -0
- tests/unit/test_conditions_routing.py +253 -0
- tests/unit/test_config.py +93 -0
- tests/unit/test_conversation_memory.py +270 -0
- tests/unit/test_database.py +145 -0
- tests/unit/test_deprecation.py +104 -0
- tests/unit/test_executor.py +60 -0
- tests/unit/test_executor_async.py +179 -0
- tests/unit/test_export.py +150 -0
- tests/unit/test_expressions.py +178 -0
- tests/unit/test_format_prompt.py +145 -0
- tests/unit/test_generic_report.py +200 -0
- tests/unit/test_graph_commands.py +327 -0
- tests/unit/test_graph_loader.py +299 -0
- tests/unit/test_graph_schema.py +193 -0
- tests/unit/test_inline_schema.py +151 -0
- tests/unit/test_issues.py +164 -0
- tests/unit/test_jinja2_prompts.py +85 -0
- tests/unit/test_langsmith.py +319 -0
- tests/unit/test_llm_factory.py +109 -0
- tests/unit/test_llm_factory_async.py +118 -0
- tests/unit/test_loops.py +403 -0
- tests/unit/test_map_node.py +144 -0
- tests/unit/test_no_backward_compat.py +56 -0
- tests/unit/test_node_factory.py +225 -0
- tests/unit/test_prompts.py +166 -0
- tests/unit/test_python_nodes.py +198 -0
- tests/unit/test_reliability.py +298 -0
- tests/unit/test_result_export.py +234 -0
- tests/unit/test_router.py +296 -0
- tests/unit/test_sanitize.py +99 -0
- tests/unit/test_schema_loader.py +295 -0
- tests/unit/test_shell_tools.py +229 -0
- tests/unit/test_state_builder.py +331 -0
- tests/unit/test_state_builder_map.py +104 -0
- tests/unit/test_state_config.py +197 -0
- tests/unit/test_template.py +190 -0
- tests/unit/test_tool_nodes.py +129 -0
- yamlgraph/__init__.py +35 -0
- yamlgraph/builder.py +110 -0
- yamlgraph/cli/__init__.py +139 -0
- yamlgraph/cli/__main__.py +6 -0
- yamlgraph/cli/commands.py +232 -0
- yamlgraph/cli/deprecation.py +92 -0
- yamlgraph/cli/graph_commands.py +382 -0
- yamlgraph/cli/validators.py +37 -0
- yamlgraph/config.py +67 -0
- yamlgraph/constants.py +66 -0
- yamlgraph/error_handlers.py +226 -0
- yamlgraph/executor.py +275 -0
- yamlgraph/executor_async.py +122 -0
- yamlgraph/graph_loader.py +337 -0
- yamlgraph/map_compiler.py +138 -0
- yamlgraph/models/__init__.py +36 -0
- yamlgraph/models/graph_schema.py +141 -0
- yamlgraph/models/schemas.py +124 -0
- yamlgraph/models/state_builder.py +236 -0
- yamlgraph/node_factory.py +240 -0
- yamlgraph/routing.py +87 -0
- yamlgraph/schema_loader.py +160 -0
- yamlgraph/storage/__init__.py +17 -0
- yamlgraph/storage/checkpointer.py +72 -0
- yamlgraph/storage/database.py +320 -0
- yamlgraph/storage/export.py +269 -0
- yamlgraph/tools/__init__.py +1 -0
- yamlgraph/tools/agent.py +235 -0
- yamlgraph/tools/nodes.py +124 -0
- yamlgraph/tools/python_tool.py +178 -0
- yamlgraph/tools/shell.py +205 -0
- yamlgraph/utils/__init__.py +47 -0
- yamlgraph/utils/conditions.py +157 -0
- yamlgraph/utils/expressions.py +111 -0
- yamlgraph/utils/langsmith.py +308 -0
- yamlgraph/utils/llm_factory.py +118 -0
- yamlgraph/utils/llm_factory_async.py +105 -0
- yamlgraph/utils/logging.py +127 -0
- yamlgraph/utils/prompts.py +116 -0
- yamlgraph/utils/sanitize.py +98 -0
- yamlgraph/utils/template.py +102 -0
- yamlgraph/utils/validators.py +181 -0
- yamlgraph-0.1.1.dist-info/METADATA +854 -0
- yamlgraph-0.1.1.dist-info/RECORD +111 -0
- yamlgraph-0.1.1.dist-info/WHEEL +5 -0
- yamlgraph-0.1.1.dist-info/entry_points.txt +2 -0
- yamlgraph-0.1.1.dist-info/licenses/LICENSE +21 -0
- yamlgraph-0.1.1.dist-info/top_level.txt +3 -0
|
@@ -0,0 +1,200 @@
|
|
|
1
|
+
"""Tests for agent nodes (type: agent).
|
|
2
|
+
|
|
3
|
+
Agent nodes allow the LLM to autonomously decide which tools to call
|
|
4
|
+
in a loop until it has enough information to respond.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from unittest.mock import MagicMock, patch
|
|
8
|
+
|
|
9
|
+
from yamlgraph.tools.agent import build_langchain_tool, create_agent_node
|
|
10
|
+
from yamlgraph.tools.shell import ShellToolConfig
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class TestBuildLangchainTool:
|
|
14
|
+
"""Tests for build_langchain_tool function."""
|
|
15
|
+
|
|
16
|
+
def test_creates_tool_with_name(self):
|
|
17
|
+
"""Tool has correct name."""
|
|
18
|
+
config = ShellToolConfig(
|
|
19
|
+
command="echo test",
|
|
20
|
+
description="Test tool",
|
|
21
|
+
)
|
|
22
|
+
tool = build_langchain_tool("my_tool", config)
|
|
23
|
+
assert tool.name == "my_tool"
|
|
24
|
+
|
|
25
|
+
def test_creates_tool_with_description(self):
|
|
26
|
+
"""Tool has correct description."""
|
|
27
|
+
config = ShellToolConfig(
|
|
28
|
+
command="echo test",
|
|
29
|
+
description="A helpful test tool",
|
|
30
|
+
)
|
|
31
|
+
tool = build_langchain_tool("test", config)
|
|
32
|
+
assert tool.description == "A helpful test tool"
|
|
33
|
+
|
|
34
|
+
def test_tool_executes_command(self):
|
|
35
|
+
"""Tool invocation runs shell command."""
|
|
36
|
+
config = ShellToolConfig(
|
|
37
|
+
command="echo {message}",
|
|
38
|
+
description="Echo a message",
|
|
39
|
+
)
|
|
40
|
+
tool = build_langchain_tool("echo", config)
|
|
41
|
+
result = tool.invoke({"message": "hello"})
|
|
42
|
+
assert "hello" in result
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class TestCreateAgentNode:
|
|
46
|
+
"""Tests for create_agent_node function."""
|
|
47
|
+
|
|
48
|
+
@patch("yamlgraph.tools.agent.create_llm")
|
|
49
|
+
def test_agent_completes_without_tools(self, mock_create_llm):
|
|
50
|
+
"""Agent can finish with no tool calls."""
|
|
51
|
+
# Mock LLM that returns a direct answer (no tool calls)
|
|
52
|
+
mock_llm = MagicMock()
|
|
53
|
+
mock_response = MagicMock()
|
|
54
|
+
mock_response.tool_calls = []
|
|
55
|
+
mock_response.content = "The answer is 42"
|
|
56
|
+
mock_llm.bind_tools.return_value = mock_llm
|
|
57
|
+
mock_llm.invoke.return_value = mock_response
|
|
58
|
+
mock_create_llm.return_value = mock_llm
|
|
59
|
+
|
|
60
|
+
tools = {
|
|
61
|
+
"search": ShellToolConfig(command="echo search", description="Search"),
|
|
62
|
+
}
|
|
63
|
+
node_config = {
|
|
64
|
+
"prompt": "agent",
|
|
65
|
+
"tools": ["search"],
|
|
66
|
+
"max_iterations": 5,
|
|
67
|
+
"state_key": "result",
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
node_fn = create_agent_node("agent", node_config, tools)
|
|
71
|
+
result = node_fn({"input": "What is the meaning of life?"})
|
|
72
|
+
|
|
73
|
+
assert result["result"] == "The answer is 42"
|
|
74
|
+
assert result["_agent_iterations"] == 1
|
|
75
|
+
|
|
76
|
+
@patch("yamlgraph.tools.agent.create_llm")
|
|
77
|
+
def test_agent_calls_tool(self, mock_create_llm):
|
|
78
|
+
"""LLM tool call executes shell command."""
|
|
79
|
+
# Mock LLM that first calls a tool, then returns answer
|
|
80
|
+
mock_llm = MagicMock()
|
|
81
|
+
|
|
82
|
+
# First response: call a tool
|
|
83
|
+
first_response = MagicMock()
|
|
84
|
+
first_response.tool_calls = [
|
|
85
|
+
{"id": "call1", "name": "echo", "args": {"message": "test"}}
|
|
86
|
+
]
|
|
87
|
+
first_response.content = ""
|
|
88
|
+
|
|
89
|
+
# Second response: final answer
|
|
90
|
+
second_response = MagicMock()
|
|
91
|
+
second_response.tool_calls = []
|
|
92
|
+
second_response.content = "I echoed: test"
|
|
93
|
+
|
|
94
|
+
mock_llm.bind_tools.return_value = mock_llm
|
|
95
|
+
mock_llm.invoke.side_effect = [first_response, second_response]
|
|
96
|
+
mock_create_llm.return_value = mock_llm
|
|
97
|
+
|
|
98
|
+
tools = {
|
|
99
|
+
"echo": ShellToolConfig(command="echo {message}", description="Echo"),
|
|
100
|
+
}
|
|
101
|
+
node_config = {
|
|
102
|
+
"prompt": "agent",
|
|
103
|
+
"tools": ["echo"],
|
|
104
|
+
"max_iterations": 5,
|
|
105
|
+
"state_key": "result",
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
node_fn = create_agent_node("agent", node_config, tools)
|
|
109
|
+
result = node_fn({"input": "Echo something"})
|
|
110
|
+
|
|
111
|
+
assert result["result"] == "I echoed: test"
|
|
112
|
+
assert result["_agent_iterations"] == 2
|
|
113
|
+
|
|
114
|
+
@patch("yamlgraph.tools.agent.create_llm")
|
|
115
|
+
def test_max_iterations_enforced(self, mock_create_llm):
|
|
116
|
+
"""Stops after max_iterations reached."""
|
|
117
|
+
# Mock LLM that always calls a tool (never finishes)
|
|
118
|
+
mock_llm = MagicMock()
|
|
119
|
+
mock_response = MagicMock()
|
|
120
|
+
mock_response.tool_calls = [
|
|
121
|
+
{"id": "call1", "name": "search", "args": {"query": "more"}}
|
|
122
|
+
]
|
|
123
|
+
mock_response.content = "Still searching..."
|
|
124
|
+
mock_llm.bind_tools.return_value = mock_llm
|
|
125
|
+
mock_llm.invoke.return_value = mock_response
|
|
126
|
+
mock_create_llm.return_value = mock_llm
|
|
127
|
+
|
|
128
|
+
tools = {
|
|
129
|
+
"search": ShellToolConfig(command="echo searching", description="Search"),
|
|
130
|
+
}
|
|
131
|
+
node_config = {
|
|
132
|
+
"prompt": "agent",
|
|
133
|
+
"tools": ["search"],
|
|
134
|
+
"max_iterations": 3,
|
|
135
|
+
"state_key": "result",
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
node_fn = create_agent_node("agent", node_config, tools)
|
|
139
|
+
result = node_fn({"input": "Search forever"})
|
|
140
|
+
|
|
141
|
+
# Should stop at max_iterations
|
|
142
|
+
assert result["_agent_limit_reached"] is True
|
|
143
|
+
assert mock_llm.invoke.call_count == 3
|
|
144
|
+
|
|
145
|
+
@patch("yamlgraph.tools.agent.create_llm")
|
|
146
|
+
def test_tool_result_returned_to_llm(self, mock_create_llm):
|
|
147
|
+
"""LLM sees tool output in next turn."""
|
|
148
|
+
mock_llm = MagicMock()
|
|
149
|
+
|
|
150
|
+
# First: call tool
|
|
151
|
+
first_response = MagicMock()
|
|
152
|
+
first_response.tool_calls = [
|
|
153
|
+
{"id": "call1", "name": "calc", "args": {"expr": "2+2"}}
|
|
154
|
+
]
|
|
155
|
+
first_response.content = ""
|
|
156
|
+
|
|
157
|
+
# Second: answer based on tool result
|
|
158
|
+
second_response = MagicMock()
|
|
159
|
+
second_response.tool_calls = []
|
|
160
|
+
second_response.content = "The result is 4"
|
|
161
|
+
|
|
162
|
+
mock_llm.bind_tools.return_value = mock_llm
|
|
163
|
+
mock_llm.invoke.side_effect = [first_response, second_response]
|
|
164
|
+
mock_create_llm.return_value = mock_llm
|
|
165
|
+
|
|
166
|
+
tools = {
|
|
167
|
+
"calc": ShellToolConfig(
|
|
168
|
+
command="echo 4", # Simulates python calc
|
|
169
|
+
description="Calculate",
|
|
170
|
+
),
|
|
171
|
+
}
|
|
172
|
+
node_config = {
|
|
173
|
+
"prompt": "agent",
|
|
174
|
+
"tools": ["calc"],
|
|
175
|
+
"max_iterations": 5,
|
|
176
|
+
"state_key": "answer",
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
node_fn = create_agent_node("agent", node_config, tools)
|
|
180
|
+
node_fn({"input": "What is 2+2?"})
|
|
181
|
+
|
|
182
|
+
# Check that second invoke received messages with tool result
|
|
183
|
+
second_call_messages = mock_llm.invoke.call_args_list[1][0][0]
|
|
184
|
+
# Should have: system, user, ai (with tool call), tool result
|
|
185
|
+
assert len(second_call_messages) >= 4
|
|
186
|
+
|
|
187
|
+
def test_default_max_iterations(self):
|
|
188
|
+
"""Default max_iterations is 5."""
|
|
189
|
+
tools = {
|
|
190
|
+
"test": ShellToolConfig(command="echo test", description="Test"),
|
|
191
|
+
}
|
|
192
|
+
node_config = {
|
|
193
|
+
"prompt": "agent",
|
|
194
|
+
"tools": ["test"],
|
|
195
|
+
# No max_iterations specified
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
# Just verify it doesn't fail - actual behavior tested above
|
|
199
|
+
node_fn = create_agent_node("agent", node_config, tools)
|
|
200
|
+
assert callable(node_fn)
|
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
"""Tests for native LangGraph checkpointer integration."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class TestGetCheckpointer:
|
|
7
|
+
"""Tests for get_checkpointer() function."""
|
|
8
|
+
|
|
9
|
+
def test_returns_sqlite_saver_instance(self, tmp_path: Path):
|
|
10
|
+
"""Should return a SqliteSaver instance."""
|
|
11
|
+
from langgraph.checkpoint.sqlite import SqliteSaver
|
|
12
|
+
|
|
13
|
+
from yamlgraph.storage.checkpointer import get_checkpointer
|
|
14
|
+
|
|
15
|
+
db_path = tmp_path / "test.db"
|
|
16
|
+
checkpointer = get_checkpointer(db_path)
|
|
17
|
+
|
|
18
|
+
assert isinstance(checkpointer, SqliteSaver)
|
|
19
|
+
|
|
20
|
+
def test_creates_database_file(self, tmp_path: Path):
|
|
21
|
+
"""Should create the database file on first use."""
|
|
22
|
+
from yamlgraph.storage.checkpointer import get_checkpointer
|
|
23
|
+
|
|
24
|
+
db_path = tmp_path / "test.db"
|
|
25
|
+
assert not db_path.exists()
|
|
26
|
+
|
|
27
|
+
checkpointer = get_checkpointer(db_path)
|
|
28
|
+
# Access connection to trigger file creation
|
|
29
|
+
_ = checkpointer
|
|
30
|
+
|
|
31
|
+
# File created when connection is made
|
|
32
|
+
assert db_path.exists()
|
|
33
|
+
|
|
34
|
+
def test_uses_default_path_when_none(self, monkeypatch, tmp_path: Path):
|
|
35
|
+
"""Should use DATABASE_PATH when db_path is None."""
|
|
36
|
+
from yamlgraph.storage.checkpointer import get_checkpointer
|
|
37
|
+
|
|
38
|
+
# Monkeypatch the default path
|
|
39
|
+
default_db = tmp_path / "default.db"
|
|
40
|
+
monkeypatch.setattr("yamlgraph.storage.checkpointer.DATABASE_PATH", default_db)
|
|
41
|
+
|
|
42
|
+
checkpointer = get_checkpointer(None)
|
|
43
|
+
assert checkpointer is not None
|
|
44
|
+
|
|
45
|
+
def test_accepts_string_path(self, tmp_path: Path):
|
|
46
|
+
"""Should accept string path as well as Path."""
|
|
47
|
+
from yamlgraph.storage.checkpointer import get_checkpointer
|
|
48
|
+
|
|
49
|
+
db_path = str(tmp_path / "test.db")
|
|
50
|
+
checkpointer = get_checkpointer(db_path)
|
|
51
|
+
|
|
52
|
+
assert checkpointer is not None
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class TestCheckpointerWithGraph:
|
|
56
|
+
"""Tests for using checkpointer with a LangGraph StateGraph."""
|
|
57
|
+
|
|
58
|
+
def test_graph_compiles_with_checkpointer(self, tmp_path: Path):
|
|
59
|
+
"""Graph should compile when checkpointer is provided."""
|
|
60
|
+
from typing import TypedDict
|
|
61
|
+
|
|
62
|
+
from langgraph.graph import END, StateGraph
|
|
63
|
+
|
|
64
|
+
from yamlgraph.storage.checkpointer import get_checkpointer
|
|
65
|
+
|
|
66
|
+
class SimpleState(TypedDict, total=False):
|
|
67
|
+
value: str
|
|
68
|
+
|
|
69
|
+
def node_fn(state: SimpleState) -> dict:
|
|
70
|
+
return {"value": "updated"}
|
|
71
|
+
|
|
72
|
+
workflow = StateGraph(SimpleState)
|
|
73
|
+
workflow.add_node("test", node_fn)
|
|
74
|
+
workflow.set_entry_point("test")
|
|
75
|
+
workflow.add_edge("test", END)
|
|
76
|
+
|
|
77
|
+
checkpointer = get_checkpointer(tmp_path / "test.db")
|
|
78
|
+
graph = workflow.compile(checkpointer=checkpointer)
|
|
79
|
+
|
|
80
|
+
assert graph is not None
|
|
81
|
+
|
|
82
|
+
def test_state_persists_with_thread_id(self, tmp_path: Path):
|
|
83
|
+
"""State should persist when using thread_id in config."""
|
|
84
|
+
from typing import TypedDict
|
|
85
|
+
|
|
86
|
+
from langgraph.graph import END, StateGraph
|
|
87
|
+
|
|
88
|
+
from yamlgraph.storage.checkpointer import get_checkpointer
|
|
89
|
+
|
|
90
|
+
class CounterState(TypedDict, total=False):
|
|
91
|
+
count: int
|
|
92
|
+
|
|
93
|
+
def increment(state: CounterState) -> dict:
|
|
94
|
+
return {"count": state.get("count", 0) + 1}
|
|
95
|
+
|
|
96
|
+
workflow = StateGraph(CounterState)
|
|
97
|
+
workflow.add_node("increment", increment)
|
|
98
|
+
workflow.set_entry_point("increment")
|
|
99
|
+
workflow.add_edge("increment", END)
|
|
100
|
+
|
|
101
|
+
checkpointer = get_checkpointer(tmp_path / "test.db")
|
|
102
|
+
graph = workflow.compile(checkpointer=checkpointer)
|
|
103
|
+
|
|
104
|
+
config = {"configurable": {"thread_id": "test-thread-1"}}
|
|
105
|
+
|
|
106
|
+
# First invocation
|
|
107
|
+
result1 = graph.invoke({"count": 0}, config)
|
|
108
|
+
assert result1["count"] == 1
|
|
109
|
+
|
|
110
|
+
# State should be retrievable
|
|
111
|
+
state = graph.get_state(config)
|
|
112
|
+
assert state.values["count"] == 1
|
|
113
|
+
|
|
114
|
+
def test_get_state_history_returns_checkpoints(self, tmp_path: Path):
|
|
115
|
+
"""get_state_history() should return checkpoint history."""
|
|
116
|
+
from typing import TypedDict
|
|
117
|
+
|
|
118
|
+
from langgraph.graph import END, StateGraph
|
|
119
|
+
|
|
120
|
+
from yamlgraph.storage.checkpointer import get_checkpointer
|
|
121
|
+
|
|
122
|
+
class StepState(TypedDict, total=False):
|
|
123
|
+
step: int
|
|
124
|
+
|
|
125
|
+
def step1(state: StepState) -> dict:
|
|
126
|
+
return {"step": 1}
|
|
127
|
+
|
|
128
|
+
def step2(state: StepState) -> dict:
|
|
129
|
+
return {"step": 2}
|
|
130
|
+
|
|
131
|
+
workflow = StateGraph(StepState)
|
|
132
|
+
workflow.add_node("step1", step1)
|
|
133
|
+
workflow.add_node("step2", step2)
|
|
134
|
+
workflow.set_entry_point("step1")
|
|
135
|
+
workflow.add_edge("step1", "step2")
|
|
136
|
+
workflow.add_edge("step2", END)
|
|
137
|
+
|
|
138
|
+
checkpointer = get_checkpointer(tmp_path / "test.db")
|
|
139
|
+
graph = workflow.compile(checkpointer=checkpointer)
|
|
140
|
+
|
|
141
|
+
config = {"configurable": {"thread_id": "history-test"}}
|
|
142
|
+
graph.invoke({}, config)
|
|
143
|
+
|
|
144
|
+
# Get history
|
|
145
|
+
history = list(graph.get_state_history(config))
|
|
146
|
+
|
|
147
|
+
# Should have multiple checkpoints (one per step + initial)
|
|
148
|
+
assert len(history) >= 2
|
|
149
|
+
|
|
150
|
+
# Most recent first
|
|
151
|
+
assert history[0].values["step"] == 2
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
class TestGetStateHistory:
|
|
155
|
+
"""Tests for get_state_history helper function."""
|
|
156
|
+
|
|
157
|
+
def test_returns_list_of_snapshots(self, tmp_path: Path):
|
|
158
|
+
"""get_state_history should return list of StateSnapshot."""
|
|
159
|
+
from typing import TypedDict
|
|
160
|
+
|
|
161
|
+
from langgraph.graph import END, StateGraph
|
|
162
|
+
|
|
163
|
+
from yamlgraph.storage.checkpointer import get_checkpointer, get_state_history
|
|
164
|
+
|
|
165
|
+
class TestState(TypedDict, total=False):
|
|
166
|
+
data: str
|
|
167
|
+
|
|
168
|
+
def node(state: TestState) -> dict:
|
|
169
|
+
return {"data": "done"}
|
|
170
|
+
|
|
171
|
+
workflow = StateGraph(TestState)
|
|
172
|
+
workflow.add_node("test", node)
|
|
173
|
+
workflow.set_entry_point("test")
|
|
174
|
+
workflow.add_edge("test", END)
|
|
175
|
+
|
|
176
|
+
checkpointer = get_checkpointer(tmp_path / "test.db")
|
|
177
|
+
graph = workflow.compile(checkpointer=checkpointer)
|
|
178
|
+
|
|
179
|
+
thread_id = "history-helper-test"
|
|
180
|
+
config = {"configurable": {"thread_id": thread_id}}
|
|
181
|
+
graph.invoke({}, config)
|
|
182
|
+
|
|
183
|
+
history = get_state_history(graph, thread_id)
|
|
184
|
+
|
|
185
|
+
assert isinstance(history, list)
|
|
186
|
+
assert len(history) >= 1
|
|
187
|
+
|
|
188
|
+
def test_empty_history_for_unknown_thread(self, tmp_path: Path):
|
|
189
|
+
"""Should return empty list for non-existent thread."""
|
|
190
|
+
from typing import TypedDict
|
|
191
|
+
|
|
192
|
+
from langgraph.graph import END, StateGraph
|
|
193
|
+
|
|
194
|
+
from yamlgraph.storage.checkpointer import get_checkpointer, get_state_history
|
|
195
|
+
|
|
196
|
+
class TestState(TypedDict, total=False):
|
|
197
|
+
data: str
|
|
198
|
+
|
|
199
|
+
def node(state: TestState) -> dict:
|
|
200
|
+
return {"data": "done"}
|
|
201
|
+
|
|
202
|
+
workflow = StateGraph(TestState)
|
|
203
|
+
workflow.add_node("test", node)
|
|
204
|
+
workflow.set_entry_point("test")
|
|
205
|
+
workflow.add_edge("test", END)
|
|
206
|
+
|
|
207
|
+
checkpointer = get_checkpointer(tmp_path / "test.db")
|
|
208
|
+
graph = workflow.compile(checkpointer=checkpointer)
|
|
209
|
+
|
|
210
|
+
history = get_state_history(graph, "non-existent-thread")
|
|
211
|
+
|
|
212
|
+
assert history == []
|
tests/unit/test_cli.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
"""Tests for yamlgraph.cli module."""
|
|
2
|
+
|
|
3
|
+
import argparse
|
|
4
|
+
|
|
5
|
+
from yamlgraph.cli.validators import validate_run_args
|
|
6
|
+
from yamlgraph.config import MAX_TOPIC_LENGTH, MAX_WORD_COUNT, MIN_WORD_COUNT
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class TestValidateRunArgs:
|
|
10
|
+
"""Tests for validate_run_args function."""
|
|
11
|
+
|
|
12
|
+
def _create_args(self, topic="test topic", word_count=300, style="informative"):
|
|
13
|
+
"""Helper to create args namespace."""
|
|
14
|
+
return argparse.Namespace(
|
|
15
|
+
topic=topic,
|
|
16
|
+
word_count=word_count,
|
|
17
|
+
style=style,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
def test_valid_args(self):
|
|
21
|
+
"""Valid arguments should pass validation."""
|
|
22
|
+
args = self._create_args()
|
|
23
|
+
assert validate_run_args(args) is True
|
|
24
|
+
|
|
25
|
+
def test_empty_topic(self):
|
|
26
|
+
"""Empty topic should fail validation."""
|
|
27
|
+
args = self._create_args(topic="")
|
|
28
|
+
assert validate_run_args(args) is False
|
|
29
|
+
|
|
30
|
+
def test_whitespace_only_topic(self):
|
|
31
|
+
"""Whitespace-only topic should fail validation."""
|
|
32
|
+
args = self._create_args(topic=" ")
|
|
33
|
+
assert validate_run_args(args) is False
|
|
34
|
+
|
|
35
|
+
def test_topic_too_long(self):
|
|
36
|
+
"""Topic exceeding max length should be truncated with warning."""
|
|
37
|
+
long_topic = "x" * (MAX_TOPIC_LENGTH + 100)
|
|
38
|
+
args = self._create_args(topic=long_topic)
|
|
39
|
+
# Should pass but truncate the topic
|
|
40
|
+
assert validate_run_args(args) is True
|
|
41
|
+
assert len(args.topic) == MAX_TOPIC_LENGTH
|
|
42
|
+
|
|
43
|
+
def test_topic_at_max_length(self):
|
|
44
|
+
"""Topic at max length should pass validation."""
|
|
45
|
+
max_topic = "x" * MAX_TOPIC_LENGTH
|
|
46
|
+
args = self._create_args(topic=max_topic)
|
|
47
|
+
assert validate_run_args(args) is True
|
|
48
|
+
|
|
49
|
+
def test_word_count_too_low(self):
|
|
50
|
+
"""Word count below minimum should fail validation."""
|
|
51
|
+
args = self._create_args(word_count=MIN_WORD_COUNT - 1)
|
|
52
|
+
assert validate_run_args(args) is False
|
|
53
|
+
|
|
54
|
+
def test_word_count_too_high(self):
|
|
55
|
+
"""Word count above maximum should fail validation."""
|
|
56
|
+
args = self._create_args(word_count=MAX_WORD_COUNT + 1)
|
|
57
|
+
assert validate_run_args(args) is False
|
|
58
|
+
|
|
59
|
+
def test_word_count_at_min(self):
|
|
60
|
+
"""Word count at minimum should pass validation."""
|
|
61
|
+
args = self._create_args(word_count=MIN_WORD_COUNT)
|
|
62
|
+
assert validate_run_args(args) is True
|
|
63
|
+
|
|
64
|
+
def test_word_count_at_max(self):
|
|
65
|
+
"""Word count at maximum should pass validation."""
|
|
66
|
+
args = self._create_args(word_count=MAX_WORD_COUNT)
|
|
67
|
+
assert validate_run_args(args) is True
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class TestFormatResult:
|
|
71
|
+
"""Tests for generic result formatting."""
|
|
72
|
+
|
|
73
|
+
def test_format_result_with_any_pydantic_model(self, capsys):
|
|
74
|
+
"""CLI should format any Pydantic model, not just known ones."""
|
|
75
|
+
from pydantic import BaseModel
|
|
76
|
+
|
|
77
|
+
from yamlgraph.cli.commands import _format_result
|
|
78
|
+
|
|
79
|
+
class CustomResult(BaseModel):
|
|
80
|
+
title: str
|
|
81
|
+
score: float
|
|
82
|
+
items: list[str]
|
|
83
|
+
|
|
84
|
+
result = {
|
|
85
|
+
"current_step": "done",
|
|
86
|
+
"custom": CustomResult(title="Test Title", score=0.95, items=["a", "b"]),
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
_format_result(result)
|
|
90
|
+
captured = capsys.readouterr()
|
|
91
|
+
assert "Test Title" in captured.out
|
|
92
|
+
assert "0.95" in captured.out
|
|
93
|
+
|
|
94
|
+
def test_format_result_skips_internal_keys(self, capsys):
|
|
95
|
+
"""Internal keys should be skipped."""
|
|
96
|
+
from yamlgraph.cli.commands import _format_result
|
|
97
|
+
|
|
98
|
+
result = {
|
|
99
|
+
"current_step": "done",
|
|
100
|
+
"_route": "positive",
|
|
101
|
+
"_loop_counts": {"draft": 2},
|
|
102
|
+
"response": "Hello!",
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
_format_result(result)
|
|
106
|
+
captured = capsys.readouterr()
|
|
107
|
+
assert "_route" not in captured.out
|
|
108
|
+
assert "_loop_counts" not in captured.out
|
|
109
|
+
assert "Hello!" in captured.out
|
|
110
|
+
|
|
111
|
+
def test_format_result_truncates_long_strings(self, capsys):
|
|
112
|
+
"""Long strings should be truncated."""
|
|
113
|
+
from yamlgraph.cli.commands import _format_result
|
|
114
|
+
|
|
115
|
+
long_text = "x" * 500
|
|
116
|
+
result = {"summary": long_text}
|
|
117
|
+
|
|
118
|
+
_format_result(result)
|
|
119
|
+
captured = capsys.readouterr()
|
|
120
|
+
assert "..." in captured.out
|
|
121
|
+
assert len(captured.out) < 400 # Truncated
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
"""Tests for CLI package structure (Phase 7.1).
|
|
2
|
+
|
|
3
|
+
TDD tests for splitting cli.py into a cli/ package.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import argparse
|
|
7
|
+
|
|
8
|
+
# =============================================================================
|
|
9
|
+
# Package Structure Tests
|
|
10
|
+
# =============================================================================
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class TestCLIPackageStructure:
|
|
14
|
+
"""Tests for CLI package imports."""
|
|
15
|
+
|
|
16
|
+
def test_cli_package_importable(self):
|
|
17
|
+
"""yamlgraph.cli should be importable as package."""
|
|
18
|
+
import yamlgraph.cli
|
|
19
|
+
|
|
20
|
+
assert yamlgraph.cli is not None
|
|
21
|
+
|
|
22
|
+
def test_main_function_available(self):
|
|
23
|
+
"""main() should be available from package."""
|
|
24
|
+
from yamlgraph.cli import main
|
|
25
|
+
|
|
26
|
+
assert callable(main)
|
|
27
|
+
|
|
28
|
+
def test_validators_submodule_exists(self):
|
|
29
|
+
"""validators submodule should exist."""
|
|
30
|
+
from yamlgraph.cli import validators
|
|
31
|
+
|
|
32
|
+
assert validators is not None
|
|
33
|
+
|
|
34
|
+
def test_validate_run_args_in_validators(self):
|
|
35
|
+
"""validate_run_args should be in validators module."""
|
|
36
|
+
from yamlgraph.cli.validators import validate_run_args
|
|
37
|
+
|
|
38
|
+
assert callable(validate_run_args)
|
|
39
|
+
|
|
40
|
+
def test_commands_submodule_exists(self):
|
|
41
|
+
"""commands submodule should exist."""
|
|
42
|
+
from yamlgraph.cli import commands
|
|
43
|
+
|
|
44
|
+
assert commands is not None
|
|
45
|
+
|
|
46
|
+
def test_cmd_list_runs_in_commands(self):
|
|
47
|
+
"""cmd_list_runs should be in commands module."""
|
|
48
|
+
from yamlgraph.cli.commands import cmd_list_runs
|
|
49
|
+
|
|
50
|
+
assert callable(cmd_list_runs)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
# =============================================================================
|
|
54
|
+
# Validator Tests (moved from cli module)
|
|
55
|
+
# =============================================================================
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class TestValidatorsModule:
|
|
59
|
+
"""Tests for validators module functionality."""
|
|
60
|
+
|
|
61
|
+
def _create_run_args(self, topic="test topic", word_count=300, style="informative"):
|
|
62
|
+
"""Helper to create run args namespace."""
|
|
63
|
+
return argparse.Namespace(
|
|
64
|
+
topic=topic,
|
|
65
|
+
word_count=word_count,
|
|
66
|
+
style=style,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
def test_validate_run_args_valid(self):
|
|
70
|
+
"""Valid run args pass validation."""
|
|
71
|
+
from yamlgraph.cli.validators import validate_run_args
|
|
72
|
+
|
|
73
|
+
args = self._create_run_args()
|
|
74
|
+
assert validate_run_args(args) is True
|
|
75
|
+
|
|
76
|
+
def test_validate_run_args_empty_topic(self):
|
|
77
|
+
"""Empty topic fails validation."""
|
|
78
|
+
from yamlgraph.cli.validators import validate_run_args
|
|
79
|
+
|
|
80
|
+
args = self._create_run_args(topic="")
|
|
81
|
+
assert validate_run_args(args) is False
|