yamlgraph 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of yamlgraph might be problematic. Click here for more details.

Files changed (111) hide show
  1. examples/__init__.py +1 -0
  2. examples/storyboard/__init__.py +1 -0
  3. examples/storyboard/generate_videos.py +335 -0
  4. examples/storyboard/nodes/__init__.py +10 -0
  5. examples/storyboard/nodes/animated_character_node.py +248 -0
  6. examples/storyboard/nodes/animated_image_node.py +138 -0
  7. examples/storyboard/nodes/character_node.py +162 -0
  8. examples/storyboard/nodes/image_node.py +118 -0
  9. examples/storyboard/nodes/replicate_tool.py +238 -0
  10. examples/storyboard/retry_images.py +118 -0
  11. tests/__init__.py +1 -0
  12. tests/conftest.py +178 -0
  13. tests/integration/__init__.py +1 -0
  14. tests/integration/test_animated_storyboard.py +63 -0
  15. tests/integration/test_cli_commands.py +242 -0
  16. tests/integration/test_map_demo.py +50 -0
  17. tests/integration/test_memory_demo.py +281 -0
  18. tests/integration/test_pipeline_flow.py +105 -0
  19. tests/integration/test_providers.py +163 -0
  20. tests/integration/test_resume.py +75 -0
  21. tests/unit/__init__.py +1 -0
  22. tests/unit/test_agent_nodes.py +200 -0
  23. tests/unit/test_checkpointer.py +212 -0
  24. tests/unit/test_cli.py +121 -0
  25. tests/unit/test_cli_package.py +81 -0
  26. tests/unit/test_compile_graph_map.py +132 -0
  27. tests/unit/test_conditions_routing.py +253 -0
  28. tests/unit/test_config.py +93 -0
  29. tests/unit/test_conversation_memory.py +270 -0
  30. tests/unit/test_database.py +145 -0
  31. tests/unit/test_deprecation.py +104 -0
  32. tests/unit/test_executor.py +60 -0
  33. tests/unit/test_executor_async.py +179 -0
  34. tests/unit/test_export.py +150 -0
  35. tests/unit/test_expressions.py +178 -0
  36. tests/unit/test_format_prompt.py +145 -0
  37. tests/unit/test_generic_report.py +200 -0
  38. tests/unit/test_graph_commands.py +327 -0
  39. tests/unit/test_graph_loader.py +299 -0
  40. tests/unit/test_graph_schema.py +193 -0
  41. tests/unit/test_inline_schema.py +151 -0
  42. tests/unit/test_issues.py +164 -0
  43. tests/unit/test_jinja2_prompts.py +85 -0
  44. tests/unit/test_langsmith.py +319 -0
  45. tests/unit/test_llm_factory.py +109 -0
  46. tests/unit/test_llm_factory_async.py +118 -0
  47. tests/unit/test_loops.py +403 -0
  48. tests/unit/test_map_node.py +144 -0
  49. tests/unit/test_no_backward_compat.py +56 -0
  50. tests/unit/test_node_factory.py +225 -0
  51. tests/unit/test_prompts.py +166 -0
  52. tests/unit/test_python_nodes.py +198 -0
  53. tests/unit/test_reliability.py +298 -0
  54. tests/unit/test_result_export.py +234 -0
  55. tests/unit/test_router.py +296 -0
  56. tests/unit/test_sanitize.py +99 -0
  57. tests/unit/test_schema_loader.py +295 -0
  58. tests/unit/test_shell_tools.py +229 -0
  59. tests/unit/test_state_builder.py +331 -0
  60. tests/unit/test_state_builder_map.py +104 -0
  61. tests/unit/test_state_config.py +197 -0
  62. tests/unit/test_template.py +190 -0
  63. tests/unit/test_tool_nodes.py +129 -0
  64. yamlgraph/__init__.py +35 -0
  65. yamlgraph/builder.py +110 -0
  66. yamlgraph/cli/__init__.py +139 -0
  67. yamlgraph/cli/__main__.py +6 -0
  68. yamlgraph/cli/commands.py +232 -0
  69. yamlgraph/cli/deprecation.py +92 -0
  70. yamlgraph/cli/graph_commands.py +382 -0
  71. yamlgraph/cli/validators.py +37 -0
  72. yamlgraph/config.py +67 -0
  73. yamlgraph/constants.py +66 -0
  74. yamlgraph/error_handlers.py +226 -0
  75. yamlgraph/executor.py +275 -0
  76. yamlgraph/executor_async.py +122 -0
  77. yamlgraph/graph_loader.py +337 -0
  78. yamlgraph/map_compiler.py +138 -0
  79. yamlgraph/models/__init__.py +36 -0
  80. yamlgraph/models/graph_schema.py +141 -0
  81. yamlgraph/models/schemas.py +124 -0
  82. yamlgraph/models/state_builder.py +236 -0
  83. yamlgraph/node_factory.py +240 -0
  84. yamlgraph/routing.py +87 -0
  85. yamlgraph/schema_loader.py +160 -0
  86. yamlgraph/storage/__init__.py +17 -0
  87. yamlgraph/storage/checkpointer.py +72 -0
  88. yamlgraph/storage/database.py +320 -0
  89. yamlgraph/storage/export.py +269 -0
  90. yamlgraph/tools/__init__.py +1 -0
  91. yamlgraph/tools/agent.py +235 -0
  92. yamlgraph/tools/nodes.py +124 -0
  93. yamlgraph/tools/python_tool.py +178 -0
  94. yamlgraph/tools/shell.py +205 -0
  95. yamlgraph/utils/__init__.py +47 -0
  96. yamlgraph/utils/conditions.py +157 -0
  97. yamlgraph/utils/expressions.py +111 -0
  98. yamlgraph/utils/langsmith.py +308 -0
  99. yamlgraph/utils/llm_factory.py +118 -0
  100. yamlgraph/utils/llm_factory_async.py +105 -0
  101. yamlgraph/utils/logging.py +127 -0
  102. yamlgraph/utils/prompts.py +116 -0
  103. yamlgraph/utils/sanitize.py +98 -0
  104. yamlgraph/utils/template.py +102 -0
  105. yamlgraph/utils/validators.py +181 -0
  106. yamlgraph-0.1.1.dist-info/METADATA +854 -0
  107. yamlgraph-0.1.1.dist-info/RECORD +111 -0
  108. yamlgraph-0.1.1.dist-info/WHEEL +5 -0
  109. yamlgraph-0.1.1.dist-info/entry_points.txt +2 -0
  110. yamlgraph-0.1.1.dist-info/licenses/LICENSE +21 -0
  111. yamlgraph-0.1.1.dist-info/top_level.txt +3 -0
@@ -0,0 +1,124 @@
1
+ """Pydantic models for structured LLM outputs.
2
+
3
+ This module contains FRAMEWORK models only - models used by the framework itself.
4
+ Demo-specific output schemas are defined inline in graph YAML files.
5
+ """
6
+
7
+ from datetime import datetime
8
+ from enum import Enum
9
+ from typing import Any
10
+
11
+ from pydantic import BaseModel, Field
12
+
13
+ # =============================================================================
14
+ # Error Types
15
+ # =============================================================================
16
+
17
+
18
+ class ErrorType(str, Enum):
19
+ """Types of errors that can occur in the pipeline."""
20
+
21
+ LLM_ERROR = "llm_error" # LLM API errors (rate limit, timeout, etc.)
22
+ VALIDATION_ERROR = "validation_error" # Pydantic validation failures
23
+ PROMPT_ERROR = "prompt_error" # Missing prompt, template errors
24
+ STATE_ERROR = "state_error" # Missing required state data
25
+ UNKNOWN_ERROR = "unknown_error" # Catch-all
26
+
27
+
28
+ class PipelineError(BaseModel):
29
+ """Structured error information for pipeline failures."""
30
+
31
+ type: ErrorType = Field(description="Category of error")
32
+ message: str = Field(description="Human-readable error message")
33
+ node: str = Field(description="Node where error occurred")
34
+ timestamp: datetime = Field(default_factory=datetime.now)
35
+ retryable: bool = Field(
36
+ default=False, description="Whether this error can be retried"
37
+ )
38
+ details: dict[str, Any] = Field(
39
+ default_factory=dict, description="Additional error context"
40
+ )
41
+
42
+ @classmethod
43
+ def from_exception(
44
+ cls, e: Exception, node: str, error_type: ErrorType | None = None
45
+ ) -> "PipelineError":
46
+ """Create a PipelineError from an exception.
47
+
48
+ Args:
49
+ e: The exception that occurred
50
+ node: Name of the node where error occurred
51
+ error_type: Optional explicit error type
52
+
53
+ Returns:
54
+ PipelineError instance
55
+ """
56
+ # Infer error type from exception
57
+ if error_type is None:
58
+ exc_name = type(e).__name__.lower()
59
+ if "rate" in exc_name or "timeout" in exc_name or "api" in exc_name:
60
+ error_type = ErrorType.LLM_ERROR
61
+ retryable = True
62
+ elif "validation" in exc_name:
63
+ error_type = ErrorType.VALIDATION_ERROR
64
+ retryable = False
65
+ elif "file" in exc_name or "prompt" in exc_name:
66
+ error_type = ErrorType.PROMPT_ERROR
67
+ retryable = False
68
+ else:
69
+ error_type = ErrorType.UNKNOWN_ERROR
70
+ retryable = False
71
+ else:
72
+ retryable = error_type == ErrorType.LLM_ERROR
73
+
74
+ return cls(
75
+ type=error_type,
76
+ message=str(e),
77
+ node=node,
78
+ retryable=retryable,
79
+ details={"exception_type": type(e).__name__},
80
+ )
81
+
82
+
83
+ # =============================================================================
84
+ # Generic Report Model (Flexible for Any Use Case)
85
+ # =============================================================================
86
+
87
+
88
+ class GenericReport(BaseModel):
89
+ """Flexible report structure for any use case.
90
+
91
+ Use this when you don't need a custom schema - works for most
92
+ analysis and summary tasks. The LLM can populate any combination
93
+ of the optional fields as needed.
94
+
95
+ Example usage in graph YAML:
96
+ nodes:
97
+ analyze:
98
+ type: llm
99
+ prompt: my_analysis
100
+ output_model: yamlgraph.models.GenericReport
101
+
102
+ Example prompts can request specific sections:
103
+ "Analyze the repository and provide:
104
+ - A summary of findings
105
+ - Key findings as bullet points
106
+ - Recommendations for improvement"
107
+ """
108
+
109
+ title: str = Field(description="Report title")
110
+ summary: str = Field(description="Executive summary")
111
+ sections: dict[str, Any] = Field(
112
+ default_factory=dict,
113
+ description="Named sections with any content (strings, dicts, lists)",
114
+ )
115
+ findings: list[str] = Field(
116
+ default_factory=list, description="Key findings or bullet points"
117
+ )
118
+ recommendations: list[str] = Field(
119
+ default_factory=list, description="Suggested actions or areas to focus on"
120
+ )
121
+ metadata: dict[str, Any] = Field(
122
+ default_factory=dict,
123
+ description="Additional key-value data (author, version, tags, etc.)",
124
+ )
@@ -0,0 +1,236 @@
1
+ """Dynamic state class generation from graph configuration.
2
+
3
+ Builds TypedDict programmatically from YAML graph config, eliminating
4
+ the need for state_class coupling between YAML and Python.
5
+ """
6
+
7
+ import logging
8
+ from operator import add
9
+ from typing import Annotated, Any, TypedDict
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ def sorted_add(existing: list, new: list) -> list:
15
+ """Reducer that adds items and sorts by _map_index if present.
16
+
17
+ Used for map node fan-in to guarantee order regardless of
18
+ parallel execution timing.
19
+
20
+ Args:
21
+ existing: Current list in state
22
+ new: New items to add
23
+
24
+ Returns:
25
+ Combined list sorted by _map_index (if items have it)
26
+ """
27
+ combined = (existing or []) + (new or [])
28
+
29
+ # Sort by _map_index if items have it
30
+ if combined and isinstance(combined[0], dict) and "_map_index" in combined[0]:
31
+ combined = sorted(combined, key=lambda x: x.get("_map_index", 0))
32
+
33
+ return combined
34
+
35
+
36
+ # =============================================================================
37
+ # Base Fields - Always included in generated state
38
+ # =============================================================================
39
+
40
+ # Infrastructure fields present in all graphs
41
+ BASE_FIELDS: dict[str, type] = {
42
+ # Core tracking
43
+ "thread_id": str,
44
+ "current_step": str,
45
+ # Error handling - singular for current error
46
+ "error": Any,
47
+ # Error handling with reducer (accumulates)
48
+ "errors": Annotated[list, add],
49
+ # Messages with reducer (accumulates)
50
+ "messages": Annotated[list, add],
51
+ # Loop tracking
52
+ "_loop_counts": dict,
53
+ "_loop_limit_reached": bool,
54
+ "_agent_iterations": int,
55
+ "_agent_limit_reached": bool,
56
+ # Timestamps
57
+ "started_at": Any,
58
+ "completed_at": Any,
59
+ }
60
+
61
+ # Common input fields used across graph types
62
+ # These are always included to support --var inputs
63
+ COMMON_INPUT_FIELDS: dict[str, type] = {
64
+ "input": str, # Agent prompt input
65
+ "topic": str, # Content generation topic
66
+ "style": str, # Writing style
67
+ "word_count": int, # Target word count
68
+ "message": str, # Router message input
69
+ }
70
+
71
+ # Type mapping for YAML state config
72
+ TYPE_MAP: dict[str, type] = {
73
+ "str": str,
74
+ "string": str,
75
+ "int": int,
76
+ "integer": int,
77
+ "float": float,
78
+ "bool": bool,
79
+ "boolean": bool,
80
+ "list": list,
81
+ "dict": dict,
82
+ "any": Any,
83
+ }
84
+
85
+
86
+ def parse_state_config(state_config: dict) -> dict[str, type]:
87
+ """Parse YAML state section into field types.
88
+
89
+ Supports simple type strings:
90
+ state:
91
+ concept: str
92
+ count: int
93
+
94
+ Args:
95
+ state_config: Dict from YAML 'state' section
96
+
97
+ Returns:
98
+ Dict of field_name -> Python type
99
+ """
100
+ fields: dict[str, type] = {}
101
+
102
+ for field_name, type_spec in state_config.items():
103
+ if isinstance(type_spec, str):
104
+ # Simple type: "str", "int", etc.
105
+ normalized = type_spec.lower()
106
+ if normalized not in TYPE_MAP:
107
+ supported = ", ".join(sorted(set(TYPE_MAP.keys())))
108
+ logger.warning(
109
+ f"Unknown type '{type_spec}' for state field '{field_name}'. "
110
+ f"Supported types: {supported}. Defaulting to Any."
111
+ )
112
+ python_type = TYPE_MAP.get(normalized, Any)
113
+ fields[field_name] = python_type
114
+ else:
115
+ # Unknown format, use Any
116
+ logger.warning(
117
+ f"Invalid type specification for state field '{field_name}': "
118
+ f"expected string, got {type(type_spec).__name__}. Defaulting to Any."
119
+ )
120
+ fields[field_name] = Any
121
+
122
+ return fields
123
+
124
+
125
+ def build_state_class(config: dict) -> type:
126
+ """Build TypedDict state class from graph configuration.
127
+
128
+ Dynamically generates a TypedDict with:
129
+ - Base infrastructure fields (errors, messages, thread_id, etc.)
130
+ - Common input fields (topic, style, input, message, etc.)
131
+ - Custom fields from YAML 'state' section
132
+ - Fields extracted from node state_key
133
+ - Special fields for agent/router node types
134
+
135
+ Args:
136
+ config: Parsed YAML graph configuration dict
137
+
138
+ Returns:
139
+ TypedDict class with total=False (all fields optional)
140
+ """
141
+ # Start with base and common fields
142
+ fields: dict[str, type] = {}
143
+ fields.update(BASE_FIELDS)
144
+ fields.update(COMMON_INPUT_FIELDS)
145
+
146
+ # Add custom state fields from YAML 'state' section
147
+ state_config = config.get("state", {})
148
+ custom_fields = parse_state_config(state_config)
149
+ fields.update(custom_fields)
150
+
151
+ # Extract fields from nodes
152
+ nodes = config.get("nodes", {})
153
+ node_fields = extract_node_fields(nodes)
154
+ fields.update(node_fields)
155
+
156
+ # Build TypedDict programmatically
157
+ return TypedDict("GraphState", fields, total=False)
158
+
159
+
160
+ def extract_node_fields(nodes: dict) -> dict[str, type]:
161
+ """Extract state fields from node configurations.
162
+
163
+ Analyzes node configs to determine required state fields:
164
+ - state_key: Where node stores its output
165
+ - type: agent → adds input, _tool_results
166
+ - type: router → adds _route
167
+
168
+ Args:
169
+ nodes: Dict of node_name -> node_config
170
+
171
+ Returns:
172
+ Dict of field_name -> type for the state
173
+ """
174
+ fields: dict[str, type] = {}
175
+
176
+ for node_name, node_config in nodes.items():
177
+ if not isinstance(node_config, dict):
178
+ continue
179
+
180
+ # state_key → Any (accepts Pydantic models)
181
+ if state_key := node_config.get("state_key"):
182
+ fields[state_key] = Any
183
+
184
+ # Node type-specific fields
185
+ node_type = node_config.get("type", "llm")
186
+
187
+ if node_type == "agent":
188
+ fields["input"] = str
189
+ fields["_tool_results"] = list
190
+
191
+ elif node_type == "router":
192
+ fields["_route"] = str
193
+
194
+ elif node_type == "map":
195
+ # Map node collect field needs sorted reducer for ordered fan-in
196
+ if collect_key := node_config.get("collect"):
197
+ fields[collect_key] = Annotated[list, sorted_add]
198
+
199
+ return fields
200
+
201
+
202
+ def create_initial_state(
203
+ topic: str = "",
204
+ style: str = "informative",
205
+ word_count: int = 300,
206
+ thread_id: str | None = None,
207
+ **kwargs: Any,
208
+ ) -> dict[str, Any]:
209
+ """Create an initial state for a new pipeline run.
210
+
211
+ Args:
212
+ topic: The topic to generate content about
213
+ style: Writing style (default: informative)
214
+ word_count: Target word count (default: 300)
215
+ thread_id: Optional thread ID (auto-generated if not provided)
216
+ **kwargs: Additional state fields (e.g., input for agents)
217
+
218
+ Returns:
219
+ Initialized state dictionary
220
+ """
221
+ import uuid
222
+ from datetime import datetime
223
+
224
+ return {
225
+ "thread_id": thread_id or uuid.uuid4().hex[:16],
226
+ "topic": topic,
227
+ "style": style,
228
+ "word_count": word_count,
229
+ "current_step": "init",
230
+ "error": None,
231
+ "errors": [],
232
+ "messages": [],
233
+ "started_at": datetime.now(),
234
+ "completed_at": None,
235
+ **kwargs,
236
+ }
@@ -0,0 +1,240 @@
1
+ """Node function factory for YAML-defined graphs.
2
+
3
+ Creates LangGraph node functions from YAML configuration with support for:
4
+ - Resume (skip if output exists)
5
+ - Error handling (skip, retry, fail, fallback)
6
+ - Router nodes with dynamic routing
7
+ - Loop counting and limits
8
+ """
9
+
10
+ import logging
11
+ from typing import Any, Callable
12
+
13
+ from yamlgraph.constants import ErrorHandler, NodeType
14
+ from yamlgraph.error_handlers import (
15
+ check_loop_limit,
16
+ check_requirements,
17
+ handle_default,
18
+ handle_fail,
19
+ handle_fallback,
20
+ handle_retry,
21
+ handle_skip,
22
+ )
23
+ from yamlgraph.executor import execute_prompt
24
+ from yamlgraph.utils.expressions import resolve_template
25
+ from yamlgraph.utils.prompts import resolve_prompt_path
26
+
27
+ # Type alias for dynamic state
28
+ GraphState = dict[str, Any]
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ def resolve_class(class_path: str) -> type:
34
+ """Dynamically import and return a class from a module path.
35
+
36
+ Args:
37
+ class_path: Full path like "yamlgraph.models.GenericReport" or short name
38
+
39
+ Returns:
40
+ The class object
41
+ """
42
+ import importlib
43
+
44
+ parts = class_path.rsplit(".", 1)
45
+ if len(parts) != 2:
46
+ # Try to find in yamlgraph.models.schemas
47
+ try:
48
+ from yamlgraph.models import schemas
49
+
50
+ if hasattr(schemas, class_path):
51
+ return getattr(schemas, class_path)
52
+ except ImportError:
53
+ pass
54
+ raise ValueError(f"Invalid class path: {class_path}")
55
+
56
+ module_path, class_name = parts
57
+ module = importlib.import_module(module_path)
58
+ return getattr(module, class_name)
59
+
60
+
61
+ def get_output_model_for_node(
62
+ node_config: dict[str, Any], prompts_dir: str | None = None
63
+ ) -> type | None:
64
+ """Get output model for a node, checking inline schema if no explicit model.
65
+
66
+ Priority:
67
+ 1. Explicit output_model in node config (class path)
68
+ 2. Inline schema in prompt YAML file
69
+ 3. None (raw string output)
70
+
71
+ Args:
72
+ node_config: Node configuration from YAML
73
+ prompts_dir: Base prompts directory
74
+
75
+ Returns:
76
+ Pydantic model class or None
77
+ """
78
+ # 1. Check for explicit output_model
79
+ if model_path := node_config.get("output_model"):
80
+ return resolve_class(model_path)
81
+
82
+ # 2. Check for inline schema in prompt YAML
83
+ prompt_name = node_config.get("prompt")
84
+ if prompt_name:
85
+ try:
86
+ from yamlgraph.schema_loader import load_schema_from_yaml
87
+
88
+ yaml_path = resolve_prompt_path(prompt_name, prompts_dir)
89
+ return load_schema_from_yaml(yaml_path)
90
+ except FileNotFoundError:
91
+ # Prompt file doesn't exist yet - will fail later
92
+ pass
93
+
94
+ # 3. No output model
95
+ return None
96
+
97
+
98
+ def create_node_function(
99
+ node_name: str,
100
+ node_config: dict,
101
+ defaults: dict,
102
+ ) -> Callable[[GraphState], dict]:
103
+ """Create a node function from YAML config.
104
+
105
+ Args:
106
+ node_name: Name of the node
107
+ node_config: Node configuration from YAML
108
+ defaults: Default configuration values
109
+
110
+ Returns:
111
+ Node function compatible with LangGraph
112
+ """
113
+ node_type = node_config.get("type", NodeType.LLM)
114
+ prompt_name = node_config.get("prompt")
115
+
116
+ # Resolve output model (explicit > inline schema > None)
117
+ output_model = get_output_model_for_node(node_config)
118
+
119
+ # Get config values (node > defaults)
120
+ temperature = node_config.get("temperature", defaults.get("temperature", 0.7))
121
+ provider = node_config.get("provider", defaults.get("provider"))
122
+ state_key = node_config.get("state_key", node_name)
123
+ variable_templates = node_config.get("variables", {})
124
+ requires = node_config.get("requires", [])
125
+
126
+ # Error handling
127
+ on_error = node_config.get("on_error")
128
+ max_retries = node_config.get("max_retries", 3)
129
+ fallback_config = node_config.get("fallback", {})
130
+ fallback_provider = fallback_config.get("provider") if fallback_config else None
131
+
132
+ # Router config
133
+ routes = node_config.get("routes", {})
134
+ default_route = node_config.get("default_route")
135
+
136
+ # Loop limit
137
+ loop_limit = node_config.get("loop_limit")
138
+
139
+ # Skip if exists (default true for resume support, false for loop nodes)
140
+ skip_if_exists = node_config.get("skip_if_exists", True)
141
+
142
+ def node_fn(state: dict) -> dict:
143
+ """Generated node function."""
144
+ loop_counts = dict(state.get("_loop_counts") or {})
145
+ current_count = loop_counts.get(node_name, 0)
146
+
147
+ # Check loop limit
148
+ if check_loop_limit(node_name, loop_limit, current_count):
149
+ return {"_loop_limit_reached": True, "current_step": node_name}
150
+
151
+ loop_counts[node_name] = current_count + 1
152
+
153
+ # Skip if output exists (resume support) - disabled for loop nodes
154
+ if skip_if_exists and state.get(state_key) is not None:
155
+ logger.info(f"Node {node_name} skipped - {state_key} already in state")
156
+ return {"current_step": node_name, "_loop_counts": loop_counts}
157
+
158
+ # Check requirements
159
+ if error := check_requirements(requires, state, node_name):
160
+ return {
161
+ "errors": [error],
162
+ "current_step": node_name,
163
+ "_loop_counts": loop_counts,
164
+ }
165
+
166
+ # Resolve variables
167
+ variables = {}
168
+ for key, template in variable_templates.items():
169
+ resolved = resolve_template(template, state)
170
+ if isinstance(resolved, list):
171
+ resolved = ", ".join(str(item) for item in resolved)
172
+ variables[key] = resolved
173
+
174
+ def attempt_execute(use_provider: str | None) -> tuple[Any, Exception | None]:
175
+ try:
176
+ result = execute_prompt(
177
+ prompt_name=prompt_name,
178
+ variables=variables,
179
+ output_model=output_model,
180
+ temperature=temperature,
181
+ provider=use_provider,
182
+ )
183
+ return result, None
184
+ except Exception as e:
185
+ return None, e
186
+
187
+ result, error = attempt_execute(provider)
188
+
189
+ if error is None:
190
+ logger.info(f"Node {node_name} completed successfully")
191
+ update = {
192
+ state_key: result,
193
+ "current_step": node_name,
194
+ "_loop_counts": loop_counts,
195
+ }
196
+
197
+ # Router: add _route to state
198
+ if node_type == NodeType.ROUTER and routes:
199
+ route_key = getattr(result, "tone", None) or getattr(
200
+ result, "intent", None
201
+ )
202
+ if route_key and route_key in routes:
203
+ update["_route"] = routes[route_key]
204
+ elif default_route:
205
+ update["_route"] = default_route
206
+ else:
207
+ update["_route"] = list(routes.values())[0]
208
+ logger.info(f"Router {node_name} routing to: {update['_route']}")
209
+ return update
210
+
211
+ # Error handling - dispatch to strategy handlers
212
+ if on_error == ErrorHandler.SKIP:
213
+ handle_skip(node_name, error, loop_counts)
214
+ return {"current_step": node_name, "_loop_counts": loop_counts}
215
+
216
+ elif on_error == ErrorHandler.FAIL:
217
+ handle_fail(node_name, error)
218
+
219
+ elif on_error == ErrorHandler.RETRY:
220
+ result = handle_retry(
221
+ node_name,
222
+ lambda: attempt_execute(provider),
223
+ max_retries,
224
+ )
225
+ return result.to_state_update(state_key, node_name, loop_counts)
226
+
227
+ elif on_error == ErrorHandler.FALLBACK and fallback_provider:
228
+ result = handle_fallback(
229
+ node_name,
230
+ attempt_execute,
231
+ fallback_provider,
232
+ )
233
+ return result.to_state_update(state_key, node_name, loop_counts)
234
+
235
+ else:
236
+ result = handle_default(node_name, error)
237
+ return result.to_state_update(state_key, node_name, loop_counts)
238
+
239
+ node_fn.__name__ = f"{node_name}_node"
240
+ return node_fn
yamlgraph/routing.py ADDED
@@ -0,0 +1,87 @@
1
+ """Routing utilities for LangGraph edge conditions.
2
+
3
+ Provides factory functions for creating router functions that determine
4
+ which node to route to based on state values and expressions.
5
+ """
6
+
7
+ import logging
8
+ from collections.abc import Callable
9
+ from typing import Any
10
+
11
+ from langgraph.graph import END
12
+
13
+ from yamlgraph.utils.conditions import evaluate_condition
14
+
15
+ # Type alias for dynamic state
16
+ GraphState = dict[str, Any]
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ def make_router_fn(targets: list[str]) -> Callable[[dict], str]:
22
+ """Create a router function that reads _route from state.
23
+
24
+ Used for type: router nodes with conditional edges to multiple targets.
25
+
26
+ NOTE: Use `state: dict` not `state: GraphState` - type hints cause
27
+ LangGraph to filter state fields. See docs/debug-router-type-hints.md
28
+
29
+ Args:
30
+ targets: List of valid target node names
31
+
32
+ Returns:
33
+ Router function that returns the target node name
34
+ """
35
+
36
+ def router_fn(state: dict) -> str:
37
+ route = state.get("_route")
38
+ logger.debug(f"Router: _route={route}, targets={targets}")
39
+ if route and route in targets:
40
+ logger.debug(f"Router: matched route {route}")
41
+ return route
42
+ # Default to first target
43
+ logger.debug(f"Router: defaulting to {targets[0]}")
44
+ return targets[0]
45
+
46
+ return router_fn
47
+
48
+
49
+ def make_expr_router_fn(
50
+ edges: list[tuple[str, str]],
51
+ source_node: str,
52
+ ) -> Callable[[GraphState], str]:
53
+ """Create router that evaluates expression conditions.
54
+
55
+ Used for reflexion-style loops with expression-based conditions
56
+ like "critique.score < 0.8".
57
+
58
+ Args:
59
+ edges: List of (condition, target) tuples
60
+ source_node: Name of the source node (for logging)
61
+
62
+ Returns:
63
+ Router function that evaluates conditions and returns target
64
+ """
65
+
66
+ def expr_router_fn(state: GraphState) -> str:
67
+ # Check loop limit first
68
+ if state.get("_loop_limit_reached"):
69
+ return END
70
+
71
+ for condition, target in edges:
72
+ try:
73
+ if evaluate_condition(condition, state):
74
+ logger.debug(
75
+ f"Condition '{condition}' matched, routing to {target}"
76
+ )
77
+ return target
78
+ except ValueError as e:
79
+ logger.warning(f"Failed to evaluate condition '{condition}': {e}")
80
+ # No condition matched - this shouldn't happen with well-formed graphs
81
+ logger.warning(f"No condition matched for {source_node}, defaulting to END")
82
+ return END
83
+
84
+ return expr_router_fn
85
+
86
+
87
+ __all__ = ["make_router_fn", "make_expr_router_fn"]