yamlgraph 0.3.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. examples/__init__.py +1 -0
  2. examples/codegen/__init__.py +5 -0
  3. examples/codegen/models/__init__.py +13 -0
  4. examples/codegen/models/schemas.py +76 -0
  5. examples/codegen/tests/__init__.py +1 -0
  6. examples/codegen/tests/test_ai_helpers.py +235 -0
  7. examples/codegen/tests/test_ast_analysis.py +174 -0
  8. examples/codegen/tests/test_code_analysis.py +134 -0
  9. examples/codegen/tests/test_code_context.py +301 -0
  10. examples/codegen/tests/test_code_nav.py +89 -0
  11. examples/codegen/tests/test_dependency_tools.py +119 -0
  12. examples/codegen/tests/test_example_tools.py +185 -0
  13. examples/codegen/tests/test_git_tools.py +112 -0
  14. examples/codegen/tests/test_impl_agent_schemas.py +193 -0
  15. examples/codegen/tests/test_impl_agent_v4_graph.py +94 -0
  16. examples/codegen/tests/test_jedi_analysis.py +226 -0
  17. examples/codegen/tests/test_meta_tools.py +250 -0
  18. examples/codegen/tests/test_plan_discovery_prompt.py +98 -0
  19. examples/codegen/tests/test_syntax_tools.py +85 -0
  20. examples/codegen/tests/test_synthesize_prompt.py +94 -0
  21. examples/codegen/tests/test_template_tools.py +244 -0
  22. examples/codegen/tools/__init__.py +80 -0
  23. examples/codegen/tools/ai_helpers.py +420 -0
  24. examples/codegen/tools/ast_analysis.py +92 -0
  25. examples/codegen/tools/code_context.py +180 -0
  26. examples/codegen/tools/code_nav.py +52 -0
  27. examples/codegen/tools/dependency_tools.py +120 -0
  28. examples/codegen/tools/example_tools.py +188 -0
  29. examples/codegen/tools/git_tools.py +151 -0
  30. examples/codegen/tools/impl_executor.py +614 -0
  31. examples/codegen/tools/jedi_analysis.py +311 -0
  32. examples/codegen/tools/meta_tools.py +202 -0
  33. examples/codegen/tools/syntax_tools.py +26 -0
  34. examples/codegen/tools/template_tools.py +356 -0
  35. examples/fastapi_interview.py +167 -0
  36. examples/npc/api/__init__.py +1 -0
  37. examples/npc/api/app.py +100 -0
  38. examples/npc/api/routes/__init__.py +5 -0
  39. examples/npc/api/routes/encounter.py +182 -0
  40. examples/npc/api/session.py +330 -0
  41. examples/npc/demo.py +387 -0
  42. examples/npc/nodes/__init__.py +5 -0
  43. examples/npc/nodes/image_node.py +92 -0
  44. examples/npc/run_encounter.py +230 -0
  45. examples/shared/__init__.py +0 -0
  46. examples/shared/replicate_tool.py +238 -0
  47. examples/storyboard/__init__.py +1 -0
  48. examples/storyboard/generate_videos.py +335 -0
  49. examples/storyboard/nodes/__init__.py +12 -0
  50. examples/storyboard/nodes/animated_character_node.py +248 -0
  51. examples/storyboard/nodes/animated_image_node.py +138 -0
  52. examples/storyboard/nodes/character_node.py +162 -0
  53. examples/storyboard/nodes/image_node.py +118 -0
  54. examples/storyboard/nodes/replicate_tool.py +49 -0
  55. examples/storyboard/retry_images.py +118 -0
  56. scripts/demo_async_executor.py +212 -0
  57. scripts/demo_interview_e2e.py +200 -0
  58. scripts/demo_streaming.py +140 -0
  59. scripts/run_interview_demo.py +94 -0
  60. scripts/test_interrupt_fix.py +26 -0
  61. tests/__init__.py +1 -0
  62. tests/conftest.py +178 -0
  63. tests/integration/__init__.py +1 -0
  64. tests/integration/test_animated_storyboard.py +63 -0
  65. tests/integration/test_cli_commands.py +242 -0
  66. tests/integration/test_colocated_prompts.py +139 -0
  67. tests/integration/test_map_demo.py +50 -0
  68. tests/integration/test_memory_demo.py +283 -0
  69. tests/integration/test_npc_api/__init__.py +1 -0
  70. tests/integration/test_npc_api/test_routes.py +357 -0
  71. tests/integration/test_npc_api/test_session.py +216 -0
  72. tests/integration/test_pipeline_flow.py +105 -0
  73. tests/integration/test_providers.py +163 -0
  74. tests/integration/test_resume.py +75 -0
  75. tests/integration/test_subgraph_integration.py +295 -0
  76. tests/integration/test_subgraph_interrupt.py +106 -0
  77. tests/unit/__init__.py +1 -0
  78. tests/unit/test_agent_nodes.py +355 -0
  79. tests/unit/test_async_executor.py +346 -0
  80. tests/unit/test_checkpointer.py +212 -0
  81. tests/unit/test_checkpointer_factory.py +212 -0
  82. tests/unit/test_cli.py +121 -0
  83. tests/unit/test_cli_package.py +81 -0
  84. tests/unit/test_compile_graph_map.py +132 -0
  85. tests/unit/test_conditions_routing.py +253 -0
  86. tests/unit/test_config.py +93 -0
  87. tests/unit/test_conversation_memory.py +276 -0
  88. tests/unit/test_database.py +145 -0
  89. tests/unit/test_deprecation.py +104 -0
  90. tests/unit/test_executor.py +172 -0
  91. tests/unit/test_executor_async.py +179 -0
  92. tests/unit/test_export.py +149 -0
  93. tests/unit/test_expressions.py +178 -0
  94. tests/unit/test_feature_brainstorm.py +194 -0
  95. tests/unit/test_format_prompt.py +145 -0
  96. tests/unit/test_generic_report.py +200 -0
  97. tests/unit/test_graph_commands.py +327 -0
  98. tests/unit/test_graph_linter.py +627 -0
  99. tests/unit/test_graph_loader.py +357 -0
  100. tests/unit/test_graph_schema.py +193 -0
  101. tests/unit/test_inline_schema.py +151 -0
  102. tests/unit/test_interrupt_node.py +182 -0
  103. tests/unit/test_issues.py +164 -0
  104. tests/unit/test_jinja2_prompts.py +85 -0
  105. tests/unit/test_json_extract.py +134 -0
  106. tests/unit/test_langsmith.py +600 -0
  107. tests/unit/test_langsmith_tools.py +204 -0
  108. tests/unit/test_llm_factory.py +109 -0
  109. tests/unit/test_llm_factory_async.py +118 -0
  110. tests/unit/test_loops.py +403 -0
  111. tests/unit/test_map_node.py +144 -0
  112. tests/unit/test_no_backward_compat.py +56 -0
  113. tests/unit/test_node_factory.py +348 -0
  114. tests/unit/test_passthrough_node.py +126 -0
  115. tests/unit/test_prompts.py +324 -0
  116. tests/unit/test_python_nodes.py +198 -0
  117. tests/unit/test_reliability.py +298 -0
  118. tests/unit/test_result_export.py +234 -0
  119. tests/unit/test_router.py +296 -0
  120. tests/unit/test_sanitize.py +99 -0
  121. tests/unit/test_schema_loader.py +295 -0
  122. tests/unit/test_shell_tools.py +229 -0
  123. tests/unit/test_state_builder.py +331 -0
  124. tests/unit/test_state_builder_map.py +104 -0
  125. tests/unit/test_state_config.py +197 -0
  126. tests/unit/test_streaming.py +307 -0
  127. tests/unit/test_subgraph.py +596 -0
  128. tests/unit/test_template.py +190 -0
  129. tests/unit/test_tool_call_integration.py +164 -0
  130. tests/unit/test_tool_call_node.py +178 -0
  131. tests/unit/test_tool_nodes.py +129 -0
  132. tests/unit/test_websearch.py +234 -0
  133. yamlgraph/__init__.py +35 -0
  134. yamlgraph/builder.py +110 -0
  135. yamlgraph/cli/__init__.py +159 -0
  136. yamlgraph/cli/__main__.py +6 -0
  137. yamlgraph/cli/commands.py +231 -0
  138. yamlgraph/cli/deprecation.py +92 -0
  139. yamlgraph/cli/graph_commands.py +541 -0
  140. yamlgraph/cli/validators.py +37 -0
  141. yamlgraph/config.py +67 -0
  142. yamlgraph/constants.py +70 -0
  143. yamlgraph/error_handlers.py +227 -0
  144. yamlgraph/executor.py +290 -0
  145. yamlgraph/executor_async.py +288 -0
  146. yamlgraph/graph_loader.py +451 -0
  147. yamlgraph/map_compiler.py +150 -0
  148. yamlgraph/models/__init__.py +36 -0
  149. yamlgraph/models/graph_schema.py +181 -0
  150. yamlgraph/models/schemas.py +124 -0
  151. yamlgraph/models/state_builder.py +236 -0
  152. yamlgraph/node_factory.py +768 -0
  153. yamlgraph/routing.py +87 -0
  154. yamlgraph/schema_loader.py +240 -0
  155. yamlgraph/storage/__init__.py +20 -0
  156. yamlgraph/storage/checkpointer.py +72 -0
  157. yamlgraph/storage/checkpointer_factory.py +123 -0
  158. yamlgraph/storage/database.py +320 -0
  159. yamlgraph/storage/export.py +269 -0
  160. yamlgraph/tools/__init__.py +1 -0
  161. yamlgraph/tools/agent.py +320 -0
  162. yamlgraph/tools/graph_linter.py +388 -0
  163. yamlgraph/tools/langsmith_tools.py +125 -0
  164. yamlgraph/tools/nodes.py +126 -0
  165. yamlgraph/tools/python_tool.py +179 -0
  166. yamlgraph/tools/shell.py +205 -0
  167. yamlgraph/tools/websearch.py +242 -0
  168. yamlgraph/utils/__init__.py +48 -0
  169. yamlgraph/utils/conditions.py +157 -0
  170. yamlgraph/utils/expressions.py +245 -0
  171. yamlgraph/utils/json_extract.py +104 -0
  172. yamlgraph/utils/langsmith.py +416 -0
  173. yamlgraph/utils/llm_factory.py +118 -0
  174. yamlgraph/utils/llm_factory_async.py +105 -0
  175. yamlgraph/utils/logging.py +104 -0
  176. yamlgraph/utils/prompts.py +171 -0
  177. yamlgraph/utils/sanitize.py +98 -0
  178. yamlgraph/utils/template.py +102 -0
  179. yamlgraph/utils/validators.py +181 -0
  180. yamlgraph-0.3.9.dist-info/METADATA +1105 -0
  181. yamlgraph-0.3.9.dist-info/RECORD +185 -0
  182. yamlgraph-0.3.9.dist-info/WHEEL +5 -0
  183. yamlgraph-0.3.9.dist-info/entry_points.txt +2 -0
  184. yamlgraph-0.3.9.dist-info/licenses/LICENSE +33 -0
  185. yamlgraph-0.3.9.dist-info/top_level.txt +4 -0
@@ -0,0 +1,768 @@
1
+ """Node function factory for YAML-defined graphs.
2
+
3
+ Creates LangGraph node functions from YAML configuration with support for:
4
+ - Resume (skip if output exists)
5
+ - Error handling (skip, retry, fail, fallback)
6
+ - Router nodes with dynamic routing
7
+ - Loop counting and limits
8
+ - Dynamic tool calls from state (type: tool_call)
9
+ - Streaming nodes (type: llm, stream: true)
10
+ - Subgraph nodes (type: subgraph) for composing workflows
11
+ - JSON extraction from LLM output (parse_json: true)
12
+ """
13
+
14
+ import logging
15
+ from collections.abc import AsyncIterator, Callable
16
+ from contextvars import ContextVar
17
+ from pathlib import Path
18
+ from typing import Any
19
+
20
+ from yamlgraph.constants import ErrorHandler, NodeType
21
+ from yamlgraph.error_handlers import (
22
+ check_loop_limit,
23
+ check_requirements,
24
+ handle_default,
25
+ handle_fail,
26
+ handle_fallback,
27
+ handle_retry,
28
+ handle_skip,
29
+ )
30
+ from yamlgraph.executor import execute_prompt
31
+ from yamlgraph.utils.expressions import resolve_template
32
+ from yamlgraph.utils.json_extract import extract_json
33
+ from yamlgraph.utils.prompts import resolve_prompt_path
34
+
35
+ # Type alias for dynamic state
36
+ GraphState = dict[str, Any]
37
+
38
+ logger = logging.getLogger(__name__)
39
+
40
+ # Thread-safe loading stack to detect circular subgraph references
41
+ # Note: Do NOT use default=[] as it shares the same list across contexts
42
+ _loading_stack: ContextVar[list[Path]] = ContextVar("loading_stack")
43
+
44
+
45
+ def create_tool_call_node(
46
+ node_name: str,
47
+ node_config: dict[str, Any],
48
+ tools_registry: dict[str, Callable],
49
+ ) -> Callable[[GraphState], dict]:
50
+ """Create a node that dynamically calls a tool from state.
51
+
52
+ This enables YAML-driven tool execution where tool name and args
53
+ are resolved from state at runtime.
54
+
55
+ Args:
56
+ node_name: Name of the node
57
+ node_config: Node configuration with 'tool', 'args', 'state_key'
58
+ tools_registry: Dict mapping tool names to callable functions
59
+
60
+ Returns:
61
+ Node function compatible with LangGraph
62
+ """
63
+ tool_expr = node_config["tool"] # e.g., "{state.task.tool}"
64
+ args_expr = node_config["args"] # e.g., "{state.task.args}"
65
+ state_key = node_config.get("state_key", "result")
66
+
67
+ def node_fn(state: dict) -> dict:
68
+ # Resolve tool name and args from state
69
+ tool_name = resolve_template(tool_expr, state)
70
+ args = resolve_template(args_expr, state)
71
+
72
+ # Extract task_id if available
73
+ task = state.get("task", {})
74
+ task_id = task.get("id") if isinstance(task, dict) else None
75
+
76
+ # Look up tool in registry
77
+ tool_func = tools_registry.get(tool_name)
78
+ if tool_func is None:
79
+ return {
80
+ state_key: {
81
+ "task_id": task_id,
82
+ "tool": tool_name,
83
+ "success": False,
84
+ "result": None,
85
+ "error": f"Unknown tool: {tool_name}",
86
+ },
87
+ "current_step": node_name,
88
+ }
89
+
90
+ # Execute tool
91
+ try:
92
+ # Ensure args is a dict
93
+ if not isinstance(args, dict):
94
+ args = {}
95
+ result = tool_func(**args)
96
+ return {
97
+ state_key: {
98
+ "task_id": task_id,
99
+ "tool": tool_name,
100
+ "success": True,
101
+ "result": result,
102
+ "error": None,
103
+ },
104
+ "current_step": node_name,
105
+ }
106
+ except Exception as e:
107
+ return {
108
+ state_key: {
109
+ "task_id": task_id,
110
+ "tool": tool_name,
111
+ "success": False,
112
+ "result": None,
113
+ "error": str(e),
114
+ },
115
+ "current_step": node_name,
116
+ }
117
+
118
+ node_fn.__name__ = f"{node_name}_tool_call"
119
+ return node_fn
120
+
121
+
122
+ def resolve_class(class_path: str) -> type:
123
+ """Dynamically import and return a class from a module path.
124
+
125
+ Args:
126
+ class_path: Full path like "yamlgraph.models.GenericReport" or short name
127
+
128
+ Returns:
129
+ The class object
130
+ """
131
+ import importlib
132
+
133
+ parts = class_path.rsplit(".", 1)
134
+ if len(parts) != 2:
135
+ # Try to find in yamlgraph.models.schemas
136
+ try:
137
+ from yamlgraph.models import schemas
138
+
139
+ if hasattr(schemas, class_path):
140
+ return getattr(schemas, class_path)
141
+ except ImportError:
142
+ pass
143
+ raise ValueError(f"Invalid class path: {class_path}")
144
+
145
+ module_path, class_name = parts
146
+ module = importlib.import_module(module_path)
147
+ return getattr(module, class_name)
148
+
149
+
150
+ def get_output_model_for_node(
151
+ node_config: dict[str, Any],
152
+ prompts_dir: Path | None = None,
153
+ graph_path: Path | None = None,
154
+ prompts_relative: bool = False,
155
+ ) -> type | None:
156
+ """Get output model for a node, checking inline schema if no explicit model.
157
+
158
+ Priority:
159
+ 1. Explicit output_model in node config (class path)
160
+ 2. Inline schema in prompt YAML file
161
+ 3. None (raw string output)
162
+
163
+ Args:
164
+ node_config: Node configuration from YAML
165
+ prompts_dir: Base prompts directory
166
+ graph_path: Path to graph YAML file (for relative prompt resolution)
167
+ prompts_relative: If True, resolve prompts relative to graph_path
168
+
169
+ Returns:
170
+ Pydantic model class or None
171
+ """
172
+ # 1. Check for explicit output_model
173
+ if model_path := node_config.get("output_model"):
174
+ return resolve_class(model_path)
175
+
176
+ # 2. Check for inline schema in prompt YAML
177
+ prompt_name = node_config.get("prompt")
178
+ if prompt_name:
179
+ try:
180
+ from yamlgraph.schema_loader import load_schema_from_yaml
181
+
182
+ yaml_path = resolve_prompt_path(
183
+ prompt_name,
184
+ prompts_dir=prompts_dir,
185
+ graph_path=graph_path,
186
+ prompts_relative=prompts_relative,
187
+ )
188
+ return load_schema_from_yaml(yaml_path)
189
+ except FileNotFoundError:
190
+ # Prompt file doesn't exist yet - will fail later
191
+ pass
192
+
193
+ # 3. No output model
194
+ return None
195
+
196
+
197
+ def create_node_function(
198
+ node_name: str,
199
+ node_config: dict,
200
+ defaults: dict,
201
+ graph_path: Path | None = None,
202
+ ) -> Callable[[GraphState], dict]:
203
+ """Create a node function from YAML config.
204
+
205
+ Args:
206
+ node_name: Name of the node
207
+ node_config: Node configuration from YAML
208
+ defaults: Default configuration values
209
+ graph_path: Path to graph YAML file (for relative prompt resolution)
210
+
211
+ Returns:
212
+ Node function compatible with LangGraph
213
+ """
214
+ node_type = node_config.get("type", NodeType.LLM)
215
+ prompt_name = node_config.get("prompt")
216
+
217
+ # Prompt resolution options from defaults (FR-A)
218
+ prompts_relative = defaults.get("prompts_relative", False)
219
+ prompts_dir = defaults.get("prompts_dir")
220
+ if prompts_dir:
221
+ prompts_dir = Path(prompts_dir)
222
+
223
+ # Check for streaming mode
224
+ if node_config.get("stream", False):
225
+ return create_streaming_node(node_name, node_config)
226
+
227
+ # Resolve output model (explicit > inline schema > None)
228
+ output_model = get_output_model_for_node(
229
+ node_config,
230
+ prompts_dir=prompts_dir,
231
+ graph_path=graph_path,
232
+ prompts_relative=prompts_relative,
233
+ )
234
+
235
+ # Get config values (node > defaults)
236
+ temperature = node_config.get("temperature", defaults.get("temperature", 0.7))
237
+ provider = node_config.get("provider", defaults.get("provider"))
238
+ state_key = node_config.get("state_key", node_name)
239
+ variable_templates = node_config.get("variables", {})
240
+ requires = node_config.get("requires", [])
241
+
242
+ # Error handling
243
+ on_error = node_config.get("on_error")
244
+ max_retries = node_config.get("max_retries", 3)
245
+ fallback_config = node_config.get("fallback", {})
246
+ fallback_provider = fallback_config.get("provider") if fallback_config else None
247
+
248
+ # Router config
249
+ routes = node_config.get("routes", {})
250
+ default_route = node_config.get("default_route")
251
+
252
+ # Loop limit
253
+ loop_limit = node_config.get("loop_limit")
254
+
255
+ # Skip if exists (default true for resume support, false for loop nodes)
256
+ skip_if_exists = node_config.get("skip_if_exists", True)
257
+
258
+ # JSON extraction (FR-B)
259
+ parse_json = node_config.get("parse_json", False)
260
+
261
+ def node_fn(state: dict) -> dict:
262
+ """Generated node function."""
263
+ loop_counts = dict(state.get("_loop_counts") or {})
264
+ current_count = loop_counts.get(node_name, 0)
265
+
266
+ # Check loop limit
267
+ if check_loop_limit(node_name, loop_limit, current_count):
268
+ return {"_loop_limit_reached": True, "current_step": node_name}
269
+
270
+ loop_counts[node_name] = current_count + 1
271
+
272
+ # Skip if output exists (resume support) - disabled for loop nodes
273
+ if skip_if_exists and state.get(state_key) is not None:
274
+ logger.info(f"Node {node_name} skipped - {state_key} already in state")
275
+ return {"current_step": node_name, "_loop_counts": loop_counts}
276
+
277
+ # Check requirements
278
+ if error := check_requirements(requires, state, node_name):
279
+ return {
280
+ "errors": [error],
281
+ "current_step": node_name,
282
+ "_loop_counts": loop_counts,
283
+ }
284
+
285
+ # Resolve variables from templates OR use state directly
286
+ if variable_templates:
287
+ variables = {}
288
+ for key, template in variable_templates.items():
289
+ resolved = resolve_template(template, state)
290
+ # Preserve original types (lists, dicts) for Jinja2 templates
291
+ variables[key] = resolved
292
+ else:
293
+ # No explicit variable mapping - pass state as variables
294
+ # Filter out internal keys and None values
295
+ variables = {
296
+ k: v
297
+ for k, v in state.items()
298
+ if not k.startswith("_") and v is not None
299
+ }
300
+
301
+ def attempt_execute(use_provider: str | None) -> tuple[Any, Exception | None]:
302
+ try:
303
+ result = execute_prompt(
304
+ prompt_name=prompt_name,
305
+ variables=variables,
306
+ output_model=output_model,
307
+ temperature=temperature,
308
+ provider=use_provider,
309
+ graph_path=graph_path,
310
+ prompts_dir=prompts_dir,
311
+ prompts_relative=prompts_relative,
312
+ )
313
+ return result, None
314
+ except Exception as e:
315
+ return None, e
316
+
317
+ result, error = attempt_execute(provider)
318
+
319
+ if error is None:
320
+ # Post-process: JSON extraction if enabled (FR-B)
321
+ if parse_json and isinstance(result, str):
322
+ result = extract_json(result)
323
+
324
+ logger.info(f"Node {node_name} completed successfully")
325
+ update = {
326
+ state_key: result,
327
+ "current_step": node_name,
328
+ "_loop_counts": loop_counts,
329
+ }
330
+
331
+ # Router: add _route to state
332
+ if node_type == NodeType.ROUTER and routes:
333
+ route_key = getattr(result, "tone", None) or getattr(
334
+ result, "intent", None
335
+ )
336
+ if route_key and route_key in routes:
337
+ update["_route"] = routes[route_key]
338
+ elif default_route:
339
+ update["_route"] = default_route
340
+ else:
341
+ update["_route"] = list(routes.values())[0]
342
+ logger.info(f"Router {node_name} routing to: {update['_route']}")
343
+ return update
344
+
345
+ # Error handling - dispatch to strategy handlers
346
+ if on_error == ErrorHandler.SKIP:
347
+ handle_skip(node_name, error, loop_counts)
348
+ return {"current_step": node_name, "_loop_counts": loop_counts}
349
+
350
+ elif on_error == ErrorHandler.FAIL:
351
+ handle_fail(node_name, error)
352
+
353
+ elif on_error == ErrorHandler.RETRY:
354
+ result = handle_retry(
355
+ node_name,
356
+ lambda: attempt_execute(provider),
357
+ max_retries,
358
+ )
359
+ return result.to_state_update(state_key, node_name, loop_counts)
360
+
361
+ elif on_error == ErrorHandler.FALLBACK and fallback_provider:
362
+ result = handle_fallback(
363
+ node_name,
364
+ attempt_execute,
365
+ fallback_provider,
366
+ )
367
+ return result.to_state_update(state_key, node_name, loop_counts)
368
+
369
+ else:
370
+ result = handle_default(node_name, error)
371
+ return result.to_state_update(state_key, node_name, loop_counts)
372
+
373
+ node_fn.__name__ = f"{node_name}_node"
374
+ return node_fn
375
+
376
+
377
+ def create_interrupt_node(
378
+ node_name: str,
379
+ config: dict[str, Any],
380
+ graph_path: Path | None = None,
381
+ prompts_dir: Path | None = None,
382
+ prompts_relative: bool = False,
383
+ ) -> Callable[[GraphState], dict]:
384
+ """Create an interrupt node that pauses for human input.
385
+
386
+ Uses LangGraph's native interrupt() function for human-in-the-loop.
387
+ Handles idempotency by checking state_key before re-executing prompts.
388
+
389
+ Args:
390
+ node_name: Name of the node
391
+ config: Node configuration with optional keys:
392
+ - message: Static interrupt payload (string or dict)
393
+ - prompt: Prompt name to generate dynamic payload
394
+ - state_key: Where to store payload (default: "interrupt_message")
395
+ - resume_key: Where to store resume value (default: "user_input")
396
+ graph_path: Path to graph file for relative prompt resolution
397
+ prompts_dir: Explicit prompts directory override
398
+ prompts_relative: If True, resolve prompts relative to graph_path
399
+
400
+ Returns:
401
+ Node function compatible with LangGraph
402
+ """
403
+ from langgraph.types import interrupt
404
+
405
+ message = config.get("message")
406
+ prompt_name = config.get("prompt")
407
+ state_key = config.get("state_key", "interrupt_message")
408
+ resume_key = config.get("resume_key", "user_input")
409
+
410
+ def interrupt_fn(state: dict) -> dict:
411
+ # Check if we already have a payload (resuming) - idempotency
412
+ existing_payload = state.get(state_key)
413
+
414
+ if existing_payload is not None:
415
+ # Resuming - use stored payload, don't re-execute prompt
416
+ payload = existing_payload
417
+ elif prompt_name:
418
+ # First execution with prompt
419
+ payload = execute_prompt(
420
+ prompt_name,
421
+ state,
422
+ graph_path=graph_path,
423
+ prompts_dir=prompts_dir,
424
+ prompts_relative=prompts_relative,
425
+ )
426
+ elif message is not None:
427
+ # Static message
428
+ payload = message
429
+ else:
430
+ # Fallback: use node name as payload
431
+ payload = {"node": node_name}
432
+
433
+ # Native LangGraph interrupt - pauses here on first run
434
+ # On resume, returns the Command(resume=...) value
435
+ response = interrupt(payload)
436
+
437
+ return {
438
+ state_key: payload, # Store for idempotency check
439
+ resume_key: response, # User's response
440
+ "current_step": node_name,
441
+ }
442
+
443
+ interrupt_fn.__name__ = f"{node_name}_interrupt"
444
+ return interrupt_fn
445
+
446
+
447
+ def create_passthrough_node(
448
+ node_name: str,
449
+ config: dict[str, Any],
450
+ ) -> Callable[[GraphState], dict]:
451
+ """Create a passthrough node that transforms state without external calls.
452
+
453
+ Useful for:
454
+ - Loop counters (increment values)
455
+ - State accumulation (append to lists)
456
+ - Simple data transformations
457
+ - Clean transition points in graphs
458
+
459
+ Args:
460
+ node_name: Name of the node
461
+ config: Node configuration with:
462
+ - output: Dict of state_key -> expression mappings
463
+ Expressions use {state.field} syntax
464
+ Supports arithmetic: {state.count + 1}
465
+ Supports list append: {state.history + [state.current]}
466
+
467
+ Returns:
468
+ Node function compatible with LangGraph
469
+
470
+ Example:
471
+ ```yaml
472
+ next_turn:
473
+ type: passthrough
474
+ output:
475
+ turn_number: "{state.turn_number + 1}"
476
+ history: "{state.history + [state.narration]}"
477
+ ```
478
+ """
479
+ from yamlgraph.utils.expressions import resolve_template
480
+
481
+ output_templates = config.get("output", {})
482
+
483
+ def passthrough_fn(state: dict) -> dict:
484
+ result = {"current_step": node_name}
485
+
486
+ for key, template in output_templates.items():
487
+ try:
488
+ resolved = resolve_template(template, state)
489
+ # If resolution failed (None) and key exists in state, keep original
490
+ if resolved is None and key in state:
491
+ result[key] = state[key]
492
+ else:
493
+ result[key] = resolved
494
+ except Exception as e:
495
+ logger.warning(
496
+ f"Passthrough node {node_name}: failed to resolve {key}: {e}"
497
+ )
498
+ # Keep original value on error
499
+ if key in state:
500
+ result[key] = state[key]
501
+
502
+ logger.info(f"Node {node_name} completed successfully")
503
+ return result
504
+
505
+ passthrough_fn.__name__ = f"{node_name}_passthrough"
506
+ return passthrough_fn
507
+
508
+
509
+ def create_streaming_node(
510
+ node_name: str,
511
+ node_config: dict[str, Any],
512
+ ) -> Callable[[GraphState], AsyncIterator[str]]:
513
+ """Create a streaming node that yields tokens.
514
+
515
+ Streaming nodes are async generators that yield tokens as they
516
+ are produced by the LLM. They do not support structured output.
517
+
518
+ Args:
519
+ node_name: Name of the node
520
+ node_config: Node configuration with:
521
+ - prompt: Prompt name to execute
522
+ - state_key: Where to store final result (optional)
523
+ - on_token: Optional callback function for each token
524
+ - provider: LLM provider
525
+ - temperature: LLM temperature
526
+
527
+ Returns:
528
+ Async generator function compatible with streaming execution
529
+ """
530
+ from yamlgraph.executor_async import execute_prompt_streaming
531
+ from yamlgraph.utils.expressions import resolve_template
532
+
533
+ prompt_name = node_config.get("prompt")
534
+ variable_templates = node_config.get("variables", {})
535
+ provider = node_config.get("provider")
536
+ temperature = node_config.get("temperature", 0.7)
537
+ on_token = node_config.get("on_token")
538
+
539
+ async def streaming_node(state: dict) -> AsyncIterator[str]:
540
+ # Resolve variables from templates OR use state directly
541
+ if variable_templates:
542
+ variables = {}
543
+ for key, template in variable_templates.items():
544
+ resolved = resolve_template(template, state)
545
+ # Preserve original types (lists, dicts) for Jinja2 templates
546
+ variables[key] = resolved
547
+ else:
548
+ # No explicit variable mapping - pass state as variables
549
+ variables = {
550
+ k: v
551
+ for k, v in state.items()
552
+ if not k.startswith("_") and v is not None
553
+ }
554
+
555
+ async for token in execute_prompt_streaming(
556
+ prompt_name,
557
+ variables=variables,
558
+ provider=provider,
559
+ temperature=temperature,
560
+ ):
561
+ if on_token:
562
+ on_token(token)
563
+ yield token
564
+
565
+ streaming_node.__name__ = f"{node_name}_streaming"
566
+ return streaming_node
567
+
568
+
569
+ # =============================================================================
570
+ # Subgraph Node Support
571
+ # =============================================================================
572
+
573
+
574
+ def _map_input_state(
575
+ parent_state: dict[str, Any],
576
+ input_mapping: dict[str, str] | str,
577
+ ) -> dict[str, Any]:
578
+ """Map parent state to child input based on mapping config.
579
+
580
+ Args:
581
+ parent_state: Current state from parent graph
582
+ input_mapping: Mapping configuration:
583
+ - dict: explicit {parent_key: child_key} mapping
584
+ - "auto": copy all fields
585
+ - "*": pass state reference directly
586
+
587
+ Returns:
588
+ Input state for child graph
589
+ """
590
+ if input_mapping == "auto":
591
+ return parent_state.copy()
592
+ elif input_mapping == "*":
593
+ return parent_state
594
+ else:
595
+ return {
596
+ child_key: parent_state.get(parent_key)
597
+ for parent_key, child_key in input_mapping.items()
598
+ }
599
+
600
+
601
+ def _map_output_state(
602
+ child_output: dict[str, Any],
603
+ output_mapping: dict[str, str] | str,
604
+ ) -> dict[str, Any]:
605
+ """Map child output to parent state updates based on mapping config.
606
+
607
+ Args:
608
+ child_output: Output state from child graph
609
+ output_mapping: Mapping configuration:
610
+ - dict: explicit {parent_key: child_key} mapping
611
+ - "auto": pass all fields
612
+ - "*": pass output directly
613
+
614
+ Returns:
615
+ Updates to apply to parent state
616
+ """
617
+ if output_mapping in ("auto", "*"):
618
+ return child_output
619
+ else:
620
+ return {
621
+ parent_key: child_output.get(child_key)
622
+ for parent_key, child_key in output_mapping.items()
623
+ }
624
+
625
+
626
+ def _build_child_config(
627
+ parent_config: dict[str, Any],
628
+ node_name: str,
629
+ ) -> dict[str, Any]:
630
+ """Build child graph config with propagated thread ID.
631
+
632
+ Args:
633
+ parent_config: RunnableConfig from parent graph
634
+ node_name: Name of the subgraph node
635
+
636
+ Returns:
637
+ Config for child graph with thread_id: parent_thread:node_name
638
+ """
639
+ configurable = parent_config.get("configurable", {})
640
+ parent_thread_id = configurable.get("thread_id")
641
+
642
+ child_thread_id = (
643
+ f"{parent_thread_id}:{node_name}" if parent_thread_id else node_name
644
+ )
645
+
646
+ return {
647
+ **parent_config,
648
+ "configurable": {
649
+ **configurable,
650
+ "thread_id": child_thread_id,
651
+ },
652
+ }
653
+
654
+
655
+ def create_subgraph_node(
656
+ node_name: str,
657
+ node_config: dict[str, Any],
658
+ parent_graph_path: Path,
659
+ parent_checkpointer: Any | None = None,
660
+ ) -> Callable[[dict, dict], dict] | Any:
661
+ """Create a node that invokes a compiled subgraph.
662
+
663
+ Args:
664
+ node_name: Name of this node in parent graph
665
+ node_config: Subgraph configuration from YAML
666
+ parent_graph_path: Path to parent graph (for relative resolution)
667
+ parent_checkpointer: Checkpointer to inherit (if any)
668
+
669
+ Returns:
670
+ Node function that invokes subgraph (or CompiledGraph for mode=direct)
671
+
672
+ Raises:
673
+ FileNotFoundError: If subgraph YAML doesn't exist
674
+ ValueError: If circular reference detected
675
+ """
676
+ from yamlgraph.graph_loader import compile_graph, load_graph_config
677
+
678
+ # Resolve path relative to parent graph file
679
+ graph_rel_path = node_config["graph"]
680
+ graph_path = (parent_graph_path.parent / graph_rel_path).resolve()
681
+
682
+ mode = node_config.get("mode", "invoke")
683
+ input_mapping = node_config.get("input_mapping", {})
684
+ output_mapping = node_config.get("output_mapping", {})
685
+ interrupt_output_mapping = node_config.get("interrupt_output_mapping", {})
686
+
687
+ # Validate graph exists
688
+ if not graph_path.exists():
689
+ raise FileNotFoundError(f"Subgraph not found: {graph_path}")
690
+
691
+ # Circular reference detection (thread-safe)
692
+ # Use .get([]) to provide default without sharing mutable state
693
+ stack = _loading_stack.get([])
694
+ if graph_path in stack:
695
+ cycle = " -> ".join(str(p) for p in [*stack, graph_path])
696
+ raise ValueError(f"Circular subgraph reference: {cycle}")
697
+
698
+ # Push onto loading stack for this context
699
+ token = _loading_stack.set([*stack, graph_path])
700
+ try:
701
+ subgraph_config = load_graph_config(graph_path)
702
+ state_graph = compile_graph(subgraph_config)
703
+ # Compile with checkpointer (if provided)
704
+ compiled = state_graph.compile(checkpointer=parent_checkpointer)
705
+ finally:
706
+ _loading_stack.reset(token)
707
+
708
+ if mode == "direct":
709
+ # Mode: Direct - shared schema, LangGraph handles state mapping
710
+ # Return compiled graph directly - LangGraph's add_node() accepts
711
+ # CompiledStateGraph objects and handles them natively
712
+ return compiled
713
+
714
+ # Mode: Invoke - explicit state mapping
715
+ from langchain_core.runnables import RunnableConfig
716
+
717
+ def subgraph_node(state: dict, config: RunnableConfig | None = None) -> dict:
718
+ """Execute the subgraph with mapped state."""
719
+ from langgraph.errors import GraphInterrupt
720
+
721
+ config = config or {}
722
+
723
+ # Build child input from parent state
724
+ child_input = _map_input_state(state, input_mapping)
725
+
726
+ # Build child config with propagated thread ID
727
+ child_config = _build_child_config(config, node_name)
728
+
729
+ # Invoke subgraph - may raise GraphInterrupt
730
+ try:
731
+ child_output = compiled.invoke(child_input, child_config)
732
+ is_interrupted = "__interrupt__" in child_output
733
+ except GraphInterrupt:
734
+ # FR-006: Child hit an interrupt
735
+ if interrupt_output_mapping:
736
+ # Get child state from checkpointer
737
+ child_state = compiled.get_state(child_config)
738
+ child_output = dict(child_state.values) if child_state else {}
739
+
740
+ # Apply interrupt_output_mapping
741
+ parent_updates = _map_output_state(child_output, interrupt_output_mapping)
742
+ parent_updates["current_step"] = node_name
743
+
744
+ # Use __pregel_send to update parent state before re-raising
745
+ # This allows the mapped state to be included in the result
746
+ send = config.get("configurable", {}).get("__pregel_send")
747
+ if send:
748
+ # Convert dict to list of (key, value) tuples
749
+ updates = [(k, v) for k, v in parent_updates.items()]
750
+ send(updates)
751
+ logger.info(f"FR-006: Subgraph {node_name} mapped state: {list(parent_updates.keys())}")
752
+
753
+ # Re-raise to pause the graph
754
+ raise
755
+
756
+ # Normal completion path
757
+ if is_interrupted and interrupt_output_mapping:
758
+ parent_updates = _map_output_state(child_output, interrupt_output_mapping)
759
+ parent_updates["__interrupt__"] = child_output["__interrupt__"]
760
+ else:
761
+ parent_updates = _map_output_state(child_output, output_mapping)
762
+
763
+ parent_updates["current_step"] = node_name
764
+
765
+ return parent_updates
766
+
767
+ subgraph_node.__name__ = f"{node_name}_subgraph"
768
+ return subgraph_node