yamlgraph 0.3.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- examples/__init__.py +1 -0
- examples/codegen/__init__.py +5 -0
- examples/codegen/models/__init__.py +13 -0
- examples/codegen/models/schemas.py +76 -0
- examples/codegen/tests/__init__.py +1 -0
- examples/codegen/tests/test_ai_helpers.py +235 -0
- examples/codegen/tests/test_ast_analysis.py +174 -0
- examples/codegen/tests/test_code_analysis.py +134 -0
- examples/codegen/tests/test_code_context.py +301 -0
- examples/codegen/tests/test_code_nav.py +89 -0
- examples/codegen/tests/test_dependency_tools.py +119 -0
- examples/codegen/tests/test_example_tools.py +185 -0
- examples/codegen/tests/test_git_tools.py +112 -0
- examples/codegen/tests/test_impl_agent_schemas.py +193 -0
- examples/codegen/tests/test_impl_agent_v4_graph.py +94 -0
- examples/codegen/tests/test_jedi_analysis.py +226 -0
- examples/codegen/tests/test_meta_tools.py +250 -0
- examples/codegen/tests/test_plan_discovery_prompt.py +98 -0
- examples/codegen/tests/test_syntax_tools.py +85 -0
- examples/codegen/tests/test_synthesize_prompt.py +94 -0
- examples/codegen/tests/test_template_tools.py +244 -0
- examples/codegen/tools/__init__.py +80 -0
- examples/codegen/tools/ai_helpers.py +420 -0
- examples/codegen/tools/ast_analysis.py +92 -0
- examples/codegen/tools/code_context.py +180 -0
- examples/codegen/tools/code_nav.py +52 -0
- examples/codegen/tools/dependency_tools.py +120 -0
- examples/codegen/tools/example_tools.py +188 -0
- examples/codegen/tools/git_tools.py +151 -0
- examples/codegen/tools/impl_executor.py +614 -0
- examples/codegen/tools/jedi_analysis.py +311 -0
- examples/codegen/tools/meta_tools.py +202 -0
- examples/codegen/tools/syntax_tools.py +26 -0
- examples/codegen/tools/template_tools.py +356 -0
- examples/fastapi_interview.py +167 -0
- examples/npc/api/__init__.py +1 -0
- examples/npc/api/app.py +100 -0
- examples/npc/api/routes/__init__.py +5 -0
- examples/npc/api/routes/encounter.py +182 -0
- examples/npc/api/session.py +330 -0
- examples/npc/demo.py +387 -0
- examples/npc/nodes/__init__.py +5 -0
- examples/npc/nodes/image_node.py +92 -0
- examples/npc/run_encounter.py +230 -0
- examples/shared/__init__.py +0 -0
- examples/shared/replicate_tool.py +238 -0
- examples/storyboard/__init__.py +1 -0
- examples/storyboard/generate_videos.py +335 -0
- examples/storyboard/nodes/__init__.py +12 -0
- examples/storyboard/nodes/animated_character_node.py +248 -0
- examples/storyboard/nodes/animated_image_node.py +138 -0
- examples/storyboard/nodes/character_node.py +162 -0
- examples/storyboard/nodes/image_node.py +118 -0
- examples/storyboard/nodes/replicate_tool.py +49 -0
- examples/storyboard/retry_images.py +118 -0
- scripts/demo_async_executor.py +212 -0
- scripts/demo_interview_e2e.py +200 -0
- scripts/demo_streaming.py +140 -0
- scripts/run_interview_demo.py +94 -0
- scripts/test_interrupt_fix.py +26 -0
- tests/__init__.py +1 -0
- tests/conftest.py +178 -0
- tests/integration/__init__.py +1 -0
- tests/integration/test_animated_storyboard.py +63 -0
- tests/integration/test_cli_commands.py +242 -0
- tests/integration/test_colocated_prompts.py +139 -0
- tests/integration/test_map_demo.py +50 -0
- tests/integration/test_memory_demo.py +283 -0
- tests/integration/test_npc_api/__init__.py +1 -0
- tests/integration/test_npc_api/test_routes.py +357 -0
- tests/integration/test_npc_api/test_session.py +216 -0
- tests/integration/test_pipeline_flow.py +105 -0
- tests/integration/test_providers.py +163 -0
- tests/integration/test_resume.py +75 -0
- tests/integration/test_subgraph_integration.py +295 -0
- tests/integration/test_subgraph_interrupt.py +106 -0
- tests/unit/__init__.py +1 -0
- tests/unit/test_agent_nodes.py +355 -0
- tests/unit/test_async_executor.py +346 -0
- tests/unit/test_checkpointer.py +212 -0
- tests/unit/test_checkpointer_factory.py +212 -0
- tests/unit/test_cli.py +121 -0
- tests/unit/test_cli_package.py +81 -0
- tests/unit/test_compile_graph_map.py +132 -0
- tests/unit/test_conditions_routing.py +253 -0
- tests/unit/test_config.py +93 -0
- tests/unit/test_conversation_memory.py +276 -0
- tests/unit/test_database.py +145 -0
- tests/unit/test_deprecation.py +104 -0
- tests/unit/test_executor.py +172 -0
- tests/unit/test_executor_async.py +179 -0
- tests/unit/test_export.py +149 -0
- tests/unit/test_expressions.py +178 -0
- tests/unit/test_feature_brainstorm.py +194 -0
- tests/unit/test_format_prompt.py +145 -0
- tests/unit/test_generic_report.py +200 -0
- tests/unit/test_graph_commands.py +327 -0
- tests/unit/test_graph_linter.py +627 -0
- tests/unit/test_graph_loader.py +357 -0
- tests/unit/test_graph_schema.py +193 -0
- tests/unit/test_inline_schema.py +151 -0
- tests/unit/test_interrupt_node.py +182 -0
- tests/unit/test_issues.py +164 -0
- tests/unit/test_jinja2_prompts.py +85 -0
- tests/unit/test_json_extract.py +134 -0
- tests/unit/test_langsmith.py +600 -0
- tests/unit/test_langsmith_tools.py +204 -0
- tests/unit/test_llm_factory.py +109 -0
- tests/unit/test_llm_factory_async.py +118 -0
- tests/unit/test_loops.py +403 -0
- tests/unit/test_map_node.py +144 -0
- tests/unit/test_no_backward_compat.py +56 -0
- tests/unit/test_node_factory.py +348 -0
- tests/unit/test_passthrough_node.py +126 -0
- tests/unit/test_prompts.py +324 -0
- tests/unit/test_python_nodes.py +198 -0
- tests/unit/test_reliability.py +298 -0
- tests/unit/test_result_export.py +234 -0
- tests/unit/test_router.py +296 -0
- tests/unit/test_sanitize.py +99 -0
- tests/unit/test_schema_loader.py +295 -0
- tests/unit/test_shell_tools.py +229 -0
- tests/unit/test_state_builder.py +331 -0
- tests/unit/test_state_builder_map.py +104 -0
- tests/unit/test_state_config.py +197 -0
- tests/unit/test_streaming.py +307 -0
- tests/unit/test_subgraph.py +596 -0
- tests/unit/test_template.py +190 -0
- tests/unit/test_tool_call_integration.py +164 -0
- tests/unit/test_tool_call_node.py +178 -0
- tests/unit/test_tool_nodes.py +129 -0
- tests/unit/test_websearch.py +234 -0
- yamlgraph/__init__.py +35 -0
- yamlgraph/builder.py +110 -0
- yamlgraph/cli/__init__.py +159 -0
- yamlgraph/cli/__main__.py +6 -0
- yamlgraph/cli/commands.py +231 -0
- yamlgraph/cli/deprecation.py +92 -0
- yamlgraph/cli/graph_commands.py +541 -0
- yamlgraph/cli/validators.py +37 -0
- yamlgraph/config.py +67 -0
- yamlgraph/constants.py +70 -0
- yamlgraph/error_handlers.py +227 -0
- yamlgraph/executor.py +290 -0
- yamlgraph/executor_async.py +288 -0
- yamlgraph/graph_loader.py +451 -0
- yamlgraph/map_compiler.py +150 -0
- yamlgraph/models/__init__.py +36 -0
- yamlgraph/models/graph_schema.py +181 -0
- yamlgraph/models/schemas.py +124 -0
- yamlgraph/models/state_builder.py +236 -0
- yamlgraph/node_factory.py +768 -0
- yamlgraph/routing.py +87 -0
- yamlgraph/schema_loader.py +240 -0
- yamlgraph/storage/__init__.py +20 -0
- yamlgraph/storage/checkpointer.py +72 -0
- yamlgraph/storage/checkpointer_factory.py +123 -0
- yamlgraph/storage/database.py +320 -0
- yamlgraph/storage/export.py +269 -0
- yamlgraph/tools/__init__.py +1 -0
- yamlgraph/tools/agent.py +320 -0
- yamlgraph/tools/graph_linter.py +388 -0
- yamlgraph/tools/langsmith_tools.py +125 -0
- yamlgraph/tools/nodes.py +126 -0
- yamlgraph/tools/python_tool.py +179 -0
- yamlgraph/tools/shell.py +205 -0
- yamlgraph/tools/websearch.py +242 -0
- yamlgraph/utils/__init__.py +48 -0
- yamlgraph/utils/conditions.py +157 -0
- yamlgraph/utils/expressions.py +245 -0
- yamlgraph/utils/json_extract.py +104 -0
- yamlgraph/utils/langsmith.py +416 -0
- yamlgraph/utils/llm_factory.py +118 -0
- yamlgraph/utils/llm_factory_async.py +105 -0
- yamlgraph/utils/logging.py +104 -0
- yamlgraph/utils/prompts.py +171 -0
- yamlgraph/utils/sanitize.py +98 -0
- yamlgraph/utils/template.py +102 -0
- yamlgraph/utils/validators.py +181 -0
- yamlgraph-0.3.9.dist-info/METADATA +1105 -0
- yamlgraph-0.3.9.dist-info/RECORD +185 -0
- yamlgraph-0.3.9.dist-info/WHEEL +5 -0
- yamlgraph-0.3.9.dist-info/entry_points.txt +2 -0
- yamlgraph-0.3.9.dist-info/licenses/LICENSE +33 -0
- yamlgraph-0.3.9.dist-info/top_level.txt +4 -0
|
@@ -0,0 +1,768 @@
|
|
|
1
|
+
"""Node function factory for YAML-defined graphs.
|
|
2
|
+
|
|
3
|
+
Creates LangGraph node functions from YAML configuration with support for:
|
|
4
|
+
- Resume (skip if output exists)
|
|
5
|
+
- Error handling (skip, retry, fail, fallback)
|
|
6
|
+
- Router nodes with dynamic routing
|
|
7
|
+
- Loop counting and limits
|
|
8
|
+
- Dynamic tool calls from state (type: tool_call)
|
|
9
|
+
- Streaming nodes (type: llm, stream: true)
|
|
10
|
+
- Subgraph nodes (type: subgraph) for composing workflows
|
|
11
|
+
- JSON extraction from LLM output (parse_json: true)
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
import logging
|
|
15
|
+
from collections.abc import AsyncIterator, Callable
|
|
16
|
+
from contextvars import ContextVar
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
from typing import Any
|
|
19
|
+
|
|
20
|
+
from yamlgraph.constants import ErrorHandler, NodeType
|
|
21
|
+
from yamlgraph.error_handlers import (
|
|
22
|
+
check_loop_limit,
|
|
23
|
+
check_requirements,
|
|
24
|
+
handle_default,
|
|
25
|
+
handle_fail,
|
|
26
|
+
handle_fallback,
|
|
27
|
+
handle_retry,
|
|
28
|
+
handle_skip,
|
|
29
|
+
)
|
|
30
|
+
from yamlgraph.executor import execute_prompt
|
|
31
|
+
from yamlgraph.utils.expressions import resolve_template
|
|
32
|
+
from yamlgraph.utils.json_extract import extract_json
|
|
33
|
+
from yamlgraph.utils.prompts import resolve_prompt_path
|
|
34
|
+
|
|
35
|
+
# Type alias for dynamic state
|
|
36
|
+
GraphState = dict[str, Any]
|
|
37
|
+
|
|
38
|
+
logger = logging.getLogger(__name__)
|
|
39
|
+
|
|
40
|
+
# Thread-safe loading stack to detect circular subgraph references
|
|
41
|
+
# Note: Do NOT use default=[] as it shares the same list across contexts
|
|
42
|
+
_loading_stack: ContextVar[list[Path]] = ContextVar("loading_stack")
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def create_tool_call_node(
|
|
46
|
+
node_name: str,
|
|
47
|
+
node_config: dict[str, Any],
|
|
48
|
+
tools_registry: dict[str, Callable],
|
|
49
|
+
) -> Callable[[GraphState], dict]:
|
|
50
|
+
"""Create a node that dynamically calls a tool from state.
|
|
51
|
+
|
|
52
|
+
This enables YAML-driven tool execution where tool name and args
|
|
53
|
+
are resolved from state at runtime.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
node_name: Name of the node
|
|
57
|
+
node_config: Node configuration with 'tool', 'args', 'state_key'
|
|
58
|
+
tools_registry: Dict mapping tool names to callable functions
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
Node function compatible with LangGraph
|
|
62
|
+
"""
|
|
63
|
+
tool_expr = node_config["tool"] # e.g., "{state.task.tool}"
|
|
64
|
+
args_expr = node_config["args"] # e.g., "{state.task.args}"
|
|
65
|
+
state_key = node_config.get("state_key", "result")
|
|
66
|
+
|
|
67
|
+
def node_fn(state: dict) -> dict:
|
|
68
|
+
# Resolve tool name and args from state
|
|
69
|
+
tool_name = resolve_template(tool_expr, state)
|
|
70
|
+
args = resolve_template(args_expr, state)
|
|
71
|
+
|
|
72
|
+
# Extract task_id if available
|
|
73
|
+
task = state.get("task", {})
|
|
74
|
+
task_id = task.get("id") if isinstance(task, dict) else None
|
|
75
|
+
|
|
76
|
+
# Look up tool in registry
|
|
77
|
+
tool_func = tools_registry.get(tool_name)
|
|
78
|
+
if tool_func is None:
|
|
79
|
+
return {
|
|
80
|
+
state_key: {
|
|
81
|
+
"task_id": task_id,
|
|
82
|
+
"tool": tool_name,
|
|
83
|
+
"success": False,
|
|
84
|
+
"result": None,
|
|
85
|
+
"error": f"Unknown tool: {tool_name}",
|
|
86
|
+
},
|
|
87
|
+
"current_step": node_name,
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
# Execute tool
|
|
91
|
+
try:
|
|
92
|
+
# Ensure args is a dict
|
|
93
|
+
if not isinstance(args, dict):
|
|
94
|
+
args = {}
|
|
95
|
+
result = tool_func(**args)
|
|
96
|
+
return {
|
|
97
|
+
state_key: {
|
|
98
|
+
"task_id": task_id,
|
|
99
|
+
"tool": tool_name,
|
|
100
|
+
"success": True,
|
|
101
|
+
"result": result,
|
|
102
|
+
"error": None,
|
|
103
|
+
},
|
|
104
|
+
"current_step": node_name,
|
|
105
|
+
}
|
|
106
|
+
except Exception as e:
|
|
107
|
+
return {
|
|
108
|
+
state_key: {
|
|
109
|
+
"task_id": task_id,
|
|
110
|
+
"tool": tool_name,
|
|
111
|
+
"success": False,
|
|
112
|
+
"result": None,
|
|
113
|
+
"error": str(e),
|
|
114
|
+
},
|
|
115
|
+
"current_step": node_name,
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
node_fn.__name__ = f"{node_name}_tool_call"
|
|
119
|
+
return node_fn
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def resolve_class(class_path: str) -> type:
|
|
123
|
+
"""Dynamically import and return a class from a module path.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
class_path: Full path like "yamlgraph.models.GenericReport" or short name
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
The class object
|
|
130
|
+
"""
|
|
131
|
+
import importlib
|
|
132
|
+
|
|
133
|
+
parts = class_path.rsplit(".", 1)
|
|
134
|
+
if len(parts) != 2:
|
|
135
|
+
# Try to find in yamlgraph.models.schemas
|
|
136
|
+
try:
|
|
137
|
+
from yamlgraph.models import schemas
|
|
138
|
+
|
|
139
|
+
if hasattr(schemas, class_path):
|
|
140
|
+
return getattr(schemas, class_path)
|
|
141
|
+
except ImportError:
|
|
142
|
+
pass
|
|
143
|
+
raise ValueError(f"Invalid class path: {class_path}")
|
|
144
|
+
|
|
145
|
+
module_path, class_name = parts
|
|
146
|
+
module = importlib.import_module(module_path)
|
|
147
|
+
return getattr(module, class_name)
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def get_output_model_for_node(
|
|
151
|
+
node_config: dict[str, Any],
|
|
152
|
+
prompts_dir: Path | None = None,
|
|
153
|
+
graph_path: Path | None = None,
|
|
154
|
+
prompts_relative: bool = False,
|
|
155
|
+
) -> type | None:
|
|
156
|
+
"""Get output model for a node, checking inline schema if no explicit model.
|
|
157
|
+
|
|
158
|
+
Priority:
|
|
159
|
+
1. Explicit output_model in node config (class path)
|
|
160
|
+
2. Inline schema in prompt YAML file
|
|
161
|
+
3. None (raw string output)
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
node_config: Node configuration from YAML
|
|
165
|
+
prompts_dir: Base prompts directory
|
|
166
|
+
graph_path: Path to graph YAML file (for relative prompt resolution)
|
|
167
|
+
prompts_relative: If True, resolve prompts relative to graph_path
|
|
168
|
+
|
|
169
|
+
Returns:
|
|
170
|
+
Pydantic model class or None
|
|
171
|
+
"""
|
|
172
|
+
# 1. Check for explicit output_model
|
|
173
|
+
if model_path := node_config.get("output_model"):
|
|
174
|
+
return resolve_class(model_path)
|
|
175
|
+
|
|
176
|
+
# 2. Check for inline schema in prompt YAML
|
|
177
|
+
prompt_name = node_config.get("prompt")
|
|
178
|
+
if prompt_name:
|
|
179
|
+
try:
|
|
180
|
+
from yamlgraph.schema_loader import load_schema_from_yaml
|
|
181
|
+
|
|
182
|
+
yaml_path = resolve_prompt_path(
|
|
183
|
+
prompt_name,
|
|
184
|
+
prompts_dir=prompts_dir,
|
|
185
|
+
graph_path=graph_path,
|
|
186
|
+
prompts_relative=prompts_relative,
|
|
187
|
+
)
|
|
188
|
+
return load_schema_from_yaml(yaml_path)
|
|
189
|
+
except FileNotFoundError:
|
|
190
|
+
# Prompt file doesn't exist yet - will fail later
|
|
191
|
+
pass
|
|
192
|
+
|
|
193
|
+
# 3. No output model
|
|
194
|
+
return None
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def create_node_function(
|
|
198
|
+
node_name: str,
|
|
199
|
+
node_config: dict,
|
|
200
|
+
defaults: dict,
|
|
201
|
+
graph_path: Path | None = None,
|
|
202
|
+
) -> Callable[[GraphState], dict]:
|
|
203
|
+
"""Create a node function from YAML config.
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
node_name: Name of the node
|
|
207
|
+
node_config: Node configuration from YAML
|
|
208
|
+
defaults: Default configuration values
|
|
209
|
+
graph_path: Path to graph YAML file (for relative prompt resolution)
|
|
210
|
+
|
|
211
|
+
Returns:
|
|
212
|
+
Node function compatible with LangGraph
|
|
213
|
+
"""
|
|
214
|
+
node_type = node_config.get("type", NodeType.LLM)
|
|
215
|
+
prompt_name = node_config.get("prompt")
|
|
216
|
+
|
|
217
|
+
# Prompt resolution options from defaults (FR-A)
|
|
218
|
+
prompts_relative = defaults.get("prompts_relative", False)
|
|
219
|
+
prompts_dir = defaults.get("prompts_dir")
|
|
220
|
+
if prompts_dir:
|
|
221
|
+
prompts_dir = Path(prompts_dir)
|
|
222
|
+
|
|
223
|
+
# Check for streaming mode
|
|
224
|
+
if node_config.get("stream", False):
|
|
225
|
+
return create_streaming_node(node_name, node_config)
|
|
226
|
+
|
|
227
|
+
# Resolve output model (explicit > inline schema > None)
|
|
228
|
+
output_model = get_output_model_for_node(
|
|
229
|
+
node_config,
|
|
230
|
+
prompts_dir=prompts_dir,
|
|
231
|
+
graph_path=graph_path,
|
|
232
|
+
prompts_relative=prompts_relative,
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
# Get config values (node > defaults)
|
|
236
|
+
temperature = node_config.get("temperature", defaults.get("temperature", 0.7))
|
|
237
|
+
provider = node_config.get("provider", defaults.get("provider"))
|
|
238
|
+
state_key = node_config.get("state_key", node_name)
|
|
239
|
+
variable_templates = node_config.get("variables", {})
|
|
240
|
+
requires = node_config.get("requires", [])
|
|
241
|
+
|
|
242
|
+
# Error handling
|
|
243
|
+
on_error = node_config.get("on_error")
|
|
244
|
+
max_retries = node_config.get("max_retries", 3)
|
|
245
|
+
fallback_config = node_config.get("fallback", {})
|
|
246
|
+
fallback_provider = fallback_config.get("provider") if fallback_config else None
|
|
247
|
+
|
|
248
|
+
# Router config
|
|
249
|
+
routes = node_config.get("routes", {})
|
|
250
|
+
default_route = node_config.get("default_route")
|
|
251
|
+
|
|
252
|
+
# Loop limit
|
|
253
|
+
loop_limit = node_config.get("loop_limit")
|
|
254
|
+
|
|
255
|
+
# Skip if exists (default true for resume support, false for loop nodes)
|
|
256
|
+
skip_if_exists = node_config.get("skip_if_exists", True)
|
|
257
|
+
|
|
258
|
+
# JSON extraction (FR-B)
|
|
259
|
+
parse_json = node_config.get("parse_json", False)
|
|
260
|
+
|
|
261
|
+
def node_fn(state: dict) -> dict:
|
|
262
|
+
"""Generated node function."""
|
|
263
|
+
loop_counts = dict(state.get("_loop_counts") or {})
|
|
264
|
+
current_count = loop_counts.get(node_name, 0)
|
|
265
|
+
|
|
266
|
+
# Check loop limit
|
|
267
|
+
if check_loop_limit(node_name, loop_limit, current_count):
|
|
268
|
+
return {"_loop_limit_reached": True, "current_step": node_name}
|
|
269
|
+
|
|
270
|
+
loop_counts[node_name] = current_count + 1
|
|
271
|
+
|
|
272
|
+
# Skip if output exists (resume support) - disabled for loop nodes
|
|
273
|
+
if skip_if_exists and state.get(state_key) is not None:
|
|
274
|
+
logger.info(f"Node {node_name} skipped - {state_key} already in state")
|
|
275
|
+
return {"current_step": node_name, "_loop_counts": loop_counts}
|
|
276
|
+
|
|
277
|
+
# Check requirements
|
|
278
|
+
if error := check_requirements(requires, state, node_name):
|
|
279
|
+
return {
|
|
280
|
+
"errors": [error],
|
|
281
|
+
"current_step": node_name,
|
|
282
|
+
"_loop_counts": loop_counts,
|
|
283
|
+
}
|
|
284
|
+
|
|
285
|
+
# Resolve variables from templates OR use state directly
|
|
286
|
+
if variable_templates:
|
|
287
|
+
variables = {}
|
|
288
|
+
for key, template in variable_templates.items():
|
|
289
|
+
resolved = resolve_template(template, state)
|
|
290
|
+
# Preserve original types (lists, dicts) for Jinja2 templates
|
|
291
|
+
variables[key] = resolved
|
|
292
|
+
else:
|
|
293
|
+
# No explicit variable mapping - pass state as variables
|
|
294
|
+
# Filter out internal keys and None values
|
|
295
|
+
variables = {
|
|
296
|
+
k: v
|
|
297
|
+
for k, v in state.items()
|
|
298
|
+
if not k.startswith("_") and v is not None
|
|
299
|
+
}
|
|
300
|
+
|
|
301
|
+
def attempt_execute(use_provider: str | None) -> tuple[Any, Exception | None]:
|
|
302
|
+
try:
|
|
303
|
+
result = execute_prompt(
|
|
304
|
+
prompt_name=prompt_name,
|
|
305
|
+
variables=variables,
|
|
306
|
+
output_model=output_model,
|
|
307
|
+
temperature=temperature,
|
|
308
|
+
provider=use_provider,
|
|
309
|
+
graph_path=graph_path,
|
|
310
|
+
prompts_dir=prompts_dir,
|
|
311
|
+
prompts_relative=prompts_relative,
|
|
312
|
+
)
|
|
313
|
+
return result, None
|
|
314
|
+
except Exception as e:
|
|
315
|
+
return None, e
|
|
316
|
+
|
|
317
|
+
result, error = attempt_execute(provider)
|
|
318
|
+
|
|
319
|
+
if error is None:
|
|
320
|
+
# Post-process: JSON extraction if enabled (FR-B)
|
|
321
|
+
if parse_json and isinstance(result, str):
|
|
322
|
+
result = extract_json(result)
|
|
323
|
+
|
|
324
|
+
logger.info(f"Node {node_name} completed successfully")
|
|
325
|
+
update = {
|
|
326
|
+
state_key: result,
|
|
327
|
+
"current_step": node_name,
|
|
328
|
+
"_loop_counts": loop_counts,
|
|
329
|
+
}
|
|
330
|
+
|
|
331
|
+
# Router: add _route to state
|
|
332
|
+
if node_type == NodeType.ROUTER and routes:
|
|
333
|
+
route_key = getattr(result, "tone", None) or getattr(
|
|
334
|
+
result, "intent", None
|
|
335
|
+
)
|
|
336
|
+
if route_key and route_key in routes:
|
|
337
|
+
update["_route"] = routes[route_key]
|
|
338
|
+
elif default_route:
|
|
339
|
+
update["_route"] = default_route
|
|
340
|
+
else:
|
|
341
|
+
update["_route"] = list(routes.values())[0]
|
|
342
|
+
logger.info(f"Router {node_name} routing to: {update['_route']}")
|
|
343
|
+
return update
|
|
344
|
+
|
|
345
|
+
# Error handling - dispatch to strategy handlers
|
|
346
|
+
if on_error == ErrorHandler.SKIP:
|
|
347
|
+
handle_skip(node_name, error, loop_counts)
|
|
348
|
+
return {"current_step": node_name, "_loop_counts": loop_counts}
|
|
349
|
+
|
|
350
|
+
elif on_error == ErrorHandler.FAIL:
|
|
351
|
+
handle_fail(node_name, error)
|
|
352
|
+
|
|
353
|
+
elif on_error == ErrorHandler.RETRY:
|
|
354
|
+
result = handle_retry(
|
|
355
|
+
node_name,
|
|
356
|
+
lambda: attempt_execute(provider),
|
|
357
|
+
max_retries,
|
|
358
|
+
)
|
|
359
|
+
return result.to_state_update(state_key, node_name, loop_counts)
|
|
360
|
+
|
|
361
|
+
elif on_error == ErrorHandler.FALLBACK and fallback_provider:
|
|
362
|
+
result = handle_fallback(
|
|
363
|
+
node_name,
|
|
364
|
+
attempt_execute,
|
|
365
|
+
fallback_provider,
|
|
366
|
+
)
|
|
367
|
+
return result.to_state_update(state_key, node_name, loop_counts)
|
|
368
|
+
|
|
369
|
+
else:
|
|
370
|
+
result = handle_default(node_name, error)
|
|
371
|
+
return result.to_state_update(state_key, node_name, loop_counts)
|
|
372
|
+
|
|
373
|
+
node_fn.__name__ = f"{node_name}_node"
|
|
374
|
+
return node_fn
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
def create_interrupt_node(
|
|
378
|
+
node_name: str,
|
|
379
|
+
config: dict[str, Any],
|
|
380
|
+
graph_path: Path | None = None,
|
|
381
|
+
prompts_dir: Path | None = None,
|
|
382
|
+
prompts_relative: bool = False,
|
|
383
|
+
) -> Callable[[GraphState], dict]:
|
|
384
|
+
"""Create an interrupt node that pauses for human input.
|
|
385
|
+
|
|
386
|
+
Uses LangGraph's native interrupt() function for human-in-the-loop.
|
|
387
|
+
Handles idempotency by checking state_key before re-executing prompts.
|
|
388
|
+
|
|
389
|
+
Args:
|
|
390
|
+
node_name: Name of the node
|
|
391
|
+
config: Node configuration with optional keys:
|
|
392
|
+
- message: Static interrupt payload (string or dict)
|
|
393
|
+
- prompt: Prompt name to generate dynamic payload
|
|
394
|
+
- state_key: Where to store payload (default: "interrupt_message")
|
|
395
|
+
- resume_key: Where to store resume value (default: "user_input")
|
|
396
|
+
graph_path: Path to graph file for relative prompt resolution
|
|
397
|
+
prompts_dir: Explicit prompts directory override
|
|
398
|
+
prompts_relative: If True, resolve prompts relative to graph_path
|
|
399
|
+
|
|
400
|
+
Returns:
|
|
401
|
+
Node function compatible with LangGraph
|
|
402
|
+
"""
|
|
403
|
+
from langgraph.types import interrupt
|
|
404
|
+
|
|
405
|
+
message = config.get("message")
|
|
406
|
+
prompt_name = config.get("prompt")
|
|
407
|
+
state_key = config.get("state_key", "interrupt_message")
|
|
408
|
+
resume_key = config.get("resume_key", "user_input")
|
|
409
|
+
|
|
410
|
+
def interrupt_fn(state: dict) -> dict:
|
|
411
|
+
# Check if we already have a payload (resuming) - idempotency
|
|
412
|
+
existing_payload = state.get(state_key)
|
|
413
|
+
|
|
414
|
+
if existing_payload is not None:
|
|
415
|
+
# Resuming - use stored payload, don't re-execute prompt
|
|
416
|
+
payload = existing_payload
|
|
417
|
+
elif prompt_name:
|
|
418
|
+
# First execution with prompt
|
|
419
|
+
payload = execute_prompt(
|
|
420
|
+
prompt_name,
|
|
421
|
+
state,
|
|
422
|
+
graph_path=graph_path,
|
|
423
|
+
prompts_dir=prompts_dir,
|
|
424
|
+
prompts_relative=prompts_relative,
|
|
425
|
+
)
|
|
426
|
+
elif message is not None:
|
|
427
|
+
# Static message
|
|
428
|
+
payload = message
|
|
429
|
+
else:
|
|
430
|
+
# Fallback: use node name as payload
|
|
431
|
+
payload = {"node": node_name}
|
|
432
|
+
|
|
433
|
+
# Native LangGraph interrupt - pauses here on first run
|
|
434
|
+
# On resume, returns the Command(resume=...) value
|
|
435
|
+
response = interrupt(payload)
|
|
436
|
+
|
|
437
|
+
return {
|
|
438
|
+
state_key: payload, # Store for idempotency check
|
|
439
|
+
resume_key: response, # User's response
|
|
440
|
+
"current_step": node_name,
|
|
441
|
+
}
|
|
442
|
+
|
|
443
|
+
interrupt_fn.__name__ = f"{node_name}_interrupt"
|
|
444
|
+
return interrupt_fn
|
|
445
|
+
|
|
446
|
+
|
|
447
|
+
def create_passthrough_node(
|
|
448
|
+
node_name: str,
|
|
449
|
+
config: dict[str, Any],
|
|
450
|
+
) -> Callable[[GraphState], dict]:
|
|
451
|
+
"""Create a passthrough node that transforms state without external calls.
|
|
452
|
+
|
|
453
|
+
Useful for:
|
|
454
|
+
- Loop counters (increment values)
|
|
455
|
+
- State accumulation (append to lists)
|
|
456
|
+
- Simple data transformations
|
|
457
|
+
- Clean transition points in graphs
|
|
458
|
+
|
|
459
|
+
Args:
|
|
460
|
+
node_name: Name of the node
|
|
461
|
+
config: Node configuration with:
|
|
462
|
+
- output: Dict of state_key -> expression mappings
|
|
463
|
+
Expressions use {state.field} syntax
|
|
464
|
+
Supports arithmetic: {state.count + 1}
|
|
465
|
+
Supports list append: {state.history + [state.current]}
|
|
466
|
+
|
|
467
|
+
Returns:
|
|
468
|
+
Node function compatible with LangGraph
|
|
469
|
+
|
|
470
|
+
Example:
|
|
471
|
+
```yaml
|
|
472
|
+
next_turn:
|
|
473
|
+
type: passthrough
|
|
474
|
+
output:
|
|
475
|
+
turn_number: "{state.turn_number + 1}"
|
|
476
|
+
history: "{state.history + [state.narration]}"
|
|
477
|
+
```
|
|
478
|
+
"""
|
|
479
|
+
from yamlgraph.utils.expressions import resolve_template
|
|
480
|
+
|
|
481
|
+
output_templates = config.get("output", {})
|
|
482
|
+
|
|
483
|
+
def passthrough_fn(state: dict) -> dict:
|
|
484
|
+
result = {"current_step": node_name}
|
|
485
|
+
|
|
486
|
+
for key, template in output_templates.items():
|
|
487
|
+
try:
|
|
488
|
+
resolved = resolve_template(template, state)
|
|
489
|
+
# If resolution failed (None) and key exists in state, keep original
|
|
490
|
+
if resolved is None and key in state:
|
|
491
|
+
result[key] = state[key]
|
|
492
|
+
else:
|
|
493
|
+
result[key] = resolved
|
|
494
|
+
except Exception as e:
|
|
495
|
+
logger.warning(
|
|
496
|
+
f"Passthrough node {node_name}: failed to resolve {key}: {e}"
|
|
497
|
+
)
|
|
498
|
+
# Keep original value on error
|
|
499
|
+
if key in state:
|
|
500
|
+
result[key] = state[key]
|
|
501
|
+
|
|
502
|
+
logger.info(f"Node {node_name} completed successfully")
|
|
503
|
+
return result
|
|
504
|
+
|
|
505
|
+
passthrough_fn.__name__ = f"{node_name}_passthrough"
|
|
506
|
+
return passthrough_fn
|
|
507
|
+
|
|
508
|
+
|
|
509
|
+
def create_streaming_node(
|
|
510
|
+
node_name: str,
|
|
511
|
+
node_config: dict[str, Any],
|
|
512
|
+
) -> Callable[[GraphState], AsyncIterator[str]]:
|
|
513
|
+
"""Create a streaming node that yields tokens.
|
|
514
|
+
|
|
515
|
+
Streaming nodes are async generators that yield tokens as they
|
|
516
|
+
are produced by the LLM. They do not support structured output.
|
|
517
|
+
|
|
518
|
+
Args:
|
|
519
|
+
node_name: Name of the node
|
|
520
|
+
node_config: Node configuration with:
|
|
521
|
+
- prompt: Prompt name to execute
|
|
522
|
+
- state_key: Where to store final result (optional)
|
|
523
|
+
- on_token: Optional callback function for each token
|
|
524
|
+
- provider: LLM provider
|
|
525
|
+
- temperature: LLM temperature
|
|
526
|
+
|
|
527
|
+
Returns:
|
|
528
|
+
Async generator function compatible with streaming execution
|
|
529
|
+
"""
|
|
530
|
+
from yamlgraph.executor_async import execute_prompt_streaming
|
|
531
|
+
from yamlgraph.utils.expressions import resolve_template
|
|
532
|
+
|
|
533
|
+
prompt_name = node_config.get("prompt")
|
|
534
|
+
variable_templates = node_config.get("variables", {})
|
|
535
|
+
provider = node_config.get("provider")
|
|
536
|
+
temperature = node_config.get("temperature", 0.7)
|
|
537
|
+
on_token = node_config.get("on_token")
|
|
538
|
+
|
|
539
|
+
async def streaming_node(state: dict) -> AsyncIterator[str]:
|
|
540
|
+
# Resolve variables from templates OR use state directly
|
|
541
|
+
if variable_templates:
|
|
542
|
+
variables = {}
|
|
543
|
+
for key, template in variable_templates.items():
|
|
544
|
+
resolved = resolve_template(template, state)
|
|
545
|
+
# Preserve original types (lists, dicts) for Jinja2 templates
|
|
546
|
+
variables[key] = resolved
|
|
547
|
+
else:
|
|
548
|
+
# No explicit variable mapping - pass state as variables
|
|
549
|
+
variables = {
|
|
550
|
+
k: v
|
|
551
|
+
for k, v in state.items()
|
|
552
|
+
if not k.startswith("_") and v is not None
|
|
553
|
+
}
|
|
554
|
+
|
|
555
|
+
async for token in execute_prompt_streaming(
|
|
556
|
+
prompt_name,
|
|
557
|
+
variables=variables,
|
|
558
|
+
provider=provider,
|
|
559
|
+
temperature=temperature,
|
|
560
|
+
):
|
|
561
|
+
if on_token:
|
|
562
|
+
on_token(token)
|
|
563
|
+
yield token
|
|
564
|
+
|
|
565
|
+
streaming_node.__name__ = f"{node_name}_streaming"
|
|
566
|
+
return streaming_node
|
|
567
|
+
|
|
568
|
+
|
|
569
|
+
# =============================================================================
|
|
570
|
+
# Subgraph Node Support
|
|
571
|
+
# =============================================================================
|
|
572
|
+
|
|
573
|
+
|
|
574
|
+
def _map_input_state(
|
|
575
|
+
parent_state: dict[str, Any],
|
|
576
|
+
input_mapping: dict[str, str] | str,
|
|
577
|
+
) -> dict[str, Any]:
|
|
578
|
+
"""Map parent state to child input based on mapping config.
|
|
579
|
+
|
|
580
|
+
Args:
|
|
581
|
+
parent_state: Current state from parent graph
|
|
582
|
+
input_mapping: Mapping configuration:
|
|
583
|
+
- dict: explicit {parent_key: child_key} mapping
|
|
584
|
+
- "auto": copy all fields
|
|
585
|
+
- "*": pass state reference directly
|
|
586
|
+
|
|
587
|
+
Returns:
|
|
588
|
+
Input state for child graph
|
|
589
|
+
"""
|
|
590
|
+
if input_mapping == "auto":
|
|
591
|
+
return parent_state.copy()
|
|
592
|
+
elif input_mapping == "*":
|
|
593
|
+
return parent_state
|
|
594
|
+
else:
|
|
595
|
+
return {
|
|
596
|
+
child_key: parent_state.get(parent_key)
|
|
597
|
+
for parent_key, child_key in input_mapping.items()
|
|
598
|
+
}
|
|
599
|
+
|
|
600
|
+
|
|
601
|
+
def _map_output_state(
|
|
602
|
+
child_output: dict[str, Any],
|
|
603
|
+
output_mapping: dict[str, str] | str,
|
|
604
|
+
) -> dict[str, Any]:
|
|
605
|
+
"""Map child output to parent state updates based on mapping config.
|
|
606
|
+
|
|
607
|
+
Args:
|
|
608
|
+
child_output: Output state from child graph
|
|
609
|
+
output_mapping: Mapping configuration:
|
|
610
|
+
- dict: explicit {parent_key: child_key} mapping
|
|
611
|
+
- "auto": pass all fields
|
|
612
|
+
- "*": pass output directly
|
|
613
|
+
|
|
614
|
+
Returns:
|
|
615
|
+
Updates to apply to parent state
|
|
616
|
+
"""
|
|
617
|
+
if output_mapping in ("auto", "*"):
|
|
618
|
+
return child_output
|
|
619
|
+
else:
|
|
620
|
+
return {
|
|
621
|
+
parent_key: child_output.get(child_key)
|
|
622
|
+
for parent_key, child_key in output_mapping.items()
|
|
623
|
+
}
|
|
624
|
+
|
|
625
|
+
|
|
626
|
+
def _build_child_config(
|
|
627
|
+
parent_config: dict[str, Any],
|
|
628
|
+
node_name: str,
|
|
629
|
+
) -> dict[str, Any]:
|
|
630
|
+
"""Build child graph config with propagated thread ID.
|
|
631
|
+
|
|
632
|
+
Args:
|
|
633
|
+
parent_config: RunnableConfig from parent graph
|
|
634
|
+
node_name: Name of the subgraph node
|
|
635
|
+
|
|
636
|
+
Returns:
|
|
637
|
+
Config for child graph with thread_id: parent_thread:node_name
|
|
638
|
+
"""
|
|
639
|
+
configurable = parent_config.get("configurable", {})
|
|
640
|
+
parent_thread_id = configurable.get("thread_id")
|
|
641
|
+
|
|
642
|
+
child_thread_id = (
|
|
643
|
+
f"{parent_thread_id}:{node_name}" if parent_thread_id else node_name
|
|
644
|
+
)
|
|
645
|
+
|
|
646
|
+
return {
|
|
647
|
+
**parent_config,
|
|
648
|
+
"configurable": {
|
|
649
|
+
**configurable,
|
|
650
|
+
"thread_id": child_thread_id,
|
|
651
|
+
},
|
|
652
|
+
}
|
|
653
|
+
|
|
654
|
+
|
|
655
|
+
def create_subgraph_node(
|
|
656
|
+
node_name: str,
|
|
657
|
+
node_config: dict[str, Any],
|
|
658
|
+
parent_graph_path: Path,
|
|
659
|
+
parent_checkpointer: Any | None = None,
|
|
660
|
+
) -> Callable[[dict, dict], dict] | Any:
|
|
661
|
+
"""Create a node that invokes a compiled subgraph.
|
|
662
|
+
|
|
663
|
+
Args:
|
|
664
|
+
node_name: Name of this node in parent graph
|
|
665
|
+
node_config: Subgraph configuration from YAML
|
|
666
|
+
parent_graph_path: Path to parent graph (for relative resolution)
|
|
667
|
+
parent_checkpointer: Checkpointer to inherit (if any)
|
|
668
|
+
|
|
669
|
+
Returns:
|
|
670
|
+
Node function that invokes subgraph (or CompiledGraph for mode=direct)
|
|
671
|
+
|
|
672
|
+
Raises:
|
|
673
|
+
FileNotFoundError: If subgraph YAML doesn't exist
|
|
674
|
+
ValueError: If circular reference detected
|
|
675
|
+
"""
|
|
676
|
+
from yamlgraph.graph_loader import compile_graph, load_graph_config
|
|
677
|
+
|
|
678
|
+
# Resolve path relative to parent graph file
|
|
679
|
+
graph_rel_path = node_config["graph"]
|
|
680
|
+
graph_path = (parent_graph_path.parent / graph_rel_path).resolve()
|
|
681
|
+
|
|
682
|
+
mode = node_config.get("mode", "invoke")
|
|
683
|
+
input_mapping = node_config.get("input_mapping", {})
|
|
684
|
+
output_mapping = node_config.get("output_mapping", {})
|
|
685
|
+
interrupt_output_mapping = node_config.get("interrupt_output_mapping", {})
|
|
686
|
+
|
|
687
|
+
# Validate graph exists
|
|
688
|
+
if not graph_path.exists():
|
|
689
|
+
raise FileNotFoundError(f"Subgraph not found: {graph_path}")
|
|
690
|
+
|
|
691
|
+
# Circular reference detection (thread-safe)
|
|
692
|
+
# Use .get([]) to provide default without sharing mutable state
|
|
693
|
+
stack = _loading_stack.get([])
|
|
694
|
+
if graph_path in stack:
|
|
695
|
+
cycle = " -> ".join(str(p) for p in [*stack, graph_path])
|
|
696
|
+
raise ValueError(f"Circular subgraph reference: {cycle}")
|
|
697
|
+
|
|
698
|
+
# Push onto loading stack for this context
|
|
699
|
+
token = _loading_stack.set([*stack, graph_path])
|
|
700
|
+
try:
|
|
701
|
+
subgraph_config = load_graph_config(graph_path)
|
|
702
|
+
state_graph = compile_graph(subgraph_config)
|
|
703
|
+
# Compile with checkpointer (if provided)
|
|
704
|
+
compiled = state_graph.compile(checkpointer=parent_checkpointer)
|
|
705
|
+
finally:
|
|
706
|
+
_loading_stack.reset(token)
|
|
707
|
+
|
|
708
|
+
if mode == "direct":
|
|
709
|
+
# Mode: Direct - shared schema, LangGraph handles state mapping
|
|
710
|
+
# Return compiled graph directly - LangGraph's add_node() accepts
|
|
711
|
+
# CompiledStateGraph objects and handles them natively
|
|
712
|
+
return compiled
|
|
713
|
+
|
|
714
|
+
# Mode: Invoke - explicit state mapping
|
|
715
|
+
from langchain_core.runnables import RunnableConfig
|
|
716
|
+
|
|
717
|
+
def subgraph_node(state: dict, config: RunnableConfig | None = None) -> dict:
|
|
718
|
+
"""Execute the subgraph with mapped state."""
|
|
719
|
+
from langgraph.errors import GraphInterrupt
|
|
720
|
+
|
|
721
|
+
config = config or {}
|
|
722
|
+
|
|
723
|
+
# Build child input from parent state
|
|
724
|
+
child_input = _map_input_state(state, input_mapping)
|
|
725
|
+
|
|
726
|
+
# Build child config with propagated thread ID
|
|
727
|
+
child_config = _build_child_config(config, node_name)
|
|
728
|
+
|
|
729
|
+
# Invoke subgraph - may raise GraphInterrupt
|
|
730
|
+
try:
|
|
731
|
+
child_output = compiled.invoke(child_input, child_config)
|
|
732
|
+
is_interrupted = "__interrupt__" in child_output
|
|
733
|
+
except GraphInterrupt:
|
|
734
|
+
# FR-006: Child hit an interrupt
|
|
735
|
+
if interrupt_output_mapping:
|
|
736
|
+
# Get child state from checkpointer
|
|
737
|
+
child_state = compiled.get_state(child_config)
|
|
738
|
+
child_output = dict(child_state.values) if child_state else {}
|
|
739
|
+
|
|
740
|
+
# Apply interrupt_output_mapping
|
|
741
|
+
parent_updates = _map_output_state(child_output, interrupt_output_mapping)
|
|
742
|
+
parent_updates["current_step"] = node_name
|
|
743
|
+
|
|
744
|
+
# Use __pregel_send to update parent state before re-raising
|
|
745
|
+
# This allows the mapped state to be included in the result
|
|
746
|
+
send = config.get("configurable", {}).get("__pregel_send")
|
|
747
|
+
if send:
|
|
748
|
+
# Convert dict to list of (key, value) tuples
|
|
749
|
+
updates = [(k, v) for k, v in parent_updates.items()]
|
|
750
|
+
send(updates)
|
|
751
|
+
logger.info(f"FR-006: Subgraph {node_name} mapped state: {list(parent_updates.keys())}")
|
|
752
|
+
|
|
753
|
+
# Re-raise to pause the graph
|
|
754
|
+
raise
|
|
755
|
+
|
|
756
|
+
# Normal completion path
|
|
757
|
+
if is_interrupted and interrupt_output_mapping:
|
|
758
|
+
parent_updates = _map_output_state(child_output, interrupt_output_mapping)
|
|
759
|
+
parent_updates["__interrupt__"] = child_output["__interrupt__"]
|
|
760
|
+
else:
|
|
761
|
+
parent_updates = _map_output_state(child_output, output_mapping)
|
|
762
|
+
|
|
763
|
+
parent_updates["current_step"] = node_name
|
|
764
|
+
|
|
765
|
+
return parent_updates
|
|
766
|
+
|
|
767
|
+
subgraph_node.__name__ = f"{node_name}_subgraph"
|
|
768
|
+
return subgraph_node
|