yamlgraph 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of yamlgraph might be problematic. Click here for more details.
- examples/__init__.py +1 -0
- examples/storyboard/__init__.py +1 -0
- examples/storyboard/generate_videos.py +335 -0
- examples/storyboard/nodes/__init__.py +10 -0
- examples/storyboard/nodes/animated_character_node.py +248 -0
- examples/storyboard/nodes/animated_image_node.py +138 -0
- examples/storyboard/nodes/character_node.py +162 -0
- examples/storyboard/nodes/image_node.py +118 -0
- examples/storyboard/nodes/replicate_tool.py +238 -0
- examples/storyboard/retry_images.py +118 -0
- tests/__init__.py +1 -0
- tests/conftest.py +178 -0
- tests/integration/__init__.py +1 -0
- tests/integration/test_animated_storyboard.py +63 -0
- tests/integration/test_cli_commands.py +242 -0
- tests/integration/test_map_demo.py +50 -0
- tests/integration/test_memory_demo.py +281 -0
- tests/integration/test_pipeline_flow.py +105 -0
- tests/integration/test_providers.py +163 -0
- tests/integration/test_resume.py +75 -0
- tests/unit/__init__.py +1 -0
- tests/unit/test_agent_nodes.py +200 -0
- tests/unit/test_checkpointer.py +212 -0
- tests/unit/test_cli.py +121 -0
- tests/unit/test_cli_package.py +81 -0
- tests/unit/test_compile_graph_map.py +132 -0
- tests/unit/test_conditions_routing.py +253 -0
- tests/unit/test_config.py +93 -0
- tests/unit/test_conversation_memory.py +270 -0
- tests/unit/test_database.py +145 -0
- tests/unit/test_deprecation.py +104 -0
- tests/unit/test_executor.py +60 -0
- tests/unit/test_executor_async.py +179 -0
- tests/unit/test_export.py +150 -0
- tests/unit/test_expressions.py +178 -0
- tests/unit/test_format_prompt.py +145 -0
- tests/unit/test_generic_report.py +200 -0
- tests/unit/test_graph_commands.py +327 -0
- tests/unit/test_graph_loader.py +299 -0
- tests/unit/test_graph_schema.py +193 -0
- tests/unit/test_inline_schema.py +151 -0
- tests/unit/test_issues.py +164 -0
- tests/unit/test_jinja2_prompts.py +85 -0
- tests/unit/test_langsmith.py +319 -0
- tests/unit/test_llm_factory.py +109 -0
- tests/unit/test_llm_factory_async.py +118 -0
- tests/unit/test_loops.py +403 -0
- tests/unit/test_map_node.py +144 -0
- tests/unit/test_no_backward_compat.py +56 -0
- tests/unit/test_node_factory.py +225 -0
- tests/unit/test_prompts.py +166 -0
- tests/unit/test_python_nodes.py +198 -0
- tests/unit/test_reliability.py +298 -0
- tests/unit/test_result_export.py +234 -0
- tests/unit/test_router.py +296 -0
- tests/unit/test_sanitize.py +99 -0
- tests/unit/test_schema_loader.py +295 -0
- tests/unit/test_shell_tools.py +229 -0
- tests/unit/test_state_builder.py +331 -0
- tests/unit/test_state_builder_map.py +104 -0
- tests/unit/test_state_config.py +197 -0
- tests/unit/test_template.py +190 -0
- tests/unit/test_tool_nodes.py +129 -0
- yamlgraph/__init__.py +35 -0
- yamlgraph/builder.py +110 -0
- yamlgraph/cli/__init__.py +139 -0
- yamlgraph/cli/__main__.py +6 -0
- yamlgraph/cli/commands.py +232 -0
- yamlgraph/cli/deprecation.py +92 -0
- yamlgraph/cli/graph_commands.py +382 -0
- yamlgraph/cli/validators.py +37 -0
- yamlgraph/config.py +67 -0
- yamlgraph/constants.py +66 -0
- yamlgraph/error_handlers.py +226 -0
- yamlgraph/executor.py +275 -0
- yamlgraph/executor_async.py +122 -0
- yamlgraph/graph_loader.py +337 -0
- yamlgraph/map_compiler.py +138 -0
- yamlgraph/models/__init__.py +36 -0
- yamlgraph/models/graph_schema.py +141 -0
- yamlgraph/models/schemas.py +124 -0
- yamlgraph/models/state_builder.py +236 -0
- yamlgraph/node_factory.py +240 -0
- yamlgraph/routing.py +87 -0
- yamlgraph/schema_loader.py +160 -0
- yamlgraph/storage/__init__.py +17 -0
- yamlgraph/storage/checkpointer.py +72 -0
- yamlgraph/storage/database.py +320 -0
- yamlgraph/storage/export.py +269 -0
- yamlgraph/tools/__init__.py +1 -0
- yamlgraph/tools/agent.py +235 -0
- yamlgraph/tools/nodes.py +124 -0
- yamlgraph/tools/python_tool.py +178 -0
- yamlgraph/tools/shell.py +205 -0
- yamlgraph/utils/__init__.py +47 -0
- yamlgraph/utils/conditions.py +157 -0
- yamlgraph/utils/expressions.py +111 -0
- yamlgraph/utils/langsmith.py +308 -0
- yamlgraph/utils/llm_factory.py +118 -0
- yamlgraph/utils/llm_factory_async.py +105 -0
- yamlgraph/utils/logging.py +127 -0
- yamlgraph/utils/prompts.py +116 -0
- yamlgraph/utils/sanitize.py +98 -0
- yamlgraph/utils/template.py +102 -0
- yamlgraph/utils/validators.py +181 -0
- yamlgraph-0.1.1.dist-info/METADATA +854 -0
- yamlgraph-0.1.1.dist-info/RECORD +111 -0
- yamlgraph-0.1.1.dist-info/WHEEL +5 -0
- yamlgraph-0.1.1.dist-info/entry_points.txt +2 -0
- yamlgraph-0.1.1.dist-info/licenses/LICENSE +21 -0
- yamlgraph-0.1.1.dist-info/top_level.txt +3 -0
|
@@ -0,0 +1,337 @@
|
|
|
1
|
+
"""YAML Graph Loader - Compile YAML to LangGraph.
|
|
2
|
+
|
|
3
|
+
This module provides functionality to load graph definitions from YAML files
|
|
4
|
+
and compile them into LangGraph StateGraph instances.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
import yaml
|
|
12
|
+
from langgraph.graph import END, StateGraph
|
|
13
|
+
|
|
14
|
+
from yamlgraph.constants import NodeType
|
|
15
|
+
from yamlgraph.map_compiler import compile_map_node
|
|
16
|
+
from yamlgraph.models.state_builder import build_state_class
|
|
17
|
+
from yamlgraph.node_factory import create_node_function, resolve_class
|
|
18
|
+
from yamlgraph.routing import make_expr_router_fn, make_router_fn
|
|
19
|
+
from yamlgraph.tools.agent import create_agent_node
|
|
20
|
+
from yamlgraph.tools.nodes import create_tool_node
|
|
21
|
+
from yamlgraph.tools.python_tool import create_python_node, parse_python_tools
|
|
22
|
+
from yamlgraph.tools.shell import parse_tools
|
|
23
|
+
from yamlgraph.utils.validators import validate_config
|
|
24
|
+
|
|
25
|
+
# Type alias for dynamic state
|
|
26
|
+
GraphState = dict[str, Any]
|
|
27
|
+
|
|
28
|
+
logger = logging.getLogger(__name__)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class GraphConfig:
|
|
32
|
+
"""Parsed graph configuration from YAML."""
|
|
33
|
+
|
|
34
|
+
def __init__(self, config: dict):
|
|
35
|
+
"""Initialize from parsed YAML dict.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
config: Parsed YAML configuration dictionary
|
|
39
|
+
|
|
40
|
+
Raises:
|
|
41
|
+
ValueError: If config is invalid
|
|
42
|
+
"""
|
|
43
|
+
# Validate before storing
|
|
44
|
+
validate_config(config)
|
|
45
|
+
|
|
46
|
+
self.version = config.get("version", "1.0")
|
|
47
|
+
self.name = config.get("name", "unnamed")
|
|
48
|
+
self.description = config.get("description", "")
|
|
49
|
+
self.defaults = config.get("defaults", {})
|
|
50
|
+
self.nodes = config.get("nodes", {})
|
|
51
|
+
self.edges = config.get("edges", [])
|
|
52
|
+
self.tools = config.get("tools", {})
|
|
53
|
+
self.state_class = config.get("state_class", "")
|
|
54
|
+
self.loop_limits = config.get("loop_limits", {})
|
|
55
|
+
# Store raw config for dynamic state building
|
|
56
|
+
self.raw_config = config
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def load_graph_config(path: str | Path) -> GraphConfig:
|
|
60
|
+
"""Load and parse a YAML graph definition.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
path: Path to the YAML file
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
GraphConfig instance
|
|
67
|
+
|
|
68
|
+
Raises:
|
|
69
|
+
FileNotFoundError: If the file doesn't exist
|
|
70
|
+
ValueError: If the YAML is invalid or missing required fields
|
|
71
|
+
"""
|
|
72
|
+
path = Path(path)
|
|
73
|
+
if not path.exists():
|
|
74
|
+
raise FileNotFoundError(f"Graph config not found: {path}")
|
|
75
|
+
|
|
76
|
+
with open(path) as f:
|
|
77
|
+
config = yaml.safe_load(f)
|
|
78
|
+
|
|
79
|
+
return GraphConfig(config)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def _resolve_state_class(config: GraphConfig) -> type:
|
|
83
|
+
"""Resolve the state class for the graph.
|
|
84
|
+
|
|
85
|
+
Uses dynamic state generation unless explicit state_class is set
|
|
86
|
+
(deprecated).
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
config: Graph configuration
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
TypedDict class for graph state
|
|
93
|
+
"""
|
|
94
|
+
if config.state_class and config.state_class != "yamlgraph.models.GraphState":
|
|
95
|
+
import warnings
|
|
96
|
+
|
|
97
|
+
warnings.warn(
|
|
98
|
+
f"state_class '{config.state_class}' is deprecated. "
|
|
99
|
+
"State is now auto-generated from graph config.",
|
|
100
|
+
DeprecationWarning,
|
|
101
|
+
stacklevel=2,
|
|
102
|
+
)
|
|
103
|
+
return resolve_class(config.state_class)
|
|
104
|
+
return build_state_class(config.raw_config)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def _parse_all_tools(
|
|
108
|
+
config: GraphConfig,
|
|
109
|
+
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
110
|
+
"""Parse shell and Python tools from config.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
config: Graph configuration
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
Tuple of (shell_tools, python_tools) dictionaries
|
|
117
|
+
"""
|
|
118
|
+
tools = parse_tools(config.tools)
|
|
119
|
+
python_tools = parse_python_tools(config.tools)
|
|
120
|
+
|
|
121
|
+
if tools:
|
|
122
|
+
logger.info(f"Parsed {len(tools)} shell tools: {', '.join(tools.keys())}")
|
|
123
|
+
if python_tools:
|
|
124
|
+
logger.info(
|
|
125
|
+
f"Parsed {len(python_tools)} Python tools: {', '.join(python_tools.keys())}"
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
return tools, python_tools
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def _compile_node(
|
|
132
|
+
node_name: str,
|
|
133
|
+
node_config: dict[str, Any],
|
|
134
|
+
graph: StateGraph,
|
|
135
|
+
config: GraphConfig,
|
|
136
|
+
tools: dict[str, Any],
|
|
137
|
+
python_tools: dict[str, Any],
|
|
138
|
+
) -> tuple[str, Any] | None:
|
|
139
|
+
"""Compile a single node and add to graph.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
node_name: Name of the node
|
|
143
|
+
node_config: Node configuration dict
|
|
144
|
+
graph: StateGraph to add node to
|
|
145
|
+
config: Full graph config for defaults
|
|
146
|
+
tools: Shell tools registry
|
|
147
|
+
python_tools: Python tools registry
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
Tuple of (node_name, map_info) for map nodes, None otherwise
|
|
151
|
+
"""
|
|
152
|
+
# Copy node config and add loop_limit if specified
|
|
153
|
+
enriched_config = dict(node_config)
|
|
154
|
+
if node_name in config.loop_limits:
|
|
155
|
+
enriched_config["loop_limit"] = config.loop_limits[node_name]
|
|
156
|
+
|
|
157
|
+
node_type = node_config.get("type", NodeType.LLM)
|
|
158
|
+
|
|
159
|
+
if node_type == NodeType.TOOL:
|
|
160
|
+
node_fn = create_tool_node(node_name, enriched_config, tools)
|
|
161
|
+
graph.add_node(node_name, node_fn)
|
|
162
|
+
elif node_type == NodeType.PYTHON:
|
|
163
|
+
node_fn = create_python_node(node_name, enriched_config, python_tools)
|
|
164
|
+
graph.add_node(node_name, node_fn)
|
|
165
|
+
elif node_type == NodeType.AGENT:
|
|
166
|
+
node_fn = create_agent_node(node_name, enriched_config, tools)
|
|
167
|
+
graph.add_node(node_name, node_fn)
|
|
168
|
+
elif node_type == NodeType.MAP:
|
|
169
|
+
map_edge_fn, sub_node_name = compile_map_node(
|
|
170
|
+
node_name, enriched_config, graph, config.defaults
|
|
171
|
+
)
|
|
172
|
+
logger.info(f"Added node: {node_name} (type={node_type})")
|
|
173
|
+
return (node_name, (map_edge_fn, sub_node_name))
|
|
174
|
+
else:
|
|
175
|
+
# LLM and router nodes
|
|
176
|
+
node_fn = create_node_function(node_name, enriched_config, config.defaults)
|
|
177
|
+
graph.add_node(node_name, node_fn)
|
|
178
|
+
|
|
179
|
+
logger.info(f"Added node: {node_name} (type={node_type})")
|
|
180
|
+
return None
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def _compile_nodes(
|
|
184
|
+
config: GraphConfig,
|
|
185
|
+
graph: StateGraph,
|
|
186
|
+
tools: dict[str, Any],
|
|
187
|
+
python_tools: dict[str, Any],
|
|
188
|
+
) -> dict[str, tuple]:
|
|
189
|
+
"""Compile all nodes and add to graph.
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
config: Graph configuration
|
|
193
|
+
graph: StateGraph to add nodes to
|
|
194
|
+
tools: Shell tools registry
|
|
195
|
+
python_tools: Python tools registry
|
|
196
|
+
|
|
197
|
+
Returns:
|
|
198
|
+
Dict of map_nodes: name -> (map_edge_fn, sub_node_name)
|
|
199
|
+
"""
|
|
200
|
+
map_nodes: dict[str, tuple] = {}
|
|
201
|
+
|
|
202
|
+
for node_name, node_config in config.nodes.items():
|
|
203
|
+
result = _compile_node(
|
|
204
|
+
node_name, node_config, graph, config, tools, python_tools
|
|
205
|
+
)
|
|
206
|
+
if result:
|
|
207
|
+
map_nodes[result[0]] = result[1]
|
|
208
|
+
|
|
209
|
+
return map_nodes
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def _process_edge(
|
|
213
|
+
edge: dict[str, Any],
|
|
214
|
+
graph: StateGraph,
|
|
215
|
+
map_nodes: dict[str, tuple],
|
|
216
|
+
router_edges: dict[str, list],
|
|
217
|
+
expression_edges: dict[str, list[tuple[str, str]]],
|
|
218
|
+
) -> None:
|
|
219
|
+
"""Process a single edge and add to graph or edge tracking dicts.
|
|
220
|
+
|
|
221
|
+
Args:
|
|
222
|
+
edge: Edge configuration dict
|
|
223
|
+
graph: StateGraph to add edges to
|
|
224
|
+
map_nodes: Map node tracking dict
|
|
225
|
+
router_edges: Dict to collect router edges
|
|
226
|
+
expression_edges: Dict to collect expression-based edges
|
|
227
|
+
"""
|
|
228
|
+
from_node = edge["from"]
|
|
229
|
+
to_node = edge["to"]
|
|
230
|
+
condition = edge.get("condition")
|
|
231
|
+
edge_type = edge.get("type")
|
|
232
|
+
|
|
233
|
+
if from_node == "START":
|
|
234
|
+
graph.set_entry_point(to_node)
|
|
235
|
+
elif isinstance(to_node, str) and to_node in map_nodes:
|
|
236
|
+
# Edge TO a map node: use conditional edge with Send function
|
|
237
|
+
map_edge_fn, sub_node_name = map_nodes[to_node]
|
|
238
|
+
graph.add_conditional_edges(from_node, map_edge_fn, [sub_node_name])
|
|
239
|
+
elif from_node in map_nodes:
|
|
240
|
+
# Edge FROM a map node: wire sub_node to next_node for fan-in
|
|
241
|
+
_, sub_node_name = map_nodes[from_node]
|
|
242
|
+
target = END if to_node == "END" else to_node
|
|
243
|
+
graph.add_edge(sub_node_name, target)
|
|
244
|
+
elif edge_type == "conditional" and isinstance(to_node, list):
|
|
245
|
+
# Router-style conditional edge: store for later processing
|
|
246
|
+
router_edges[from_node] = to_node
|
|
247
|
+
elif condition:
|
|
248
|
+
# Expression-based condition (e.g., "critique.score < 0.8")
|
|
249
|
+
if from_node not in expression_edges:
|
|
250
|
+
expression_edges[from_node] = []
|
|
251
|
+
target = END if to_node == "END" else to_node
|
|
252
|
+
expression_edges[from_node].append((condition, target))
|
|
253
|
+
elif to_node == "END":
|
|
254
|
+
graph.add_edge(from_node, END)
|
|
255
|
+
else:
|
|
256
|
+
graph.add_edge(from_node, to_node)
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def _add_conditional_edges(
|
|
260
|
+
graph: StateGraph,
|
|
261
|
+
router_edges: dict[str, list],
|
|
262
|
+
expression_edges: dict[str, list[tuple[str, str]]],
|
|
263
|
+
) -> None:
|
|
264
|
+
"""Add router and expression conditional edges to graph.
|
|
265
|
+
|
|
266
|
+
Args:
|
|
267
|
+
graph: StateGraph to add edges to
|
|
268
|
+
router_edges: Router-style conditional edges
|
|
269
|
+
expression_edges: Expression-based conditional edges
|
|
270
|
+
"""
|
|
271
|
+
# Add router conditional edges
|
|
272
|
+
for source_node, target_nodes in router_edges.items():
|
|
273
|
+
route_mapping = {target: target for target in target_nodes}
|
|
274
|
+
graph.add_conditional_edges(
|
|
275
|
+
source_node,
|
|
276
|
+
make_router_fn(target_nodes),
|
|
277
|
+
route_mapping,
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
# Add expression-based conditional edges
|
|
281
|
+
for source_node, expr_edges in expression_edges.items():
|
|
282
|
+
targets = {target for _, target in expr_edges}
|
|
283
|
+
targets.add(END) # Always include END as fallback
|
|
284
|
+
route_mapping = {t: (END if t == END else t) for t in targets}
|
|
285
|
+
graph.add_conditional_edges(
|
|
286
|
+
source_node,
|
|
287
|
+
make_expr_router_fn(expr_edges, source_node),
|
|
288
|
+
route_mapping,
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
def compile_graph(config: GraphConfig) -> StateGraph:
|
|
293
|
+
"""Compile a GraphConfig to a LangGraph StateGraph.
|
|
294
|
+
|
|
295
|
+
Args:
|
|
296
|
+
config: Parsed graph configuration
|
|
297
|
+
|
|
298
|
+
Returns:
|
|
299
|
+
StateGraph ready for compilation
|
|
300
|
+
"""
|
|
301
|
+
# Build state class and create graph
|
|
302
|
+
state_class = _resolve_state_class(config)
|
|
303
|
+
graph = StateGraph(state_class)
|
|
304
|
+
|
|
305
|
+
# Parse all tools
|
|
306
|
+
tools, python_tools = _parse_all_tools(config)
|
|
307
|
+
|
|
308
|
+
# Compile all nodes
|
|
309
|
+
map_nodes = _compile_nodes(config, graph, tools, python_tools)
|
|
310
|
+
|
|
311
|
+
# Process edges
|
|
312
|
+
router_edges: dict[str, list] = {}
|
|
313
|
+
expression_edges: dict[str, list[tuple[str, str]]] = {}
|
|
314
|
+
|
|
315
|
+
for edge in config.edges:
|
|
316
|
+
_process_edge(edge, graph, map_nodes, router_edges, expression_edges)
|
|
317
|
+
|
|
318
|
+
# Add conditional edges
|
|
319
|
+
_add_conditional_edges(graph, router_edges, expression_edges)
|
|
320
|
+
|
|
321
|
+
return graph
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
def load_and_compile(path: str | Path) -> StateGraph:
|
|
325
|
+
"""Load YAML and compile to StateGraph.
|
|
326
|
+
|
|
327
|
+
Convenience function combining load_graph_config and compile_graph.
|
|
328
|
+
|
|
329
|
+
Args:
|
|
330
|
+
path: Path to YAML graph definition
|
|
331
|
+
|
|
332
|
+
Returns:
|
|
333
|
+
StateGraph ready for compilation
|
|
334
|
+
"""
|
|
335
|
+
config = load_graph_config(path)
|
|
336
|
+
logger.info(f"Loaded graph config: {config.name} v{config.version}")
|
|
337
|
+
return compile_graph(config)
|
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
"""Map node compiler - Handles type: map node compilation.
|
|
2
|
+
|
|
3
|
+
This module provides functionality to compile map nodes that fan out
|
|
4
|
+
to sub-nodes for parallel processing using LangGraph's Send mechanism.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
from collections.abc import Callable
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
from langgraph.graph import StateGraph
|
|
12
|
+
from langgraph.types import Send
|
|
13
|
+
|
|
14
|
+
from yamlgraph.node_factory import create_node_function
|
|
15
|
+
from yamlgraph.utils.expressions import resolve_state_expression
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def wrap_for_reducer(
|
|
21
|
+
node_fn: Callable[[dict], dict],
|
|
22
|
+
collect_key: str,
|
|
23
|
+
state_key: str,
|
|
24
|
+
) -> Callable[[dict], dict]:
|
|
25
|
+
"""Wrap sub-node output for Annotated reducer aggregation.
|
|
26
|
+
|
|
27
|
+
Handles error propagation: if a map branch fails, the error is
|
|
28
|
+
included in the result with the _map_index for tracking.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
node_fn: The original node function
|
|
32
|
+
collect_key: State key where results are collected
|
|
33
|
+
state_key: Key to extract from node result
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
Wrapped function that outputs in reducer-compatible format
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
def wrapped(state: dict) -> dict:
|
|
40
|
+
try:
|
|
41
|
+
result = node_fn(state)
|
|
42
|
+
except Exception as e:
|
|
43
|
+
# Propagate error with map index
|
|
44
|
+
from yamlgraph.models import PipelineError
|
|
45
|
+
|
|
46
|
+
error_result = {
|
|
47
|
+
"_map_index": state.get("_map_index", 0),
|
|
48
|
+
"_error": str(e),
|
|
49
|
+
"_error_type": type(e).__name__,
|
|
50
|
+
}
|
|
51
|
+
return {
|
|
52
|
+
collect_key: [error_result],
|
|
53
|
+
"errors": [PipelineError.from_exception(e, node="map_subnode")],
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
# Check if result contains an error
|
|
57
|
+
if "errors" in result or "error" in result:
|
|
58
|
+
error_result = {
|
|
59
|
+
"_map_index": state.get("_map_index", 0),
|
|
60
|
+
"_error": str(result.get("errors") or result.get("error")),
|
|
61
|
+
}
|
|
62
|
+
# Preserve errors in output
|
|
63
|
+
output = {collect_key: [error_result]}
|
|
64
|
+
if "errors" in result:
|
|
65
|
+
output["errors"] = result["errors"]
|
|
66
|
+
return output
|
|
67
|
+
|
|
68
|
+
extracted = result.get(state_key, result)
|
|
69
|
+
|
|
70
|
+
# Convert Pydantic models to dicts
|
|
71
|
+
if hasattr(extracted, "model_dump"):
|
|
72
|
+
extracted = extracted.model_dump()
|
|
73
|
+
|
|
74
|
+
# Include _map_index if present for ordering
|
|
75
|
+
if "_map_index" in state:
|
|
76
|
+
if isinstance(extracted, dict):
|
|
77
|
+
extracted = {"_map_index": state["_map_index"], **extracted}
|
|
78
|
+
else:
|
|
79
|
+
extracted = {"_map_index": state["_map_index"], "value": extracted}
|
|
80
|
+
|
|
81
|
+
return {collect_key: [extracted]}
|
|
82
|
+
|
|
83
|
+
return wrapped
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def compile_map_node(
|
|
87
|
+
name: str,
|
|
88
|
+
config: dict[str, Any],
|
|
89
|
+
builder: StateGraph,
|
|
90
|
+
defaults: dict[str, Any],
|
|
91
|
+
) -> tuple[Callable[[dict], list[Send]], str]:
|
|
92
|
+
"""Compile type: map node using LangGraph Send.
|
|
93
|
+
|
|
94
|
+
Creates a sub-node and returns a map edge function that fans out
|
|
95
|
+
to the sub-node for each item in the list.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
name: Name of the map node
|
|
99
|
+
config: Map node configuration with 'over', 'as', 'node', 'collect'
|
|
100
|
+
builder: StateGraph builder to add sub-node to
|
|
101
|
+
defaults: Default configuration for nodes
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
Tuple of (map_edge_function, sub_node_name)
|
|
105
|
+
"""
|
|
106
|
+
over_expr = config["over"]
|
|
107
|
+
item_var = config["as"]
|
|
108
|
+
sub_node_name = f"_map_{name}_sub"
|
|
109
|
+
collect_key = config["collect"]
|
|
110
|
+
sub_node_config = dict(config["node"]) # Copy to avoid mutating original
|
|
111
|
+
state_key = sub_node_config.get("state_key", "result")
|
|
112
|
+
|
|
113
|
+
# Auto-inject the 'as' variable into sub-node's variables
|
|
114
|
+
# So the prompt can access it as {item_var}
|
|
115
|
+
sub_variables = dict(sub_node_config.get("variables", {}))
|
|
116
|
+
sub_variables[item_var] = f"{{state.{item_var}}}"
|
|
117
|
+
sub_node_config["variables"] = sub_variables
|
|
118
|
+
|
|
119
|
+
# Create sub-node from config
|
|
120
|
+
sub_node = create_node_function(sub_node_name, sub_node_config, defaults)
|
|
121
|
+
wrapped_node = wrap_for_reducer(sub_node, collect_key, state_key)
|
|
122
|
+
builder.add_node(sub_node_name, wrapped_node)
|
|
123
|
+
|
|
124
|
+
# Create fan-out edge function using Send
|
|
125
|
+
def map_edge(state: dict) -> list[Send]:
|
|
126
|
+
items = resolve_state_expression(over_expr, state)
|
|
127
|
+
|
|
128
|
+
if not isinstance(items, list):
|
|
129
|
+
raise TypeError(
|
|
130
|
+
f"Map 'over' must resolve to list, got {type(items).__name__}"
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
return [
|
|
134
|
+
Send(sub_node_name, {**state, item_var: item, "_map_index": i})
|
|
135
|
+
for i, item in enumerate(items)
|
|
136
|
+
]
|
|
137
|
+
|
|
138
|
+
return map_edge, sub_node_name
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
"""Pydantic models and state definitions.
|
|
2
|
+
|
|
3
|
+
Framework models for error handling and generic reports.
|
|
4
|
+
State is now generated dynamically by state_builder.py.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from yamlgraph.models.graph_schema import (
|
|
8
|
+
EdgeConfig,
|
|
9
|
+
GraphConfigSchema,
|
|
10
|
+
NodeConfig,
|
|
11
|
+
validate_graph_schema,
|
|
12
|
+
)
|
|
13
|
+
from yamlgraph.models.schemas import (
|
|
14
|
+
ErrorType,
|
|
15
|
+
GenericReport,
|
|
16
|
+
PipelineError,
|
|
17
|
+
)
|
|
18
|
+
from yamlgraph.models.state_builder import (
|
|
19
|
+
build_state_class,
|
|
20
|
+
create_initial_state,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
__all__ = [
|
|
24
|
+
# Framework models
|
|
25
|
+
"ErrorType",
|
|
26
|
+
"PipelineError",
|
|
27
|
+
"GenericReport",
|
|
28
|
+
# Graph config schema
|
|
29
|
+
"GraphConfigSchema",
|
|
30
|
+
"NodeConfig",
|
|
31
|
+
"EdgeConfig",
|
|
32
|
+
"validate_graph_schema",
|
|
33
|
+
# Dynamic state generation
|
|
34
|
+
"build_state_class",
|
|
35
|
+
"create_initial_state",
|
|
36
|
+
]
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
"""Pydantic schemas for YAML graph configuration validation.
|
|
2
|
+
|
|
3
|
+
Provides structured validation for graph YAML files with clear error messages.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from pydantic import BaseModel, Field, field_validator, model_validator
|
|
9
|
+
|
|
10
|
+
from yamlgraph.constants import ErrorHandler, NodeType
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class NodeConfig(BaseModel):
|
|
14
|
+
"""Configuration for a single graph node."""
|
|
15
|
+
|
|
16
|
+
type: str = Field(default=NodeType.LLM, description="Node type")
|
|
17
|
+
prompt: str | None = Field(default=None, description="Prompt template name")
|
|
18
|
+
state_key: str | None = Field(default=None, description="State key for output")
|
|
19
|
+
temperature: float | None = Field(default=None, ge=0, le=2)
|
|
20
|
+
provider: str | None = Field(default=None)
|
|
21
|
+
on_error: str | None = Field(default=None)
|
|
22
|
+
fallback: dict[str, Any] | None = Field(default=None)
|
|
23
|
+
variables: dict[str, str] = Field(default_factory=dict)
|
|
24
|
+
requires: list[str] = Field(default_factory=list)
|
|
25
|
+
routes: dict[str, str] | None = Field(default=None, description="Router routes")
|
|
26
|
+
|
|
27
|
+
# Map node fields
|
|
28
|
+
over: str | None = Field(default=None, description="Map over expression")
|
|
29
|
+
# 'as' is reserved in Python, handled specially
|
|
30
|
+
item_var: str | None = Field(default=None, alias="as")
|
|
31
|
+
node: dict[str, Any] | None = Field(default=None, description="Map sub-node")
|
|
32
|
+
collect: str | None = Field(default=None, description="Map collect key")
|
|
33
|
+
|
|
34
|
+
# Tool/Agent fields
|
|
35
|
+
tools: list[str] = Field(default_factory=list)
|
|
36
|
+
max_iterations: int = Field(default=10, ge=1)
|
|
37
|
+
|
|
38
|
+
model_config = {"extra": "allow", "populate_by_name": True}
|
|
39
|
+
|
|
40
|
+
@field_validator("on_error")
|
|
41
|
+
@classmethod
|
|
42
|
+
def validate_on_error(cls, v: str | None) -> str | None:
|
|
43
|
+
"""Validate on_error is a known handler."""
|
|
44
|
+
if v is not None and v not in ErrorHandler.all_values():
|
|
45
|
+
valid = ", ".join(ErrorHandler.all_values())
|
|
46
|
+
raise ValueError(f"Invalid on_error '{v}'. Valid: {valid}")
|
|
47
|
+
return v
|
|
48
|
+
|
|
49
|
+
@model_validator(mode="after")
|
|
50
|
+
def validate_node_requirements(self) -> "NodeConfig":
|
|
51
|
+
"""Validate node has required fields based on type."""
|
|
52
|
+
if NodeType.requires_prompt(self.type) and not self.prompt:
|
|
53
|
+
raise ValueError(f"Node type '{self.type}' requires 'prompt' field")
|
|
54
|
+
|
|
55
|
+
if self.type == NodeType.ROUTER and not self.routes:
|
|
56
|
+
raise ValueError("Router node requires 'routes' field")
|
|
57
|
+
|
|
58
|
+
if self.type == NodeType.MAP:
|
|
59
|
+
if not self.over:
|
|
60
|
+
raise ValueError("Map node requires 'over' field")
|
|
61
|
+
if not self.item_var:
|
|
62
|
+
raise ValueError("Map node requires 'as' field")
|
|
63
|
+
if not self.node:
|
|
64
|
+
raise ValueError("Map node requires 'node' field")
|
|
65
|
+
if not self.collect:
|
|
66
|
+
raise ValueError("Map node requires 'collect' field")
|
|
67
|
+
|
|
68
|
+
return self
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class EdgeConfig(BaseModel):
|
|
72
|
+
"""Configuration for a graph edge."""
|
|
73
|
+
|
|
74
|
+
from_node: str = Field(..., alias="from", description="Source node")
|
|
75
|
+
to: str | list[str] = Field(..., description="Target node(s)")
|
|
76
|
+
condition: str | None = Field(default=None, description="Condition expression")
|
|
77
|
+
|
|
78
|
+
model_config = {"populate_by_name": True}
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class GraphConfigSchema(BaseModel):
|
|
82
|
+
"""Full YAML graph configuration schema.
|
|
83
|
+
|
|
84
|
+
Use this for validating graph YAML files with Pydantic.
|
|
85
|
+
"""
|
|
86
|
+
|
|
87
|
+
version: str = Field(default="1.0")
|
|
88
|
+
name: str = Field(default="unnamed")
|
|
89
|
+
description: str = Field(default="")
|
|
90
|
+
defaults: dict[str, Any] = Field(default_factory=dict)
|
|
91
|
+
nodes: dict[str, NodeConfig] = Field(...)
|
|
92
|
+
edges: list[EdgeConfig] = Field(...)
|
|
93
|
+
tools: dict[str, Any] = Field(default_factory=dict)
|
|
94
|
+
state_class: str = Field(default="")
|
|
95
|
+
loop_limits: dict[str, int] = Field(default_factory=dict)
|
|
96
|
+
|
|
97
|
+
model_config = {"extra": "allow"}
|
|
98
|
+
|
|
99
|
+
@model_validator(mode="after")
|
|
100
|
+
def validate_router_targets(self) -> "GraphConfigSchema":
|
|
101
|
+
"""Validate router routes point to existing nodes."""
|
|
102
|
+
for node_name, node in self.nodes.items():
|
|
103
|
+
if node.type == NodeType.ROUTER and node.routes:
|
|
104
|
+
for route_key, target in node.routes.items():
|
|
105
|
+
if target not in self.nodes:
|
|
106
|
+
raise ValueError(
|
|
107
|
+
f"Router '{node_name}' route '{route_key}' "
|
|
108
|
+
f"targets nonexistent node '{target}'"
|
|
109
|
+
)
|
|
110
|
+
return self
|
|
111
|
+
|
|
112
|
+
@model_validator(mode="after")
|
|
113
|
+
def validate_edge_nodes(self) -> "GraphConfigSchema":
|
|
114
|
+
"""Validate edge sources and targets exist."""
|
|
115
|
+
valid_nodes = set(self.nodes.keys()) | {"START", "END"}
|
|
116
|
+
|
|
117
|
+
for edge in self.edges:
|
|
118
|
+
if edge.from_node not in valid_nodes:
|
|
119
|
+
raise ValueError(f"Edge 'from' node '{edge.from_node}' not found")
|
|
120
|
+
|
|
121
|
+
targets = edge.to if isinstance(edge.to, list) else [edge.to]
|
|
122
|
+
for target in targets:
|
|
123
|
+
if target not in valid_nodes:
|
|
124
|
+
raise ValueError(f"Edge 'to' node '{target}' not found")
|
|
125
|
+
|
|
126
|
+
return self
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def validate_graph_schema(config: dict[str, Any]) -> GraphConfigSchema:
|
|
130
|
+
"""Validate a graph configuration dict using Pydantic.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
config: Raw parsed YAML configuration
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
Validated GraphConfigSchema
|
|
137
|
+
|
|
138
|
+
Raises:
|
|
139
|
+
pydantic.ValidationError: If validation fails
|
|
140
|
+
"""
|
|
141
|
+
return GraphConfigSchema.model_validate(config)
|