yamlgraph 0.3.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. examples/__init__.py +1 -0
  2. examples/codegen/__init__.py +5 -0
  3. examples/codegen/models/__init__.py +13 -0
  4. examples/codegen/models/schemas.py +76 -0
  5. examples/codegen/tests/__init__.py +1 -0
  6. examples/codegen/tests/test_ai_helpers.py +235 -0
  7. examples/codegen/tests/test_ast_analysis.py +174 -0
  8. examples/codegen/tests/test_code_analysis.py +134 -0
  9. examples/codegen/tests/test_code_context.py +301 -0
  10. examples/codegen/tests/test_code_nav.py +89 -0
  11. examples/codegen/tests/test_dependency_tools.py +119 -0
  12. examples/codegen/tests/test_example_tools.py +185 -0
  13. examples/codegen/tests/test_git_tools.py +112 -0
  14. examples/codegen/tests/test_impl_agent_schemas.py +193 -0
  15. examples/codegen/tests/test_impl_agent_v4_graph.py +94 -0
  16. examples/codegen/tests/test_jedi_analysis.py +226 -0
  17. examples/codegen/tests/test_meta_tools.py +250 -0
  18. examples/codegen/tests/test_plan_discovery_prompt.py +98 -0
  19. examples/codegen/tests/test_syntax_tools.py +85 -0
  20. examples/codegen/tests/test_synthesize_prompt.py +94 -0
  21. examples/codegen/tests/test_template_tools.py +244 -0
  22. examples/codegen/tools/__init__.py +80 -0
  23. examples/codegen/tools/ai_helpers.py +420 -0
  24. examples/codegen/tools/ast_analysis.py +92 -0
  25. examples/codegen/tools/code_context.py +180 -0
  26. examples/codegen/tools/code_nav.py +52 -0
  27. examples/codegen/tools/dependency_tools.py +120 -0
  28. examples/codegen/tools/example_tools.py +188 -0
  29. examples/codegen/tools/git_tools.py +151 -0
  30. examples/codegen/tools/impl_executor.py +614 -0
  31. examples/codegen/tools/jedi_analysis.py +311 -0
  32. examples/codegen/tools/meta_tools.py +202 -0
  33. examples/codegen/tools/syntax_tools.py +26 -0
  34. examples/codegen/tools/template_tools.py +356 -0
  35. examples/fastapi_interview.py +167 -0
  36. examples/npc/api/__init__.py +1 -0
  37. examples/npc/api/app.py +100 -0
  38. examples/npc/api/routes/__init__.py +5 -0
  39. examples/npc/api/routes/encounter.py +182 -0
  40. examples/npc/api/session.py +330 -0
  41. examples/npc/demo.py +387 -0
  42. examples/npc/nodes/__init__.py +5 -0
  43. examples/npc/nodes/image_node.py +92 -0
  44. examples/npc/run_encounter.py +230 -0
  45. examples/shared/__init__.py +0 -0
  46. examples/shared/replicate_tool.py +238 -0
  47. examples/storyboard/__init__.py +1 -0
  48. examples/storyboard/generate_videos.py +335 -0
  49. examples/storyboard/nodes/__init__.py +12 -0
  50. examples/storyboard/nodes/animated_character_node.py +248 -0
  51. examples/storyboard/nodes/animated_image_node.py +138 -0
  52. examples/storyboard/nodes/character_node.py +162 -0
  53. examples/storyboard/nodes/image_node.py +118 -0
  54. examples/storyboard/nodes/replicate_tool.py +49 -0
  55. examples/storyboard/retry_images.py +118 -0
  56. scripts/demo_async_executor.py +212 -0
  57. scripts/demo_interview_e2e.py +200 -0
  58. scripts/demo_streaming.py +140 -0
  59. scripts/run_interview_demo.py +94 -0
  60. scripts/test_interrupt_fix.py +26 -0
  61. tests/__init__.py +1 -0
  62. tests/conftest.py +178 -0
  63. tests/integration/__init__.py +1 -0
  64. tests/integration/test_animated_storyboard.py +63 -0
  65. tests/integration/test_cli_commands.py +242 -0
  66. tests/integration/test_colocated_prompts.py +139 -0
  67. tests/integration/test_map_demo.py +50 -0
  68. tests/integration/test_memory_demo.py +283 -0
  69. tests/integration/test_npc_api/__init__.py +1 -0
  70. tests/integration/test_npc_api/test_routes.py +357 -0
  71. tests/integration/test_npc_api/test_session.py +216 -0
  72. tests/integration/test_pipeline_flow.py +105 -0
  73. tests/integration/test_providers.py +163 -0
  74. tests/integration/test_resume.py +75 -0
  75. tests/integration/test_subgraph_integration.py +295 -0
  76. tests/integration/test_subgraph_interrupt.py +106 -0
  77. tests/unit/__init__.py +1 -0
  78. tests/unit/test_agent_nodes.py +355 -0
  79. tests/unit/test_async_executor.py +346 -0
  80. tests/unit/test_checkpointer.py +212 -0
  81. tests/unit/test_checkpointer_factory.py +212 -0
  82. tests/unit/test_cli.py +121 -0
  83. tests/unit/test_cli_package.py +81 -0
  84. tests/unit/test_compile_graph_map.py +132 -0
  85. tests/unit/test_conditions_routing.py +253 -0
  86. tests/unit/test_config.py +93 -0
  87. tests/unit/test_conversation_memory.py +276 -0
  88. tests/unit/test_database.py +145 -0
  89. tests/unit/test_deprecation.py +104 -0
  90. tests/unit/test_executor.py +172 -0
  91. tests/unit/test_executor_async.py +179 -0
  92. tests/unit/test_export.py +149 -0
  93. tests/unit/test_expressions.py +178 -0
  94. tests/unit/test_feature_brainstorm.py +194 -0
  95. tests/unit/test_format_prompt.py +145 -0
  96. tests/unit/test_generic_report.py +200 -0
  97. tests/unit/test_graph_commands.py +327 -0
  98. tests/unit/test_graph_linter.py +627 -0
  99. tests/unit/test_graph_loader.py +357 -0
  100. tests/unit/test_graph_schema.py +193 -0
  101. tests/unit/test_inline_schema.py +151 -0
  102. tests/unit/test_interrupt_node.py +182 -0
  103. tests/unit/test_issues.py +164 -0
  104. tests/unit/test_jinja2_prompts.py +85 -0
  105. tests/unit/test_json_extract.py +134 -0
  106. tests/unit/test_langsmith.py +600 -0
  107. tests/unit/test_langsmith_tools.py +204 -0
  108. tests/unit/test_llm_factory.py +109 -0
  109. tests/unit/test_llm_factory_async.py +118 -0
  110. tests/unit/test_loops.py +403 -0
  111. tests/unit/test_map_node.py +144 -0
  112. tests/unit/test_no_backward_compat.py +56 -0
  113. tests/unit/test_node_factory.py +348 -0
  114. tests/unit/test_passthrough_node.py +126 -0
  115. tests/unit/test_prompts.py +324 -0
  116. tests/unit/test_python_nodes.py +198 -0
  117. tests/unit/test_reliability.py +298 -0
  118. tests/unit/test_result_export.py +234 -0
  119. tests/unit/test_router.py +296 -0
  120. tests/unit/test_sanitize.py +99 -0
  121. tests/unit/test_schema_loader.py +295 -0
  122. tests/unit/test_shell_tools.py +229 -0
  123. tests/unit/test_state_builder.py +331 -0
  124. tests/unit/test_state_builder_map.py +104 -0
  125. tests/unit/test_state_config.py +197 -0
  126. tests/unit/test_streaming.py +307 -0
  127. tests/unit/test_subgraph.py +596 -0
  128. tests/unit/test_template.py +190 -0
  129. tests/unit/test_tool_call_integration.py +164 -0
  130. tests/unit/test_tool_call_node.py +178 -0
  131. tests/unit/test_tool_nodes.py +129 -0
  132. tests/unit/test_websearch.py +234 -0
  133. yamlgraph/__init__.py +35 -0
  134. yamlgraph/builder.py +110 -0
  135. yamlgraph/cli/__init__.py +159 -0
  136. yamlgraph/cli/__main__.py +6 -0
  137. yamlgraph/cli/commands.py +231 -0
  138. yamlgraph/cli/deprecation.py +92 -0
  139. yamlgraph/cli/graph_commands.py +541 -0
  140. yamlgraph/cli/validators.py +37 -0
  141. yamlgraph/config.py +67 -0
  142. yamlgraph/constants.py +70 -0
  143. yamlgraph/error_handlers.py +227 -0
  144. yamlgraph/executor.py +290 -0
  145. yamlgraph/executor_async.py +288 -0
  146. yamlgraph/graph_loader.py +451 -0
  147. yamlgraph/map_compiler.py +150 -0
  148. yamlgraph/models/__init__.py +36 -0
  149. yamlgraph/models/graph_schema.py +181 -0
  150. yamlgraph/models/schemas.py +124 -0
  151. yamlgraph/models/state_builder.py +236 -0
  152. yamlgraph/node_factory.py +768 -0
  153. yamlgraph/routing.py +87 -0
  154. yamlgraph/schema_loader.py +240 -0
  155. yamlgraph/storage/__init__.py +20 -0
  156. yamlgraph/storage/checkpointer.py +72 -0
  157. yamlgraph/storage/checkpointer_factory.py +123 -0
  158. yamlgraph/storage/database.py +320 -0
  159. yamlgraph/storage/export.py +269 -0
  160. yamlgraph/tools/__init__.py +1 -0
  161. yamlgraph/tools/agent.py +320 -0
  162. yamlgraph/tools/graph_linter.py +388 -0
  163. yamlgraph/tools/langsmith_tools.py +125 -0
  164. yamlgraph/tools/nodes.py +126 -0
  165. yamlgraph/tools/python_tool.py +179 -0
  166. yamlgraph/tools/shell.py +205 -0
  167. yamlgraph/tools/websearch.py +242 -0
  168. yamlgraph/utils/__init__.py +48 -0
  169. yamlgraph/utils/conditions.py +157 -0
  170. yamlgraph/utils/expressions.py +245 -0
  171. yamlgraph/utils/json_extract.py +104 -0
  172. yamlgraph/utils/langsmith.py +416 -0
  173. yamlgraph/utils/llm_factory.py +118 -0
  174. yamlgraph/utils/llm_factory_async.py +105 -0
  175. yamlgraph/utils/logging.py +104 -0
  176. yamlgraph/utils/prompts.py +171 -0
  177. yamlgraph/utils/sanitize.py +98 -0
  178. yamlgraph/utils/template.py +102 -0
  179. yamlgraph/utils/validators.py +181 -0
  180. yamlgraph-0.3.9.dist-info/METADATA +1105 -0
  181. yamlgraph-0.3.9.dist-info/RECORD +185 -0
  182. yamlgraph-0.3.9.dist-info/WHEEL +5 -0
  183. yamlgraph-0.3.9.dist-info/entry_points.txt +2 -0
  184. yamlgraph-0.3.9.dist-info/licenses/LICENSE +33 -0
  185. yamlgraph-0.3.9.dist-info/top_level.txt +4 -0
@@ -0,0 +1,197 @@
1
+ """Tests for YAML state configuration parsing.
2
+
3
+ Tests the parse_state_config function and state: section handling
4
+ in build_state_class.
5
+ """
6
+
7
+ from typing import Any
8
+
9
+ from yamlgraph.models.state_builder import (
10
+ TYPE_MAP,
11
+ build_state_class,
12
+ parse_state_config,
13
+ )
14
+
15
+
16
+ class TestParseStateConfig:
17
+ """Tests for parse_state_config function."""
18
+
19
+ def test_empty_config(self):
20
+ """Empty state config returns empty dict."""
21
+ result = parse_state_config({})
22
+ assert result == {}
23
+
24
+ def test_simple_str_type(self):
25
+ """Parse 'str' type."""
26
+ result = parse_state_config({"concept": "str"})
27
+ assert result == {"concept": str}
28
+
29
+ def test_simple_int_type(self):
30
+ """Parse 'int' type."""
31
+ result = parse_state_config({"count": "int"})
32
+ assert result == {"count": int}
33
+
34
+ def test_simple_float_type(self):
35
+ """Parse 'float' type."""
36
+ result = parse_state_config({"score": "float"})
37
+ assert result == {"score": float}
38
+
39
+ def test_simple_bool_type(self):
40
+ """Parse 'bool' type."""
41
+ result = parse_state_config({"enabled": "bool"})
42
+ assert result == {"enabled": bool}
43
+
44
+ def test_simple_list_type(self):
45
+ """Parse 'list' type."""
46
+ result = parse_state_config({"items": "list"})
47
+ assert result == {"items": list}
48
+
49
+ def test_simple_dict_type(self):
50
+ """Parse 'dict' type."""
51
+ result = parse_state_config({"metadata": "dict"})
52
+ assert result == {"metadata": dict}
53
+
54
+ def test_any_type(self):
55
+ """Parse 'any' type."""
56
+ result = parse_state_config({"data": "any"})
57
+ assert result == {"data": Any}
58
+
59
+ def test_type_aliases(self):
60
+ """Type aliases like 'string', 'integer', 'boolean' work."""
61
+ result = parse_state_config(
62
+ {
63
+ "name": "string",
64
+ "age": "integer",
65
+ "active": "boolean",
66
+ }
67
+ )
68
+ assert result == {"name": str, "age": int, "active": bool}
69
+
70
+ def test_case_insensitive(self):
71
+ """Type names are case-insensitive."""
72
+ result = parse_state_config(
73
+ {
74
+ "a": "STR",
75
+ "b": "Int",
76
+ "c": "FLOAT",
77
+ }
78
+ )
79
+ assert result == {"a": str, "b": int, "c": float}
80
+
81
+ def test_multiple_fields(self):
82
+ """Parse multiple fields."""
83
+ result = parse_state_config(
84
+ {
85
+ "concept": "str",
86
+ "count": "int",
87
+ "score": "float",
88
+ }
89
+ )
90
+ assert result == {"concept": str, "count": int, "score": float}
91
+
92
+ def test_unknown_type_defaults_to_any(self):
93
+ """Unknown type strings default to Any."""
94
+ result = parse_state_config({"custom": "unknown_type"})
95
+ assert result == {"custom": Any}
96
+
97
+ def test_non_string_value_defaults_to_any(self):
98
+ """Non-string values default to Any."""
99
+ result = parse_state_config(
100
+ {
101
+ "nested": {"type": "str"}, # Dict value, not string
102
+ "number": 123, # Int value, not string
103
+ }
104
+ )
105
+ assert result == {"nested": Any, "number": Any}
106
+
107
+
108
+ class TestTypeMap:
109
+ """Tests for TYPE_MAP constant."""
110
+
111
+ def test_all_basic_types_present(self):
112
+ """TYPE_MAP contains all basic Python types."""
113
+ assert "str" in TYPE_MAP
114
+ assert "int" in TYPE_MAP
115
+ assert "float" in TYPE_MAP
116
+ assert "bool" in TYPE_MAP
117
+ assert "list" in TYPE_MAP
118
+ assert "dict" in TYPE_MAP
119
+ assert "any" in TYPE_MAP
120
+
121
+ def test_aliases_present(self):
122
+ """TYPE_MAP contains common aliases."""
123
+ assert "string" in TYPE_MAP
124
+ assert "integer" in TYPE_MAP
125
+ assert "boolean" in TYPE_MAP
126
+
127
+
128
+ class TestBuildStateClassWithStateConfig:
129
+ """Tests for build_state_class with state: section."""
130
+
131
+ def test_state_section_adds_fields(self):
132
+ """State section fields are included in generated class."""
133
+ config = {
134
+ "state": {"concept": "str", "count": "int"},
135
+ "nodes": {},
136
+ "edges": [],
137
+ }
138
+ state_class = build_state_class(config)
139
+ annotations = state_class.__annotations__
140
+
141
+ assert "concept" in annotations
142
+ assert "count" in annotations
143
+
144
+ def test_state_section_empty(self):
145
+ """Empty state section doesn't break build."""
146
+ config = {
147
+ "state": {},
148
+ "nodes": {},
149
+ "edges": [],
150
+ }
151
+ state_class = build_state_class(config)
152
+ # Should still have base fields
153
+ assert "thread_id" in state_class.__annotations__
154
+
155
+ def test_state_section_missing(self):
156
+ """Missing state section is handled."""
157
+ config = {
158
+ "nodes": {},
159
+ "edges": [],
160
+ }
161
+ state_class = build_state_class(config)
162
+ # Should still have base fields
163
+ assert "thread_id" in state_class.__annotations__
164
+
165
+ def test_custom_field_overrides_common(self):
166
+ """Custom state field can override common field type."""
167
+ config = {
168
+ "state": {"topic": "int"}, # Override str default
169
+ "nodes": {},
170
+ "edges": [],
171
+ }
172
+ state_class = build_state_class(config)
173
+ # The custom field should be present
174
+ assert "topic" in state_class.__annotations__
175
+
176
+ def test_storyboard_example(self):
177
+ """Test storyboard-style config with concept field."""
178
+ config = {
179
+ "state": {"concept": "str"},
180
+ "nodes": {
181
+ "expand_story": {
182
+ "type": "llm",
183
+ "state_key": "story",
184
+ }
185
+ },
186
+ "edges": [],
187
+ }
188
+ state_class = build_state_class(config)
189
+ annotations = state_class.__annotations__
190
+
191
+ # Custom field from state section
192
+ assert "concept" in annotations
193
+ # Output field from node
194
+ assert "story" in annotations
195
+ # Base infrastructure fields
196
+ assert "thread_id" in annotations
197
+ assert "errors" in annotations
@@ -0,0 +1,307 @@
1
+ """Tests for streaming support - Phase 3 (004).
2
+
3
+ TDD: RED phase - write tests first.
4
+ """
5
+
6
+ from unittest.mock import MagicMock, patch
7
+
8
+ import pytest
9
+
10
+ # ==============================================================================
11
+ # execute_prompt_streaming tests
12
+ # ==============================================================================
13
+
14
+
15
+ @pytest.mark.asyncio
16
+ async def test_execute_prompt_streaming_yields_tokens():
17
+ """execute_prompt_streaming yields tokens from LLM stream."""
18
+ from yamlgraph.executor_async import execute_prompt_streaming
19
+
20
+ # Mock LLM with astream method
21
+ mock_chunk1 = MagicMock()
22
+ mock_chunk1.content = "Hello"
23
+ mock_chunk2 = MagicMock()
24
+ mock_chunk2.content = " World"
25
+ mock_chunk3 = MagicMock()
26
+ mock_chunk3.content = "!"
27
+
28
+ async def mock_astream(*args, **kwargs):
29
+ for chunk in [mock_chunk1, mock_chunk2, mock_chunk3]:
30
+ yield chunk
31
+
32
+ mock_llm = MagicMock()
33
+ mock_llm.astream = mock_astream
34
+
35
+ with (
36
+ patch("yamlgraph.executor_async.create_llm", return_value=mock_llm),
37
+ patch("yamlgraph.executor_async.load_prompt") as mock_load,
38
+ ):
39
+ mock_load.return_value = {
40
+ "system": "You are helpful.",
41
+ "user": "Say hello",
42
+ }
43
+
44
+ tokens = []
45
+ async for token in execute_prompt_streaming("greet", variables={}):
46
+ tokens.append(token)
47
+
48
+ assert tokens == ["Hello", " World", "!"]
49
+
50
+
51
+ @pytest.mark.asyncio
52
+ async def test_execute_prompt_streaming_with_variables():
53
+ """execute_prompt_streaming formats template with variables."""
54
+ from yamlgraph.executor_async import execute_prompt_streaming
55
+
56
+ mock_chunk = MagicMock()
57
+ mock_chunk.content = "Hi Alice!"
58
+
59
+ async def mock_astream(*args, **kwargs):
60
+ yield mock_chunk
61
+
62
+ mock_llm = MagicMock()
63
+ mock_llm.astream = mock_astream
64
+
65
+ with (
66
+ patch("yamlgraph.executor_async.create_llm", return_value=mock_llm),
67
+ patch("yamlgraph.executor_async.load_prompt") as mock_load,
68
+ patch("yamlgraph.executor_async.format_prompt", return_value="Say hello to Alice") as mock_format,
69
+ ):
70
+ mock_load.return_value = {
71
+ "system": "",
72
+ "user": "Say hello to {name}",
73
+ }
74
+
75
+ tokens = []
76
+ async for token in execute_prompt_streaming("greet", variables={"name": "Alice"}):
77
+ tokens.append(token)
78
+
79
+ assert tokens == ["Hi Alice!"]
80
+ mock_format.assert_called()
81
+
82
+
83
+ @pytest.mark.asyncio
84
+ async def test_execute_prompt_streaming_uses_provider():
85
+ """execute_prompt_streaming passes provider to create_llm."""
86
+ from yamlgraph.executor_async import execute_prompt_streaming
87
+
88
+ mock_chunk = MagicMock()
89
+ mock_chunk.content = "test"
90
+
91
+ async def mock_astream(*args, **kwargs):
92
+ yield mock_chunk
93
+
94
+ mock_llm = MagicMock()
95
+ mock_llm.astream = mock_astream
96
+
97
+ with (
98
+ patch("yamlgraph.executor_async.create_llm", return_value=mock_llm) as mock_create,
99
+ patch("yamlgraph.executor_async.load_prompt") as mock_load,
100
+ ):
101
+ mock_load.return_value = {"system": "", "user": "test"}
102
+
103
+ async for _ in execute_prompt_streaming("test", variables={}, provider="openai"):
104
+ pass
105
+
106
+ mock_create.assert_called_once()
107
+ call_kwargs = mock_create.call_args.kwargs
108
+ assert call_kwargs.get("provider") == "openai"
109
+
110
+
111
+ @pytest.mark.asyncio
112
+ async def test_execute_prompt_streaming_handles_empty_chunks():
113
+ """execute_prompt_streaming skips empty chunks."""
114
+ from yamlgraph.executor_async import execute_prompt_streaming
115
+
116
+ mock_chunk1 = MagicMock()
117
+ mock_chunk1.content = "Hello"
118
+ mock_chunk2 = MagicMock()
119
+ mock_chunk2.content = "" # Empty
120
+ mock_chunk3 = MagicMock()
121
+ mock_chunk3.content = None # None
122
+ mock_chunk4 = MagicMock()
123
+ mock_chunk4.content = "World"
124
+
125
+ async def mock_astream(*args, **kwargs):
126
+ for chunk in [mock_chunk1, mock_chunk2, mock_chunk3, mock_chunk4]:
127
+ yield chunk
128
+
129
+ mock_llm = MagicMock()
130
+ mock_llm.astream = mock_astream
131
+
132
+ with (
133
+ patch("yamlgraph.executor_async.create_llm", return_value=mock_llm),
134
+ patch("yamlgraph.executor_async.load_prompt") as mock_load,
135
+ ):
136
+ mock_load.return_value = {"system": "", "user": "test"}
137
+
138
+ tokens = []
139
+ async for token in execute_prompt_streaming("test", variables={}):
140
+ tokens.append(token)
141
+
142
+ # Only non-empty tokens
143
+ assert tokens == ["Hello", "World"]
144
+
145
+
146
+ @pytest.mark.asyncio
147
+ async def test_execute_prompt_streaming_propagates_errors():
148
+ """execute_prompt_streaming propagates LLM errors."""
149
+ from yamlgraph.executor_async import execute_prompt_streaming
150
+
151
+ async def mock_astream(*args, **kwargs):
152
+ yield MagicMock(content="start")
153
+ raise ValueError("LLM error")
154
+
155
+ mock_llm = MagicMock()
156
+ mock_llm.astream = mock_astream
157
+
158
+ with (
159
+ patch("yamlgraph.executor_async.create_llm", return_value=mock_llm),
160
+ patch("yamlgraph.executor_async.load_prompt") as mock_load,
161
+ ):
162
+ mock_load.return_value = {"system": "", "user": "test"}
163
+
164
+ tokens = []
165
+ with pytest.raises(ValueError, match="LLM error"):
166
+ async for token in execute_prompt_streaming("test", variables={}):
167
+ tokens.append(token)
168
+
169
+ # Should have received first token before error
170
+ assert tokens == ["start"]
171
+
172
+
173
+ # ==============================================================================
174
+ # Streaming with output collection
175
+ # ==============================================================================
176
+
177
+
178
+ @pytest.mark.asyncio
179
+ async def test_execute_prompt_streaming_collect():
180
+ """execute_prompt_streaming can collect all tokens into string."""
181
+ from yamlgraph.executor_async import execute_prompt_streaming
182
+
183
+ mock_chunks = [MagicMock(content=c) for c in ["The ", "quick ", "brown ", "fox"]]
184
+
185
+ async def mock_astream(*args, **kwargs):
186
+ for chunk in mock_chunks:
187
+ yield chunk
188
+
189
+ mock_llm = MagicMock()
190
+ mock_llm.astream = mock_astream
191
+
192
+ with (
193
+ patch("yamlgraph.executor_async.create_llm", return_value=mock_llm),
194
+ patch("yamlgraph.executor_async.load_prompt") as mock_load,
195
+ ):
196
+ mock_load.return_value = {"system": "", "user": "test"}
197
+
198
+ # Collect all tokens
199
+ result = "".join([token async for token in execute_prompt_streaming("test", {})])
200
+
201
+ assert result == "The quick brown fox"
202
+
203
+
204
+ # ==============================================================================
205
+ # Streaming node factory tests
206
+ # ==============================================================================
207
+
208
+
209
+ @pytest.mark.asyncio
210
+ async def test_create_streaming_node_yields_tokens():
211
+ """create_streaming_node creates node that yields tokens."""
212
+ from yamlgraph.node_factory import create_streaming_node
213
+
214
+ mock_chunks = [MagicMock(content=c) for c in ["Hello", " ", "World"]]
215
+
216
+ async def mock_streaming(*args, **kwargs):
217
+ for chunk in mock_chunks:
218
+ yield chunk.content
219
+
220
+ with patch("yamlgraph.executor_async.execute_prompt_streaming", mock_streaming):
221
+ node_config = {
222
+ "prompt": "greet",
223
+ "state_key": "response",
224
+ }
225
+ streaming_node = create_streaming_node("generate", node_config)
226
+
227
+ state = {"input": "test"}
228
+ tokens = []
229
+ async for token in streaming_node(state):
230
+ tokens.append(token)
231
+
232
+ assert tokens == ["Hello", " ", "World"]
233
+
234
+
235
+ @pytest.mark.asyncio
236
+ async def test_streaming_node_with_callback():
237
+ """Streaming node can use callback for each token."""
238
+ from yamlgraph.node_factory import create_streaming_node
239
+
240
+ async def mock_streaming(*args, **kwargs):
241
+ for token in ["A", "B", "C"]:
242
+ yield token
243
+
244
+ collected = []
245
+
246
+ def token_callback(token: str):
247
+ collected.append(token)
248
+
249
+ with patch("yamlgraph.executor_async.execute_prompt_streaming", mock_streaming):
250
+ node_config = {
251
+ "prompt": "test",
252
+ "state_key": "output",
253
+ "on_token": token_callback,
254
+ }
255
+ streaming_node = create_streaming_node("stream_node", node_config)
256
+
257
+ # Consume the generator
258
+ async for _ in streaming_node({}):
259
+ pass
260
+
261
+ assert collected == ["A", "B", "C"]
262
+
263
+
264
+ # ==============================================================================
265
+ # YAML config tests
266
+ # ==============================================================================
267
+
268
+
269
+ def test_node_config_stream_true_creates_streaming_node():
270
+ """Node with stream: true creates streaming node."""
271
+ from yamlgraph.node_factory import create_node_function
272
+
273
+ node_config = {
274
+ "prompt": "greet",
275
+ "state_key": "response",
276
+ "stream": True,
277
+ }
278
+
279
+ with patch("yamlgraph.node_factory.create_streaming_node") as mock_create:
280
+ mock_create.return_value = MagicMock()
281
+
282
+ # This should detect stream: true and use create_streaming_node
283
+ _result = create_node_function("generate", node_config, defaults={})
284
+
285
+ mock_create.assert_called_once_with("generate", node_config)
286
+
287
+
288
+ def test_node_config_stream_false_creates_regular_node():
289
+ """Node with stream: false creates regular node."""
290
+ from yamlgraph.node_factory import create_node_function
291
+
292
+ node_config = {
293
+ "prompt": "greet",
294
+ "state_key": "response",
295
+ "stream": False,
296
+ }
297
+
298
+ with (
299
+ patch("yamlgraph.node_factory.create_streaming_node") as mock_streaming,
300
+ patch("yamlgraph.node_factory.execute_prompt") as mock_execute,
301
+ ):
302
+ mock_execute.return_value = "result"
303
+
304
+ _result = create_node_function("generate", node_config, defaults={})
305
+
306
+ # Should NOT call create_streaming_node
307
+ mock_streaming.assert_not_called()