yamlgraph 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of yamlgraph might be problematic. Click here for more details.
- examples/__init__.py +1 -0
- examples/storyboard/__init__.py +1 -0
- examples/storyboard/generate_videos.py +335 -0
- examples/storyboard/nodes/__init__.py +10 -0
- examples/storyboard/nodes/animated_character_node.py +248 -0
- examples/storyboard/nodes/animated_image_node.py +138 -0
- examples/storyboard/nodes/character_node.py +162 -0
- examples/storyboard/nodes/image_node.py +118 -0
- examples/storyboard/nodes/replicate_tool.py +238 -0
- examples/storyboard/retry_images.py +118 -0
- tests/__init__.py +1 -0
- tests/conftest.py +178 -0
- tests/integration/__init__.py +1 -0
- tests/integration/test_animated_storyboard.py +63 -0
- tests/integration/test_cli_commands.py +242 -0
- tests/integration/test_map_demo.py +50 -0
- tests/integration/test_memory_demo.py +281 -0
- tests/integration/test_pipeline_flow.py +105 -0
- tests/integration/test_providers.py +163 -0
- tests/integration/test_resume.py +75 -0
- tests/unit/__init__.py +1 -0
- tests/unit/test_agent_nodes.py +200 -0
- tests/unit/test_checkpointer.py +212 -0
- tests/unit/test_cli.py +121 -0
- tests/unit/test_cli_package.py +81 -0
- tests/unit/test_compile_graph_map.py +132 -0
- tests/unit/test_conditions_routing.py +253 -0
- tests/unit/test_config.py +93 -0
- tests/unit/test_conversation_memory.py +270 -0
- tests/unit/test_database.py +145 -0
- tests/unit/test_deprecation.py +104 -0
- tests/unit/test_executor.py +60 -0
- tests/unit/test_executor_async.py +179 -0
- tests/unit/test_export.py +150 -0
- tests/unit/test_expressions.py +178 -0
- tests/unit/test_format_prompt.py +145 -0
- tests/unit/test_generic_report.py +200 -0
- tests/unit/test_graph_commands.py +327 -0
- tests/unit/test_graph_loader.py +299 -0
- tests/unit/test_graph_schema.py +193 -0
- tests/unit/test_inline_schema.py +151 -0
- tests/unit/test_issues.py +164 -0
- tests/unit/test_jinja2_prompts.py +85 -0
- tests/unit/test_langsmith.py +319 -0
- tests/unit/test_llm_factory.py +109 -0
- tests/unit/test_llm_factory_async.py +118 -0
- tests/unit/test_loops.py +403 -0
- tests/unit/test_map_node.py +144 -0
- tests/unit/test_no_backward_compat.py +56 -0
- tests/unit/test_node_factory.py +225 -0
- tests/unit/test_prompts.py +166 -0
- tests/unit/test_python_nodes.py +198 -0
- tests/unit/test_reliability.py +298 -0
- tests/unit/test_result_export.py +234 -0
- tests/unit/test_router.py +296 -0
- tests/unit/test_sanitize.py +99 -0
- tests/unit/test_schema_loader.py +295 -0
- tests/unit/test_shell_tools.py +229 -0
- tests/unit/test_state_builder.py +331 -0
- tests/unit/test_state_builder_map.py +104 -0
- tests/unit/test_state_config.py +197 -0
- tests/unit/test_template.py +190 -0
- tests/unit/test_tool_nodes.py +129 -0
- yamlgraph/__init__.py +35 -0
- yamlgraph/builder.py +110 -0
- yamlgraph/cli/__init__.py +139 -0
- yamlgraph/cli/__main__.py +6 -0
- yamlgraph/cli/commands.py +232 -0
- yamlgraph/cli/deprecation.py +92 -0
- yamlgraph/cli/graph_commands.py +382 -0
- yamlgraph/cli/validators.py +37 -0
- yamlgraph/config.py +67 -0
- yamlgraph/constants.py +66 -0
- yamlgraph/error_handlers.py +226 -0
- yamlgraph/executor.py +275 -0
- yamlgraph/executor_async.py +122 -0
- yamlgraph/graph_loader.py +337 -0
- yamlgraph/map_compiler.py +138 -0
- yamlgraph/models/__init__.py +36 -0
- yamlgraph/models/graph_schema.py +141 -0
- yamlgraph/models/schemas.py +124 -0
- yamlgraph/models/state_builder.py +236 -0
- yamlgraph/node_factory.py +240 -0
- yamlgraph/routing.py +87 -0
- yamlgraph/schema_loader.py +160 -0
- yamlgraph/storage/__init__.py +17 -0
- yamlgraph/storage/checkpointer.py +72 -0
- yamlgraph/storage/database.py +320 -0
- yamlgraph/storage/export.py +269 -0
- yamlgraph/tools/__init__.py +1 -0
- yamlgraph/tools/agent.py +235 -0
- yamlgraph/tools/nodes.py +124 -0
- yamlgraph/tools/python_tool.py +178 -0
- yamlgraph/tools/shell.py +205 -0
- yamlgraph/utils/__init__.py +47 -0
- yamlgraph/utils/conditions.py +157 -0
- yamlgraph/utils/expressions.py +111 -0
- yamlgraph/utils/langsmith.py +308 -0
- yamlgraph/utils/llm_factory.py +118 -0
- yamlgraph/utils/llm_factory_async.py +105 -0
- yamlgraph/utils/logging.py +127 -0
- yamlgraph/utils/prompts.py +116 -0
- yamlgraph/utils/sanitize.py +98 -0
- yamlgraph/utils/template.py +102 -0
- yamlgraph/utils/validators.py +181 -0
- yamlgraph-0.1.1.dist-info/METADATA +854 -0
- yamlgraph-0.1.1.dist-info/RECORD +111 -0
- yamlgraph-0.1.1.dist-info/WHEEL +5 -0
- yamlgraph-0.1.1.dist-info/entry_points.txt +2 -0
- yamlgraph-0.1.1.dist-info/licenses/LICENSE +21 -0
- yamlgraph-0.1.1.dist-info/top_level.txt +3 -0
|
@@ -0,0 +1,179 @@
|
|
|
1
|
+
"""Unit tests for async executor module."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
from unittest.mock import AsyncMock, MagicMock, patch
|
|
5
|
+
|
|
6
|
+
import pytest
|
|
7
|
+
|
|
8
|
+
from yamlgraph.executor_async import execute_prompt_async, execute_prompts_concurrent
|
|
9
|
+
from yamlgraph.utils.llm_factory_async import shutdown_executor
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class TestExecutePromptAsync:
|
|
13
|
+
"""Tests for execute_prompt_async function."""
|
|
14
|
+
|
|
15
|
+
def teardown_method(self):
|
|
16
|
+
"""Clean up executor after each test."""
|
|
17
|
+
shutdown_executor()
|
|
18
|
+
|
|
19
|
+
@pytest.mark.asyncio
|
|
20
|
+
async def test_executes_prompt(self):
|
|
21
|
+
"""Should execute a prompt and return result."""
|
|
22
|
+
with patch("yamlgraph.executor_async.invoke_async") as mock_invoke:
|
|
23
|
+
mock_invoke.return_value = "Hello, World!"
|
|
24
|
+
|
|
25
|
+
result = await execute_prompt_async(
|
|
26
|
+
"greet",
|
|
27
|
+
variables={"name": "World", "style": "friendly"},
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
assert result == "Hello, World!"
|
|
31
|
+
mock_invoke.assert_called_once()
|
|
32
|
+
|
|
33
|
+
@pytest.mark.asyncio
|
|
34
|
+
async def test_passes_output_model(self):
|
|
35
|
+
"""Should pass output model to invoke_async."""
|
|
36
|
+
from pydantic import BaseModel
|
|
37
|
+
|
|
38
|
+
class TestModel(BaseModel):
|
|
39
|
+
greeting: str
|
|
40
|
+
|
|
41
|
+
with patch("yamlgraph.executor_async.invoke_async") as mock_invoke:
|
|
42
|
+
mock_invoke.return_value = TestModel(greeting="Hi")
|
|
43
|
+
|
|
44
|
+
result = await execute_prompt_async(
|
|
45
|
+
"greet",
|
|
46
|
+
variables={"name": "Test", "style": "casual"},
|
|
47
|
+
output_model=TestModel,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
assert isinstance(result, TestModel)
|
|
51
|
+
# Check output_model was passed (positional arg)
|
|
52
|
+
call_args = mock_invoke.call_args
|
|
53
|
+
assert call_args[0][2] is TestModel # 3rd positional arg
|
|
54
|
+
|
|
55
|
+
@pytest.mark.asyncio
|
|
56
|
+
async def test_validates_variables(self):
|
|
57
|
+
"""Should raise error for missing variables."""
|
|
58
|
+
with pytest.raises(ValueError, match="Missing required variable"):
|
|
59
|
+
await execute_prompt_async("greet", variables={})
|
|
60
|
+
|
|
61
|
+
@pytest.mark.asyncio
|
|
62
|
+
async def test_uses_provider_from_yaml(self):
|
|
63
|
+
"""Should extract provider from YAML metadata."""
|
|
64
|
+
with (
|
|
65
|
+
patch("yamlgraph.executor_async.load_prompt") as mock_load,
|
|
66
|
+
patch("yamlgraph.executor_async.invoke_async") as mock_invoke,
|
|
67
|
+
patch("yamlgraph.executor_async.create_llm") as mock_create_llm,
|
|
68
|
+
):
|
|
69
|
+
mock_load.return_value = {
|
|
70
|
+
"system": "You are helpful.",
|
|
71
|
+
"user": "Hello {name}",
|
|
72
|
+
"provider": "mistral",
|
|
73
|
+
}
|
|
74
|
+
mock_invoke.return_value = "Response"
|
|
75
|
+
mock_create_llm.return_value = MagicMock()
|
|
76
|
+
|
|
77
|
+
await execute_prompt_async("test", variables={"name": "User"})
|
|
78
|
+
|
|
79
|
+
mock_create_llm.assert_called_once()
|
|
80
|
+
call_kwargs = mock_create_llm.call_args[1]
|
|
81
|
+
assert call_kwargs["provider"] == "mistral"
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class TestExecutePromptsConcurrent:
|
|
85
|
+
"""Tests for execute_prompts_concurrent function."""
|
|
86
|
+
|
|
87
|
+
def teardown_method(self):
|
|
88
|
+
"""Clean up executor after each test."""
|
|
89
|
+
shutdown_executor()
|
|
90
|
+
|
|
91
|
+
@pytest.mark.asyncio
|
|
92
|
+
async def test_executes_multiple_prompts(self):
|
|
93
|
+
"""Should execute multiple prompts concurrently."""
|
|
94
|
+
with patch(
|
|
95
|
+
"yamlgraph.executor_async.execute_prompt_async", new_callable=AsyncMock
|
|
96
|
+
) as mock_execute:
|
|
97
|
+
mock_execute.side_effect = ["Result 1", "Result 2", "Result 3"]
|
|
98
|
+
|
|
99
|
+
results = await execute_prompts_concurrent(
|
|
100
|
+
[
|
|
101
|
+
{"prompt_name": "greet", "variables": {"name": "A", "style": "x"}},
|
|
102
|
+
{"prompt_name": "greet", "variables": {"name": "B", "style": "y"}},
|
|
103
|
+
{"prompt_name": "greet", "variables": {"name": "C", "style": "z"}},
|
|
104
|
+
]
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
assert len(results) == 3
|
|
108
|
+
assert results == ["Result 1", "Result 2", "Result 3"]
|
|
109
|
+
assert mock_execute.call_count == 3
|
|
110
|
+
|
|
111
|
+
@pytest.mark.asyncio
|
|
112
|
+
async def test_preserves_order(self):
|
|
113
|
+
"""Should return results in same order as input."""
|
|
114
|
+
with patch(
|
|
115
|
+
"yamlgraph.executor_async.execute_prompt_async", new_callable=AsyncMock
|
|
116
|
+
) as mock_execute:
|
|
117
|
+
# Simulate varying response times
|
|
118
|
+
async def delayed_response(prompt_name, **kwargs):
|
|
119
|
+
name = kwargs.get("variables", {}).get("name", "")
|
|
120
|
+
if name == "slow":
|
|
121
|
+
await asyncio.sleep(0.01)
|
|
122
|
+
return f"Response for {name}"
|
|
123
|
+
|
|
124
|
+
mock_execute.side_effect = delayed_response
|
|
125
|
+
|
|
126
|
+
results = await execute_prompts_concurrent(
|
|
127
|
+
[
|
|
128
|
+
{
|
|
129
|
+
"prompt_name": "greet",
|
|
130
|
+
"variables": {"name": "slow", "style": "a"},
|
|
131
|
+
},
|
|
132
|
+
{
|
|
133
|
+
"prompt_name": "greet",
|
|
134
|
+
"variables": {"name": "fast", "style": "b"},
|
|
135
|
+
},
|
|
136
|
+
]
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
assert results[0] == "Response for slow"
|
|
140
|
+
assert results[1] == "Response for fast"
|
|
141
|
+
|
|
142
|
+
@pytest.mark.asyncio
|
|
143
|
+
async def test_empty_list(self):
|
|
144
|
+
"""Should handle empty prompt list."""
|
|
145
|
+
results = await execute_prompts_concurrent([])
|
|
146
|
+
assert results == []
|
|
147
|
+
|
|
148
|
+
@pytest.mark.asyncio
|
|
149
|
+
async def test_passes_all_options(self):
|
|
150
|
+
"""Should pass all options to execute_prompt_async."""
|
|
151
|
+
from pydantic import BaseModel
|
|
152
|
+
|
|
153
|
+
class TestModel(BaseModel):
|
|
154
|
+
value: str
|
|
155
|
+
|
|
156
|
+
with patch(
|
|
157
|
+
"yamlgraph.executor_async.execute_prompt_async", new_callable=AsyncMock
|
|
158
|
+
) as mock_execute:
|
|
159
|
+
mock_execute.return_value = TestModel(value="test")
|
|
160
|
+
|
|
161
|
+
await execute_prompts_concurrent(
|
|
162
|
+
[
|
|
163
|
+
{
|
|
164
|
+
"prompt_name": "test",
|
|
165
|
+
"variables": {"x": "y"},
|
|
166
|
+
"output_model": TestModel,
|
|
167
|
+
"temperature": 0.5,
|
|
168
|
+
"provider": "openai",
|
|
169
|
+
}
|
|
170
|
+
]
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
mock_execute.assert_called_once_with(
|
|
174
|
+
prompt_name="test",
|
|
175
|
+
variables={"x": "y"},
|
|
176
|
+
output_model=TestModel,
|
|
177
|
+
temperature=0.5,
|
|
178
|
+
provider="openai",
|
|
179
|
+
)
|
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
"""Tests for yamlgraph.storage.export module."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
|
|
5
|
+
from tests.conftest import FixtureGeneratedContent
|
|
6
|
+
from yamlgraph.storage.export import _serialize_state, export_state
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class TestExportState:
|
|
10
|
+
"""Tests for export_state function."""
|
|
11
|
+
|
|
12
|
+
def test_export_creates_file(self, temp_output_dir, sample_state):
|
|
13
|
+
"""Export should create a JSON file."""
|
|
14
|
+
filepath = export_state(sample_state, output_dir=temp_output_dir)
|
|
15
|
+
assert filepath.exists()
|
|
16
|
+
assert filepath.suffix == ".json"
|
|
17
|
+
|
|
18
|
+
def test_export_file_contains_valid_json(self, temp_output_dir, sample_state):
|
|
19
|
+
"""Exported file should contain valid JSON."""
|
|
20
|
+
filepath = export_state(sample_state, output_dir=temp_output_dir)
|
|
21
|
+
with open(filepath) as f:
|
|
22
|
+
data = json.load(f)
|
|
23
|
+
assert "topic" in data
|
|
24
|
+
assert "thread_id" in data
|
|
25
|
+
|
|
26
|
+
def test_export_filename_format(self, temp_output_dir, sample_state):
|
|
27
|
+
"""Filename should include prefix and thread_id."""
|
|
28
|
+
filepath = export_state(
|
|
29
|
+
sample_state,
|
|
30
|
+
output_dir=temp_output_dir,
|
|
31
|
+
prefix="test_export",
|
|
32
|
+
)
|
|
33
|
+
assert "test_export" in filepath.name
|
|
34
|
+
assert sample_state["thread_id"] in filepath.name
|
|
35
|
+
|
|
36
|
+
def test_export_creates_output_dir(self, tmp_path, sample_state):
|
|
37
|
+
"""Export should create output directory if it doesn't exist."""
|
|
38
|
+
new_dir = tmp_path / "new_outputs"
|
|
39
|
+
filepath = export_state(sample_state, output_dir=new_dir)
|
|
40
|
+
assert new_dir.exists()
|
|
41
|
+
assert filepath.exists()
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class TestSerializeState:
|
|
45
|
+
"""Tests for _serialize_state function."""
|
|
46
|
+
|
|
47
|
+
def test_serialize_simple_state(self, empty_state):
|
|
48
|
+
"""Simple state should serialize unchanged."""
|
|
49
|
+
result = _serialize_state(empty_state)
|
|
50
|
+
assert result["topic"] == empty_state["topic"]
|
|
51
|
+
assert result["style"] == empty_state["style"]
|
|
52
|
+
|
|
53
|
+
def test_serialize_pydantic_models(self):
|
|
54
|
+
"""Pydantic models should be converted to dicts."""
|
|
55
|
+
content = FixtureGeneratedContent(
|
|
56
|
+
title="Test",
|
|
57
|
+
content="Content",
|
|
58
|
+
word_count=1,
|
|
59
|
+
tags=["tag"],
|
|
60
|
+
)
|
|
61
|
+
state = {"generated": content}
|
|
62
|
+
result = _serialize_state(state)
|
|
63
|
+
assert isinstance(result["generated"], dict)
|
|
64
|
+
assert result["generated"]["title"] == "Test"
|
|
65
|
+
|
|
66
|
+
def test_serialize_preserves_none(self, empty_state):
|
|
67
|
+
"""None values should be preserved."""
|
|
68
|
+
# Add a None field to test serialization
|
|
69
|
+
empty_state["generated"] = None
|
|
70
|
+
result = _serialize_state(empty_state)
|
|
71
|
+
assert result["generated"] is None
|
|
72
|
+
assert result["error"] is None
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class TestExportSummaryGeneric:
|
|
76
|
+
"""Tests for generic export_summary behavior."""
|
|
77
|
+
|
|
78
|
+
def test_export_summary_with_any_pydantic_model(self):
|
|
79
|
+
"""export_summary should work with any Pydantic model, not just demo-specific ones."""
|
|
80
|
+
from pydantic import BaseModel
|
|
81
|
+
|
|
82
|
+
from yamlgraph.storage.export import export_summary
|
|
83
|
+
|
|
84
|
+
class CustomModel(BaseModel):
|
|
85
|
+
name: str
|
|
86
|
+
value: int
|
|
87
|
+
|
|
88
|
+
state = {
|
|
89
|
+
"thread_id": "test-123",
|
|
90
|
+
"topic": "custom topic",
|
|
91
|
+
"custom_field": CustomModel(name="test", value=42),
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
summary = export_summary(state)
|
|
95
|
+
|
|
96
|
+
# Should include core fields
|
|
97
|
+
assert summary["thread_id"] == "test-123"
|
|
98
|
+
assert summary["topic"] == "custom topic"
|
|
99
|
+
|
|
100
|
+
def test_export_summary_extracts_scalar_fields(self):
|
|
101
|
+
"""export_summary should extract key scalar fields from any model."""
|
|
102
|
+
from pydantic import BaseModel
|
|
103
|
+
|
|
104
|
+
from yamlgraph.storage.export import export_summary
|
|
105
|
+
|
|
106
|
+
class ReportContent(BaseModel):
|
|
107
|
+
headline: str
|
|
108
|
+
body: str
|
|
109
|
+
author: str
|
|
110
|
+
|
|
111
|
+
state = {
|
|
112
|
+
"thread_id": "report-1",
|
|
113
|
+
"topic": "report topic",
|
|
114
|
+
"report": ReportContent(
|
|
115
|
+
headline="Breaking News",
|
|
116
|
+
body="Content here...",
|
|
117
|
+
author="Alice",
|
|
118
|
+
),
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
summary = export_summary(state)
|
|
122
|
+
# Should extract and include scalar fields
|
|
123
|
+
assert "report" in summary or any(k.startswith("report") for k in summary)
|
|
124
|
+
|
|
125
|
+
def test_export_summary_no_demo_model_dependencies(self):
|
|
126
|
+
"""export_summary should not import demo-specific model types."""
|
|
127
|
+
import ast
|
|
128
|
+
import inspect
|
|
129
|
+
|
|
130
|
+
from yamlgraph.storage import export
|
|
131
|
+
|
|
132
|
+
source = inspect.getsource(export)
|
|
133
|
+
tree = ast.parse(source)
|
|
134
|
+
|
|
135
|
+
demo_models = {
|
|
136
|
+
"GeneratedContent",
|
|
137
|
+
"Analysis",
|
|
138
|
+
"ToneClassification",
|
|
139
|
+
"DraftContent",
|
|
140
|
+
"Critique",
|
|
141
|
+
"SearchResults",
|
|
142
|
+
"FinalReport",
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
for node in ast.walk(tree):
|
|
146
|
+
if isinstance(node, ast.ImportFrom):
|
|
147
|
+
if node.module and "schemas" in node.module:
|
|
148
|
+
imported_names = {alias.name for alias in node.names}
|
|
149
|
+
overlap = imported_names & demo_models
|
|
150
|
+
assert not overlap, f"export.py imports demo models: {overlap}"
|
|
@@ -0,0 +1,178 @@
|
|
|
1
|
+
"""Tests for state expression resolution."""
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
from pydantic import BaseModel
|
|
5
|
+
|
|
6
|
+
from yamlgraph.utils.expressions import (
|
|
7
|
+
resolve_state_expression,
|
|
8
|
+
resolve_state_path,
|
|
9
|
+
resolve_template,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class TestResolveStateExpression:
|
|
14
|
+
"""Tests for resolve_state_expression function."""
|
|
15
|
+
|
|
16
|
+
def test_simple_key(self):
|
|
17
|
+
"""Resolve simple state key."""
|
|
18
|
+
state = {"name": "test"}
|
|
19
|
+
result = resolve_state_expression("{name}", state)
|
|
20
|
+
assert result == "test"
|
|
21
|
+
|
|
22
|
+
def test_nested_path(self):
|
|
23
|
+
"""Resolve nested path like {state.story.panels}."""
|
|
24
|
+
state = {"story": {"panels": ["a", "b", "c"]}}
|
|
25
|
+
result = resolve_state_expression("{state.story.panels}", state)
|
|
26
|
+
assert result == ["a", "b", "c"]
|
|
27
|
+
|
|
28
|
+
def test_state_prefix_stripped(self):
|
|
29
|
+
"""The 'state.' prefix is optional and stripped."""
|
|
30
|
+
state = {"story": {"title": "My Story"}}
|
|
31
|
+
# With prefix
|
|
32
|
+
assert resolve_state_expression("{state.story.title}", state) == "My Story"
|
|
33
|
+
# Without prefix
|
|
34
|
+
assert resolve_state_expression("{story.title}", state) == "My Story"
|
|
35
|
+
|
|
36
|
+
def test_literal_passthrough(self):
|
|
37
|
+
"""Non-expression strings pass through unchanged."""
|
|
38
|
+
result = resolve_state_expression("literal string", {})
|
|
39
|
+
assert result == "literal string"
|
|
40
|
+
|
|
41
|
+
def test_non_string_passthrough(self):
|
|
42
|
+
"""Non-string values pass through unchanged."""
|
|
43
|
+
result = resolve_state_expression(42, {})
|
|
44
|
+
assert result == 42
|
|
45
|
+
|
|
46
|
+
def test_missing_key_raises(self):
|
|
47
|
+
"""Missing key raises KeyError."""
|
|
48
|
+
state = {"foo": "bar"}
|
|
49
|
+
with pytest.raises(KeyError):
|
|
50
|
+
resolve_state_expression("{missing}", state)
|
|
51
|
+
|
|
52
|
+
def test_missing_nested_key_raises(self):
|
|
53
|
+
"""Missing nested key raises KeyError."""
|
|
54
|
+
state = {"story": {"title": "test"}}
|
|
55
|
+
with pytest.raises(KeyError):
|
|
56
|
+
resolve_state_expression("{story.panels}", state)
|
|
57
|
+
|
|
58
|
+
def test_deeply_nested_path(self):
|
|
59
|
+
"""Resolve deeply nested paths."""
|
|
60
|
+
state = {"a": {"b": {"c": {"d": "deep"}}}}
|
|
61
|
+
result = resolve_state_expression("{a.b.c.d}", state)
|
|
62
|
+
assert result == "deep"
|
|
63
|
+
|
|
64
|
+
def test_list_result(self):
|
|
65
|
+
"""Can resolve to list values."""
|
|
66
|
+
state = {"items": [1, 2, 3]}
|
|
67
|
+
result = resolve_state_expression("{items}", state)
|
|
68
|
+
assert result == [1, 2, 3]
|
|
69
|
+
|
|
70
|
+
def test_dict_result(self):
|
|
71
|
+
"""Can resolve to dict values."""
|
|
72
|
+
state = {"config": {"key": "value"}}
|
|
73
|
+
result = resolve_state_expression("{config}", state)
|
|
74
|
+
assert result == {"key": "value"}
|
|
75
|
+
|
|
76
|
+
def test_object_attribute_access(self):
|
|
77
|
+
"""Can resolve object attributes (Pydantic models)."""
|
|
78
|
+
|
|
79
|
+
class MockModel:
|
|
80
|
+
def __init__(self):
|
|
81
|
+
self.title = "Test Title"
|
|
82
|
+
self.panels = ["panel 1", "panel 2"]
|
|
83
|
+
|
|
84
|
+
state = {"story": MockModel()}
|
|
85
|
+
result = resolve_state_expression("{state.story.panels}", state)
|
|
86
|
+
assert result == ["panel 1", "panel 2"]
|
|
87
|
+
|
|
88
|
+
def test_mixed_dict_and_object_access(self):
|
|
89
|
+
"""Can resolve mixed dict and object paths."""
|
|
90
|
+
|
|
91
|
+
class Inner:
|
|
92
|
+
def __init__(self):
|
|
93
|
+
self.value = "found"
|
|
94
|
+
|
|
95
|
+
state = {"outer": {"middle": Inner()}}
|
|
96
|
+
result = resolve_state_expression("{outer.middle.value}", state)
|
|
97
|
+
assert result == "found"
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class TestResolveStatePath:
|
|
101
|
+
"""Tests for resolve_state_path - the core resolution function."""
|
|
102
|
+
|
|
103
|
+
def test_simple_key(self):
|
|
104
|
+
"""Should resolve simple key."""
|
|
105
|
+
state = {"score": 0.8}
|
|
106
|
+
assert resolve_state_path("score", state) == 0.8
|
|
107
|
+
|
|
108
|
+
def test_nested_dict_path(self):
|
|
109
|
+
"""Should resolve nested dict path."""
|
|
110
|
+
state = {"critique": {"score": 0.9}}
|
|
111
|
+
assert resolve_state_path("critique.score", state) == 0.9
|
|
112
|
+
|
|
113
|
+
def test_deeply_nested(self):
|
|
114
|
+
"""Should resolve deeply nested path."""
|
|
115
|
+
state = {"a": {"b": {"c": {"d": 42}}}}
|
|
116
|
+
assert resolve_state_path("a.b.c.d", state) == 42
|
|
117
|
+
|
|
118
|
+
def test_missing_key_returns_none(self):
|
|
119
|
+
"""Should return None for missing key."""
|
|
120
|
+
state = {"a": 1}
|
|
121
|
+
assert resolve_state_path("b", state) is None
|
|
122
|
+
|
|
123
|
+
def test_missing_nested_returns_none(self):
|
|
124
|
+
"""Should return None for missing nested path."""
|
|
125
|
+
state = {"a": {"b": 1}}
|
|
126
|
+
assert resolve_state_path("a.c", state) is None
|
|
127
|
+
|
|
128
|
+
def test_pydantic_model_attribute(self):
|
|
129
|
+
"""Should resolve Pydantic model attribute."""
|
|
130
|
+
|
|
131
|
+
class Critique(BaseModel):
|
|
132
|
+
score: float
|
|
133
|
+
feedback: str
|
|
134
|
+
|
|
135
|
+
state = {"critique": Critique(score=0.75, feedback="Good")}
|
|
136
|
+
assert resolve_state_path("critique.score", state) == 0.75
|
|
137
|
+
assert resolve_state_path("critique.feedback", state) == "Good"
|
|
138
|
+
|
|
139
|
+
def test_empty_path_returns_none(self):
|
|
140
|
+
"""Should return None for empty path."""
|
|
141
|
+
assert resolve_state_path("", {"a": 1}) is None
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
class TestResolveTemplate:
|
|
145
|
+
"""Tests for resolve_template - optional resolution returning None."""
|
|
146
|
+
|
|
147
|
+
def test_state_template(self):
|
|
148
|
+
"""Should resolve {state.field} template."""
|
|
149
|
+
state = {"topic": "AI"}
|
|
150
|
+
assert resolve_template("{state.topic}", state) == "AI"
|
|
151
|
+
|
|
152
|
+
def test_nested_template(self):
|
|
153
|
+
"""Should resolve nested path template."""
|
|
154
|
+
state = {"config": {"max_tokens": 100}}
|
|
155
|
+
assert resolve_template("{state.config.max_tokens}", state) == 100
|
|
156
|
+
|
|
157
|
+
def test_missing_returns_none(self):
|
|
158
|
+
"""Should return None for missing path."""
|
|
159
|
+
state = {"a": 1}
|
|
160
|
+
assert resolve_template("{state.missing}", state) is None
|
|
161
|
+
|
|
162
|
+
def test_non_string_passthrough(self):
|
|
163
|
+
"""Should pass through non-string values."""
|
|
164
|
+
assert resolve_template(123, {}) == 123
|
|
165
|
+
|
|
166
|
+
def test_non_state_template_passthrough(self):
|
|
167
|
+
"""Should pass through non-state templates."""
|
|
168
|
+
assert resolve_template("{other.field}", {}) == "{other.field}"
|
|
169
|
+
assert resolve_template("plain text", {}) == "plain text"
|
|
170
|
+
|
|
171
|
+
def test_pydantic_model(self):
|
|
172
|
+
"""Should resolve Pydantic model attribute."""
|
|
173
|
+
|
|
174
|
+
class Draft(BaseModel):
|
|
175
|
+
text: str
|
|
176
|
+
|
|
177
|
+
state = {"draft": Draft(text="Content")}
|
|
178
|
+
assert resolve_template("{state.draft.text}", state) == "Content"
|
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
"""Tests for prompt formatting with Jinja2 support."""
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
|
|
5
|
+
from yamlgraph.executor import format_prompt
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class TestFormatPrompt:
|
|
9
|
+
"""Test the format_prompt function with both simple and Jinja2 templates."""
|
|
10
|
+
|
|
11
|
+
def test_simple_format_basic(self):
|
|
12
|
+
"""Test basic string formatting with {variable} syntax."""
|
|
13
|
+
template = "Hello {name}!"
|
|
14
|
+
variables = {"name": "World"}
|
|
15
|
+
result = format_prompt(template, variables)
|
|
16
|
+
assert result == "Hello World!"
|
|
17
|
+
|
|
18
|
+
def test_simple_format_multiple_variables(self):
|
|
19
|
+
"""Test formatting with multiple variables."""
|
|
20
|
+
template = "Topic: {topic}, Style: {style}, Words: {word_count}"
|
|
21
|
+
variables = {"topic": "AI", "style": "casual", "word_count": 500}
|
|
22
|
+
result = format_prompt(template, variables)
|
|
23
|
+
assert result == "Topic: AI, Style: casual, Words: 500"
|
|
24
|
+
|
|
25
|
+
def test_simple_format_missing_variable(self):
|
|
26
|
+
"""Test that missing variables raise KeyError."""
|
|
27
|
+
template = "Hello {name}!"
|
|
28
|
+
variables = {}
|
|
29
|
+
with pytest.raises(KeyError):
|
|
30
|
+
format_prompt(template, variables)
|
|
31
|
+
|
|
32
|
+
def test_jinja2_basic_variable(self):
|
|
33
|
+
"""Test Jinja2 template with basic {{ variable }} syntax."""
|
|
34
|
+
template = "Hello {{ name }}!"
|
|
35
|
+
variables = {"name": "World"}
|
|
36
|
+
result = format_prompt(template, variables)
|
|
37
|
+
assert result == "Hello World!"
|
|
38
|
+
|
|
39
|
+
def test_jinja2_for_loop(self):
|
|
40
|
+
"""Test Jinja2 template with for loop."""
|
|
41
|
+
template = """{% for item in items %}
|
|
42
|
+
- {{ item }}
|
|
43
|
+
{% endfor %}"""
|
|
44
|
+
variables = {"items": ["apple", "banana", "cherry"]}
|
|
45
|
+
result = format_prompt(template, variables)
|
|
46
|
+
# Jinja2 preserves whitespace from template
|
|
47
|
+
assert "- apple" in result
|
|
48
|
+
assert "- banana" in result
|
|
49
|
+
assert "- cherry" in result
|
|
50
|
+
|
|
51
|
+
def test_jinja2_conditional(self):
|
|
52
|
+
"""Test Jinja2 template with if/else."""
|
|
53
|
+
template = """{% if premium %}Premium User{% else %}Regular User{% endif %}"""
|
|
54
|
+
|
|
55
|
+
result_premium = format_prompt(template, {"premium": True})
|
|
56
|
+
assert result_premium == "Premium User"
|
|
57
|
+
|
|
58
|
+
result_regular = format_prompt(template, {"premium": False})
|
|
59
|
+
assert result_regular == "Regular User"
|
|
60
|
+
|
|
61
|
+
def test_jinja2_filter_slice(self):
|
|
62
|
+
"""Test Jinja2 template with slice filter."""
|
|
63
|
+
template = "Summary: {{ text[:50] }}..."
|
|
64
|
+
variables = {
|
|
65
|
+
"text": "This is a very long text that should be truncated to show only first fifty characters"
|
|
66
|
+
}
|
|
67
|
+
result = format_prompt(template, variables)
|
|
68
|
+
# Check that the text is sliced to 50 characters
|
|
69
|
+
assert result.startswith(
|
|
70
|
+
"Summary: This is a very long text that should be truncated"
|
|
71
|
+
)
|
|
72
|
+
assert result.endswith("...")
|
|
73
|
+
assert len(result) < len(variables["text"]) + len("Summary: ...")
|
|
74
|
+
|
|
75
|
+
def test_jinja2_filter_upper(self):
|
|
76
|
+
"""Test Jinja2 template with upper filter."""
|
|
77
|
+
template = "{{ name | upper }}"
|
|
78
|
+
variables = {"name": "world"}
|
|
79
|
+
result = format_prompt(template, variables)
|
|
80
|
+
assert result == "WORLD"
|
|
81
|
+
|
|
82
|
+
def test_jinja2_complex_template(self):
|
|
83
|
+
"""Test complex Jinja2 template with loops and conditionals."""
|
|
84
|
+
template = """Items in {{ category }}:
|
|
85
|
+
{% for item in items %}
|
|
86
|
+
{% if item.available %}
|
|
87
|
+
- {{ item.name }}: ${{ item.price }}
|
|
88
|
+
{% endif %}
|
|
89
|
+
{% endfor %}"""
|
|
90
|
+
variables = {
|
|
91
|
+
"category": "Fruits",
|
|
92
|
+
"items": [
|
|
93
|
+
{"name": "Apple", "price": 1.50, "available": True},
|
|
94
|
+
{"name": "Banana", "price": 0.75, "available": False},
|
|
95
|
+
{"name": "Cherry", "price": 2.00, "available": True},
|
|
96
|
+
],
|
|
97
|
+
}
|
|
98
|
+
result = format_prompt(template, variables)
|
|
99
|
+
assert "Apple: $1.5" in result
|
|
100
|
+
assert "Cherry: $2.0" in result
|
|
101
|
+
assert "Banana" not in result
|
|
102
|
+
|
|
103
|
+
def test_jinja2_missing_variable_graceful(self):
|
|
104
|
+
"""Test that Jinja2 missing variables are handled (rendered as empty by default)."""
|
|
105
|
+
template = "Hello {{ name }}!"
|
|
106
|
+
variables = {}
|
|
107
|
+
result = format_prompt(template, variables)
|
|
108
|
+
# Jinja2 by default renders undefined variables as empty strings
|
|
109
|
+
assert result == "Hello !"
|
|
110
|
+
|
|
111
|
+
def test_detection_uses_jinja2_for_double_braces(self):
|
|
112
|
+
"""Test that {{ triggers Jinja2 mode."""
|
|
113
|
+
template = "Value: {{ x }}"
|
|
114
|
+
variables = {"x": 42}
|
|
115
|
+
result = format_prompt(template, variables)
|
|
116
|
+
assert result == "Value: 42"
|
|
117
|
+
|
|
118
|
+
def test_detection_uses_jinja2_for_statements(self):
|
|
119
|
+
"""Test that {% triggers Jinja2 mode."""
|
|
120
|
+
template = "{% if true %}Yes{% endif %}"
|
|
121
|
+
variables = {}
|
|
122
|
+
result = format_prompt(template, variables)
|
|
123
|
+
assert result == "Yes"
|
|
124
|
+
|
|
125
|
+
def test_backward_compatibility_no_jinja2_syntax(self):
|
|
126
|
+
"""Test that templates without Jinja2 syntax still use simple format."""
|
|
127
|
+
# This ensures backward compatibility
|
|
128
|
+
template = "Simple {var} template"
|
|
129
|
+
variables = {"var": "test"}
|
|
130
|
+
result = format_prompt(template, variables)
|
|
131
|
+
assert result == "Simple test template"
|
|
132
|
+
|
|
133
|
+
def test_empty_template(self):
|
|
134
|
+
"""Test formatting empty template."""
|
|
135
|
+
template = ""
|
|
136
|
+
variables = {}
|
|
137
|
+
result = format_prompt(template, variables)
|
|
138
|
+
assert result == ""
|
|
139
|
+
|
|
140
|
+
def test_template_with_no_placeholders(self):
|
|
141
|
+
"""Test template with no variables."""
|
|
142
|
+
template = "Just plain text"
|
|
143
|
+
variables = {"unused": "value"}
|
|
144
|
+
result = format_prompt(template, variables)
|
|
145
|
+
assert result == "Just plain text"
|