yamlgraph 0.3.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. examples/__init__.py +1 -0
  2. examples/codegen/__init__.py +5 -0
  3. examples/codegen/models/__init__.py +13 -0
  4. examples/codegen/models/schemas.py +76 -0
  5. examples/codegen/tests/__init__.py +1 -0
  6. examples/codegen/tests/test_ai_helpers.py +235 -0
  7. examples/codegen/tests/test_ast_analysis.py +174 -0
  8. examples/codegen/tests/test_code_analysis.py +134 -0
  9. examples/codegen/tests/test_code_context.py +301 -0
  10. examples/codegen/tests/test_code_nav.py +89 -0
  11. examples/codegen/tests/test_dependency_tools.py +119 -0
  12. examples/codegen/tests/test_example_tools.py +185 -0
  13. examples/codegen/tests/test_git_tools.py +112 -0
  14. examples/codegen/tests/test_impl_agent_schemas.py +193 -0
  15. examples/codegen/tests/test_impl_agent_v4_graph.py +94 -0
  16. examples/codegen/tests/test_jedi_analysis.py +226 -0
  17. examples/codegen/tests/test_meta_tools.py +250 -0
  18. examples/codegen/tests/test_plan_discovery_prompt.py +98 -0
  19. examples/codegen/tests/test_syntax_tools.py +85 -0
  20. examples/codegen/tests/test_synthesize_prompt.py +94 -0
  21. examples/codegen/tests/test_template_tools.py +244 -0
  22. examples/codegen/tools/__init__.py +80 -0
  23. examples/codegen/tools/ai_helpers.py +420 -0
  24. examples/codegen/tools/ast_analysis.py +92 -0
  25. examples/codegen/tools/code_context.py +180 -0
  26. examples/codegen/tools/code_nav.py +52 -0
  27. examples/codegen/tools/dependency_tools.py +120 -0
  28. examples/codegen/tools/example_tools.py +188 -0
  29. examples/codegen/tools/git_tools.py +151 -0
  30. examples/codegen/tools/impl_executor.py +614 -0
  31. examples/codegen/tools/jedi_analysis.py +311 -0
  32. examples/codegen/tools/meta_tools.py +202 -0
  33. examples/codegen/tools/syntax_tools.py +26 -0
  34. examples/codegen/tools/template_tools.py +356 -0
  35. examples/fastapi_interview.py +167 -0
  36. examples/npc/api/__init__.py +1 -0
  37. examples/npc/api/app.py +100 -0
  38. examples/npc/api/routes/__init__.py +5 -0
  39. examples/npc/api/routes/encounter.py +182 -0
  40. examples/npc/api/session.py +330 -0
  41. examples/npc/demo.py +387 -0
  42. examples/npc/nodes/__init__.py +5 -0
  43. examples/npc/nodes/image_node.py +92 -0
  44. examples/npc/run_encounter.py +230 -0
  45. examples/shared/__init__.py +0 -0
  46. examples/shared/replicate_tool.py +238 -0
  47. examples/storyboard/__init__.py +1 -0
  48. examples/storyboard/generate_videos.py +335 -0
  49. examples/storyboard/nodes/__init__.py +12 -0
  50. examples/storyboard/nodes/animated_character_node.py +248 -0
  51. examples/storyboard/nodes/animated_image_node.py +138 -0
  52. examples/storyboard/nodes/character_node.py +162 -0
  53. examples/storyboard/nodes/image_node.py +118 -0
  54. examples/storyboard/nodes/replicate_tool.py +49 -0
  55. examples/storyboard/retry_images.py +118 -0
  56. scripts/demo_async_executor.py +212 -0
  57. scripts/demo_interview_e2e.py +200 -0
  58. scripts/demo_streaming.py +140 -0
  59. scripts/run_interview_demo.py +94 -0
  60. scripts/test_interrupt_fix.py +26 -0
  61. tests/__init__.py +1 -0
  62. tests/conftest.py +178 -0
  63. tests/integration/__init__.py +1 -0
  64. tests/integration/test_animated_storyboard.py +63 -0
  65. tests/integration/test_cli_commands.py +242 -0
  66. tests/integration/test_colocated_prompts.py +139 -0
  67. tests/integration/test_map_demo.py +50 -0
  68. tests/integration/test_memory_demo.py +283 -0
  69. tests/integration/test_npc_api/__init__.py +1 -0
  70. tests/integration/test_npc_api/test_routes.py +357 -0
  71. tests/integration/test_npc_api/test_session.py +216 -0
  72. tests/integration/test_pipeline_flow.py +105 -0
  73. tests/integration/test_providers.py +163 -0
  74. tests/integration/test_resume.py +75 -0
  75. tests/integration/test_subgraph_integration.py +295 -0
  76. tests/integration/test_subgraph_interrupt.py +106 -0
  77. tests/unit/__init__.py +1 -0
  78. tests/unit/test_agent_nodes.py +355 -0
  79. tests/unit/test_async_executor.py +346 -0
  80. tests/unit/test_checkpointer.py +212 -0
  81. tests/unit/test_checkpointer_factory.py +212 -0
  82. tests/unit/test_cli.py +121 -0
  83. tests/unit/test_cli_package.py +81 -0
  84. tests/unit/test_compile_graph_map.py +132 -0
  85. tests/unit/test_conditions_routing.py +253 -0
  86. tests/unit/test_config.py +93 -0
  87. tests/unit/test_conversation_memory.py +276 -0
  88. tests/unit/test_database.py +145 -0
  89. tests/unit/test_deprecation.py +104 -0
  90. tests/unit/test_executor.py +172 -0
  91. tests/unit/test_executor_async.py +179 -0
  92. tests/unit/test_export.py +149 -0
  93. tests/unit/test_expressions.py +178 -0
  94. tests/unit/test_feature_brainstorm.py +194 -0
  95. tests/unit/test_format_prompt.py +145 -0
  96. tests/unit/test_generic_report.py +200 -0
  97. tests/unit/test_graph_commands.py +327 -0
  98. tests/unit/test_graph_linter.py +627 -0
  99. tests/unit/test_graph_loader.py +357 -0
  100. tests/unit/test_graph_schema.py +193 -0
  101. tests/unit/test_inline_schema.py +151 -0
  102. tests/unit/test_interrupt_node.py +182 -0
  103. tests/unit/test_issues.py +164 -0
  104. tests/unit/test_jinja2_prompts.py +85 -0
  105. tests/unit/test_json_extract.py +134 -0
  106. tests/unit/test_langsmith.py +600 -0
  107. tests/unit/test_langsmith_tools.py +204 -0
  108. tests/unit/test_llm_factory.py +109 -0
  109. tests/unit/test_llm_factory_async.py +118 -0
  110. tests/unit/test_loops.py +403 -0
  111. tests/unit/test_map_node.py +144 -0
  112. tests/unit/test_no_backward_compat.py +56 -0
  113. tests/unit/test_node_factory.py +348 -0
  114. tests/unit/test_passthrough_node.py +126 -0
  115. tests/unit/test_prompts.py +324 -0
  116. tests/unit/test_python_nodes.py +198 -0
  117. tests/unit/test_reliability.py +298 -0
  118. tests/unit/test_result_export.py +234 -0
  119. tests/unit/test_router.py +296 -0
  120. tests/unit/test_sanitize.py +99 -0
  121. tests/unit/test_schema_loader.py +295 -0
  122. tests/unit/test_shell_tools.py +229 -0
  123. tests/unit/test_state_builder.py +331 -0
  124. tests/unit/test_state_builder_map.py +104 -0
  125. tests/unit/test_state_config.py +197 -0
  126. tests/unit/test_streaming.py +307 -0
  127. tests/unit/test_subgraph.py +596 -0
  128. tests/unit/test_template.py +190 -0
  129. tests/unit/test_tool_call_integration.py +164 -0
  130. tests/unit/test_tool_call_node.py +178 -0
  131. tests/unit/test_tool_nodes.py +129 -0
  132. tests/unit/test_websearch.py +234 -0
  133. yamlgraph/__init__.py +35 -0
  134. yamlgraph/builder.py +110 -0
  135. yamlgraph/cli/__init__.py +159 -0
  136. yamlgraph/cli/__main__.py +6 -0
  137. yamlgraph/cli/commands.py +231 -0
  138. yamlgraph/cli/deprecation.py +92 -0
  139. yamlgraph/cli/graph_commands.py +541 -0
  140. yamlgraph/cli/validators.py +37 -0
  141. yamlgraph/config.py +67 -0
  142. yamlgraph/constants.py +70 -0
  143. yamlgraph/error_handlers.py +227 -0
  144. yamlgraph/executor.py +290 -0
  145. yamlgraph/executor_async.py +288 -0
  146. yamlgraph/graph_loader.py +451 -0
  147. yamlgraph/map_compiler.py +150 -0
  148. yamlgraph/models/__init__.py +36 -0
  149. yamlgraph/models/graph_schema.py +181 -0
  150. yamlgraph/models/schemas.py +124 -0
  151. yamlgraph/models/state_builder.py +236 -0
  152. yamlgraph/node_factory.py +768 -0
  153. yamlgraph/routing.py +87 -0
  154. yamlgraph/schema_loader.py +240 -0
  155. yamlgraph/storage/__init__.py +20 -0
  156. yamlgraph/storage/checkpointer.py +72 -0
  157. yamlgraph/storage/checkpointer_factory.py +123 -0
  158. yamlgraph/storage/database.py +320 -0
  159. yamlgraph/storage/export.py +269 -0
  160. yamlgraph/tools/__init__.py +1 -0
  161. yamlgraph/tools/agent.py +320 -0
  162. yamlgraph/tools/graph_linter.py +388 -0
  163. yamlgraph/tools/langsmith_tools.py +125 -0
  164. yamlgraph/tools/nodes.py +126 -0
  165. yamlgraph/tools/python_tool.py +179 -0
  166. yamlgraph/tools/shell.py +205 -0
  167. yamlgraph/tools/websearch.py +242 -0
  168. yamlgraph/utils/__init__.py +48 -0
  169. yamlgraph/utils/conditions.py +157 -0
  170. yamlgraph/utils/expressions.py +245 -0
  171. yamlgraph/utils/json_extract.py +104 -0
  172. yamlgraph/utils/langsmith.py +416 -0
  173. yamlgraph/utils/llm_factory.py +118 -0
  174. yamlgraph/utils/llm_factory_async.py +105 -0
  175. yamlgraph/utils/logging.py +104 -0
  176. yamlgraph/utils/prompts.py +171 -0
  177. yamlgraph/utils/sanitize.py +98 -0
  178. yamlgraph/utils/template.py +102 -0
  179. yamlgraph/utils/validators.py +181 -0
  180. yamlgraph-0.3.9.dist-info/METADATA +1105 -0
  181. yamlgraph-0.3.9.dist-info/RECORD +185 -0
  182. yamlgraph-0.3.9.dist-info/WHEEL +5 -0
  183. yamlgraph-0.3.9.dist-info/entry_points.txt +2 -0
  184. yamlgraph-0.3.9.dist-info/licenses/LICENSE +33 -0
  185. yamlgraph-0.3.9.dist-info/top_level.txt +4 -0
@@ -0,0 +1,451 @@
1
+ """YAML Graph Loader - Compile YAML to LangGraph.
2
+
3
+ This module provides functionality to load graph definitions from YAML files
4
+ and compile them into LangGraph StateGraph instances.
5
+ """
6
+
7
+ import logging
8
+ from collections.abc import Callable
9
+ from pathlib import Path
10
+ from typing import Any
11
+
12
+ import yaml
13
+ from langgraph.graph import END, StateGraph
14
+
15
+ from yamlgraph.constants import NodeType
16
+ from yamlgraph.map_compiler import compile_map_node
17
+ from yamlgraph.models.state_builder import build_state_class
18
+ from yamlgraph.node_factory import (
19
+ create_interrupt_node,
20
+ create_node_function,
21
+ create_passthrough_node,
22
+ create_subgraph_node,
23
+ create_tool_call_node,
24
+ resolve_class,
25
+ )
26
+ from yamlgraph.routing import make_expr_router_fn, make_router_fn
27
+ from yamlgraph.storage.checkpointer_factory import get_checkpointer
28
+ from yamlgraph.tools.agent import create_agent_node
29
+ from yamlgraph.tools.nodes import create_tool_node
30
+ from yamlgraph.tools.python_tool import (
31
+ create_python_node,
32
+ load_python_function,
33
+ parse_python_tools,
34
+ )
35
+ from yamlgraph.tools.shell import parse_tools
36
+ from yamlgraph.tools.websearch import parse_websearch_tools
37
+ from yamlgraph.utils.validators import validate_config
38
+
39
+ # Type alias for dynamic state
40
+ GraphState = dict[str, Any]
41
+
42
+ logger = logging.getLogger(__name__)
43
+
44
+
45
+ class GraphConfig:
46
+ """Parsed graph configuration from YAML."""
47
+
48
+ def __init__(self, config: dict, source_path: Path | None = None):
49
+ """Initialize from parsed YAML dict.
50
+
51
+ Args:
52
+ config: Parsed YAML configuration dictionary
53
+ source_path: Path to the source YAML file (for subgraph resolution)
54
+
55
+ Raises:
56
+ ValueError: If config is invalid
57
+ """
58
+ # Validate before storing
59
+ validate_config(config)
60
+
61
+ self.version = config.get("version", "1.0")
62
+ self.name = config.get("name", "unnamed")
63
+ self.description = config.get("description", "")
64
+ self.defaults = config.get("defaults", {})
65
+ self.nodes = config.get("nodes", {})
66
+ self.edges = config.get("edges", [])
67
+ self.tools = config.get("tools", {})
68
+ self.state_class = config.get("state_class", "")
69
+ self.loop_limits = config.get("loop_limits", {})
70
+ self.checkpointer = config.get("checkpointer")
71
+ # Store raw config for dynamic state building
72
+ self.raw_config = config
73
+ # Store source path for subgraph resolution
74
+ self.source_path = source_path
75
+ # Prompt resolution options (FR-A: graph-relative prompts)
76
+ self.prompts_relative = self.defaults.get("prompts_relative", False)
77
+ self.prompts_dir = self.defaults.get("prompts_dir")
78
+
79
+
80
+ def load_graph_config(path: str | Path) -> GraphConfig:
81
+ """Load and parse a YAML graph definition.
82
+
83
+ Args:
84
+ path: Path to the YAML file
85
+
86
+ Returns:
87
+ GraphConfig instance
88
+
89
+ Raises:
90
+ FileNotFoundError: If the file doesn't exist
91
+ ValueError: If the YAML is invalid or missing required fields
92
+ """
93
+ path = Path(path)
94
+ if not path.exists():
95
+ raise FileNotFoundError(f"Graph config not found: {path}")
96
+
97
+ with open(path) as f:
98
+ config = yaml.safe_load(f)
99
+
100
+ return GraphConfig(config, source_path=path.resolve())
101
+
102
+
103
+ def _resolve_state_class(config: GraphConfig) -> type:
104
+ """Resolve the state class for the graph.
105
+
106
+ Uses dynamic state generation unless explicit state_class is set
107
+ (deprecated).
108
+
109
+ Args:
110
+ config: Graph configuration
111
+
112
+ Returns:
113
+ TypedDict class for graph state
114
+ """
115
+ if config.state_class and config.state_class != "yamlgraph.models.GraphState":
116
+ import warnings
117
+
118
+ warnings.warn(
119
+ f"state_class '{config.state_class}' is deprecated. "
120
+ "State is now auto-generated from graph config.",
121
+ DeprecationWarning,
122
+ stacklevel=2,
123
+ )
124
+ return resolve_class(config.state_class)
125
+ return build_state_class(config.raw_config)
126
+
127
+
128
+ def _parse_all_tools(
129
+ config: GraphConfig,
130
+ ) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any], dict[str, Callable]]:
131
+ """Parse shell, Python, and websearch tools from config.
132
+
133
+ Args:
134
+ config: Graph configuration
135
+
136
+ Returns:
137
+ Tuple of (shell_tools, python_tools, websearch_tools, callable_registry)
138
+ callable_registry maps tool names to actual callable functions for tool_call nodes
139
+ """
140
+ tools = parse_tools(config.tools)
141
+ python_tools = parse_python_tools(config.tools)
142
+ websearch_tools = parse_websearch_tools(config.tools)
143
+
144
+ # Build callable registry for tool_call nodes
145
+ callable_registry: dict[str, Callable] = {}
146
+ for name, tool_config in python_tools.items():
147
+ try:
148
+ callable_registry[name] = load_python_function(tool_config)
149
+ except (ImportError, AttributeError) as e:
150
+ logger.warning(f"Failed to load tool '{name}': {e}")
151
+
152
+ if tools:
153
+ logger.info(f"Parsed {len(tools)} shell tools: {', '.join(tools.keys())}")
154
+ if python_tools:
155
+ logger.info(
156
+ f"Parsed {len(python_tools)} Python tools: {', '.join(python_tools.keys())}"
157
+ )
158
+ if websearch_tools:
159
+ logger.info(
160
+ f"Parsed {len(websearch_tools)} websearch tools: {', '.join(websearch_tools.keys())}"
161
+ )
162
+
163
+ return tools, python_tools, websearch_tools, callable_registry
164
+
165
+
166
+ def _compile_node(
167
+ node_name: str,
168
+ node_config: dict[str, Any],
169
+ graph: StateGraph,
170
+ config: GraphConfig,
171
+ tools: dict[str, Any],
172
+ python_tools: dict[str, Any],
173
+ websearch_tools: dict[str, Any],
174
+ callable_registry: dict[str, Callable],
175
+ ) -> tuple[str, Any] | None:
176
+ """Compile a single node and add to graph.
177
+
178
+ Args:
179
+ node_name: Name of the node
180
+ node_config: Node configuration dict
181
+ graph: StateGraph to add node to
182
+ config: Full graph config for defaults
183
+ tools: Shell tools registry
184
+ python_tools: Python tools registry
185
+ websearch_tools: Web search tools registry (LangChain StructuredTool)
186
+ callable_registry: Loaded callable functions for tool_call nodes
187
+
188
+ Returns:
189
+ Tuple of (node_name, map_info) for map nodes, None otherwise
190
+ """
191
+ # Copy node config and add loop_limit if specified
192
+ enriched_config = dict(node_config)
193
+ if node_name in config.loop_limits:
194
+ enriched_config["loop_limit"] = config.loop_limits[node_name]
195
+
196
+ # Extract prompts path config from defaults (FR-A)
197
+ prompts_relative = config.defaults.get("prompts_relative", False)
198
+ prompts_dir = config.defaults.get("prompts_dir")
199
+ if prompts_dir:
200
+ prompts_dir = Path(prompts_dir)
201
+
202
+ node_type = node_config.get("type", NodeType.LLM)
203
+
204
+ if node_type == NodeType.TOOL:
205
+ node_fn = create_tool_node(node_name, enriched_config, tools)
206
+ graph.add_node(node_name, node_fn)
207
+ elif node_type == NodeType.PYTHON:
208
+ node_fn = create_python_node(node_name, enriched_config, python_tools)
209
+ graph.add_node(node_name, node_fn)
210
+ elif node_type == NodeType.AGENT:
211
+ node_fn = create_agent_node(
212
+ node_name, enriched_config, tools, websearch_tools, python_tools
213
+ )
214
+ graph.add_node(node_name, node_fn)
215
+ elif node_type == NodeType.MAP:
216
+ map_edge_fn, sub_node_name = compile_map_node(
217
+ node_name, enriched_config, graph, config.defaults, callable_registry
218
+ )
219
+ logger.info(f"Added node: {node_name} (type={node_type})")
220
+ return (node_name, (map_edge_fn, sub_node_name))
221
+ elif node_type == NodeType.TOOL_CALL:
222
+ # Dynamic tool call from state
223
+ node_fn = create_tool_call_node(node_name, enriched_config, callable_registry)
224
+ graph.add_node(node_name, node_fn)
225
+ elif node_type == NodeType.INTERRUPT:
226
+ # Human-in-the-loop interrupt node
227
+ node_fn = create_interrupt_node(
228
+ node_name,
229
+ enriched_config,
230
+ graph_path=config.source_path,
231
+ prompts_dir=prompts_dir,
232
+ prompts_relative=prompts_relative,
233
+ )
234
+ graph.add_node(node_name, node_fn)
235
+ elif node_type == NodeType.PASSTHROUGH:
236
+ # Simple state transformation node
237
+ node_fn = create_passthrough_node(node_name, enriched_config)
238
+ graph.add_node(node_name, node_fn)
239
+ elif node_type == NodeType.SUBGRAPH:
240
+ # Subgraph node - compose graphs from YAML
241
+ if not config.source_path:
242
+ raise ValueError(
243
+ f"Cannot resolve subgraph path for node '{node_name}': "
244
+ "parent graph has no source_path"
245
+ )
246
+ node_fn = create_subgraph_node(
247
+ node_name,
248
+ enriched_config,
249
+ parent_graph_path=config.source_path,
250
+ )
251
+ graph.add_node(node_name, node_fn)
252
+ else:
253
+ # LLM and router nodes
254
+ node_fn = create_node_function(
255
+ node_name,
256
+ enriched_config,
257
+ config.defaults,
258
+ graph_path=config.source_path,
259
+ )
260
+ graph.add_node(node_name, node_fn)
261
+
262
+ logger.info(f"Added node: {node_name} (type={node_type})")
263
+ return None
264
+
265
+
266
+ def _compile_nodes(
267
+ config: GraphConfig,
268
+ graph: StateGraph,
269
+ tools: dict[str, Any],
270
+ python_tools: dict[str, Any],
271
+ websearch_tools: dict[str, Any],
272
+ callable_registry: dict[str, Callable],
273
+ ) -> dict[str, tuple]:
274
+ """Compile all nodes and add to graph.
275
+
276
+ Args:
277
+ config: Graph configuration
278
+ graph: StateGraph to add nodes to
279
+ tools: Shell tools registry
280
+ python_tools: Python tools registry
281
+ websearch_tools: Web search tools registry
282
+ callable_registry: Loaded callable functions for tool_call nodes
283
+
284
+ Returns:
285
+ Dict of map_nodes: name -> (map_edge_fn, sub_node_name)
286
+ """
287
+ map_nodes: dict[str, tuple] = {}
288
+
289
+ for node_name, node_config in config.nodes.items():
290
+ result = _compile_node(
291
+ node_name,
292
+ node_config,
293
+ graph,
294
+ config,
295
+ tools,
296
+ python_tools,
297
+ websearch_tools,
298
+ callable_registry,
299
+ )
300
+ if result:
301
+ map_nodes[result[0]] = result[1]
302
+
303
+ return map_nodes
304
+
305
+
306
+ def _process_edge(
307
+ edge: dict[str, Any],
308
+ graph: StateGraph,
309
+ map_nodes: dict[str, tuple],
310
+ router_edges: dict[str, list],
311
+ expression_edges: dict[str, list[tuple[str, str]]],
312
+ ) -> None:
313
+ """Process a single edge and add to graph or edge tracking dicts.
314
+
315
+ Args:
316
+ edge: Edge configuration dict
317
+ graph: StateGraph to add edges to
318
+ map_nodes: Map node tracking dict
319
+ router_edges: Dict to collect router edges
320
+ expression_edges: Dict to collect expression-based edges
321
+ """
322
+ from_node = edge["from"]
323
+ to_node = edge["to"]
324
+ condition = edge.get("condition")
325
+ edge_type = edge.get("type")
326
+
327
+ if from_node == "START":
328
+ graph.set_entry_point(to_node)
329
+ elif from_node in map_nodes and to_node in map_nodes:
330
+ # Edge from map node TO another map node: sub_node → map_edge_fn
331
+ _, from_sub = map_nodes[from_node]
332
+ to_map_edge_fn, to_sub = map_nodes[to_node]
333
+ graph.add_conditional_edges(from_sub, to_map_edge_fn, [to_sub])
334
+ elif isinstance(to_node, str) and to_node in map_nodes:
335
+ # Edge TO a map node: use conditional edge with Send function
336
+ map_edge_fn, sub_node_name = map_nodes[to_node]
337
+ graph.add_conditional_edges(from_node, map_edge_fn, [sub_node_name])
338
+ elif from_node in map_nodes:
339
+ # Edge FROM a map node: wire sub_node to next_node for fan-in
340
+ _, sub_node_name = map_nodes[from_node]
341
+ target = END if to_node == "END" else to_node
342
+ graph.add_edge(sub_node_name, target)
343
+ elif edge_type == "conditional" and isinstance(to_node, list):
344
+ # Router-style conditional edge: store for later processing
345
+ router_edges[from_node] = to_node
346
+ elif condition:
347
+ # Expression-based condition (e.g., "critique.score < 0.8")
348
+ if from_node not in expression_edges:
349
+ expression_edges[from_node] = []
350
+ target = END if to_node == "END" else to_node
351
+ expression_edges[from_node].append((condition, target))
352
+ elif to_node == "END":
353
+ graph.add_edge(from_node, END)
354
+ else:
355
+ graph.add_edge(from_node, to_node)
356
+
357
+
358
+ def _add_conditional_edges(
359
+ graph: StateGraph,
360
+ router_edges: dict[str, list],
361
+ expression_edges: dict[str, list[tuple[str, str]]],
362
+ ) -> None:
363
+ """Add router and expression conditional edges to graph.
364
+
365
+ Args:
366
+ graph: StateGraph to add edges to
367
+ router_edges: Router-style conditional edges
368
+ expression_edges: Expression-based conditional edges
369
+ """
370
+ # Add router conditional edges
371
+ for source_node, target_nodes in router_edges.items():
372
+ route_mapping = {target: target for target in target_nodes}
373
+ graph.add_conditional_edges(
374
+ source_node,
375
+ make_router_fn(target_nodes),
376
+ route_mapping,
377
+ )
378
+
379
+ # Add expression-based conditional edges
380
+ for source_node, expr_edges in expression_edges.items():
381
+ targets = {target for _, target in expr_edges}
382
+ targets.add(END) # Always include END as fallback
383
+ route_mapping = {t: (END if t == END else t) for t in targets}
384
+ graph.add_conditional_edges(
385
+ source_node,
386
+ make_expr_router_fn(expr_edges, source_node),
387
+ route_mapping,
388
+ )
389
+
390
+
391
+ def compile_graph(config: GraphConfig) -> StateGraph:
392
+ """Compile a GraphConfig to a LangGraph StateGraph.
393
+
394
+ Args:
395
+ config: Parsed graph configuration
396
+
397
+ Returns:
398
+ StateGraph ready for compilation
399
+ """
400
+ # Build state class and create graph
401
+ state_class = _resolve_state_class(config)
402
+ graph = StateGraph(state_class)
403
+
404
+ # Parse all tools
405
+ tools, python_tools, websearch_tools, callable_registry = _parse_all_tools(config)
406
+
407
+ # Compile all nodes
408
+ map_nodes = _compile_nodes(
409
+ config, graph, tools, python_tools, websearch_tools, callable_registry
410
+ )
411
+
412
+ # Process edges
413
+ router_edges: dict[str, list] = {}
414
+ expression_edges: dict[str, list[tuple[str, str]]] = {}
415
+
416
+ for edge in config.edges:
417
+ _process_edge(edge, graph, map_nodes, router_edges, expression_edges)
418
+
419
+ # Add conditional edges
420
+ _add_conditional_edges(graph, router_edges, expression_edges)
421
+
422
+ return graph
423
+
424
+
425
+ def load_and_compile(path: str | Path) -> StateGraph:
426
+ """Load YAML and compile to StateGraph.
427
+
428
+ Convenience function combining load_graph_config and compile_graph.
429
+
430
+ Args:
431
+ path: Path to YAML graph definition
432
+
433
+ Returns:
434
+ StateGraph ready for compilation
435
+ """
436
+ config = load_graph_config(path)
437
+ logger.info(f"Loaded graph config: {config.name} v{config.version}")
438
+ return compile_graph(config)
439
+
440
+
441
+ def get_checkpointer_for_graph(config: GraphConfig, *, async_mode: bool = False):
442
+ """Get checkpointer from graph config.
443
+
444
+ Args:
445
+ config: Graph configuration
446
+ async_mode: If True, return async-compatible saver
447
+
448
+ Returns:
449
+ Configured checkpointer or None if not specified
450
+ """
451
+ return get_checkpointer(config.checkpointer, async_mode=async_mode)
@@ -0,0 +1,150 @@
1
+ """Map node compiler - Handles type: map node compilation.
2
+
3
+ This module provides functionality to compile map nodes that fan out
4
+ to sub-nodes for parallel processing using LangGraph's Send mechanism.
5
+ """
6
+
7
+ import logging
8
+ from collections.abc import Callable
9
+ from typing import Any
10
+
11
+ from langgraph.graph import StateGraph
12
+ from langgraph.types import Send
13
+
14
+ from yamlgraph.constants import NodeType
15
+ from yamlgraph.node_factory import create_node_function, create_tool_call_node
16
+ from yamlgraph.utils.expressions import resolve_state_expression
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ def wrap_for_reducer(
22
+ node_fn: Callable[[dict], dict],
23
+ collect_key: str,
24
+ state_key: str,
25
+ ) -> Callable[[dict], dict]:
26
+ """Wrap sub-node output for Annotated reducer aggregation.
27
+
28
+ Handles error propagation: if a map branch fails, the error is
29
+ included in the result with the _map_index for tracking.
30
+
31
+ Args:
32
+ node_fn: The original node function
33
+ collect_key: State key where results are collected
34
+ state_key: Key to extract from node result
35
+
36
+ Returns:
37
+ Wrapped function that outputs in reducer-compatible format
38
+ """
39
+
40
+ def wrapped(state: dict) -> dict:
41
+ try:
42
+ result = node_fn(state)
43
+ except Exception as e:
44
+ # Propagate error with map index
45
+ from yamlgraph.models import PipelineError
46
+
47
+ error_result = {
48
+ "_map_index": state.get("_map_index", 0),
49
+ "_error": str(e),
50
+ "_error_type": type(e).__name__,
51
+ }
52
+ return {
53
+ collect_key: [error_result],
54
+ "errors": [PipelineError.from_exception(e, node="map_subnode")],
55
+ }
56
+
57
+ # Check if result contains an error
58
+ if "errors" in result or "error" in result:
59
+ error_result = {
60
+ "_map_index": state.get("_map_index", 0),
61
+ "_error": str(result.get("errors") or result.get("error")),
62
+ }
63
+ # Preserve errors in output
64
+ output = {collect_key: [error_result]}
65
+ if "errors" in result:
66
+ output["errors"] = result["errors"]
67
+ return output
68
+
69
+ extracted = result.get(state_key, result)
70
+
71
+ # Convert Pydantic models to dicts
72
+ if hasattr(extracted, "model_dump"):
73
+ extracted = extracted.model_dump()
74
+
75
+ # Include _map_index if present for ordering
76
+ if "_map_index" in state:
77
+ if isinstance(extracted, dict):
78
+ extracted = {"_map_index": state["_map_index"], **extracted}
79
+ else:
80
+ extracted = {"_map_index": state["_map_index"], "value": extracted}
81
+
82
+ return {collect_key: [extracted]}
83
+
84
+ return wrapped
85
+
86
+
87
+ def compile_map_node(
88
+ name: str,
89
+ config: dict[str, Any],
90
+ builder: StateGraph,
91
+ defaults: dict[str, Any],
92
+ tools_registry: dict[str, Any] | None = None,
93
+ ) -> tuple[Callable[[dict], list[Send]], str]:
94
+ """Compile type: map node using LangGraph Send.
95
+
96
+ Creates a sub-node and returns a map edge function that fans out
97
+ to the sub-node for each item in the list.
98
+
99
+ Args:
100
+ name: Name of the map node
101
+ config: Map node configuration with 'over', 'as', 'node', 'collect'
102
+ builder: StateGraph builder to add sub-node to
103
+ defaults: Default configuration for nodes
104
+ tools_registry: Optional tools registry for tool_call sub-nodes
105
+
106
+ Returns:
107
+ Tuple of (map_edge_function, sub_node_name)
108
+ """
109
+ over_expr = config["over"]
110
+ item_var = config["as"]
111
+ sub_node_name = f"_map_{name}_sub"
112
+ collect_key = config["collect"]
113
+ sub_node_config = dict(config["node"]) # Copy to avoid mutating original
114
+ state_key = sub_node_config.get("state_key", "result")
115
+ sub_node_type = sub_node_config.get("type", "llm")
116
+
117
+ # Auto-inject the 'as' variable into sub-node's variables
118
+ # So the prompt can access it as {item_var}
119
+ sub_variables = dict(sub_node_config.get("variables", {}))
120
+ sub_variables[item_var] = f"{{state.{item_var}}}"
121
+ sub_node_config["variables"] = sub_variables
122
+
123
+ # Create sub-node based on type
124
+ if sub_node_type == NodeType.TOOL_CALL:
125
+ if tools_registry is None:
126
+ raise ValueError(
127
+ f"Map node '{name}' has tool_call sub-node but no tools_registry"
128
+ )
129
+ sub_node = create_tool_call_node(sub_node_name, sub_node_config, tools_registry)
130
+ else:
131
+ sub_node = create_node_function(sub_node_name, sub_node_config, defaults)
132
+
133
+ wrapped_node = wrap_for_reducer(sub_node, collect_key, state_key)
134
+ builder.add_node(sub_node_name, wrapped_node)
135
+
136
+ # Create fan-out edge function using Send
137
+ def map_edge(state: dict) -> list[Send]:
138
+ items = resolve_state_expression(over_expr, state)
139
+
140
+ if not isinstance(items, list):
141
+ raise TypeError(
142
+ f"Map 'over' must resolve to list, got {type(items).__name__}"
143
+ )
144
+
145
+ return [
146
+ Send(sub_node_name, {**state, item_var: item, "_map_index": i})
147
+ for i, item in enumerate(items)
148
+ ]
149
+
150
+ return map_edge, sub_node_name
@@ -0,0 +1,36 @@
1
+ """Pydantic models and state definitions.
2
+
3
+ Framework models for error handling and generic reports.
4
+ State is now generated dynamically by state_builder.py.
5
+ """
6
+
7
+ from yamlgraph.models.graph_schema import (
8
+ EdgeConfig,
9
+ GraphConfigSchema,
10
+ NodeConfig,
11
+ validate_graph_schema,
12
+ )
13
+ from yamlgraph.models.schemas import (
14
+ ErrorType,
15
+ GenericReport,
16
+ PipelineError,
17
+ )
18
+ from yamlgraph.models.state_builder import (
19
+ build_state_class,
20
+ create_initial_state,
21
+ )
22
+
23
+ __all__ = [
24
+ # Framework models
25
+ "ErrorType",
26
+ "PipelineError",
27
+ "GenericReport",
28
+ # Graph config schema
29
+ "GraphConfigSchema",
30
+ "NodeConfig",
31
+ "EdgeConfig",
32
+ "validate_graph_schema",
33
+ # Dynamic state generation
34
+ "build_state_class",
35
+ "create_initial_state",
36
+ ]