yamlgraph 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of yamlgraph might be problematic. Click here for more details.
- examples/__init__.py +1 -0
- examples/storyboard/__init__.py +1 -0
- examples/storyboard/generate_videos.py +335 -0
- examples/storyboard/nodes/__init__.py +10 -0
- examples/storyboard/nodes/animated_character_node.py +248 -0
- examples/storyboard/nodes/animated_image_node.py +138 -0
- examples/storyboard/nodes/character_node.py +162 -0
- examples/storyboard/nodes/image_node.py +118 -0
- examples/storyboard/nodes/replicate_tool.py +238 -0
- examples/storyboard/retry_images.py +118 -0
- tests/__init__.py +1 -0
- tests/conftest.py +178 -0
- tests/integration/__init__.py +1 -0
- tests/integration/test_animated_storyboard.py +63 -0
- tests/integration/test_cli_commands.py +242 -0
- tests/integration/test_map_demo.py +50 -0
- tests/integration/test_memory_demo.py +281 -0
- tests/integration/test_pipeline_flow.py +105 -0
- tests/integration/test_providers.py +163 -0
- tests/integration/test_resume.py +75 -0
- tests/unit/__init__.py +1 -0
- tests/unit/test_agent_nodes.py +200 -0
- tests/unit/test_checkpointer.py +212 -0
- tests/unit/test_cli.py +121 -0
- tests/unit/test_cli_package.py +81 -0
- tests/unit/test_compile_graph_map.py +132 -0
- tests/unit/test_conditions_routing.py +253 -0
- tests/unit/test_config.py +93 -0
- tests/unit/test_conversation_memory.py +270 -0
- tests/unit/test_database.py +145 -0
- tests/unit/test_deprecation.py +104 -0
- tests/unit/test_executor.py +60 -0
- tests/unit/test_executor_async.py +179 -0
- tests/unit/test_export.py +150 -0
- tests/unit/test_expressions.py +178 -0
- tests/unit/test_format_prompt.py +145 -0
- tests/unit/test_generic_report.py +200 -0
- tests/unit/test_graph_commands.py +327 -0
- tests/unit/test_graph_loader.py +299 -0
- tests/unit/test_graph_schema.py +193 -0
- tests/unit/test_inline_schema.py +151 -0
- tests/unit/test_issues.py +164 -0
- tests/unit/test_jinja2_prompts.py +85 -0
- tests/unit/test_langsmith.py +319 -0
- tests/unit/test_llm_factory.py +109 -0
- tests/unit/test_llm_factory_async.py +118 -0
- tests/unit/test_loops.py +403 -0
- tests/unit/test_map_node.py +144 -0
- tests/unit/test_no_backward_compat.py +56 -0
- tests/unit/test_node_factory.py +225 -0
- tests/unit/test_prompts.py +166 -0
- tests/unit/test_python_nodes.py +198 -0
- tests/unit/test_reliability.py +298 -0
- tests/unit/test_result_export.py +234 -0
- tests/unit/test_router.py +296 -0
- tests/unit/test_sanitize.py +99 -0
- tests/unit/test_schema_loader.py +295 -0
- tests/unit/test_shell_tools.py +229 -0
- tests/unit/test_state_builder.py +331 -0
- tests/unit/test_state_builder_map.py +104 -0
- tests/unit/test_state_config.py +197 -0
- tests/unit/test_template.py +190 -0
- tests/unit/test_tool_nodes.py +129 -0
- yamlgraph/__init__.py +35 -0
- yamlgraph/builder.py +110 -0
- yamlgraph/cli/__init__.py +139 -0
- yamlgraph/cli/__main__.py +6 -0
- yamlgraph/cli/commands.py +232 -0
- yamlgraph/cli/deprecation.py +92 -0
- yamlgraph/cli/graph_commands.py +382 -0
- yamlgraph/cli/validators.py +37 -0
- yamlgraph/config.py +67 -0
- yamlgraph/constants.py +66 -0
- yamlgraph/error_handlers.py +226 -0
- yamlgraph/executor.py +275 -0
- yamlgraph/executor_async.py +122 -0
- yamlgraph/graph_loader.py +337 -0
- yamlgraph/map_compiler.py +138 -0
- yamlgraph/models/__init__.py +36 -0
- yamlgraph/models/graph_schema.py +141 -0
- yamlgraph/models/schemas.py +124 -0
- yamlgraph/models/state_builder.py +236 -0
- yamlgraph/node_factory.py +240 -0
- yamlgraph/routing.py +87 -0
- yamlgraph/schema_loader.py +160 -0
- yamlgraph/storage/__init__.py +17 -0
- yamlgraph/storage/checkpointer.py +72 -0
- yamlgraph/storage/database.py +320 -0
- yamlgraph/storage/export.py +269 -0
- yamlgraph/tools/__init__.py +1 -0
- yamlgraph/tools/agent.py +235 -0
- yamlgraph/tools/nodes.py +124 -0
- yamlgraph/tools/python_tool.py +178 -0
- yamlgraph/tools/shell.py +205 -0
- yamlgraph/utils/__init__.py +47 -0
- yamlgraph/utils/conditions.py +157 -0
- yamlgraph/utils/expressions.py +111 -0
- yamlgraph/utils/langsmith.py +308 -0
- yamlgraph/utils/llm_factory.py +118 -0
- yamlgraph/utils/llm_factory_async.py +105 -0
- yamlgraph/utils/logging.py +127 -0
- yamlgraph/utils/prompts.py +116 -0
- yamlgraph/utils/sanitize.py +98 -0
- yamlgraph/utils/template.py +102 -0
- yamlgraph/utils/validators.py +181 -0
- yamlgraph-0.1.1.dist-info/METADATA +854 -0
- yamlgraph-0.1.1.dist-info/RECORD +111 -0
- yamlgraph-0.1.1.dist-info/WHEEL +5 -0
- yamlgraph-0.1.1.dist-info/entry_points.txt +2 -0
- yamlgraph-0.1.1.dist-info/licenses/LICENSE +21 -0
- yamlgraph-0.1.1.dist-info/top_level.txt +3 -0
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
"""Pydantic models for structured LLM outputs.
|
|
2
|
+
|
|
3
|
+
This module contains FRAMEWORK models only - models used by the framework itself.
|
|
4
|
+
Demo-specific output schemas are defined inline in graph YAML files.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from datetime import datetime
|
|
8
|
+
from enum import Enum
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
from pydantic import BaseModel, Field
|
|
12
|
+
|
|
13
|
+
# =============================================================================
|
|
14
|
+
# Error Types
|
|
15
|
+
# =============================================================================
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ErrorType(str, Enum):
|
|
19
|
+
"""Types of errors that can occur in the pipeline."""
|
|
20
|
+
|
|
21
|
+
LLM_ERROR = "llm_error" # LLM API errors (rate limit, timeout, etc.)
|
|
22
|
+
VALIDATION_ERROR = "validation_error" # Pydantic validation failures
|
|
23
|
+
PROMPT_ERROR = "prompt_error" # Missing prompt, template errors
|
|
24
|
+
STATE_ERROR = "state_error" # Missing required state data
|
|
25
|
+
UNKNOWN_ERROR = "unknown_error" # Catch-all
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class PipelineError(BaseModel):
|
|
29
|
+
"""Structured error information for pipeline failures."""
|
|
30
|
+
|
|
31
|
+
type: ErrorType = Field(description="Category of error")
|
|
32
|
+
message: str = Field(description="Human-readable error message")
|
|
33
|
+
node: str = Field(description="Node where error occurred")
|
|
34
|
+
timestamp: datetime = Field(default_factory=datetime.now)
|
|
35
|
+
retryable: bool = Field(
|
|
36
|
+
default=False, description="Whether this error can be retried"
|
|
37
|
+
)
|
|
38
|
+
details: dict[str, Any] = Field(
|
|
39
|
+
default_factory=dict, description="Additional error context"
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
@classmethod
|
|
43
|
+
def from_exception(
|
|
44
|
+
cls, e: Exception, node: str, error_type: ErrorType | None = None
|
|
45
|
+
) -> "PipelineError":
|
|
46
|
+
"""Create a PipelineError from an exception.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
e: The exception that occurred
|
|
50
|
+
node: Name of the node where error occurred
|
|
51
|
+
error_type: Optional explicit error type
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
PipelineError instance
|
|
55
|
+
"""
|
|
56
|
+
# Infer error type from exception
|
|
57
|
+
if error_type is None:
|
|
58
|
+
exc_name = type(e).__name__.lower()
|
|
59
|
+
if "rate" in exc_name or "timeout" in exc_name or "api" in exc_name:
|
|
60
|
+
error_type = ErrorType.LLM_ERROR
|
|
61
|
+
retryable = True
|
|
62
|
+
elif "validation" in exc_name:
|
|
63
|
+
error_type = ErrorType.VALIDATION_ERROR
|
|
64
|
+
retryable = False
|
|
65
|
+
elif "file" in exc_name or "prompt" in exc_name:
|
|
66
|
+
error_type = ErrorType.PROMPT_ERROR
|
|
67
|
+
retryable = False
|
|
68
|
+
else:
|
|
69
|
+
error_type = ErrorType.UNKNOWN_ERROR
|
|
70
|
+
retryable = False
|
|
71
|
+
else:
|
|
72
|
+
retryable = error_type == ErrorType.LLM_ERROR
|
|
73
|
+
|
|
74
|
+
return cls(
|
|
75
|
+
type=error_type,
|
|
76
|
+
message=str(e),
|
|
77
|
+
node=node,
|
|
78
|
+
retryable=retryable,
|
|
79
|
+
details={"exception_type": type(e).__name__},
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
# =============================================================================
|
|
84
|
+
# Generic Report Model (Flexible for Any Use Case)
|
|
85
|
+
# =============================================================================
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class GenericReport(BaseModel):
|
|
89
|
+
"""Flexible report structure for any use case.
|
|
90
|
+
|
|
91
|
+
Use this when you don't need a custom schema - works for most
|
|
92
|
+
analysis and summary tasks. The LLM can populate any combination
|
|
93
|
+
of the optional fields as needed.
|
|
94
|
+
|
|
95
|
+
Example usage in graph YAML:
|
|
96
|
+
nodes:
|
|
97
|
+
analyze:
|
|
98
|
+
type: llm
|
|
99
|
+
prompt: my_analysis
|
|
100
|
+
output_model: yamlgraph.models.GenericReport
|
|
101
|
+
|
|
102
|
+
Example prompts can request specific sections:
|
|
103
|
+
"Analyze the repository and provide:
|
|
104
|
+
- A summary of findings
|
|
105
|
+
- Key findings as bullet points
|
|
106
|
+
- Recommendations for improvement"
|
|
107
|
+
"""
|
|
108
|
+
|
|
109
|
+
title: str = Field(description="Report title")
|
|
110
|
+
summary: str = Field(description="Executive summary")
|
|
111
|
+
sections: dict[str, Any] = Field(
|
|
112
|
+
default_factory=dict,
|
|
113
|
+
description="Named sections with any content (strings, dicts, lists)",
|
|
114
|
+
)
|
|
115
|
+
findings: list[str] = Field(
|
|
116
|
+
default_factory=list, description="Key findings or bullet points"
|
|
117
|
+
)
|
|
118
|
+
recommendations: list[str] = Field(
|
|
119
|
+
default_factory=list, description="Suggested actions or areas to focus on"
|
|
120
|
+
)
|
|
121
|
+
metadata: dict[str, Any] = Field(
|
|
122
|
+
default_factory=dict,
|
|
123
|
+
description="Additional key-value data (author, version, tags, etc.)",
|
|
124
|
+
)
|
|
@@ -0,0 +1,236 @@
|
|
|
1
|
+
"""Dynamic state class generation from graph configuration.
|
|
2
|
+
|
|
3
|
+
Builds TypedDict programmatically from YAML graph config, eliminating
|
|
4
|
+
the need for state_class coupling between YAML and Python.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
from operator import add
|
|
9
|
+
from typing import Annotated, Any, TypedDict
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def sorted_add(existing: list, new: list) -> list:
|
|
15
|
+
"""Reducer that adds items and sorts by _map_index if present.
|
|
16
|
+
|
|
17
|
+
Used for map node fan-in to guarantee order regardless of
|
|
18
|
+
parallel execution timing.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
existing: Current list in state
|
|
22
|
+
new: New items to add
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
Combined list sorted by _map_index (if items have it)
|
|
26
|
+
"""
|
|
27
|
+
combined = (existing or []) + (new or [])
|
|
28
|
+
|
|
29
|
+
# Sort by _map_index if items have it
|
|
30
|
+
if combined and isinstance(combined[0], dict) and "_map_index" in combined[0]:
|
|
31
|
+
combined = sorted(combined, key=lambda x: x.get("_map_index", 0))
|
|
32
|
+
|
|
33
|
+
return combined
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
# =============================================================================
|
|
37
|
+
# Base Fields - Always included in generated state
|
|
38
|
+
# =============================================================================
|
|
39
|
+
|
|
40
|
+
# Infrastructure fields present in all graphs
|
|
41
|
+
BASE_FIELDS: dict[str, type] = {
|
|
42
|
+
# Core tracking
|
|
43
|
+
"thread_id": str,
|
|
44
|
+
"current_step": str,
|
|
45
|
+
# Error handling - singular for current error
|
|
46
|
+
"error": Any,
|
|
47
|
+
# Error handling with reducer (accumulates)
|
|
48
|
+
"errors": Annotated[list, add],
|
|
49
|
+
# Messages with reducer (accumulates)
|
|
50
|
+
"messages": Annotated[list, add],
|
|
51
|
+
# Loop tracking
|
|
52
|
+
"_loop_counts": dict,
|
|
53
|
+
"_loop_limit_reached": bool,
|
|
54
|
+
"_agent_iterations": int,
|
|
55
|
+
"_agent_limit_reached": bool,
|
|
56
|
+
# Timestamps
|
|
57
|
+
"started_at": Any,
|
|
58
|
+
"completed_at": Any,
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
# Common input fields used across graph types
|
|
62
|
+
# These are always included to support --var inputs
|
|
63
|
+
COMMON_INPUT_FIELDS: dict[str, type] = {
|
|
64
|
+
"input": str, # Agent prompt input
|
|
65
|
+
"topic": str, # Content generation topic
|
|
66
|
+
"style": str, # Writing style
|
|
67
|
+
"word_count": int, # Target word count
|
|
68
|
+
"message": str, # Router message input
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
# Type mapping for YAML state config
|
|
72
|
+
TYPE_MAP: dict[str, type] = {
|
|
73
|
+
"str": str,
|
|
74
|
+
"string": str,
|
|
75
|
+
"int": int,
|
|
76
|
+
"integer": int,
|
|
77
|
+
"float": float,
|
|
78
|
+
"bool": bool,
|
|
79
|
+
"boolean": bool,
|
|
80
|
+
"list": list,
|
|
81
|
+
"dict": dict,
|
|
82
|
+
"any": Any,
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def parse_state_config(state_config: dict) -> dict[str, type]:
|
|
87
|
+
"""Parse YAML state section into field types.
|
|
88
|
+
|
|
89
|
+
Supports simple type strings:
|
|
90
|
+
state:
|
|
91
|
+
concept: str
|
|
92
|
+
count: int
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
state_config: Dict from YAML 'state' section
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
Dict of field_name -> Python type
|
|
99
|
+
"""
|
|
100
|
+
fields: dict[str, type] = {}
|
|
101
|
+
|
|
102
|
+
for field_name, type_spec in state_config.items():
|
|
103
|
+
if isinstance(type_spec, str):
|
|
104
|
+
# Simple type: "str", "int", etc.
|
|
105
|
+
normalized = type_spec.lower()
|
|
106
|
+
if normalized not in TYPE_MAP:
|
|
107
|
+
supported = ", ".join(sorted(set(TYPE_MAP.keys())))
|
|
108
|
+
logger.warning(
|
|
109
|
+
f"Unknown type '{type_spec}' for state field '{field_name}'. "
|
|
110
|
+
f"Supported types: {supported}. Defaulting to Any."
|
|
111
|
+
)
|
|
112
|
+
python_type = TYPE_MAP.get(normalized, Any)
|
|
113
|
+
fields[field_name] = python_type
|
|
114
|
+
else:
|
|
115
|
+
# Unknown format, use Any
|
|
116
|
+
logger.warning(
|
|
117
|
+
f"Invalid type specification for state field '{field_name}': "
|
|
118
|
+
f"expected string, got {type(type_spec).__name__}. Defaulting to Any."
|
|
119
|
+
)
|
|
120
|
+
fields[field_name] = Any
|
|
121
|
+
|
|
122
|
+
return fields
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def build_state_class(config: dict) -> type:
|
|
126
|
+
"""Build TypedDict state class from graph configuration.
|
|
127
|
+
|
|
128
|
+
Dynamically generates a TypedDict with:
|
|
129
|
+
- Base infrastructure fields (errors, messages, thread_id, etc.)
|
|
130
|
+
- Common input fields (topic, style, input, message, etc.)
|
|
131
|
+
- Custom fields from YAML 'state' section
|
|
132
|
+
- Fields extracted from node state_key
|
|
133
|
+
- Special fields for agent/router node types
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
config: Parsed YAML graph configuration dict
|
|
137
|
+
|
|
138
|
+
Returns:
|
|
139
|
+
TypedDict class with total=False (all fields optional)
|
|
140
|
+
"""
|
|
141
|
+
# Start with base and common fields
|
|
142
|
+
fields: dict[str, type] = {}
|
|
143
|
+
fields.update(BASE_FIELDS)
|
|
144
|
+
fields.update(COMMON_INPUT_FIELDS)
|
|
145
|
+
|
|
146
|
+
# Add custom state fields from YAML 'state' section
|
|
147
|
+
state_config = config.get("state", {})
|
|
148
|
+
custom_fields = parse_state_config(state_config)
|
|
149
|
+
fields.update(custom_fields)
|
|
150
|
+
|
|
151
|
+
# Extract fields from nodes
|
|
152
|
+
nodes = config.get("nodes", {})
|
|
153
|
+
node_fields = extract_node_fields(nodes)
|
|
154
|
+
fields.update(node_fields)
|
|
155
|
+
|
|
156
|
+
# Build TypedDict programmatically
|
|
157
|
+
return TypedDict("GraphState", fields, total=False)
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def extract_node_fields(nodes: dict) -> dict[str, type]:
|
|
161
|
+
"""Extract state fields from node configurations.
|
|
162
|
+
|
|
163
|
+
Analyzes node configs to determine required state fields:
|
|
164
|
+
- state_key: Where node stores its output
|
|
165
|
+
- type: agent → adds input, _tool_results
|
|
166
|
+
- type: router → adds _route
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
nodes: Dict of node_name -> node_config
|
|
170
|
+
|
|
171
|
+
Returns:
|
|
172
|
+
Dict of field_name -> type for the state
|
|
173
|
+
"""
|
|
174
|
+
fields: dict[str, type] = {}
|
|
175
|
+
|
|
176
|
+
for node_name, node_config in nodes.items():
|
|
177
|
+
if not isinstance(node_config, dict):
|
|
178
|
+
continue
|
|
179
|
+
|
|
180
|
+
# state_key → Any (accepts Pydantic models)
|
|
181
|
+
if state_key := node_config.get("state_key"):
|
|
182
|
+
fields[state_key] = Any
|
|
183
|
+
|
|
184
|
+
# Node type-specific fields
|
|
185
|
+
node_type = node_config.get("type", "llm")
|
|
186
|
+
|
|
187
|
+
if node_type == "agent":
|
|
188
|
+
fields["input"] = str
|
|
189
|
+
fields["_tool_results"] = list
|
|
190
|
+
|
|
191
|
+
elif node_type == "router":
|
|
192
|
+
fields["_route"] = str
|
|
193
|
+
|
|
194
|
+
elif node_type == "map":
|
|
195
|
+
# Map node collect field needs sorted reducer for ordered fan-in
|
|
196
|
+
if collect_key := node_config.get("collect"):
|
|
197
|
+
fields[collect_key] = Annotated[list, sorted_add]
|
|
198
|
+
|
|
199
|
+
return fields
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def create_initial_state(
|
|
203
|
+
topic: str = "",
|
|
204
|
+
style: str = "informative",
|
|
205
|
+
word_count: int = 300,
|
|
206
|
+
thread_id: str | None = None,
|
|
207
|
+
**kwargs: Any,
|
|
208
|
+
) -> dict[str, Any]:
|
|
209
|
+
"""Create an initial state for a new pipeline run.
|
|
210
|
+
|
|
211
|
+
Args:
|
|
212
|
+
topic: The topic to generate content about
|
|
213
|
+
style: Writing style (default: informative)
|
|
214
|
+
word_count: Target word count (default: 300)
|
|
215
|
+
thread_id: Optional thread ID (auto-generated if not provided)
|
|
216
|
+
**kwargs: Additional state fields (e.g., input for agents)
|
|
217
|
+
|
|
218
|
+
Returns:
|
|
219
|
+
Initialized state dictionary
|
|
220
|
+
"""
|
|
221
|
+
import uuid
|
|
222
|
+
from datetime import datetime
|
|
223
|
+
|
|
224
|
+
return {
|
|
225
|
+
"thread_id": thread_id or uuid.uuid4().hex[:16],
|
|
226
|
+
"topic": topic,
|
|
227
|
+
"style": style,
|
|
228
|
+
"word_count": word_count,
|
|
229
|
+
"current_step": "init",
|
|
230
|
+
"error": None,
|
|
231
|
+
"errors": [],
|
|
232
|
+
"messages": [],
|
|
233
|
+
"started_at": datetime.now(),
|
|
234
|
+
"completed_at": None,
|
|
235
|
+
**kwargs,
|
|
236
|
+
}
|
|
@@ -0,0 +1,240 @@
|
|
|
1
|
+
"""Node function factory for YAML-defined graphs.
|
|
2
|
+
|
|
3
|
+
Creates LangGraph node functions from YAML configuration with support for:
|
|
4
|
+
- Resume (skip if output exists)
|
|
5
|
+
- Error handling (skip, retry, fail, fallback)
|
|
6
|
+
- Router nodes with dynamic routing
|
|
7
|
+
- Loop counting and limits
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import logging
|
|
11
|
+
from typing import Any, Callable
|
|
12
|
+
|
|
13
|
+
from yamlgraph.constants import ErrorHandler, NodeType
|
|
14
|
+
from yamlgraph.error_handlers import (
|
|
15
|
+
check_loop_limit,
|
|
16
|
+
check_requirements,
|
|
17
|
+
handle_default,
|
|
18
|
+
handle_fail,
|
|
19
|
+
handle_fallback,
|
|
20
|
+
handle_retry,
|
|
21
|
+
handle_skip,
|
|
22
|
+
)
|
|
23
|
+
from yamlgraph.executor import execute_prompt
|
|
24
|
+
from yamlgraph.utils.expressions import resolve_template
|
|
25
|
+
from yamlgraph.utils.prompts import resolve_prompt_path
|
|
26
|
+
|
|
27
|
+
# Type alias for dynamic state
|
|
28
|
+
GraphState = dict[str, Any]
|
|
29
|
+
|
|
30
|
+
logger = logging.getLogger(__name__)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def resolve_class(class_path: str) -> type:
|
|
34
|
+
"""Dynamically import and return a class from a module path.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
class_path: Full path like "yamlgraph.models.GenericReport" or short name
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
The class object
|
|
41
|
+
"""
|
|
42
|
+
import importlib
|
|
43
|
+
|
|
44
|
+
parts = class_path.rsplit(".", 1)
|
|
45
|
+
if len(parts) != 2:
|
|
46
|
+
# Try to find in yamlgraph.models.schemas
|
|
47
|
+
try:
|
|
48
|
+
from yamlgraph.models import schemas
|
|
49
|
+
|
|
50
|
+
if hasattr(schemas, class_path):
|
|
51
|
+
return getattr(schemas, class_path)
|
|
52
|
+
except ImportError:
|
|
53
|
+
pass
|
|
54
|
+
raise ValueError(f"Invalid class path: {class_path}")
|
|
55
|
+
|
|
56
|
+
module_path, class_name = parts
|
|
57
|
+
module = importlib.import_module(module_path)
|
|
58
|
+
return getattr(module, class_name)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def get_output_model_for_node(
|
|
62
|
+
node_config: dict[str, Any], prompts_dir: str | None = None
|
|
63
|
+
) -> type | None:
|
|
64
|
+
"""Get output model for a node, checking inline schema if no explicit model.
|
|
65
|
+
|
|
66
|
+
Priority:
|
|
67
|
+
1. Explicit output_model in node config (class path)
|
|
68
|
+
2. Inline schema in prompt YAML file
|
|
69
|
+
3. None (raw string output)
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
node_config: Node configuration from YAML
|
|
73
|
+
prompts_dir: Base prompts directory
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
Pydantic model class or None
|
|
77
|
+
"""
|
|
78
|
+
# 1. Check for explicit output_model
|
|
79
|
+
if model_path := node_config.get("output_model"):
|
|
80
|
+
return resolve_class(model_path)
|
|
81
|
+
|
|
82
|
+
# 2. Check for inline schema in prompt YAML
|
|
83
|
+
prompt_name = node_config.get("prompt")
|
|
84
|
+
if prompt_name:
|
|
85
|
+
try:
|
|
86
|
+
from yamlgraph.schema_loader import load_schema_from_yaml
|
|
87
|
+
|
|
88
|
+
yaml_path = resolve_prompt_path(prompt_name, prompts_dir)
|
|
89
|
+
return load_schema_from_yaml(yaml_path)
|
|
90
|
+
except FileNotFoundError:
|
|
91
|
+
# Prompt file doesn't exist yet - will fail later
|
|
92
|
+
pass
|
|
93
|
+
|
|
94
|
+
# 3. No output model
|
|
95
|
+
return None
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def create_node_function(
|
|
99
|
+
node_name: str,
|
|
100
|
+
node_config: dict,
|
|
101
|
+
defaults: dict,
|
|
102
|
+
) -> Callable[[GraphState], dict]:
|
|
103
|
+
"""Create a node function from YAML config.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
node_name: Name of the node
|
|
107
|
+
node_config: Node configuration from YAML
|
|
108
|
+
defaults: Default configuration values
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
Node function compatible with LangGraph
|
|
112
|
+
"""
|
|
113
|
+
node_type = node_config.get("type", NodeType.LLM)
|
|
114
|
+
prompt_name = node_config.get("prompt")
|
|
115
|
+
|
|
116
|
+
# Resolve output model (explicit > inline schema > None)
|
|
117
|
+
output_model = get_output_model_for_node(node_config)
|
|
118
|
+
|
|
119
|
+
# Get config values (node > defaults)
|
|
120
|
+
temperature = node_config.get("temperature", defaults.get("temperature", 0.7))
|
|
121
|
+
provider = node_config.get("provider", defaults.get("provider"))
|
|
122
|
+
state_key = node_config.get("state_key", node_name)
|
|
123
|
+
variable_templates = node_config.get("variables", {})
|
|
124
|
+
requires = node_config.get("requires", [])
|
|
125
|
+
|
|
126
|
+
# Error handling
|
|
127
|
+
on_error = node_config.get("on_error")
|
|
128
|
+
max_retries = node_config.get("max_retries", 3)
|
|
129
|
+
fallback_config = node_config.get("fallback", {})
|
|
130
|
+
fallback_provider = fallback_config.get("provider") if fallback_config else None
|
|
131
|
+
|
|
132
|
+
# Router config
|
|
133
|
+
routes = node_config.get("routes", {})
|
|
134
|
+
default_route = node_config.get("default_route")
|
|
135
|
+
|
|
136
|
+
# Loop limit
|
|
137
|
+
loop_limit = node_config.get("loop_limit")
|
|
138
|
+
|
|
139
|
+
# Skip if exists (default true for resume support, false for loop nodes)
|
|
140
|
+
skip_if_exists = node_config.get("skip_if_exists", True)
|
|
141
|
+
|
|
142
|
+
def node_fn(state: dict) -> dict:
|
|
143
|
+
"""Generated node function."""
|
|
144
|
+
loop_counts = dict(state.get("_loop_counts") or {})
|
|
145
|
+
current_count = loop_counts.get(node_name, 0)
|
|
146
|
+
|
|
147
|
+
# Check loop limit
|
|
148
|
+
if check_loop_limit(node_name, loop_limit, current_count):
|
|
149
|
+
return {"_loop_limit_reached": True, "current_step": node_name}
|
|
150
|
+
|
|
151
|
+
loop_counts[node_name] = current_count + 1
|
|
152
|
+
|
|
153
|
+
# Skip if output exists (resume support) - disabled for loop nodes
|
|
154
|
+
if skip_if_exists and state.get(state_key) is not None:
|
|
155
|
+
logger.info(f"Node {node_name} skipped - {state_key} already in state")
|
|
156
|
+
return {"current_step": node_name, "_loop_counts": loop_counts}
|
|
157
|
+
|
|
158
|
+
# Check requirements
|
|
159
|
+
if error := check_requirements(requires, state, node_name):
|
|
160
|
+
return {
|
|
161
|
+
"errors": [error],
|
|
162
|
+
"current_step": node_name,
|
|
163
|
+
"_loop_counts": loop_counts,
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
# Resolve variables
|
|
167
|
+
variables = {}
|
|
168
|
+
for key, template in variable_templates.items():
|
|
169
|
+
resolved = resolve_template(template, state)
|
|
170
|
+
if isinstance(resolved, list):
|
|
171
|
+
resolved = ", ".join(str(item) for item in resolved)
|
|
172
|
+
variables[key] = resolved
|
|
173
|
+
|
|
174
|
+
def attempt_execute(use_provider: str | None) -> tuple[Any, Exception | None]:
|
|
175
|
+
try:
|
|
176
|
+
result = execute_prompt(
|
|
177
|
+
prompt_name=prompt_name,
|
|
178
|
+
variables=variables,
|
|
179
|
+
output_model=output_model,
|
|
180
|
+
temperature=temperature,
|
|
181
|
+
provider=use_provider,
|
|
182
|
+
)
|
|
183
|
+
return result, None
|
|
184
|
+
except Exception as e:
|
|
185
|
+
return None, e
|
|
186
|
+
|
|
187
|
+
result, error = attempt_execute(provider)
|
|
188
|
+
|
|
189
|
+
if error is None:
|
|
190
|
+
logger.info(f"Node {node_name} completed successfully")
|
|
191
|
+
update = {
|
|
192
|
+
state_key: result,
|
|
193
|
+
"current_step": node_name,
|
|
194
|
+
"_loop_counts": loop_counts,
|
|
195
|
+
}
|
|
196
|
+
|
|
197
|
+
# Router: add _route to state
|
|
198
|
+
if node_type == NodeType.ROUTER and routes:
|
|
199
|
+
route_key = getattr(result, "tone", None) or getattr(
|
|
200
|
+
result, "intent", None
|
|
201
|
+
)
|
|
202
|
+
if route_key and route_key in routes:
|
|
203
|
+
update["_route"] = routes[route_key]
|
|
204
|
+
elif default_route:
|
|
205
|
+
update["_route"] = default_route
|
|
206
|
+
else:
|
|
207
|
+
update["_route"] = list(routes.values())[0]
|
|
208
|
+
logger.info(f"Router {node_name} routing to: {update['_route']}")
|
|
209
|
+
return update
|
|
210
|
+
|
|
211
|
+
# Error handling - dispatch to strategy handlers
|
|
212
|
+
if on_error == ErrorHandler.SKIP:
|
|
213
|
+
handle_skip(node_name, error, loop_counts)
|
|
214
|
+
return {"current_step": node_name, "_loop_counts": loop_counts}
|
|
215
|
+
|
|
216
|
+
elif on_error == ErrorHandler.FAIL:
|
|
217
|
+
handle_fail(node_name, error)
|
|
218
|
+
|
|
219
|
+
elif on_error == ErrorHandler.RETRY:
|
|
220
|
+
result = handle_retry(
|
|
221
|
+
node_name,
|
|
222
|
+
lambda: attempt_execute(provider),
|
|
223
|
+
max_retries,
|
|
224
|
+
)
|
|
225
|
+
return result.to_state_update(state_key, node_name, loop_counts)
|
|
226
|
+
|
|
227
|
+
elif on_error == ErrorHandler.FALLBACK and fallback_provider:
|
|
228
|
+
result = handle_fallback(
|
|
229
|
+
node_name,
|
|
230
|
+
attempt_execute,
|
|
231
|
+
fallback_provider,
|
|
232
|
+
)
|
|
233
|
+
return result.to_state_update(state_key, node_name, loop_counts)
|
|
234
|
+
|
|
235
|
+
else:
|
|
236
|
+
result = handle_default(node_name, error)
|
|
237
|
+
return result.to_state_update(state_key, node_name, loop_counts)
|
|
238
|
+
|
|
239
|
+
node_fn.__name__ = f"{node_name}_node"
|
|
240
|
+
return node_fn
|
yamlgraph/routing.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
"""Routing utilities for LangGraph edge conditions.
|
|
2
|
+
|
|
3
|
+
Provides factory functions for creating router functions that determine
|
|
4
|
+
which node to route to based on state values and expressions.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
from collections.abc import Callable
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
from langgraph.graph import END
|
|
12
|
+
|
|
13
|
+
from yamlgraph.utils.conditions import evaluate_condition
|
|
14
|
+
|
|
15
|
+
# Type alias for dynamic state
|
|
16
|
+
GraphState = dict[str, Any]
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def make_router_fn(targets: list[str]) -> Callable[[dict], str]:
|
|
22
|
+
"""Create a router function that reads _route from state.
|
|
23
|
+
|
|
24
|
+
Used for type: router nodes with conditional edges to multiple targets.
|
|
25
|
+
|
|
26
|
+
NOTE: Use `state: dict` not `state: GraphState` - type hints cause
|
|
27
|
+
LangGraph to filter state fields. See docs/debug-router-type-hints.md
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
targets: List of valid target node names
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
Router function that returns the target node name
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def router_fn(state: dict) -> str:
|
|
37
|
+
route = state.get("_route")
|
|
38
|
+
logger.debug(f"Router: _route={route}, targets={targets}")
|
|
39
|
+
if route and route in targets:
|
|
40
|
+
logger.debug(f"Router: matched route {route}")
|
|
41
|
+
return route
|
|
42
|
+
# Default to first target
|
|
43
|
+
logger.debug(f"Router: defaulting to {targets[0]}")
|
|
44
|
+
return targets[0]
|
|
45
|
+
|
|
46
|
+
return router_fn
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def make_expr_router_fn(
|
|
50
|
+
edges: list[tuple[str, str]],
|
|
51
|
+
source_node: str,
|
|
52
|
+
) -> Callable[[GraphState], str]:
|
|
53
|
+
"""Create router that evaluates expression conditions.
|
|
54
|
+
|
|
55
|
+
Used for reflexion-style loops with expression-based conditions
|
|
56
|
+
like "critique.score < 0.8".
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
edges: List of (condition, target) tuples
|
|
60
|
+
source_node: Name of the source node (for logging)
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
Router function that evaluates conditions and returns target
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
def expr_router_fn(state: GraphState) -> str:
|
|
67
|
+
# Check loop limit first
|
|
68
|
+
if state.get("_loop_limit_reached"):
|
|
69
|
+
return END
|
|
70
|
+
|
|
71
|
+
for condition, target in edges:
|
|
72
|
+
try:
|
|
73
|
+
if evaluate_condition(condition, state):
|
|
74
|
+
logger.debug(
|
|
75
|
+
f"Condition '{condition}' matched, routing to {target}"
|
|
76
|
+
)
|
|
77
|
+
return target
|
|
78
|
+
except ValueError as e:
|
|
79
|
+
logger.warning(f"Failed to evaluate condition '{condition}': {e}")
|
|
80
|
+
# No condition matched - this shouldn't happen with well-formed graphs
|
|
81
|
+
logger.warning(f"No condition matched for {source_node}, defaulting to END")
|
|
82
|
+
return END
|
|
83
|
+
|
|
84
|
+
return expr_router_fn
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
__all__ = ["make_router_fn", "make_expr_router_fn"]
|