yamlgraph 0.3.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. examples/__init__.py +1 -0
  2. examples/codegen/__init__.py +5 -0
  3. examples/codegen/models/__init__.py +13 -0
  4. examples/codegen/models/schemas.py +76 -0
  5. examples/codegen/tests/__init__.py +1 -0
  6. examples/codegen/tests/test_ai_helpers.py +235 -0
  7. examples/codegen/tests/test_ast_analysis.py +174 -0
  8. examples/codegen/tests/test_code_analysis.py +134 -0
  9. examples/codegen/tests/test_code_context.py +301 -0
  10. examples/codegen/tests/test_code_nav.py +89 -0
  11. examples/codegen/tests/test_dependency_tools.py +119 -0
  12. examples/codegen/tests/test_example_tools.py +185 -0
  13. examples/codegen/tests/test_git_tools.py +112 -0
  14. examples/codegen/tests/test_impl_agent_schemas.py +193 -0
  15. examples/codegen/tests/test_impl_agent_v4_graph.py +94 -0
  16. examples/codegen/tests/test_jedi_analysis.py +226 -0
  17. examples/codegen/tests/test_meta_tools.py +250 -0
  18. examples/codegen/tests/test_plan_discovery_prompt.py +98 -0
  19. examples/codegen/tests/test_syntax_tools.py +85 -0
  20. examples/codegen/tests/test_synthesize_prompt.py +94 -0
  21. examples/codegen/tests/test_template_tools.py +244 -0
  22. examples/codegen/tools/__init__.py +80 -0
  23. examples/codegen/tools/ai_helpers.py +420 -0
  24. examples/codegen/tools/ast_analysis.py +92 -0
  25. examples/codegen/tools/code_context.py +180 -0
  26. examples/codegen/tools/code_nav.py +52 -0
  27. examples/codegen/tools/dependency_tools.py +120 -0
  28. examples/codegen/tools/example_tools.py +188 -0
  29. examples/codegen/tools/git_tools.py +151 -0
  30. examples/codegen/tools/impl_executor.py +614 -0
  31. examples/codegen/tools/jedi_analysis.py +311 -0
  32. examples/codegen/tools/meta_tools.py +202 -0
  33. examples/codegen/tools/syntax_tools.py +26 -0
  34. examples/codegen/tools/template_tools.py +356 -0
  35. examples/fastapi_interview.py +167 -0
  36. examples/npc/api/__init__.py +1 -0
  37. examples/npc/api/app.py +100 -0
  38. examples/npc/api/routes/__init__.py +5 -0
  39. examples/npc/api/routes/encounter.py +182 -0
  40. examples/npc/api/session.py +330 -0
  41. examples/npc/demo.py +387 -0
  42. examples/npc/nodes/__init__.py +5 -0
  43. examples/npc/nodes/image_node.py +92 -0
  44. examples/npc/run_encounter.py +230 -0
  45. examples/shared/__init__.py +0 -0
  46. examples/shared/replicate_tool.py +238 -0
  47. examples/storyboard/__init__.py +1 -0
  48. examples/storyboard/generate_videos.py +335 -0
  49. examples/storyboard/nodes/__init__.py +12 -0
  50. examples/storyboard/nodes/animated_character_node.py +248 -0
  51. examples/storyboard/nodes/animated_image_node.py +138 -0
  52. examples/storyboard/nodes/character_node.py +162 -0
  53. examples/storyboard/nodes/image_node.py +118 -0
  54. examples/storyboard/nodes/replicate_tool.py +49 -0
  55. examples/storyboard/retry_images.py +118 -0
  56. scripts/demo_async_executor.py +212 -0
  57. scripts/demo_interview_e2e.py +200 -0
  58. scripts/demo_streaming.py +140 -0
  59. scripts/run_interview_demo.py +94 -0
  60. scripts/test_interrupt_fix.py +26 -0
  61. tests/__init__.py +1 -0
  62. tests/conftest.py +178 -0
  63. tests/integration/__init__.py +1 -0
  64. tests/integration/test_animated_storyboard.py +63 -0
  65. tests/integration/test_cli_commands.py +242 -0
  66. tests/integration/test_colocated_prompts.py +139 -0
  67. tests/integration/test_map_demo.py +50 -0
  68. tests/integration/test_memory_demo.py +283 -0
  69. tests/integration/test_npc_api/__init__.py +1 -0
  70. tests/integration/test_npc_api/test_routes.py +357 -0
  71. tests/integration/test_npc_api/test_session.py +216 -0
  72. tests/integration/test_pipeline_flow.py +105 -0
  73. tests/integration/test_providers.py +163 -0
  74. tests/integration/test_resume.py +75 -0
  75. tests/integration/test_subgraph_integration.py +295 -0
  76. tests/integration/test_subgraph_interrupt.py +106 -0
  77. tests/unit/__init__.py +1 -0
  78. tests/unit/test_agent_nodes.py +355 -0
  79. tests/unit/test_async_executor.py +346 -0
  80. tests/unit/test_checkpointer.py +212 -0
  81. tests/unit/test_checkpointer_factory.py +212 -0
  82. tests/unit/test_cli.py +121 -0
  83. tests/unit/test_cli_package.py +81 -0
  84. tests/unit/test_compile_graph_map.py +132 -0
  85. tests/unit/test_conditions_routing.py +253 -0
  86. tests/unit/test_config.py +93 -0
  87. tests/unit/test_conversation_memory.py +276 -0
  88. tests/unit/test_database.py +145 -0
  89. tests/unit/test_deprecation.py +104 -0
  90. tests/unit/test_executor.py +172 -0
  91. tests/unit/test_executor_async.py +179 -0
  92. tests/unit/test_export.py +149 -0
  93. tests/unit/test_expressions.py +178 -0
  94. tests/unit/test_feature_brainstorm.py +194 -0
  95. tests/unit/test_format_prompt.py +145 -0
  96. tests/unit/test_generic_report.py +200 -0
  97. tests/unit/test_graph_commands.py +327 -0
  98. tests/unit/test_graph_linter.py +627 -0
  99. tests/unit/test_graph_loader.py +357 -0
  100. tests/unit/test_graph_schema.py +193 -0
  101. tests/unit/test_inline_schema.py +151 -0
  102. tests/unit/test_interrupt_node.py +182 -0
  103. tests/unit/test_issues.py +164 -0
  104. tests/unit/test_jinja2_prompts.py +85 -0
  105. tests/unit/test_json_extract.py +134 -0
  106. tests/unit/test_langsmith.py +600 -0
  107. tests/unit/test_langsmith_tools.py +204 -0
  108. tests/unit/test_llm_factory.py +109 -0
  109. tests/unit/test_llm_factory_async.py +118 -0
  110. tests/unit/test_loops.py +403 -0
  111. tests/unit/test_map_node.py +144 -0
  112. tests/unit/test_no_backward_compat.py +56 -0
  113. tests/unit/test_node_factory.py +348 -0
  114. tests/unit/test_passthrough_node.py +126 -0
  115. tests/unit/test_prompts.py +324 -0
  116. tests/unit/test_python_nodes.py +198 -0
  117. tests/unit/test_reliability.py +298 -0
  118. tests/unit/test_result_export.py +234 -0
  119. tests/unit/test_router.py +296 -0
  120. tests/unit/test_sanitize.py +99 -0
  121. tests/unit/test_schema_loader.py +295 -0
  122. tests/unit/test_shell_tools.py +229 -0
  123. tests/unit/test_state_builder.py +331 -0
  124. tests/unit/test_state_builder_map.py +104 -0
  125. tests/unit/test_state_config.py +197 -0
  126. tests/unit/test_streaming.py +307 -0
  127. tests/unit/test_subgraph.py +596 -0
  128. tests/unit/test_template.py +190 -0
  129. tests/unit/test_tool_call_integration.py +164 -0
  130. tests/unit/test_tool_call_node.py +178 -0
  131. tests/unit/test_tool_nodes.py +129 -0
  132. tests/unit/test_websearch.py +234 -0
  133. yamlgraph/__init__.py +35 -0
  134. yamlgraph/builder.py +110 -0
  135. yamlgraph/cli/__init__.py +159 -0
  136. yamlgraph/cli/__main__.py +6 -0
  137. yamlgraph/cli/commands.py +231 -0
  138. yamlgraph/cli/deprecation.py +92 -0
  139. yamlgraph/cli/graph_commands.py +541 -0
  140. yamlgraph/cli/validators.py +37 -0
  141. yamlgraph/config.py +67 -0
  142. yamlgraph/constants.py +70 -0
  143. yamlgraph/error_handlers.py +227 -0
  144. yamlgraph/executor.py +290 -0
  145. yamlgraph/executor_async.py +288 -0
  146. yamlgraph/graph_loader.py +451 -0
  147. yamlgraph/map_compiler.py +150 -0
  148. yamlgraph/models/__init__.py +36 -0
  149. yamlgraph/models/graph_schema.py +181 -0
  150. yamlgraph/models/schemas.py +124 -0
  151. yamlgraph/models/state_builder.py +236 -0
  152. yamlgraph/node_factory.py +768 -0
  153. yamlgraph/routing.py +87 -0
  154. yamlgraph/schema_loader.py +240 -0
  155. yamlgraph/storage/__init__.py +20 -0
  156. yamlgraph/storage/checkpointer.py +72 -0
  157. yamlgraph/storage/checkpointer_factory.py +123 -0
  158. yamlgraph/storage/database.py +320 -0
  159. yamlgraph/storage/export.py +269 -0
  160. yamlgraph/tools/__init__.py +1 -0
  161. yamlgraph/tools/agent.py +320 -0
  162. yamlgraph/tools/graph_linter.py +388 -0
  163. yamlgraph/tools/langsmith_tools.py +125 -0
  164. yamlgraph/tools/nodes.py +126 -0
  165. yamlgraph/tools/python_tool.py +179 -0
  166. yamlgraph/tools/shell.py +205 -0
  167. yamlgraph/tools/websearch.py +242 -0
  168. yamlgraph/utils/__init__.py +48 -0
  169. yamlgraph/utils/conditions.py +157 -0
  170. yamlgraph/utils/expressions.py +245 -0
  171. yamlgraph/utils/json_extract.py +104 -0
  172. yamlgraph/utils/langsmith.py +416 -0
  173. yamlgraph/utils/llm_factory.py +118 -0
  174. yamlgraph/utils/llm_factory_async.py +105 -0
  175. yamlgraph/utils/logging.py +104 -0
  176. yamlgraph/utils/prompts.py +171 -0
  177. yamlgraph/utils/sanitize.py +98 -0
  178. yamlgraph/utils/template.py +102 -0
  179. yamlgraph/utils/validators.py +181 -0
  180. yamlgraph-0.3.9.dist-info/METADATA +1105 -0
  181. yamlgraph-0.3.9.dist-info/RECORD +185 -0
  182. yamlgraph-0.3.9.dist-info/WHEEL +5 -0
  183. yamlgraph-0.3.9.dist-info/entry_points.txt +2 -0
  184. yamlgraph-0.3.9.dist-info/licenses/LICENSE +33 -0
  185. yamlgraph-0.3.9.dist-info/top_level.txt +4 -0
yamlgraph/constants.py ADDED
@@ -0,0 +1,70 @@
1
+ """Type-safe constants for YAML graph configuration.
2
+
3
+ Provides enums for node types, error handlers, and other magic strings
4
+ used throughout the codebase to enable static type checking and IDE support.
5
+ """
6
+
7
+ from enum import StrEnum
8
+
9
+
10
+ class NodeType(StrEnum):
11
+ """Valid node types in YAML graph configuration."""
12
+
13
+ LLM = "llm"
14
+ ROUTER = "router"
15
+ TOOL = "tool"
16
+ AGENT = "agent"
17
+ PYTHON = "python"
18
+ MAP = "map"
19
+ TOOL_CALL = "tool_call"
20
+ INTERRUPT = "interrupt"
21
+ SUBGRAPH = "subgraph"
22
+ PASSTHROUGH = "passthrough"
23
+
24
+ @classmethod
25
+ def requires_prompt(cls, node_type: str) -> bool:
26
+ """Check if node type requires a prompt field.
27
+
28
+ Args:
29
+ node_type: The node type string
30
+
31
+ Returns:
32
+ True if the node type requires a prompt
33
+ """
34
+ return node_type in (cls.LLM, cls.ROUTER)
35
+
36
+
37
+ class ErrorHandler(StrEnum):
38
+ """Valid on_error handling strategies."""
39
+
40
+ SKIP = "skip" # Skip node and continue pipeline
41
+ RETRY = "retry" # Retry with max_retries attempts
42
+ FAIL = "fail" # Raise exception immediately
43
+ FALLBACK = "fallback" # Try fallback provider
44
+
45
+ @classmethod
46
+ def all_values(cls) -> set[str]:
47
+ """Return all valid error handler values.
48
+
49
+ Returns:
50
+ Set of valid error handler strings
51
+ """
52
+ return {handler.value for handler in cls}
53
+
54
+
55
+ class EdgeType(StrEnum):
56
+ """Valid edge types in graph configuration."""
57
+
58
+ SIMPLE = "simple" # Direct edge from -> to
59
+ CONDITIONAL = "conditional" # Edge with conditions
60
+
61
+
62
+ class SpecialNodes(StrEnum):
63
+ """Special node names with semantic meaning."""
64
+
65
+ START = "__start__"
66
+ END = "__end__"
67
+
68
+
69
+ # Re-export for convenience
70
+ __all__ = ["NodeType", "ErrorHandler", "EdgeType", "SpecialNodes"]
@@ -0,0 +1,227 @@
1
+ """Error handling strategies for node execution.
2
+
3
+ Provides strategy functions for different error handling modes:
4
+ - skip: Continue without output
5
+ - fail: Raise exception immediately
6
+ - retry: Retry up to N times
7
+ - fallback: Try fallback provider
8
+ """
9
+
10
+ import logging
11
+ from collections.abc import Callable
12
+ from typing import Any
13
+
14
+ from yamlgraph.models import ErrorType, PipelineError
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class NodeResult:
20
+ """Result of node execution with consistent structure.
21
+
22
+ Attributes:
23
+ success: Whether execution succeeded
24
+ output: The result value (if success)
25
+ error: PipelineError (if failure)
26
+ state_updates: Additional state updates
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ success: bool,
32
+ output: Any = None,
33
+ error: PipelineError | None = None,
34
+ state_updates: dict | None = None,
35
+ ):
36
+ self.success = success
37
+ self.output = output
38
+ self.error = error
39
+ self.state_updates = state_updates or {}
40
+
41
+ def to_state_update(
42
+ self,
43
+ state_key: str,
44
+ node_name: str,
45
+ loop_counts: dict,
46
+ ) -> dict:
47
+ """Convert to LangGraph state update dict.
48
+
49
+ Args:
50
+ state_key: Key to store output under
51
+ node_name: Name of the node
52
+ loop_counts: Current loop counts
53
+
54
+ Returns:
55
+ State update dict with consistent structure
56
+ """
57
+ update = {
58
+ "current_step": node_name,
59
+ "_loop_counts": loop_counts,
60
+ }
61
+
62
+ if self.success:
63
+ update[state_key] = self.output
64
+ elif self.error:
65
+ # Always use 'errors' list for consistency
66
+ update["errors"] = [self.error]
67
+
68
+ update.update(self.state_updates)
69
+ return update
70
+
71
+
72
+ def handle_skip(
73
+ node_name: str,
74
+ error: Exception,
75
+ loop_counts: dict,
76
+ ) -> NodeResult:
77
+ """Handle error with skip strategy.
78
+
79
+ Args:
80
+ node_name: Name of the node
81
+ error: The exception that occurred
82
+ loop_counts: Current loop counts
83
+
84
+ Returns:
85
+ NodeResult with empty output
86
+ """
87
+ logger.warning(f"Node {node_name} failed, skipping: {error}")
88
+ return NodeResult(success=True, output=None)
89
+
90
+
91
+ def handle_fail(
92
+ node_name: str,
93
+ error: Exception,
94
+ ) -> None:
95
+ """Handle error with fail strategy.
96
+
97
+ Args:
98
+ node_name: Name of the node
99
+ error: The exception that occurred
100
+
101
+ Raises:
102
+ Exception: Always raises the original error
103
+ """
104
+ logger.error(f"Node {node_name} failed (on_error=fail): {error}")
105
+ raise error
106
+
107
+
108
+ def handle_retry(
109
+ node_name: str,
110
+ execute_fn: Callable[[], tuple[Any, Exception | None]],
111
+ max_retries: int,
112
+ ) -> NodeResult:
113
+ """Handle error with retry strategy.
114
+
115
+ Args:
116
+ node_name: Name of the node
117
+ execute_fn: Function to execute (returns result, error)
118
+ max_retries: Maximum retry attempts
119
+
120
+ Returns:
121
+ NodeResult with output or error
122
+ """
123
+ last_exception: Exception | None = None
124
+
125
+ for attempt in range(1, max_retries + 1):
126
+ logger.info(f"Node {node_name} retry {attempt}/{max_retries}")
127
+ result, error = execute_fn()
128
+ if error is None:
129
+ return NodeResult(success=True, output=result)
130
+ last_exception = error
131
+
132
+ logger.error(f"Node {node_name} failed after {max_retries} attempts")
133
+ pipeline_error = PipelineError.from_exception(
134
+ last_exception or Exception("Unknown error"), node=node_name
135
+ )
136
+ return NodeResult(success=False, error=pipeline_error)
137
+
138
+
139
+ def handle_fallback(
140
+ node_name: str,
141
+ execute_fn: Callable[[str | None], tuple[Any, Exception | None]],
142
+ fallback_provider: str,
143
+ ) -> NodeResult:
144
+ """Handle error with fallback strategy.
145
+
146
+ Args:
147
+ node_name: Name of the node
148
+ execute_fn: Function to execute with provider param
149
+ fallback_provider: Fallback provider to try
150
+
151
+ Returns:
152
+ NodeResult with output or error
153
+ """
154
+ logger.info(f"Node {node_name} trying fallback: {fallback_provider}")
155
+ result, fallback_error = execute_fn(fallback_provider)
156
+
157
+ if fallback_error is None:
158
+ return NodeResult(success=True, output=result)
159
+
160
+ logger.error(f"Node {node_name} failed with primary and fallback")
161
+ pipeline_error = PipelineError.from_exception(fallback_error, node=node_name)
162
+ return NodeResult(success=False, error=pipeline_error)
163
+
164
+
165
+ def handle_default(
166
+ node_name: str,
167
+ error: Exception,
168
+ ) -> NodeResult:
169
+ """Handle error with default strategy (log and return error).
170
+
171
+ Args:
172
+ node_name: Name of the node
173
+ error: The exception that occurred
174
+
175
+ Returns:
176
+ NodeResult with error
177
+ """
178
+ logger.error(f"Node {node_name} failed: {error}")
179
+ pipeline_error = PipelineError.from_exception(error, node=node_name)
180
+ return NodeResult(success=False, error=pipeline_error)
181
+
182
+
183
+ def check_requirements(
184
+ requires: list[str],
185
+ state: dict,
186
+ node_name: str,
187
+ ) -> PipelineError | None:
188
+ """Check if all required state keys are present.
189
+
190
+ Args:
191
+ requires: List of required state keys
192
+ state: Current state
193
+ node_name: Name of the node
194
+
195
+ Returns:
196
+ PipelineError if requirements not met, None otherwise
197
+ """
198
+ for req in requires:
199
+ if state.get(req) is None:
200
+ return PipelineError(
201
+ type=ErrorType.STATE_ERROR,
202
+ message=f"Missing required state: {req}",
203
+ node=node_name,
204
+ retryable=False,
205
+ )
206
+ return None
207
+
208
+
209
+ def check_loop_limit(
210
+ node_name: str,
211
+ loop_limit: int | None,
212
+ current_count: int,
213
+ ) -> bool:
214
+ """Check if loop limit has been reached.
215
+
216
+ Args:
217
+ node_name: Name of the node
218
+ loop_limit: Maximum loop iterations (None = no limit)
219
+ current_count: Current iteration count
220
+
221
+ Returns:
222
+ True if limit reached, False otherwise
223
+ """
224
+ if loop_limit is not None and current_count >= loop_limit:
225
+ logger.warning(f"Node {node_name} hit loop limit ({loop_limit})")
226
+ return True
227
+ return False
yamlgraph/executor.py ADDED
@@ -0,0 +1,290 @@
1
+ """YAML Prompt Executor - Unified interface for LLM calls.
2
+
3
+ This module provides a simple, reusable executor for YAML-defined prompts
4
+ with support for structured outputs via Pydantic models.
5
+ """
6
+
7
+ import logging
8
+ import threading
9
+ import time
10
+ from pathlib import Path
11
+ from typing import TypeVar
12
+
13
+ from langchain_core.language_models.chat_models import BaseChatModel
14
+ from langchain_core.messages import HumanMessage, SystemMessage
15
+ from pydantic import BaseModel
16
+
17
+ from yamlgraph.config import (
18
+ DEFAULT_TEMPERATURE,
19
+ MAX_RETRIES,
20
+ RETRY_BASE_DELAY,
21
+ RETRY_MAX_DELAY,
22
+ )
23
+ from yamlgraph.utils.llm_factory import create_llm
24
+ from yamlgraph.utils.prompts import load_prompt
25
+ from yamlgraph.utils.template import validate_variables
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+ T = TypeVar("T", bound=BaseModel)
30
+
31
+
32
+ # Exceptions that are retryable
33
+ RETRYABLE_EXCEPTIONS = (
34
+ "RateLimitError",
35
+ "APIConnectionError",
36
+ "APITimeoutError",
37
+ "InternalServerError",
38
+ "ServiceUnavailableError",
39
+ )
40
+
41
+
42
+ def is_retryable(exception: Exception) -> bool:
43
+ """Check if an exception is retryable.
44
+
45
+ Args:
46
+ exception: The exception to check
47
+
48
+ Returns:
49
+ True if the exception should be retried
50
+ """
51
+ exc_name = type(exception).__name__
52
+ return exc_name in RETRYABLE_EXCEPTIONS or "rate" in exc_name.lower()
53
+
54
+
55
+ def format_prompt(
56
+ template: str,
57
+ variables: dict,
58
+ state: dict | None = None,
59
+ ) -> str:
60
+ """Format a prompt template with variables.
61
+
62
+ Supports both simple {variable} placeholders and Jinja2 templates.
63
+ If the template contains Jinja2 syntax ({%, {{), uses Jinja2 rendering.
64
+
65
+ Args:
66
+ template: Template string with {variable} or Jinja2 placeholders
67
+ variables: Dictionary of variable values
68
+ state: Optional state dict for Jinja2 templates (accessible as {{ state.field }})
69
+
70
+ Returns:
71
+ Formatted string
72
+
73
+ Examples:
74
+ Simple format:
75
+ format_prompt("Hello {name}", {"name": "World"})
76
+
77
+ Jinja2 with variables:
78
+ format_prompt("{% for item in items %}{{ item }}{% endfor %}", {"items": [1, 2]})
79
+
80
+ Jinja2 with state:
81
+ format_prompt("Topic: {{ state.topic }}", {}, state={"topic": "AI"})
82
+ """
83
+ # Check for Jinja2 syntax
84
+ if "{%" in template or "{{" in template:
85
+ from jinja2 import Template
86
+
87
+ jinja_template = Template(template)
88
+ # Pass both variables and state to Jinja2
89
+ context = {"state": state or {}, **variables}
90
+ return jinja_template.render(**context)
91
+
92
+ # Fall back to simple format - stringify lists for compatibility
93
+ safe_vars = {
94
+ k: (", ".join(map(str, v)) if isinstance(v, list) else v)
95
+ for k, v in variables.items()
96
+ }
97
+ return template.format(**safe_vars)
98
+
99
+
100
+ def execute_prompt(
101
+ prompt_name: str,
102
+ variables: dict | None = None,
103
+ output_model: type[T] | None = None,
104
+ temperature: float = DEFAULT_TEMPERATURE,
105
+ provider: str | None = None,
106
+ graph_path: "Path | None" = None,
107
+ prompts_dir: "Path | None" = None,
108
+ prompts_relative: bool = False,
109
+ ) -> T | str:
110
+ """Execute a YAML prompt with optional structured output.
111
+
112
+ Uses the singleton PromptExecutor for LLM caching.
113
+
114
+ Args:
115
+ prompt_name: Name of the prompt file (without .yaml)
116
+ variables: Variables to substitute in the template
117
+ output_model: Optional Pydantic model for structured output
118
+ temperature: LLM temperature setting
119
+ provider: LLM provider ("anthropic", "mistral", "openai").
120
+ Can also be set in YAML metadata or PROVIDER env var.
121
+ graph_path: Path to graph file for relative prompt resolution
122
+ prompts_dir: Explicit prompts directory override
123
+ prompts_relative: If True, resolve prompts relative to graph_path
124
+
125
+ Returns:
126
+ Parsed Pydantic model if output_model provided, else raw string
127
+
128
+ Example:
129
+ >>> result = execute_prompt(
130
+ ... "greet",
131
+ ... variables={"name": "World", "style": "formal"},
132
+ ... output_model=GenericReport,
133
+ ... )
134
+ >>> print(result.summary)
135
+ """
136
+ return get_executor().execute(
137
+ prompt_name=prompt_name,
138
+ variables=variables,
139
+ output_model=output_model,
140
+ temperature=temperature,
141
+ provider=provider,
142
+ graph_path=graph_path,
143
+ prompts_dir=prompts_dir,
144
+ prompts_relative=prompts_relative,
145
+ )
146
+
147
+
148
+ # Default executor instance for LLM caching
149
+ # Use get_executor() to access, or set_executor() for dependency injection
150
+ _executor: "PromptExecutor | None" = None
151
+ _executor_lock = threading.Lock()
152
+
153
+
154
+ def get_executor() -> "PromptExecutor":
155
+ """Get the executor instance (thread-safe).
156
+
157
+ Returns the default singleton or a custom instance set via set_executor().
158
+
159
+ Returns:
160
+ PromptExecutor instance with LLM caching
161
+ """
162
+ global _executor
163
+ if _executor is None:
164
+ with _executor_lock:
165
+ # Double-check after acquiring lock
166
+ if _executor is None:
167
+ _executor = PromptExecutor()
168
+ return _executor
169
+
170
+
171
+ class PromptExecutor:
172
+ """Reusable executor with LLM caching and retry logic."""
173
+
174
+ def __init__(self, max_retries: int = MAX_RETRIES):
175
+ self._max_retries = max_retries
176
+
177
+ def _get_llm(
178
+ self,
179
+ temperature: float = DEFAULT_TEMPERATURE,
180
+ provider: str | None = None,
181
+ ) -> BaseChatModel:
182
+ """Get or create cached LLM instance.
183
+
184
+ Uses llm_factory which handles caching internally.
185
+ """
186
+ return create_llm(temperature=temperature, provider=provider)
187
+
188
+ def _invoke_with_retry(
189
+ self, llm, messages, output_model: type[T] | None = None
190
+ ) -> T | str:
191
+ """Invoke LLM with exponential backoff retry.
192
+
193
+ Args:
194
+ llm: The LLM instance to use
195
+ messages: Messages to send
196
+ output_model: Optional Pydantic model for structured output
197
+
198
+ Returns:
199
+ LLM response (parsed model or string)
200
+
201
+ Raises:
202
+ Last exception if all retries fail
203
+ """
204
+ last_exception = None
205
+
206
+ for attempt in range(self._max_retries):
207
+ try:
208
+ if output_model:
209
+ structured_llm = llm.with_structured_output(output_model)
210
+ return structured_llm.invoke(messages)
211
+ else:
212
+ response = llm.invoke(messages)
213
+ return response.content
214
+
215
+ except Exception as e:
216
+ last_exception = e
217
+
218
+ if not is_retryable(e) or attempt == self._max_retries - 1:
219
+ raise
220
+
221
+ # Exponential backoff with jitter
222
+ delay = min(RETRY_BASE_DELAY * (2**attempt), RETRY_MAX_DELAY)
223
+ logger.warning(
224
+ f"LLM call failed (attempt {attempt + 1}/{self._max_retries}): {e}. "
225
+ f"Retrying in {delay:.1f}s..."
226
+ )
227
+ time.sleep(delay)
228
+
229
+ raise last_exception
230
+
231
+ def execute(
232
+ self,
233
+ prompt_name: str,
234
+ variables: dict | None = None,
235
+ output_model: type[T] | None = None,
236
+ temperature: float = DEFAULT_TEMPERATURE,
237
+ provider: str | None = None,
238
+ graph_path: "Path | None" = None,
239
+ prompts_dir: "Path | None" = None,
240
+ prompts_relative: bool = False,
241
+ ) -> T | str:
242
+ """Execute a prompt using cached LLM with retry logic.
243
+
244
+ Same interface as execute_prompt() but with LLM caching and
245
+ automatic retry for transient failures.
246
+
247
+ Provider priority: parameter > YAML metadata > env var > default
248
+
249
+ Args:
250
+ prompt_name: Name of the prompt file (without .yaml)
251
+ variables: Variables to substitute in the template
252
+ output_model: Optional Pydantic model for structured output
253
+ temperature: LLM temperature setting
254
+ provider: LLM provider ("anthropic", "mistral", "openai")
255
+ graph_path: Path to graph file for relative prompt resolution
256
+ prompts_dir: Explicit prompts directory override
257
+ prompts_relative: If True, resolve prompts relative to graph_path
258
+
259
+ Raises:
260
+ ValueError: If required template variables are missing
261
+ """
262
+ variables = variables or {}
263
+
264
+ prompt_config = load_prompt(
265
+ prompt_name,
266
+ prompts_dir=prompts_dir,
267
+ graph_path=graph_path,
268
+ prompts_relative=prompts_relative,
269
+ )
270
+
271
+ # Validate all required variables are provided (fail fast)
272
+ full_template = prompt_config.get("system", "") + prompt_config.get("user", "")
273
+ validate_variables(full_template, variables, prompt_name)
274
+
275
+ # Extract provider from YAML metadata if not provided
276
+ if provider is None and "provider" in prompt_config:
277
+ provider = prompt_config["provider"]
278
+ logger.debug(f"Using provider from YAML metadata: {provider}")
279
+
280
+ system_text = format_prompt(prompt_config.get("system", ""), variables)
281
+ user_text = format_prompt(prompt_config["user"], variables)
282
+
283
+ messages = []
284
+ if system_text:
285
+ messages.append(SystemMessage(content=system_text))
286
+ messages.append(HumanMessage(content=user_text))
287
+
288
+ llm = self._get_llm(temperature=temperature, provider=provider)
289
+
290
+ return self._invoke_with_retry(llm, messages, output_model)