yamlgraph 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of yamlgraph might be problematic. Click here for more details.
- examples/__init__.py +1 -0
- examples/storyboard/__init__.py +1 -0
- examples/storyboard/generate_videos.py +335 -0
- examples/storyboard/nodes/__init__.py +10 -0
- examples/storyboard/nodes/animated_character_node.py +248 -0
- examples/storyboard/nodes/animated_image_node.py +138 -0
- examples/storyboard/nodes/character_node.py +162 -0
- examples/storyboard/nodes/image_node.py +118 -0
- examples/storyboard/nodes/replicate_tool.py +238 -0
- examples/storyboard/retry_images.py +118 -0
- tests/__init__.py +1 -0
- tests/conftest.py +178 -0
- tests/integration/__init__.py +1 -0
- tests/integration/test_animated_storyboard.py +63 -0
- tests/integration/test_cli_commands.py +242 -0
- tests/integration/test_map_demo.py +50 -0
- tests/integration/test_memory_demo.py +281 -0
- tests/integration/test_pipeline_flow.py +105 -0
- tests/integration/test_providers.py +163 -0
- tests/integration/test_resume.py +75 -0
- tests/unit/__init__.py +1 -0
- tests/unit/test_agent_nodes.py +200 -0
- tests/unit/test_checkpointer.py +212 -0
- tests/unit/test_cli.py +121 -0
- tests/unit/test_cli_package.py +81 -0
- tests/unit/test_compile_graph_map.py +132 -0
- tests/unit/test_conditions_routing.py +253 -0
- tests/unit/test_config.py +93 -0
- tests/unit/test_conversation_memory.py +270 -0
- tests/unit/test_database.py +145 -0
- tests/unit/test_deprecation.py +104 -0
- tests/unit/test_executor.py +60 -0
- tests/unit/test_executor_async.py +179 -0
- tests/unit/test_export.py +150 -0
- tests/unit/test_expressions.py +178 -0
- tests/unit/test_format_prompt.py +145 -0
- tests/unit/test_generic_report.py +200 -0
- tests/unit/test_graph_commands.py +327 -0
- tests/unit/test_graph_loader.py +299 -0
- tests/unit/test_graph_schema.py +193 -0
- tests/unit/test_inline_schema.py +151 -0
- tests/unit/test_issues.py +164 -0
- tests/unit/test_jinja2_prompts.py +85 -0
- tests/unit/test_langsmith.py +319 -0
- tests/unit/test_llm_factory.py +109 -0
- tests/unit/test_llm_factory_async.py +118 -0
- tests/unit/test_loops.py +403 -0
- tests/unit/test_map_node.py +144 -0
- tests/unit/test_no_backward_compat.py +56 -0
- tests/unit/test_node_factory.py +225 -0
- tests/unit/test_prompts.py +166 -0
- tests/unit/test_python_nodes.py +198 -0
- tests/unit/test_reliability.py +298 -0
- tests/unit/test_result_export.py +234 -0
- tests/unit/test_router.py +296 -0
- tests/unit/test_sanitize.py +99 -0
- tests/unit/test_schema_loader.py +295 -0
- tests/unit/test_shell_tools.py +229 -0
- tests/unit/test_state_builder.py +331 -0
- tests/unit/test_state_builder_map.py +104 -0
- tests/unit/test_state_config.py +197 -0
- tests/unit/test_template.py +190 -0
- tests/unit/test_tool_nodes.py +129 -0
- yamlgraph/__init__.py +35 -0
- yamlgraph/builder.py +110 -0
- yamlgraph/cli/__init__.py +139 -0
- yamlgraph/cli/__main__.py +6 -0
- yamlgraph/cli/commands.py +232 -0
- yamlgraph/cli/deprecation.py +92 -0
- yamlgraph/cli/graph_commands.py +382 -0
- yamlgraph/cli/validators.py +37 -0
- yamlgraph/config.py +67 -0
- yamlgraph/constants.py +66 -0
- yamlgraph/error_handlers.py +226 -0
- yamlgraph/executor.py +275 -0
- yamlgraph/executor_async.py +122 -0
- yamlgraph/graph_loader.py +337 -0
- yamlgraph/map_compiler.py +138 -0
- yamlgraph/models/__init__.py +36 -0
- yamlgraph/models/graph_schema.py +141 -0
- yamlgraph/models/schemas.py +124 -0
- yamlgraph/models/state_builder.py +236 -0
- yamlgraph/node_factory.py +240 -0
- yamlgraph/routing.py +87 -0
- yamlgraph/schema_loader.py +160 -0
- yamlgraph/storage/__init__.py +17 -0
- yamlgraph/storage/checkpointer.py +72 -0
- yamlgraph/storage/database.py +320 -0
- yamlgraph/storage/export.py +269 -0
- yamlgraph/tools/__init__.py +1 -0
- yamlgraph/tools/agent.py +235 -0
- yamlgraph/tools/nodes.py +124 -0
- yamlgraph/tools/python_tool.py +178 -0
- yamlgraph/tools/shell.py +205 -0
- yamlgraph/utils/__init__.py +47 -0
- yamlgraph/utils/conditions.py +157 -0
- yamlgraph/utils/expressions.py +111 -0
- yamlgraph/utils/langsmith.py +308 -0
- yamlgraph/utils/llm_factory.py +118 -0
- yamlgraph/utils/llm_factory_async.py +105 -0
- yamlgraph/utils/logging.py +127 -0
- yamlgraph/utils/prompts.py +116 -0
- yamlgraph/utils/sanitize.py +98 -0
- yamlgraph/utils/template.py +102 -0
- yamlgraph/utils/validators.py +181 -0
- yamlgraph-0.1.1.dist-info/METADATA +854 -0
- yamlgraph-0.1.1.dist-info/RECORD +111 -0
- yamlgraph-0.1.1.dist-info/WHEEL +5 -0
- yamlgraph-0.1.1.dist-info/entry_points.txt +2 -0
- yamlgraph-0.1.1.dist-info/licenses/LICENSE +21 -0
- yamlgraph-0.1.1.dist-info/top_level.txt +3 -0
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
"""Test that 'backward compatibility' markers are cleaned up in source code.
|
|
2
|
+
|
|
3
|
+
Per project guidelines (.github/copilot-instructions.md):
|
|
4
|
+
"Term 'backward compatibility' is a key indicator for a refactoring need."
|
|
5
|
+
|
|
6
|
+
This test fails if any Python source files contain backward compatibility markers,
|
|
7
|
+
ensuring deprecated code gets cleaned up rather than accumulating.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import subprocess
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class TestNoBackwardCompatibilityMarkers:
|
|
15
|
+
"""Ensure no backward compatibility markers exist in source code."""
|
|
16
|
+
|
|
17
|
+
def test_no_backward_compat_in_yamlgraph_source(self):
|
|
18
|
+
"""Source files should not contain 'backward compatibility' markers.
|
|
19
|
+
|
|
20
|
+
Allowed exceptions:
|
|
21
|
+
- deprecation.py: Documents the DeprecationError pattern
|
|
22
|
+
- Tests in this file
|
|
23
|
+
"""
|
|
24
|
+
project_root = Path(__file__).parent.parent.parent
|
|
25
|
+
yamlgraph_dir = project_root / "yamlgraph"
|
|
26
|
+
|
|
27
|
+
result = subprocess.run(
|
|
28
|
+
[
|
|
29
|
+
"grep",
|
|
30
|
+
"-rn",
|
|
31
|
+
"-i",
|
|
32
|
+
"backward compatib",
|
|
33
|
+
str(yamlgraph_dir),
|
|
34
|
+
"--include=*.py",
|
|
35
|
+
],
|
|
36
|
+
capture_output=True,
|
|
37
|
+
text=True,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
if result.returncode == 0: # Found matches
|
|
41
|
+
lines = result.stdout.strip().split("\n")
|
|
42
|
+
# Filter out allowed files
|
|
43
|
+
violations = [
|
|
44
|
+
line
|
|
45
|
+
for line in lines
|
|
46
|
+
if "deprecation.py" not in line # Pattern documentation
|
|
47
|
+
]
|
|
48
|
+
|
|
49
|
+
if violations:
|
|
50
|
+
msg = (
|
|
51
|
+
"Found 'backward compatibility' markers in source code.\n"
|
|
52
|
+
"Per guidelines, this signals refactoring need.\n"
|
|
53
|
+
"Clean up deprecated code or move to deprecation.py.\n\n"
|
|
54
|
+
"Violations:\n" + "\n".join(f" {v}" for v in violations)
|
|
55
|
+
)
|
|
56
|
+
raise AssertionError(msg)
|
|
@@ -0,0 +1,225 @@
|
|
|
1
|
+
"""Tests for node factory - class resolution, template resolution, and node creation.
|
|
2
|
+
|
|
3
|
+
Split from test_graph_loader.py for better organization and file size management.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from unittest.mock import patch
|
|
7
|
+
|
|
8
|
+
import pytest
|
|
9
|
+
|
|
10
|
+
from tests.conftest import FixtureGeneratedContent
|
|
11
|
+
from yamlgraph.node_factory import (
|
|
12
|
+
create_node_function,
|
|
13
|
+
resolve_class,
|
|
14
|
+
resolve_template,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
# =============================================================================
|
|
18
|
+
# Fixtures
|
|
19
|
+
# =============================================================================
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@pytest.fixture
|
|
23
|
+
def sample_state():
|
|
24
|
+
"""Sample pipeline state."""
|
|
25
|
+
return {
|
|
26
|
+
"thread_id": "test-123",
|
|
27
|
+
"topic": "machine learning",
|
|
28
|
+
"style": "informative",
|
|
29
|
+
"word_count": 300,
|
|
30
|
+
"generated": None,
|
|
31
|
+
"analysis": None,
|
|
32
|
+
"final_summary": None,
|
|
33
|
+
"current_step": "init",
|
|
34
|
+
"error": None,
|
|
35
|
+
"errors": [],
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@pytest.fixture
|
|
40
|
+
def state_with_generated(sample_state):
|
|
41
|
+
"""State with generated content."""
|
|
42
|
+
state = dict(sample_state)
|
|
43
|
+
state["generated"] = FixtureGeneratedContent(
|
|
44
|
+
title="Test Title",
|
|
45
|
+
content="Test content about ML.",
|
|
46
|
+
word_count=50,
|
|
47
|
+
tags=["test"],
|
|
48
|
+
)
|
|
49
|
+
return state
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
# =============================================================================
|
|
53
|
+
# TestResolveClass
|
|
54
|
+
# =============================================================================
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class TestResolveClass:
|
|
58
|
+
"""Tests for dynamic class importing."""
|
|
59
|
+
|
|
60
|
+
def test_resolve_existing_class(self):
|
|
61
|
+
"""Import a real class from dotted path."""
|
|
62
|
+
cls = resolve_class("yamlgraph.models.GenericReport")
|
|
63
|
+
# Just verify it resolves to a class with expected attributes
|
|
64
|
+
assert cls is not None
|
|
65
|
+
assert hasattr(cls, "model_fields") # Pydantic model check
|
|
66
|
+
|
|
67
|
+
def test_resolve_state_class(self):
|
|
68
|
+
"""Dynamic state class can be built."""
|
|
69
|
+
from yamlgraph.models.state_builder import build_state_class
|
|
70
|
+
|
|
71
|
+
cls = build_state_class({"nodes": {}})
|
|
72
|
+
# Dynamic state is a TypedDict
|
|
73
|
+
assert cls is not None
|
|
74
|
+
assert hasattr(cls, "__annotations__")
|
|
75
|
+
|
|
76
|
+
def test_resolve_invalid_module_raises(self):
|
|
77
|
+
"""Invalid module raises ImportError."""
|
|
78
|
+
with pytest.raises((ImportError, ModuleNotFoundError)):
|
|
79
|
+
resolve_class("nonexistent.module.Class")
|
|
80
|
+
|
|
81
|
+
def test_resolve_invalid_class_raises(self):
|
|
82
|
+
"""Invalid class name raises AttributeError."""
|
|
83
|
+
with pytest.raises(AttributeError):
|
|
84
|
+
resolve_class("yamlgraph.models.NonexistentClass")
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
# =============================================================================
|
|
88
|
+
# TestResolveTemplate
|
|
89
|
+
# =============================================================================
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class TestResolveTemplate:
|
|
93
|
+
"""Tests for template resolution against state."""
|
|
94
|
+
|
|
95
|
+
def test_simple_state_access(self, sample_state):
|
|
96
|
+
"""'{state.topic}' resolves to state['topic']."""
|
|
97
|
+
result = resolve_template("{state.topic}", sample_state)
|
|
98
|
+
assert result == "machine learning"
|
|
99
|
+
|
|
100
|
+
def test_nested_state_access(self, state_with_generated):
|
|
101
|
+
"""'{state.generated.content}' resolves nested attrs."""
|
|
102
|
+
result = resolve_template("{state.generated.content}", state_with_generated)
|
|
103
|
+
assert result == "Test content about ML."
|
|
104
|
+
|
|
105
|
+
def test_missing_state_returns_none(self, sample_state):
|
|
106
|
+
"""Missing state key returns None."""
|
|
107
|
+
result = resolve_template("{state.generated.content}", sample_state)
|
|
108
|
+
assert result is None
|
|
109
|
+
|
|
110
|
+
def test_literal_string_unchanged(self, sample_state):
|
|
111
|
+
"""Non-template strings returned as-is."""
|
|
112
|
+
result = resolve_template("literal value", sample_state)
|
|
113
|
+
assert result == "literal value"
|
|
114
|
+
|
|
115
|
+
def test_int_access(self, sample_state):
|
|
116
|
+
"""Integer values resolved correctly."""
|
|
117
|
+
result = resolve_template("{state.word_count}", sample_state)
|
|
118
|
+
assert result == 300
|
|
119
|
+
|
|
120
|
+
def test_list_access(self, state_with_generated):
|
|
121
|
+
"""List values resolved correctly."""
|
|
122
|
+
result = resolve_template("{state.generated.tags}", state_with_generated)
|
|
123
|
+
assert result == ["test"]
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
# =============================================================================
|
|
127
|
+
# TestCreateNodeFunction
|
|
128
|
+
# =============================================================================
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class TestCreateNodeFunction:
|
|
132
|
+
"""Tests for node function factory."""
|
|
133
|
+
|
|
134
|
+
def test_node_calls_execute_prompt(self, sample_state):
|
|
135
|
+
"""Generated node calls execute_prompt with config."""
|
|
136
|
+
node_config = {
|
|
137
|
+
"type": "llm",
|
|
138
|
+
"prompt": "generate",
|
|
139
|
+
"output_model": "yamlgraph.models.GenericReport",
|
|
140
|
+
"temperature": 0.8,
|
|
141
|
+
"variables": {"topic": "{state.topic}"},
|
|
142
|
+
"state_key": "generated",
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
mock_result = FixtureGeneratedContent(
|
|
146
|
+
title="Test",
|
|
147
|
+
content="Content",
|
|
148
|
+
word_count=100,
|
|
149
|
+
tags=[],
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
with patch(
|
|
153
|
+
"yamlgraph.node_factory.execute_prompt", return_value=mock_result
|
|
154
|
+
) as mock:
|
|
155
|
+
node_fn = create_node_function(
|
|
156
|
+
"generate", node_config, {"provider": "mistral"}
|
|
157
|
+
)
|
|
158
|
+
result = node_fn(sample_state)
|
|
159
|
+
|
|
160
|
+
mock.assert_called_once()
|
|
161
|
+
call_kwargs = mock.call_args
|
|
162
|
+
assert call_kwargs[1]["prompt_name"] == "generate"
|
|
163
|
+
assert call_kwargs[1]["temperature"] == 0.8
|
|
164
|
+
assert call_kwargs[1]["variables"]["topic"] == "machine learning"
|
|
165
|
+
|
|
166
|
+
assert result["generated"] == mock_result
|
|
167
|
+
assert result["current_step"] == "generate"
|
|
168
|
+
|
|
169
|
+
def test_node_checks_requirements(self, sample_state):
|
|
170
|
+
"""Node returns error if requires not met."""
|
|
171
|
+
node_config = {
|
|
172
|
+
"type": "llm",
|
|
173
|
+
"prompt": "analyze",
|
|
174
|
+
"variables": {},
|
|
175
|
+
"state_key": "analysis",
|
|
176
|
+
"requires": ["generated"], # generated is None in sample_state
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
node_fn = create_node_function("analyze", node_config, {})
|
|
180
|
+
result = node_fn(sample_state)
|
|
181
|
+
|
|
182
|
+
assert result.get("errors")
|
|
183
|
+
assert "generated" in result["errors"][0].message
|
|
184
|
+
|
|
185
|
+
def test_node_handles_exception(self, sample_state):
|
|
186
|
+
"""Exceptions become PipelineError."""
|
|
187
|
+
node_config = {
|
|
188
|
+
"type": "llm",
|
|
189
|
+
"prompt": "generate",
|
|
190
|
+
"variables": {"topic": "{state.topic}"},
|
|
191
|
+
"state_key": "generated",
|
|
192
|
+
}
|
|
193
|
+
|
|
194
|
+
with patch(
|
|
195
|
+
"yamlgraph.node_factory.execute_prompt", side_effect=ValueError("API Error")
|
|
196
|
+
):
|
|
197
|
+
node_fn = create_node_function("generate", node_config, {})
|
|
198
|
+
result = node_fn(sample_state)
|
|
199
|
+
|
|
200
|
+
assert result.get("errors")
|
|
201
|
+
assert "API Error" in result["errors"][0].message
|
|
202
|
+
|
|
203
|
+
def test_node_uses_defaults(self, sample_state):
|
|
204
|
+
"""Node uses default provider/temperature from config."""
|
|
205
|
+
node_config = {
|
|
206
|
+
"type": "llm",
|
|
207
|
+
"prompt": "generate",
|
|
208
|
+
"variables": {},
|
|
209
|
+
"state_key": "generated",
|
|
210
|
+
# No temperature specified - should use default
|
|
211
|
+
}
|
|
212
|
+
defaults = {"provider": "anthropic", "temperature": 0.5}
|
|
213
|
+
|
|
214
|
+
mock_result = FixtureGeneratedContent(
|
|
215
|
+
title="T", content="C", word_count=1, tags=[]
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
with patch(
|
|
219
|
+
"yamlgraph.node_factory.execute_prompt", return_value=mock_result
|
|
220
|
+
) as mock:
|
|
221
|
+
node_fn = create_node_function("generate", node_config, defaults)
|
|
222
|
+
node_fn(sample_state)
|
|
223
|
+
|
|
224
|
+
assert mock.call_args[1]["temperature"] == 0.5
|
|
225
|
+
assert mock.call_args[1]["provider"] == "anthropic"
|
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
"""Tests for yamlgraph.utils.prompts module.
|
|
2
|
+
|
|
3
|
+
TDD: Red phase - write tests before implementation.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
import pytest
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class TestResolvePromptPath:
|
|
12
|
+
"""Tests for resolve_prompt_path function."""
|
|
13
|
+
|
|
14
|
+
def test_resolve_standard_prompt(self, tmp_path: Path):
|
|
15
|
+
"""Should resolve prompt in standard prompts/ directory."""
|
|
16
|
+
from yamlgraph.utils.prompts import resolve_prompt_path
|
|
17
|
+
|
|
18
|
+
# Create temp prompt file
|
|
19
|
+
prompts_dir = tmp_path / "prompts"
|
|
20
|
+
prompts_dir.mkdir()
|
|
21
|
+
prompt_file = prompts_dir / "greet.yaml"
|
|
22
|
+
prompt_file.write_text("system: Hello\nuser: Hi {name}")
|
|
23
|
+
|
|
24
|
+
result = resolve_prompt_path("greet", prompts_dir=prompts_dir)
|
|
25
|
+
|
|
26
|
+
assert result == prompt_file
|
|
27
|
+
assert result.exists()
|
|
28
|
+
|
|
29
|
+
def test_resolve_nested_prompt(self, tmp_path: Path):
|
|
30
|
+
"""Should resolve nested prompt like map-demo/generate_ideas."""
|
|
31
|
+
from yamlgraph.utils.prompts import resolve_prompt_path
|
|
32
|
+
|
|
33
|
+
# Create nested prompt structure
|
|
34
|
+
prompts_dir = tmp_path / "prompts"
|
|
35
|
+
nested_dir = prompts_dir / "map-demo"
|
|
36
|
+
nested_dir.mkdir(parents=True)
|
|
37
|
+
prompt_file = nested_dir / "generate_ideas.yaml"
|
|
38
|
+
prompt_file.write_text("system: Generate\nuser: {topic}")
|
|
39
|
+
|
|
40
|
+
result = resolve_prompt_path("map-demo/generate_ideas", prompts_dir=prompts_dir)
|
|
41
|
+
|
|
42
|
+
assert result == prompt_file
|
|
43
|
+
|
|
44
|
+
def test_resolve_external_example_prompt(self, tmp_path: Path, monkeypatch):
|
|
45
|
+
"""Should resolve external example like examples/storyboard/expand_story."""
|
|
46
|
+
from yamlgraph.utils.prompts import resolve_prompt_path
|
|
47
|
+
|
|
48
|
+
# Create external example structure: {parent}/prompts/{basename}.yaml
|
|
49
|
+
example_dir = tmp_path / "examples" / "storyboard"
|
|
50
|
+
prompts_subdir = example_dir / "prompts"
|
|
51
|
+
prompts_subdir.mkdir(parents=True)
|
|
52
|
+
prompt_file = prompts_subdir / "expand_story.yaml"
|
|
53
|
+
prompt_file.write_text("system: Expand\nuser: {story}")
|
|
54
|
+
|
|
55
|
+
# Change to tmp_path so relative paths resolve correctly
|
|
56
|
+
monkeypatch.chdir(tmp_path)
|
|
57
|
+
|
|
58
|
+
# Standard prompts dir doesn't have it, should fall back to external
|
|
59
|
+
prompts_dir = tmp_path / "prompts"
|
|
60
|
+
prompts_dir.mkdir(exist_ok=True)
|
|
61
|
+
|
|
62
|
+
result = resolve_prompt_path(
|
|
63
|
+
"examples/storyboard/expand_story",
|
|
64
|
+
prompts_dir=prompts_dir,
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
assert result.resolve() == prompt_file.resolve()
|
|
68
|
+
|
|
69
|
+
def test_resolve_nonexistent_raises(self, tmp_path: Path):
|
|
70
|
+
"""Should raise FileNotFoundError for missing prompt."""
|
|
71
|
+
from yamlgraph.utils.prompts import resolve_prompt_path
|
|
72
|
+
|
|
73
|
+
prompts_dir = tmp_path / "prompts"
|
|
74
|
+
prompts_dir.mkdir()
|
|
75
|
+
|
|
76
|
+
with pytest.raises(FileNotFoundError, match="Prompt not found"):
|
|
77
|
+
resolve_prompt_path("nonexistent", prompts_dir=prompts_dir)
|
|
78
|
+
|
|
79
|
+
def test_resolve_uses_default_prompts_dir(self):
|
|
80
|
+
"""Should use PROMPTS_DIR from config when not specified."""
|
|
81
|
+
from yamlgraph.utils.prompts import resolve_prompt_path
|
|
82
|
+
|
|
83
|
+
# This should find the real greet.yaml in prompts/
|
|
84
|
+
result = resolve_prompt_path("greet")
|
|
85
|
+
|
|
86
|
+
assert result.exists()
|
|
87
|
+
assert result.name == "greet.yaml"
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class TestLoadPrompt:
|
|
91
|
+
"""Tests for load_prompt function."""
|
|
92
|
+
|
|
93
|
+
def test_load_existing_prompt(self, tmp_path: Path):
|
|
94
|
+
"""Should load and parse YAML prompt file."""
|
|
95
|
+
from yamlgraph.utils.prompts import load_prompt
|
|
96
|
+
|
|
97
|
+
# Create temp prompt file
|
|
98
|
+
prompts_dir = tmp_path / "prompts"
|
|
99
|
+
prompts_dir.mkdir()
|
|
100
|
+
prompt_file = prompts_dir / "test.yaml"
|
|
101
|
+
prompt_file.write_text("system: You are helpful\nuser: Hello {name}")
|
|
102
|
+
|
|
103
|
+
result = load_prompt("test", prompts_dir=prompts_dir)
|
|
104
|
+
|
|
105
|
+
assert result["system"] == "You are helpful"
|
|
106
|
+
assert result["user"] == "Hello {name}"
|
|
107
|
+
|
|
108
|
+
def test_load_prompt_with_schema(self, tmp_path: Path):
|
|
109
|
+
"""Should load prompt with inline schema section."""
|
|
110
|
+
from yamlgraph.utils.prompts import load_prompt
|
|
111
|
+
|
|
112
|
+
prompts_dir = tmp_path / "prompts"
|
|
113
|
+
prompts_dir.mkdir()
|
|
114
|
+
prompt_file = prompts_dir / "structured.yaml"
|
|
115
|
+
prompt_file.write_text("""
|
|
116
|
+
system: Analyze content
|
|
117
|
+
user: "{content}"
|
|
118
|
+
schema:
|
|
119
|
+
name: Analysis
|
|
120
|
+
fields:
|
|
121
|
+
summary:
|
|
122
|
+
type: str
|
|
123
|
+
description: Brief summary
|
|
124
|
+
""")
|
|
125
|
+
|
|
126
|
+
result = load_prompt("structured", prompts_dir=prompts_dir)
|
|
127
|
+
|
|
128
|
+
assert "schema" in result
|
|
129
|
+
assert result["schema"]["name"] == "Analysis"
|
|
130
|
+
|
|
131
|
+
def test_load_nonexistent_raises(self, tmp_path: Path):
|
|
132
|
+
"""Should raise FileNotFoundError for missing prompt."""
|
|
133
|
+
from yamlgraph.utils.prompts import load_prompt
|
|
134
|
+
|
|
135
|
+
prompts_dir = tmp_path / "prompts"
|
|
136
|
+
prompts_dir.mkdir()
|
|
137
|
+
|
|
138
|
+
with pytest.raises(FileNotFoundError):
|
|
139
|
+
load_prompt("missing", prompts_dir=prompts_dir)
|
|
140
|
+
|
|
141
|
+
def test_load_real_generate_prompt(self):
|
|
142
|
+
"""Should load the real generate.yaml from prompts/."""
|
|
143
|
+
from yamlgraph.utils.prompts import load_prompt
|
|
144
|
+
|
|
145
|
+
result = load_prompt("generate")
|
|
146
|
+
|
|
147
|
+
assert "system" in result
|
|
148
|
+
assert "user" in result
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
class TestLoadPromptPath:
|
|
152
|
+
"""Tests for load_prompt_path (returns Path + parsed content)."""
|
|
153
|
+
|
|
154
|
+
def test_load_prompt_path_returns_both(self, tmp_path: Path):
|
|
155
|
+
"""Should return both path and parsed content."""
|
|
156
|
+
from yamlgraph.utils.prompts import load_prompt_path
|
|
157
|
+
|
|
158
|
+
prompts_dir = tmp_path / "prompts"
|
|
159
|
+
prompts_dir.mkdir()
|
|
160
|
+
prompt_file = prompts_dir / "dual.yaml"
|
|
161
|
+
prompt_file.write_text("system: Test\nuser: Hello")
|
|
162
|
+
|
|
163
|
+
path, content = load_prompt_path("dual", prompts_dir=prompts_dir)
|
|
164
|
+
|
|
165
|
+
assert path == prompt_file
|
|
166
|
+
assert content["system"] == "Test"
|
|
@@ -0,0 +1,198 @@
|
|
|
1
|
+
"""Tests for Python tool nodes (type: python)."""
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
|
|
5
|
+
from yamlgraph.tools.python_tool import (
|
|
6
|
+
PythonToolConfig,
|
|
7
|
+
create_python_node,
|
|
8
|
+
load_python_function,
|
|
9
|
+
parse_python_tools,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class TestPythonToolConfig:
|
|
14
|
+
"""Tests for PythonToolConfig dataclass."""
|
|
15
|
+
|
|
16
|
+
def test_basic_config(self):
|
|
17
|
+
"""Can create config with required fields."""
|
|
18
|
+
config = PythonToolConfig(
|
|
19
|
+
module="os.path",
|
|
20
|
+
function="join",
|
|
21
|
+
)
|
|
22
|
+
assert config.module == "os.path"
|
|
23
|
+
assert config.function == "join"
|
|
24
|
+
assert config.description == ""
|
|
25
|
+
|
|
26
|
+
def test_config_with_description(self):
|
|
27
|
+
"""Can create config with description."""
|
|
28
|
+
config = PythonToolConfig(
|
|
29
|
+
module="json",
|
|
30
|
+
function="dumps",
|
|
31
|
+
description="Serialize to JSON",
|
|
32
|
+
)
|
|
33
|
+
assert config.description == "Serialize to JSON"
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class TestLoadPythonFunction:
|
|
37
|
+
"""Tests for load_python_function."""
|
|
38
|
+
|
|
39
|
+
def test_loads_stdlib_function(self):
|
|
40
|
+
"""Can load function from stdlib."""
|
|
41
|
+
config = PythonToolConfig(module="os.path", function="join")
|
|
42
|
+
func = load_python_function(config)
|
|
43
|
+
assert callable(func)
|
|
44
|
+
assert func("a", "b") == "a/b"
|
|
45
|
+
|
|
46
|
+
def test_loads_json_dumps(self):
|
|
47
|
+
"""Can load json.dumps."""
|
|
48
|
+
config = PythonToolConfig(module="json", function="dumps")
|
|
49
|
+
func = load_python_function(config)
|
|
50
|
+
assert func({"a": 1}) == '{"a": 1}'
|
|
51
|
+
|
|
52
|
+
def test_raises_on_invalid_module(self):
|
|
53
|
+
"""Raises ImportError for non-existent module."""
|
|
54
|
+
config = PythonToolConfig(module="nonexistent.module", function="foo")
|
|
55
|
+
with pytest.raises(ImportError, match="Cannot import module"):
|
|
56
|
+
load_python_function(config)
|
|
57
|
+
|
|
58
|
+
def test_raises_on_invalid_function(self):
|
|
59
|
+
"""Raises AttributeError for non-existent function."""
|
|
60
|
+
config = PythonToolConfig(module="os.path", function="nonexistent_func")
|
|
61
|
+
with pytest.raises(AttributeError, match="not found in module"):
|
|
62
|
+
load_python_function(config)
|
|
63
|
+
|
|
64
|
+
def test_raises_on_non_callable(self):
|
|
65
|
+
"""Raises TypeError if attribute is not callable."""
|
|
66
|
+
config = PythonToolConfig(module="os", function="name")
|
|
67
|
+
with pytest.raises(TypeError, match="not callable"):
|
|
68
|
+
load_python_function(config)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class TestParsePythonTools:
|
|
72
|
+
"""Tests for parse_python_tools."""
|
|
73
|
+
|
|
74
|
+
def test_parses_python_tools(self):
|
|
75
|
+
"""Extracts only type: python tools."""
|
|
76
|
+
tools_config = {
|
|
77
|
+
"shell_tool": {"command": "echo hello"},
|
|
78
|
+
"python_tool": {
|
|
79
|
+
"type": "python",
|
|
80
|
+
"module": "json",
|
|
81
|
+
"function": "dumps",
|
|
82
|
+
},
|
|
83
|
+
}
|
|
84
|
+
result = parse_python_tools(tools_config)
|
|
85
|
+
|
|
86
|
+
assert len(result) == 1
|
|
87
|
+
assert "python_tool" in result
|
|
88
|
+
assert result["python_tool"].module == "json"
|
|
89
|
+
assert result["python_tool"].function == "dumps"
|
|
90
|
+
|
|
91
|
+
def test_skips_shell_tools(self):
|
|
92
|
+
"""Does not include shell tools."""
|
|
93
|
+
tools_config = {
|
|
94
|
+
"git_log": {
|
|
95
|
+
"type": "shell",
|
|
96
|
+
"command": "git log",
|
|
97
|
+
},
|
|
98
|
+
}
|
|
99
|
+
result = parse_python_tools(tools_config)
|
|
100
|
+
assert len(result) == 0
|
|
101
|
+
|
|
102
|
+
def test_skips_incomplete_python_tools(self):
|
|
103
|
+
"""Skips Python tools missing module or function."""
|
|
104
|
+
tools_config = {
|
|
105
|
+
"missing_module": {"type": "python", "function": "foo"},
|
|
106
|
+
"missing_function": {"type": "python", "module": "json"},
|
|
107
|
+
}
|
|
108
|
+
result = parse_python_tools(tools_config)
|
|
109
|
+
assert len(result) == 0
|
|
110
|
+
|
|
111
|
+
def test_includes_description(self):
|
|
112
|
+
"""Parses description field."""
|
|
113
|
+
tools_config = {
|
|
114
|
+
"my_tool": {
|
|
115
|
+
"type": "python",
|
|
116
|
+
"module": "json",
|
|
117
|
+
"function": "loads",
|
|
118
|
+
"description": "Parse JSON",
|
|
119
|
+
},
|
|
120
|
+
}
|
|
121
|
+
result = parse_python_tools(tools_config)
|
|
122
|
+
assert result["my_tool"].description == "Parse JSON"
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class TestCreatePythonNode:
|
|
126
|
+
"""Tests for create_python_node."""
|
|
127
|
+
|
|
128
|
+
def test_creates_node_function(self):
|
|
129
|
+
"""Creates callable node function."""
|
|
130
|
+
python_tools = {
|
|
131
|
+
"my_tool": PythonToolConfig(
|
|
132
|
+
module="tests.unit.test_python_nodes",
|
|
133
|
+
function="sample_node_function",
|
|
134
|
+
),
|
|
135
|
+
}
|
|
136
|
+
node_config = {"tool": "my_tool", "state_key": "result"}
|
|
137
|
+
|
|
138
|
+
node_fn = create_python_node("test_node", node_config, python_tools)
|
|
139
|
+
assert callable(node_fn)
|
|
140
|
+
|
|
141
|
+
def test_raises_on_missing_tool(self):
|
|
142
|
+
"""Raises if tool not in registry."""
|
|
143
|
+
python_tools = {}
|
|
144
|
+
node_config = {"tool": "nonexistent"}
|
|
145
|
+
|
|
146
|
+
with pytest.raises(KeyError, match="not found"):
|
|
147
|
+
create_python_node("test_node", node_config, python_tools)
|
|
148
|
+
|
|
149
|
+
def test_raises_on_missing_tool_key(self):
|
|
150
|
+
"""Raises if node config missing tool key."""
|
|
151
|
+
python_tools = {}
|
|
152
|
+
node_config = {}
|
|
153
|
+
|
|
154
|
+
with pytest.raises(ValueError, match="must specify"):
|
|
155
|
+
create_python_node("test_node", node_config, python_tools)
|
|
156
|
+
|
|
157
|
+
def test_node_returns_dict_from_function(self):
|
|
158
|
+
"""Node returns function's dict result with current_step."""
|
|
159
|
+
python_tools = {
|
|
160
|
+
"dict_tool": PythonToolConfig(
|
|
161
|
+
module="tests.unit.test_python_nodes",
|
|
162
|
+
function="sample_node_function",
|
|
163
|
+
),
|
|
164
|
+
}
|
|
165
|
+
node_config = {"tool": "dict_tool"}
|
|
166
|
+
|
|
167
|
+
node_fn = create_python_node("test_node", node_config, python_tools)
|
|
168
|
+
result = node_fn({"input": "hello"})
|
|
169
|
+
|
|
170
|
+
assert result["current_step"] == "test_node"
|
|
171
|
+
assert "output" in result
|
|
172
|
+
|
|
173
|
+
def test_node_wraps_non_dict_return(self):
|
|
174
|
+
"""Node wraps non-dict return in state_key."""
|
|
175
|
+
python_tools = {
|
|
176
|
+
"scalar_tool": PythonToolConfig(
|
|
177
|
+
module="tests.unit.test_python_nodes",
|
|
178
|
+
function="scalar_return_function",
|
|
179
|
+
),
|
|
180
|
+
}
|
|
181
|
+
node_config = {"tool": "scalar_tool", "state_key": "my_value"}
|
|
182
|
+
|
|
183
|
+
node_fn = create_python_node("test_node", node_config, python_tools)
|
|
184
|
+
result = node_fn({})
|
|
185
|
+
|
|
186
|
+
assert result["my_value"] == 42
|
|
187
|
+
assert result["current_step"] == "test_node"
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
# Sample functions for testing
|
|
191
|
+
def sample_node_function(state: dict) -> dict:
|
|
192
|
+
"""Sample node function that returns a dict."""
|
|
193
|
+
return {"output": f"processed: {state.get('input', 'none')}"}
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def scalar_return_function(state: dict) -> int:
|
|
197
|
+
"""Sample function that returns a scalar."""
|
|
198
|
+
return 42
|