yamlgraph 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of yamlgraph might be problematic. Click here for more details.
- examples/__init__.py +1 -0
- examples/storyboard/__init__.py +1 -0
- examples/storyboard/generate_videos.py +335 -0
- examples/storyboard/nodes/__init__.py +10 -0
- examples/storyboard/nodes/animated_character_node.py +248 -0
- examples/storyboard/nodes/animated_image_node.py +138 -0
- examples/storyboard/nodes/character_node.py +162 -0
- examples/storyboard/nodes/image_node.py +118 -0
- examples/storyboard/nodes/replicate_tool.py +238 -0
- examples/storyboard/retry_images.py +118 -0
- tests/__init__.py +1 -0
- tests/conftest.py +178 -0
- tests/integration/__init__.py +1 -0
- tests/integration/test_animated_storyboard.py +63 -0
- tests/integration/test_cli_commands.py +242 -0
- tests/integration/test_map_demo.py +50 -0
- tests/integration/test_memory_demo.py +281 -0
- tests/integration/test_pipeline_flow.py +105 -0
- tests/integration/test_providers.py +163 -0
- tests/integration/test_resume.py +75 -0
- tests/unit/__init__.py +1 -0
- tests/unit/test_agent_nodes.py +200 -0
- tests/unit/test_checkpointer.py +212 -0
- tests/unit/test_cli.py +121 -0
- tests/unit/test_cli_package.py +81 -0
- tests/unit/test_compile_graph_map.py +132 -0
- tests/unit/test_conditions_routing.py +253 -0
- tests/unit/test_config.py +93 -0
- tests/unit/test_conversation_memory.py +270 -0
- tests/unit/test_database.py +145 -0
- tests/unit/test_deprecation.py +104 -0
- tests/unit/test_executor.py +60 -0
- tests/unit/test_executor_async.py +179 -0
- tests/unit/test_export.py +150 -0
- tests/unit/test_expressions.py +178 -0
- tests/unit/test_format_prompt.py +145 -0
- tests/unit/test_generic_report.py +200 -0
- tests/unit/test_graph_commands.py +327 -0
- tests/unit/test_graph_loader.py +299 -0
- tests/unit/test_graph_schema.py +193 -0
- tests/unit/test_inline_schema.py +151 -0
- tests/unit/test_issues.py +164 -0
- tests/unit/test_jinja2_prompts.py +85 -0
- tests/unit/test_langsmith.py +319 -0
- tests/unit/test_llm_factory.py +109 -0
- tests/unit/test_llm_factory_async.py +118 -0
- tests/unit/test_loops.py +403 -0
- tests/unit/test_map_node.py +144 -0
- tests/unit/test_no_backward_compat.py +56 -0
- tests/unit/test_node_factory.py +225 -0
- tests/unit/test_prompts.py +166 -0
- tests/unit/test_python_nodes.py +198 -0
- tests/unit/test_reliability.py +298 -0
- tests/unit/test_result_export.py +234 -0
- tests/unit/test_router.py +296 -0
- tests/unit/test_sanitize.py +99 -0
- tests/unit/test_schema_loader.py +295 -0
- tests/unit/test_shell_tools.py +229 -0
- tests/unit/test_state_builder.py +331 -0
- tests/unit/test_state_builder_map.py +104 -0
- tests/unit/test_state_config.py +197 -0
- tests/unit/test_template.py +190 -0
- tests/unit/test_tool_nodes.py +129 -0
- yamlgraph/__init__.py +35 -0
- yamlgraph/builder.py +110 -0
- yamlgraph/cli/__init__.py +139 -0
- yamlgraph/cli/__main__.py +6 -0
- yamlgraph/cli/commands.py +232 -0
- yamlgraph/cli/deprecation.py +92 -0
- yamlgraph/cli/graph_commands.py +382 -0
- yamlgraph/cli/validators.py +37 -0
- yamlgraph/config.py +67 -0
- yamlgraph/constants.py +66 -0
- yamlgraph/error_handlers.py +226 -0
- yamlgraph/executor.py +275 -0
- yamlgraph/executor_async.py +122 -0
- yamlgraph/graph_loader.py +337 -0
- yamlgraph/map_compiler.py +138 -0
- yamlgraph/models/__init__.py +36 -0
- yamlgraph/models/graph_schema.py +141 -0
- yamlgraph/models/schemas.py +124 -0
- yamlgraph/models/state_builder.py +236 -0
- yamlgraph/node_factory.py +240 -0
- yamlgraph/routing.py +87 -0
- yamlgraph/schema_loader.py +160 -0
- yamlgraph/storage/__init__.py +17 -0
- yamlgraph/storage/checkpointer.py +72 -0
- yamlgraph/storage/database.py +320 -0
- yamlgraph/storage/export.py +269 -0
- yamlgraph/tools/__init__.py +1 -0
- yamlgraph/tools/agent.py +235 -0
- yamlgraph/tools/nodes.py +124 -0
- yamlgraph/tools/python_tool.py +178 -0
- yamlgraph/tools/shell.py +205 -0
- yamlgraph/utils/__init__.py +47 -0
- yamlgraph/utils/conditions.py +157 -0
- yamlgraph/utils/expressions.py +111 -0
- yamlgraph/utils/langsmith.py +308 -0
- yamlgraph/utils/llm_factory.py +118 -0
- yamlgraph/utils/llm_factory_async.py +105 -0
- yamlgraph/utils/logging.py +127 -0
- yamlgraph/utils/prompts.py +116 -0
- yamlgraph/utils/sanitize.py +98 -0
- yamlgraph/utils/template.py +102 -0
- yamlgraph/utils/validators.py +181 -0
- yamlgraph-0.1.1.dist-info/METADATA +854 -0
- yamlgraph-0.1.1.dist-info/RECORD +111 -0
- yamlgraph-0.1.1.dist-info/WHEEL +5 -0
- yamlgraph-0.1.1.dist-info/entry_points.txt +2 -0
- yamlgraph-0.1.1.dist-info/licenses/LICENSE +21 -0
- yamlgraph-0.1.1.dist-info/top_level.txt +3 -0
|
@@ -0,0 +1,226 @@
|
|
|
1
|
+
"""Error handling strategies for node execution.
|
|
2
|
+
|
|
3
|
+
Provides strategy functions for different error handling modes:
|
|
4
|
+
- skip: Continue without output
|
|
5
|
+
- fail: Raise exception immediately
|
|
6
|
+
- retry: Retry up to N times
|
|
7
|
+
- fallback: Try fallback provider
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import logging
|
|
11
|
+
from typing import Any, Callable
|
|
12
|
+
|
|
13
|
+
from yamlgraph.models import ErrorType, PipelineError
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class NodeResult:
|
|
19
|
+
"""Result of node execution with consistent structure.
|
|
20
|
+
|
|
21
|
+
Attributes:
|
|
22
|
+
success: Whether execution succeeded
|
|
23
|
+
output: The result value (if success)
|
|
24
|
+
error: PipelineError (if failure)
|
|
25
|
+
state_updates: Additional state updates
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
success: bool,
|
|
31
|
+
output: Any = None,
|
|
32
|
+
error: PipelineError | None = None,
|
|
33
|
+
state_updates: dict | None = None,
|
|
34
|
+
):
|
|
35
|
+
self.success = success
|
|
36
|
+
self.output = output
|
|
37
|
+
self.error = error
|
|
38
|
+
self.state_updates = state_updates or {}
|
|
39
|
+
|
|
40
|
+
def to_state_update(
|
|
41
|
+
self,
|
|
42
|
+
state_key: str,
|
|
43
|
+
node_name: str,
|
|
44
|
+
loop_counts: dict,
|
|
45
|
+
) -> dict:
|
|
46
|
+
"""Convert to LangGraph state update dict.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
state_key: Key to store output under
|
|
50
|
+
node_name: Name of the node
|
|
51
|
+
loop_counts: Current loop counts
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
State update dict with consistent structure
|
|
55
|
+
"""
|
|
56
|
+
update = {
|
|
57
|
+
"current_step": node_name,
|
|
58
|
+
"_loop_counts": loop_counts,
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
if self.success:
|
|
62
|
+
update[state_key] = self.output
|
|
63
|
+
elif self.error:
|
|
64
|
+
# Always use 'errors' list for consistency
|
|
65
|
+
update["errors"] = [self.error]
|
|
66
|
+
|
|
67
|
+
update.update(self.state_updates)
|
|
68
|
+
return update
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def handle_skip(
|
|
72
|
+
node_name: str,
|
|
73
|
+
error: Exception,
|
|
74
|
+
loop_counts: dict,
|
|
75
|
+
) -> NodeResult:
|
|
76
|
+
"""Handle error with skip strategy.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
node_name: Name of the node
|
|
80
|
+
error: The exception that occurred
|
|
81
|
+
loop_counts: Current loop counts
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
NodeResult with empty output
|
|
85
|
+
"""
|
|
86
|
+
logger.warning(f"Node {node_name} failed, skipping: {error}")
|
|
87
|
+
return NodeResult(success=True, output=None)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def handle_fail(
|
|
91
|
+
node_name: str,
|
|
92
|
+
error: Exception,
|
|
93
|
+
) -> None:
|
|
94
|
+
"""Handle error with fail strategy.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
node_name: Name of the node
|
|
98
|
+
error: The exception that occurred
|
|
99
|
+
|
|
100
|
+
Raises:
|
|
101
|
+
Exception: Always raises the original error
|
|
102
|
+
"""
|
|
103
|
+
logger.error(f"Node {node_name} failed (on_error=fail): {error}")
|
|
104
|
+
raise error
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def handle_retry(
|
|
108
|
+
node_name: str,
|
|
109
|
+
execute_fn: Callable[[], tuple[Any, Exception | None]],
|
|
110
|
+
max_retries: int,
|
|
111
|
+
) -> NodeResult:
|
|
112
|
+
"""Handle error with retry strategy.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
node_name: Name of the node
|
|
116
|
+
execute_fn: Function to execute (returns result, error)
|
|
117
|
+
max_retries: Maximum retry attempts
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
NodeResult with output or error
|
|
121
|
+
"""
|
|
122
|
+
last_exception: Exception | None = None
|
|
123
|
+
|
|
124
|
+
for attempt in range(1, max_retries + 1):
|
|
125
|
+
logger.info(f"Node {node_name} retry {attempt}/{max_retries}")
|
|
126
|
+
result, error = execute_fn()
|
|
127
|
+
if error is None:
|
|
128
|
+
return NodeResult(success=True, output=result)
|
|
129
|
+
last_exception = error
|
|
130
|
+
|
|
131
|
+
logger.error(f"Node {node_name} failed after {max_retries} attempts")
|
|
132
|
+
pipeline_error = PipelineError.from_exception(
|
|
133
|
+
last_exception or Exception("Unknown error"), node=node_name
|
|
134
|
+
)
|
|
135
|
+
return NodeResult(success=False, error=pipeline_error)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def handle_fallback(
|
|
139
|
+
node_name: str,
|
|
140
|
+
execute_fn: Callable[[str | None], tuple[Any, Exception | None]],
|
|
141
|
+
fallback_provider: str,
|
|
142
|
+
) -> NodeResult:
|
|
143
|
+
"""Handle error with fallback strategy.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
node_name: Name of the node
|
|
147
|
+
execute_fn: Function to execute with provider param
|
|
148
|
+
fallback_provider: Fallback provider to try
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
NodeResult with output or error
|
|
152
|
+
"""
|
|
153
|
+
logger.info(f"Node {node_name} trying fallback: {fallback_provider}")
|
|
154
|
+
result, fallback_error = execute_fn(fallback_provider)
|
|
155
|
+
|
|
156
|
+
if fallback_error is None:
|
|
157
|
+
return NodeResult(success=True, output=result)
|
|
158
|
+
|
|
159
|
+
logger.error(f"Node {node_name} failed with primary and fallback")
|
|
160
|
+
pipeline_error = PipelineError.from_exception(fallback_error, node=node_name)
|
|
161
|
+
return NodeResult(success=False, error=pipeline_error)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def handle_default(
|
|
165
|
+
node_name: str,
|
|
166
|
+
error: Exception,
|
|
167
|
+
) -> NodeResult:
|
|
168
|
+
"""Handle error with default strategy (log and return error).
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
node_name: Name of the node
|
|
172
|
+
error: The exception that occurred
|
|
173
|
+
|
|
174
|
+
Returns:
|
|
175
|
+
NodeResult with error
|
|
176
|
+
"""
|
|
177
|
+
logger.error(f"Node {node_name} failed: {error}")
|
|
178
|
+
pipeline_error = PipelineError.from_exception(error, node=node_name)
|
|
179
|
+
return NodeResult(success=False, error=pipeline_error)
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def check_requirements(
|
|
183
|
+
requires: list[str],
|
|
184
|
+
state: dict,
|
|
185
|
+
node_name: str,
|
|
186
|
+
) -> PipelineError | None:
|
|
187
|
+
"""Check if all required state keys are present.
|
|
188
|
+
|
|
189
|
+
Args:
|
|
190
|
+
requires: List of required state keys
|
|
191
|
+
state: Current state
|
|
192
|
+
node_name: Name of the node
|
|
193
|
+
|
|
194
|
+
Returns:
|
|
195
|
+
PipelineError if requirements not met, None otherwise
|
|
196
|
+
"""
|
|
197
|
+
for req in requires:
|
|
198
|
+
if state.get(req) is None:
|
|
199
|
+
return PipelineError(
|
|
200
|
+
type=ErrorType.STATE_ERROR,
|
|
201
|
+
message=f"Missing required state: {req}",
|
|
202
|
+
node=node_name,
|
|
203
|
+
retryable=False,
|
|
204
|
+
)
|
|
205
|
+
return None
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def check_loop_limit(
|
|
209
|
+
node_name: str,
|
|
210
|
+
loop_limit: int | None,
|
|
211
|
+
current_count: int,
|
|
212
|
+
) -> bool:
|
|
213
|
+
"""Check if loop limit has been reached.
|
|
214
|
+
|
|
215
|
+
Args:
|
|
216
|
+
node_name: Name of the node
|
|
217
|
+
loop_limit: Maximum loop iterations (None = no limit)
|
|
218
|
+
current_count: Current iteration count
|
|
219
|
+
|
|
220
|
+
Returns:
|
|
221
|
+
True if limit reached, False otherwise
|
|
222
|
+
"""
|
|
223
|
+
if loop_limit is not None and current_count >= loop_limit:
|
|
224
|
+
logger.warning(f"Node {node_name} hit loop limit ({loop_limit})")
|
|
225
|
+
return True
|
|
226
|
+
return False
|
yamlgraph/executor.py
ADDED
|
@@ -0,0 +1,275 @@
|
|
|
1
|
+
"""YAML Prompt Executor - Unified interface for LLM calls.
|
|
2
|
+
|
|
3
|
+
This module provides a simple, reusable executor for YAML-defined prompts
|
|
4
|
+
with support for structured outputs via Pydantic models.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
import threading
|
|
9
|
+
import time
|
|
10
|
+
from typing import Type, TypeVar
|
|
11
|
+
|
|
12
|
+
from langchain_core.language_models.chat_models import BaseChatModel
|
|
13
|
+
from langchain_core.messages import HumanMessage, SystemMessage
|
|
14
|
+
from pydantic import BaseModel
|
|
15
|
+
|
|
16
|
+
from yamlgraph.config import (
|
|
17
|
+
DEFAULT_TEMPERATURE,
|
|
18
|
+
MAX_RETRIES,
|
|
19
|
+
RETRY_BASE_DELAY,
|
|
20
|
+
RETRY_MAX_DELAY,
|
|
21
|
+
)
|
|
22
|
+
from yamlgraph.utils.llm_factory import create_llm
|
|
23
|
+
from yamlgraph.utils.prompts import load_prompt
|
|
24
|
+
from yamlgraph.utils.template import validate_variables
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
T = TypeVar("T", bound=BaseModel)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
# Exceptions that are retryable
|
|
32
|
+
RETRYABLE_EXCEPTIONS = (
|
|
33
|
+
"RateLimitError",
|
|
34
|
+
"APIConnectionError",
|
|
35
|
+
"APITimeoutError",
|
|
36
|
+
"InternalServerError",
|
|
37
|
+
"ServiceUnavailableError",
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def is_retryable(exception: Exception) -> bool:
|
|
42
|
+
"""Check if an exception is retryable.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
exception: The exception to check
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
True if the exception should be retried
|
|
49
|
+
"""
|
|
50
|
+
exc_name = type(exception).__name__
|
|
51
|
+
return exc_name in RETRYABLE_EXCEPTIONS or "rate" in exc_name.lower()
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def format_prompt(
|
|
55
|
+
template: str,
|
|
56
|
+
variables: dict,
|
|
57
|
+
state: dict | None = None,
|
|
58
|
+
) -> str:
|
|
59
|
+
"""Format a prompt template with variables.
|
|
60
|
+
|
|
61
|
+
Supports both simple {variable} placeholders and Jinja2 templates.
|
|
62
|
+
If the template contains Jinja2 syntax ({%, {{), uses Jinja2 rendering.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
template: Template string with {variable} or Jinja2 placeholders
|
|
66
|
+
variables: Dictionary of variable values
|
|
67
|
+
state: Optional state dict for Jinja2 templates (accessible as {{ state.field }})
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
Formatted string
|
|
71
|
+
|
|
72
|
+
Examples:
|
|
73
|
+
Simple format:
|
|
74
|
+
format_prompt("Hello {name}", {"name": "World"})
|
|
75
|
+
|
|
76
|
+
Jinja2 with variables:
|
|
77
|
+
format_prompt("{% for item in items %}{{ item }}{% endfor %}", {"items": [1, 2]})
|
|
78
|
+
|
|
79
|
+
Jinja2 with state:
|
|
80
|
+
format_prompt("Topic: {{ state.topic }}", {}, state={"topic": "AI"})
|
|
81
|
+
"""
|
|
82
|
+
# Check for Jinja2 syntax
|
|
83
|
+
if "{%" in template or "{{" in template:
|
|
84
|
+
from jinja2 import Template
|
|
85
|
+
|
|
86
|
+
jinja_template = Template(template)
|
|
87
|
+
# Pass both variables and state to Jinja2
|
|
88
|
+
context = {"state": state or {}, **variables}
|
|
89
|
+
return jinja_template.render(**context)
|
|
90
|
+
|
|
91
|
+
# Fall back to simple format - stringify lists for compatibility
|
|
92
|
+
safe_vars = {
|
|
93
|
+
k: (", ".join(map(str, v)) if isinstance(v, list) else v)
|
|
94
|
+
for k, v in variables.items()
|
|
95
|
+
}
|
|
96
|
+
return template.format(**safe_vars)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def execute_prompt(
|
|
100
|
+
prompt_name: str,
|
|
101
|
+
variables: dict | None = None,
|
|
102
|
+
output_model: Type[T] | None = None,
|
|
103
|
+
temperature: float = DEFAULT_TEMPERATURE,
|
|
104
|
+
provider: str | None = None,
|
|
105
|
+
) -> T | str:
|
|
106
|
+
"""Execute a YAML prompt with optional structured output.
|
|
107
|
+
|
|
108
|
+
Uses the singleton PromptExecutor for LLM caching.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
prompt_name: Name of the prompt file (without .yaml)
|
|
112
|
+
variables: Variables to substitute in the template
|
|
113
|
+
output_model: Optional Pydantic model for structured output
|
|
114
|
+
temperature: LLM temperature setting
|
|
115
|
+
provider: LLM provider ("anthropic", "mistral", "openai").
|
|
116
|
+
Can also be set in YAML metadata or PROVIDER env var.
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
Parsed Pydantic model if output_model provided, else raw string
|
|
120
|
+
|
|
121
|
+
Example:
|
|
122
|
+
>>> result = execute_prompt(
|
|
123
|
+
... "greet",
|
|
124
|
+
... variables={"name": "World", "style": "formal"},
|
|
125
|
+
... output_model=GenericReport,
|
|
126
|
+
... )
|
|
127
|
+
>>> print(result.summary)
|
|
128
|
+
"""
|
|
129
|
+
return get_executor().execute(
|
|
130
|
+
prompt_name=prompt_name,
|
|
131
|
+
variables=variables,
|
|
132
|
+
output_model=output_model,
|
|
133
|
+
temperature=temperature,
|
|
134
|
+
provider=provider,
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
# Default executor instance for LLM caching
|
|
139
|
+
# Use get_executor() to access, or set_executor() for dependency injection
|
|
140
|
+
_executor: "PromptExecutor | None" = None
|
|
141
|
+
_executor_lock = threading.Lock()
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def get_executor() -> "PromptExecutor":
|
|
145
|
+
"""Get the executor instance (thread-safe).
|
|
146
|
+
|
|
147
|
+
Returns the default singleton or a custom instance set via set_executor().
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
PromptExecutor instance with LLM caching
|
|
151
|
+
"""
|
|
152
|
+
global _executor
|
|
153
|
+
if _executor is None:
|
|
154
|
+
with _executor_lock:
|
|
155
|
+
# Double-check after acquiring lock
|
|
156
|
+
if _executor is None:
|
|
157
|
+
_executor = PromptExecutor()
|
|
158
|
+
return _executor
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def set_executor(executor: "PromptExecutor | None") -> None:
|
|
162
|
+
"""Set a custom executor instance for dependency injection (thread-safe).
|
|
163
|
+
|
|
164
|
+
Useful for testing or when you need different executor configurations.
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
executor: Custom PromptExecutor instance, or None to reset to default
|
|
168
|
+
"""
|
|
169
|
+
global _executor
|
|
170
|
+
with _executor_lock:
|
|
171
|
+
_executor = executor
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
class PromptExecutor:
|
|
175
|
+
"""Reusable executor with LLM caching and retry logic."""
|
|
176
|
+
|
|
177
|
+
def __init__(self, max_retries: int = MAX_RETRIES):
|
|
178
|
+
self._max_retries = max_retries
|
|
179
|
+
|
|
180
|
+
def _get_llm(
|
|
181
|
+
self,
|
|
182
|
+
temperature: float = DEFAULT_TEMPERATURE,
|
|
183
|
+
provider: str | None = None,
|
|
184
|
+
) -> BaseChatModel:
|
|
185
|
+
"""Get or create cached LLM instance.
|
|
186
|
+
|
|
187
|
+
Uses llm_factory which handles caching internally.
|
|
188
|
+
"""
|
|
189
|
+
return create_llm(temperature=temperature, provider=provider)
|
|
190
|
+
|
|
191
|
+
def _invoke_with_retry(
|
|
192
|
+
self, llm, messages, output_model: Type[T] | None = None
|
|
193
|
+
) -> T | str:
|
|
194
|
+
"""Invoke LLM with exponential backoff retry.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
llm: The LLM instance to use
|
|
198
|
+
messages: Messages to send
|
|
199
|
+
output_model: Optional Pydantic model for structured output
|
|
200
|
+
|
|
201
|
+
Returns:
|
|
202
|
+
LLM response (parsed model or string)
|
|
203
|
+
|
|
204
|
+
Raises:
|
|
205
|
+
Last exception if all retries fail
|
|
206
|
+
"""
|
|
207
|
+
last_exception = None
|
|
208
|
+
|
|
209
|
+
for attempt in range(self._max_retries):
|
|
210
|
+
try:
|
|
211
|
+
if output_model:
|
|
212
|
+
structured_llm = llm.with_structured_output(output_model)
|
|
213
|
+
return structured_llm.invoke(messages)
|
|
214
|
+
else:
|
|
215
|
+
response = llm.invoke(messages)
|
|
216
|
+
return response.content
|
|
217
|
+
|
|
218
|
+
except Exception as e:
|
|
219
|
+
last_exception = e
|
|
220
|
+
|
|
221
|
+
if not is_retryable(e) or attempt == self._max_retries - 1:
|
|
222
|
+
raise
|
|
223
|
+
|
|
224
|
+
# Exponential backoff with jitter
|
|
225
|
+
delay = min(RETRY_BASE_DELAY * (2**attempt), RETRY_MAX_DELAY)
|
|
226
|
+
logger.warning(
|
|
227
|
+
f"LLM call failed (attempt {attempt + 1}/{self._max_retries}): {e}. "
|
|
228
|
+
f"Retrying in {delay:.1f}s..."
|
|
229
|
+
)
|
|
230
|
+
time.sleep(delay)
|
|
231
|
+
|
|
232
|
+
raise last_exception
|
|
233
|
+
|
|
234
|
+
def execute(
|
|
235
|
+
self,
|
|
236
|
+
prompt_name: str,
|
|
237
|
+
variables: dict | None = None,
|
|
238
|
+
output_model: Type[T] | None = None,
|
|
239
|
+
temperature: float = DEFAULT_TEMPERATURE,
|
|
240
|
+
provider: str | None = None,
|
|
241
|
+
) -> T | str:
|
|
242
|
+
"""Execute a prompt using cached LLM with retry logic.
|
|
243
|
+
|
|
244
|
+
Same interface as execute_prompt() but with LLM caching and
|
|
245
|
+
automatic retry for transient failures.
|
|
246
|
+
|
|
247
|
+
Provider priority: parameter > YAML metadata > env var > default
|
|
248
|
+
|
|
249
|
+
Raises:
|
|
250
|
+
ValueError: If required template variables are missing
|
|
251
|
+
"""
|
|
252
|
+
variables = variables or {}
|
|
253
|
+
|
|
254
|
+
prompt_config = load_prompt(prompt_name)
|
|
255
|
+
|
|
256
|
+
# Validate all required variables are provided (fail fast)
|
|
257
|
+
full_template = prompt_config.get("system", "") + prompt_config.get("user", "")
|
|
258
|
+
validate_variables(full_template, variables, prompt_name)
|
|
259
|
+
|
|
260
|
+
# Extract provider from YAML metadata if not provided
|
|
261
|
+
if provider is None and "provider" in prompt_config:
|
|
262
|
+
provider = prompt_config["provider"]
|
|
263
|
+
logger.debug(f"Using provider from YAML metadata: {provider}")
|
|
264
|
+
|
|
265
|
+
system_text = format_prompt(prompt_config.get("system", ""), variables)
|
|
266
|
+
user_text = format_prompt(prompt_config["user"], variables)
|
|
267
|
+
|
|
268
|
+
messages = []
|
|
269
|
+
if system_text:
|
|
270
|
+
messages.append(SystemMessage(content=system_text))
|
|
271
|
+
messages.append(HumanMessage(content=user_text))
|
|
272
|
+
|
|
273
|
+
llm = self._get_llm(temperature=temperature, provider=provider)
|
|
274
|
+
|
|
275
|
+
return self._invoke_with_retry(llm, messages, output_model)
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
"""Async Prompt Executor - Async interface for LLM calls.
|
|
2
|
+
|
|
3
|
+
This module provides async versions of execute_prompt for use in
|
|
4
|
+
async contexts like web servers or concurrent pipelines.
|
|
5
|
+
|
|
6
|
+
Note: This is a foundation module. The underlying LLM calls still
|
|
7
|
+
use sync HTTP clients wrapped with run_in_executor.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import asyncio
|
|
11
|
+
import logging
|
|
12
|
+
from typing import Type, TypeVar
|
|
13
|
+
|
|
14
|
+
from langchain_core.messages import HumanMessage, SystemMessage
|
|
15
|
+
from pydantic import BaseModel
|
|
16
|
+
|
|
17
|
+
from yamlgraph.config import DEFAULT_TEMPERATURE
|
|
18
|
+
from yamlgraph.executor import format_prompt, load_prompt
|
|
19
|
+
from yamlgraph.utils.llm_factory import create_llm
|
|
20
|
+
from yamlgraph.utils.llm_factory_async import invoke_async
|
|
21
|
+
from yamlgraph.utils.template import validate_variables
|
|
22
|
+
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
T = TypeVar("T", bound=BaseModel)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
async def execute_prompt_async(
|
|
29
|
+
prompt_name: str,
|
|
30
|
+
variables: dict | None = None,
|
|
31
|
+
output_model: Type[T] | None = None,
|
|
32
|
+
temperature: float = DEFAULT_TEMPERATURE,
|
|
33
|
+
provider: str | None = None,
|
|
34
|
+
) -> T | str:
|
|
35
|
+
"""Execute a YAML prompt asynchronously.
|
|
36
|
+
|
|
37
|
+
Async version of execute_prompt for use in async contexts.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
prompt_name: Name of the prompt file (without .yaml)
|
|
41
|
+
variables: Variables to substitute in the template
|
|
42
|
+
output_model: Optional Pydantic model for structured output
|
|
43
|
+
temperature: LLM temperature setting
|
|
44
|
+
provider: LLM provider ("anthropic", "mistral", "openai")
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
Parsed Pydantic model if output_model provided, else raw string
|
|
48
|
+
|
|
49
|
+
Example:
|
|
50
|
+
>>> result = await execute_prompt_async(
|
|
51
|
+
... "greet",
|
|
52
|
+
... variables={"name": "World"},
|
|
53
|
+
... output_model=GenericReport,
|
|
54
|
+
... )
|
|
55
|
+
"""
|
|
56
|
+
variables = variables or {}
|
|
57
|
+
|
|
58
|
+
# Load and validate prompt (sync - file I/O is fast)
|
|
59
|
+
prompt_config = load_prompt(prompt_name)
|
|
60
|
+
|
|
61
|
+
full_template = prompt_config.get("system", "") + prompt_config.get("user", "")
|
|
62
|
+
validate_variables(full_template, variables, prompt_name)
|
|
63
|
+
|
|
64
|
+
# Extract provider from YAML metadata if not provided
|
|
65
|
+
if provider is None and "provider" in prompt_config:
|
|
66
|
+
provider = prompt_config["provider"]
|
|
67
|
+
logger.debug(f"Using provider from YAML metadata: {provider}")
|
|
68
|
+
|
|
69
|
+
system_text = format_prompt(prompt_config.get("system", ""), variables)
|
|
70
|
+
user_text = format_prompt(prompt_config["user"], variables)
|
|
71
|
+
|
|
72
|
+
messages = []
|
|
73
|
+
if system_text:
|
|
74
|
+
messages.append(SystemMessage(content=system_text))
|
|
75
|
+
messages.append(HumanMessage(content=user_text))
|
|
76
|
+
|
|
77
|
+
# Create LLM (cached via factory)
|
|
78
|
+
llm = create_llm(temperature=temperature, provider=provider)
|
|
79
|
+
|
|
80
|
+
# Invoke asynchronously
|
|
81
|
+
return await invoke_async(llm, messages, output_model)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
async def execute_prompts_concurrent(
|
|
85
|
+
prompts: list[dict],
|
|
86
|
+
) -> list[BaseModel | str]:
|
|
87
|
+
"""Execute multiple prompts concurrently.
|
|
88
|
+
|
|
89
|
+
Useful for parallel LLM calls in pipelines.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
prompts: List of dicts with keys:
|
|
93
|
+
- prompt_name: str (required)
|
|
94
|
+
- variables: dict (optional)
|
|
95
|
+
- output_model: Type[BaseModel] (optional)
|
|
96
|
+
- temperature: float (optional)
|
|
97
|
+
- provider: str (optional)
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
List of results in same order as input prompts
|
|
101
|
+
|
|
102
|
+
Example:
|
|
103
|
+
>>> results = await execute_prompts_concurrent([
|
|
104
|
+
... {"prompt_name": "summarize", "variables": {"text": "..."}},
|
|
105
|
+
... {"prompt_name": "analyze", "variables": {"text": "..."}},
|
|
106
|
+
... ])
|
|
107
|
+
"""
|
|
108
|
+
tasks = []
|
|
109
|
+
for prompt_config in prompts:
|
|
110
|
+
task = execute_prompt_async(
|
|
111
|
+
prompt_name=prompt_config["prompt_name"],
|
|
112
|
+
variables=prompt_config.get("variables"),
|
|
113
|
+
output_model=prompt_config.get("output_model"),
|
|
114
|
+
temperature=prompt_config.get("temperature", DEFAULT_TEMPERATURE),
|
|
115
|
+
provider=prompt_config.get("provider"),
|
|
116
|
+
)
|
|
117
|
+
tasks.append(task)
|
|
118
|
+
|
|
119
|
+
return await asyncio.gather(*tasks)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
__all__ = ["execute_prompt_async", "execute_prompts_concurrent"]
|