yamlgraph 0.3.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- examples/__init__.py +1 -0
- examples/codegen/__init__.py +5 -0
- examples/codegen/models/__init__.py +13 -0
- examples/codegen/models/schemas.py +76 -0
- examples/codegen/tests/__init__.py +1 -0
- examples/codegen/tests/test_ai_helpers.py +235 -0
- examples/codegen/tests/test_ast_analysis.py +174 -0
- examples/codegen/tests/test_code_analysis.py +134 -0
- examples/codegen/tests/test_code_context.py +301 -0
- examples/codegen/tests/test_code_nav.py +89 -0
- examples/codegen/tests/test_dependency_tools.py +119 -0
- examples/codegen/tests/test_example_tools.py +185 -0
- examples/codegen/tests/test_git_tools.py +112 -0
- examples/codegen/tests/test_impl_agent_schemas.py +193 -0
- examples/codegen/tests/test_impl_agent_v4_graph.py +94 -0
- examples/codegen/tests/test_jedi_analysis.py +226 -0
- examples/codegen/tests/test_meta_tools.py +250 -0
- examples/codegen/tests/test_plan_discovery_prompt.py +98 -0
- examples/codegen/tests/test_syntax_tools.py +85 -0
- examples/codegen/tests/test_synthesize_prompt.py +94 -0
- examples/codegen/tests/test_template_tools.py +244 -0
- examples/codegen/tools/__init__.py +80 -0
- examples/codegen/tools/ai_helpers.py +420 -0
- examples/codegen/tools/ast_analysis.py +92 -0
- examples/codegen/tools/code_context.py +180 -0
- examples/codegen/tools/code_nav.py +52 -0
- examples/codegen/tools/dependency_tools.py +120 -0
- examples/codegen/tools/example_tools.py +188 -0
- examples/codegen/tools/git_tools.py +151 -0
- examples/codegen/tools/impl_executor.py +614 -0
- examples/codegen/tools/jedi_analysis.py +311 -0
- examples/codegen/tools/meta_tools.py +202 -0
- examples/codegen/tools/syntax_tools.py +26 -0
- examples/codegen/tools/template_tools.py +356 -0
- examples/fastapi_interview.py +167 -0
- examples/npc/api/__init__.py +1 -0
- examples/npc/api/app.py +100 -0
- examples/npc/api/routes/__init__.py +5 -0
- examples/npc/api/routes/encounter.py +182 -0
- examples/npc/api/session.py +330 -0
- examples/npc/demo.py +387 -0
- examples/npc/nodes/__init__.py +5 -0
- examples/npc/nodes/image_node.py +92 -0
- examples/npc/run_encounter.py +230 -0
- examples/shared/__init__.py +0 -0
- examples/shared/replicate_tool.py +238 -0
- examples/storyboard/__init__.py +1 -0
- examples/storyboard/generate_videos.py +335 -0
- examples/storyboard/nodes/__init__.py +12 -0
- examples/storyboard/nodes/animated_character_node.py +248 -0
- examples/storyboard/nodes/animated_image_node.py +138 -0
- examples/storyboard/nodes/character_node.py +162 -0
- examples/storyboard/nodes/image_node.py +118 -0
- examples/storyboard/nodes/replicate_tool.py +49 -0
- examples/storyboard/retry_images.py +118 -0
- scripts/demo_async_executor.py +212 -0
- scripts/demo_interview_e2e.py +200 -0
- scripts/demo_streaming.py +140 -0
- scripts/run_interview_demo.py +94 -0
- scripts/test_interrupt_fix.py +26 -0
- tests/__init__.py +1 -0
- tests/conftest.py +178 -0
- tests/integration/__init__.py +1 -0
- tests/integration/test_animated_storyboard.py +63 -0
- tests/integration/test_cli_commands.py +242 -0
- tests/integration/test_colocated_prompts.py +139 -0
- tests/integration/test_map_demo.py +50 -0
- tests/integration/test_memory_demo.py +283 -0
- tests/integration/test_npc_api/__init__.py +1 -0
- tests/integration/test_npc_api/test_routes.py +357 -0
- tests/integration/test_npc_api/test_session.py +216 -0
- tests/integration/test_pipeline_flow.py +105 -0
- tests/integration/test_providers.py +163 -0
- tests/integration/test_resume.py +75 -0
- tests/integration/test_subgraph_integration.py +295 -0
- tests/integration/test_subgraph_interrupt.py +106 -0
- tests/unit/__init__.py +1 -0
- tests/unit/test_agent_nodes.py +355 -0
- tests/unit/test_async_executor.py +346 -0
- tests/unit/test_checkpointer.py +212 -0
- tests/unit/test_checkpointer_factory.py +212 -0
- tests/unit/test_cli.py +121 -0
- tests/unit/test_cli_package.py +81 -0
- tests/unit/test_compile_graph_map.py +132 -0
- tests/unit/test_conditions_routing.py +253 -0
- tests/unit/test_config.py +93 -0
- tests/unit/test_conversation_memory.py +276 -0
- tests/unit/test_database.py +145 -0
- tests/unit/test_deprecation.py +104 -0
- tests/unit/test_executor.py +172 -0
- tests/unit/test_executor_async.py +179 -0
- tests/unit/test_export.py +149 -0
- tests/unit/test_expressions.py +178 -0
- tests/unit/test_feature_brainstorm.py +194 -0
- tests/unit/test_format_prompt.py +145 -0
- tests/unit/test_generic_report.py +200 -0
- tests/unit/test_graph_commands.py +327 -0
- tests/unit/test_graph_linter.py +627 -0
- tests/unit/test_graph_loader.py +357 -0
- tests/unit/test_graph_schema.py +193 -0
- tests/unit/test_inline_schema.py +151 -0
- tests/unit/test_interrupt_node.py +182 -0
- tests/unit/test_issues.py +164 -0
- tests/unit/test_jinja2_prompts.py +85 -0
- tests/unit/test_json_extract.py +134 -0
- tests/unit/test_langsmith.py +600 -0
- tests/unit/test_langsmith_tools.py +204 -0
- tests/unit/test_llm_factory.py +109 -0
- tests/unit/test_llm_factory_async.py +118 -0
- tests/unit/test_loops.py +403 -0
- tests/unit/test_map_node.py +144 -0
- tests/unit/test_no_backward_compat.py +56 -0
- tests/unit/test_node_factory.py +348 -0
- tests/unit/test_passthrough_node.py +126 -0
- tests/unit/test_prompts.py +324 -0
- tests/unit/test_python_nodes.py +198 -0
- tests/unit/test_reliability.py +298 -0
- tests/unit/test_result_export.py +234 -0
- tests/unit/test_router.py +296 -0
- tests/unit/test_sanitize.py +99 -0
- tests/unit/test_schema_loader.py +295 -0
- tests/unit/test_shell_tools.py +229 -0
- tests/unit/test_state_builder.py +331 -0
- tests/unit/test_state_builder_map.py +104 -0
- tests/unit/test_state_config.py +197 -0
- tests/unit/test_streaming.py +307 -0
- tests/unit/test_subgraph.py +596 -0
- tests/unit/test_template.py +190 -0
- tests/unit/test_tool_call_integration.py +164 -0
- tests/unit/test_tool_call_node.py +178 -0
- tests/unit/test_tool_nodes.py +129 -0
- tests/unit/test_websearch.py +234 -0
- yamlgraph/__init__.py +35 -0
- yamlgraph/builder.py +110 -0
- yamlgraph/cli/__init__.py +159 -0
- yamlgraph/cli/__main__.py +6 -0
- yamlgraph/cli/commands.py +231 -0
- yamlgraph/cli/deprecation.py +92 -0
- yamlgraph/cli/graph_commands.py +541 -0
- yamlgraph/cli/validators.py +37 -0
- yamlgraph/config.py +67 -0
- yamlgraph/constants.py +70 -0
- yamlgraph/error_handlers.py +227 -0
- yamlgraph/executor.py +290 -0
- yamlgraph/executor_async.py +288 -0
- yamlgraph/graph_loader.py +451 -0
- yamlgraph/map_compiler.py +150 -0
- yamlgraph/models/__init__.py +36 -0
- yamlgraph/models/graph_schema.py +181 -0
- yamlgraph/models/schemas.py +124 -0
- yamlgraph/models/state_builder.py +236 -0
- yamlgraph/node_factory.py +768 -0
- yamlgraph/routing.py +87 -0
- yamlgraph/schema_loader.py +240 -0
- yamlgraph/storage/__init__.py +20 -0
- yamlgraph/storage/checkpointer.py +72 -0
- yamlgraph/storage/checkpointer_factory.py +123 -0
- yamlgraph/storage/database.py +320 -0
- yamlgraph/storage/export.py +269 -0
- yamlgraph/tools/__init__.py +1 -0
- yamlgraph/tools/agent.py +320 -0
- yamlgraph/tools/graph_linter.py +388 -0
- yamlgraph/tools/langsmith_tools.py +125 -0
- yamlgraph/tools/nodes.py +126 -0
- yamlgraph/tools/python_tool.py +179 -0
- yamlgraph/tools/shell.py +205 -0
- yamlgraph/tools/websearch.py +242 -0
- yamlgraph/utils/__init__.py +48 -0
- yamlgraph/utils/conditions.py +157 -0
- yamlgraph/utils/expressions.py +245 -0
- yamlgraph/utils/json_extract.py +104 -0
- yamlgraph/utils/langsmith.py +416 -0
- yamlgraph/utils/llm_factory.py +118 -0
- yamlgraph/utils/llm_factory_async.py +105 -0
- yamlgraph/utils/logging.py +104 -0
- yamlgraph/utils/prompts.py +171 -0
- yamlgraph/utils/sanitize.py +98 -0
- yamlgraph/utils/template.py +102 -0
- yamlgraph/utils/validators.py +181 -0
- yamlgraph-0.3.9.dist-info/METADATA +1105 -0
- yamlgraph-0.3.9.dist-info/RECORD +185 -0
- yamlgraph-0.3.9.dist-info/WHEEL +5 -0
- yamlgraph-0.3.9.dist-info/entry_points.txt +2 -0
- yamlgraph-0.3.9.dist-info/licenses/LICENSE +33 -0
- yamlgraph-0.3.9.dist-info/top_level.txt +4 -0
|
@@ -0,0 +1,451 @@
|
|
|
1
|
+
"""YAML Graph Loader - Compile YAML to LangGraph.
|
|
2
|
+
|
|
3
|
+
This module provides functionality to load graph definitions from YAML files
|
|
4
|
+
and compile them into LangGraph StateGraph instances.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
from collections.abc import Callable
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
import yaml
|
|
13
|
+
from langgraph.graph import END, StateGraph
|
|
14
|
+
|
|
15
|
+
from yamlgraph.constants import NodeType
|
|
16
|
+
from yamlgraph.map_compiler import compile_map_node
|
|
17
|
+
from yamlgraph.models.state_builder import build_state_class
|
|
18
|
+
from yamlgraph.node_factory import (
|
|
19
|
+
create_interrupt_node,
|
|
20
|
+
create_node_function,
|
|
21
|
+
create_passthrough_node,
|
|
22
|
+
create_subgraph_node,
|
|
23
|
+
create_tool_call_node,
|
|
24
|
+
resolve_class,
|
|
25
|
+
)
|
|
26
|
+
from yamlgraph.routing import make_expr_router_fn, make_router_fn
|
|
27
|
+
from yamlgraph.storage.checkpointer_factory import get_checkpointer
|
|
28
|
+
from yamlgraph.tools.agent import create_agent_node
|
|
29
|
+
from yamlgraph.tools.nodes import create_tool_node
|
|
30
|
+
from yamlgraph.tools.python_tool import (
|
|
31
|
+
create_python_node,
|
|
32
|
+
load_python_function,
|
|
33
|
+
parse_python_tools,
|
|
34
|
+
)
|
|
35
|
+
from yamlgraph.tools.shell import parse_tools
|
|
36
|
+
from yamlgraph.tools.websearch import parse_websearch_tools
|
|
37
|
+
from yamlgraph.utils.validators import validate_config
|
|
38
|
+
|
|
39
|
+
# Type alias for dynamic state
|
|
40
|
+
GraphState = dict[str, Any]
|
|
41
|
+
|
|
42
|
+
logger = logging.getLogger(__name__)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class GraphConfig:
|
|
46
|
+
"""Parsed graph configuration from YAML."""
|
|
47
|
+
|
|
48
|
+
def __init__(self, config: dict, source_path: Path | None = None):
|
|
49
|
+
"""Initialize from parsed YAML dict.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
config: Parsed YAML configuration dictionary
|
|
53
|
+
source_path: Path to the source YAML file (for subgraph resolution)
|
|
54
|
+
|
|
55
|
+
Raises:
|
|
56
|
+
ValueError: If config is invalid
|
|
57
|
+
"""
|
|
58
|
+
# Validate before storing
|
|
59
|
+
validate_config(config)
|
|
60
|
+
|
|
61
|
+
self.version = config.get("version", "1.0")
|
|
62
|
+
self.name = config.get("name", "unnamed")
|
|
63
|
+
self.description = config.get("description", "")
|
|
64
|
+
self.defaults = config.get("defaults", {})
|
|
65
|
+
self.nodes = config.get("nodes", {})
|
|
66
|
+
self.edges = config.get("edges", [])
|
|
67
|
+
self.tools = config.get("tools", {})
|
|
68
|
+
self.state_class = config.get("state_class", "")
|
|
69
|
+
self.loop_limits = config.get("loop_limits", {})
|
|
70
|
+
self.checkpointer = config.get("checkpointer")
|
|
71
|
+
# Store raw config for dynamic state building
|
|
72
|
+
self.raw_config = config
|
|
73
|
+
# Store source path for subgraph resolution
|
|
74
|
+
self.source_path = source_path
|
|
75
|
+
# Prompt resolution options (FR-A: graph-relative prompts)
|
|
76
|
+
self.prompts_relative = self.defaults.get("prompts_relative", False)
|
|
77
|
+
self.prompts_dir = self.defaults.get("prompts_dir")
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def load_graph_config(path: str | Path) -> GraphConfig:
|
|
81
|
+
"""Load and parse a YAML graph definition.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
path: Path to the YAML file
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
GraphConfig instance
|
|
88
|
+
|
|
89
|
+
Raises:
|
|
90
|
+
FileNotFoundError: If the file doesn't exist
|
|
91
|
+
ValueError: If the YAML is invalid or missing required fields
|
|
92
|
+
"""
|
|
93
|
+
path = Path(path)
|
|
94
|
+
if not path.exists():
|
|
95
|
+
raise FileNotFoundError(f"Graph config not found: {path}")
|
|
96
|
+
|
|
97
|
+
with open(path) as f:
|
|
98
|
+
config = yaml.safe_load(f)
|
|
99
|
+
|
|
100
|
+
return GraphConfig(config, source_path=path.resolve())
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def _resolve_state_class(config: GraphConfig) -> type:
|
|
104
|
+
"""Resolve the state class for the graph.
|
|
105
|
+
|
|
106
|
+
Uses dynamic state generation unless explicit state_class is set
|
|
107
|
+
(deprecated).
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
config: Graph configuration
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
TypedDict class for graph state
|
|
114
|
+
"""
|
|
115
|
+
if config.state_class and config.state_class != "yamlgraph.models.GraphState":
|
|
116
|
+
import warnings
|
|
117
|
+
|
|
118
|
+
warnings.warn(
|
|
119
|
+
f"state_class '{config.state_class}' is deprecated. "
|
|
120
|
+
"State is now auto-generated from graph config.",
|
|
121
|
+
DeprecationWarning,
|
|
122
|
+
stacklevel=2,
|
|
123
|
+
)
|
|
124
|
+
return resolve_class(config.state_class)
|
|
125
|
+
return build_state_class(config.raw_config)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def _parse_all_tools(
|
|
129
|
+
config: GraphConfig,
|
|
130
|
+
) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any], dict[str, Callable]]:
|
|
131
|
+
"""Parse shell, Python, and websearch tools from config.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
config: Graph configuration
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
Tuple of (shell_tools, python_tools, websearch_tools, callable_registry)
|
|
138
|
+
callable_registry maps tool names to actual callable functions for tool_call nodes
|
|
139
|
+
"""
|
|
140
|
+
tools = parse_tools(config.tools)
|
|
141
|
+
python_tools = parse_python_tools(config.tools)
|
|
142
|
+
websearch_tools = parse_websearch_tools(config.tools)
|
|
143
|
+
|
|
144
|
+
# Build callable registry for tool_call nodes
|
|
145
|
+
callable_registry: dict[str, Callable] = {}
|
|
146
|
+
for name, tool_config in python_tools.items():
|
|
147
|
+
try:
|
|
148
|
+
callable_registry[name] = load_python_function(tool_config)
|
|
149
|
+
except (ImportError, AttributeError) as e:
|
|
150
|
+
logger.warning(f"Failed to load tool '{name}': {e}")
|
|
151
|
+
|
|
152
|
+
if tools:
|
|
153
|
+
logger.info(f"Parsed {len(tools)} shell tools: {', '.join(tools.keys())}")
|
|
154
|
+
if python_tools:
|
|
155
|
+
logger.info(
|
|
156
|
+
f"Parsed {len(python_tools)} Python tools: {', '.join(python_tools.keys())}"
|
|
157
|
+
)
|
|
158
|
+
if websearch_tools:
|
|
159
|
+
logger.info(
|
|
160
|
+
f"Parsed {len(websearch_tools)} websearch tools: {', '.join(websearch_tools.keys())}"
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
return tools, python_tools, websearch_tools, callable_registry
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def _compile_node(
|
|
167
|
+
node_name: str,
|
|
168
|
+
node_config: dict[str, Any],
|
|
169
|
+
graph: StateGraph,
|
|
170
|
+
config: GraphConfig,
|
|
171
|
+
tools: dict[str, Any],
|
|
172
|
+
python_tools: dict[str, Any],
|
|
173
|
+
websearch_tools: dict[str, Any],
|
|
174
|
+
callable_registry: dict[str, Callable],
|
|
175
|
+
) -> tuple[str, Any] | None:
|
|
176
|
+
"""Compile a single node and add to graph.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
node_name: Name of the node
|
|
180
|
+
node_config: Node configuration dict
|
|
181
|
+
graph: StateGraph to add node to
|
|
182
|
+
config: Full graph config for defaults
|
|
183
|
+
tools: Shell tools registry
|
|
184
|
+
python_tools: Python tools registry
|
|
185
|
+
websearch_tools: Web search tools registry (LangChain StructuredTool)
|
|
186
|
+
callable_registry: Loaded callable functions for tool_call nodes
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
Tuple of (node_name, map_info) for map nodes, None otherwise
|
|
190
|
+
"""
|
|
191
|
+
# Copy node config and add loop_limit if specified
|
|
192
|
+
enriched_config = dict(node_config)
|
|
193
|
+
if node_name in config.loop_limits:
|
|
194
|
+
enriched_config["loop_limit"] = config.loop_limits[node_name]
|
|
195
|
+
|
|
196
|
+
# Extract prompts path config from defaults (FR-A)
|
|
197
|
+
prompts_relative = config.defaults.get("prompts_relative", False)
|
|
198
|
+
prompts_dir = config.defaults.get("prompts_dir")
|
|
199
|
+
if prompts_dir:
|
|
200
|
+
prompts_dir = Path(prompts_dir)
|
|
201
|
+
|
|
202
|
+
node_type = node_config.get("type", NodeType.LLM)
|
|
203
|
+
|
|
204
|
+
if node_type == NodeType.TOOL:
|
|
205
|
+
node_fn = create_tool_node(node_name, enriched_config, tools)
|
|
206
|
+
graph.add_node(node_name, node_fn)
|
|
207
|
+
elif node_type == NodeType.PYTHON:
|
|
208
|
+
node_fn = create_python_node(node_name, enriched_config, python_tools)
|
|
209
|
+
graph.add_node(node_name, node_fn)
|
|
210
|
+
elif node_type == NodeType.AGENT:
|
|
211
|
+
node_fn = create_agent_node(
|
|
212
|
+
node_name, enriched_config, tools, websearch_tools, python_tools
|
|
213
|
+
)
|
|
214
|
+
graph.add_node(node_name, node_fn)
|
|
215
|
+
elif node_type == NodeType.MAP:
|
|
216
|
+
map_edge_fn, sub_node_name = compile_map_node(
|
|
217
|
+
node_name, enriched_config, graph, config.defaults, callable_registry
|
|
218
|
+
)
|
|
219
|
+
logger.info(f"Added node: {node_name} (type={node_type})")
|
|
220
|
+
return (node_name, (map_edge_fn, sub_node_name))
|
|
221
|
+
elif node_type == NodeType.TOOL_CALL:
|
|
222
|
+
# Dynamic tool call from state
|
|
223
|
+
node_fn = create_tool_call_node(node_name, enriched_config, callable_registry)
|
|
224
|
+
graph.add_node(node_name, node_fn)
|
|
225
|
+
elif node_type == NodeType.INTERRUPT:
|
|
226
|
+
# Human-in-the-loop interrupt node
|
|
227
|
+
node_fn = create_interrupt_node(
|
|
228
|
+
node_name,
|
|
229
|
+
enriched_config,
|
|
230
|
+
graph_path=config.source_path,
|
|
231
|
+
prompts_dir=prompts_dir,
|
|
232
|
+
prompts_relative=prompts_relative,
|
|
233
|
+
)
|
|
234
|
+
graph.add_node(node_name, node_fn)
|
|
235
|
+
elif node_type == NodeType.PASSTHROUGH:
|
|
236
|
+
# Simple state transformation node
|
|
237
|
+
node_fn = create_passthrough_node(node_name, enriched_config)
|
|
238
|
+
graph.add_node(node_name, node_fn)
|
|
239
|
+
elif node_type == NodeType.SUBGRAPH:
|
|
240
|
+
# Subgraph node - compose graphs from YAML
|
|
241
|
+
if not config.source_path:
|
|
242
|
+
raise ValueError(
|
|
243
|
+
f"Cannot resolve subgraph path for node '{node_name}': "
|
|
244
|
+
"parent graph has no source_path"
|
|
245
|
+
)
|
|
246
|
+
node_fn = create_subgraph_node(
|
|
247
|
+
node_name,
|
|
248
|
+
enriched_config,
|
|
249
|
+
parent_graph_path=config.source_path,
|
|
250
|
+
)
|
|
251
|
+
graph.add_node(node_name, node_fn)
|
|
252
|
+
else:
|
|
253
|
+
# LLM and router nodes
|
|
254
|
+
node_fn = create_node_function(
|
|
255
|
+
node_name,
|
|
256
|
+
enriched_config,
|
|
257
|
+
config.defaults,
|
|
258
|
+
graph_path=config.source_path,
|
|
259
|
+
)
|
|
260
|
+
graph.add_node(node_name, node_fn)
|
|
261
|
+
|
|
262
|
+
logger.info(f"Added node: {node_name} (type={node_type})")
|
|
263
|
+
return None
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
def _compile_nodes(
|
|
267
|
+
config: GraphConfig,
|
|
268
|
+
graph: StateGraph,
|
|
269
|
+
tools: dict[str, Any],
|
|
270
|
+
python_tools: dict[str, Any],
|
|
271
|
+
websearch_tools: dict[str, Any],
|
|
272
|
+
callable_registry: dict[str, Callable],
|
|
273
|
+
) -> dict[str, tuple]:
|
|
274
|
+
"""Compile all nodes and add to graph.
|
|
275
|
+
|
|
276
|
+
Args:
|
|
277
|
+
config: Graph configuration
|
|
278
|
+
graph: StateGraph to add nodes to
|
|
279
|
+
tools: Shell tools registry
|
|
280
|
+
python_tools: Python tools registry
|
|
281
|
+
websearch_tools: Web search tools registry
|
|
282
|
+
callable_registry: Loaded callable functions for tool_call nodes
|
|
283
|
+
|
|
284
|
+
Returns:
|
|
285
|
+
Dict of map_nodes: name -> (map_edge_fn, sub_node_name)
|
|
286
|
+
"""
|
|
287
|
+
map_nodes: dict[str, tuple] = {}
|
|
288
|
+
|
|
289
|
+
for node_name, node_config in config.nodes.items():
|
|
290
|
+
result = _compile_node(
|
|
291
|
+
node_name,
|
|
292
|
+
node_config,
|
|
293
|
+
graph,
|
|
294
|
+
config,
|
|
295
|
+
tools,
|
|
296
|
+
python_tools,
|
|
297
|
+
websearch_tools,
|
|
298
|
+
callable_registry,
|
|
299
|
+
)
|
|
300
|
+
if result:
|
|
301
|
+
map_nodes[result[0]] = result[1]
|
|
302
|
+
|
|
303
|
+
return map_nodes
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
def _process_edge(
|
|
307
|
+
edge: dict[str, Any],
|
|
308
|
+
graph: StateGraph,
|
|
309
|
+
map_nodes: dict[str, tuple],
|
|
310
|
+
router_edges: dict[str, list],
|
|
311
|
+
expression_edges: dict[str, list[tuple[str, str]]],
|
|
312
|
+
) -> None:
|
|
313
|
+
"""Process a single edge and add to graph or edge tracking dicts.
|
|
314
|
+
|
|
315
|
+
Args:
|
|
316
|
+
edge: Edge configuration dict
|
|
317
|
+
graph: StateGraph to add edges to
|
|
318
|
+
map_nodes: Map node tracking dict
|
|
319
|
+
router_edges: Dict to collect router edges
|
|
320
|
+
expression_edges: Dict to collect expression-based edges
|
|
321
|
+
"""
|
|
322
|
+
from_node = edge["from"]
|
|
323
|
+
to_node = edge["to"]
|
|
324
|
+
condition = edge.get("condition")
|
|
325
|
+
edge_type = edge.get("type")
|
|
326
|
+
|
|
327
|
+
if from_node == "START":
|
|
328
|
+
graph.set_entry_point(to_node)
|
|
329
|
+
elif from_node in map_nodes and to_node in map_nodes:
|
|
330
|
+
# Edge from map node TO another map node: sub_node → map_edge_fn
|
|
331
|
+
_, from_sub = map_nodes[from_node]
|
|
332
|
+
to_map_edge_fn, to_sub = map_nodes[to_node]
|
|
333
|
+
graph.add_conditional_edges(from_sub, to_map_edge_fn, [to_sub])
|
|
334
|
+
elif isinstance(to_node, str) and to_node in map_nodes:
|
|
335
|
+
# Edge TO a map node: use conditional edge with Send function
|
|
336
|
+
map_edge_fn, sub_node_name = map_nodes[to_node]
|
|
337
|
+
graph.add_conditional_edges(from_node, map_edge_fn, [sub_node_name])
|
|
338
|
+
elif from_node in map_nodes:
|
|
339
|
+
# Edge FROM a map node: wire sub_node to next_node for fan-in
|
|
340
|
+
_, sub_node_name = map_nodes[from_node]
|
|
341
|
+
target = END if to_node == "END" else to_node
|
|
342
|
+
graph.add_edge(sub_node_name, target)
|
|
343
|
+
elif edge_type == "conditional" and isinstance(to_node, list):
|
|
344
|
+
# Router-style conditional edge: store for later processing
|
|
345
|
+
router_edges[from_node] = to_node
|
|
346
|
+
elif condition:
|
|
347
|
+
# Expression-based condition (e.g., "critique.score < 0.8")
|
|
348
|
+
if from_node not in expression_edges:
|
|
349
|
+
expression_edges[from_node] = []
|
|
350
|
+
target = END if to_node == "END" else to_node
|
|
351
|
+
expression_edges[from_node].append((condition, target))
|
|
352
|
+
elif to_node == "END":
|
|
353
|
+
graph.add_edge(from_node, END)
|
|
354
|
+
else:
|
|
355
|
+
graph.add_edge(from_node, to_node)
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
def _add_conditional_edges(
|
|
359
|
+
graph: StateGraph,
|
|
360
|
+
router_edges: dict[str, list],
|
|
361
|
+
expression_edges: dict[str, list[tuple[str, str]]],
|
|
362
|
+
) -> None:
|
|
363
|
+
"""Add router and expression conditional edges to graph.
|
|
364
|
+
|
|
365
|
+
Args:
|
|
366
|
+
graph: StateGraph to add edges to
|
|
367
|
+
router_edges: Router-style conditional edges
|
|
368
|
+
expression_edges: Expression-based conditional edges
|
|
369
|
+
"""
|
|
370
|
+
# Add router conditional edges
|
|
371
|
+
for source_node, target_nodes in router_edges.items():
|
|
372
|
+
route_mapping = {target: target for target in target_nodes}
|
|
373
|
+
graph.add_conditional_edges(
|
|
374
|
+
source_node,
|
|
375
|
+
make_router_fn(target_nodes),
|
|
376
|
+
route_mapping,
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
# Add expression-based conditional edges
|
|
380
|
+
for source_node, expr_edges in expression_edges.items():
|
|
381
|
+
targets = {target for _, target in expr_edges}
|
|
382
|
+
targets.add(END) # Always include END as fallback
|
|
383
|
+
route_mapping = {t: (END if t == END else t) for t in targets}
|
|
384
|
+
graph.add_conditional_edges(
|
|
385
|
+
source_node,
|
|
386
|
+
make_expr_router_fn(expr_edges, source_node),
|
|
387
|
+
route_mapping,
|
|
388
|
+
)
|
|
389
|
+
|
|
390
|
+
|
|
391
|
+
def compile_graph(config: GraphConfig) -> StateGraph:
|
|
392
|
+
"""Compile a GraphConfig to a LangGraph StateGraph.
|
|
393
|
+
|
|
394
|
+
Args:
|
|
395
|
+
config: Parsed graph configuration
|
|
396
|
+
|
|
397
|
+
Returns:
|
|
398
|
+
StateGraph ready for compilation
|
|
399
|
+
"""
|
|
400
|
+
# Build state class and create graph
|
|
401
|
+
state_class = _resolve_state_class(config)
|
|
402
|
+
graph = StateGraph(state_class)
|
|
403
|
+
|
|
404
|
+
# Parse all tools
|
|
405
|
+
tools, python_tools, websearch_tools, callable_registry = _parse_all_tools(config)
|
|
406
|
+
|
|
407
|
+
# Compile all nodes
|
|
408
|
+
map_nodes = _compile_nodes(
|
|
409
|
+
config, graph, tools, python_tools, websearch_tools, callable_registry
|
|
410
|
+
)
|
|
411
|
+
|
|
412
|
+
# Process edges
|
|
413
|
+
router_edges: dict[str, list] = {}
|
|
414
|
+
expression_edges: dict[str, list[tuple[str, str]]] = {}
|
|
415
|
+
|
|
416
|
+
for edge in config.edges:
|
|
417
|
+
_process_edge(edge, graph, map_nodes, router_edges, expression_edges)
|
|
418
|
+
|
|
419
|
+
# Add conditional edges
|
|
420
|
+
_add_conditional_edges(graph, router_edges, expression_edges)
|
|
421
|
+
|
|
422
|
+
return graph
|
|
423
|
+
|
|
424
|
+
|
|
425
|
+
def load_and_compile(path: str | Path) -> StateGraph:
|
|
426
|
+
"""Load YAML and compile to StateGraph.
|
|
427
|
+
|
|
428
|
+
Convenience function combining load_graph_config and compile_graph.
|
|
429
|
+
|
|
430
|
+
Args:
|
|
431
|
+
path: Path to YAML graph definition
|
|
432
|
+
|
|
433
|
+
Returns:
|
|
434
|
+
StateGraph ready for compilation
|
|
435
|
+
"""
|
|
436
|
+
config = load_graph_config(path)
|
|
437
|
+
logger.info(f"Loaded graph config: {config.name} v{config.version}")
|
|
438
|
+
return compile_graph(config)
|
|
439
|
+
|
|
440
|
+
|
|
441
|
+
def get_checkpointer_for_graph(config: GraphConfig, *, async_mode: bool = False):
|
|
442
|
+
"""Get checkpointer from graph config.
|
|
443
|
+
|
|
444
|
+
Args:
|
|
445
|
+
config: Graph configuration
|
|
446
|
+
async_mode: If True, return async-compatible saver
|
|
447
|
+
|
|
448
|
+
Returns:
|
|
449
|
+
Configured checkpointer or None if not specified
|
|
450
|
+
"""
|
|
451
|
+
return get_checkpointer(config.checkpointer, async_mode=async_mode)
|
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
"""Map node compiler - Handles type: map node compilation.
|
|
2
|
+
|
|
3
|
+
This module provides functionality to compile map nodes that fan out
|
|
4
|
+
to sub-nodes for parallel processing using LangGraph's Send mechanism.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
from collections.abc import Callable
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
from langgraph.graph import StateGraph
|
|
12
|
+
from langgraph.types import Send
|
|
13
|
+
|
|
14
|
+
from yamlgraph.constants import NodeType
|
|
15
|
+
from yamlgraph.node_factory import create_node_function, create_tool_call_node
|
|
16
|
+
from yamlgraph.utils.expressions import resolve_state_expression
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def wrap_for_reducer(
|
|
22
|
+
node_fn: Callable[[dict], dict],
|
|
23
|
+
collect_key: str,
|
|
24
|
+
state_key: str,
|
|
25
|
+
) -> Callable[[dict], dict]:
|
|
26
|
+
"""Wrap sub-node output for Annotated reducer aggregation.
|
|
27
|
+
|
|
28
|
+
Handles error propagation: if a map branch fails, the error is
|
|
29
|
+
included in the result with the _map_index for tracking.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
node_fn: The original node function
|
|
33
|
+
collect_key: State key where results are collected
|
|
34
|
+
state_key: Key to extract from node result
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
Wrapped function that outputs in reducer-compatible format
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def wrapped(state: dict) -> dict:
|
|
41
|
+
try:
|
|
42
|
+
result = node_fn(state)
|
|
43
|
+
except Exception as e:
|
|
44
|
+
# Propagate error with map index
|
|
45
|
+
from yamlgraph.models import PipelineError
|
|
46
|
+
|
|
47
|
+
error_result = {
|
|
48
|
+
"_map_index": state.get("_map_index", 0),
|
|
49
|
+
"_error": str(e),
|
|
50
|
+
"_error_type": type(e).__name__,
|
|
51
|
+
}
|
|
52
|
+
return {
|
|
53
|
+
collect_key: [error_result],
|
|
54
|
+
"errors": [PipelineError.from_exception(e, node="map_subnode")],
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
# Check if result contains an error
|
|
58
|
+
if "errors" in result or "error" in result:
|
|
59
|
+
error_result = {
|
|
60
|
+
"_map_index": state.get("_map_index", 0),
|
|
61
|
+
"_error": str(result.get("errors") or result.get("error")),
|
|
62
|
+
}
|
|
63
|
+
# Preserve errors in output
|
|
64
|
+
output = {collect_key: [error_result]}
|
|
65
|
+
if "errors" in result:
|
|
66
|
+
output["errors"] = result["errors"]
|
|
67
|
+
return output
|
|
68
|
+
|
|
69
|
+
extracted = result.get(state_key, result)
|
|
70
|
+
|
|
71
|
+
# Convert Pydantic models to dicts
|
|
72
|
+
if hasattr(extracted, "model_dump"):
|
|
73
|
+
extracted = extracted.model_dump()
|
|
74
|
+
|
|
75
|
+
# Include _map_index if present for ordering
|
|
76
|
+
if "_map_index" in state:
|
|
77
|
+
if isinstance(extracted, dict):
|
|
78
|
+
extracted = {"_map_index": state["_map_index"], **extracted}
|
|
79
|
+
else:
|
|
80
|
+
extracted = {"_map_index": state["_map_index"], "value": extracted}
|
|
81
|
+
|
|
82
|
+
return {collect_key: [extracted]}
|
|
83
|
+
|
|
84
|
+
return wrapped
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def compile_map_node(
|
|
88
|
+
name: str,
|
|
89
|
+
config: dict[str, Any],
|
|
90
|
+
builder: StateGraph,
|
|
91
|
+
defaults: dict[str, Any],
|
|
92
|
+
tools_registry: dict[str, Any] | None = None,
|
|
93
|
+
) -> tuple[Callable[[dict], list[Send]], str]:
|
|
94
|
+
"""Compile type: map node using LangGraph Send.
|
|
95
|
+
|
|
96
|
+
Creates a sub-node and returns a map edge function that fans out
|
|
97
|
+
to the sub-node for each item in the list.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
name: Name of the map node
|
|
101
|
+
config: Map node configuration with 'over', 'as', 'node', 'collect'
|
|
102
|
+
builder: StateGraph builder to add sub-node to
|
|
103
|
+
defaults: Default configuration for nodes
|
|
104
|
+
tools_registry: Optional tools registry for tool_call sub-nodes
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
Tuple of (map_edge_function, sub_node_name)
|
|
108
|
+
"""
|
|
109
|
+
over_expr = config["over"]
|
|
110
|
+
item_var = config["as"]
|
|
111
|
+
sub_node_name = f"_map_{name}_sub"
|
|
112
|
+
collect_key = config["collect"]
|
|
113
|
+
sub_node_config = dict(config["node"]) # Copy to avoid mutating original
|
|
114
|
+
state_key = sub_node_config.get("state_key", "result")
|
|
115
|
+
sub_node_type = sub_node_config.get("type", "llm")
|
|
116
|
+
|
|
117
|
+
# Auto-inject the 'as' variable into sub-node's variables
|
|
118
|
+
# So the prompt can access it as {item_var}
|
|
119
|
+
sub_variables = dict(sub_node_config.get("variables", {}))
|
|
120
|
+
sub_variables[item_var] = f"{{state.{item_var}}}"
|
|
121
|
+
sub_node_config["variables"] = sub_variables
|
|
122
|
+
|
|
123
|
+
# Create sub-node based on type
|
|
124
|
+
if sub_node_type == NodeType.TOOL_CALL:
|
|
125
|
+
if tools_registry is None:
|
|
126
|
+
raise ValueError(
|
|
127
|
+
f"Map node '{name}' has tool_call sub-node but no tools_registry"
|
|
128
|
+
)
|
|
129
|
+
sub_node = create_tool_call_node(sub_node_name, sub_node_config, tools_registry)
|
|
130
|
+
else:
|
|
131
|
+
sub_node = create_node_function(sub_node_name, sub_node_config, defaults)
|
|
132
|
+
|
|
133
|
+
wrapped_node = wrap_for_reducer(sub_node, collect_key, state_key)
|
|
134
|
+
builder.add_node(sub_node_name, wrapped_node)
|
|
135
|
+
|
|
136
|
+
# Create fan-out edge function using Send
|
|
137
|
+
def map_edge(state: dict) -> list[Send]:
|
|
138
|
+
items = resolve_state_expression(over_expr, state)
|
|
139
|
+
|
|
140
|
+
if not isinstance(items, list):
|
|
141
|
+
raise TypeError(
|
|
142
|
+
f"Map 'over' must resolve to list, got {type(items).__name__}"
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
return [
|
|
146
|
+
Send(sub_node_name, {**state, item_var: item, "_map_index": i})
|
|
147
|
+
for i, item in enumerate(items)
|
|
148
|
+
]
|
|
149
|
+
|
|
150
|
+
return map_edge, sub_node_name
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
"""Pydantic models and state definitions.
|
|
2
|
+
|
|
3
|
+
Framework models for error handling and generic reports.
|
|
4
|
+
State is now generated dynamically by state_builder.py.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from yamlgraph.models.graph_schema import (
|
|
8
|
+
EdgeConfig,
|
|
9
|
+
GraphConfigSchema,
|
|
10
|
+
NodeConfig,
|
|
11
|
+
validate_graph_schema,
|
|
12
|
+
)
|
|
13
|
+
from yamlgraph.models.schemas import (
|
|
14
|
+
ErrorType,
|
|
15
|
+
GenericReport,
|
|
16
|
+
PipelineError,
|
|
17
|
+
)
|
|
18
|
+
from yamlgraph.models.state_builder import (
|
|
19
|
+
build_state_class,
|
|
20
|
+
create_initial_state,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
__all__ = [
|
|
24
|
+
# Framework models
|
|
25
|
+
"ErrorType",
|
|
26
|
+
"PipelineError",
|
|
27
|
+
"GenericReport",
|
|
28
|
+
# Graph config schema
|
|
29
|
+
"GraphConfigSchema",
|
|
30
|
+
"NodeConfig",
|
|
31
|
+
"EdgeConfig",
|
|
32
|
+
"validate_graph_schema",
|
|
33
|
+
# Dynamic state generation
|
|
34
|
+
"build_state_class",
|
|
35
|
+
"create_initial_state",
|
|
36
|
+
]
|