yamlgraph 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of yamlgraph might be problematic. Click here for more details.

Files changed (111) hide show
  1. examples/__init__.py +1 -0
  2. examples/storyboard/__init__.py +1 -0
  3. examples/storyboard/generate_videos.py +335 -0
  4. examples/storyboard/nodes/__init__.py +10 -0
  5. examples/storyboard/nodes/animated_character_node.py +248 -0
  6. examples/storyboard/nodes/animated_image_node.py +138 -0
  7. examples/storyboard/nodes/character_node.py +162 -0
  8. examples/storyboard/nodes/image_node.py +118 -0
  9. examples/storyboard/nodes/replicate_tool.py +238 -0
  10. examples/storyboard/retry_images.py +118 -0
  11. tests/__init__.py +1 -0
  12. tests/conftest.py +178 -0
  13. tests/integration/__init__.py +1 -0
  14. tests/integration/test_animated_storyboard.py +63 -0
  15. tests/integration/test_cli_commands.py +242 -0
  16. tests/integration/test_map_demo.py +50 -0
  17. tests/integration/test_memory_demo.py +281 -0
  18. tests/integration/test_pipeline_flow.py +105 -0
  19. tests/integration/test_providers.py +163 -0
  20. tests/integration/test_resume.py +75 -0
  21. tests/unit/__init__.py +1 -0
  22. tests/unit/test_agent_nodes.py +200 -0
  23. tests/unit/test_checkpointer.py +212 -0
  24. tests/unit/test_cli.py +121 -0
  25. tests/unit/test_cli_package.py +81 -0
  26. tests/unit/test_compile_graph_map.py +132 -0
  27. tests/unit/test_conditions_routing.py +253 -0
  28. tests/unit/test_config.py +93 -0
  29. tests/unit/test_conversation_memory.py +270 -0
  30. tests/unit/test_database.py +145 -0
  31. tests/unit/test_deprecation.py +104 -0
  32. tests/unit/test_executor.py +60 -0
  33. tests/unit/test_executor_async.py +179 -0
  34. tests/unit/test_export.py +150 -0
  35. tests/unit/test_expressions.py +178 -0
  36. tests/unit/test_format_prompt.py +145 -0
  37. tests/unit/test_generic_report.py +200 -0
  38. tests/unit/test_graph_commands.py +327 -0
  39. tests/unit/test_graph_loader.py +299 -0
  40. tests/unit/test_graph_schema.py +193 -0
  41. tests/unit/test_inline_schema.py +151 -0
  42. tests/unit/test_issues.py +164 -0
  43. tests/unit/test_jinja2_prompts.py +85 -0
  44. tests/unit/test_langsmith.py +319 -0
  45. tests/unit/test_llm_factory.py +109 -0
  46. tests/unit/test_llm_factory_async.py +118 -0
  47. tests/unit/test_loops.py +403 -0
  48. tests/unit/test_map_node.py +144 -0
  49. tests/unit/test_no_backward_compat.py +56 -0
  50. tests/unit/test_node_factory.py +225 -0
  51. tests/unit/test_prompts.py +166 -0
  52. tests/unit/test_python_nodes.py +198 -0
  53. tests/unit/test_reliability.py +298 -0
  54. tests/unit/test_result_export.py +234 -0
  55. tests/unit/test_router.py +296 -0
  56. tests/unit/test_sanitize.py +99 -0
  57. tests/unit/test_schema_loader.py +295 -0
  58. tests/unit/test_shell_tools.py +229 -0
  59. tests/unit/test_state_builder.py +331 -0
  60. tests/unit/test_state_builder_map.py +104 -0
  61. tests/unit/test_state_config.py +197 -0
  62. tests/unit/test_template.py +190 -0
  63. tests/unit/test_tool_nodes.py +129 -0
  64. yamlgraph/__init__.py +35 -0
  65. yamlgraph/builder.py +110 -0
  66. yamlgraph/cli/__init__.py +139 -0
  67. yamlgraph/cli/__main__.py +6 -0
  68. yamlgraph/cli/commands.py +232 -0
  69. yamlgraph/cli/deprecation.py +92 -0
  70. yamlgraph/cli/graph_commands.py +382 -0
  71. yamlgraph/cli/validators.py +37 -0
  72. yamlgraph/config.py +67 -0
  73. yamlgraph/constants.py +66 -0
  74. yamlgraph/error_handlers.py +226 -0
  75. yamlgraph/executor.py +275 -0
  76. yamlgraph/executor_async.py +122 -0
  77. yamlgraph/graph_loader.py +337 -0
  78. yamlgraph/map_compiler.py +138 -0
  79. yamlgraph/models/__init__.py +36 -0
  80. yamlgraph/models/graph_schema.py +141 -0
  81. yamlgraph/models/schemas.py +124 -0
  82. yamlgraph/models/state_builder.py +236 -0
  83. yamlgraph/node_factory.py +240 -0
  84. yamlgraph/routing.py +87 -0
  85. yamlgraph/schema_loader.py +160 -0
  86. yamlgraph/storage/__init__.py +17 -0
  87. yamlgraph/storage/checkpointer.py +72 -0
  88. yamlgraph/storage/database.py +320 -0
  89. yamlgraph/storage/export.py +269 -0
  90. yamlgraph/tools/__init__.py +1 -0
  91. yamlgraph/tools/agent.py +235 -0
  92. yamlgraph/tools/nodes.py +124 -0
  93. yamlgraph/tools/python_tool.py +178 -0
  94. yamlgraph/tools/shell.py +205 -0
  95. yamlgraph/utils/__init__.py +47 -0
  96. yamlgraph/utils/conditions.py +157 -0
  97. yamlgraph/utils/expressions.py +111 -0
  98. yamlgraph/utils/langsmith.py +308 -0
  99. yamlgraph/utils/llm_factory.py +118 -0
  100. yamlgraph/utils/llm_factory_async.py +105 -0
  101. yamlgraph/utils/logging.py +127 -0
  102. yamlgraph/utils/prompts.py +116 -0
  103. yamlgraph/utils/sanitize.py +98 -0
  104. yamlgraph/utils/template.py +102 -0
  105. yamlgraph/utils/validators.py +181 -0
  106. yamlgraph-0.1.1.dist-info/METADATA +854 -0
  107. yamlgraph-0.1.1.dist-info/RECORD +111 -0
  108. yamlgraph-0.1.1.dist-info/WHEEL +5 -0
  109. yamlgraph-0.1.1.dist-info/entry_points.txt +2 -0
  110. yamlgraph-0.1.1.dist-info/licenses/LICENSE +21 -0
  111. yamlgraph-0.1.1.dist-info/top_level.txt +3 -0
@@ -0,0 +1,226 @@
1
+ """Error handling strategies for node execution.
2
+
3
+ Provides strategy functions for different error handling modes:
4
+ - skip: Continue without output
5
+ - fail: Raise exception immediately
6
+ - retry: Retry up to N times
7
+ - fallback: Try fallback provider
8
+ """
9
+
10
+ import logging
11
+ from typing import Any, Callable
12
+
13
+ from yamlgraph.models import ErrorType, PipelineError
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class NodeResult:
19
+ """Result of node execution with consistent structure.
20
+
21
+ Attributes:
22
+ success: Whether execution succeeded
23
+ output: The result value (if success)
24
+ error: PipelineError (if failure)
25
+ state_updates: Additional state updates
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ success: bool,
31
+ output: Any = None,
32
+ error: PipelineError | None = None,
33
+ state_updates: dict | None = None,
34
+ ):
35
+ self.success = success
36
+ self.output = output
37
+ self.error = error
38
+ self.state_updates = state_updates or {}
39
+
40
+ def to_state_update(
41
+ self,
42
+ state_key: str,
43
+ node_name: str,
44
+ loop_counts: dict,
45
+ ) -> dict:
46
+ """Convert to LangGraph state update dict.
47
+
48
+ Args:
49
+ state_key: Key to store output under
50
+ node_name: Name of the node
51
+ loop_counts: Current loop counts
52
+
53
+ Returns:
54
+ State update dict with consistent structure
55
+ """
56
+ update = {
57
+ "current_step": node_name,
58
+ "_loop_counts": loop_counts,
59
+ }
60
+
61
+ if self.success:
62
+ update[state_key] = self.output
63
+ elif self.error:
64
+ # Always use 'errors' list for consistency
65
+ update["errors"] = [self.error]
66
+
67
+ update.update(self.state_updates)
68
+ return update
69
+
70
+
71
+ def handle_skip(
72
+ node_name: str,
73
+ error: Exception,
74
+ loop_counts: dict,
75
+ ) -> NodeResult:
76
+ """Handle error with skip strategy.
77
+
78
+ Args:
79
+ node_name: Name of the node
80
+ error: The exception that occurred
81
+ loop_counts: Current loop counts
82
+
83
+ Returns:
84
+ NodeResult with empty output
85
+ """
86
+ logger.warning(f"Node {node_name} failed, skipping: {error}")
87
+ return NodeResult(success=True, output=None)
88
+
89
+
90
+ def handle_fail(
91
+ node_name: str,
92
+ error: Exception,
93
+ ) -> None:
94
+ """Handle error with fail strategy.
95
+
96
+ Args:
97
+ node_name: Name of the node
98
+ error: The exception that occurred
99
+
100
+ Raises:
101
+ Exception: Always raises the original error
102
+ """
103
+ logger.error(f"Node {node_name} failed (on_error=fail): {error}")
104
+ raise error
105
+
106
+
107
+ def handle_retry(
108
+ node_name: str,
109
+ execute_fn: Callable[[], tuple[Any, Exception | None]],
110
+ max_retries: int,
111
+ ) -> NodeResult:
112
+ """Handle error with retry strategy.
113
+
114
+ Args:
115
+ node_name: Name of the node
116
+ execute_fn: Function to execute (returns result, error)
117
+ max_retries: Maximum retry attempts
118
+
119
+ Returns:
120
+ NodeResult with output or error
121
+ """
122
+ last_exception: Exception | None = None
123
+
124
+ for attempt in range(1, max_retries + 1):
125
+ logger.info(f"Node {node_name} retry {attempt}/{max_retries}")
126
+ result, error = execute_fn()
127
+ if error is None:
128
+ return NodeResult(success=True, output=result)
129
+ last_exception = error
130
+
131
+ logger.error(f"Node {node_name} failed after {max_retries} attempts")
132
+ pipeline_error = PipelineError.from_exception(
133
+ last_exception or Exception("Unknown error"), node=node_name
134
+ )
135
+ return NodeResult(success=False, error=pipeline_error)
136
+
137
+
138
+ def handle_fallback(
139
+ node_name: str,
140
+ execute_fn: Callable[[str | None], tuple[Any, Exception | None]],
141
+ fallback_provider: str,
142
+ ) -> NodeResult:
143
+ """Handle error with fallback strategy.
144
+
145
+ Args:
146
+ node_name: Name of the node
147
+ execute_fn: Function to execute with provider param
148
+ fallback_provider: Fallback provider to try
149
+
150
+ Returns:
151
+ NodeResult with output or error
152
+ """
153
+ logger.info(f"Node {node_name} trying fallback: {fallback_provider}")
154
+ result, fallback_error = execute_fn(fallback_provider)
155
+
156
+ if fallback_error is None:
157
+ return NodeResult(success=True, output=result)
158
+
159
+ logger.error(f"Node {node_name} failed with primary and fallback")
160
+ pipeline_error = PipelineError.from_exception(fallback_error, node=node_name)
161
+ return NodeResult(success=False, error=pipeline_error)
162
+
163
+
164
+ def handle_default(
165
+ node_name: str,
166
+ error: Exception,
167
+ ) -> NodeResult:
168
+ """Handle error with default strategy (log and return error).
169
+
170
+ Args:
171
+ node_name: Name of the node
172
+ error: The exception that occurred
173
+
174
+ Returns:
175
+ NodeResult with error
176
+ """
177
+ logger.error(f"Node {node_name} failed: {error}")
178
+ pipeline_error = PipelineError.from_exception(error, node=node_name)
179
+ return NodeResult(success=False, error=pipeline_error)
180
+
181
+
182
+ def check_requirements(
183
+ requires: list[str],
184
+ state: dict,
185
+ node_name: str,
186
+ ) -> PipelineError | None:
187
+ """Check if all required state keys are present.
188
+
189
+ Args:
190
+ requires: List of required state keys
191
+ state: Current state
192
+ node_name: Name of the node
193
+
194
+ Returns:
195
+ PipelineError if requirements not met, None otherwise
196
+ """
197
+ for req in requires:
198
+ if state.get(req) is None:
199
+ return PipelineError(
200
+ type=ErrorType.STATE_ERROR,
201
+ message=f"Missing required state: {req}",
202
+ node=node_name,
203
+ retryable=False,
204
+ )
205
+ return None
206
+
207
+
208
+ def check_loop_limit(
209
+ node_name: str,
210
+ loop_limit: int | None,
211
+ current_count: int,
212
+ ) -> bool:
213
+ """Check if loop limit has been reached.
214
+
215
+ Args:
216
+ node_name: Name of the node
217
+ loop_limit: Maximum loop iterations (None = no limit)
218
+ current_count: Current iteration count
219
+
220
+ Returns:
221
+ True if limit reached, False otherwise
222
+ """
223
+ if loop_limit is not None and current_count >= loop_limit:
224
+ logger.warning(f"Node {node_name} hit loop limit ({loop_limit})")
225
+ return True
226
+ return False
yamlgraph/executor.py ADDED
@@ -0,0 +1,275 @@
1
+ """YAML Prompt Executor - Unified interface for LLM calls.
2
+
3
+ This module provides a simple, reusable executor for YAML-defined prompts
4
+ with support for structured outputs via Pydantic models.
5
+ """
6
+
7
+ import logging
8
+ import threading
9
+ import time
10
+ from typing import Type, TypeVar
11
+
12
+ from langchain_core.language_models.chat_models import BaseChatModel
13
+ from langchain_core.messages import HumanMessage, SystemMessage
14
+ from pydantic import BaseModel
15
+
16
+ from yamlgraph.config import (
17
+ DEFAULT_TEMPERATURE,
18
+ MAX_RETRIES,
19
+ RETRY_BASE_DELAY,
20
+ RETRY_MAX_DELAY,
21
+ )
22
+ from yamlgraph.utils.llm_factory import create_llm
23
+ from yamlgraph.utils.prompts import load_prompt
24
+ from yamlgraph.utils.template import validate_variables
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+ T = TypeVar("T", bound=BaseModel)
29
+
30
+
31
+ # Exceptions that are retryable
32
+ RETRYABLE_EXCEPTIONS = (
33
+ "RateLimitError",
34
+ "APIConnectionError",
35
+ "APITimeoutError",
36
+ "InternalServerError",
37
+ "ServiceUnavailableError",
38
+ )
39
+
40
+
41
+ def is_retryable(exception: Exception) -> bool:
42
+ """Check if an exception is retryable.
43
+
44
+ Args:
45
+ exception: The exception to check
46
+
47
+ Returns:
48
+ True if the exception should be retried
49
+ """
50
+ exc_name = type(exception).__name__
51
+ return exc_name in RETRYABLE_EXCEPTIONS or "rate" in exc_name.lower()
52
+
53
+
54
+ def format_prompt(
55
+ template: str,
56
+ variables: dict,
57
+ state: dict | None = None,
58
+ ) -> str:
59
+ """Format a prompt template with variables.
60
+
61
+ Supports both simple {variable} placeholders and Jinja2 templates.
62
+ If the template contains Jinja2 syntax ({%, {{), uses Jinja2 rendering.
63
+
64
+ Args:
65
+ template: Template string with {variable} or Jinja2 placeholders
66
+ variables: Dictionary of variable values
67
+ state: Optional state dict for Jinja2 templates (accessible as {{ state.field }})
68
+
69
+ Returns:
70
+ Formatted string
71
+
72
+ Examples:
73
+ Simple format:
74
+ format_prompt("Hello {name}", {"name": "World"})
75
+
76
+ Jinja2 with variables:
77
+ format_prompt("{% for item in items %}{{ item }}{% endfor %}", {"items": [1, 2]})
78
+
79
+ Jinja2 with state:
80
+ format_prompt("Topic: {{ state.topic }}", {}, state={"topic": "AI"})
81
+ """
82
+ # Check for Jinja2 syntax
83
+ if "{%" in template or "{{" in template:
84
+ from jinja2 import Template
85
+
86
+ jinja_template = Template(template)
87
+ # Pass both variables and state to Jinja2
88
+ context = {"state": state or {}, **variables}
89
+ return jinja_template.render(**context)
90
+
91
+ # Fall back to simple format - stringify lists for compatibility
92
+ safe_vars = {
93
+ k: (", ".join(map(str, v)) if isinstance(v, list) else v)
94
+ for k, v in variables.items()
95
+ }
96
+ return template.format(**safe_vars)
97
+
98
+
99
+ def execute_prompt(
100
+ prompt_name: str,
101
+ variables: dict | None = None,
102
+ output_model: Type[T] | None = None,
103
+ temperature: float = DEFAULT_TEMPERATURE,
104
+ provider: str | None = None,
105
+ ) -> T | str:
106
+ """Execute a YAML prompt with optional structured output.
107
+
108
+ Uses the singleton PromptExecutor for LLM caching.
109
+
110
+ Args:
111
+ prompt_name: Name of the prompt file (without .yaml)
112
+ variables: Variables to substitute in the template
113
+ output_model: Optional Pydantic model for structured output
114
+ temperature: LLM temperature setting
115
+ provider: LLM provider ("anthropic", "mistral", "openai").
116
+ Can also be set in YAML metadata or PROVIDER env var.
117
+
118
+ Returns:
119
+ Parsed Pydantic model if output_model provided, else raw string
120
+
121
+ Example:
122
+ >>> result = execute_prompt(
123
+ ... "greet",
124
+ ... variables={"name": "World", "style": "formal"},
125
+ ... output_model=GenericReport,
126
+ ... )
127
+ >>> print(result.summary)
128
+ """
129
+ return get_executor().execute(
130
+ prompt_name=prompt_name,
131
+ variables=variables,
132
+ output_model=output_model,
133
+ temperature=temperature,
134
+ provider=provider,
135
+ )
136
+
137
+
138
+ # Default executor instance for LLM caching
139
+ # Use get_executor() to access, or set_executor() for dependency injection
140
+ _executor: "PromptExecutor | None" = None
141
+ _executor_lock = threading.Lock()
142
+
143
+
144
+ def get_executor() -> "PromptExecutor":
145
+ """Get the executor instance (thread-safe).
146
+
147
+ Returns the default singleton or a custom instance set via set_executor().
148
+
149
+ Returns:
150
+ PromptExecutor instance with LLM caching
151
+ """
152
+ global _executor
153
+ if _executor is None:
154
+ with _executor_lock:
155
+ # Double-check after acquiring lock
156
+ if _executor is None:
157
+ _executor = PromptExecutor()
158
+ return _executor
159
+
160
+
161
+ def set_executor(executor: "PromptExecutor | None") -> None:
162
+ """Set a custom executor instance for dependency injection (thread-safe).
163
+
164
+ Useful for testing or when you need different executor configurations.
165
+
166
+ Args:
167
+ executor: Custom PromptExecutor instance, or None to reset to default
168
+ """
169
+ global _executor
170
+ with _executor_lock:
171
+ _executor = executor
172
+
173
+
174
+ class PromptExecutor:
175
+ """Reusable executor with LLM caching and retry logic."""
176
+
177
+ def __init__(self, max_retries: int = MAX_RETRIES):
178
+ self._max_retries = max_retries
179
+
180
+ def _get_llm(
181
+ self,
182
+ temperature: float = DEFAULT_TEMPERATURE,
183
+ provider: str | None = None,
184
+ ) -> BaseChatModel:
185
+ """Get or create cached LLM instance.
186
+
187
+ Uses llm_factory which handles caching internally.
188
+ """
189
+ return create_llm(temperature=temperature, provider=provider)
190
+
191
+ def _invoke_with_retry(
192
+ self, llm, messages, output_model: Type[T] | None = None
193
+ ) -> T | str:
194
+ """Invoke LLM with exponential backoff retry.
195
+
196
+ Args:
197
+ llm: The LLM instance to use
198
+ messages: Messages to send
199
+ output_model: Optional Pydantic model for structured output
200
+
201
+ Returns:
202
+ LLM response (parsed model or string)
203
+
204
+ Raises:
205
+ Last exception if all retries fail
206
+ """
207
+ last_exception = None
208
+
209
+ for attempt in range(self._max_retries):
210
+ try:
211
+ if output_model:
212
+ structured_llm = llm.with_structured_output(output_model)
213
+ return structured_llm.invoke(messages)
214
+ else:
215
+ response = llm.invoke(messages)
216
+ return response.content
217
+
218
+ except Exception as e:
219
+ last_exception = e
220
+
221
+ if not is_retryable(e) or attempt == self._max_retries - 1:
222
+ raise
223
+
224
+ # Exponential backoff with jitter
225
+ delay = min(RETRY_BASE_DELAY * (2**attempt), RETRY_MAX_DELAY)
226
+ logger.warning(
227
+ f"LLM call failed (attempt {attempt + 1}/{self._max_retries}): {e}. "
228
+ f"Retrying in {delay:.1f}s..."
229
+ )
230
+ time.sleep(delay)
231
+
232
+ raise last_exception
233
+
234
+ def execute(
235
+ self,
236
+ prompt_name: str,
237
+ variables: dict | None = None,
238
+ output_model: Type[T] | None = None,
239
+ temperature: float = DEFAULT_TEMPERATURE,
240
+ provider: str | None = None,
241
+ ) -> T | str:
242
+ """Execute a prompt using cached LLM with retry logic.
243
+
244
+ Same interface as execute_prompt() but with LLM caching and
245
+ automatic retry for transient failures.
246
+
247
+ Provider priority: parameter > YAML metadata > env var > default
248
+
249
+ Raises:
250
+ ValueError: If required template variables are missing
251
+ """
252
+ variables = variables or {}
253
+
254
+ prompt_config = load_prompt(prompt_name)
255
+
256
+ # Validate all required variables are provided (fail fast)
257
+ full_template = prompt_config.get("system", "") + prompt_config.get("user", "")
258
+ validate_variables(full_template, variables, prompt_name)
259
+
260
+ # Extract provider from YAML metadata if not provided
261
+ if provider is None and "provider" in prompt_config:
262
+ provider = prompt_config["provider"]
263
+ logger.debug(f"Using provider from YAML metadata: {provider}")
264
+
265
+ system_text = format_prompt(prompt_config.get("system", ""), variables)
266
+ user_text = format_prompt(prompt_config["user"], variables)
267
+
268
+ messages = []
269
+ if system_text:
270
+ messages.append(SystemMessage(content=system_text))
271
+ messages.append(HumanMessage(content=user_text))
272
+
273
+ llm = self._get_llm(temperature=temperature, provider=provider)
274
+
275
+ return self._invoke_with_retry(llm, messages, output_model)
@@ -0,0 +1,122 @@
1
+ """Async Prompt Executor - Async interface for LLM calls.
2
+
3
+ This module provides async versions of execute_prompt for use in
4
+ async contexts like web servers or concurrent pipelines.
5
+
6
+ Note: This is a foundation module. The underlying LLM calls still
7
+ use sync HTTP clients wrapped with run_in_executor.
8
+ """
9
+
10
+ import asyncio
11
+ import logging
12
+ from typing import Type, TypeVar
13
+
14
+ from langchain_core.messages import HumanMessage, SystemMessage
15
+ from pydantic import BaseModel
16
+
17
+ from yamlgraph.config import DEFAULT_TEMPERATURE
18
+ from yamlgraph.executor import format_prompt, load_prompt
19
+ from yamlgraph.utils.llm_factory import create_llm
20
+ from yamlgraph.utils.llm_factory_async import invoke_async
21
+ from yamlgraph.utils.template import validate_variables
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+ T = TypeVar("T", bound=BaseModel)
26
+
27
+
28
+ async def execute_prompt_async(
29
+ prompt_name: str,
30
+ variables: dict | None = None,
31
+ output_model: Type[T] | None = None,
32
+ temperature: float = DEFAULT_TEMPERATURE,
33
+ provider: str | None = None,
34
+ ) -> T | str:
35
+ """Execute a YAML prompt asynchronously.
36
+
37
+ Async version of execute_prompt for use in async contexts.
38
+
39
+ Args:
40
+ prompt_name: Name of the prompt file (without .yaml)
41
+ variables: Variables to substitute in the template
42
+ output_model: Optional Pydantic model for structured output
43
+ temperature: LLM temperature setting
44
+ provider: LLM provider ("anthropic", "mistral", "openai")
45
+
46
+ Returns:
47
+ Parsed Pydantic model if output_model provided, else raw string
48
+
49
+ Example:
50
+ >>> result = await execute_prompt_async(
51
+ ... "greet",
52
+ ... variables={"name": "World"},
53
+ ... output_model=GenericReport,
54
+ ... )
55
+ """
56
+ variables = variables or {}
57
+
58
+ # Load and validate prompt (sync - file I/O is fast)
59
+ prompt_config = load_prompt(prompt_name)
60
+
61
+ full_template = prompt_config.get("system", "") + prompt_config.get("user", "")
62
+ validate_variables(full_template, variables, prompt_name)
63
+
64
+ # Extract provider from YAML metadata if not provided
65
+ if provider is None and "provider" in prompt_config:
66
+ provider = prompt_config["provider"]
67
+ logger.debug(f"Using provider from YAML metadata: {provider}")
68
+
69
+ system_text = format_prompt(prompt_config.get("system", ""), variables)
70
+ user_text = format_prompt(prompt_config["user"], variables)
71
+
72
+ messages = []
73
+ if system_text:
74
+ messages.append(SystemMessage(content=system_text))
75
+ messages.append(HumanMessage(content=user_text))
76
+
77
+ # Create LLM (cached via factory)
78
+ llm = create_llm(temperature=temperature, provider=provider)
79
+
80
+ # Invoke asynchronously
81
+ return await invoke_async(llm, messages, output_model)
82
+
83
+
84
+ async def execute_prompts_concurrent(
85
+ prompts: list[dict],
86
+ ) -> list[BaseModel | str]:
87
+ """Execute multiple prompts concurrently.
88
+
89
+ Useful for parallel LLM calls in pipelines.
90
+
91
+ Args:
92
+ prompts: List of dicts with keys:
93
+ - prompt_name: str (required)
94
+ - variables: dict (optional)
95
+ - output_model: Type[BaseModel] (optional)
96
+ - temperature: float (optional)
97
+ - provider: str (optional)
98
+
99
+ Returns:
100
+ List of results in same order as input prompts
101
+
102
+ Example:
103
+ >>> results = await execute_prompts_concurrent([
104
+ ... {"prompt_name": "summarize", "variables": {"text": "..."}},
105
+ ... {"prompt_name": "analyze", "variables": {"text": "..."}},
106
+ ... ])
107
+ """
108
+ tasks = []
109
+ for prompt_config in prompts:
110
+ task = execute_prompt_async(
111
+ prompt_name=prompt_config["prompt_name"],
112
+ variables=prompt_config.get("variables"),
113
+ output_model=prompt_config.get("output_model"),
114
+ temperature=prompt_config.get("temperature", DEFAULT_TEMPERATURE),
115
+ provider=prompt_config.get("provider"),
116
+ )
117
+ tasks.append(task)
118
+
119
+ return await asyncio.gather(*tasks)
120
+
121
+
122
+ __all__ = ["execute_prompt_async", "execute_prompts_concurrent"]