polos-sdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- polos/__init__.py +105 -0
- polos/agents/__init__.py +7 -0
- polos/agents/agent.py +746 -0
- polos/agents/conversation_history.py +121 -0
- polos/agents/stop_conditions.py +280 -0
- polos/agents/stream.py +635 -0
- polos/core/__init__.py +0 -0
- polos/core/context.py +143 -0
- polos/core/state.py +26 -0
- polos/core/step.py +1380 -0
- polos/core/workflow.py +1192 -0
- polos/features/__init__.py +0 -0
- polos/features/events.py +456 -0
- polos/features/schedules.py +110 -0
- polos/features/tracing.py +605 -0
- polos/features/wait.py +82 -0
- polos/llm/__init__.py +9 -0
- polos/llm/generate.py +152 -0
- polos/llm/providers/__init__.py +5 -0
- polos/llm/providers/anthropic.py +615 -0
- polos/llm/providers/azure.py +42 -0
- polos/llm/providers/base.py +196 -0
- polos/llm/providers/fireworks.py +41 -0
- polos/llm/providers/gemini.py +40 -0
- polos/llm/providers/groq.py +40 -0
- polos/llm/providers/openai.py +1021 -0
- polos/llm/providers/together.py +40 -0
- polos/llm/stream.py +183 -0
- polos/middleware/__init__.py +0 -0
- polos/middleware/guardrail.py +148 -0
- polos/middleware/guardrail_executor.py +253 -0
- polos/middleware/hook.py +164 -0
- polos/middleware/hook_executor.py +104 -0
- polos/runtime/__init__.py +0 -0
- polos/runtime/batch.py +87 -0
- polos/runtime/client.py +841 -0
- polos/runtime/queue.py +42 -0
- polos/runtime/worker.py +1365 -0
- polos/runtime/worker_server.py +249 -0
- polos/tools/__init__.py +0 -0
- polos/tools/tool.py +587 -0
- polos/types/__init__.py +23 -0
- polos/types/types.py +116 -0
- polos/utils/__init__.py +27 -0
- polos/utils/agent.py +27 -0
- polos/utils/client_context.py +41 -0
- polos/utils/config.py +12 -0
- polos/utils/output_schema.py +311 -0
- polos/utils/retry.py +47 -0
- polos/utils/serializer.py +167 -0
- polos/utils/tracing.py +27 -0
- polos/utils/worker_singleton.py +40 -0
- polos_sdk-0.1.0.dist-info/METADATA +650 -0
- polos_sdk-0.1.0.dist-info/RECORD +55 -0
- polos_sdk-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
"""Together provider - routes to OpenAI provider with chat_completions API."""
|
|
2
|
+
|
|
3
|
+
from .base import register_provider
|
|
4
|
+
from .openai import OpenAIProvider
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@register_provider("together")
|
|
8
|
+
class TogetherProvider(OpenAIProvider):
|
|
9
|
+
"""Together provider using OpenAI provider with Chat Completions API."""
|
|
10
|
+
|
|
11
|
+
def __init__(self, api_key=None):
|
|
12
|
+
"""
|
|
13
|
+
Initialize Together provider.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
api_key: Together API key. If not provided, uses TOGETHER_API_KEY env var.
|
|
17
|
+
"""
|
|
18
|
+
import os
|
|
19
|
+
|
|
20
|
+
together_api_key = api_key or os.getenv("TOGETHER_API_KEY")
|
|
21
|
+
if not together_api_key:
|
|
22
|
+
raise ValueError(
|
|
23
|
+
"Together API key not provided. Set TOGETHER_API_KEY environment variable "
|
|
24
|
+
"or pass api_key parameter."
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
try:
|
|
28
|
+
from openai import AsyncOpenAI # noqa: F401
|
|
29
|
+
except ImportError:
|
|
30
|
+
raise ImportError(
|
|
31
|
+
"OpenAI SDK not installed. Install it with: pip install 'polos[together]'"
|
|
32
|
+
) from None
|
|
33
|
+
|
|
34
|
+
# Initialize with Together's base URL and chat_completions API version
|
|
35
|
+
super().__init__(
|
|
36
|
+
api_key=together_api_key,
|
|
37
|
+
base_url="https://api.together.xyz/v1",
|
|
38
|
+
llm_api="chat_completions",
|
|
39
|
+
)
|
|
40
|
+
self.supports_structured_output = False
|
polos/llm/stream.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
1
|
+
"""Built-in LLM streaming function."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from ..core.context import WorkflowContext
|
|
6
|
+
from ..core.workflow import _execution_context
|
|
7
|
+
from ..types.types import AgentConfig
|
|
8
|
+
from ..utils.agent import convert_input_to_messages
|
|
9
|
+
from ..utils.client_context import get_client_or_raise
|
|
10
|
+
from .providers import get_provider
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
async def _llm_stream(ctx: WorkflowContext, payload: dict[str, Any]) -> dict[str, Any]:
|
|
14
|
+
"""
|
|
15
|
+
Durable function for LLM streaming.
|
|
16
|
+
|
|
17
|
+
Must be executed within a workflow execution context. Uses step_outputs for durability.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
ctx: WorkflowContext for the current execution
|
|
21
|
+
payload: Dictionary containing:
|
|
22
|
+
- agent_run_id: str
|
|
23
|
+
- agent_config: Dict with provider, model, tools, system_prompt, etc.
|
|
24
|
+
- input: Union[str, List[Dict]] - Input data
|
|
25
|
+
- agent_step: int - Step in agent conversation (1 = first, 2 = after tools, etc.)
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
Dictionary with streaming result containing content, tool_calls, usage, etc.
|
|
29
|
+
"""
|
|
30
|
+
# Check we're in a workflow execution context
|
|
31
|
+
exec_context = _execution_context.get()
|
|
32
|
+
if not exec_context or not exec_context.get("execution_id"):
|
|
33
|
+
raise ValueError("_llm_stream must be executed within an agent")
|
|
34
|
+
|
|
35
|
+
# Extract payload
|
|
36
|
+
agent_run_id = payload["agent_run_id"]
|
|
37
|
+
agent_config = AgentConfig.model_validate(payload["agent_config"])
|
|
38
|
+
input_data = payload["input"]
|
|
39
|
+
agent_step = payload.get("agent_step", 1)
|
|
40
|
+
tool_results = payload.get("tool_results") # Optional tool results in OpenAI format
|
|
41
|
+
|
|
42
|
+
# Get LLM provider
|
|
43
|
+
provider_kwargs = {}
|
|
44
|
+
if agent_config.provider_base_url:
|
|
45
|
+
provider_kwargs["base_url"] = agent_config.provider_base_url
|
|
46
|
+
if agent_config.provider_llm_api:
|
|
47
|
+
provider_kwargs["llm_api"] = agent_config.provider_llm_api
|
|
48
|
+
provider = get_provider(agent_config.provider, **provider_kwargs)
|
|
49
|
+
|
|
50
|
+
# Convert input to messages format (without system_prompt - provider will handle it)
|
|
51
|
+
messages = convert_input_to_messages(input_data, system_prompt=None)
|
|
52
|
+
|
|
53
|
+
topic = f"workflow:{agent_run_id}"
|
|
54
|
+
|
|
55
|
+
# Stream from LLM API and publish events
|
|
56
|
+
# Call helper function to handle streaming using step.run() for durable execution
|
|
57
|
+
streaming_result = await ctx.step.run(
|
|
58
|
+
f"llm_stream:{agent_step}",
|
|
59
|
+
_stream_from_provider,
|
|
60
|
+
ctx=ctx,
|
|
61
|
+
provider=provider,
|
|
62
|
+
messages=messages,
|
|
63
|
+
agent_config=agent_config,
|
|
64
|
+
tool_results=tool_results,
|
|
65
|
+
topic=topic,
|
|
66
|
+
agent_step=agent_step,
|
|
67
|
+
agent_run_id=agent_run_id,
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
return streaming_result
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
async def _stream_from_provider(
|
|
74
|
+
ctx: WorkflowContext,
|
|
75
|
+
provider: Any,
|
|
76
|
+
messages: list[dict[str, Any]],
|
|
77
|
+
agent_config: AgentConfig,
|
|
78
|
+
tool_results: list[dict[str, Any]] | None,
|
|
79
|
+
topic: str,
|
|
80
|
+
agent_step: int,
|
|
81
|
+
agent_run_id: str,
|
|
82
|
+
) -> dict[str, Any]:
|
|
83
|
+
"""
|
|
84
|
+
Helper function to stream from provider and publish events.
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
Dictionary with chunk_index, response_content, response_tool_calls, usage, raw_output
|
|
88
|
+
"""
|
|
89
|
+
from ..features.events import publish as publish_event
|
|
90
|
+
|
|
91
|
+
chunk_index = 0
|
|
92
|
+
response_content = None
|
|
93
|
+
response_tool_calls = None
|
|
94
|
+
usage = None
|
|
95
|
+
raw_output = None
|
|
96
|
+
polos_client = get_client_or_raise()
|
|
97
|
+
|
|
98
|
+
# Publish start event
|
|
99
|
+
# This is needed for invalidating events in the case of failures during
|
|
100
|
+
# agent streaming
|
|
101
|
+
# If the consumer seems stream_start event for the same agent_step,
|
|
102
|
+
# discard previous events for that agent_step
|
|
103
|
+
await publish_event(
|
|
104
|
+
client=polos_client,
|
|
105
|
+
topic=topic,
|
|
106
|
+
event_type="stream_start",
|
|
107
|
+
data={"step": agent_step},
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
# Stream from provider
|
|
111
|
+
# Pass agent_config and tool_results to provider
|
|
112
|
+
# Include provider_kwargs if provided
|
|
113
|
+
provider_kwargs = agent_config.provider_kwargs or {}
|
|
114
|
+
async for event in provider.stream(
|
|
115
|
+
messages=messages,
|
|
116
|
+
model=agent_config.model,
|
|
117
|
+
tools=agent_config.tools,
|
|
118
|
+
temperature=agent_config.temperature,
|
|
119
|
+
max_tokens=agent_config.max_output_tokens,
|
|
120
|
+
top_p=agent_config.top_p,
|
|
121
|
+
agent_config=agent_config.model_dump(mode="json"),
|
|
122
|
+
tool_results=tool_results,
|
|
123
|
+
output_schema=agent_config.output_schema,
|
|
124
|
+
output_schema_name=agent_config.output_schema_name,
|
|
125
|
+
**provider_kwargs,
|
|
126
|
+
):
|
|
127
|
+
# Handle error events
|
|
128
|
+
if event.get("type") == "error":
|
|
129
|
+
error_msg = event.get("data", {}).get("error", "Unknown error")
|
|
130
|
+
raise RuntimeError(f"LLM streaming error: {error_msg}")
|
|
131
|
+
|
|
132
|
+
# Event is already in normalized format from provider
|
|
133
|
+
normalized_chunk = event
|
|
134
|
+
event_type = None
|
|
135
|
+
|
|
136
|
+
# Accumulate response data for llm_calls update
|
|
137
|
+
if normalized_chunk["type"] == "text_delta":
|
|
138
|
+
event_type = "text_delta"
|
|
139
|
+
if response_content is None:
|
|
140
|
+
response_content = ""
|
|
141
|
+
content = normalized_chunk["data"].get("content", "")
|
|
142
|
+
if content:
|
|
143
|
+
response_content += content
|
|
144
|
+
elif normalized_chunk["type"] == "tool_call":
|
|
145
|
+
event_type = "tool_call"
|
|
146
|
+
if response_tool_calls is None:
|
|
147
|
+
response_tool_calls = []
|
|
148
|
+
tool_call = normalized_chunk["data"].get("tool_call")
|
|
149
|
+
if tool_call:
|
|
150
|
+
response_tool_calls.append(tool_call)
|
|
151
|
+
elif normalized_chunk["type"] == "done":
|
|
152
|
+
usage = normalized_chunk["data"].get("usage")
|
|
153
|
+
raw_output = normalized_chunk["data"].get("raw_output")
|
|
154
|
+
|
|
155
|
+
# Publish chunk as event (skip "done" events)
|
|
156
|
+
if normalized_chunk["type"] != "done" and event_type:
|
|
157
|
+
await publish_event(
|
|
158
|
+
client=polos_client,
|
|
159
|
+
topic=topic,
|
|
160
|
+
event_type=event_type,
|
|
161
|
+
data={
|
|
162
|
+
"step": agent_step,
|
|
163
|
+
"chunk_index": chunk_index,
|
|
164
|
+
"content": normalized_chunk["data"].get("content"),
|
|
165
|
+
"tool_call": normalized_chunk["data"].get("tool_call"),
|
|
166
|
+
"usage": normalized_chunk["data"].get("usage"),
|
|
167
|
+
"_metadata": {
|
|
168
|
+
"execution_id": agent_run_id,
|
|
169
|
+
"workflow_id": ctx.workflow_id,
|
|
170
|
+
},
|
|
171
|
+
},
|
|
172
|
+
)
|
|
173
|
+
chunk_index += 1
|
|
174
|
+
|
|
175
|
+
return {
|
|
176
|
+
"agent_run_id": agent_run_id,
|
|
177
|
+
"chunk_count": chunk_index,
|
|
178
|
+
"status": "completed",
|
|
179
|
+
"content": response_content,
|
|
180
|
+
"tool_calls": response_tool_calls if response_tool_calls else None,
|
|
181
|
+
"usage": usage if usage else None,
|
|
182
|
+
"raw_output": raw_output if raw_output else None,
|
|
183
|
+
}
|
|
File without changes
|
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
"""Guardrail classes for validating/modifying LLM responses before tool execution.
|
|
2
|
+
|
|
3
|
+
Guardrails are executed after LLM calls but before tool execution.
|
|
4
|
+
They can validate, filter, or modify the LLM content and tool_calls.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import inspect
|
|
8
|
+
from collections.abc import Callable
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
from pydantic import BaseModel
|
|
12
|
+
|
|
13
|
+
from ..types.types import AgentConfig, Step, ToolCall
|
|
14
|
+
from .hook import HookAction, HookResult
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class GuardrailContext(BaseModel):
|
|
18
|
+
"""Context specific to guardrails - what they receive.
|
|
19
|
+
|
|
20
|
+
Guardrails receive the LLM response (content and tool_calls) along with
|
|
21
|
+
execution context to make validation/modification decisions.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
# LLM response data
|
|
25
|
+
content: Any | None = None # LLM response content
|
|
26
|
+
tool_calls: list[ToolCall] | None = None # LLM tool calls
|
|
27
|
+
|
|
28
|
+
# Execution context (for guardrail to make decisions)
|
|
29
|
+
agent_workflow_id: str | None = ""
|
|
30
|
+
agent_run_id: str | None = ""
|
|
31
|
+
session_id: str | None = None
|
|
32
|
+
user_id: str | None = None
|
|
33
|
+
llm_config: AgentConfig = AgentConfig(name="", provider="", model="")
|
|
34
|
+
steps: list[Step] = [] # Previous conversation steps
|
|
35
|
+
|
|
36
|
+
def to_dict(self) -> dict[str, Any]:
|
|
37
|
+
"""Convert to dict for serialization."""
|
|
38
|
+
return self.model_dump(mode="json")
|
|
39
|
+
|
|
40
|
+
@classmethod
|
|
41
|
+
def from_dict(cls, data: Any) -> "GuardrailContext":
|
|
42
|
+
"""Create GuardrailContext from dictionary."""
|
|
43
|
+
if isinstance(data, GuardrailContext):
|
|
44
|
+
return data
|
|
45
|
+
if isinstance(data, dict):
|
|
46
|
+
return cls.model_validate(data)
|
|
47
|
+
raise TypeError(f"Cannot create GuardrailContext from {type(data)}")
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class GuardrailResult(HookResult):
|
|
51
|
+
"""Result from guardrail execution.
|
|
52
|
+
|
|
53
|
+
Inherits from HookResult but adds guardrail-specific modifications
|
|
54
|
+
for LLM content and tool_calls.
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
# Modify LLM response
|
|
58
|
+
modified_content: Any | None = None # Modified content
|
|
59
|
+
modified_tool_calls: list[ToolCall] | None = None # Modified tool calls
|
|
60
|
+
modified_llm_config: AgentConfig | None = None
|
|
61
|
+
|
|
62
|
+
# If modified_tool_calls is empty list [], no tools will be executed
|
|
63
|
+
# If modified_tool_calls is None, original tool_calls are used
|
|
64
|
+
|
|
65
|
+
def to_dict(self) -> dict[str, Any]:
|
|
66
|
+
"""Convert to dict, including guardrail-specific fields."""
|
|
67
|
+
return self.model_dump(mode="json")
|
|
68
|
+
|
|
69
|
+
@classmethod
|
|
70
|
+
def from_dict(cls, data: dict[str, Any]) -> "GuardrailResult":
|
|
71
|
+
"""Create GuardrailResult from dictionary."""
|
|
72
|
+
return cls.model_validate(data)
|
|
73
|
+
|
|
74
|
+
@classmethod
|
|
75
|
+
def continue_with(cls, **modifications) -> "GuardrailResult":
|
|
76
|
+
"""Continue with optional modifications.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
**modifications: Can include modified_content, modified_tool_calls,
|
|
80
|
+
modified_agent_config, etc.
|
|
81
|
+
"""
|
|
82
|
+
return cls(action=HookAction.CONTINUE, **modifications)
|
|
83
|
+
|
|
84
|
+
@classmethod
|
|
85
|
+
def fail(cls, message: str) -> "GuardrailResult":
|
|
86
|
+
"""Fail processing with an error message.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
message: Error message to return
|
|
90
|
+
"""
|
|
91
|
+
return cls(action=HookAction.FAIL, error_message=message)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def _validate_guardrail_signature(func: Callable) -> None:
|
|
95
|
+
"""Validate that guardrail function has correct signature.
|
|
96
|
+
|
|
97
|
+
Expected: (ctx: WorkflowContext, guardrail_context: GuardrailContext) -> GuardrailResult
|
|
98
|
+
|
|
99
|
+
Raises:
|
|
100
|
+
TypeError: If signature is invalid
|
|
101
|
+
"""
|
|
102
|
+
sig = inspect.signature(func)
|
|
103
|
+
params = list(sig.parameters.values())
|
|
104
|
+
|
|
105
|
+
# Must have exactly 2 parameters: ctx and guardrail_context
|
|
106
|
+
if len(params) != 2:
|
|
107
|
+
raise TypeError(
|
|
108
|
+
f"Guardrail function '{func.__name__}' must have exactly 2 parameters: "
|
|
109
|
+
f"(ctx: WorkflowContext, guardrail_context: GuardrailContext). "
|
|
110
|
+
f"Got {len(params)} parameters."
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def guardrail(func: Callable | None = None):
|
|
115
|
+
"""
|
|
116
|
+
Decorator to mark a function as a guardrail.
|
|
117
|
+
|
|
118
|
+
Guardrail functions must have the signature:
|
|
119
|
+
(ctx: WorkflowContext, guardrail_context: GuardrailContext) -> GuardrailResult
|
|
120
|
+
|
|
121
|
+
Usage:
|
|
122
|
+
@guardrail
|
|
123
|
+
def my_guardrail(
|
|
124
|
+
ctx: WorkflowContext, guardrail_context: GuardrailContext
|
|
125
|
+
) -> GuardrailResult:
|
|
126
|
+
return GuardrailResult.continue_with()
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
func: The function to decorate (when used as @guardrail)
|
|
130
|
+
|
|
131
|
+
Returns:
|
|
132
|
+
The function itself (validated)
|
|
133
|
+
|
|
134
|
+
Raises:
|
|
135
|
+
TypeError: If function signature is invalid
|
|
136
|
+
"""
|
|
137
|
+
|
|
138
|
+
def decorator(f: Callable) -> Callable:
|
|
139
|
+
# Validate function signature
|
|
140
|
+
_validate_guardrail_signature(f)
|
|
141
|
+
return f
|
|
142
|
+
|
|
143
|
+
# Handle @guardrail (without parentheses) - the function is passed as the first argument
|
|
144
|
+
if func is not None:
|
|
145
|
+
return decorator(func)
|
|
146
|
+
|
|
147
|
+
# Handle @guardrail() - return decorator
|
|
148
|
+
return decorator
|
|
@@ -0,0 +1,253 @@
|
|
|
1
|
+
"""Guardrail execution infrastructure.
|
|
2
|
+
|
|
3
|
+
Guardrails are executed sequentially within a workflow execution context
|
|
4
|
+
and support durable execution.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import json
|
|
8
|
+
from collections.abc import Callable
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
from ..core.context import WorkflowContext
|
|
12
|
+
from ..core.workflow import _execution_context
|
|
13
|
+
from ..types.types import AgentConfig
|
|
14
|
+
from .guardrail import GuardrailContext, GuardrailResult
|
|
15
|
+
from .hook import HookAction
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _get_guardrail_identifier(guardrail: Callable | str, index: int) -> str:
|
|
19
|
+
"""Get a unique identifier for a guardrail.
|
|
20
|
+
|
|
21
|
+
Uses function name if callable, otherwise uses a string identifier.
|
|
22
|
+
"""
|
|
23
|
+
if isinstance(guardrail, str):
|
|
24
|
+
# For string guardrails, use a truncated version
|
|
25
|
+
return f"guardrail_string_{guardrail[:50]}"
|
|
26
|
+
if hasattr(guardrail, "__name__") and guardrail.__name__ != "<lambda>":
|
|
27
|
+
return guardrail.__name__
|
|
28
|
+
return f"guardrail_{index}"
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
async def execute_guardrails(
|
|
32
|
+
guardrail_name: str,
|
|
33
|
+
guardrails: list[Callable | str],
|
|
34
|
+
guardrail_context: GuardrailContext,
|
|
35
|
+
ctx: WorkflowContext,
|
|
36
|
+
agent_config: AgentConfig | None = None,
|
|
37
|
+
) -> GuardrailResult:
|
|
38
|
+
"""
|
|
39
|
+
Execute a list of guardrails sequentially and return the combined result.
|
|
40
|
+
|
|
41
|
+
Guardrails are executed within a workflow execution context. Each guardrail execution:
|
|
42
|
+
1. Checks for cached result (for durable execution)
|
|
43
|
+
2. If cached, returns cached result
|
|
44
|
+
3. If not cached, executes guardrail and stores result
|
|
45
|
+
|
|
46
|
+
Guardrails can be:
|
|
47
|
+
- Callable: Functions decorated with @guardrail
|
|
48
|
+
- str: String prompts evaluated using LLM with structured output
|
|
49
|
+
|
|
50
|
+
Each guardrail can:
|
|
51
|
+
- Return CONTINUE to proceed to the next guardrail
|
|
52
|
+
- Return FAIL to stop execution with an error
|
|
53
|
+
|
|
54
|
+
Modifications from guardrails are accumulated and applied in order.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
guardrails: List of guardrail callables or strings
|
|
58
|
+
guardrail_context: Context to pass to guardrails
|
|
59
|
+
ctx: WorkflowContext for the current execution
|
|
60
|
+
agent_config: Optional agent config (model, provider, etc.) for string guardrails
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
GuardrailResult with action and any modifications
|
|
64
|
+
|
|
65
|
+
Raises:
|
|
66
|
+
ValueError: If not executed within a workflow execution context
|
|
67
|
+
"""
|
|
68
|
+
if not guardrails:
|
|
69
|
+
return GuardrailResult.continue_with()
|
|
70
|
+
|
|
71
|
+
# Check we're in a workflow execution context
|
|
72
|
+
exec_context = _execution_context.get()
|
|
73
|
+
if not exec_context or not exec_context.get("execution_id"):
|
|
74
|
+
raise ValueError("Guardrails must be executed within a workflow execution context")
|
|
75
|
+
|
|
76
|
+
# Accumulated modifications
|
|
77
|
+
if guardrail_context.content is not None:
|
|
78
|
+
if isinstance(guardrail_context.content, str):
|
|
79
|
+
modified_content = guardrail_context.content
|
|
80
|
+
else:
|
|
81
|
+
modified_content = guardrail_context.content.copy()
|
|
82
|
+
else:
|
|
83
|
+
modified_content = None
|
|
84
|
+
|
|
85
|
+
modified_tool_calls = (
|
|
86
|
+
guardrail_context.tool_calls.copy() if guardrail_context.tool_calls else []
|
|
87
|
+
)
|
|
88
|
+
modified_llm_config = guardrail_context.llm_config.copy()
|
|
89
|
+
|
|
90
|
+
# Execute guardrails sequentially
|
|
91
|
+
for index, guardrail in enumerate(guardrails):
|
|
92
|
+
# Get identifier for durable execution
|
|
93
|
+
guardrail_id = _get_guardrail_identifier(guardrail, index)
|
|
94
|
+
|
|
95
|
+
# Update guardrail context with accumulated modifications
|
|
96
|
+
guardrail_context.content = modified_content
|
|
97
|
+
guardrail_context.tool_calls = modified_tool_calls
|
|
98
|
+
guardrail_context.llm_config = modified_llm_config
|
|
99
|
+
|
|
100
|
+
# Execute guardrail (callable or string)
|
|
101
|
+
if isinstance(guardrail, str):
|
|
102
|
+
# String guardrail: evaluate using LLM
|
|
103
|
+
if not agent_config:
|
|
104
|
+
raise ValueError("agent_config is required for string guardrails")
|
|
105
|
+
|
|
106
|
+
guardrail_result = await _execute_string_guardrail(
|
|
107
|
+
guardrail_name=f"{guardrail_name}:{index}",
|
|
108
|
+
guardrail_string=guardrail,
|
|
109
|
+
guardrail_context=guardrail_context,
|
|
110
|
+
ctx=ctx,
|
|
111
|
+
agent_config=agent_config,
|
|
112
|
+
index=index,
|
|
113
|
+
)
|
|
114
|
+
elif callable(guardrail):
|
|
115
|
+
guardrail_result = await ctx.step.run(
|
|
116
|
+
f"{guardrail_name}.{guardrail_id}.{index}", guardrail, ctx, guardrail_context
|
|
117
|
+
)
|
|
118
|
+
else:
|
|
119
|
+
raise TypeError(
|
|
120
|
+
f"Guardrail at index {index} is neither callable nor string: {type(guardrail)}"
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
# Ensure result is GuardrailResult
|
|
124
|
+
if not isinstance(guardrail_result, GuardrailResult):
|
|
125
|
+
guardrail_result = GuardrailResult.fail(
|
|
126
|
+
f"Guardrail '{guardrail_id}' returned invalid result type: "
|
|
127
|
+
f"{type(guardrail_result)}. Expected GuardrailResult."
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
# Apply modifications
|
|
131
|
+
if guardrail_result.modified_content is not None:
|
|
132
|
+
modified_content = guardrail_result.modified_content
|
|
133
|
+
|
|
134
|
+
if guardrail_result.modified_tool_calls is not None:
|
|
135
|
+
modified_tool_calls = guardrail_result.modified_tool_calls
|
|
136
|
+
|
|
137
|
+
if guardrail_result.modified_llm_config is not None:
|
|
138
|
+
modified_llm_config.update(guardrail_result.modified_llm_config)
|
|
139
|
+
|
|
140
|
+
# Check action
|
|
141
|
+
if guardrail_result.action == HookAction.FAIL:
|
|
142
|
+
# Fail execution - return error
|
|
143
|
+
return guardrail_result
|
|
144
|
+
|
|
145
|
+
# CONTINUE - proceed to next guardrail
|
|
146
|
+
|
|
147
|
+
# All guardrails completed with CONTINUE - return accumulated modifications
|
|
148
|
+
return GuardrailResult.continue_with(
|
|
149
|
+
modified_content=modified_content,
|
|
150
|
+
modified_tool_calls=modified_tool_calls,
|
|
151
|
+
modified_llm_config=modified_llm_config,
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
async def _execute_string_guardrail(
|
|
156
|
+
guardrail_name: str,
|
|
157
|
+
guardrail_string: str,
|
|
158
|
+
guardrail_context: GuardrailContext,
|
|
159
|
+
ctx: WorkflowContext,
|
|
160
|
+
agent_config: dict[str, Any],
|
|
161
|
+
index: int,
|
|
162
|
+
) -> GuardrailResult:
|
|
163
|
+
"""Execute a string guardrail using LLM with structured output.
|
|
164
|
+
|
|
165
|
+
Args:
|
|
166
|
+
guardrail_string: The guardrail prompt/instruction
|
|
167
|
+
guardrail_context: Context containing LLM response to validate
|
|
168
|
+
ctx: WorkflowContext for the current execution
|
|
169
|
+
agent_config: Agent configuration (model, provider, etc.)
|
|
170
|
+
|
|
171
|
+
Returns:
|
|
172
|
+
GuardrailResult indicating pass/fail
|
|
173
|
+
"""
|
|
174
|
+
import json
|
|
175
|
+
|
|
176
|
+
from ..llm import _llm_generate
|
|
177
|
+
|
|
178
|
+
# Create JSON schema for structured output: {passed: bool, reason: Optional[str]}
|
|
179
|
+
guardrail_output_schema = {
|
|
180
|
+
"type": "object",
|
|
181
|
+
"properties": {"passed": {"type": "boolean"}, "reason": {"type": "string", "default": ""}},
|
|
182
|
+
"required": ["passed", "reason"],
|
|
183
|
+
"additionalProperties": False,
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
# Build the evaluation prompt
|
|
187
|
+
# Include the guardrail instruction and the LLM response to validate
|
|
188
|
+
llm_content = guardrail_context.content
|
|
189
|
+
if isinstance(llm_content, str):
|
|
190
|
+
content_to_validate = llm_content
|
|
191
|
+
else:
|
|
192
|
+
content_to_validate = json.dumps(llm_content) if llm_content else ""
|
|
193
|
+
|
|
194
|
+
if not content_to_validate:
|
|
195
|
+
return GuardrailResult.continue_with()
|
|
196
|
+
|
|
197
|
+
evaluation_prompt = f"""{guardrail_string}
|
|
198
|
+
|
|
199
|
+
Please evaluate the following LLM response against the criteria above.
|
|
200
|
+
Return a JSON object with "passed" (boolean) and optionally "reason"
|
|
201
|
+
(string if the guardrail fails).
|
|
202
|
+
|
|
203
|
+
LLM Response to evaluate:
|
|
204
|
+
{content_to_validate}"""
|
|
205
|
+
|
|
206
|
+
# Prepare agent config for guardrail evaluation
|
|
207
|
+
# Use the agent's model, provider, etc., but override output_schema
|
|
208
|
+
config_dict = agent_config.model_dump(mode="json")
|
|
209
|
+
config_dict["output_schema"] = guardrail_output_schema
|
|
210
|
+
config_dict["output_schema_name"] = "GuardrailEvaluationResult"
|
|
211
|
+
guardrail_agent_config = AgentConfig.model_validate(config_dict)
|
|
212
|
+
|
|
213
|
+
# Call _llm_generate to evaluate the guardrail
|
|
214
|
+
llm_result = await _llm_generate(
|
|
215
|
+
ctx,
|
|
216
|
+
{
|
|
217
|
+
"agent_run_id": guardrail_context.agent_run_id,
|
|
218
|
+
"agent_config": guardrail_agent_config.model_dump(mode="json"),
|
|
219
|
+
"input": evaluation_prompt,
|
|
220
|
+
"agent_step": 0, # Guardrail evaluation doesn't count as agent step
|
|
221
|
+
"guardrails": None, # Don't recurse guardrails on guardrail evaluation
|
|
222
|
+
"guardrail_max_retries": 0,
|
|
223
|
+
},
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
return await ctx.step.run(
|
|
227
|
+
f"{guardrail_name}.{guardrail_string[:50]}.{index}", _parse_llm_guardrail_result, llm_result
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
async def _parse_llm_guardrail_result(llm_result: dict[str, Any]) -> GuardrailResult:
|
|
232
|
+
"""Parse the LLM result and return a GuardrailResult."""
|
|
233
|
+
# Parse the structured output
|
|
234
|
+
result_content = llm_result.get("content", "")
|
|
235
|
+
if not result_content:
|
|
236
|
+
return GuardrailResult.fail("Guardrail evaluation returned empty response")
|
|
237
|
+
|
|
238
|
+
try:
|
|
239
|
+
evaluation_result = (
|
|
240
|
+
json.loads(result_content) if isinstance(result_content, str) else result_content
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
passed = evaluation_result.get("passed", False)
|
|
244
|
+
reason = evaluation_result.get("reason")
|
|
245
|
+
|
|
246
|
+
if passed:
|
|
247
|
+
return GuardrailResult.continue_with()
|
|
248
|
+
else:
|
|
249
|
+
error_message = reason or "Guardrail validation failed"
|
|
250
|
+
return GuardrailResult.fail(error_message)
|
|
251
|
+
except (json.JSONDecodeError, TypeError, AttributeError) as e:
|
|
252
|
+
# If parsing fails, treat as failure
|
|
253
|
+
return GuardrailResult.fail(f"Failed to parse guardrail evaluation result: {str(e)}")
|