yamlgraph 0.3.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- examples/__init__.py +1 -0
- examples/codegen/__init__.py +5 -0
- examples/codegen/models/__init__.py +13 -0
- examples/codegen/models/schemas.py +76 -0
- examples/codegen/tests/__init__.py +1 -0
- examples/codegen/tests/test_ai_helpers.py +235 -0
- examples/codegen/tests/test_ast_analysis.py +174 -0
- examples/codegen/tests/test_code_analysis.py +134 -0
- examples/codegen/tests/test_code_context.py +301 -0
- examples/codegen/tests/test_code_nav.py +89 -0
- examples/codegen/tests/test_dependency_tools.py +119 -0
- examples/codegen/tests/test_example_tools.py +185 -0
- examples/codegen/tests/test_git_tools.py +112 -0
- examples/codegen/tests/test_impl_agent_schemas.py +193 -0
- examples/codegen/tests/test_impl_agent_v4_graph.py +94 -0
- examples/codegen/tests/test_jedi_analysis.py +226 -0
- examples/codegen/tests/test_meta_tools.py +250 -0
- examples/codegen/tests/test_plan_discovery_prompt.py +98 -0
- examples/codegen/tests/test_syntax_tools.py +85 -0
- examples/codegen/tests/test_synthesize_prompt.py +94 -0
- examples/codegen/tests/test_template_tools.py +244 -0
- examples/codegen/tools/__init__.py +80 -0
- examples/codegen/tools/ai_helpers.py +420 -0
- examples/codegen/tools/ast_analysis.py +92 -0
- examples/codegen/tools/code_context.py +180 -0
- examples/codegen/tools/code_nav.py +52 -0
- examples/codegen/tools/dependency_tools.py +120 -0
- examples/codegen/tools/example_tools.py +188 -0
- examples/codegen/tools/git_tools.py +151 -0
- examples/codegen/tools/impl_executor.py +614 -0
- examples/codegen/tools/jedi_analysis.py +311 -0
- examples/codegen/tools/meta_tools.py +202 -0
- examples/codegen/tools/syntax_tools.py +26 -0
- examples/codegen/tools/template_tools.py +356 -0
- examples/fastapi_interview.py +167 -0
- examples/npc/api/__init__.py +1 -0
- examples/npc/api/app.py +100 -0
- examples/npc/api/routes/__init__.py +5 -0
- examples/npc/api/routes/encounter.py +182 -0
- examples/npc/api/session.py +330 -0
- examples/npc/demo.py +387 -0
- examples/npc/nodes/__init__.py +5 -0
- examples/npc/nodes/image_node.py +92 -0
- examples/npc/run_encounter.py +230 -0
- examples/shared/__init__.py +0 -0
- examples/shared/replicate_tool.py +238 -0
- examples/storyboard/__init__.py +1 -0
- examples/storyboard/generate_videos.py +335 -0
- examples/storyboard/nodes/__init__.py +12 -0
- examples/storyboard/nodes/animated_character_node.py +248 -0
- examples/storyboard/nodes/animated_image_node.py +138 -0
- examples/storyboard/nodes/character_node.py +162 -0
- examples/storyboard/nodes/image_node.py +118 -0
- examples/storyboard/nodes/replicate_tool.py +49 -0
- examples/storyboard/retry_images.py +118 -0
- scripts/demo_async_executor.py +212 -0
- scripts/demo_interview_e2e.py +200 -0
- scripts/demo_streaming.py +140 -0
- scripts/run_interview_demo.py +94 -0
- scripts/test_interrupt_fix.py +26 -0
- tests/__init__.py +1 -0
- tests/conftest.py +178 -0
- tests/integration/__init__.py +1 -0
- tests/integration/test_animated_storyboard.py +63 -0
- tests/integration/test_cli_commands.py +242 -0
- tests/integration/test_colocated_prompts.py +139 -0
- tests/integration/test_map_demo.py +50 -0
- tests/integration/test_memory_demo.py +283 -0
- tests/integration/test_npc_api/__init__.py +1 -0
- tests/integration/test_npc_api/test_routes.py +357 -0
- tests/integration/test_npc_api/test_session.py +216 -0
- tests/integration/test_pipeline_flow.py +105 -0
- tests/integration/test_providers.py +163 -0
- tests/integration/test_resume.py +75 -0
- tests/integration/test_subgraph_integration.py +295 -0
- tests/integration/test_subgraph_interrupt.py +106 -0
- tests/unit/__init__.py +1 -0
- tests/unit/test_agent_nodes.py +355 -0
- tests/unit/test_async_executor.py +346 -0
- tests/unit/test_checkpointer.py +212 -0
- tests/unit/test_checkpointer_factory.py +212 -0
- tests/unit/test_cli.py +121 -0
- tests/unit/test_cli_package.py +81 -0
- tests/unit/test_compile_graph_map.py +132 -0
- tests/unit/test_conditions_routing.py +253 -0
- tests/unit/test_config.py +93 -0
- tests/unit/test_conversation_memory.py +276 -0
- tests/unit/test_database.py +145 -0
- tests/unit/test_deprecation.py +104 -0
- tests/unit/test_executor.py +172 -0
- tests/unit/test_executor_async.py +179 -0
- tests/unit/test_export.py +149 -0
- tests/unit/test_expressions.py +178 -0
- tests/unit/test_feature_brainstorm.py +194 -0
- tests/unit/test_format_prompt.py +145 -0
- tests/unit/test_generic_report.py +200 -0
- tests/unit/test_graph_commands.py +327 -0
- tests/unit/test_graph_linter.py +627 -0
- tests/unit/test_graph_loader.py +357 -0
- tests/unit/test_graph_schema.py +193 -0
- tests/unit/test_inline_schema.py +151 -0
- tests/unit/test_interrupt_node.py +182 -0
- tests/unit/test_issues.py +164 -0
- tests/unit/test_jinja2_prompts.py +85 -0
- tests/unit/test_json_extract.py +134 -0
- tests/unit/test_langsmith.py +600 -0
- tests/unit/test_langsmith_tools.py +204 -0
- tests/unit/test_llm_factory.py +109 -0
- tests/unit/test_llm_factory_async.py +118 -0
- tests/unit/test_loops.py +403 -0
- tests/unit/test_map_node.py +144 -0
- tests/unit/test_no_backward_compat.py +56 -0
- tests/unit/test_node_factory.py +348 -0
- tests/unit/test_passthrough_node.py +126 -0
- tests/unit/test_prompts.py +324 -0
- tests/unit/test_python_nodes.py +198 -0
- tests/unit/test_reliability.py +298 -0
- tests/unit/test_result_export.py +234 -0
- tests/unit/test_router.py +296 -0
- tests/unit/test_sanitize.py +99 -0
- tests/unit/test_schema_loader.py +295 -0
- tests/unit/test_shell_tools.py +229 -0
- tests/unit/test_state_builder.py +331 -0
- tests/unit/test_state_builder_map.py +104 -0
- tests/unit/test_state_config.py +197 -0
- tests/unit/test_streaming.py +307 -0
- tests/unit/test_subgraph.py +596 -0
- tests/unit/test_template.py +190 -0
- tests/unit/test_tool_call_integration.py +164 -0
- tests/unit/test_tool_call_node.py +178 -0
- tests/unit/test_tool_nodes.py +129 -0
- tests/unit/test_websearch.py +234 -0
- yamlgraph/__init__.py +35 -0
- yamlgraph/builder.py +110 -0
- yamlgraph/cli/__init__.py +159 -0
- yamlgraph/cli/__main__.py +6 -0
- yamlgraph/cli/commands.py +231 -0
- yamlgraph/cli/deprecation.py +92 -0
- yamlgraph/cli/graph_commands.py +541 -0
- yamlgraph/cli/validators.py +37 -0
- yamlgraph/config.py +67 -0
- yamlgraph/constants.py +70 -0
- yamlgraph/error_handlers.py +227 -0
- yamlgraph/executor.py +290 -0
- yamlgraph/executor_async.py +288 -0
- yamlgraph/graph_loader.py +451 -0
- yamlgraph/map_compiler.py +150 -0
- yamlgraph/models/__init__.py +36 -0
- yamlgraph/models/graph_schema.py +181 -0
- yamlgraph/models/schemas.py +124 -0
- yamlgraph/models/state_builder.py +236 -0
- yamlgraph/node_factory.py +768 -0
- yamlgraph/routing.py +87 -0
- yamlgraph/schema_loader.py +240 -0
- yamlgraph/storage/__init__.py +20 -0
- yamlgraph/storage/checkpointer.py +72 -0
- yamlgraph/storage/checkpointer_factory.py +123 -0
- yamlgraph/storage/database.py +320 -0
- yamlgraph/storage/export.py +269 -0
- yamlgraph/tools/__init__.py +1 -0
- yamlgraph/tools/agent.py +320 -0
- yamlgraph/tools/graph_linter.py +388 -0
- yamlgraph/tools/langsmith_tools.py +125 -0
- yamlgraph/tools/nodes.py +126 -0
- yamlgraph/tools/python_tool.py +179 -0
- yamlgraph/tools/shell.py +205 -0
- yamlgraph/tools/websearch.py +242 -0
- yamlgraph/utils/__init__.py +48 -0
- yamlgraph/utils/conditions.py +157 -0
- yamlgraph/utils/expressions.py +245 -0
- yamlgraph/utils/json_extract.py +104 -0
- yamlgraph/utils/langsmith.py +416 -0
- yamlgraph/utils/llm_factory.py +118 -0
- yamlgraph/utils/llm_factory_async.py +105 -0
- yamlgraph/utils/logging.py +104 -0
- yamlgraph/utils/prompts.py +171 -0
- yamlgraph/utils/sanitize.py +98 -0
- yamlgraph/utils/template.py +102 -0
- yamlgraph/utils/validators.py +181 -0
- yamlgraph-0.3.9.dist-info/METADATA +1105 -0
- yamlgraph-0.3.9.dist-info/RECORD +185 -0
- yamlgraph-0.3.9.dist-info/WHEEL +5 -0
- yamlgraph-0.3.9.dist-info/entry_points.txt +2 -0
- yamlgraph-0.3.9.dist-info/licenses/LICENSE +33 -0
- yamlgraph-0.3.9.dist-info/top_level.txt +4 -0
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
"""Async LLM Factory - Async versions of LLM creation.
|
|
2
|
+
|
|
3
|
+
This module provides async-compatible LLM creation with support for
|
|
4
|
+
non-blocking I/O operations in async contexts.
|
|
5
|
+
|
|
6
|
+
Note: This module is a foundation for future async support. Currently,
|
|
7
|
+
LangChain's LLM implementations use sync HTTP clients internally, so
|
|
8
|
+
this wraps them for use in async contexts via run_in_executor.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import asyncio
|
|
12
|
+
import logging
|
|
13
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
14
|
+
from functools import partial
|
|
15
|
+
from typing import TypeVar
|
|
16
|
+
|
|
17
|
+
from langchain_core.language_models.chat_models import BaseChatModel
|
|
18
|
+
from langchain_core.messages import BaseMessage
|
|
19
|
+
from pydantic import BaseModel
|
|
20
|
+
|
|
21
|
+
from yamlgraph.utils.llm_factory import ProviderType, create_llm
|
|
22
|
+
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
T = TypeVar("T", bound=BaseModel)
|
|
26
|
+
|
|
27
|
+
# Shared executor for running sync LLM calls
|
|
28
|
+
_executor: ThreadPoolExecutor | None = None
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def get_executor() -> ThreadPoolExecutor:
|
|
32
|
+
"""Get or create the shared thread pool executor."""
|
|
33
|
+
global _executor
|
|
34
|
+
if _executor is None:
|
|
35
|
+
_executor = ThreadPoolExecutor(max_workers=4)
|
|
36
|
+
return _executor
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
async def create_llm_async(
|
|
40
|
+
provider: ProviderType | None = None,
|
|
41
|
+
model: str | None = None,
|
|
42
|
+
temperature: float = 0.7,
|
|
43
|
+
) -> BaseChatModel:
|
|
44
|
+
"""Create an LLM instance asynchronously.
|
|
45
|
+
|
|
46
|
+
Currently wraps the sync create_llm. Future versions may use
|
|
47
|
+
native async LLM implementations.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
provider: LLM provider ("anthropic", "mistral", "openai")
|
|
51
|
+
model: Model name
|
|
52
|
+
temperature: Temperature for generation
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
Configured LLM instance
|
|
56
|
+
"""
|
|
57
|
+
loop = asyncio.get_event_loop()
|
|
58
|
+
return await loop.run_in_executor(
|
|
59
|
+
get_executor(),
|
|
60
|
+
partial(create_llm, provider=provider, model=model, temperature=temperature),
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
async def invoke_async(
|
|
65
|
+
llm: BaseChatModel,
|
|
66
|
+
messages: list[BaseMessage],
|
|
67
|
+
output_model: type[T] | None = None,
|
|
68
|
+
) -> T | str:
|
|
69
|
+
"""Invoke LLM asynchronously.
|
|
70
|
+
|
|
71
|
+
Runs the sync invoke in a thread pool to avoid blocking.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
llm: The LLM instance
|
|
75
|
+
messages: Messages to send
|
|
76
|
+
output_model: Optional Pydantic model for structured output
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
LLM response (parsed model or string)
|
|
80
|
+
"""
|
|
81
|
+
loop = asyncio.get_event_loop()
|
|
82
|
+
|
|
83
|
+
def sync_invoke() -> T | str:
|
|
84
|
+
if output_model:
|
|
85
|
+
structured_llm = llm.with_structured_output(output_model)
|
|
86
|
+
return structured_llm.invoke(messages)
|
|
87
|
+
else:
|
|
88
|
+
response = llm.invoke(messages)
|
|
89
|
+
return response.content
|
|
90
|
+
|
|
91
|
+
return await loop.run_in_executor(get_executor(), sync_invoke)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def shutdown_executor() -> None:
|
|
95
|
+
"""Shutdown the thread pool executor.
|
|
96
|
+
|
|
97
|
+
Call this during application shutdown to clean up resources.
|
|
98
|
+
"""
|
|
99
|
+
global _executor
|
|
100
|
+
if _executor is not None:
|
|
101
|
+
_executor.shutdown(wait=True)
|
|
102
|
+
_executor = None
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
__all__ = ["create_llm_async", "invoke_async", "shutdown_executor"]
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
"""Structured logging configuration for yamlgraph.
|
|
2
|
+
|
|
3
|
+
Provides consistent logging across all modules with JSON-formatted
|
|
4
|
+
output for production environments.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
import os
|
|
9
|
+
import sys
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class StructuredFormatter(logging.Formatter):
|
|
13
|
+
"""Formatter that outputs structured log messages.
|
|
14
|
+
|
|
15
|
+
In production (LOG_FORMAT=json), outputs JSON lines.
|
|
16
|
+
In development, outputs human-readable format.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(self, use_json: bool = False):
|
|
20
|
+
super().__init__()
|
|
21
|
+
self.use_json = use_json
|
|
22
|
+
|
|
23
|
+
def format(self, record: logging.LogRecord) -> str:
|
|
24
|
+
"""Format a log record."""
|
|
25
|
+
if self.use_json:
|
|
26
|
+
import json
|
|
27
|
+
|
|
28
|
+
log_data = {
|
|
29
|
+
"timestamp": self.formatTime(record),
|
|
30
|
+
"level": record.levelname,
|
|
31
|
+
"logger": record.name,
|
|
32
|
+
"message": record.getMessage(),
|
|
33
|
+
}
|
|
34
|
+
# Add extra fields if present
|
|
35
|
+
if hasattr(record, "extra"):
|
|
36
|
+
log_data.update(record.extra)
|
|
37
|
+
if record.exc_info:
|
|
38
|
+
log_data["exception"] = self.formatException(record.exc_info)
|
|
39
|
+
return json.dumps(log_data)
|
|
40
|
+
else:
|
|
41
|
+
# Human-readable format
|
|
42
|
+
base = f"{self.formatTime(record)} [{record.levelname}] {record.name}: {record.getMessage()}"
|
|
43
|
+
if record.exc_info:
|
|
44
|
+
base += f"\n{self.formatException(record.exc_info)}"
|
|
45
|
+
return base
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def setup_logging(
|
|
49
|
+
level: str | None = None,
|
|
50
|
+
use_json: bool | None = None,
|
|
51
|
+
) -> logging.Logger:
|
|
52
|
+
"""Configure logging for yamlgraph.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
level: Log level (DEBUG, INFO, WARNING, ERROR).
|
|
56
|
+
Defaults to LOG_LEVEL env var or INFO.
|
|
57
|
+
use_json: If True, output JSON lines.
|
|
58
|
+
Defaults to LOG_FORMAT=json env var.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
Root logger for the yamlgraph package
|
|
62
|
+
"""
|
|
63
|
+
if level is None:
|
|
64
|
+
level = os.getenv("LOG_LEVEL", "INFO")
|
|
65
|
+
|
|
66
|
+
if use_json is None:
|
|
67
|
+
use_json = os.getenv("LOG_FORMAT", "").lower() == "json"
|
|
68
|
+
|
|
69
|
+
# Get the yamlgraph logger
|
|
70
|
+
logger = logging.getLogger("yamlgraph")
|
|
71
|
+
logger.setLevel(getattr(logging, level.upper()))
|
|
72
|
+
|
|
73
|
+
# Remove existing handlers
|
|
74
|
+
logger.handlers.clear()
|
|
75
|
+
|
|
76
|
+
# Add handler with formatter
|
|
77
|
+
handler = logging.StreamHandler(sys.stderr)
|
|
78
|
+
handler.setFormatter(StructuredFormatter(use_json=use_json))
|
|
79
|
+
logger.addHandler(handler)
|
|
80
|
+
|
|
81
|
+
# Don't propagate to root logger
|
|
82
|
+
logger.propagate = False
|
|
83
|
+
|
|
84
|
+
return logger
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def get_logger(name: str) -> logging.Logger:
|
|
88
|
+
"""Get a logger for a specific module.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
name: Module name (typically __name__)
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
Logger instance
|
|
95
|
+
|
|
96
|
+
Example:
|
|
97
|
+
>>> logger = get_logger(__name__)
|
|
98
|
+
>>> logger.info("Processing started", extra={"topic": "AI"})
|
|
99
|
+
"""
|
|
100
|
+
return logging.getLogger(name)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
# Initialize logging on import
|
|
104
|
+
_root_logger = setup_logging()
|
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
"""Unified prompt loading and path resolution.
|
|
2
|
+
|
|
3
|
+
This module consolidates prompt loading logic used by executor.py
|
|
4
|
+
and node_factory.py into a single, testable module.
|
|
5
|
+
|
|
6
|
+
Search order for prompts:
|
|
7
|
+
1. If prompts_relative + prompts_dir + graph_path: graph_path.parent/prompts_dir/{prompt_name}.yaml
|
|
8
|
+
2. If prompts_dir specified: prompts_dir/{prompt_name}.yaml
|
|
9
|
+
3. If prompts_relative + graph_path: graph_path.parent/{prompt_name}.yaml
|
|
10
|
+
4. Default: PROMPTS_DIR/{prompt_name}.yaml
|
|
11
|
+
5. Fallback: {parent}/prompts/{basename}.yaml (external examples)
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
|
|
16
|
+
import yaml
|
|
17
|
+
|
|
18
|
+
from yamlgraph.config import PROMPTS_DIR
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def resolve_prompt_path(
|
|
22
|
+
prompt_name: str,
|
|
23
|
+
prompts_dir: Path | None = None,
|
|
24
|
+
graph_path: Path | None = None,
|
|
25
|
+
prompts_relative: bool = False,
|
|
26
|
+
) -> Path:
|
|
27
|
+
"""Resolve a prompt name to its full YAML file path.
|
|
28
|
+
|
|
29
|
+
Resolution order:
|
|
30
|
+
1. If prompts_relative + prompts_dir + graph_path: graph_path.parent/prompts_dir/{prompt_name}.yaml
|
|
31
|
+
2. If prompts_dir specified: prompts_dir/{prompt_name}.yaml
|
|
32
|
+
3. If prompts_relative + graph_path: graph_path.parent/{prompt_name}.yaml
|
|
33
|
+
4. Default: PROMPTS_DIR/{prompt_name}.yaml
|
|
34
|
+
5. Fallback: {parent}/prompts/{basename}.yaml (external examples)
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
prompt_name: Prompt name like "greet" or "prompts/opening"
|
|
38
|
+
prompts_dir: Explicit prompts directory (combined with graph_path if prompts_relative=True)
|
|
39
|
+
graph_path: Path to the graph YAML file (for relative resolution)
|
|
40
|
+
prompts_relative: If True, resolve relative to graph_path.parent
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
Path to the YAML file
|
|
44
|
+
|
|
45
|
+
Raises:
|
|
46
|
+
FileNotFoundError: If prompt file doesn't exist
|
|
47
|
+
ValueError: If prompts_relative=True but graph_path not provided
|
|
48
|
+
|
|
49
|
+
Examples:
|
|
50
|
+
>>> resolve_prompt_path("greet")
|
|
51
|
+
PosixPath('/path/to/prompts/greet.yaml')
|
|
52
|
+
|
|
53
|
+
>>> resolve_prompt_path("prompts/opening", graph_path=Path("graphs/demo.yaml"), prompts_relative=True)
|
|
54
|
+
PosixPath('/path/to/graphs/prompts/opening.yaml')
|
|
55
|
+
|
|
56
|
+
>>> resolve_prompt_path("opening", prompts_dir="prompts", graph_path=Path("graphs/demo.yaml"), prompts_relative=True)
|
|
57
|
+
PosixPath('/path/to/graphs/prompts/opening.yaml')
|
|
58
|
+
"""
|
|
59
|
+
# Validate prompts_relative requires graph_path
|
|
60
|
+
if prompts_relative and graph_path is None and prompts_dir is None:
|
|
61
|
+
raise ValueError("graph_path required when prompts_relative=True")
|
|
62
|
+
|
|
63
|
+
# 1. Graph-relative with explicit prompts_dir (combine them)
|
|
64
|
+
if prompts_relative and prompts_dir is not None and graph_path is not None:
|
|
65
|
+
graph_dir = Path(graph_path).parent
|
|
66
|
+
yaml_path = graph_dir / prompts_dir / f"{prompt_name}.yaml"
|
|
67
|
+
if yaml_path.exists():
|
|
68
|
+
return yaml_path
|
|
69
|
+
# Fall through if not found
|
|
70
|
+
|
|
71
|
+
# 2. Explicit prompts_dir (absolute path or CWD-relative)
|
|
72
|
+
if prompts_dir is not None:
|
|
73
|
+
prompts_dir = Path(prompts_dir)
|
|
74
|
+
yaml_path = prompts_dir / f"{prompt_name}.yaml"
|
|
75
|
+
if yaml_path.exists():
|
|
76
|
+
return yaml_path
|
|
77
|
+
# Fall through to other resolution methods
|
|
78
|
+
|
|
79
|
+
# 3. Graph-relative resolution (without explicit prompts_dir)
|
|
80
|
+
if prompts_relative and graph_path is not None:
|
|
81
|
+
graph_dir = Path(graph_path).parent
|
|
82
|
+
yaml_path = graph_dir / f"{prompt_name}.yaml"
|
|
83
|
+
if yaml_path.exists():
|
|
84
|
+
return yaml_path
|
|
85
|
+
# Fall through to default
|
|
86
|
+
|
|
87
|
+
# 4. Default: use global PROMPTS_DIR
|
|
88
|
+
default_dir = PROMPTS_DIR if prompts_dir is None else prompts_dir
|
|
89
|
+
yaml_path = Path(default_dir) / f"{prompt_name}.yaml"
|
|
90
|
+
if yaml_path.exists():
|
|
91
|
+
return yaml_path
|
|
92
|
+
|
|
93
|
+
# 5. Fallback: external example location {parent}/prompts/{basename}.yaml
|
|
94
|
+
parts = prompt_name.rsplit("/", 1)
|
|
95
|
+
if len(parts) == 2:
|
|
96
|
+
parent_dir, basename = parts
|
|
97
|
+
alt_path = Path(parent_dir) / "prompts" / f"{basename}.yaml"
|
|
98
|
+
if alt_path.exists():
|
|
99
|
+
return alt_path
|
|
100
|
+
|
|
101
|
+
raise FileNotFoundError(f"Prompt not found: {yaml_path}")
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def load_prompt(
|
|
105
|
+
prompt_name: str,
|
|
106
|
+
prompts_dir: Path | None = None,
|
|
107
|
+
graph_path: Path | None = None,
|
|
108
|
+
prompts_relative: bool = False,
|
|
109
|
+
) -> dict:
|
|
110
|
+
"""Load a YAML prompt template.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
prompt_name: Name of the prompt file (without .yaml extension)
|
|
114
|
+
prompts_dir: Optional prompts directory override
|
|
115
|
+
graph_path: Path to the graph YAML file (for relative resolution)
|
|
116
|
+
prompts_relative: If True, resolve relative to graph_path.parent
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
Dictionary with prompt content (typically 'system' and 'user' keys)
|
|
120
|
+
|
|
121
|
+
Raises:
|
|
122
|
+
FileNotFoundError: If prompt file doesn't exist
|
|
123
|
+
"""
|
|
124
|
+
path = resolve_prompt_path(
|
|
125
|
+
prompt_name,
|
|
126
|
+
prompts_dir=prompts_dir,
|
|
127
|
+
graph_path=graph_path,
|
|
128
|
+
prompts_relative=prompts_relative,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
with open(path) as f:
|
|
132
|
+
return yaml.safe_load(f)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def load_prompt_path(
|
|
136
|
+
prompt_name: str,
|
|
137
|
+
prompts_dir: Path | None = None,
|
|
138
|
+
graph_path: Path | None = None,
|
|
139
|
+
prompts_relative: bool = False,
|
|
140
|
+
) -> tuple[Path, dict]:
|
|
141
|
+
"""Load a prompt and return both path and content.
|
|
142
|
+
|
|
143
|
+
Useful when you need both the file path (for schema loading)
|
|
144
|
+
and the content (for prompt execution).
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
prompt_name: Name of the prompt file (without .yaml extension)
|
|
148
|
+
prompts_dir: Optional prompts directory override
|
|
149
|
+
graph_path: Path to the graph YAML file (for relative resolution)
|
|
150
|
+
prompts_relative: If True, resolve relative to graph_path.parent
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
Tuple of (path, content_dict)
|
|
154
|
+
|
|
155
|
+
Raises:
|
|
156
|
+
FileNotFoundError: If prompt file doesn't exist
|
|
157
|
+
"""
|
|
158
|
+
path = resolve_prompt_path(
|
|
159
|
+
prompt_name,
|
|
160
|
+
prompts_dir=prompts_dir,
|
|
161
|
+
graph_path=graph_path,
|
|
162
|
+
prompts_relative=prompts_relative,
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
with open(path) as f:
|
|
166
|
+
content = yaml.safe_load(f)
|
|
167
|
+
|
|
168
|
+
return path, content
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
__all__ = ["resolve_prompt_path", "load_prompt", "load_prompt_path"]
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
"""Input sanitization utilities.
|
|
2
|
+
|
|
3
|
+
Provides functions for validating and sanitizing user input
|
|
4
|
+
to prevent prompt injection and other security issues.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import re
|
|
8
|
+
from typing import NamedTuple
|
|
9
|
+
|
|
10
|
+
from yamlgraph.config import DANGEROUS_PATTERNS, MAX_TOPIC_LENGTH
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class SanitizationResult(NamedTuple):
|
|
14
|
+
"""Result of input sanitization."""
|
|
15
|
+
|
|
16
|
+
value: str
|
|
17
|
+
is_safe: bool
|
|
18
|
+
warnings: list[str]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def sanitize_topic(topic: str) -> SanitizationResult:
|
|
22
|
+
"""Sanitize a topic string for use in prompts.
|
|
23
|
+
|
|
24
|
+
Checks for:
|
|
25
|
+
- Length limits
|
|
26
|
+
- Potential prompt injection patterns
|
|
27
|
+
- Control characters
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
topic: The raw topic string
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
SanitizationResult with cleaned value and safety status
|
|
34
|
+
|
|
35
|
+
Example:
|
|
36
|
+
>>> result = sanitize_topic("machine learning")
|
|
37
|
+
>>> result.is_safe
|
|
38
|
+
True
|
|
39
|
+
>>> result = sanitize_topic("ignore previous instructions")
|
|
40
|
+
>>> result.is_safe
|
|
41
|
+
False
|
|
42
|
+
"""
|
|
43
|
+
warnings = []
|
|
44
|
+
cleaned = topic.strip()
|
|
45
|
+
|
|
46
|
+
# Check length
|
|
47
|
+
if len(cleaned) > MAX_TOPIC_LENGTH:
|
|
48
|
+
cleaned = cleaned[:MAX_TOPIC_LENGTH]
|
|
49
|
+
warnings.append(f"Topic truncated to {MAX_TOPIC_LENGTH} characters")
|
|
50
|
+
|
|
51
|
+
# Check for empty
|
|
52
|
+
if not cleaned:
|
|
53
|
+
return SanitizationResult(
|
|
54
|
+
value="",
|
|
55
|
+
is_safe=False,
|
|
56
|
+
warnings=["Topic cannot be empty"],
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
# Remove control characters (except newlines)
|
|
60
|
+
cleaned = re.sub(r"[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]", "", cleaned)
|
|
61
|
+
|
|
62
|
+
# Check for dangerous patterns (case-insensitive)
|
|
63
|
+
topic_lower = cleaned.lower()
|
|
64
|
+
for pattern in DANGEROUS_PATTERNS:
|
|
65
|
+
if pattern.lower() in topic_lower:
|
|
66
|
+
return SanitizationResult(
|
|
67
|
+
value=cleaned,
|
|
68
|
+
is_safe=False,
|
|
69
|
+
warnings=[f"Topic contains potentially unsafe pattern: '{pattern}'"],
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
return SanitizationResult(
|
|
73
|
+
value=cleaned,
|
|
74
|
+
is_safe=True,
|
|
75
|
+
warnings=warnings,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def sanitize_variables(variables: dict) -> dict:
|
|
80
|
+
"""Sanitize a dictionary of template variables.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
variables: Dictionary of variable name -> value
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
Sanitized dictionary with cleaned values
|
|
87
|
+
"""
|
|
88
|
+
sanitized = {}
|
|
89
|
+
|
|
90
|
+
for key, value in variables.items():
|
|
91
|
+
if isinstance(value, str):
|
|
92
|
+
# Remove control characters but preserve newlines
|
|
93
|
+
cleaned = re.sub(r"[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]", "", value)
|
|
94
|
+
sanitized[key] = cleaned
|
|
95
|
+
else:
|
|
96
|
+
sanitized[key] = value
|
|
97
|
+
|
|
98
|
+
return sanitized
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
"""Template utilities - Variable extraction and validation.
|
|
2
|
+
|
|
3
|
+
This module provides functions to extract required variables from
|
|
4
|
+
prompt templates and validate that all required variables are provided
|
|
5
|
+
before execution.
|
|
6
|
+
|
|
7
|
+
Supports both simple {variable} placeholders and Jinja2 templates.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import logging
|
|
11
|
+
import re
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def extract_variables(template: str) -> set[str]:
|
|
18
|
+
"""Extract all variable names required by a template.
|
|
19
|
+
|
|
20
|
+
Handles both simple {var} and Jinja2 {{ var }}, {% for x in var %} syntax.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
template: Template string with placeholders
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
Set of variable names required by the template
|
|
27
|
+
|
|
28
|
+
Examples:
|
|
29
|
+
>>> extract_variables("Hello {name}")
|
|
30
|
+
{'name'}
|
|
31
|
+
|
|
32
|
+
>>> extract_variables("{% for item in items %}{{ item }}{% endfor %}")
|
|
33
|
+
{'items'}
|
|
34
|
+
"""
|
|
35
|
+
variables: set[str] = set()
|
|
36
|
+
|
|
37
|
+
# Simple format: {var} - but NOT {{ (Jinja2)
|
|
38
|
+
# Match {word} but not {{word}} - use negative lookbehind/lookahead
|
|
39
|
+
simple_pattern = r"(?<!\{)\{(\w+)\}(?!\})"
|
|
40
|
+
variables.update(re.findall(simple_pattern, template))
|
|
41
|
+
|
|
42
|
+
# Jinja2 variable: {{ var }} or {{ var.field }}
|
|
43
|
+
jinja_var_pattern = r"\{\{\s*(\w+)"
|
|
44
|
+
variables.update(re.findall(jinja_var_pattern, template))
|
|
45
|
+
|
|
46
|
+
# Jinja2 loop: {% for x in var %}
|
|
47
|
+
jinja_loop_pattern = r"\{%\s*for\s+\w+\s+in\s+(\w+)"
|
|
48
|
+
variables.update(re.findall(jinja_loop_pattern, template))
|
|
49
|
+
|
|
50
|
+
# Jinja2 condition: {% if var %} or {% if var.field %}
|
|
51
|
+
jinja_if_pattern = r"\{%\s*if\s+(\w+)"
|
|
52
|
+
variables.update(re.findall(jinja_if_pattern, template))
|
|
53
|
+
|
|
54
|
+
# Remove loop iteration variables (they're not inputs)
|
|
55
|
+
# e.g., in "{% for item in items %}", "item" is not required
|
|
56
|
+
loop_iter_pattern = r"\{%\s*for\s+(\w+)\s+in"
|
|
57
|
+
loop_vars = set(re.findall(loop_iter_pattern, template))
|
|
58
|
+
variables -= loop_vars
|
|
59
|
+
|
|
60
|
+
# Remove common non-input variables
|
|
61
|
+
# - state: injected by node_factory
|
|
62
|
+
# - loop: Jinja2 loop context
|
|
63
|
+
# - range: Jinja2 builtin function
|
|
64
|
+
excluded = {"state", "loop", "range", "true", "false", "none"}
|
|
65
|
+
variables -= excluded
|
|
66
|
+
|
|
67
|
+
return variables
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def validate_variables(
|
|
71
|
+
template: str,
|
|
72
|
+
provided: dict[str, Any],
|
|
73
|
+
prompt_name: str,
|
|
74
|
+
) -> None:
|
|
75
|
+
"""Validate that all required template variables are provided.
|
|
76
|
+
|
|
77
|
+
Raises ValueError with helpful message listing all missing variables.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
template: Template string with placeholders
|
|
81
|
+
provided: Dictionary of provided variable values
|
|
82
|
+
prompt_name: Name of the prompt (for error messages)
|
|
83
|
+
|
|
84
|
+
Raises:
|
|
85
|
+
ValueError: If any required variables are missing
|
|
86
|
+
|
|
87
|
+
Examples:
|
|
88
|
+
>>> validate_variables("Hello {name}", {"name": "World"}, "greet")
|
|
89
|
+
# No error
|
|
90
|
+
|
|
91
|
+
>>> validate_variables("Hello {name}", {}, "greet")
|
|
92
|
+
ValueError: Missing required variable(s) for prompt 'greet': name
|
|
93
|
+
"""
|
|
94
|
+
required = extract_variables(template)
|
|
95
|
+
provided_keys = set(provided.keys())
|
|
96
|
+
missing = required - provided_keys
|
|
97
|
+
|
|
98
|
+
if missing:
|
|
99
|
+
raise ValueError(
|
|
100
|
+
f"Missing required variable(s) for prompt '{prompt_name}': "
|
|
101
|
+
f"{', '.join(sorted(missing))}"
|
|
102
|
+
)
|