yamlgraph 0.3.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. examples/__init__.py +1 -0
  2. examples/codegen/__init__.py +5 -0
  3. examples/codegen/models/__init__.py +13 -0
  4. examples/codegen/models/schemas.py +76 -0
  5. examples/codegen/tests/__init__.py +1 -0
  6. examples/codegen/tests/test_ai_helpers.py +235 -0
  7. examples/codegen/tests/test_ast_analysis.py +174 -0
  8. examples/codegen/tests/test_code_analysis.py +134 -0
  9. examples/codegen/tests/test_code_context.py +301 -0
  10. examples/codegen/tests/test_code_nav.py +89 -0
  11. examples/codegen/tests/test_dependency_tools.py +119 -0
  12. examples/codegen/tests/test_example_tools.py +185 -0
  13. examples/codegen/tests/test_git_tools.py +112 -0
  14. examples/codegen/tests/test_impl_agent_schemas.py +193 -0
  15. examples/codegen/tests/test_impl_agent_v4_graph.py +94 -0
  16. examples/codegen/tests/test_jedi_analysis.py +226 -0
  17. examples/codegen/tests/test_meta_tools.py +250 -0
  18. examples/codegen/tests/test_plan_discovery_prompt.py +98 -0
  19. examples/codegen/tests/test_syntax_tools.py +85 -0
  20. examples/codegen/tests/test_synthesize_prompt.py +94 -0
  21. examples/codegen/tests/test_template_tools.py +244 -0
  22. examples/codegen/tools/__init__.py +80 -0
  23. examples/codegen/tools/ai_helpers.py +420 -0
  24. examples/codegen/tools/ast_analysis.py +92 -0
  25. examples/codegen/tools/code_context.py +180 -0
  26. examples/codegen/tools/code_nav.py +52 -0
  27. examples/codegen/tools/dependency_tools.py +120 -0
  28. examples/codegen/tools/example_tools.py +188 -0
  29. examples/codegen/tools/git_tools.py +151 -0
  30. examples/codegen/tools/impl_executor.py +614 -0
  31. examples/codegen/tools/jedi_analysis.py +311 -0
  32. examples/codegen/tools/meta_tools.py +202 -0
  33. examples/codegen/tools/syntax_tools.py +26 -0
  34. examples/codegen/tools/template_tools.py +356 -0
  35. examples/fastapi_interview.py +167 -0
  36. examples/npc/api/__init__.py +1 -0
  37. examples/npc/api/app.py +100 -0
  38. examples/npc/api/routes/__init__.py +5 -0
  39. examples/npc/api/routes/encounter.py +182 -0
  40. examples/npc/api/session.py +330 -0
  41. examples/npc/demo.py +387 -0
  42. examples/npc/nodes/__init__.py +5 -0
  43. examples/npc/nodes/image_node.py +92 -0
  44. examples/npc/run_encounter.py +230 -0
  45. examples/shared/__init__.py +0 -0
  46. examples/shared/replicate_tool.py +238 -0
  47. examples/storyboard/__init__.py +1 -0
  48. examples/storyboard/generate_videos.py +335 -0
  49. examples/storyboard/nodes/__init__.py +12 -0
  50. examples/storyboard/nodes/animated_character_node.py +248 -0
  51. examples/storyboard/nodes/animated_image_node.py +138 -0
  52. examples/storyboard/nodes/character_node.py +162 -0
  53. examples/storyboard/nodes/image_node.py +118 -0
  54. examples/storyboard/nodes/replicate_tool.py +49 -0
  55. examples/storyboard/retry_images.py +118 -0
  56. scripts/demo_async_executor.py +212 -0
  57. scripts/demo_interview_e2e.py +200 -0
  58. scripts/demo_streaming.py +140 -0
  59. scripts/run_interview_demo.py +94 -0
  60. scripts/test_interrupt_fix.py +26 -0
  61. tests/__init__.py +1 -0
  62. tests/conftest.py +178 -0
  63. tests/integration/__init__.py +1 -0
  64. tests/integration/test_animated_storyboard.py +63 -0
  65. tests/integration/test_cli_commands.py +242 -0
  66. tests/integration/test_colocated_prompts.py +139 -0
  67. tests/integration/test_map_demo.py +50 -0
  68. tests/integration/test_memory_demo.py +283 -0
  69. tests/integration/test_npc_api/__init__.py +1 -0
  70. tests/integration/test_npc_api/test_routes.py +357 -0
  71. tests/integration/test_npc_api/test_session.py +216 -0
  72. tests/integration/test_pipeline_flow.py +105 -0
  73. tests/integration/test_providers.py +163 -0
  74. tests/integration/test_resume.py +75 -0
  75. tests/integration/test_subgraph_integration.py +295 -0
  76. tests/integration/test_subgraph_interrupt.py +106 -0
  77. tests/unit/__init__.py +1 -0
  78. tests/unit/test_agent_nodes.py +355 -0
  79. tests/unit/test_async_executor.py +346 -0
  80. tests/unit/test_checkpointer.py +212 -0
  81. tests/unit/test_checkpointer_factory.py +212 -0
  82. tests/unit/test_cli.py +121 -0
  83. tests/unit/test_cli_package.py +81 -0
  84. tests/unit/test_compile_graph_map.py +132 -0
  85. tests/unit/test_conditions_routing.py +253 -0
  86. tests/unit/test_config.py +93 -0
  87. tests/unit/test_conversation_memory.py +276 -0
  88. tests/unit/test_database.py +145 -0
  89. tests/unit/test_deprecation.py +104 -0
  90. tests/unit/test_executor.py +172 -0
  91. tests/unit/test_executor_async.py +179 -0
  92. tests/unit/test_export.py +149 -0
  93. tests/unit/test_expressions.py +178 -0
  94. tests/unit/test_feature_brainstorm.py +194 -0
  95. tests/unit/test_format_prompt.py +145 -0
  96. tests/unit/test_generic_report.py +200 -0
  97. tests/unit/test_graph_commands.py +327 -0
  98. tests/unit/test_graph_linter.py +627 -0
  99. tests/unit/test_graph_loader.py +357 -0
  100. tests/unit/test_graph_schema.py +193 -0
  101. tests/unit/test_inline_schema.py +151 -0
  102. tests/unit/test_interrupt_node.py +182 -0
  103. tests/unit/test_issues.py +164 -0
  104. tests/unit/test_jinja2_prompts.py +85 -0
  105. tests/unit/test_json_extract.py +134 -0
  106. tests/unit/test_langsmith.py +600 -0
  107. tests/unit/test_langsmith_tools.py +204 -0
  108. tests/unit/test_llm_factory.py +109 -0
  109. tests/unit/test_llm_factory_async.py +118 -0
  110. tests/unit/test_loops.py +403 -0
  111. tests/unit/test_map_node.py +144 -0
  112. tests/unit/test_no_backward_compat.py +56 -0
  113. tests/unit/test_node_factory.py +348 -0
  114. tests/unit/test_passthrough_node.py +126 -0
  115. tests/unit/test_prompts.py +324 -0
  116. tests/unit/test_python_nodes.py +198 -0
  117. tests/unit/test_reliability.py +298 -0
  118. tests/unit/test_result_export.py +234 -0
  119. tests/unit/test_router.py +296 -0
  120. tests/unit/test_sanitize.py +99 -0
  121. tests/unit/test_schema_loader.py +295 -0
  122. tests/unit/test_shell_tools.py +229 -0
  123. tests/unit/test_state_builder.py +331 -0
  124. tests/unit/test_state_builder_map.py +104 -0
  125. tests/unit/test_state_config.py +197 -0
  126. tests/unit/test_streaming.py +307 -0
  127. tests/unit/test_subgraph.py +596 -0
  128. tests/unit/test_template.py +190 -0
  129. tests/unit/test_tool_call_integration.py +164 -0
  130. tests/unit/test_tool_call_node.py +178 -0
  131. tests/unit/test_tool_nodes.py +129 -0
  132. tests/unit/test_websearch.py +234 -0
  133. yamlgraph/__init__.py +35 -0
  134. yamlgraph/builder.py +110 -0
  135. yamlgraph/cli/__init__.py +159 -0
  136. yamlgraph/cli/__main__.py +6 -0
  137. yamlgraph/cli/commands.py +231 -0
  138. yamlgraph/cli/deprecation.py +92 -0
  139. yamlgraph/cli/graph_commands.py +541 -0
  140. yamlgraph/cli/validators.py +37 -0
  141. yamlgraph/config.py +67 -0
  142. yamlgraph/constants.py +70 -0
  143. yamlgraph/error_handlers.py +227 -0
  144. yamlgraph/executor.py +290 -0
  145. yamlgraph/executor_async.py +288 -0
  146. yamlgraph/graph_loader.py +451 -0
  147. yamlgraph/map_compiler.py +150 -0
  148. yamlgraph/models/__init__.py +36 -0
  149. yamlgraph/models/graph_schema.py +181 -0
  150. yamlgraph/models/schemas.py +124 -0
  151. yamlgraph/models/state_builder.py +236 -0
  152. yamlgraph/node_factory.py +768 -0
  153. yamlgraph/routing.py +87 -0
  154. yamlgraph/schema_loader.py +240 -0
  155. yamlgraph/storage/__init__.py +20 -0
  156. yamlgraph/storage/checkpointer.py +72 -0
  157. yamlgraph/storage/checkpointer_factory.py +123 -0
  158. yamlgraph/storage/database.py +320 -0
  159. yamlgraph/storage/export.py +269 -0
  160. yamlgraph/tools/__init__.py +1 -0
  161. yamlgraph/tools/agent.py +320 -0
  162. yamlgraph/tools/graph_linter.py +388 -0
  163. yamlgraph/tools/langsmith_tools.py +125 -0
  164. yamlgraph/tools/nodes.py +126 -0
  165. yamlgraph/tools/python_tool.py +179 -0
  166. yamlgraph/tools/shell.py +205 -0
  167. yamlgraph/tools/websearch.py +242 -0
  168. yamlgraph/utils/__init__.py +48 -0
  169. yamlgraph/utils/conditions.py +157 -0
  170. yamlgraph/utils/expressions.py +245 -0
  171. yamlgraph/utils/json_extract.py +104 -0
  172. yamlgraph/utils/langsmith.py +416 -0
  173. yamlgraph/utils/llm_factory.py +118 -0
  174. yamlgraph/utils/llm_factory_async.py +105 -0
  175. yamlgraph/utils/logging.py +104 -0
  176. yamlgraph/utils/prompts.py +171 -0
  177. yamlgraph/utils/sanitize.py +98 -0
  178. yamlgraph/utils/template.py +102 -0
  179. yamlgraph/utils/validators.py +181 -0
  180. yamlgraph-0.3.9.dist-info/METADATA +1105 -0
  181. yamlgraph-0.3.9.dist-info/RECORD +185 -0
  182. yamlgraph-0.3.9.dist-info/WHEEL +5 -0
  183. yamlgraph-0.3.9.dist-info/entry_points.txt +2 -0
  184. yamlgraph-0.3.9.dist-info/licenses/LICENSE +33 -0
  185. yamlgraph-0.3.9.dist-info/top_level.txt +4 -0
@@ -0,0 +1,212 @@
1
+ """Tests for native LangGraph checkpointer integration."""
2
+
3
+ from pathlib import Path
4
+
5
+
6
+ class TestGetCheckpointer:
7
+ """Tests for get_checkpointer() function."""
8
+
9
+ def test_returns_sqlite_saver_instance(self, tmp_path: Path):
10
+ """Should return a SqliteSaver instance."""
11
+ from langgraph.checkpoint.sqlite import SqliteSaver
12
+
13
+ from yamlgraph.storage.checkpointer import get_checkpointer
14
+
15
+ db_path = tmp_path / "test.db"
16
+ checkpointer = get_checkpointer(db_path)
17
+
18
+ assert isinstance(checkpointer, SqliteSaver)
19
+
20
+ def test_creates_database_file(self, tmp_path: Path):
21
+ """Should create the database file on first use."""
22
+ from yamlgraph.storage.checkpointer import get_checkpointer
23
+
24
+ db_path = tmp_path / "test.db"
25
+ assert not db_path.exists()
26
+
27
+ checkpointer = get_checkpointer(db_path)
28
+ # Access connection to trigger file creation
29
+ _ = checkpointer
30
+
31
+ # File created when connection is made
32
+ assert db_path.exists()
33
+
34
+ def test_uses_default_path_when_none(self, monkeypatch, tmp_path: Path):
35
+ """Should use DATABASE_PATH when db_path is None."""
36
+ from yamlgraph.storage.checkpointer import get_checkpointer
37
+
38
+ # Monkeypatch the default path
39
+ default_db = tmp_path / "default.db"
40
+ monkeypatch.setattr("yamlgraph.storage.checkpointer.DATABASE_PATH", default_db)
41
+
42
+ checkpointer = get_checkpointer(None)
43
+ assert checkpointer is not None
44
+
45
+ def test_accepts_string_path(self, tmp_path: Path):
46
+ """Should accept string path as well as Path."""
47
+ from yamlgraph.storage.checkpointer import get_checkpointer
48
+
49
+ db_path = str(tmp_path / "test.db")
50
+ checkpointer = get_checkpointer(db_path)
51
+
52
+ assert checkpointer is not None
53
+
54
+
55
+ class TestCheckpointerWithGraph:
56
+ """Tests for using checkpointer with a LangGraph StateGraph."""
57
+
58
+ def test_graph_compiles_with_checkpointer(self, tmp_path: Path):
59
+ """Graph should compile when checkpointer is provided."""
60
+ from typing import TypedDict
61
+
62
+ from langgraph.graph import END, StateGraph
63
+
64
+ from yamlgraph.storage.checkpointer import get_checkpointer
65
+
66
+ class SimpleState(TypedDict, total=False):
67
+ value: str
68
+
69
+ def node_fn(state: SimpleState) -> dict:
70
+ return {"value": "updated"}
71
+
72
+ workflow = StateGraph(SimpleState)
73
+ workflow.add_node("test", node_fn)
74
+ workflow.set_entry_point("test")
75
+ workflow.add_edge("test", END)
76
+
77
+ checkpointer = get_checkpointer(tmp_path / "test.db")
78
+ graph = workflow.compile(checkpointer=checkpointer)
79
+
80
+ assert graph is not None
81
+
82
+ def test_state_persists_with_thread_id(self, tmp_path: Path):
83
+ """State should persist when using thread_id in config."""
84
+ from typing import TypedDict
85
+
86
+ from langgraph.graph import END, StateGraph
87
+
88
+ from yamlgraph.storage.checkpointer import get_checkpointer
89
+
90
+ class CounterState(TypedDict, total=False):
91
+ count: int
92
+
93
+ def increment(state: CounterState) -> dict:
94
+ return {"count": state.get("count", 0) + 1}
95
+
96
+ workflow = StateGraph(CounterState)
97
+ workflow.add_node("increment", increment)
98
+ workflow.set_entry_point("increment")
99
+ workflow.add_edge("increment", END)
100
+
101
+ checkpointer = get_checkpointer(tmp_path / "test.db")
102
+ graph = workflow.compile(checkpointer=checkpointer)
103
+
104
+ config = {"configurable": {"thread_id": "test-thread-1"}}
105
+
106
+ # First invocation
107
+ result1 = graph.invoke({"count": 0}, config)
108
+ assert result1["count"] == 1
109
+
110
+ # State should be retrievable
111
+ state = graph.get_state(config)
112
+ assert state.values["count"] == 1
113
+
114
+ def test_get_state_history_returns_checkpoints(self, tmp_path: Path):
115
+ """get_state_history() should return checkpoint history."""
116
+ from typing import TypedDict
117
+
118
+ from langgraph.graph import END, StateGraph
119
+
120
+ from yamlgraph.storage.checkpointer import get_checkpointer
121
+
122
+ class StepState(TypedDict, total=False):
123
+ step: int
124
+
125
+ def step1(state: StepState) -> dict:
126
+ return {"step": 1}
127
+
128
+ def step2(state: StepState) -> dict:
129
+ return {"step": 2}
130
+
131
+ workflow = StateGraph(StepState)
132
+ workflow.add_node("step1", step1)
133
+ workflow.add_node("step2", step2)
134
+ workflow.set_entry_point("step1")
135
+ workflow.add_edge("step1", "step2")
136
+ workflow.add_edge("step2", END)
137
+
138
+ checkpointer = get_checkpointer(tmp_path / "test.db")
139
+ graph = workflow.compile(checkpointer=checkpointer)
140
+
141
+ config = {"configurable": {"thread_id": "history-test"}}
142
+ graph.invoke({}, config)
143
+
144
+ # Get history
145
+ history = list(graph.get_state_history(config))
146
+
147
+ # Should have multiple checkpoints (one per step + initial)
148
+ assert len(history) >= 2
149
+
150
+ # Most recent first
151
+ assert history[0].values["step"] == 2
152
+
153
+
154
+ class TestGetStateHistory:
155
+ """Tests for get_state_history helper function."""
156
+
157
+ def test_returns_list_of_snapshots(self, tmp_path: Path):
158
+ """get_state_history should return list of StateSnapshot."""
159
+ from typing import TypedDict
160
+
161
+ from langgraph.graph import END, StateGraph
162
+
163
+ from yamlgraph.storage.checkpointer import get_checkpointer, get_state_history
164
+
165
+ class TestState(TypedDict, total=False):
166
+ data: str
167
+
168
+ def node(state: TestState) -> dict:
169
+ return {"data": "done"}
170
+
171
+ workflow = StateGraph(TestState)
172
+ workflow.add_node("test", node)
173
+ workflow.set_entry_point("test")
174
+ workflow.add_edge("test", END)
175
+
176
+ checkpointer = get_checkpointer(tmp_path / "test.db")
177
+ graph = workflow.compile(checkpointer=checkpointer)
178
+
179
+ thread_id = "history-helper-test"
180
+ config = {"configurable": {"thread_id": thread_id}}
181
+ graph.invoke({}, config)
182
+
183
+ history = get_state_history(graph, thread_id)
184
+
185
+ assert isinstance(history, list)
186
+ assert len(history) >= 1
187
+
188
+ def test_empty_history_for_unknown_thread(self, tmp_path: Path):
189
+ """Should return empty list for non-existent thread."""
190
+ from typing import TypedDict
191
+
192
+ from langgraph.graph import END, StateGraph
193
+
194
+ from yamlgraph.storage.checkpointer import get_checkpointer, get_state_history
195
+
196
+ class TestState(TypedDict, total=False):
197
+ data: str
198
+
199
+ def node(state: TestState) -> dict:
200
+ return {"data": "done"}
201
+
202
+ workflow = StateGraph(TestState)
203
+ workflow.add_node("test", node)
204
+ workflow.set_entry_point("test")
205
+ workflow.add_edge("test", END)
206
+
207
+ checkpointer = get_checkpointer(tmp_path / "test.db")
208
+ graph = workflow.compile(checkpointer=checkpointer)
209
+
210
+ history = get_state_history(graph, "non-existent-thread")
211
+
212
+ assert history == []
@@ -0,0 +1,212 @@
1
+ """Unit tests for checkpointer factory.
2
+
3
+ TDD tests for 002: Redis Checkpointer feature.
4
+ Tests get_checkpointer() factory with env var expansion and async mode.
5
+ """
6
+
7
+ import os
8
+ from unittest.mock import MagicMock, patch
9
+
10
+ import pytest
11
+
12
+ from yamlgraph.storage.checkpointer_factory import (
13
+ expand_env_vars,
14
+ get_checkpointer,
15
+ )
16
+
17
+
18
+ class TestExpandEnvVars:
19
+ """Test environment variable expansion."""
20
+
21
+ def test_expand_single_var(self):
22
+ """Should expand ${VAR} pattern."""
23
+ with patch.dict(os.environ, {"REDIS_URL": "redis://localhost:6379"}):
24
+ result = expand_env_vars("${REDIS_URL}")
25
+ assert result == "redis://localhost:6379"
26
+
27
+ def test_expand_multiple_vars(self):
28
+ """Should expand multiple ${VAR} patterns."""
29
+ with patch.dict(os.environ, {"HOST": "localhost", "PORT": "6379"}):
30
+ result = expand_env_vars("redis://${HOST}:${PORT}/0")
31
+ assert result == "redis://localhost:6379/0"
32
+
33
+ def test_expand_missing_var_keeps_original(self):
34
+ """Missing env vars should keep original ${VAR} pattern."""
35
+ # Ensure NONEXISTENT is not set
36
+ os.environ.pop("NONEXISTENT", None)
37
+ result = expand_env_vars("${NONEXISTENT}")
38
+ assert result == "${NONEXISTENT}"
39
+
40
+ def test_expand_non_string_returns_unchanged(self):
41
+ """Non-string values should pass through unchanged."""
42
+ assert expand_env_vars(123) == 123
43
+ assert expand_env_vars(None) is None
44
+ assert expand_env_vars(["a", "b"]) == ["a", "b"]
45
+
46
+ def test_expand_no_vars_returns_original(self):
47
+ """String without ${} should return unchanged."""
48
+ result = expand_env_vars("redis://localhost:6379")
49
+ assert result == "redis://localhost:6379"
50
+
51
+ def test_expand_empty_string(self):
52
+ """Empty string should return empty string."""
53
+ assert expand_env_vars("") == ""
54
+
55
+
56
+ class TestGetCheckpointerMemory:
57
+ """Test in-memory checkpointer (default)."""
58
+
59
+ def test_memory_checkpointer_default(self):
60
+ """Default type should be memory."""
61
+ config = {"type": "memory"} # Empty config defaults to memory via get
62
+ saver = get_checkpointer(config)
63
+
64
+ from langgraph.checkpoint.memory import InMemorySaver
65
+
66
+ assert isinstance(saver, InMemorySaver)
67
+
68
+ def test_memory_checkpointer_explicit(self):
69
+ """Explicit type: memory should work."""
70
+ config = {"type": "memory"}
71
+ saver = get_checkpointer(config)
72
+
73
+ from langgraph.checkpoint.memory import InMemorySaver
74
+
75
+ assert isinstance(saver, InMemorySaver)
76
+
77
+ def test_none_config_returns_none(self):
78
+ """None config should return None."""
79
+ assert get_checkpointer(None) is None
80
+
81
+
82
+ class TestGetCheckpointerSqlite:
83
+ """Test SQLite checkpointer."""
84
+
85
+ def test_sqlite_checkpointer_memory(self):
86
+ """SQLite with :memory: should work."""
87
+ config = {"type": "sqlite", "path": ":memory:"}
88
+ saver = get_checkpointer(config)
89
+
90
+ from langgraph.checkpoint.sqlite import SqliteSaver
91
+
92
+ assert isinstance(saver, SqliteSaver)
93
+
94
+ def test_sqlite_expands_env_var(self):
95
+ """SQLite path should expand env vars."""
96
+ with patch.dict(os.environ, {"DB_PATH": ":memory:"}):
97
+ config = {"type": "sqlite", "path": "${DB_PATH}"}
98
+ saver = get_checkpointer(config)
99
+
100
+ from langgraph.checkpoint.sqlite import SqliteSaver
101
+
102
+ assert isinstance(saver, SqliteSaver)
103
+
104
+
105
+ class TestGetCheckpointerRedis:
106
+ """Test Redis checkpointer (mocked)."""
107
+
108
+ def test_redis_checkpointer_sync(self):
109
+ """Redis sync saver should be created."""
110
+ mock_saver = MagicMock()
111
+ mock_redis_module = MagicMock()
112
+ mock_redis_module.RedisSaver.from_conn_string.return_value = mock_saver
113
+
114
+ with patch.dict(
115
+ "sys.modules", {"langgraph.checkpoint.redis": mock_redis_module}
116
+ ):
117
+ # Re-import to pick up mocked module
118
+ import importlib
119
+
120
+ from yamlgraph.storage import checkpointer_factory
121
+
122
+ importlib.reload(checkpointer_factory)
123
+
124
+ with patch.dict(os.environ, {"REDIS_URL": "redis://localhost:6379"}):
125
+ config = {"type": "redis", "url": "${REDIS_URL}", "ttl": 120}
126
+ saver = checkpointer_factory.get_checkpointer(config)
127
+
128
+ mock_redis_module.RedisSaver.from_conn_string.assert_called_once_with(
129
+ "redis://localhost:6379",
130
+ ttl={"default_ttl": 120},
131
+ )
132
+ mock_saver.setup.assert_called_once()
133
+ assert saver is mock_saver
134
+
135
+ def test_redis_checkpointer_async(self):
136
+ """Redis async saver should be created with async_mode=True."""
137
+ mock_saver = MagicMock()
138
+ mock_aio_module = MagicMock()
139
+ mock_aio_module.AsyncRedisSaver.from_conn_string.return_value = mock_saver
140
+
141
+ with patch.dict(
142
+ "sys.modules", {"langgraph.checkpoint.redis.aio": mock_aio_module}
143
+ ):
144
+ import importlib
145
+
146
+ from yamlgraph.storage import checkpointer_factory
147
+
148
+ importlib.reload(checkpointer_factory)
149
+
150
+ config = {"type": "redis", "url": "redis://localhost:6379", "ttl": 60}
151
+ saver = checkpointer_factory.get_checkpointer(config, async_mode=True)
152
+
153
+ mock_aio_module.AsyncRedisSaver.from_conn_string.assert_called_once_with(
154
+ "redis://localhost:6379",
155
+ ttl={"default_ttl": 60},
156
+ )
157
+ # Async saver should NOT call setup() - caller must await asetup()
158
+ mock_saver.setup.assert_not_called()
159
+ assert saver is mock_saver
160
+
161
+ def test_redis_import_error_helpful_message(self):
162
+ """Missing redis package should give helpful error."""
163
+ # Clear any cached imports
164
+ import sys
165
+
166
+ for key in list(sys.modules.keys()):
167
+ if "langgraph.checkpoint.redis" in key:
168
+ del sys.modules[key]
169
+
170
+ # This test verifies the ImportError wrapping
171
+ config = {"type": "redis", "url": "redis://localhost:6379"}
172
+
173
+ with pytest.raises(ImportError) as exc_info:
174
+ get_checkpointer(config)
175
+
176
+ assert "pip install yamlgraph[redis]" in str(exc_info.value)
177
+
178
+ def test_redis_default_ttl(self):
179
+ """Redis should use default TTL of 60 if not specified."""
180
+ mock_saver = MagicMock()
181
+ mock_redis_module = MagicMock()
182
+ mock_redis_module.RedisSaver.from_conn_string.return_value = mock_saver
183
+
184
+ with patch.dict(
185
+ "sys.modules", {"langgraph.checkpoint.redis": mock_redis_module}
186
+ ):
187
+ import importlib
188
+
189
+ from yamlgraph.storage import checkpointer_factory
190
+
191
+ importlib.reload(checkpointer_factory)
192
+
193
+ config = {"type": "redis", "url": "redis://localhost:6379"}
194
+ checkpointer_factory.get_checkpointer(config)
195
+
196
+ mock_redis_module.RedisSaver.from_conn_string.assert_called_once_with(
197
+ "redis://localhost:6379",
198
+ ttl={"default_ttl": 60},
199
+ )
200
+
201
+
202
+ class TestGetCheckpointerErrors:
203
+ """Test error handling."""
204
+
205
+ def test_unknown_type_raises_error(self):
206
+ """Unknown checkpointer type should raise ValueError."""
207
+ config = {"type": "unknown_db"}
208
+
209
+ with pytest.raises(ValueError) as exc_info:
210
+ get_checkpointer(config)
211
+
212
+ assert "Unknown checkpointer type: unknown_db" in str(exc_info.value)
tests/unit/test_cli.py ADDED
@@ -0,0 +1,121 @@
1
+ """Tests for yamlgraph.cli module."""
2
+
3
+ import argparse
4
+
5
+ from yamlgraph.cli.validators import validate_run_args
6
+ from yamlgraph.config import MAX_TOPIC_LENGTH, MAX_WORD_COUNT, MIN_WORD_COUNT
7
+
8
+
9
+ class TestValidateRunArgs:
10
+ """Tests for validate_run_args function."""
11
+
12
+ def _create_args(self, topic="test topic", word_count=300, style="informative"):
13
+ """Helper to create args namespace."""
14
+ return argparse.Namespace(
15
+ topic=topic,
16
+ word_count=word_count,
17
+ style=style,
18
+ )
19
+
20
+ def test_valid_args(self):
21
+ """Valid arguments should pass validation."""
22
+ args = self._create_args()
23
+ assert validate_run_args(args) is True
24
+
25
+ def test_empty_topic(self):
26
+ """Empty topic should fail validation."""
27
+ args = self._create_args(topic="")
28
+ assert validate_run_args(args) is False
29
+
30
+ def test_whitespace_only_topic(self):
31
+ """Whitespace-only topic should fail validation."""
32
+ args = self._create_args(topic=" ")
33
+ assert validate_run_args(args) is False
34
+
35
+ def test_topic_too_long(self):
36
+ """Topic exceeding max length should be truncated with warning."""
37
+ long_topic = "x" * (MAX_TOPIC_LENGTH + 100)
38
+ args = self._create_args(topic=long_topic)
39
+ # Should pass but truncate the topic
40
+ assert validate_run_args(args) is True
41
+ assert len(args.topic) == MAX_TOPIC_LENGTH
42
+
43
+ def test_topic_at_max_length(self):
44
+ """Topic at max length should pass validation."""
45
+ max_topic = "x" * MAX_TOPIC_LENGTH
46
+ args = self._create_args(topic=max_topic)
47
+ assert validate_run_args(args) is True
48
+
49
+ def test_word_count_too_low(self):
50
+ """Word count below minimum should fail validation."""
51
+ args = self._create_args(word_count=MIN_WORD_COUNT - 1)
52
+ assert validate_run_args(args) is False
53
+
54
+ def test_word_count_too_high(self):
55
+ """Word count above maximum should fail validation."""
56
+ args = self._create_args(word_count=MAX_WORD_COUNT + 1)
57
+ assert validate_run_args(args) is False
58
+
59
+ def test_word_count_at_min(self):
60
+ """Word count at minimum should pass validation."""
61
+ args = self._create_args(word_count=MIN_WORD_COUNT)
62
+ assert validate_run_args(args) is True
63
+
64
+ def test_word_count_at_max(self):
65
+ """Word count at maximum should pass validation."""
66
+ args = self._create_args(word_count=MAX_WORD_COUNT)
67
+ assert validate_run_args(args) is True
68
+
69
+
70
+ class TestFormatResult:
71
+ """Tests for generic result formatting."""
72
+
73
+ def test_format_result_with_any_pydantic_model(self, capsys):
74
+ """CLI should format any Pydantic model, not just known ones."""
75
+ from pydantic import BaseModel
76
+
77
+ from yamlgraph.cli.commands import _format_result
78
+
79
+ class CustomResult(BaseModel):
80
+ title: str
81
+ score: float
82
+ items: list[str]
83
+
84
+ result = {
85
+ "current_step": "done",
86
+ "custom": CustomResult(title="Test Title", score=0.95, items=["a", "b"]),
87
+ }
88
+
89
+ _format_result(result)
90
+ captured = capsys.readouterr()
91
+ assert "Test Title" in captured.out
92
+ assert "0.95" in captured.out
93
+
94
+ def test_format_result_skips_internal_keys(self, capsys):
95
+ """Internal keys should be skipped."""
96
+ from yamlgraph.cli.commands import _format_result
97
+
98
+ result = {
99
+ "current_step": "done",
100
+ "_route": "positive",
101
+ "_loop_counts": {"draft": 2},
102
+ "response": "Hello!",
103
+ }
104
+
105
+ _format_result(result)
106
+ captured = capsys.readouterr()
107
+ assert "_route" not in captured.out
108
+ assert "_loop_counts" not in captured.out
109
+ assert "Hello!" in captured.out
110
+
111
+ def test_format_result_truncates_long_strings(self, capsys):
112
+ """Long strings should be truncated."""
113
+ from yamlgraph.cli.commands import _format_result
114
+
115
+ long_text = "x" * 500
116
+ result = {"summary": long_text}
117
+
118
+ _format_result(result)
119
+ captured = capsys.readouterr()
120
+ assert "..." in captured.out
121
+ assert len(captured.out) < 400 # Truncated
@@ -0,0 +1,81 @@
1
+ """Tests for CLI package structure (Phase 7.1).
2
+
3
+ TDD tests for splitting cli.py into a cli/ package.
4
+ """
5
+
6
+ import argparse
7
+
8
+ # =============================================================================
9
+ # Package Structure Tests
10
+ # =============================================================================
11
+
12
+
13
+ class TestCLIPackageStructure:
14
+ """Tests for CLI package imports."""
15
+
16
+ def test_cli_package_importable(self):
17
+ """yamlgraph.cli should be importable as package."""
18
+ import yamlgraph.cli
19
+
20
+ assert yamlgraph.cli is not None
21
+
22
+ def test_main_function_available(self):
23
+ """main() should be available from package."""
24
+ from yamlgraph.cli import main
25
+
26
+ assert callable(main)
27
+
28
+ def test_validators_submodule_exists(self):
29
+ """validators submodule should exist."""
30
+ from yamlgraph.cli import validators
31
+
32
+ assert validators is not None
33
+
34
+ def test_validate_run_args_in_validators(self):
35
+ """validate_run_args should be in validators module."""
36
+ from yamlgraph.cli.validators import validate_run_args
37
+
38
+ assert callable(validate_run_args)
39
+
40
+ def test_commands_submodule_exists(self):
41
+ """commands submodule should exist."""
42
+ from yamlgraph.cli import commands
43
+
44
+ assert commands is not None
45
+
46
+ def test_cmd_list_runs_in_commands(self):
47
+ """cmd_list_runs should be in commands module."""
48
+ from yamlgraph.cli.commands import cmd_list_runs
49
+
50
+ assert callable(cmd_list_runs)
51
+
52
+
53
+ # =============================================================================
54
+ # Validator Tests (moved from cli module)
55
+ # =============================================================================
56
+
57
+
58
+ class TestValidatorsModule:
59
+ """Tests for validators module functionality."""
60
+
61
+ def _create_run_args(self, topic="test topic", word_count=300, style="informative"):
62
+ """Helper to create run args namespace."""
63
+ return argparse.Namespace(
64
+ topic=topic,
65
+ word_count=word_count,
66
+ style=style,
67
+ )
68
+
69
+ def test_validate_run_args_valid(self):
70
+ """Valid run args pass validation."""
71
+ from yamlgraph.cli.validators import validate_run_args
72
+
73
+ args = self._create_run_args()
74
+ assert validate_run_args(args) is True
75
+
76
+ def test_validate_run_args_empty_topic(self):
77
+ """Empty topic fails validation."""
78
+ from yamlgraph.cli.validators import validate_run_args
79
+
80
+ args = self._create_run_args(topic="")
81
+ assert validate_run_args(args) is False