yamlgraph 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of yamlgraph might be problematic. Click here for more details.

Files changed (111) hide show
  1. examples/__init__.py +1 -0
  2. examples/storyboard/__init__.py +1 -0
  3. examples/storyboard/generate_videos.py +335 -0
  4. examples/storyboard/nodes/__init__.py +10 -0
  5. examples/storyboard/nodes/animated_character_node.py +248 -0
  6. examples/storyboard/nodes/animated_image_node.py +138 -0
  7. examples/storyboard/nodes/character_node.py +162 -0
  8. examples/storyboard/nodes/image_node.py +118 -0
  9. examples/storyboard/nodes/replicate_tool.py +238 -0
  10. examples/storyboard/retry_images.py +118 -0
  11. tests/__init__.py +1 -0
  12. tests/conftest.py +178 -0
  13. tests/integration/__init__.py +1 -0
  14. tests/integration/test_animated_storyboard.py +63 -0
  15. tests/integration/test_cli_commands.py +242 -0
  16. tests/integration/test_map_demo.py +50 -0
  17. tests/integration/test_memory_demo.py +281 -0
  18. tests/integration/test_pipeline_flow.py +105 -0
  19. tests/integration/test_providers.py +163 -0
  20. tests/integration/test_resume.py +75 -0
  21. tests/unit/__init__.py +1 -0
  22. tests/unit/test_agent_nodes.py +200 -0
  23. tests/unit/test_checkpointer.py +212 -0
  24. tests/unit/test_cli.py +121 -0
  25. tests/unit/test_cli_package.py +81 -0
  26. tests/unit/test_compile_graph_map.py +132 -0
  27. tests/unit/test_conditions_routing.py +253 -0
  28. tests/unit/test_config.py +93 -0
  29. tests/unit/test_conversation_memory.py +270 -0
  30. tests/unit/test_database.py +145 -0
  31. tests/unit/test_deprecation.py +104 -0
  32. tests/unit/test_executor.py +60 -0
  33. tests/unit/test_executor_async.py +179 -0
  34. tests/unit/test_export.py +150 -0
  35. tests/unit/test_expressions.py +178 -0
  36. tests/unit/test_format_prompt.py +145 -0
  37. tests/unit/test_generic_report.py +200 -0
  38. tests/unit/test_graph_commands.py +327 -0
  39. tests/unit/test_graph_loader.py +299 -0
  40. tests/unit/test_graph_schema.py +193 -0
  41. tests/unit/test_inline_schema.py +151 -0
  42. tests/unit/test_issues.py +164 -0
  43. tests/unit/test_jinja2_prompts.py +85 -0
  44. tests/unit/test_langsmith.py +319 -0
  45. tests/unit/test_llm_factory.py +109 -0
  46. tests/unit/test_llm_factory_async.py +118 -0
  47. tests/unit/test_loops.py +403 -0
  48. tests/unit/test_map_node.py +144 -0
  49. tests/unit/test_no_backward_compat.py +56 -0
  50. tests/unit/test_node_factory.py +225 -0
  51. tests/unit/test_prompts.py +166 -0
  52. tests/unit/test_python_nodes.py +198 -0
  53. tests/unit/test_reliability.py +298 -0
  54. tests/unit/test_result_export.py +234 -0
  55. tests/unit/test_router.py +296 -0
  56. tests/unit/test_sanitize.py +99 -0
  57. tests/unit/test_schema_loader.py +295 -0
  58. tests/unit/test_shell_tools.py +229 -0
  59. tests/unit/test_state_builder.py +331 -0
  60. tests/unit/test_state_builder_map.py +104 -0
  61. tests/unit/test_state_config.py +197 -0
  62. tests/unit/test_template.py +190 -0
  63. tests/unit/test_tool_nodes.py +129 -0
  64. yamlgraph/__init__.py +35 -0
  65. yamlgraph/builder.py +110 -0
  66. yamlgraph/cli/__init__.py +139 -0
  67. yamlgraph/cli/__main__.py +6 -0
  68. yamlgraph/cli/commands.py +232 -0
  69. yamlgraph/cli/deprecation.py +92 -0
  70. yamlgraph/cli/graph_commands.py +382 -0
  71. yamlgraph/cli/validators.py +37 -0
  72. yamlgraph/config.py +67 -0
  73. yamlgraph/constants.py +66 -0
  74. yamlgraph/error_handlers.py +226 -0
  75. yamlgraph/executor.py +275 -0
  76. yamlgraph/executor_async.py +122 -0
  77. yamlgraph/graph_loader.py +337 -0
  78. yamlgraph/map_compiler.py +138 -0
  79. yamlgraph/models/__init__.py +36 -0
  80. yamlgraph/models/graph_schema.py +141 -0
  81. yamlgraph/models/schemas.py +124 -0
  82. yamlgraph/models/state_builder.py +236 -0
  83. yamlgraph/node_factory.py +240 -0
  84. yamlgraph/routing.py +87 -0
  85. yamlgraph/schema_loader.py +160 -0
  86. yamlgraph/storage/__init__.py +17 -0
  87. yamlgraph/storage/checkpointer.py +72 -0
  88. yamlgraph/storage/database.py +320 -0
  89. yamlgraph/storage/export.py +269 -0
  90. yamlgraph/tools/__init__.py +1 -0
  91. yamlgraph/tools/agent.py +235 -0
  92. yamlgraph/tools/nodes.py +124 -0
  93. yamlgraph/tools/python_tool.py +178 -0
  94. yamlgraph/tools/shell.py +205 -0
  95. yamlgraph/utils/__init__.py +47 -0
  96. yamlgraph/utils/conditions.py +157 -0
  97. yamlgraph/utils/expressions.py +111 -0
  98. yamlgraph/utils/langsmith.py +308 -0
  99. yamlgraph/utils/llm_factory.py +118 -0
  100. yamlgraph/utils/llm_factory_async.py +105 -0
  101. yamlgraph/utils/logging.py +127 -0
  102. yamlgraph/utils/prompts.py +116 -0
  103. yamlgraph/utils/sanitize.py +98 -0
  104. yamlgraph/utils/template.py +102 -0
  105. yamlgraph/utils/validators.py +181 -0
  106. yamlgraph-0.1.1.dist-info/METADATA +854 -0
  107. yamlgraph-0.1.1.dist-info/RECORD +111 -0
  108. yamlgraph-0.1.1.dist-info/WHEEL +5 -0
  109. yamlgraph-0.1.1.dist-info/entry_points.txt +2 -0
  110. yamlgraph-0.1.1.dist-info/licenses/LICENSE +21 -0
  111. yamlgraph-0.1.1.dist-info/top_level.txt +3 -0
@@ -0,0 +1,337 @@
1
+ """YAML Graph Loader - Compile YAML to LangGraph.
2
+
3
+ This module provides functionality to load graph definitions from YAML files
4
+ and compile them into LangGraph StateGraph instances.
5
+ """
6
+
7
+ import logging
8
+ from pathlib import Path
9
+ from typing import Any
10
+
11
+ import yaml
12
+ from langgraph.graph import END, StateGraph
13
+
14
+ from yamlgraph.constants import NodeType
15
+ from yamlgraph.map_compiler import compile_map_node
16
+ from yamlgraph.models.state_builder import build_state_class
17
+ from yamlgraph.node_factory import create_node_function, resolve_class
18
+ from yamlgraph.routing import make_expr_router_fn, make_router_fn
19
+ from yamlgraph.tools.agent import create_agent_node
20
+ from yamlgraph.tools.nodes import create_tool_node
21
+ from yamlgraph.tools.python_tool import create_python_node, parse_python_tools
22
+ from yamlgraph.tools.shell import parse_tools
23
+ from yamlgraph.utils.validators import validate_config
24
+
25
+ # Type alias for dynamic state
26
+ GraphState = dict[str, Any]
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ class GraphConfig:
32
+ """Parsed graph configuration from YAML."""
33
+
34
+ def __init__(self, config: dict):
35
+ """Initialize from parsed YAML dict.
36
+
37
+ Args:
38
+ config: Parsed YAML configuration dictionary
39
+
40
+ Raises:
41
+ ValueError: If config is invalid
42
+ """
43
+ # Validate before storing
44
+ validate_config(config)
45
+
46
+ self.version = config.get("version", "1.0")
47
+ self.name = config.get("name", "unnamed")
48
+ self.description = config.get("description", "")
49
+ self.defaults = config.get("defaults", {})
50
+ self.nodes = config.get("nodes", {})
51
+ self.edges = config.get("edges", [])
52
+ self.tools = config.get("tools", {})
53
+ self.state_class = config.get("state_class", "")
54
+ self.loop_limits = config.get("loop_limits", {})
55
+ # Store raw config for dynamic state building
56
+ self.raw_config = config
57
+
58
+
59
+ def load_graph_config(path: str | Path) -> GraphConfig:
60
+ """Load and parse a YAML graph definition.
61
+
62
+ Args:
63
+ path: Path to the YAML file
64
+
65
+ Returns:
66
+ GraphConfig instance
67
+
68
+ Raises:
69
+ FileNotFoundError: If the file doesn't exist
70
+ ValueError: If the YAML is invalid or missing required fields
71
+ """
72
+ path = Path(path)
73
+ if not path.exists():
74
+ raise FileNotFoundError(f"Graph config not found: {path}")
75
+
76
+ with open(path) as f:
77
+ config = yaml.safe_load(f)
78
+
79
+ return GraphConfig(config)
80
+
81
+
82
+ def _resolve_state_class(config: GraphConfig) -> type:
83
+ """Resolve the state class for the graph.
84
+
85
+ Uses dynamic state generation unless explicit state_class is set
86
+ (deprecated).
87
+
88
+ Args:
89
+ config: Graph configuration
90
+
91
+ Returns:
92
+ TypedDict class for graph state
93
+ """
94
+ if config.state_class and config.state_class != "yamlgraph.models.GraphState":
95
+ import warnings
96
+
97
+ warnings.warn(
98
+ f"state_class '{config.state_class}' is deprecated. "
99
+ "State is now auto-generated from graph config.",
100
+ DeprecationWarning,
101
+ stacklevel=2,
102
+ )
103
+ return resolve_class(config.state_class)
104
+ return build_state_class(config.raw_config)
105
+
106
+
107
+ def _parse_all_tools(
108
+ config: GraphConfig,
109
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
110
+ """Parse shell and Python tools from config.
111
+
112
+ Args:
113
+ config: Graph configuration
114
+
115
+ Returns:
116
+ Tuple of (shell_tools, python_tools) dictionaries
117
+ """
118
+ tools = parse_tools(config.tools)
119
+ python_tools = parse_python_tools(config.tools)
120
+
121
+ if tools:
122
+ logger.info(f"Parsed {len(tools)} shell tools: {', '.join(tools.keys())}")
123
+ if python_tools:
124
+ logger.info(
125
+ f"Parsed {len(python_tools)} Python tools: {', '.join(python_tools.keys())}"
126
+ )
127
+
128
+ return tools, python_tools
129
+
130
+
131
+ def _compile_node(
132
+ node_name: str,
133
+ node_config: dict[str, Any],
134
+ graph: StateGraph,
135
+ config: GraphConfig,
136
+ tools: dict[str, Any],
137
+ python_tools: dict[str, Any],
138
+ ) -> tuple[str, Any] | None:
139
+ """Compile a single node and add to graph.
140
+
141
+ Args:
142
+ node_name: Name of the node
143
+ node_config: Node configuration dict
144
+ graph: StateGraph to add node to
145
+ config: Full graph config for defaults
146
+ tools: Shell tools registry
147
+ python_tools: Python tools registry
148
+
149
+ Returns:
150
+ Tuple of (node_name, map_info) for map nodes, None otherwise
151
+ """
152
+ # Copy node config and add loop_limit if specified
153
+ enriched_config = dict(node_config)
154
+ if node_name in config.loop_limits:
155
+ enriched_config["loop_limit"] = config.loop_limits[node_name]
156
+
157
+ node_type = node_config.get("type", NodeType.LLM)
158
+
159
+ if node_type == NodeType.TOOL:
160
+ node_fn = create_tool_node(node_name, enriched_config, tools)
161
+ graph.add_node(node_name, node_fn)
162
+ elif node_type == NodeType.PYTHON:
163
+ node_fn = create_python_node(node_name, enriched_config, python_tools)
164
+ graph.add_node(node_name, node_fn)
165
+ elif node_type == NodeType.AGENT:
166
+ node_fn = create_agent_node(node_name, enriched_config, tools)
167
+ graph.add_node(node_name, node_fn)
168
+ elif node_type == NodeType.MAP:
169
+ map_edge_fn, sub_node_name = compile_map_node(
170
+ node_name, enriched_config, graph, config.defaults
171
+ )
172
+ logger.info(f"Added node: {node_name} (type={node_type})")
173
+ return (node_name, (map_edge_fn, sub_node_name))
174
+ else:
175
+ # LLM and router nodes
176
+ node_fn = create_node_function(node_name, enriched_config, config.defaults)
177
+ graph.add_node(node_name, node_fn)
178
+
179
+ logger.info(f"Added node: {node_name} (type={node_type})")
180
+ return None
181
+
182
+
183
+ def _compile_nodes(
184
+ config: GraphConfig,
185
+ graph: StateGraph,
186
+ tools: dict[str, Any],
187
+ python_tools: dict[str, Any],
188
+ ) -> dict[str, tuple]:
189
+ """Compile all nodes and add to graph.
190
+
191
+ Args:
192
+ config: Graph configuration
193
+ graph: StateGraph to add nodes to
194
+ tools: Shell tools registry
195
+ python_tools: Python tools registry
196
+
197
+ Returns:
198
+ Dict of map_nodes: name -> (map_edge_fn, sub_node_name)
199
+ """
200
+ map_nodes: dict[str, tuple] = {}
201
+
202
+ for node_name, node_config in config.nodes.items():
203
+ result = _compile_node(
204
+ node_name, node_config, graph, config, tools, python_tools
205
+ )
206
+ if result:
207
+ map_nodes[result[0]] = result[1]
208
+
209
+ return map_nodes
210
+
211
+
212
+ def _process_edge(
213
+ edge: dict[str, Any],
214
+ graph: StateGraph,
215
+ map_nodes: dict[str, tuple],
216
+ router_edges: dict[str, list],
217
+ expression_edges: dict[str, list[tuple[str, str]]],
218
+ ) -> None:
219
+ """Process a single edge and add to graph or edge tracking dicts.
220
+
221
+ Args:
222
+ edge: Edge configuration dict
223
+ graph: StateGraph to add edges to
224
+ map_nodes: Map node tracking dict
225
+ router_edges: Dict to collect router edges
226
+ expression_edges: Dict to collect expression-based edges
227
+ """
228
+ from_node = edge["from"]
229
+ to_node = edge["to"]
230
+ condition = edge.get("condition")
231
+ edge_type = edge.get("type")
232
+
233
+ if from_node == "START":
234
+ graph.set_entry_point(to_node)
235
+ elif isinstance(to_node, str) and to_node in map_nodes:
236
+ # Edge TO a map node: use conditional edge with Send function
237
+ map_edge_fn, sub_node_name = map_nodes[to_node]
238
+ graph.add_conditional_edges(from_node, map_edge_fn, [sub_node_name])
239
+ elif from_node in map_nodes:
240
+ # Edge FROM a map node: wire sub_node to next_node for fan-in
241
+ _, sub_node_name = map_nodes[from_node]
242
+ target = END if to_node == "END" else to_node
243
+ graph.add_edge(sub_node_name, target)
244
+ elif edge_type == "conditional" and isinstance(to_node, list):
245
+ # Router-style conditional edge: store for later processing
246
+ router_edges[from_node] = to_node
247
+ elif condition:
248
+ # Expression-based condition (e.g., "critique.score < 0.8")
249
+ if from_node not in expression_edges:
250
+ expression_edges[from_node] = []
251
+ target = END if to_node == "END" else to_node
252
+ expression_edges[from_node].append((condition, target))
253
+ elif to_node == "END":
254
+ graph.add_edge(from_node, END)
255
+ else:
256
+ graph.add_edge(from_node, to_node)
257
+
258
+
259
+ def _add_conditional_edges(
260
+ graph: StateGraph,
261
+ router_edges: dict[str, list],
262
+ expression_edges: dict[str, list[tuple[str, str]]],
263
+ ) -> None:
264
+ """Add router and expression conditional edges to graph.
265
+
266
+ Args:
267
+ graph: StateGraph to add edges to
268
+ router_edges: Router-style conditional edges
269
+ expression_edges: Expression-based conditional edges
270
+ """
271
+ # Add router conditional edges
272
+ for source_node, target_nodes in router_edges.items():
273
+ route_mapping = {target: target for target in target_nodes}
274
+ graph.add_conditional_edges(
275
+ source_node,
276
+ make_router_fn(target_nodes),
277
+ route_mapping,
278
+ )
279
+
280
+ # Add expression-based conditional edges
281
+ for source_node, expr_edges in expression_edges.items():
282
+ targets = {target for _, target in expr_edges}
283
+ targets.add(END) # Always include END as fallback
284
+ route_mapping = {t: (END if t == END else t) for t in targets}
285
+ graph.add_conditional_edges(
286
+ source_node,
287
+ make_expr_router_fn(expr_edges, source_node),
288
+ route_mapping,
289
+ )
290
+
291
+
292
+ def compile_graph(config: GraphConfig) -> StateGraph:
293
+ """Compile a GraphConfig to a LangGraph StateGraph.
294
+
295
+ Args:
296
+ config: Parsed graph configuration
297
+
298
+ Returns:
299
+ StateGraph ready for compilation
300
+ """
301
+ # Build state class and create graph
302
+ state_class = _resolve_state_class(config)
303
+ graph = StateGraph(state_class)
304
+
305
+ # Parse all tools
306
+ tools, python_tools = _parse_all_tools(config)
307
+
308
+ # Compile all nodes
309
+ map_nodes = _compile_nodes(config, graph, tools, python_tools)
310
+
311
+ # Process edges
312
+ router_edges: dict[str, list] = {}
313
+ expression_edges: dict[str, list[tuple[str, str]]] = {}
314
+
315
+ for edge in config.edges:
316
+ _process_edge(edge, graph, map_nodes, router_edges, expression_edges)
317
+
318
+ # Add conditional edges
319
+ _add_conditional_edges(graph, router_edges, expression_edges)
320
+
321
+ return graph
322
+
323
+
324
+ def load_and_compile(path: str | Path) -> StateGraph:
325
+ """Load YAML and compile to StateGraph.
326
+
327
+ Convenience function combining load_graph_config and compile_graph.
328
+
329
+ Args:
330
+ path: Path to YAML graph definition
331
+
332
+ Returns:
333
+ StateGraph ready for compilation
334
+ """
335
+ config = load_graph_config(path)
336
+ logger.info(f"Loaded graph config: {config.name} v{config.version}")
337
+ return compile_graph(config)
@@ -0,0 +1,138 @@
1
+ """Map node compiler - Handles type: map node compilation.
2
+
3
+ This module provides functionality to compile map nodes that fan out
4
+ to sub-nodes for parallel processing using LangGraph's Send mechanism.
5
+ """
6
+
7
+ import logging
8
+ from collections.abc import Callable
9
+ from typing import Any
10
+
11
+ from langgraph.graph import StateGraph
12
+ from langgraph.types import Send
13
+
14
+ from yamlgraph.node_factory import create_node_function
15
+ from yamlgraph.utils.expressions import resolve_state_expression
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ def wrap_for_reducer(
21
+ node_fn: Callable[[dict], dict],
22
+ collect_key: str,
23
+ state_key: str,
24
+ ) -> Callable[[dict], dict]:
25
+ """Wrap sub-node output for Annotated reducer aggregation.
26
+
27
+ Handles error propagation: if a map branch fails, the error is
28
+ included in the result with the _map_index for tracking.
29
+
30
+ Args:
31
+ node_fn: The original node function
32
+ collect_key: State key where results are collected
33
+ state_key: Key to extract from node result
34
+
35
+ Returns:
36
+ Wrapped function that outputs in reducer-compatible format
37
+ """
38
+
39
+ def wrapped(state: dict) -> dict:
40
+ try:
41
+ result = node_fn(state)
42
+ except Exception as e:
43
+ # Propagate error with map index
44
+ from yamlgraph.models import PipelineError
45
+
46
+ error_result = {
47
+ "_map_index": state.get("_map_index", 0),
48
+ "_error": str(e),
49
+ "_error_type": type(e).__name__,
50
+ }
51
+ return {
52
+ collect_key: [error_result],
53
+ "errors": [PipelineError.from_exception(e, node="map_subnode")],
54
+ }
55
+
56
+ # Check if result contains an error
57
+ if "errors" in result or "error" in result:
58
+ error_result = {
59
+ "_map_index": state.get("_map_index", 0),
60
+ "_error": str(result.get("errors") or result.get("error")),
61
+ }
62
+ # Preserve errors in output
63
+ output = {collect_key: [error_result]}
64
+ if "errors" in result:
65
+ output["errors"] = result["errors"]
66
+ return output
67
+
68
+ extracted = result.get(state_key, result)
69
+
70
+ # Convert Pydantic models to dicts
71
+ if hasattr(extracted, "model_dump"):
72
+ extracted = extracted.model_dump()
73
+
74
+ # Include _map_index if present for ordering
75
+ if "_map_index" in state:
76
+ if isinstance(extracted, dict):
77
+ extracted = {"_map_index": state["_map_index"], **extracted}
78
+ else:
79
+ extracted = {"_map_index": state["_map_index"], "value": extracted}
80
+
81
+ return {collect_key: [extracted]}
82
+
83
+ return wrapped
84
+
85
+
86
+ def compile_map_node(
87
+ name: str,
88
+ config: dict[str, Any],
89
+ builder: StateGraph,
90
+ defaults: dict[str, Any],
91
+ ) -> tuple[Callable[[dict], list[Send]], str]:
92
+ """Compile type: map node using LangGraph Send.
93
+
94
+ Creates a sub-node and returns a map edge function that fans out
95
+ to the sub-node for each item in the list.
96
+
97
+ Args:
98
+ name: Name of the map node
99
+ config: Map node configuration with 'over', 'as', 'node', 'collect'
100
+ builder: StateGraph builder to add sub-node to
101
+ defaults: Default configuration for nodes
102
+
103
+ Returns:
104
+ Tuple of (map_edge_function, sub_node_name)
105
+ """
106
+ over_expr = config["over"]
107
+ item_var = config["as"]
108
+ sub_node_name = f"_map_{name}_sub"
109
+ collect_key = config["collect"]
110
+ sub_node_config = dict(config["node"]) # Copy to avoid mutating original
111
+ state_key = sub_node_config.get("state_key", "result")
112
+
113
+ # Auto-inject the 'as' variable into sub-node's variables
114
+ # So the prompt can access it as {item_var}
115
+ sub_variables = dict(sub_node_config.get("variables", {}))
116
+ sub_variables[item_var] = f"{{state.{item_var}}}"
117
+ sub_node_config["variables"] = sub_variables
118
+
119
+ # Create sub-node from config
120
+ sub_node = create_node_function(sub_node_name, sub_node_config, defaults)
121
+ wrapped_node = wrap_for_reducer(sub_node, collect_key, state_key)
122
+ builder.add_node(sub_node_name, wrapped_node)
123
+
124
+ # Create fan-out edge function using Send
125
+ def map_edge(state: dict) -> list[Send]:
126
+ items = resolve_state_expression(over_expr, state)
127
+
128
+ if not isinstance(items, list):
129
+ raise TypeError(
130
+ f"Map 'over' must resolve to list, got {type(items).__name__}"
131
+ )
132
+
133
+ return [
134
+ Send(sub_node_name, {**state, item_var: item, "_map_index": i})
135
+ for i, item in enumerate(items)
136
+ ]
137
+
138
+ return map_edge, sub_node_name
@@ -0,0 +1,36 @@
1
+ """Pydantic models and state definitions.
2
+
3
+ Framework models for error handling and generic reports.
4
+ State is now generated dynamically by state_builder.py.
5
+ """
6
+
7
+ from yamlgraph.models.graph_schema import (
8
+ EdgeConfig,
9
+ GraphConfigSchema,
10
+ NodeConfig,
11
+ validate_graph_schema,
12
+ )
13
+ from yamlgraph.models.schemas import (
14
+ ErrorType,
15
+ GenericReport,
16
+ PipelineError,
17
+ )
18
+ from yamlgraph.models.state_builder import (
19
+ build_state_class,
20
+ create_initial_state,
21
+ )
22
+
23
+ __all__ = [
24
+ # Framework models
25
+ "ErrorType",
26
+ "PipelineError",
27
+ "GenericReport",
28
+ # Graph config schema
29
+ "GraphConfigSchema",
30
+ "NodeConfig",
31
+ "EdgeConfig",
32
+ "validate_graph_schema",
33
+ # Dynamic state generation
34
+ "build_state_class",
35
+ "create_initial_state",
36
+ ]
@@ -0,0 +1,141 @@
1
+ """Pydantic schemas for YAML graph configuration validation.
2
+
3
+ Provides structured validation for graph YAML files with clear error messages.
4
+ """
5
+
6
+ from typing import Any
7
+
8
+ from pydantic import BaseModel, Field, field_validator, model_validator
9
+
10
+ from yamlgraph.constants import ErrorHandler, NodeType
11
+
12
+
13
+ class NodeConfig(BaseModel):
14
+ """Configuration for a single graph node."""
15
+
16
+ type: str = Field(default=NodeType.LLM, description="Node type")
17
+ prompt: str | None = Field(default=None, description="Prompt template name")
18
+ state_key: str | None = Field(default=None, description="State key for output")
19
+ temperature: float | None = Field(default=None, ge=0, le=2)
20
+ provider: str | None = Field(default=None)
21
+ on_error: str | None = Field(default=None)
22
+ fallback: dict[str, Any] | None = Field(default=None)
23
+ variables: dict[str, str] = Field(default_factory=dict)
24
+ requires: list[str] = Field(default_factory=list)
25
+ routes: dict[str, str] | None = Field(default=None, description="Router routes")
26
+
27
+ # Map node fields
28
+ over: str | None = Field(default=None, description="Map over expression")
29
+ # 'as' is reserved in Python, handled specially
30
+ item_var: str | None = Field(default=None, alias="as")
31
+ node: dict[str, Any] | None = Field(default=None, description="Map sub-node")
32
+ collect: str | None = Field(default=None, description="Map collect key")
33
+
34
+ # Tool/Agent fields
35
+ tools: list[str] = Field(default_factory=list)
36
+ max_iterations: int = Field(default=10, ge=1)
37
+
38
+ model_config = {"extra": "allow", "populate_by_name": True}
39
+
40
+ @field_validator("on_error")
41
+ @classmethod
42
+ def validate_on_error(cls, v: str | None) -> str | None:
43
+ """Validate on_error is a known handler."""
44
+ if v is not None and v not in ErrorHandler.all_values():
45
+ valid = ", ".join(ErrorHandler.all_values())
46
+ raise ValueError(f"Invalid on_error '{v}'. Valid: {valid}")
47
+ return v
48
+
49
+ @model_validator(mode="after")
50
+ def validate_node_requirements(self) -> "NodeConfig":
51
+ """Validate node has required fields based on type."""
52
+ if NodeType.requires_prompt(self.type) and not self.prompt:
53
+ raise ValueError(f"Node type '{self.type}' requires 'prompt' field")
54
+
55
+ if self.type == NodeType.ROUTER and not self.routes:
56
+ raise ValueError("Router node requires 'routes' field")
57
+
58
+ if self.type == NodeType.MAP:
59
+ if not self.over:
60
+ raise ValueError("Map node requires 'over' field")
61
+ if not self.item_var:
62
+ raise ValueError("Map node requires 'as' field")
63
+ if not self.node:
64
+ raise ValueError("Map node requires 'node' field")
65
+ if not self.collect:
66
+ raise ValueError("Map node requires 'collect' field")
67
+
68
+ return self
69
+
70
+
71
+ class EdgeConfig(BaseModel):
72
+ """Configuration for a graph edge."""
73
+
74
+ from_node: str = Field(..., alias="from", description="Source node")
75
+ to: str | list[str] = Field(..., description="Target node(s)")
76
+ condition: str | None = Field(default=None, description="Condition expression")
77
+
78
+ model_config = {"populate_by_name": True}
79
+
80
+
81
+ class GraphConfigSchema(BaseModel):
82
+ """Full YAML graph configuration schema.
83
+
84
+ Use this for validating graph YAML files with Pydantic.
85
+ """
86
+
87
+ version: str = Field(default="1.0")
88
+ name: str = Field(default="unnamed")
89
+ description: str = Field(default="")
90
+ defaults: dict[str, Any] = Field(default_factory=dict)
91
+ nodes: dict[str, NodeConfig] = Field(...)
92
+ edges: list[EdgeConfig] = Field(...)
93
+ tools: dict[str, Any] = Field(default_factory=dict)
94
+ state_class: str = Field(default="")
95
+ loop_limits: dict[str, int] = Field(default_factory=dict)
96
+
97
+ model_config = {"extra": "allow"}
98
+
99
+ @model_validator(mode="after")
100
+ def validate_router_targets(self) -> "GraphConfigSchema":
101
+ """Validate router routes point to existing nodes."""
102
+ for node_name, node in self.nodes.items():
103
+ if node.type == NodeType.ROUTER and node.routes:
104
+ for route_key, target in node.routes.items():
105
+ if target not in self.nodes:
106
+ raise ValueError(
107
+ f"Router '{node_name}' route '{route_key}' "
108
+ f"targets nonexistent node '{target}'"
109
+ )
110
+ return self
111
+
112
+ @model_validator(mode="after")
113
+ def validate_edge_nodes(self) -> "GraphConfigSchema":
114
+ """Validate edge sources and targets exist."""
115
+ valid_nodes = set(self.nodes.keys()) | {"START", "END"}
116
+
117
+ for edge in self.edges:
118
+ if edge.from_node not in valid_nodes:
119
+ raise ValueError(f"Edge 'from' node '{edge.from_node}' not found")
120
+
121
+ targets = edge.to if isinstance(edge.to, list) else [edge.to]
122
+ for target in targets:
123
+ if target not in valid_nodes:
124
+ raise ValueError(f"Edge 'to' node '{target}' not found")
125
+
126
+ return self
127
+
128
+
129
+ def validate_graph_schema(config: dict[str, Any]) -> GraphConfigSchema:
130
+ """Validate a graph configuration dict using Pydantic.
131
+
132
+ Args:
133
+ config: Raw parsed YAML configuration
134
+
135
+ Returns:
136
+ Validated GraphConfigSchema
137
+
138
+ Raises:
139
+ pydantic.ValidationError: If validation fails
140
+ """
141
+ return GraphConfigSchema.model_validate(config)