yamlgraph 0.3.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- examples/__init__.py +1 -0
- examples/codegen/__init__.py +5 -0
- examples/codegen/models/__init__.py +13 -0
- examples/codegen/models/schemas.py +76 -0
- examples/codegen/tests/__init__.py +1 -0
- examples/codegen/tests/test_ai_helpers.py +235 -0
- examples/codegen/tests/test_ast_analysis.py +174 -0
- examples/codegen/tests/test_code_analysis.py +134 -0
- examples/codegen/tests/test_code_context.py +301 -0
- examples/codegen/tests/test_code_nav.py +89 -0
- examples/codegen/tests/test_dependency_tools.py +119 -0
- examples/codegen/tests/test_example_tools.py +185 -0
- examples/codegen/tests/test_git_tools.py +112 -0
- examples/codegen/tests/test_impl_agent_schemas.py +193 -0
- examples/codegen/tests/test_impl_agent_v4_graph.py +94 -0
- examples/codegen/tests/test_jedi_analysis.py +226 -0
- examples/codegen/tests/test_meta_tools.py +250 -0
- examples/codegen/tests/test_plan_discovery_prompt.py +98 -0
- examples/codegen/tests/test_syntax_tools.py +85 -0
- examples/codegen/tests/test_synthesize_prompt.py +94 -0
- examples/codegen/tests/test_template_tools.py +244 -0
- examples/codegen/tools/__init__.py +80 -0
- examples/codegen/tools/ai_helpers.py +420 -0
- examples/codegen/tools/ast_analysis.py +92 -0
- examples/codegen/tools/code_context.py +180 -0
- examples/codegen/tools/code_nav.py +52 -0
- examples/codegen/tools/dependency_tools.py +120 -0
- examples/codegen/tools/example_tools.py +188 -0
- examples/codegen/tools/git_tools.py +151 -0
- examples/codegen/tools/impl_executor.py +614 -0
- examples/codegen/tools/jedi_analysis.py +311 -0
- examples/codegen/tools/meta_tools.py +202 -0
- examples/codegen/tools/syntax_tools.py +26 -0
- examples/codegen/tools/template_tools.py +356 -0
- examples/fastapi_interview.py +167 -0
- examples/npc/api/__init__.py +1 -0
- examples/npc/api/app.py +100 -0
- examples/npc/api/routes/__init__.py +5 -0
- examples/npc/api/routes/encounter.py +182 -0
- examples/npc/api/session.py +330 -0
- examples/npc/demo.py +387 -0
- examples/npc/nodes/__init__.py +5 -0
- examples/npc/nodes/image_node.py +92 -0
- examples/npc/run_encounter.py +230 -0
- examples/shared/__init__.py +0 -0
- examples/shared/replicate_tool.py +238 -0
- examples/storyboard/__init__.py +1 -0
- examples/storyboard/generate_videos.py +335 -0
- examples/storyboard/nodes/__init__.py +12 -0
- examples/storyboard/nodes/animated_character_node.py +248 -0
- examples/storyboard/nodes/animated_image_node.py +138 -0
- examples/storyboard/nodes/character_node.py +162 -0
- examples/storyboard/nodes/image_node.py +118 -0
- examples/storyboard/nodes/replicate_tool.py +49 -0
- examples/storyboard/retry_images.py +118 -0
- scripts/demo_async_executor.py +212 -0
- scripts/demo_interview_e2e.py +200 -0
- scripts/demo_streaming.py +140 -0
- scripts/run_interview_demo.py +94 -0
- scripts/test_interrupt_fix.py +26 -0
- tests/__init__.py +1 -0
- tests/conftest.py +178 -0
- tests/integration/__init__.py +1 -0
- tests/integration/test_animated_storyboard.py +63 -0
- tests/integration/test_cli_commands.py +242 -0
- tests/integration/test_colocated_prompts.py +139 -0
- tests/integration/test_map_demo.py +50 -0
- tests/integration/test_memory_demo.py +283 -0
- tests/integration/test_npc_api/__init__.py +1 -0
- tests/integration/test_npc_api/test_routes.py +357 -0
- tests/integration/test_npc_api/test_session.py +216 -0
- tests/integration/test_pipeline_flow.py +105 -0
- tests/integration/test_providers.py +163 -0
- tests/integration/test_resume.py +75 -0
- tests/integration/test_subgraph_integration.py +295 -0
- tests/integration/test_subgraph_interrupt.py +106 -0
- tests/unit/__init__.py +1 -0
- tests/unit/test_agent_nodes.py +355 -0
- tests/unit/test_async_executor.py +346 -0
- tests/unit/test_checkpointer.py +212 -0
- tests/unit/test_checkpointer_factory.py +212 -0
- tests/unit/test_cli.py +121 -0
- tests/unit/test_cli_package.py +81 -0
- tests/unit/test_compile_graph_map.py +132 -0
- tests/unit/test_conditions_routing.py +253 -0
- tests/unit/test_config.py +93 -0
- tests/unit/test_conversation_memory.py +276 -0
- tests/unit/test_database.py +145 -0
- tests/unit/test_deprecation.py +104 -0
- tests/unit/test_executor.py +172 -0
- tests/unit/test_executor_async.py +179 -0
- tests/unit/test_export.py +149 -0
- tests/unit/test_expressions.py +178 -0
- tests/unit/test_feature_brainstorm.py +194 -0
- tests/unit/test_format_prompt.py +145 -0
- tests/unit/test_generic_report.py +200 -0
- tests/unit/test_graph_commands.py +327 -0
- tests/unit/test_graph_linter.py +627 -0
- tests/unit/test_graph_loader.py +357 -0
- tests/unit/test_graph_schema.py +193 -0
- tests/unit/test_inline_schema.py +151 -0
- tests/unit/test_interrupt_node.py +182 -0
- tests/unit/test_issues.py +164 -0
- tests/unit/test_jinja2_prompts.py +85 -0
- tests/unit/test_json_extract.py +134 -0
- tests/unit/test_langsmith.py +600 -0
- tests/unit/test_langsmith_tools.py +204 -0
- tests/unit/test_llm_factory.py +109 -0
- tests/unit/test_llm_factory_async.py +118 -0
- tests/unit/test_loops.py +403 -0
- tests/unit/test_map_node.py +144 -0
- tests/unit/test_no_backward_compat.py +56 -0
- tests/unit/test_node_factory.py +348 -0
- tests/unit/test_passthrough_node.py +126 -0
- tests/unit/test_prompts.py +324 -0
- tests/unit/test_python_nodes.py +198 -0
- tests/unit/test_reliability.py +298 -0
- tests/unit/test_result_export.py +234 -0
- tests/unit/test_router.py +296 -0
- tests/unit/test_sanitize.py +99 -0
- tests/unit/test_schema_loader.py +295 -0
- tests/unit/test_shell_tools.py +229 -0
- tests/unit/test_state_builder.py +331 -0
- tests/unit/test_state_builder_map.py +104 -0
- tests/unit/test_state_config.py +197 -0
- tests/unit/test_streaming.py +307 -0
- tests/unit/test_subgraph.py +596 -0
- tests/unit/test_template.py +190 -0
- tests/unit/test_tool_call_integration.py +164 -0
- tests/unit/test_tool_call_node.py +178 -0
- tests/unit/test_tool_nodes.py +129 -0
- tests/unit/test_websearch.py +234 -0
- yamlgraph/__init__.py +35 -0
- yamlgraph/builder.py +110 -0
- yamlgraph/cli/__init__.py +159 -0
- yamlgraph/cli/__main__.py +6 -0
- yamlgraph/cli/commands.py +231 -0
- yamlgraph/cli/deprecation.py +92 -0
- yamlgraph/cli/graph_commands.py +541 -0
- yamlgraph/cli/validators.py +37 -0
- yamlgraph/config.py +67 -0
- yamlgraph/constants.py +70 -0
- yamlgraph/error_handlers.py +227 -0
- yamlgraph/executor.py +290 -0
- yamlgraph/executor_async.py +288 -0
- yamlgraph/graph_loader.py +451 -0
- yamlgraph/map_compiler.py +150 -0
- yamlgraph/models/__init__.py +36 -0
- yamlgraph/models/graph_schema.py +181 -0
- yamlgraph/models/schemas.py +124 -0
- yamlgraph/models/state_builder.py +236 -0
- yamlgraph/node_factory.py +768 -0
- yamlgraph/routing.py +87 -0
- yamlgraph/schema_loader.py +240 -0
- yamlgraph/storage/__init__.py +20 -0
- yamlgraph/storage/checkpointer.py +72 -0
- yamlgraph/storage/checkpointer_factory.py +123 -0
- yamlgraph/storage/database.py +320 -0
- yamlgraph/storage/export.py +269 -0
- yamlgraph/tools/__init__.py +1 -0
- yamlgraph/tools/agent.py +320 -0
- yamlgraph/tools/graph_linter.py +388 -0
- yamlgraph/tools/langsmith_tools.py +125 -0
- yamlgraph/tools/nodes.py +126 -0
- yamlgraph/tools/python_tool.py +179 -0
- yamlgraph/tools/shell.py +205 -0
- yamlgraph/tools/websearch.py +242 -0
- yamlgraph/utils/__init__.py +48 -0
- yamlgraph/utils/conditions.py +157 -0
- yamlgraph/utils/expressions.py +245 -0
- yamlgraph/utils/json_extract.py +104 -0
- yamlgraph/utils/langsmith.py +416 -0
- yamlgraph/utils/llm_factory.py +118 -0
- yamlgraph/utils/llm_factory_async.py +105 -0
- yamlgraph/utils/logging.py +104 -0
- yamlgraph/utils/prompts.py +171 -0
- yamlgraph/utils/sanitize.py +98 -0
- yamlgraph/utils/template.py +102 -0
- yamlgraph/utils/validators.py +181 -0
- yamlgraph-0.3.9.dist-info/METADATA +1105 -0
- yamlgraph-0.3.9.dist-info/RECORD +185 -0
- yamlgraph-0.3.9.dist-info/WHEEL +5 -0
- yamlgraph-0.3.9.dist-info/entry_points.txt +2 -0
- yamlgraph-0.3.9.dist-info/licenses/LICENSE +33 -0
- yamlgraph-0.3.9.dist-info/top_level.txt +4 -0
yamlgraph/constants.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
"""Type-safe constants for YAML graph configuration.
|
|
2
|
+
|
|
3
|
+
Provides enums for node types, error handlers, and other magic strings
|
|
4
|
+
used throughout the codebase to enable static type checking and IDE support.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from enum import StrEnum
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class NodeType(StrEnum):
|
|
11
|
+
"""Valid node types in YAML graph configuration."""
|
|
12
|
+
|
|
13
|
+
LLM = "llm"
|
|
14
|
+
ROUTER = "router"
|
|
15
|
+
TOOL = "tool"
|
|
16
|
+
AGENT = "agent"
|
|
17
|
+
PYTHON = "python"
|
|
18
|
+
MAP = "map"
|
|
19
|
+
TOOL_CALL = "tool_call"
|
|
20
|
+
INTERRUPT = "interrupt"
|
|
21
|
+
SUBGRAPH = "subgraph"
|
|
22
|
+
PASSTHROUGH = "passthrough"
|
|
23
|
+
|
|
24
|
+
@classmethod
|
|
25
|
+
def requires_prompt(cls, node_type: str) -> bool:
|
|
26
|
+
"""Check if node type requires a prompt field.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
node_type: The node type string
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
True if the node type requires a prompt
|
|
33
|
+
"""
|
|
34
|
+
return node_type in (cls.LLM, cls.ROUTER)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class ErrorHandler(StrEnum):
|
|
38
|
+
"""Valid on_error handling strategies."""
|
|
39
|
+
|
|
40
|
+
SKIP = "skip" # Skip node and continue pipeline
|
|
41
|
+
RETRY = "retry" # Retry with max_retries attempts
|
|
42
|
+
FAIL = "fail" # Raise exception immediately
|
|
43
|
+
FALLBACK = "fallback" # Try fallback provider
|
|
44
|
+
|
|
45
|
+
@classmethod
|
|
46
|
+
def all_values(cls) -> set[str]:
|
|
47
|
+
"""Return all valid error handler values.
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
Set of valid error handler strings
|
|
51
|
+
"""
|
|
52
|
+
return {handler.value for handler in cls}
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class EdgeType(StrEnum):
|
|
56
|
+
"""Valid edge types in graph configuration."""
|
|
57
|
+
|
|
58
|
+
SIMPLE = "simple" # Direct edge from -> to
|
|
59
|
+
CONDITIONAL = "conditional" # Edge with conditions
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class SpecialNodes(StrEnum):
|
|
63
|
+
"""Special node names with semantic meaning."""
|
|
64
|
+
|
|
65
|
+
START = "__start__"
|
|
66
|
+
END = "__end__"
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
# Re-export for convenience
|
|
70
|
+
__all__ = ["NodeType", "ErrorHandler", "EdgeType", "SpecialNodes"]
|
|
@@ -0,0 +1,227 @@
|
|
|
1
|
+
"""Error handling strategies for node execution.
|
|
2
|
+
|
|
3
|
+
Provides strategy functions for different error handling modes:
|
|
4
|
+
- skip: Continue without output
|
|
5
|
+
- fail: Raise exception immediately
|
|
6
|
+
- retry: Retry up to N times
|
|
7
|
+
- fallback: Try fallback provider
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import logging
|
|
11
|
+
from collections.abc import Callable
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
from yamlgraph.models import ErrorType, PipelineError
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class NodeResult:
|
|
20
|
+
"""Result of node execution with consistent structure.
|
|
21
|
+
|
|
22
|
+
Attributes:
|
|
23
|
+
success: Whether execution succeeded
|
|
24
|
+
output: The result value (if success)
|
|
25
|
+
error: PipelineError (if failure)
|
|
26
|
+
state_updates: Additional state updates
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
success: bool,
|
|
32
|
+
output: Any = None,
|
|
33
|
+
error: PipelineError | None = None,
|
|
34
|
+
state_updates: dict | None = None,
|
|
35
|
+
):
|
|
36
|
+
self.success = success
|
|
37
|
+
self.output = output
|
|
38
|
+
self.error = error
|
|
39
|
+
self.state_updates = state_updates or {}
|
|
40
|
+
|
|
41
|
+
def to_state_update(
|
|
42
|
+
self,
|
|
43
|
+
state_key: str,
|
|
44
|
+
node_name: str,
|
|
45
|
+
loop_counts: dict,
|
|
46
|
+
) -> dict:
|
|
47
|
+
"""Convert to LangGraph state update dict.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
state_key: Key to store output under
|
|
51
|
+
node_name: Name of the node
|
|
52
|
+
loop_counts: Current loop counts
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
State update dict with consistent structure
|
|
56
|
+
"""
|
|
57
|
+
update = {
|
|
58
|
+
"current_step": node_name,
|
|
59
|
+
"_loop_counts": loop_counts,
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
if self.success:
|
|
63
|
+
update[state_key] = self.output
|
|
64
|
+
elif self.error:
|
|
65
|
+
# Always use 'errors' list for consistency
|
|
66
|
+
update["errors"] = [self.error]
|
|
67
|
+
|
|
68
|
+
update.update(self.state_updates)
|
|
69
|
+
return update
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def handle_skip(
|
|
73
|
+
node_name: str,
|
|
74
|
+
error: Exception,
|
|
75
|
+
loop_counts: dict,
|
|
76
|
+
) -> NodeResult:
|
|
77
|
+
"""Handle error with skip strategy.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
node_name: Name of the node
|
|
81
|
+
error: The exception that occurred
|
|
82
|
+
loop_counts: Current loop counts
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
NodeResult with empty output
|
|
86
|
+
"""
|
|
87
|
+
logger.warning(f"Node {node_name} failed, skipping: {error}")
|
|
88
|
+
return NodeResult(success=True, output=None)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def handle_fail(
|
|
92
|
+
node_name: str,
|
|
93
|
+
error: Exception,
|
|
94
|
+
) -> None:
|
|
95
|
+
"""Handle error with fail strategy.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
node_name: Name of the node
|
|
99
|
+
error: The exception that occurred
|
|
100
|
+
|
|
101
|
+
Raises:
|
|
102
|
+
Exception: Always raises the original error
|
|
103
|
+
"""
|
|
104
|
+
logger.error(f"Node {node_name} failed (on_error=fail): {error}")
|
|
105
|
+
raise error
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def handle_retry(
|
|
109
|
+
node_name: str,
|
|
110
|
+
execute_fn: Callable[[], tuple[Any, Exception | None]],
|
|
111
|
+
max_retries: int,
|
|
112
|
+
) -> NodeResult:
|
|
113
|
+
"""Handle error with retry strategy.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
node_name: Name of the node
|
|
117
|
+
execute_fn: Function to execute (returns result, error)
|
|
118
|
+
max_retries: Maximum retry attempts
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
NodeResult with output or error
|
|
122
|
+
"""
|
|
123
|
+
last_exception: Exception | None = None
|
|
124
|
+
|
|
125
|
+
for attempt in range(1, max_retries + 1):
|
|
126
|
+
logger.info(f"Node {node_name} retry {attempt}/{max_retries}")
|
|
127
|
+
result, error = execute_fn()
|
|
128
|
+
if error is None:
|
|
129
|
+
return NodeResult(success=True, output=result)
|
|
130
|
+
last_exception = error
|
|
131
|
+
|
|
132
|
+
logger.error(f"Node {node_name} failed after {max_retries} attempts")
|
|
133
|
+
pipeline_error = PipelineError.from_exception(
|
|
134
|
+
last_exception or Exception("Unknown error"), node=node_name
|
|
135
|
+
)
|
|
136
|
+
return NodeResult(success=False, error=pipeline_error)
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def handle_fallback(
|
|
140
|
+
node_name: str,
|
|
141
|
+
execute_fn: Callable[[str | None], tuple[Any, Exception | None]],
|
|
142
|
+
fallback_provider: str,
|
|
143
|
+
) -> NodeResult:
|
|
144
|
+
"""Handle error with fallback strategy.
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
node_name: Name of the node
|
|
148
|
+
execute_fn: Function to execute with provider param
|
|
149
|
+
fallback_provider: Fallback provider to try
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
NodeResult with output or error
|
|
153
|
+
"""
|
|
154
|
+
logger.info(f"Node {node_name} trying fallback: {fallback_provider}")
|
|
155
|
+
result, fallback_error = execute_fn(fallback_provider)
|
|
156
|
+
|
|
157
|
+
if fallback_error is None:
|
|
158
|
+
return NodeResult(success=True, output=result)
|
|
159
|
+
|
|
160
|
+
logger.error(f"Node {node_name} failed with primary and fallback")
|
|
161
|
+
pipeline_error = PipelineError.from_exception(fallback_error, node=node_name)
|
|
162
|
+
return NodeResult(success=False, error=pipeline_error)
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def handle_default(
|
|
166
|
+
node_name: str,
|
|
167
|
+
error: Exception,
|
|
168
|
+
) -> NodeResult:
|
|
169
|
+
"""Handle error with default strategy (log and return error).
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
node_name: Name of the node
|
|
173
|
+
error: The exception that occurred
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
NodeResult with error
|
|
177
|
+
"""
|
|
178
|
+
logger.error(f"Node {node_name} failed: {error}")
|
|
179
|
+
pipeline_error = PipelineError.from_exception(error, node=node_name)
|
|
180
|
+
return NodeResult(success=False, error=pipeline_error)
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def check_requirements(
|
|
184
|
+
requires: list[str],
|
|
185
|
+
state: dict,
|
|
186
|
+
node_name: str,
|
|
187
|
+
) -> PipelineError | None:
|
|
188
|
+
"""Check if all required state keys are present.
|
|
189
|
+
|
|
190
|
+
Args:
|
|
191
|
+
requires: List of required state keys
|
|
192
|
+
state: Current state
|
|
193
|
+
node_name: Name of the node
|
|
194
|
+
|
|
195
|
+
Returns:
|
|
196
|
+
PipelineError if requirements not met, None otherwise
|
|
197
|
+
"""
|
|
198
|
+
for req in requires:
|
|
199
|
+
if state.get(req) is None:
|
|
200
|
+
return PipelineError(
|
|
201
|
+
type=ErrorType.STATE_ERROR,
|
|
202
|
+
message=f"Missing required state: {req}",
|
|
203
|
+
node=node_name,
|
|
204
|
+
retryable=False,
|
|
205
|
+
)
|
|
206
|
+
return None
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def check_loop_limit(
|
|
210
|
+
node_name: str,
|
|
211
|
+
loop_limit: int | None,
|
|
212
|
+
current_count: int,
|
|
213
|
+
) -> bool:
|
|
214
|
+
"""Check if loop limit has been reached.
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
node_name: Name of the node
|
|
218
|
+
loop_limit: Maximum loop iterations (None = no limit)
|
|
219
|
+
current_count: Current iteration count
|
|
220
|
+
|
|
221
|
+
Returns:
|
|
222
|
+
True if limit reached, False otherwise
|
|
223
|
+
"""
|
|
224
|
+
if loop_limit is not None and current_count >= loop_limit:
|
|
225
|
+
logger.warning(f"Node {node_name} hit loop limit ({loop_limit})")
|
|
226
|
+
return True
|
|
227
|
+
return False
|
yamlgraph/executor.py
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
1
|
+
"""YAML Prompt Executor - Unified interface for LLM calls.
|
|
2
|
+
|
|
3
|
+
This module provides a simple, reusable executor for YAML-defined prompts
|
|
4
|
+
with support for structured outputs via Pydantic models.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
import threading
|
|
9
|
+
import time
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import TypeVar
|
|
12
|
+
|
|
13
|
+
from langchain_core.language_models.chat_models import BaseChatModel
|
|
14
|
+
from langchain_core.messages import HumanMessage, SystemMessage
|
|
15
|
+
from pydantic import BaseModel
|
|
16
|
+
|
|
17
|
+
from yamlgraph.config import (
|
|
18
|
+
DEFAULT_TEMPERATURE,
|
|
19
|
+
MAX_RETRIES,
|
|
20
|
+
RETRY_BASE_DELAY,
|
|
21
|
+
RETRY_MAX_DELAY,
|
|
22
|
+
)
|
|
23
|
+
from yamlgraph.utils.llm_factory import create_llm
|
|
24
|
+
from yamlgraph.utils.prompts import load_prompt
|
|
25
|
+
from yamlgraph.utils.template import validate_variables
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
T = TypeVar("T", bound=BaseModel)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
# Exceptions that are retryable
|
|
33
|
+
RETRYABLE_EXCEPTIONS = (
|
|
34
|
+
"RateLimitError",
|
|
35
|
+
"APIConnectionError",
|
|
36
|
+
"APITimeoutError",
|
|
37
|
+
"InternalServerError",
|
|
38
|
+
"ServiceUnavailableError",
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def is_retryable(exception: Exception) -> bool:
|
|
43
|
+
"""Check if an exception is retryable.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
exception: The exception to check
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
True if the exception should be retried
|
|
50
|
+
"""
|
|
51
|
+
exc_name = type(exception).__name__
|
|
52
|
+
return exc_name in RETRYABLE_EXCEPTIONS or "rate" in exc_name.lower()
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def format_prompt(
|
|
56
|
+
template: str,
|
|
57
|
+
variables: dict,
|
|
58
|
+
state: dict | None = None,
|
|
59
|
+
) -> str:
|
|
60
|
+
"""Format a prompt template with variables.
|
|
61
|
+
|
|
62
|
+
Supports both simple {variable} placeholders and Jinja2 templates.
|
|
63
|
+
If the template contains Jinja2 syntax ({%, {{), uses Jinja2 rendering.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
template: Template string with {variable} or Jinja2 placeholders
|
|
67
|
+
variables: Dictionary of variable values
|
|
68
|
+
state: Optional state dict for Jinja2 templates (accessible as {{ state.field }})
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
Formatted string
|
|
72
|
+
|
|
73
|
+
Examples:
|
|
74
|
+
Simple format:
|
|
75
|
+
format_prompt("Hello {name}", {"name": "World"})
|
|
76
|
+
|
|
77
|
+
Jinja2 with variables:
|
|
78
|
+
format_prompt("{% for item in items %}{{ item }}{% endfor %}", {"items": [1, 2]})
|
|
79
|
+
|
|
80
|
+
Jinja2 with state:
|
|
81
|
+
format_prompt("Topic: {{ state.topic }}", {}, state={"topic": "AI"})
|
|
82
|
+
"""
|
|
83
|
+
# Check for Jinja2 syntax
|
|
84
|
+
if "{%" in template or "{{" in template:
|
|
85
|
+
from jinja2 import Template
|
|
86
|
+
|
|
87
|
+
jinja_template = Template(template)
|
|
88
|
+
# Pass both variables and state to Jinja2
|
|
89
|
+
context = {"state": state or {}, **variables}
|
|
90
|
+
return jinja_template.render(**context)
|
|
91
|
+
|
|
92
|
+
# Fall back to simple format - stringify lists for compatibility
|
|
93
|
+
safe_vars = {
|
|
94
|
+
k: (", ".join(map(str, v)) if isinstance(v, list) else v)
|
|
95
|
+
for k, v in variables.items()
|
|
96
|
+
}
|
|
97
|
+
return template.format(**safe_vars)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def execute_prompt(
|
|
101
|
+
prompt_name: str,
|
|
102
|
+
variables: dict | None = None,
|
|
103
|
+
output_model: type[T] | None = None,
|
|
104
|
+
temperature: float = DEFAULT_TEMPERATURE,
|
|
105
|
+
provider: str | None = None,
|
|
106
|
+
graph_path: "Path | None" = None,
|
|
107
|
+
prompts_dir: "Path | None" = None,
|
|
108
|
+
prompts_relative: bool = False,
|
|
109
|
+
) -> T | str:
|
|
110
|
+
"""Execute a YAML prompt with optional structured output.
|
|
111
|
+
|
|
112
|
+
Uses the singleton PromptExecutor for LLM caching.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
prompt_name: Name of the prompt file (without .yaml)
|
|
116
|
+
variables: Variables to substitute in the template
|
|
117
|
+
output_model: Optional Pydantic model for structured output
|
|
118
|
+
temperature: LLM temperature setting
|
|
119
|
+
provider: LLM provider ("anthropic", "mistral", "openai").
|
|
120
|
+
Can also be set in YAML metadata or PROVIDER env var.
|
|
121
|
+
graph_path: Path to graph file for relative prompt resolution
|
|
122
|
+
prompts_dir: Explicit prompts directory override
|
|
123
|
+
prompts_relative: If True, resolve prompts relative to graph_path
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
Parsed Pydantic model if output_model provided, else raw string
|
|
127
|
+
|
|
128
|
+
Example:
|
|
129
|
+
>>> result = execute_prompt(
|
|
130
|
+
... "greet",
|
|
131
|
+
... variables={"name": "World", "style": "formal"},
|
|
132
|
+
... output_model=GenericReport,
|
|
133
|
+
... )
|
|
134
|
+
>>> print(result.summary)
|
|
135
|
+
"""
|
|
136
|
+
return get_executor().execute(
|
|
137
|
+
prompt_name=prompt_name,
|
|
138
|
+
variables=variables,
|
|
139
|
+
output_model=output_model,
|
|
140
|
+
temperature=temperature,
|
|
141
|
+
provider=provider,
|
|
142
|
+
graph_path=graph_path,
|
|
143
|
+
prompts_dir=prompts_dir,
|
|
144
|
+
prompts_relative=prompts_relative,
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
# Default executor instance for LLM caching
|
|
149
|
+
# Use get_executor() to access, or set_executor() for dependency injection
|
|
150
|
+
_executor: "PromptExecutor | None" = None
|
|
151
|
+
_executor_lock = threading.Lock()
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def get_executor() -> "PromptExecutor":
|
|
155
|
+
"""Get the executor instance (thread-safe).
|
|
156
|
+
|
|
157
|
+
Returns the default singleton or a custom instance set via set_executor().
|
|
158
|
+
|
|
159
|
+
Returns:
|
|
160
|
+
PromptExecutor instance with LLM caching
|
|
161
|
+
"""
|
|
162
|
+
global _executor
|
|
163
|
+
if _executor is None:
|
|
164
|
+
with _executor_lock:
|
|
165
|
+
# Double-check after acquiring lock
|
|
166
|
+
if _executor is None:
|
|
167
|
+
_executor = PromptExecutor()
|
|
168
|
+
return _executor
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
class PromptExecutor:
|
|
172
|
+
"""Reusable executor with LLM caching and retry logic."""
|
|
173
|
+
|
|
174
|
+
def __init__(self, max_retries: int = MAX_RETRIES):
|
|
175
|
+
self._max_retries = max_retries
|
|
176
|
+
|
|
177
|
+
def _get_llm(
|
|
178
|
+
self,
|
|
179
|
+
temperature: float = DEFAULT_TEMPERATURE,
|
|
180
|
+
provider: str | None = None,
|
|
181
|
+
) -> BaseChatModel:
|
|
182
|
+
"""Get or create cached LLM instance.
|
|
183
|
+
|
|
184
|
+
Uses llm_factory which handles caching internally.
|
|
185
|
+
"""
|
|
186
|
+
return create_llm(temperature=temperature, provider=provider)
|
|
187
|
+
|
|
188
|
+
def _invoke_with_retry(
|
|
189
|
+
self, llm, messages, output_model: type[T] | None = None
|
|
190
|
+
) -> T | str:
|
|
191
|
+
"""Invoke LLM with exponential backoff retry.
|
|
192
|
+
|
|
193
|
+
Args:
|
|
194
|
+
llm: The LLM instance to use
|
|
195
|
+
messages: Messages to send
|
|
196
|
+
output_model: Optional Pydantic model for structured output
|
|
197
|
+
|
|
198
|
+
Returns:
|
|
199
|
+
LLM response (parsed model or string)
|
|
200
|
+
|
|
201
|
+
Raises:
|
|
202
|
+
Last exception if all retries fail
|
|
203
|
+
"""
|
|
204
|
+
last_exception = None
|
|
205
|
+
|
|
206
|
+
for attempt in range(self._max_retries):
|
|
207
|
+
try:
|
|
208
|
+
if output_model:
|
|
209
|
+
structured_llm = llm.with_structured_output(output_model)
|
|
210
|
+
return structured_llm.invoke(messages)
|
|
211
|
+
else:
|
|
212
|
+
response = llm.invoke(messages)
|
|
213
|
+
return response.content
|
|
214
|
+
|
|
215
|
+
except Exception as e:
|
|
216
|
+
last_exception = e
|
|
217
|
+
|
|
218
|
+
if not is_retryable(e) or attempt == self._max_retries - 1:
|
|
219
|
+
raise
|
|
220
|
+
|
|
221
|
+
# Exponential backoff with jitter
|
|
222
|
+
delay = min(RETRY_BASE_DELAY * (2**attempt), RETRY_MAX_DELAY)
|
|
223
|
+
logger.warning(
|
|
224
|
+
f"LLM call failed (attempt {attempt + 1}/{self._max_retries}): {e}. "
|
|
225
|
+
f"Retrying in {delay:.1f}s..."
|
|
226
|
+
)
|
|
227
|
+
time.sleep(delay)
|
|
228
|
+
|
|
229
|
+
raise last_exception
|
|
230
|
+
|
|
231
|
+
def execute(
|
|
232
|
+
self,
|
|
233
|
+
prompt_name: str,
|
|
234
|
+
variables: dict | None = None,
|
|
235
|
+
output_model: type[T] | None = None,
|
|
236
|
+
temperature: float = DEFAULT_TEMPERATURE,
|
|
237
|
+
provider: str | None = None,
|
|
238
|
+
graph_path: "Path | None" = None,
|
|
239
|
+
prompts_dir: "Path | None" = None,
|
|
240
|
+
prompts_relative: bool = False,
|
|
241
|
+
) -> T | str:
|
|
242
|
+
"""Execute a prompt using cached LLM with retry logic.
|
|
243
|
+
|
|
244
|
+
Same interface as execute_prompt() but with LLM caching and
|
|
245
|
+
automatic retry for transient failures.
|
|
246
|
+
|
|
247
|
+
Provider priority: parameter > YAML metadata > env var > default
|
|
248
|
+
|
|
249
|
+
Args:
|
|
250
|
+
prompt_name: Name of the prompt file (without .yaml)
|
|
251
|
+
variables: Variables to substitute in the template
|
|
252
|
+
output_model: Optional Pydantic model for structured output
|
|
253
|
+
temperature: LLM temperature setting
|
|
254
|
+
provider: LLM provider ("anthropic", "mistral", "openai")
|
|
255
|
+
graph_path: Path to graph file for relative prompt resolution
|
|
256
|
+
prompts_dir: Explicit prompts directory override
|
|
257
|
+
prompts_relative: If True, resolve prompts relative to graph_path
|
|
258
|
+
|
|
259
|
+
Raises:
|
|
260
|
+
ValueError: If required template variables are missing
|
|
261
|
+
"""
|
|
262
|
+
variables = variables or {}
|
|
263
|
+
|
|
264
|
+
prompt_config = load_prompt(
|
|
265
|
+
prompt_name,
|
|
266
|
+
prompts_dir=prompts_dir,
|
|
267
|
+
graph_path=graph_path,
|
|
268
|
+
prompts_relative=prompts_relative,
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
# Validate all required variables are provided (fail fast)
|
|
272
|
+
full_template = prompt_config.get("system", "") + prompt_config.get("user", "")
|
|
273
|
+
validate_variables(full_template, variables, prompt_name)
|
|
274
|
+
|
|
275
|
+
# Extract provider from YAML metadata if not provided
|
|
276
|
+
if provider is None and "provider" in prompt_config:
|
|
277
|
+
provider = prompt_config["provider"]
|
|
278
|
+
logger.debug(f"Using provider from YAML metadata: {provider}")
|
|
279
|
+
|
|
280
|
+
system_text = format_prompt(prompt_config.get("system", ""), variables)
|
|
281
|
+
user_text = format_prompt(prompt_config["user"], variables)
|
|
282
|
+
|
|
283
|
+
messages = []
|
|
284
|
+
if system_text:
|
|
285
|
+
messages.append(SystemMessage(content=system_text))
|
|
286
|
+
messages.append(HumanMessage(content=user_text))
|
|
287
|
+
|
|
288
|
+
llm = self._get_llm(temperature=temperature, provider=provider)
|
|
289
|
+
|
|
290
|
+
return self._invoke_with_retry(llm, messages, output_model)
|