yamlgraph 0.3.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- examples/__init__.py +1 -0
- examples/codegen/__init__.py +5 -0
- examples/codegen/models/__init__.py +13 -0
- examples/codegen/models/schemas.py +76 -0
- examples/codegen/tests/__init__.py +1 -0
- examples/codegen/tests/test_ai_helpers.py +235 -0
- examples/codegen/tests/test_ast_analysis.py +174 -0
- examples/codegen/tests/test_code_analysis.py +134 -0
- examples/codegen/tests/test_code_context.py +301 -0
- examples/codegen/tests/test_code_nav.py +89 -0
- examples/codegen/tests/test_dependency_tools.py +119 -0
- examples/codegen/tests/test_example_tools.py +185 -0
- examples/codegen/tests/test_git_tools.py +112 -0
- examples/codegen/tests/test_impl_agent_schemas.py +193 -0
- examples/codegen/tests/test_impl_agent_v4_graph.py +94 -0
- examples/codegen/tests/test_jedi_analysis.py +226 -0
- examples/codegen/tests/test_meta_tools.py +250 -0
- examples/codegen/tests/test_plan_discovery_prompt.py +98 -0
- examples/codegen/tests/test_syntax_tools.py +85 -0
- examples/codegen/tests/test_synthesize_prompt.py +94 -0
- examples/codegen/tests/test_template_tools.py +244 -0
- examples/codegen/tools/__init__.py +80 -0
- examples/codegen/tools/ai_helpers.py +420 -0
- examples/codegen/tools/ast_analysis.py +92 -0
- examples/codegen/tools/code_context.py +180 -0
- examples/codegen/tools/code_nav.py +52 -0
- examples/codegen/tools/dependency_tools.py +120 -0
- examples/codegen/tools/example_tools.py +188 -0
- examples/codegen/tools/git_tools.py +151 -0
- examples/codegen/tools/impl_executor.py +614 -0
- examples/codegen/tools/jedi_analysis.py +311 -0
- examples/codegen/tools/meta_tools.py +202 -0
- examples/codegen/tools/syntax_tools.py +26 -0
- examples/codegen/tools/template_tools.py +356 -0
- examples/fastapi_interview.py +167 -0
- examples/npc/api/__init__.py +1 -0
- examples/npc/api/app.py +100 -0
- examples/npc/api/routes/__init__.py +5 -0
- examples/npc/api/routes/encounter.py +182 -0
- examples/npc/api/session.py +330 -0
- examples/npc/demo.py +387 -0
- examples/npc/nodes/__init__.py +5 -0
- examples/npc/nodes/image_node.py +92 -0
- examples/npc/run_encounter.py +230 -0
- examples/shared/__init__.py +0 -0
- examples/shared/replicate_tool.py +238 -0
- examples/storyboard/__init__.py +1 -0
- examples/storyboard/generate_videos.py +335 -0
- examples/storyboard/nodes/__init__.py +12 -0
- examples/storyboard/nodes/animated_character_node.py +248 -0
- examples/storyboard/nodes/animated_image_node.py +138 -0
- examples/storyboard/nodes/character_node.py +162 -0
- examples/storyboard/nodes/image_node.py +118 -0
- examples/storyboard/nodes/replicate_tool.py +49 -0
- examples/storyboard/retry_images.py +118 -0
- scripts/demo_async_executor.py +212 -0
- scripts/demo_interview_e2e.py +200 -0
- scripts/demo_streaming.py +140 -0
- scripts/run_interview_demo.py +94 -0
- scripts/test_interrupt_fix.py +26 -0
- tests/__init__.py +1 -0
- tests/conftest.py +178 -0
- tests/integration/__init__.py +1 -0
- tests/integration/test_animated_storyboard.py +63 -0
- tests/integration/test_cli_commands.py +242 -0
- tests/integration/test_colocated_prompts.py +139 -0
- tests/integration/test_map_demo.py +50 -0
- tests/integration/test_memory_demo.py +283 -0
- tests/integration/test_npc_api/__init__.py +1 -0
- tests/integration/test_npc_api/test_routes.py +357 -0
- tests/integration/test_npc_api/test_session.py +216 -0
- tests/integration/test_pipeline_flow.py +105 -0
- tests/integration/test_providers.py +163 -0
- tests/integration/test_resume.py +75 -0
- tests/integration/test_subgraph_integration.py +295 -0
- tests/integration/test_subgraph_interrupt.py +106 -0
- tests/unit/__init__.py +1 -0
- tests/unit/test_agent_nodes.py +355 -0
- tests/unit/test_async_executor.py +346 -0
- tests/unit/test_checkpointer.py +212 -0
- tests/unit/test_checkpointer_factory.py +212 -0
- tests/unit/test_cli.py +121 -0
- tests/unit/test_cli_package.py +81 -0
- tests/unit/test_compile_graph_map.py +132 -0
- tests/unit/test_conditions_routing.py +253 -0
- tests/unit/test_config.py +93 -0
- tests/unit/test_conversation_memory.py +276 -0
- tests/unit/test_database.py +145 -0
- tests/unit/test_deprecation.py +104 -0
- tests/unit/test_executor.py +172 -0
- tests/unit/test_executor_async.py +179 -0
- tests/unit/test_export.py +149 -0
- tests/unit/test_expressions.py +178 -0
- tests/unit/test_feature_brainstorm.py +194 -0
- tests/unit/test_format_prompt.py +145 -0
- tests/unit/test_generic_report.py +200 -0
- tests/unit/test_graph_commands.py +327 -0
- tests/unit/test_graph_linter.py +627 -0
- tests/unit/test_graph_loader.py +357 -0
- tests/unit/test_graph_schema.py +193 -0
- tests/unit/test_inline_schema.py +151 -0
- tests/unit/test_interrupt_node.py +182 -0
- tests/unit/test_issues.py +164 -0
- tests/unit/test_jinja2_prompts.py +85 -0
- tests/unit/test_json_extract.py +134 -0
- tests/unit/test_langsmith.py +600 -0
- tests/unit/test_langsmith_tools.py +204 -0
- tests/unit/test_llm_factory.py +109 -0
- tests/unit/test_llm_factory_async.py +118 -0
- tests/unit/test_loops.py +403 -0
- tests/unit/test_map_node.py +144 -0
- tests/unit/test_no_backward_compat.py +56 -0
- tests/unit/test_node_factory.py +348 -0
- tests/unit/test_passthrough_node.py +126 -0
- tests/unit/test_prompts.py +324 -0
- tests/unit/test_python_nodes.py +198 -0
- tests/unit/test_reliability.py +298 -0
- tests/unit/test_result_export.py +234 -0
- tests/unit/test_router.py +296 -0
- tests/unit/test_sanitize.py +99 -0
- tests/unit/test_schema_loader.py +295 -0
- tests/unit/test_shell_tools.py +229 -0
- tests/unit/test_state_builder.py +331 -0
- tests/unit/test_state_builder_map.py +104 -0
- tests/unit/test_state_config.py +197 -0
- tests/unit/test_streaming.py +307 -0
- tests/unit/test_subgraph.py +596 -0
- tests/unit/test_template.py +190 -0
- tests/unit/test_tool_call_integration.py +164 -0
- tests/unit/test_tool_call_node.py +178 -0
- tests/unit/test_tool_nodes.py +129 -0
- tests/unit/test_websearch.py +234 -0
- yamlgraph/__init__.py +35 -0
- yamlgraph/builder.py +110 -0
- yamlgraph/cli/__init__.py +159 -0
- yamlgraph/cli/__main__.py +6 -0
- yamlgraph/cli/commands.py +231 -0
- yamlgraph/cli/deprecation.py +92 -0
- yamlgraph/cli/graph_commands.py +541 -0
- yamlgraph/cli/validators.py +37 -0
- yamlgraph/config.py +67 -0
- yamlgraph/constants.py +70 -0
- yamlgraph/error_handlers.py +227 -0
- yamlgraph/executor.py +290 -0
- yamlgraph/executor_async.py +288 -0
- yamlgraph/graph_loader.py +451 -0
- yamlgraph/map_compiler.py +150 -0
- yamlgraph/models/__init__.py +36 -0
- yamlgraph/models/graph_schema.py +181 -0
- yamlgraph/models/schemas.py +124 -0
- yamlgraph/models/state_builder.py +236 -0
- yamlgraph/node_factory.py +768 -0
- yamlgraph/routing.py +87 -0
- yamlgraph/schema_loader.py +240 -0
- yamlgraph/storage/__init__.py +20 -0
- yamlgraph/storage/checkpointer.py +72 -0
- yamlgraph/storage/checkpointer_factory.py +123 -0
- yamlgraph/storage/database.py +320 -0
- yamlgraph/storage/export.py +269 -0
- yamlgraph/tools/__init__.py +1 -0
- yamlgraph/tools/agent.py +320 -0
- yamlgraph/tools/graph_linter.py +388 -0
- yamlgraph/tools/langsmith_tools.py +125 -0
- yamlgraph/tools/nodes.py +126 -0
- yamlgraph/tools/python_tool.py +179 -0
- yamlgraph/tools/shell.py +205 -0
- yamlgraph/tools/websearch.py +242 -0
- yamlgraph/utils/__init__.py +48 -0
- yamlgraph/utils/conditions.py +157 -0
- yamlgraph/utils/expressions.py +245 -0
- yamlgraph/utils/json_extract.py +104 -0
- yamlgraph/utils/langsmith.py +416 -0
- yamlgraph/utils/llm_factory.py +118 -0
- yamlgraph/utils/llm_factory_async.py +105 -0
- yamlgraph/utils/logging.py +104 -0
- yamlgraph/utils/prompts.py +171 -0
- yamlgraph/utils/sanitize.py +98 -0
- yamlgraph/utils/template.py +102 -0
- yamlgraph/utils/validators.py +181 -0
- yamlgraph-0.3.9.dist-info/METADATA +1105 -0
- yamlgraph-0.3.9.dist-info/RECORD +185 -0
- yamlgraph-0.3.9.dist-info/WHEEL +5 -0
- yamlgraph-0.3.9.dist-info/entry_points.txt +2 -0
- yamlgraph-0.3.9.dist-info/licenses/LICENSE +33 -0
- yamlgraph-0.3.9.dist-info/top_level.txt +4 -0
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
"""Tests for native LangGraph checkpointer integration."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class TestGetCheckpointer:
|
|
7
|
+
"""Tests for get_checkpointer() function."""
|
|
8
|
+
|
|
9
|
+
def test_returns_sqlite_saver_instance(self, tmp_path: Path):
|
|
10
|
+
"""Should return a SqliteSaver instance."""
|
|
11
|
+
from langgraph.checkpoint.sqlite import SqliteSaver
|
|
12
|
+
|
|
13
|
+
from yamlgraph.storage.checkpointer import get_checkpointer
|
|
14
|
+
|
|
15
|
+
db_path = tmp_path / "test.db"
|
|
16
|
+
checkpointer = get_checkpointer(db_path)
|
|
17
|
+
|
|
18
|
+
assert isinstance(checkpointer, SqliteSaver)
|
|
19
|
+
|
|
20
|
+
def test_creates_database_file(self, tmp_path: Path):
|
|
21
|
+
"""Should create the database file on first use."""
|
|
22
|
+
from yamlgraph.storage.checkpointer import get_checkpointer
|
|
23
|
+
|
|
24
|
+
db_path = tmp_path / "test.db"
|
|
25
|
+
assert not db_path.exists()
|
|
26
|
+
|
|
27
|
+
checkpointer = get_checkpointer(db_path)
|
|
28
|
+
# Access connection to trigger file creation
|
|
29
|
+
_ = checkpointer
|
|
30
|
+
|
|
31
|
+
# File created when connection is made
|
|
32
|
+
assert db_path.exists()
|
|
33
|
+
|
|
34
|
+
def test_uses_default_path_when_none(self, monkeypatch, tmp_path: Path):
|
|
35
|
+
"""Should use DATABASE_PATH when db_path is None."""
|
|
36
|
+
from yamlgraph.storage.checkpointer import get_checkpointer
|
|
37
|
+
|
|
38
|
+
# Monkeypatch the default path
|
|
39
|
+
default_db = tmp_path / "default.db"
|
|
40
|
+
monkeypatch.setattr("yamlgraph.storage.checkpointer.DATABASE_PATH", default_db)
|
|
41
|
+
|
|
42
|
+
checkpointer = get_checkpointer(None)
|
|
43
|
+
assert checkpointer is not None
|
|
44
|
+
|
|
45
|
+
def test_accepts_string_path(self, tmp_path: Path):
|
|
46
|
+
"""Should accept string path as well as Path."""
|
|
47
|
+
from yamlgraph.storage.checkpointer import get_checkpointer
|
|
48
|
+
|
|
49
|
+
db_path = str(tmp_path / "test.db")
|
|
50
|
+
checkpointer = get_checkpointer(db_path)
|
|
51
|
+
|
|
52
|
+
assert checkpointer is not None
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class TestCheckpointerWithGraph:
|
|
56
|
+
"""Tests for using checkpointer with a LangGraph StateGraph."""
|
|
57
|
+
|
|
58
|
+
def test_graph_compiles_with_checkpointer(self, tmp_path: Path):
|
|
59
|
+
"""Graph should compile when checkpointer is provided."""
|
|
60
|
+
from typing import TypedDict
|
|
61
|
+
|
|
62
|
+
from langgraph.graph import END, StateGraph
|
|
63
|
+
|
|
64
|
+
from yamlgraph.storage.checkpointer import get_checkpointer
|
|
65
|
+
|
|
66
|
+
class SimpleState(TypedDict, total=False):
|
|
67
|
+
value: str
|
|
68
|
+
|
|
69
|
+
def node_fn(state: SimpleState) -> dict:
|
|
70
|
+
return {"value": "updated"}
|
|
71
|
+
|
|
72
|
+
workflow = StateGraph(SimpleState)
|
|
73
|
+
workflow.add_node("test", node_fn)
|
|
74
|
+
workflow.set_entry_point("test")
|
|
75
|
+
workflow.add_edge("test", END)
|
|
76
|
+
|
|
77
|
+
checkpointer = get_checkpointer(tmp_path / "test.db")
|
|
78
|
+
graph = workflow.compile(checkpointer=checkpointer)
|
|
79
|
+
|
|
80
|
+
assert graph is not None
|
|
81
|
+
|
|
82
|
+
def test_state_persists_with_thread_id(self, tmp_path: Path):
|
|
83
|
+
"""State should persist when using thread_id in config."""
|
|
84
|
+
from typing import TypedDict
|
|
85
|
+
|
|
86
|
+
from langgraph.graph import END, StateGraph
|
|
87
|
+
|
|
88
|
+
from yamlgraph.storage.checkpointer import get_checkpointer
|
|
89
|
+
|
|
90
|
+
class CounterState(TypedDict, total=False):
|
|
91
|
+
count: int
|
|
92
|
+
|
|
93
|
+
def increment(state: CounterState) -> dict:
|
|
94
|
+
return {"count": state.get("count", 0) + 1}
|
|
95
|
+
|
|
96
|
+
workflow = StateGraph(CounterState)
|
|
97
|
+
workflow.add_node("increment", increment)
|
|
98
|
+
workflow.set_entry_point("increment")
|
|
99
|
+
workflow.add_edge("increment", END)
|
|
100
|
+
|
|
101
|
+
checkpointer = get_checkpointer(tmp_path / "test.db")
|
|
102
|
+
graph = workflow.compile(checkpointer=checkpointer)
|
|
103
|
+
|
|
104
|
+
config = {"configurable": {"thread_id": "test-thread-1"}}
|
|
105
|
+
|
|
106
|
+
# First invocation
|
|
107
|
+
result1 = graph.invoke({"count": 0}, config)
|
|
108
|
+
assert result1["count"] == 1
|
|
109
|
+
|
|
110
|
+
# State should be retrievable
|
|
111
|
+
state = graph.get_state(config)
|
|
112
|
+
assert state.values["count"] == 1
|
|
113
|
+
|
|
114
|
+
def test_get_state_history_returns_checkpoints(self, tmp_path: Path):
|
|
115
|
+
"""get_state_history() should return checkpoint history."""
|
|
116
|
+
from typing import TypedDict
|
|
117
|
+
|
|
118
|
+
from langgraph.graph import END, StateGraph
|
|
119
|
+
|
|
120
|
+
from yamlgraph.storage.checkpointer import get_checkpointer
|
|
121
|
+
|
|
122
|
+
class StepState(TypedDict, total=False):
|
|
123
|
+
step: int
|
|
124
|
+
|
|
125
|
+
def step1(state: StepState) -> dict:
|
|
126
|
+
return {"step": 1}
|
|
127
|
+
|
|
128
|
+
def step2(state: StepState) -> dict:
|
|
129
|
+
return {"step": 2}
|
|
130
|
+
|
|
131
|
+
workflow = StateGraph(StepState)
|
|
132
|
+
workflow.add_node("step1", step1)
|
|
133
|
+
workflow.add_node("step2", step2)
|
|
134
|
+
workflow.set_entry_point("step1")
|
|
135
|
+
workflow.add_edge("step1", "step2")
|
|
136
|
+
workflow.add_edge("step2", END)
|
|
137
|
+
|
|
138
|
+
checkpointer = get_checkpointer(tmp_path / "test.db")
|
|
139
|
+
graph = workflow.compile(checkpointer=checkpointer)
|
|
140
|
+
|
|
141
|
+
config = {"configurable": {"thread_id": "history-test"}}
|
|
142
|
+
graph.invoke({}, config)
|
|
143
|
+
|
|
144
|
+
# Get history
|
|
145
|
+
history = list(graph.get_state_history(config))
|
|
146
|
+
|
|
147
|
+
# Should have multiple checkpoints (one per step + initial)
|
|
148
|
+
assert len(history) >= 2
|
|
149
|
+
|
|
150
|
+
# Most recent first
|
|
151
|
+
assert history[0].values["step"] == 2
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
class TestGetStateHistory:
|
|
155
|
+
"""Tests for get_state_history helper function."""
|
|
156
|
+
|
|
157
|
+
def test_returns_list_of_snapshots(self, tmp_path: Path):
|
|
158
|
+
"""get_state_history should return list of StateSnapshot."""
|
|
159
|
+
from typing import TypedDict
|
|
160
|
+
|
|
161
|
+
from langgraph.graph import END, StateGraph
|
|
162
|
+
|
|
163
|
+
from yamlgraph.storage.checkpointer import get_checkpointer, get_state_history
|
|
164
|
+
|
|
165
|
+
class TestState(TypedDict, total=False):
|
|
166
|
+
data: str
|
|
167
|
+
|
|
168
|
+
def node(state: TestState) -> dict:
|
|
169
|
+
return {"data": "done"}
|
|
170
|
+
|
|
171
|
+
workflow = StateGraph(TestState)
|
|
172
|
+
workflow.add_node("test", node)
|
|
173
|
+
workflow.set_entry_point("test")
|
|
174
|
+
workflow.add_edge("test", END)
|
|
175
|
+
|
|
176
|
+
checkpointer = get_checkpointer(tmp_path / "test.db")
|
|
177
|
+
graph = workflow.compile(checkpointer=checkpointer)
|
|
178
|
+
|
|
179
|
+
thread_id = "history-helper-test"
|
|
180
|
+
config = {"configurable": {"thread_id": thread_id}}
|
|
181
|
+
graph.invoke({}, config)
|
|
182
|
+
|
|
183
|
+
history = get_state_history(graph, thread_id)
|
|
184
|
+
|
|
185
|
+
assert isinstance(history, list)
|
|
186
|
+
assert len(history) >= 1
|
|
187
|
+
|
|
188
|
+
def test_empty_history_for_unknown_thread(self, tmp_path: Path):
|
|
189
|
+
"""Should return empty list for non-existent thread."""
|
|
190
|
+
from typing import TypedDict
|
|
191
|
+
|
|
192
|
+
from langgraph.graph import END, StateGraph
|
|
193
|
+
|
|
194
|
+
from yamlgraph.storage.checkpointer import get_checkpointer, get_state_history
|
|
195
|
+
|
|
196
|
+
class TestState(TypedDict, total=False):
|
|
197
|
+
data: str
|
|
198
|
+
|
|
199
|
+
def node(state: TestState) -> dict:
|
|
200
|
+
return {"data": "done"}
|
|
201
|
+
|
|
202
|
+
workflow = StateGraph(TestState)
|
|
203
|
+
workflow.add_node("test", node)
|
|
204
|
+
workflow.set_entry_point("test")
|
|
205
|
+
workflow.add_edge("test", END)
|
|
206
|
+
|
|
207
|
+
checkpointer = get_checkpointer(tmp_path / "test.db")
|
|
208
|
+
graph = workflow.compile(checkpointer=checkpointer)
|
|
209
|
+
|
|
210
|
+
history = get_state_history(graph, "non-existent-thread")
|
|
211
|
+
|
|
212
|
+
assert history == []
|
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
"""Unit tests for checkpointer factory.
|
|
2
|
+
|
|
3
|
+
TDD tests for 002: Redis Checkpointer feature.
|
|
4
|
+
Tests get_checkpointer() factory with env var expansion and async mode.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import os
|
|
8
|
+
from unittest.mock import MagicMock, patch
|
|
9
|
+
|
|
10
|
+
import pytest
|
|
11
|
+
|
|
12
|
+
from yamlgraph.storage.checkpointer_factory import (
|
|
13
|
+
expand_env_vars,
|
|
14
|
+
get_checkpointer,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class TestExpandEnvVars:
|
|
19
|
+
"""Test environment variable expansion."""
|
|
20
|
+
|
|
21
|
+
def test_expand_single_var(self):
|
|
22
|
+
"""Should expand ${VAR} pattern."""
|
|
23
|
+
with patch.dict(os.environ, {"REDIS_URL": "redis://localhost:6379"}):
|
|
24
|
+
result = expand_env_vars("${REDIS_URL}")
|
|
25
|
+
assert result == "redis://localhost:6379"
|
|
26
|
+
|
|
27
|
+
def test_expand_multiple_vars(self):
|
|
28
|
+
"""Should expand multiple ${VAR} patterns."""
|
|
29
|
+
with patch.dict(os.environ, {"HOST": "localhost", "PORT": "6379"}):
|
|
30
|
+
result = expand_env_vars("redis://${HOST}:${PORT}/0")
|
|
31
|
+
assert result == "redis://localhost:6379/0"
|
|
32
|
+
|
|
33
|
+
def test_expand_missing_var_keeps_original(self):
|
|
34
|
+
"""Missing env vars should keep original ${VAR} pattern."""
|
|
35
|
+
# Ensure NONEXISTENT is not set
|
|
36
|
+
os.environ.pop("NONEXISTENT", None)
|
|
37
|
+
result = expand_env_vars("${NONEXISTENT}")
|
|
38
|
+
assert result == "${NONEXISTENT}"
|
|
39
|
+
|
|
40
|
+
def test_expand_non_string_returns_unchanged(self):
|
|
41
|
+
"""Non-string values should pass through unchanged."""
|
|
42
|
+
assert expand_env_vars(123) == 123
|
|
43
|
+
assert expand_env_vars(None) is None
|
|
44
|
+
assert expand_env_vars(["a", "b"]) == ["a", "b"]
|
|
45
|
+
|
|
46
|
+
def test_expand_no_vars_returns_original(self):
|
|
47
|
+
"""String without ${} should return unchanged."""
|
|
48
|
+
result = expand_env_vars("redis://localhost:6379")
|
|
49
|
+
assert result == "redis://localhost:6379"
|
|
50
|
+
|
|
51
|
+
def test_expand_empty_string(self):
|
|
52
|
+
"""Empty string should return empty string."""
|
|
53
|
+
assert expand_env_vars("") == ""
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class TestGetCheckpointerMemory:
|
|
57
|
+
"""Test in-memory checkpointer (default)."""
|
|
58
|
+
|
|
59
|
+
def test_memory_checkpointer_default(self):
|
|
60
|
+
"""Default type should be memory."""
|
|
61
|
+
config = {"type": "memory"} # Empty config defaults to memory via get
|
|
62
|
+
saver = get_checkpointer(config)
|
|
63
|
+
|
|
64
|
+
from langgraph.checkpoint.memory import InMemorySaver
|
|
65
|
+
|
|
66
|
+
assert isinstance(saver, InMemorySaver)
|
|
67
|
+
|
|
68
|
+
def test_memory_checkpointer_explicit(self):
|
|
69
|
+
"""Explicit type: memory should work."""
|
|
70
|
+
config = {"type": "memory"}
|
|
71
|
+
saver = get_checkpointer(config)
|
|
72
|
+
|
|
73
|
+
from langgraph.checkpoint.memory import InMemorySaver
|
|
74
|
+
|
|
75
|
+
assert isinstance(saver, InMemorySaver)
|
|
76
|
+
|
|
77
|
+
def test_none_config_returns_none(self):
|
|
78
|
+
"""None config should return None."""
|
|
79
|
+
assert get_checkpointer(None) is None
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class TestGetCheckpointerSqlite:
|
|
83
|
+
"""Test SQLite checkpointer."""
|
|
84
|
+
|
|
85
|
+
def test_sqlite_checkpointer_memory(self):
|
|
86
|
+
"""SQLite with :memory: should work."""
|
|
87
|
+
config = {"type": "sqlite", "path": ":memory:"}
|
|
88
|
+
saver = get_checkpointer(config)
|
|
89
|
+
|
|
90
|
+
from langgraph.checkpoint.sqlite import SqliteSaver
|
|
91
|
+
|
|
92
|
+
assert isinstance(saver, SqliteSaver)
|
|
93
|
+
|
|
94
|
+
def test_sqlite_expands_env_var(self):
|
|
95
|
+
"""SQLite path should expand env vars."""
|
|
96
|
+
with patch.dict(os.environ, {"DB_PATH": ":memory:"}):
|
|
97
|
+
config = {"type": "sqlite", "path": "${DB_PATH}"}
|
|
98
|
+
saver = get_checkpointer(config)
|
|
99
|
+
|
|
100
|
+
from langgraph.checkpoint.sqlite import SqliteSaver
|
|
101
|
+
|
|
102
|
+
assert isinstance(saver, SqliteSaver)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class TestGetCheckpointerRedis:
|
|
106
|
+
"""Test Redis checkpointer (mocked)."""
|
|
107
|
+
|
|
108
|
+
def test_redis_checkpointer_sync(self):
|
|
109
|
+
"""Redis sync saver should be created."""
|
|
110
|
+
mock_saver = MagicMock()
|
|
111
|
+
mock_redis_module = MagicMock()
|
|
112
|
+
mock_redis_module.RedisSaver.from_conn_string.return_value = mock_saver
|
|
113
|
+
|
|
114
|
+
with patch.dict(
|
|
115
|
+
"sys.modules", {"langgraph.checkpoint.redis": mock_redis_module}
|
|
116
|
+
):
|
|
117
|
+
# Re-import to pick up mocked module
|
|
118
|
+
import importlib
|
|
119
|
+
|
|
120
|
+
from yamlgraph.storage import checkpointer_factory
|
|
121
|
+
|
|
122
|
+
importlib.reload(checkpointer_factory)
|
|
123
|
+
|
|
124
|
+
with patch.dict(os.environ, {"REDIS_URL": "redis://localhost:6379"}):
|
|
125
|
+
config = {"type": "redis", "url": "${REDIS_URL}", "ttl": 120}
|
|
126
|
+
saver = checkpointer_factory.get_checkpointer(config)
|
|
127
|
+
|
|
128
|
+
mock_redis_module.RedisSaver.from_conn_string.assert_called_once_with(
|
|
129
|
+
"redis://localhost:6379",
|
|
130
|
+
ttl={"default_ttl": 120},
|
|
131
|
+
)
|
|
132
|
+
mock_saver.setup.assert_called_once()
|
|
133
|
+
assert saver is mock_saver
|
|
134
|
+
|
|
135
|
+
def test_redis_checkpointer_async(self):
|
|
136
|
+
"""Redis async saver should be created with async_mode=True."""
|
|
137
|
+
mock_saver = MagicMock()
|
|
138
|
+
mock_aio_module = MagicMock()
|
|
139
|
+
mock_aio_module.AsyncRedisSaver.from_conn_string.return_value = mock_saver
|
|
140
|
+
|
|
141
|
+
with patch.dict(
|
|
142
|
+
"sys.modules", {"langgraph.checkpoint.redis.aio": mock_aio_module}
|
|
143
|
+
):
|
|
144
|
+
import importlib
|
|
145
|
+
|
|
146
|
+
from yamlgraph.storage import checkpointer_factory
|
|
147
|
+
|
|
148
|
+
importlib.reload(checkpointer_factory)
|
|
149
|
+
|
|
150
|
+
config = {"type": "redis", "url": "redis://localhost:6379", "ttl": 60}
|
|
151
|
+
saver = checkpointer_factory.get_checkpointer(config, async_mode=True)
|
|
152
|
+
|
|
153
|
+
mock_aio_module.AsyncRedisSaver.from_conn_string.assert_called_once_with(
|
|
154
|
+
"redis://localhost:6379",
|
|
155
|
+
ttl={"default_ttl": 60},
|
|
156
|
+
)
|
|
157
|
+
# Async saver should NOT call setup() - caller must await asetup()
|
|
158
|
+
mock_saver.setup.assert_not_called()
|
|
159
|
+
assert saver is mock_saver
|
|
160
|
+
|
|
161
|
+
def test_redis_import_error_helpful_message(self):
|
|
162
|
+
"""Missing redis package should give helpful error."""
|
|
163
|
+
# Clear any cached imports
|
|
164
|
+
import sys
|
|
165
|
+
|
|
166
|
+
for key in list(sys.modules.keys()):
|
|
167
|
+
if "langgraph.checkpoint.redis" in key:
|
|
168
|
+
del sys.modules[key]
|
|
169
|
+
|
|
170
|
+
# This test verifies the ImportError wrapping
|
|
171
|
+
config = {"type": "redis", "url": "redis://localhost:6379"}
|
|
172
|
+
|
|
173
|
+
with pytest.raises(ImportError) as exc_info:
|
|
174
|
+
get_checkpointer(config)
|
|
175
|
+
|
|
176
|
+
assert "pip install yamlgraph[redis]" in str(exc_info.value)
|
|
177
|
+
|
|
178
|
+
def test_redis_default_ttl(self):
|
|
179
|
+
"""Redis should use default TTL of 60 if not specified."""
|
|
180
|
+
mock_saver = MagicMock()
|
|
181
|
+
mock_redis_module = MagicMock()
|
|
182
|
+
mock_redis_module.RedisSaver.from_conn_string.return_value = mock_saver
|
|
183
|
+
|
|
184
|
+
with patch.dict(
|
|
185
|
+
"sys.modules", {"langgraph.checkpoint.redis": mock_redis_module}
|
|
186
|
+
):
|
|
187
|
+
import importlib
|
|
188
|
+
|
|
189
|
+
from yamlgraph.storage import checkpointer_factory
|
|
190
|
+
|
|
191
|
+
importlib.reload(checkpointer_factory)
|
|
192
|
+
|
|
193
|
+
config = {"type": "redis", "url": "redis://localhost:6379"}
|
|
194
|
+
checkpointer_factory.get_checkpointer(config)
|
|
195
|
+
|
|
196
|
+
mock_redis_module.RedisSaver.from_conn_string.assert_called_once_with(
|
|
197
|
+
"redis://localhost:6379",
|
|
198
|
+
ttl={"default_ttl": 60},
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
class TestGetCheckpointerErrors:
|
|
203
|
+
"""Test error handling."""
|
|
204
|
+
|
|
205
|
+
def test_unknown_type_raises_error(self):
|
|
206
|
+
"""Unknown checkpointer type should raise ValueError."""
|
|
207
|
+
config = {"type": "unknown_db"}
|
|
208
|
+
|
|
209
|
+
with pytest.raises(ValueError) as exc_info:
|
|
210
|
+
get_checkpointer(config)
|
|
211
|
+
|
|
212
|
+
assert "Unknown checkpointer type: unknown_db" in str(exc_info.value)
|
tests/unit/test_cli.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
"""Tests for yamlgraph.cli module."""
|
|
2
|
+
|
|
3
|
+
import argparse
|
|
4
|
+
|
|
5
|
+
from yamlgraph.cli.validators import validate_run_args
|
|
6
|
+
from yamlgraph.config import MAX_TOPIC_LENGTH, MAX_WORD_COUNT, MIN_WORD_COUNT
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class TestValidateRunArgs:
|
|
10
|
+
"""Tests for validate_run_args function."""
|
|
11
|
+
|
|
12
|
+
def _create_args(self, topic="test topic", word_count=300, style="informative"):
|
|
13
|
+
"""Helper to create args namespace."""
|
|
14
|
+
return argparse.Namespace(
|
|
15
|
+
topic=topic,
|
|
16
|
+
word_count=word_count,
|
|
17
|
+
style=style,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
def test_valid_args(self):
|
|
21
|
+
"""Valid arguments should pass validation."""
|
|
22
|
+
args = self._create_args()
|
|
23
|
+
assert validate_run_args(args) is True
|
|
24
|
+
|
|
25
|
+
def test_empty_topic(self):
|
|
26
|
+
"""Empty topic should fail validation."""
|
|
27
|
+
args = self._create_args(topic="")
|
|
28
|
+
assert validate_run_args(args) is False
|
|
29
|
+
|
|
30
|
+
def test_whitespace_only_topic(self):
|
|
31
|
+
"""Whitespace-only topic should fail validation."""
|
|
32
|
+
args = self._create_args(topic=" ")
|
|
33
|
+
assert validate_run_args(args) is False
|
|
34
|
+
|
|
35
|
+
def test_topic_too_long(self):
|
|
36
|
+
"""Topic exceeding max length should be truncated with warning."""
|
|
37
|
+
long_topic = "x" * (MAX_TOPIC_LENGTH + 100)
|
|
38
|
+
args = self._create_args(topic=long_topic)
|
|
39
|
+
# Should pass but truncate the topic
|
|
40
|
+
assert validate_run_args(args) is True
|
|
41
|
+
assert len(args.topic) == MAX_TOPIC_LENGTH
|
|
42
|
+
|
|
43
|
+
def test_topic_at_max_length(self):
|
|
44
|
+
"""Topic at max length should pass validation."""
|
|
45
|
+
max_topic = "x" * MAX_TOPIC_LENGTH
|
|
46
|
+
args = self._create_args(topic=max_topic)
|
|
47
|
+
assert validate_run_args(args) is True
|
|
48
|
+
|
|
49
|
+
def test_word_count_too_low(self):
|
|
50
|
+
"""Word count below minimum should fail validation."""
|
|
51
|
+
args = self._create_args(word_count=MIN_WORD_COUNT - 1)
|
|
52
|
+
assert validate_run_args(args) is False
|
|
53
|
+
|
|
54
|
+
def test_word_count_too_high(self):
|
|
55
|
+
"""Word count above maximum should fail validation."""
|
|
56
|
+
args = self._create_args(word_count=MAX_WORD_COUNT + 1)
|
|
57
|
+
assert validate_run_args(args) is False
|
|
58
|
+
|
|
59
|
+
def test_word_count_at_min(self):
|
|
60
|
+
"""Word count at minimum should pass validation."""
|
|
61
|
+
args = self._create_args(word_count=MIN_WORD_COUNT)
|
|
62
|
+
assert validate_run_args(args) is True
|
|
63
|
+
|
|
64
|
+
def test_word_count_at_max(self):
|
|
65
|
+
"""Word count at maximum should pass validation."""
|
|
66
|
+
args = self._create_args(word_count=MAX_WORD_COUNT)
|
|
67
|
+
assert validate_run_args(args) is True
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class TestFormatResult:
|
|
71
|
+
"""Tests for generic result formatting."""
|
|
72
|
+
|
|
73
|
+
def test_format_result_with_any_pydantic_model(self, capsys):
|
|
74
|
+
"""CLI should format any Pydantic model, not just known ones."""
|
|
75
|
+
from pydantic import BaseModel
|
|
76
|
+
|
|
77
|
+
from yamlgraph.cli.commands import _format_result
|
|
78
|
+
|
|
79
|
+
class CustomResult(BaseModel):
|
|
80
|
+
title: str
|
|
81
|
+
score: float
|
|
82
|
+
items: list[str]
|
|
83
|
+
|
|
84
|
+
result = {
|
|
85
|
+
"current_step": "done",
|
|
86
|
+
"custom": CustomResult(title="Test Title", score=0.95, items=["a", "b"]),
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
_format_result(result)
|
|
90
|
+
captured = capsys.readouterr()
|
|
91
|
+
assert "Test Title" in captured.out
|
|
92
|
+
assert "0.95" in captured.out
|
|
93
|
+
|
|
94
|
+
def test_format_result_skips_internal_keys(self, capsys):
|
|
95
|
+
"""Internal keys should be skipped."""
|
|
96
|
+
from yamlgraph.cli.commands import _format_result
|
|
97
|
+
|
|
98
|
+
result = {
|
|
99
|
+
"current_step": "done",
|
|
100
|
+
"_route": "positive",
|
|
101
|
+
"_loop_counts": {"draft": 2},
|
|
102
|
+
"response": "Hello!",
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
_format_result(result)
|
|
106
|
+
captured = capsys.readouterr()
|
|
107
|
+
assert "_route" not in captured.out
|
|
108
|
+
assert "_loop_counts" not in captured.out
|
|
109
|
+
assert "Hello!" in captured.out
|
|
110
|
+
|
|
111
|
+
def test_format_result_truncates_long_strings(self, capsys):
|
|
112
|
+
"""Long strings should be truncated."""
|
|
113
|
+
from yamlgraph.cli.commands import _format_result
|
|
114
|
+
|
|
115
|
+
long_text = "x" * 500
|
|
116
|
+
result = {"summary": long_text}
|
|
117
|
+
|
|
118
|
+
_format_result(result)
|
|
119
|
+
captured = capsys.readouterr()
|
|
120
|
+
assert "..." in captured.out
|
|
121
|
+
assert len(captured.out) < 400 # Truncated
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
"""Tests for CLI package structure (Phase 7.1).
|
|
2
|
+
|
|
3
|
+
TDD tests for splitting cli.py into a cli/ package.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import argparse
|
|
7
|
+
|
|
8
|
+
# =============================================================================
|
|
9
|
+
# Package Structure Tests
|
|
10
|
+
# =============================================================================
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class TestCLIPackageStructure:
|
|
14
|
+
"""Tests for CLI package imports."""
|
|
15
|
+
|
|
16
|
+
def test_cli_package_importable(self):
|
|
17
|
+
"""yamlgraph.cli should be importable as package."""
|
|
18
|
+
import yamlgraph.cli
|
|
19
|
+
|
|
20
|
+
assert yamlgraph.cli is not None
|
|
21
|
+
|
|
22
|
+
def test_main_function_available(self):
|
|
23
|
+
"""main() should be available from package."""
|
|
24
|
+
from yamlgraph.cli import main
|
|
25
|
+
|
|
26
|
+
assert callable(main)
|
|
27
|
+
|
|
28
|
+
def test_validators_submodule_exists(self):
|
|
29
|
+
"""validators submodule should exist."""
|
|
30
|
+
from yamlgraph.cli import validators
|
|
31
|
+
|
|
32
|
+
assert validators is not None
|
|
33
|
+
|
|
34
|
+
def test_validate_run_args_in_validators(self):
|
|
35
|
+
"""validate_run_args should be in validators module."""
|
|
36
|
+
from yamlgraph.cli.validators import validate_run_args
|
|
37
|
+
|
|
38
|
+
assert callable(validate_run_args)
|
|
39
|
+
|
|
40
|
+
def test_commands_submodule_exists(self):
|
|
41
|
+
"""commands submodule should exist."""
|
|
42
|
+
from yamlgraph.cli import commands
|
|
43
|
+
|
|
44
|
+
assert commands is not None
|
|
45
|
+
|
|
46
|
+
def test_cmd_list_runs_in_commands(self):
|
|
47
|
+
"""cmd_list_runs should be in commands module."""
|
|
48
|
+
from yamlgraph.cli.commands import cmd_list_runs
|
|
49
|
+
|
|
50
|
+
assert callable(cmd_list_runs)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
# =============================================================================
|
|
54
|
+
# Validator Tests (moved from cli module)
|
|
55
|
+
# =============================================================================
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class TestValidatorsModule:
|
|
59
|
+
"""Tests for validators module functionality."""
|
|
60
|
+
|
|
61
|
+
def _create_run_args(self, topic="test topic", word_count=300, style="informative"):
|
|
62
|
+
"""Helper to create run args namespace."""
|
|
63
|
+
return argparse.Namespace(
|
|
64
|
+
topic=topic,
|
|
65
|
+
word_count=word_count,
|
|
66
|
+
style=style,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
def test_validate_run_args_valid(self):
|
|
70
|
+
"""Valid run args pass validation."""
|
|
71
|
+
from yamlgraph.cli.validators import validate_run_args
|
|
72
|
+
|
|
73
|
+
args = self._create_run_args()
|
|
74
|
+
assert validate_run_args(args) is True
|
|
75
|
+
|
|
76
|
+
def test_validate_run_args_empty_topic(self):
|
|
77
|
+
"""Empty topic fails validation."""
|
|
78
|
+
from yamlgraph.cli.validators import validate_run_args
|
|
79
|
+
|
|
80
|
+
args = self._create_run_args(topic="")
|
|
81
|
+
assert validate_run_args(args) is False
|