yamlgraph 0.3.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- examples/__init__.py +1 -0
- examples/codegen/__init__.py +5 -0
- examples/codegen/models/__init__.py +13 -0
- examples/codegen/models/schemas.py +76 -0
- examples/codegen/tests/__init__.py +1 -0
- examples/codegen/tests/test_ai_helpers.py +235 -0
- examples/codegen/tests/test_ast_analysis.py +174 -0
- examples/codegen/tests/test_code_analysis.py +134 -0
- examples/codegen/tests/test_code_context.py +301 -0
- examples/codegen/tests/test_code_nav.py +89 -0
- examples/codegen/tests/test_dependency_tools.py +119 -0
- examples/codegen/tests/test_example_tools.py +185 -0
- examples/codegen/tests/test_git_tools.py +112 -0
- examples/codegen/tests/test_impl_agent_schemas.py +193 -0
- examples/codegen/tests/test_impl_agent_v4_graph.py +94 -0
- examples/codegen/tests/test_jedi_analysis.py +226 -0
- examples/codegen/tests/test_meta_tools.py +250 -0
- examples/codegen/tests/test_plan_discovery_prompt.py +98 -0
- examples/codegen/tests/test_syntax_tools.py +85 -0
- examples/codegen/tests/test_synthesize_prompt.py +94 -0
- examples/codegen/tests/test_template_tools.py +244 -0
- examples/codegen/tools/__init__.py +80 -0
- examples/codegen/tools/ai_helpers.py +420 -0
- examples/codegen/tools/ast_analysis.py +92 -0
- examples/codegen/tools/code_context.py +180 -0
- examples/codegen/tools/code_nav.py +52 -0
- examples/codegen/tools/dependency_tools.py +120 -0
- examples/codegen/tools/example_tools.py +188 -0
- examples/codegen/tools/git_tools.py +151 -0
- examples/codegen/tools/impl_executor.py +614 -0
- examples/codegen/tools/jedi_analysis.py +311 -0
- examples/codegen/tools/meta_tools.py +202 -0
- examples/codegen/tools/syntax_tools.py +26 -0
- examples/codegen/tools/template_tools.py +356 -0
- examples/fastapi_interview.py +167 -0
- examples/npc/api/__init__.py +1 -0
- examples/npc/api/app.py +100 -0
- examples/npc/api/routes/__init__.py +5 -0
- examples/npc/api/routes/encounter.py +182 -0
- examples/npc/api/session.py +330 -0
- examples/npc/demo.py +387 -0
- examples/npc/nodes/__init__.py +5 -0
- examples/npc/nodes/image_node.py +92 -0
- examples/npc/run_encounter.py +230 -0
- examples/shared/__init__.py +0 -0
- examples/shared/replicate_tool.py +238 -0
- examples/storyboard/__init__.py +1 -0
- examples/storyboard/generate_videos.py +335 -0
- examples/storyboard/nodes/__init__.py +12 -0
- examples/storyboard/nodes/animated_character_node.py +248 -0
- examples/storyboard/nodes/animated_image_node.py +138 -0
- examples/storyboard/nodes/character_node.py +162 -0
- examples/storyboard/nodes/image_node.py +118 -0
- examples/storyboard/nodes/replicate_tool.py +49 -0
- examples/storyboard/retry_images.py +118 -0
- scripts/demo_async_executor.py +212 -0
- scripts/demo_interview_e2e.py +200 -0
- scripts/demo_streaming.py +140 -0
- scripts/run_interview_demo.py +94 -0
- scripts/test_interrupt_fix.py +26 -0
- tests/__init__.py +1 -0
- tests/conftest.py +178 -0
- tests/integration/__init__.py +1 -0
- tests/integration/test_animated_storyboard.py +63 -0
- tests/integration/test_cli_commands.py +242 -0
- tests/integration/test_colocated_prompts.py +139 -0
- tests/integration/test_map_demo.py +50 -0
- tests/integration/test_memory_demo.py +283 -0
- tests/integration/test_npc_api/__init__.py +1 -0
- tests/integration/test_npc_api/test_routes.py +357 -0
- tests/integration/test_npc_api/test_session.py +216 -0
- tests/integration/test_pipeline_flow.py +105 -0
- tests/integration/test_providers.py +163 -0
- tests/integration/test_resume.py +75 -0
- tests/integration/test_subgraph_integration.py +295 -0
- tests/integration/test_subgraph_interrupt.py +106 -0
- tests/unit/__init__.py +1 -0
- tests/unit/test_agent_nodes.py +355 -0
- tests/unit/test_async_executor.py +346 -0
- tests/unit/test_checkpointer.py +212 -0
- tests/unit/test_checkpointer_factory.py +212 -0
- tests/unit/test_cli.py +121 -0
- tests/unit/test_cli_package.py +81 -0
- tests/unit/test_compile_graph_map.py +132 -0
- tests/unit/test_conditions_routing.py +253 -0
- tests/unit/test_config.py +93 -0
- tests/unit/test_conversation_memory.py +276 -0
- tests/unit/test_database.py +145 -0
- tests/unit/test_deprecation.py +104 -0
- tests/unit/test_executor.py +172 -0
- tests/unit/test_executor_async.py +179 -0
- tests/unit/test_export.py +149 -0
- tests/unit/test_expressions.py +178 -0
- tests/unit/test_feature_brainstorm.py +194 -0
- tests/unit/test_format_prompt.py +145 -0
- tests/unit/test_generic_report.py +200 -0
- tests/unit/test_graph_commands.py +327 -0
- tests/unit/test_graph_linter.py +627 -0
- tests/unit/test_graph_loader.py +357 -0
- tests/unit/test_graph_schema.py +193 -0
- tests/unit/test_inline_schema.py +151 -0
- tests/unit/test_interrupt_node.py +182 -0
- tests/unit/test_issues.py +164 -0
- tests/unit/test_jinja2_prompts.py +85 -0
- tests/unit/test_json_extract.py +134 -0
- tests/unit/test_langsmith.py +600 -0
- tests/unit/test_langsmith_tools.py +204 -0
- tests/unit/test_llm_factory.py +109 -0
- tests/unit/test_llm_factory_async.py +118 -0
- tests/unit/test_loops.py +403 -0
- tests/unit/test_map_node.py +144 -0
- tests/unit/test_no_backward_compat.py +56 -0
- tests/unit/test_node_factory.py +348 -0
- tests/unit/test_passthrough_node.py +126 -0
- tests/unit/test_prompts.py +324 -0
- tests/unit/test_python_nodes.py +198 -0
- tests/unit/test_reliability.py +298 -0
- tests/unit/test_result_export.py +234 -0
- tests/unit/test_router.py +296 -0
- tests/unit/test_sanitize.py +99 -0
- tests/unit/test_schema_loader.py +295 -0
- tests/unit/test_shell_tools.py +229 -0
- tests/unit/test_state_builder.py +331 -0
- tests/unit/test_state_builder_map.py +104 -0
- tests/unit/test_state_config.py +197 -0
- tests/unit/test_streaming.py +307 -0
- tests/unit/test_subgraph.py +596 -0
- tests/unit/test_template.py +190 -0
- tests/unit/test_tool_call_integration.py +164 -0
- tests/unit/test_tool_call_node.py +178 -0
- tests/unit/test_tool_nodes.py +129 -0
- tests/unit/test_websearch.py +234 -0
- yamlgraph/__init__.py +35 -0
- yamlgraph/builder.py +110 -0
- yamlgraph/cli/__init__.py +159 -0
- yamlgraph/cli/__main__.py +6 -0
- yamlgraph/cli/commands.py +231 -0
- yamlgraph/cli/deprecation.py +92 -0
- yamlgraph/cli/graph_commands.py +541 -0
- yamlgraph/cli/validators.py +37 -0
- yamlgraph/config.py +67 -0
- yamlgraph/constants.py +70 -0
- yamlgraph/error_handlers.py +227 -0
- yamlgraph/executor.py +290 -0
- yamlgraph/executor_async.py +288 -0
- yamlgraph/graph_loader.py +451 -0
- yamlgraph/map_compiler.py +150 -0
- yamlgraph/models/__init__.py +36 -0
- yamlgraph/models/graph_schema.py +181 -0
- yamlgraph/models/schemas.py +124 -0
- yamlgraph/models/state_builder.py +236 -0
- yamlgraph/node_factory.py +768 -0
- yamlgraph/routing.py +87 -0
- yamlgraph/schema_loader.py +240 -0
- yamlgraph/storage/__init__.py +20 -0
- yamlgraph/storage/checkpointer.py +72 -0
- yamlgraph/storage/checkpointer_factory.py +123 -0
- yamlgraph/storage/database.py +320 -0
- yamlgraph/storage/export.py +269 -0
- yamlgraph/tools/__init__.py +1 -0
- yamlgraph/tools/agent.py +320 -0
- yamlgraph/tools/graph_linter.py +388 -0
- yamlgraph/tools/langsmith_tools.py +125 -0
- yamlgraph/tools/nodes.py +126 -0
- yamlgraph/tools/python_tool.py +179 -0
- yamlgraph/tools/shell.py +205 -0
- yamlgraph/tools/websearch.py +242 -0
- yamlgraph/utils/__init__.py +48 -0
- yamlgraph/utils/conditions.py +157 -0
- yamlgraph/utils/expressions.py +245 -0
- yamlgraph/utils/json_extract.py +104 -0
- yamlgraph/utils/langsmith.py +416 -0
- yamlgraph/utils/llm_factory.py +118 -0
- yamlgraph/utils/llm_factory_async.py +105 -0
- yamlgraph/utils/logging.py +104 -0
- yamlgraph/utils/prompts.py +171 -0
- yamlgraph/utils/sanitize.py +98 -0
- yamlgraph/utils/template.py +102 -0
- yamlgraph/utils/validators.py +181 -0
- yamlgraph-0.3.9.dist-info/METADATA +1105 -0
- yamlgraph-0.3.9.dist-info/RECORD +185 -0
- yamlgraph-0.3.9.dist-info/WHEEL +5 -0
- yamlgraph-0.3.9.dist-info/entry_points.txt +2 -0
- yamlgraph-0.3.9.dist-info/licenses/LICENSE +33 -0
- yamlgraph-0.3.9.dist-info/top_level.txt +4 -0
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
"""Tests for YAML state configuration parsing.
|
|
2
|
+
|
|
3
|
+
Tests the parse_state_config function and state: section handling
|
|
4
|
+
in build_state_class.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
from yamlgraph.models.state_builder import (
|
|
10
|
+
TYPE_MAP,
|
|
11
|
+
build_state_class,
|
|
12
|
+
parse_state_config,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class TestParseStateConfig:
|
|
17
|
+
"""Tests for parse_state_config function."""
|
|
18
|
+
|
|
19
|
+
def test_empty_config(self):
|
|
20
|
+
"""Empty state config returns empty dict."""
|
|
21
|
+
result = parse_state_config({})
|
|
22
|
+
assert result == {}
|
|
23
|
+
|
|
24
|
+
def test_simple_str_type(self):
|
|
25
|
+
"""Parse 'str' type."""
|
|
26
|
+
result = parse_state_config({"concept": "str"})
|
|
27
|
+
assert result == {"concept": str}
|
|
28
|
+
|
|
29
|
+
def test_simple_int_type(self):
|
|
30
|
+
"""Parse 'int' type."""
|
|
31
|
+
result = parse_state_config({"count": "int"})
|
|
32
|
+
assert result == {"count": int}
|
|
33
|
+
|
|
34
|
+
def test_simple_float_type(self):
|
|
35
|
+
"""Parse 'float' type."""
|
|
36
|
+
result = parse_state_config({"score": "float"})
|
|
37
|
+
assert result == {"score": float}
|
|
38
|
+
|
|
39
|
+
def test_simple_bool_type(self):
|
|
40
|
+
"""Parse 'bool' type."""
|
|
41
|
+
result = parse_state_config({"enabled": "bool"})
|
|
42
|
+
assert result == {"enabled": bool}
|
|
43
|
+
|
|
44
|
+
def test_simple_list_type(self):
|
|
45
|
+
"""Parse 'list' type."""
|
|
46
|
+
result = parse_state_config({"items": "list"})
|
|
47
|
+
assert result == {"items": list}
|
|
48
|
+
|
|
49
|
+
def test_simple_dict_type(self):
|
|
50
|
+
"""Parse 'dict' type."""
|
|
51
|
+
result = parse_state_config({"metadata": "dict"})
|
|
52
|
+
assert result == {"metadata": dict}
|
|
53
|
+
|
|
54
|
+
def test_any_type(self):
|
|
55
|
+
"""Parse 'any' type."""
|
|
56
|
+
result = parse_state_config({"data": "any"})
|
|
57
|
+
assert result == {"data": Any}
|
|
58
|
+
|
|
59
|
+
def test_type_aliases(self):
|
|
60
|
+
"""Type aliases like 'string', 'integer', 'boolean' work."""
|
|
61
|
+
result = parse_state_config(
|
|
62
|
+
{
|
|
63
|
+
"name": "string",
|
|
64
|
+
"age": "integer",
|
|
65
|
+
"active": "boolean",
|
|
66
|
+
}
|
|
67
|
+
)
|
|
68
|
+
assert result == {"name": str, "age": int, "active": bool}
|
|
69
|
+
|
|
70
|
+
def test_case_insensitive(self):
|
|
71
|
+
"""Type names are case-insensitive."""
|
|
72
|
+
result = parse_state_config(
|
|
73
|
+
{
|
|
74
|
+
"a": "STR",
|
|
75
|
+
"b": "Int",
|
|
76
|
+
"c": "FLOAT",
|
|
77
|
+
}
|
|
78
|
+
)
|
|
79
|
+
assert result == {"a": str, "b": int, "c": float}
|
|
80
|
+
|
|
81
|
+
def test_multiple_fields(self):
|
|
82
|
+
"""Parse multiple fields."""
|
|
83
|
+
result = parse_state_config(
|
|
84
|
+
{
|
|
85
|
+
"concept": "str",
|
|
86
|
+
"count": "int",
|
|
87
|
+
"score": "float",
|
|
88
|
+
}
|
|
89
|
+
)
|
|
90
|
+
assert result == {"concept": str, "count": int, "score": float}
|
|
91
|
+
|
|
92
|
+
def test_unknown_type_defaults_to_any(self):
|
|
93
|
+
"""Unknown type strings default to Any."""
|
|
94
|
+
result = parse_state_config({"custom": "unknown_type"})
|
|
95
|
+
assert result == {"custom": Any}
|
|
96
|
+
|
|
97
|
+
def test_non_string_value_defaults_to_any(self):
|
|
98
|
+
"""Non-string values default to Any."""
|
|
99
|
+
result = parse_state_config(
|
|
100
|
+
{
|
|
101
|
+
"nested": {"type": "str"}, # Dict value, not string
|
|
102
|
+
"number": 123, # Int value, not string
|
|
103
|
+
}
|
|
104
|
+
)
|
|
105
|
+
assert result == {"nested": Any, "number": Any}
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class TestTypeMap:
|
|
109
|
+
"""Tests for TYPE_MAP constant."""
|
|
110
|
+
|
|
111
|
+
def test_all_basic_types_present(self):
|
|
112
|
+
"""TYPE_MAP contains all basic Python types."""
|
|
113
|
+
assert "str" in TYPE_MAP
|
|
114
|
+
assert "int" in TYPE_MAP
|
|
115
|
+
assert "float" in TYPE_MAP
|
|
116
|
+
assert "bool" in TYPE_MAP
|
|
117
|
+
assert "list" in TYPE_MAP
|
|
118
|
+
assert "dict" in TYPE_MAP
|
|
119
|
+
assert "any" in TYPE_MAP
|
|
120
|
+
|
|
121
|
+
def test_aliases_present(self):
|
|
122
|
+
"""TYPE_MAP contains common aliases."""
|
|
123
|
+
assert "string" in TYPE_MAP
|
|
124
|
+
assert "integer" in TYPE_MAP
|
|
125
|
+
assert "boolean" in TYPE_MAP
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class TestBuildStateClassWithStateConfig:
|
|
129
|
+
"""Tests for build_state_class with state: section."""
|
|
130
|
+
|
|
131
|
+
def test_state_section_adds_fields(self):
|
|
132
|
+
"""State section fields are included in generated class."""
|
|
133
|
+
config = {
|
|
134
|
+
"state": {"concept": "str", "count": "int"},
|
|
135
|
+
"nodes": {},
|
|
136
|
+
"edges": [],
|
|
137
|
+
}
|
|
138
|
+
state_class = build_state_class(config)
|
|
139
|
+
annotations = state_class.__annotations__
|
|
140
|
+
|
|
141
|
+
assert "concept" in annotations
|
|
142
|
+
assert "count" in annotations
|
|
143
|
+
|
|
144
|
+
def test_state_section_empty(self):
|
|
145
|
+
"""Empty state section doesn't break build."""
|
|
146
|
+
config = {
|
|
147
|
+
"state": {},
|
|
148
|
+
"nodes": {},
|
|
149
|
+
"edges": [],
|
|
150
|
+
}
|
|
151
|
+
state_class = build_state_class(config)
|
|
152
|
+
# Should still have base fields
|
|
153
|
+
assert "thread_id" in state_class.__annotations__
|
|
154
|
+
|
|
155
|
+
def test_state_section_missing(self):
|
|
156
|
+
"""Missing state section is handled."""
|
|
157
|
+
config = {
|
|
158
|
+
"nodes": {},
|
|
159
|
+
"edges": [],
|
|
160
|
+
}
|
|
161
|
+
state_class = build_state_class(config)
|
|
162
|
+
# Should still have base fields
|
|
163
|
+
assert "thread_id" in state_class.__annotations__
|
|
164
|
+
|
|
165
|
+
def test_custom_field_overrides_common(self):
|
|
166
|
+
"""Custom state field can override common field type."""
|
|
167
|
+
config = {
|
|
168
|
+
"state": {"topic": "int"}, # Override str default
|
|
169
|
+
"nodes": {},
|
|
170
|
+
"edges": [],
|
|
171
|
+
}
|
|
172
|
+
state_class = build_state_class(config)
|
|
173
|
+
# The custom field should be present
|
|
174
|
+
assert "topic" in state_class.__annotations__
|
|
175
|
+
|
|
176
|
+
def test_storyboard_example(self):
|
|
177
|
+
"""Test storyboard-style config with concept field."""
|
|
178
|
+
config = {
|
|
179
|
+
"state": {"concept": "str"},
|
|
180
|
+
"nodes": {
|
|
181
|
+
"expand_story": {
|
|
182
|
+
"type": "llm",
|
|
183
|
+
"state_key": "story",
|
|
184
|
+
}
|
|
185
|
+
},
|
|
186
|
+
"edges": [],
|
|
187
|
+
}
|
|
188
|
+
state_class = build_state_class(config)
|
|
189
|
+
annotations = state_class.__annotations__
|
|
190
|
+
|
|
191
|
+
# Custom field from state section
|
|
192
|
+
assert "concept" in annotations
|
|
193
|
+
# Output field from node
|
|
194
|
+
assert "story" in annotations
|
|
195
|
+
# Base infrastructure fields
|
|
196
|
+
assert "thread_id" in annotations
|
|
197
|
+
assert "errors" in annotations
|
|
@@ -0,0 +1,307 @@
|
|
|
1
|
+
"""Tests for streaming support - Phase 3 (004).
|
|
2
|
+
|
|
3
|
+
TDD: RED phase - write tests first.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from unittest.mock import MagicMock, patch
|
|
7
|
+
|
|
8
|
+
import pytest
|
|
9
|
+
|
|
10
|
+
# ==============================================================================
|
|
11
|
+
# execute_prompt_streaming tests
|
|
12
|
+
# ==============================================================================
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@pytest.mark.asyncio
|
|
16
|
+
async def test_execute_prompt_streaming_yields_tokens():
|
|
17
|
+
"""execute_prompt_streaming yields tokens from LLM stream."""
|
|
18
|
+
from yamlgraph.executor_async import execute_prompt_streaming
|
|
19
|
+
|
|
20
|
+
# Mock LLM with astream method
|
|
21
|
+
mock_chunk1 = MagicMock()
|
|
22
|
+
mock_chunk1.content = "Hello"
|
|
23
|
+
mock_chunk2 = MagicMock()
|
|
24
|
+
mock_chunk2.content = " World"
|
|
25
|
+
mock_chunk3 = MagicMock()
|
|
26
|
+
mock_chunk3.content = "!"
|
|
27
|
+
|
|
28
|
+
async def mock_astream(*args, **kwargs):
|
|
29
|
+
for chunk in [mock_chunk1, mock_chunk2, mock_chunk3]:
|
|
30
|
+
yield chunk
|
|
31
|
+
|
|
32
|
+
mock_llm = MagicMock()
|
|
33
|
+
mock_llm.astream = mock_astream
|
|
34
|
+
|
|
35
|
+
with (
|
|
36
|
+
patch("yamlgraph.executor_async.create_llm", return_value=mock_llm),
|
|
37
|
+
patch("yamlgraph.executor_async.load_prompt") as mock_load,
|
|
38
|
+
):
|
|
39
|
+
mock_load.return_value = {
|
|
40
|
+
"system": "You are helpful.",
|
|
41
|
+
"user": "Say hello",
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
tokens = []
|
|
45
|
+
async for token in execute_prompt_streaming("greet", variables={}):
|
|
46
|
+
tokens.append(token)
|
|
47
|
+
|
|
48
|
+
assert tokens == ["Hello", " World", "!"]
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@pytest.mark.asyncio
|
|
52
|
+
async def test_execute_prompt_streaming_with_variables():
|
|
53
|
+
"""execute_prompt_streaming formats template with variables."""
|
|
54
|
+
from yamlgraph.executor_async import execute_prompt_streaming
|
|
55
|
+
|
|
56
|
+
mock_chunk = MagicMock()
|
|
57
|
+
mock_chunk.content = "Hi Alice!"
|
|
58
|
+
|
|
59
|
+
async def mock_astream(*args, **kwargs):
|
|
60
|
+
yield mock_chunk
|
|
61
|
+
|
|
62
|
+
mock_llm = MagicMock()
|
|
63
|
+
mock_llm.astream = mock_astream
|
|
64
|
+
|
|
65
|
+
with (
|
|
66
|
+
patch("yamlgraph.executor_async.create_llm", return_value=mock_llm),
|
|
67
|
+
patch("yamlgraph.executor_async.load_prompt") as mock_load,
|
|
68
|
+
patch("yamlgraph.executor_async.format_prompt", return_value="Say hello to Alice") as mock_format,
|
|
69
|
+
):
|
|
70
|
+
mock_load.return_value = {
|
|
71
|
+
"system": "",
|
|
72
|
+
"user": "Say hello to {name}",
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
tokens = []
|
|
76
|
+
async for token in execute_prompt_streaming("greet", variables={"name": "Alice"}):
|
|
77
|
+
tokens.append(token)
|
|
78
|
+
|
|
79
|
+
assert tokens == ["Hi Alice!"]
|
|
80
|
+
mock_format.assert_called()
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
@pytest.mark.asyncio
|
|
84
|
+
async def test_execute_prompt_streaming_uses_provider():
|
|
85
|
+
"""execute_prompt_streaming passes provider to create_llm."""
|
|
86
|
+
from yamlgraph.executor_async import execute_prompt_streaming
|
|
87
|
+
|
|
88
|
+
mock_chunk = MagicMock()
|
|
89
|
+
mock_chunk.content = "test"
|
|
90
|
+
|
|
91
|
+
async def mock_astream(*args, **kwargs):
|
|
92
|
+
yield mock_chunk
|
|
93
|
+
|
|
94
|
+
mock_llm = MagicMock()
|
|
95
|
+
mock_llm.astream = mock_astream
|
|
96
|
+
|
|
97
|
+
with (
|
|
98
|
+
patch("yamlgraph.executor_async.create_llm", return_value=mock_llm) as mock_create,
|
|
99
|
+
patch("yamlgraph.executor_async.load_prompt") as mock_load,
|
|
100
|
+
):
|
|
101
|
+
mock_load.return_value = {"system": "", "user": "test"}
|
|
102
|
+
|
|
103
|
+
async for _ in execute_prompt_streaming("test", variables={}, provider="openai"):
|
|
104
|
+
pass
|
|
105
|
+
|
|
106
|
+
mock_create.assert_called_once()
|
|
107
|
+
call_kwargs = mock_create.call_args.kwargs
|
|
108
|
+
assert call_kwargs.get("provider") == "openai"
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
@pytest.mark.asyncio
|
|
112
|
+
async def test_execute_prompt_streaming_handles_empty_chunks():
|
|
113
|
+
"""execute_prompt_streaming skips empty chunks."""
|
|
114
|
+
from yamlgraph.executor_async import execute_prompt_streaming
|
|
115
|
+
|
|
116
|
+
mock_chunk1 = MagicMock()
|
|
117
|
+
mock_chunk1.content = "Hello"
|
|
118
|
+
mock_chunk2 = MagicMock()
|
|
119
|
+
mock_chunk2.content = "" # Empty
|
|
120
|
+
mock_chunk3 = MagicMock()
|
|
121
|
+
mock_chunk3.content = None # None
|
|
122
|
+
mock_chunk4 = MagicMock()
|
|
123
|
+
mock_chunk4.content = "World"
|
|
124
|
+
|
|
125
|
+
async def mock_astream(*args, **kwargs):
|
|
126
|
+
for chunk in [mock_chunk1, mock_chunk2, mock_chunk3, mock_chunk4]:
|
|
127
|
+
yield chunk
|
|
128
|
+
|
|
129
|
+
mock_llm = MagicMock()
|
|
130
|
+
mock_llm.astream = mock_astream
|
|
131
|
+
|
|
132
|
+
with (
|
|
133
|
+
patch("yamlgraph.executor_async.create_llm", return_value=mock_llm),
|
|
134
|
+
patch("yamlgraph.executor_async.load_prompt") as mock_load,
|
|
135
|
+
):
|
|
136
|
+
mock_load.return_value = {"system": "", "user": "test"}
|
|
137
|
+
|
|
138
|
+
tokens = []
|
|
139
|
+
async for token in execute_prompt_streaming("test", variables={}):
|
|
140
|
+
tokens.append(token)
|
|
141
|
+
|
|
142
|
+
# Only non-empty tokens
|
|
143
|
+
assert tokens == ["Hello", "World"]
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
@pytest.mark.asyncio
|
|
147
|
+
async def test_execute_prompt_streaming_propagates_errors():
|
|
148
|
+
"""execute_prompt_streaming propagates LLM errors."""
|
|
149
|
+
from yamlgraph.executor_async import execute_prompt_streaming
|
|
150
|
+
|
|
151
|
+
async def mock_astream(*args, **kwargs):
|
|
152
|
+
yield MagicMock(content="start")
|
|
153
|
+
raise ValueError("LLM error")
|
|
154
|
+
|
|
155
|
+
mock_llm = MagicMock()
|
|
156
|
+
mock_llm.astream = mock_astream
|
|
157
|
+
|
|
158
|
+
with (
|
|
159
|
+
patch("yamlgraph.executor_async.create_llm", return_value=mock_llm),
|
|
160
|
+
patch("yamlgraph.executor_async.load_prompt") as mock_load,
|
|
161
|
+
):
|
|
162
|
+
mock_load.return_value = {"system": "", "user": "test"}
|
|
163
|
+
|
|
164
|
+
tokens = []
|
|
165
|
+
with pytest.raises(ValueError, match="LLM error"):
|
|
166
|
+
async for token in execute_prompt_streaming("test", variables={}):
|
|
167
|
+
tokens.append(token)
|
|
168
|
+
|
|
169
|
+
# Should have received first token before error
|
|
170
|
+
assert tokens == ["start"]
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
# ==============================================================================
|
|
174
|
+
# Streaming with output collection
|
|
175
|
+
# ==============================================================================
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
@pytest.mark.asyncio
|
|
179
|
+
async def test_execute_prompt_streaming_collect():
|
|
180
|
+
"""execute_prompt_streaming can collect all tokens into string."""
|
|
181
|
+
from yamlgraph.executor_async import execute_prompt_streaming
|
|
182
|
+
|
|
183
|
+
mock_chunks = [MagicMock(content=c) for c in ["The ", "quick ", "brown ", "fox"]]
|
|
184
|
+
|
|
185
|
+
async def mock_astream(*args, **kwargs):
|
|
186
|
+
for chunk in mock_chunks:
|
|
187
|
+
yield chunk
|
|
188
|
+
|
|
189
|
+
mock_llm = MagicMock()
|
|
190
|
+
mock_llm.astream = mock_astream
|
|
191
|
+
|
|
192
|
+
with (
|
|
193
|
+
patch("yamlgraph.executor_async.create_llm", return_value=mock_llm),
|
|
194
|
+
patch("yamlgraph.executor_async.load_prompt") as mock_load,
|
|
195
|
+
):
|
|
196
|
+
mock_load.return_value = {"system": "", "user": "test"}
|
|
197
|
+
|
|
198
|
+
# Collect all tokens
|
|
199
|
+
result = "".join([token async for token in execute_prompt_streaming("test", {})])
|
|
200
|
+
|
|
201
|
+
assert result == "The quick brown fox"
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
# ==============================================================================
|
|
205
|
+
# Streaming node factory tests
|
|
206
|
+
# ==============================================================================
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
@pytest.mark.asyncio
|
|
210
|
+
async def test_create_streaming_node_yields_tokens():
|
|
211
|
+
"""create_streaming_node creates node that yields tokens."""
|
|
212
|
+
from yamlgraph.node_factory import create_streaming_node
|
|
213
|
+
|
|
214
|
+
mock_chunks = [MagicMock(content=c) for c in ["Hello", " ", "World"]]
|
|
215
|
+
|
|
216
|
+
async def mock_streaming(*args, **kwargs):
|
|
217
|
+
for chunk in mock_chunks:
|
|
218
|
+
yield chunk.content
|
|
219
|
+
|
|
220
|
+
with patch("yamlgraph.executor_async.execute_prompt_streaming", mock_streaming):
|
|
221
|
+
node_config = {
|
|
222
|
+
"prompt": "greet",
|
|
223
|
+
"state_key": "response",
|
|
224
|
+
}
|
|
225
|
+
streaming_node = create_streaming_node("generate", node_config)
|
|
226
|
+
|
|
227
|
+
state = {"input": "test"}
|
|
228
|
+
tokens = []
|
|
229
|
+
async for token in streaming_node(state):
|
|
230
|
+
tokens.append(token)
|
|
231
|
+
|
|
232
|
+
assert tokens == ["Hello", " ", "World"]
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
@pytest.mark.asyncio
|
|
236
|
+
async def test_streaming_node_with_callback():
|
|
237
|
+
"""Streaming node can use callback for each token."""
|
|
238
|
+
from yamlgraph.node_factory import create_streaming_node
|
|
239
|
+
|
|
240
|
+
async def mock_streaming(*args, **kwargs):
|
|
241
|
+
for token in ["A", "B", "C"]:
|
|
242
|
+
yield token
|
|
243
|
+
|
|
244
|
+
collected = []
|
|
245
|
+
|
|
246
|
+
def token_callback(token: str):
|
|
247
|
+
collected.append(token)
|
|
248
|
+
|
|
249
|
+
with patch("yamlgraph.executor_async.execute_prompt_streaming", mock_streaming):
|
|
250
|
+
node_config = {
|
|
251
|
+
"prompt": "test",
|
|
252
|
+
"state_key": "output",
|
|
253
|
+
"on_token": token_callback,
|
|
254
|
+
}
|
|
255
|
+
streaming_node = create_streaming_node("stream_node", node_config)
|
|
256
|
+
|
|
257
|
+
# Consume the generator
|
|
258
|
+
async for _ in streaming_node({}):
|
|
259
|
+
pass
|
|
260
|
+
|
|
261
|
+
assert collected == ["A", "B", "C"]
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
# ==============================================================================
|
|
265
|
+
# YAML config tests
|
|
266
|
+
# ==============================================================================
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
def test_node_config_stream_true_creates_streaming_node():
|
|
270
|
+
"""Node with stream: true creates streaming node."""
|
|
271
|
+
from yamlgraph.node_factory import create_node_function
|
|
272
|
+
|
|
273
|
+
node_config = {
|
|
274
|
+
"prompt": "greet",
|
|
275
|
+
"state_key": "response",
|
|
276
|
+
"stream": True,
|
|
277
|
+
}
|
|
278
|
+
|
|
279
|
+
with patch("yamlgraph.node_factory.create_streaming_node") as mock_create:
|
|
280
|
+
mock_create.return_value = MagicMock()
|
|
281
|
+
|
|
282
|
+
# This should detect stream: true and use create_streaming_node
|
|
283
|
+
_result = create_node_function("generate", node_config, defaults={})
|
|
284
|
+
|
|
285
|
+
mock_create.assert_called_once_with("generate", node_config)
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
def test_node_config_stream_false_creates_regular_node():
|
|
289
|
+
"""Node with stream: false creates regular node."""
|
|
290
|
+
from yamlgraph.node_factory import create_node_function
|
|
291
|
+
|
|
292
|
+
node_config = {
|
|
293
|
+
"prompt": "greet",
|
|
294
|
+
"state_key": "response",
|
|
295
|
+
"stream": False,
|
|
296
|
+
}
|
|
297
|
+
|
|
298
|
+
with (
|
|
299
|
+
patch("yamlgraph.node_factory.create_streaming_node") as mock_streaming,
|
|
300
|
+
patch("yamlgraph.node_factory.execute_prompt") as mock_execute,
|
|
301
|
+
):
|
|
302
|
+
mock_execute.return_value = "result"
|
|
303
|
+
|
|
304
|
+
_result = create_node_function("generate", node_config, defaults={})
|
|
305
|
+
|
|
306
|
+
# Should NOT call create_streaming_node
|
|
307
|
+
mock_streaming.assert_not_called()
|