openai-agents 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of openai-agents might be problematic. Click here for more details.
- agents/__init__.py +5 -1
- agents/_run_impl.py +5 -1
- agents/agent.py +62 -30
- agents/agent_output.py +2 -2
- agents/function_schema.py +11 -1
- agents/guardrail.py +5 -1
- agents/handoffs.py +32 -14
- agents/lifecycle.py +26 -17
- agents/mcp/server.py +82 -11
- agents/mcp/util.py +16 -9
- agents/memory/__init__.py +3 -0
- agents/memory/session.py +369 -0
- agents/model_settings.py +15 -7
- agents/models/chatcmpl_converter.py +20 -3
- agents/models/chatcmpl_stream_handler.py +134 -43
- agents/models/openai_responses.py +12 -5
- agents/realtime/README.md +3 -0
- agents/realtime/__init__.py +177 -0
- agents/realtime/agent.py +89 -0
- agents/realtime/config.py +188 -0
- agents/realtime/events.py +216 -0
- agents/realtime/handoffs.py +165 -0
- agents/realtime/items.py +184 -0
- agents/realtime/model.py +69 -0
- agents/realtime/model_events.py +159 -0
- agents/realtime/model_inputs.py +100 -0
- agents/realtime/openai_realtime.py +670 -0
- agents/realtime/runner.py +118 -0
- agents/realtime/session.py +535 -0
- agents/run.py +106 -4
- agents/tool.py +6 -7
- agents/tool_context.py +16 -3
- agents/voice/models/openai_stt.py +1 -1
- agents/voice/pipeline.py +6 -0
- agents/voice/workflow.py +8 -0
- {openai_agents-0.1.0.dist-info → openai_agents-0.2.1.dist-info}/METADATA +121 -4
- {openai_agents-0.1.0.dist-info → openai_agents-0.2.1.dist-info}/RECORD +39 -24
- {openai_agents-0.1.0.dist-info → openai_agents-0.2.1.dist-info}/WHEEL +0 -0
- {openai_agents-0.1.0.dist-info → openai_agents-0.2.1.dist-info}/licenses/LICENSE +0 -0
agents/__init__.py
CHANGED
|
@@ -5,7 +5,7 @@ from typing import Literal
|
|
|
5
5
|
from openai import AsyncOpenAI
|
|
6
6
|
|
|
7
7
|
from . import _config
|
|
8
|
-
from .agent import Agent, ToolsToFinalOutputFunction, ToolsToFinalOutputResult
|
|
8
|
+
from .agent import Agent, AgentBase, ToolsToFinalOutputFunction, ToolsToFinalOutputResult
|
|
9
9
|
from .agent_output import AgentOutputSchema, AgentOutputSchemaBase
|
|
10
10
|
from .computer import AsyncComputer, Button, Computer, Environment
|
|
11
11
|
from .exceptions import (
|
|
@@ -40,6 +40,7 @@ from .items import (
|
|
|
40
40
|
TResponseInputItem,
|
|
41
41
|
)
|
|
42
42
|
from .lifecycle import AgentHooks, RunHooks
|
|
43
|
+
from .memory import Session, SQLiteSession
|
|
43
44
|
from .model_settings import ModelSettings
|
|
44
45
|
from .models.interface import Model, ModelProvider, ModelTracing
|
|
45
46
|
from .models.openai_chatcompletions import OpenAIChatCompletionsModel
|
|
@@ -160,6 +161,7 @@ def enable_verbose_stdout_logging():
|
|
|
160
161
|
|
|
161
162
|
__all__ = [
|
|
162
163
|
"Agent",
|
|
164
|
+
"AgentBase",
|
|
163
165
|
"ToolsToFinalOutputFunction",
|
|
164
166
|
"ToolsToFinalOutputResult",
|
|
165
167
|
"Runner",
|
|
@@ -209,6 +211,8 @@ __all__ = [
|
|
|
209
211
|
"ItemHelpers",
|
|
210
212
|
"RunHooks",
|
|
211
213
|
"AgentHooks",
|
|
214
|
+
"Session",
|
|
215
|
+
"SQLiteSession",
|
|
212
216
|
"RunContextWrapper",
|
|
213
217
|
"TContext",
|
|
214
218
|
"RunErrorDetails",
|
agents/_run_impl.py
CHANGED
|
@@ -548,7 +548,11 @@ class RunImpl:
|
|
|
548
548
|
func_tool: FunctionTool, tool_call: ResponseFunctionToolCall
|
|
549
549
|
) -> Any:
|
|
550
550
|
with function_span(func_tool.name) as span_fn:
|
|
551
|
-
tool_context = ToolContext.from_agent_context(
|
|
551
|
+
tool_context = ToolContext.from_agent_context(
|
|
552
|
+
context_wrapper,
|
|
553
|
+
tool_call.call_id,
|
|
554
|
+
tool_call=tool_call,
|
|
555
|
+
)
|
|
552
556
|
if config.trace_include_sensitive_data:
|
|
553
557
|
span_fn.span_data.input = tool_call.arguments
|
|
554
558
|
try:
|
agents/agent.py
CHANGED
|
@@ -67,7 +67,63 @@ class MCPConfig(TypedDict):
|
|
|
67
67
|
|
|
68
68
|
|
|
69
69
|
@dataclass
|
|
70
|
-
class
|
|
70
|
+
class AgentBase(Generic[TContext]):
|
|
71
|
+
"""Base class for `Agent` and `RealtimeAgent`."""
|
|
72
|
+
|
|
73
|
+
name: str
|
|
74
|
+
"""The name of the agent."""
|
|
75
|
+
|
|
76
|
+
handoff_description: str | None = None
|
|
77
|
+
"""A description of the agent. This is used when the agent is used as a handoff, so that an
|
|
78
|
+
LLM knows what it does and when to invoke it.
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
tools: list[Tool] = field(default_factory=list)
|
|
82
|
+
"""A list of tools that the agent can use."""
|
|
83
|
+
|
|
84
|
+
mcp_servers: list[MCPServer] = field(default_factory=list)
|
|
85
|
+
"""A list of [Model Context Protocol](https://modelcontextprotocol.io/) servers that
|
|
86
|
+
the agent can use. Every time the agent runs, it will include tools from these servers in the
|
|
87
|
+
list of available tools.
|
|
88
|
+
|
|
89
|
+
NOTE: You are expected to manage the lifecycle of these servers. Specifically, you must call
|
|
90
|
+
`server.connect()` before passing it to the agent, and `server.cleanup()` when the server is no
|
|
91
|
+
longer needed.
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
mcp_config: MCPConfig = field(default_factory=lambda: MCPConfig())
|
|
95
|
+
"""Configuration for MCP servers."""
|
|
96
|
+
|
|
97
|
+
async def get_mcp_tools(self, run_context: RunContextWrapper[TContext]) -> list[Tool]:
|
|
98
|
+
"""Fetches the available tools from the MCP servers."""
|
|
99
|
+
convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False)
|
|
100
|
+
return await MCPUtil.get_all_function_tools(
|
|
101
|
+
self.mcp_servers, convert_schemas_to_strict, run_context, self
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
async def get_all_tools(self, run_context: RunContextWrapper[Any]) -> list[Tool]:
|
|
105
|
+
"""All agent tools, including MCP tools and function tools."""
|
|
106
|
+
mcp_tools = await self.get_mcp_tools(run_context)
|
|
107
|
+
|
|
108
|
+
async def _check_tool_enabled(tool: Tool) -> bool:
|
|
109
|
+
if not isinstance(tool, FunctionTool):
|
|
110
|
+
return True
|
|
111
|
+
|
|
112
|
+
attr = tool.is_enabled
|
|
113
|
+
if isinstance(attr, bool):
|
|
114
|
+
return attr
|
|
115
|
+
res = attr(run_context, self)
|
|
116
|
+
if inspect.isawaitable(res):
|
|
117
|
+
return bool(await res)
|
|
118
|
+
return bool(res)
|
|
119
|
+
|
|
120
|
+
results = await asyncio.gather(*(_check_tool_enabled(t) for t in self.tools))
|
|
121
|
+
enabled: list[Tool] = [t for t, ok in zip(self.tools, results) if ok]
|
|
122
|
+
return [*mcp_tools, *enabled]
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
@dataclass
|
|
126
|
+
class Agent(AgentBase, Generic[TContext]):
|
|
71
127
|
"""An agent is an AI model configured with instructions, tools, guardrails, handoffs and more.
|
|
72
128
|
|
|
73
129
|
We strongly recommend passing `instructions`, which is the "system prompt" for the agent. In
|
|
@@ -76,10 +132,9 @@ class Agent(Generic[TContext]):
|
|
|
76
132
|
|
|
77
133
|
Agents are generic on the context type. The context is a (mutable) object you create. It is
|
|
78
134
|
passed to tool functions, handoffs, guardrails, etc.
|
|
79
|
-
"""
|
|
80
135
|
|
|
81
|
-
|
|
82
|
-
"""
|
|
136
|
+
See `AgentBase` for base parameters that are shared with `RealtimeAgent`s.
|
|
137
|
+
"""
|
|
83
138
|
|
|
84
139
|
instructions: (
|
|
85
140
|
str
|
|
@@ -103,12 +158,7 @@ class Agent(Generic[TContext]):
|
|
|
103
158
|
usable with OpenAI models, using the Responses API.
|
|
104
159
|
"""
|
|
105
160
|
|
|
106
|
-
|
|
107
|
-
"""A description of the agent. This is used when the agent is used as a handoff, so that an
|
|
108
|
-
LLM knows what it does and when to invoke it.
|
|
109
|
-
"""
|
|
110
|
-
|
|
111
|
-
handoffs: list[Agent[Any] | Handoff[TContext]] = field(default_factory=list)
|
|
161
|
+
handoffs: list[Agent[Any] | Handoff[TContext, Any]] = field(default_factory=list)
|
|
112
162
|
"""Handoffs are sub-agents that the agent can delegate to. You can provide a list of handoffs,
|
|
113
163
|
and the agent can choose to delegate to them if relevant. Allows for separation of concerns and
|
|
114
164
|
modularity.
|
|
@@ -125,22 +175,6 @@ class Agent(Generic[TContext]):
|
|
|
125
175
|
"""Configures model-specific tuning parameters (e.g. temperature, top_p).
|
|
126
176
|
"""
|
|
127
177
|
|
|
128
|
-
tools: list[Tool] = field(default_factory=list)
|
|
129
|
-
"""A list of tools that the agent can use."""
|
|
130
|
-
|
|
131
|
-
mcp_servers: list[MCPServer] = field(default_factory=list)
|
|
132
|
-
"""A list of [Model Context Protocol](https://modelcontextprotocol.io/) servers that
|
|
133
|
-
the agent can use. Every time the agent runs, it will include tools from these servers in the
|
|
134
|
-
list of available tools.
|
|
135
|
-
|
|
136
|
-
NOTE: You are expected to manage the lifecycle of these servers. Specifically, you must call
|
|
137
|
-
`server.connect()` before passing it to the agent, and `server.cleanup()` when the server is no
|
|
138
|
-
longer needed.
|
|
139
|
-
"""
|
|
140
|
-
|
|
141
|
-
mcp_config: MCPConfig = field(default_factory=lambda: MCPConfig())
|
|
142
|
-
"""Configuration for MCP servers."""
|
|
143
|
-
|
|
144
178
|
input_guardrails: list[InputGuardrail[TContext]] = field(default_factory=list)
|
|
145
179
|
"""A list of checks that run in parallel to the agent's execution, before generating a
|
|
146
180
|
response. Runs only if the agent is the first agent in the chain.
|
|
@@ -176,7 +210,7 @@ class Agent(Generic[TContext]):
|
|
|
176
210
|
The final output will be the output of the first matching tool call. The LLM does not
|
|
177
211
|
process the result of the tool call.
|
|
178
212
|
- A function: If you pass a function, it will be called with the run context and the list of
|
|
179
|
-
tool results. It must return a `
|
|
213
|
+
tool results. It must return a `ToolsToFinalOutputResult`, which determines whether the tool
|
|
180
214
|
calls result in a final output.
|
|
181
215
|
|
|
182
216
|
NOTE: This configuration is specific to FunctionTools. Hosted tools, such as file search,
|
|
@@ -256,9 +290,7 @@ class Agent(Generic[TContext]):
|
|
|
256
290
|
"""Get the prompt for the agent."""
|
|
257
291
|
return await PromptUtil.to_model_input(self.prompt, run_context, self)
|
|
258
292
|
|
|
259
|
-
async def get_mcp_tools(
|
|
260
|
-
self, run_context: RunContextWrapper[TContext]
|
|
261
|
-
) -> list[Tool]:
|
|
293
|
+
async def get_mcp_tools(self, run_context: RunContextWrapper[TContext]) -> list[Tool]:
|
|
262
294
|
"""Fetches the available tools from the MCP servers."""
|
|
263
295
|
convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False)
|
|
264
296
|
return await MCPUtil.get_all_function_tools(
|
agents/agent_output.py
CHANGED
|
@@ -115,8 +115,8 @@ class AgentOutputSchema(AgentOutputSchemaBase):
|
|
|
115
115
|
except UserError as e:
|
|
116
116
|
raise UserError(
|
|
117
117
|
"Strict JSON schema is enabled, but the output type is not valid. "
|
|
118
|
-
"Either make the output type strict,
|
|
119
|
-
"your
|
|
118
|
+
"Either make the output type strict, "
|
|
119
|
+
"or wrap your type with AgentOutputSchema(your_type, strict_json_schema=False)"
|
|
120
120
|
) from e
|
|
121
121
|
|
|
122
122
|
def is_plain_text(self) -> bool:
|
agents/function_schema.py
CHANGED
|
@@ -9,6 +9,7 @@ from typing import Any, Callable, Literal, get_args, get_origin, get_type_hints
|
|
|
9
9
|
|
|
10
10
|
from griffe import Docstring, DocstringSectionKind
|
|
11
11
|
from pydantic import BaseModel, Field, create_model
|
|
12
|
+
from pydantic.fields import FieldInfo
|
|
12
13
|
|
|
13
14
|
from .exceptions import UserError
|
|
14
15
|
from .run_context import RunContextWrapper
|
|
@@ -319,6 +320,14 @@ def function_schema(
|
|
|
319
320
|
ann,
|
|
320
321
|
Field(..., description=field_description),
|
|
321
322
|
)
|
|
323
|
+
elif isinstance(default, FieldInfo):
|
|
324
|
+
# Parameter with a default value that is a Field(...)
|
|
325
|
+
fields[name] = (
|
|
326
|
+
ann,
|
|
327
|
+
FieldInfo.merge_field_infos(
|
|
328
|
+
default, description=field_description or default.description
|
|
329
|
+
),
|
|
330
|
+
)
|
|
322
331
|
else:
|
|
323
332
|
# Parameter with a default value
|
|
324
333
|
fields[name] = (
|
|
@@ -337,7 +346,8 @@ def function_schema(
|
|
|
337
346
|
# 5. Return as a FuncSchema dataclass
|
|
338
347
|
return FuncSchema(
|
|
339
348
|
name=func_name,
|
|
340
|
-
|
|
349
|
+
# Ensure description_override takes precedence even if docstring info is disabled.
|
|
350
|
+
description=description_override or (doc_info.description if doc_info else None),
|
|
341
351
|
params_pydantic_model=dynamic_model,
|
|
342
352
|
params_json_schema=json_schema,
|
|
343
353
|
signature=sig,
|
agents/guardrail.py
CHANGED
|
@@ -241,7 +241,11 @@ def input_guardrail(
|
|
|
241
241
|
def decorator(
|
|
242
242
|
f: _InputGuardrailFuncSync[TContext_co] | _InputGuardrailFuncAsync[TContext_co],
|
|
243
243
|
) -> InputGuardrail[TContext_co]:
|
|
244
|
-
return InputGuardrail(
|
|
244
|
+
return InputGuardrail(
|
|
245
|
+
guardrail_function=f,
|
|
246
|
+
# If not set, guardrail name uses the function’s name by default.
|
|
247
|
+
name=name if name else f.__name__,
|
|
248
|
+
)
|
|
245
249
|
|
|
246
250
|
if func is not None:
|
|
247
251
|
# Decorator was used without parentheses
|
agents/handoffs.py
CHANGED
|
@@ -18,12 +18,15 @@ from .util import _error_tracing, _json, _transforms
|
|
|
18
18
|
from .util._types import MaybeAwaitable
|
|
19
19
|
|
|
20
20
|
if TYPE_CHECKING:
|
|
21
|
-
from .agent import Agent
|
|
21
|
+
from .agent import Agent, AgentBase
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
# The handoff input type is the type of data passed when the agent is called via a handoff.
|
|
25
25
|
THandoffInput = TypeVar("THandoffInput", default=Any)
|
|
26
26
|
|
|
27
|
+
# The agent type that the handoff returns
|
|
28
|
+
TAgent = TypeVar("TAgent", bound="AgentBase[Any]", default="Agent[Any]")
|
|
29
|
+
|
|
27
30
|
OnHandoffWithInput = Callable[[RunContextWrapper[Any], THandoffInput], Any]
|
|
28
31
|
OnHandoffWithoutInput = Callable[[RunContextWrapper[Any]], Any]
|
|
29
32
|
|
|
@@ -52,7 +55,7 @@ HandoffInputFilter: TypeAlias = Callable[[HandoffInputData], HandoffInputData]
|
|
|
52
55
|
|
|
53
56
|
|
|
54
57
|
@dataclass
|
|
55
|
-
class Handoff(Generic[TContext]):
|
|
58
|
+
class Handoff(Generic[TContext, TAgent]):
|
|
56
59
|
"""A handoff is when an agent delegates a task to another agent.
|
|
57
60
|
For example, in a customer support scenario you might have a "triage agent" that determines
|
|
58
61
|
which agent should handle the user's request, and sub-agents that specialize in different
|
|
@@ -69,7 +72,7 @@ class Handoff(Generic[TContext]):
|
|
|
69
72
|
"""The JSON schema for the handoff input. Can be empty if the handoff does not take an input.
|
|
70
73
|
"""
|
|
71
74
|
|
|
72
|
-
on_invoke_handoff: Callable[[RunContextWrapper[Any], str], Awaitable[
|
|
75
|
+
on_invoke_handoff: Callable[[RunContextWrapper[Any], str], Awaitable[TAgent]]
|
|
73
76
|
"""The function that invokes the handoff. The parameters passed are:
|
|
74
77
|
1. The handoff run context
|
|
75
78
|
2. The arguments from the LLM, as a JSON string. Empty string if input_json_schema is empty.
|
|
@@ -100,20 +103,22 @@ class Handoff(Generic[TContext]):
|
|
|
100
103
|
True, as it increases the likelihood of correct JSON input.
|
|
101
104
|
"""
|
|
102
105
|
|
|
103
|
-
is_enabled: bool | Callable[[RunContextWrapper[Any],
|
|
106
|
+
is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase[Any]], MaybeAwaitable[bool]] = (
|
|
107
|
+
True
|
|
108
|
+
)
|
|
104
109
|
"""Whether the handoff is enabled. Either a bool or a Callable that takes the run context and
|
|
105
110
|
agent and returns whether the handoff is enabled. You can use this to dynamically enable/disable
|
|
106
111
|
a handoff based on your context/state."""
|
|
107
112
|
|
|
108
|
-
def get_transfer_message(self, agent:
|
|
113
|
+
def get_transfer_message(self, agent: AgentBase[Any]) -> str:
|
|
109
114
|
return json.dumps({"assistant": agent.name})
|
|
110
115
|
|
|
111
116
|
@classmethod
|
|
112
|
-
def default_tool_name(cls, agent:
|
|
117
|
+
def default_tool_name(cls, agent: AgentBase[Any]) -> str:
|
|
113
118
|
return _transforms.transform_string_function_style(f"transfer_to_{agent.name}")
|
|
114
119
|
|
|
115
120
|
@classmethod
|
|
116
|
-
def default_tool_description(cls, agent:
|
|
121
|
+
def default_tool_description(cls, agent: AgentBase[Any]) -> str:
|
|
117
122
|
return (
|
|
118
123
|
f"Handoff to the {agent.name} agent to handle the request. "
|
|
119
124
|
f"{agent.handoff_description or ''}"
|
|
@@ -128,7 +133,7 @@ def handoff(
|
|
|
128
133
|
tool_description_override: str | None = None,
|
|
129
134
|
input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None,
|
|
130
135
|
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
|
|
131
|
-
) -> Handoff[TContext]: ...
|
|
136
|
+
) -> Handoff[TContext, Agent[TContext]]: ...
|
|
132
137
|
|
|
133
138
|
|
|
134
139
|
@overload
|
|
@@ -141,7 +146,7 @@ def handoff(
|
|
|
141
146
|
tool_name_override: str | None = None,
|
|
142
147
|
input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None,
|
|
143
148
|
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
|
|
144
|
-
) -> Handoff[TContext]: ...
|
|
149
|
+
) -> Handoff[TContext, Agent[TContext]]: ...
|
|
145
150
|
|
|
146
151
|
|
|
147
152
|
@overload
|
|
@@ -153,7 +158,7 @@ def handoff(
|
|
|
153
158
|
tool_name_override: str | None = None,
|
|
154
159
|
input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None,
|
|
155
160
|
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
|
|
156
|
-
) -> Handoff[TContext]: ...
|
|
161
|
+
) -> Handoff[TContext, Agent[TContext]]: ...
|
|
157
162
|
|
|
158
163
|
|
|
159
164
|
def handoff(
|
|
@@ -163,8 +168,9 @@ def handoff(
|
|
|
163
168
|
on_handoff: OnHandoffWithInput[THandoffInput] | OnHandoffWithoutInput | None = None,
|
|
164
169
|
input_type: type[THandoffInput] | None = None,
|
|
165
170
|
input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None,
|
|
166
|
-
is_enabled: bool
|
|
167
|
-
|
|
171
|
+
is_enabled: bool
|
|
172
|
+
| Callable[[RunContextWrapper[Any], Agent[TContext]], MaybeAwaitable[bool]] = True,
|
|
173
|
+
) -> Handoff[TContext, Agent[TContext]]:
|
|
168
174
|
"""Create a handoff from an agent.
|
|
169
175
|
|
|
170
176
|
Args:
|
|
@@ -202,7 +208,7 @@ def handoff(
|
|
|
202
208
|
|
|
203
209
|
async def _invoke_handoff(
|
|
204
210
|
ctx: RunContextWrapper[Any], input_json: str | None = None
|
|
205
|
-
) -> Agent[
|
|
211
|
+
) -> Agent[TContext]:
|
|
206
212
|
if input_type is not None and type_adapter is not None:
|
|
207
213
|
if input_json is None:
|
|
208
214
|
_error_tracing.attach_error_to_current_span(
|
|
@@ -239,6 +245,18 @@ def handoff(
|
|
|
239
245
|
# If there is a need, we can make this configurable in the future
|
|
240
246
|
input_json_schema = ensure_strict_json_schema(input_json_schema)
|
|
241
247
|
|
|
248
|
+
async def _is_enabled(ctx: RunContextWrapper[Any], agent_base: AgentBase[Any]) -> bool:
|
|
249
|
+
from .agent import Agent
|
|
250
|
+
|
|
251
|
+
assert callable(is_enabled), "is_enabled must be non-null here"
|
|
252
|
+
assert isinstance(agent_base, Agent), "Can't handoff to a non-Agent"
|
|
253
|
+
result = is_enabled(ctx, agent_base)
|
|
254
|
+
|
|
255
|
+
if inspect.isawaitable(result):
|
|
256
|
+
return await result
|
|
257
|
+
|
|
258
|
+
return result
|
|
259
|
+
|
|
242
260
|
return Handoff(
|
|
243
261
|
tool_name=tool_name,
|
|
244
262
|
tool_description=tool_description,
|
|
@@ -246,5 +264,5 @@ def handoff(
|
|
|
246
264
|
on_invoke_handoff=_invoke_handoff,
|
|
247
265
|
input_filter=input_filter,
|
|
248
266
|
agent_name=agent.name,
|
|
249
|
-
is_enabled=is_enabled,
|
|
267
|
+
is_enabled=_is_enabled if callable(is_enabled) else is_enabled,
|
|
250
268
|
)
|
agents/lifecycle.py
CHANGED
|
@@ -1,25 +1,27 @@
|
|
|
1
1
|
from typing import Any, Generic
|
|
2
2
|
|
|
3
|
-
from
|
|
3
|
+
from typing_extensions import TypeVar
|
|
4
|
+
|
|
5
|
+
from .agent import Agent, AgentBase
|
|
4
6
|
from .run_context import RunContextWrapper, TContext
|
|
5
7
|
from .tool import Tool
|
|
6
8
|
|
|
9
|
+
TAgent = TypeVar("TAgent", bound=AgentBase, default=AgentBase)
|
|
10
|
+
|
|
7
11
|
|
|
8
|
-
class
|
|
12
|
+
class RunHooksBase(Generic[TContext, TAgent]):
|
|
9
13
|
"""A class that receives callbacks on various lifecycle events in an agent run. Subclass and
|
|
10
14
|
override the methods you need.
|
|
11
15
|
"""
|
|
12
16
|
|
|
13
|
-
async def on_agent_start(
|
|
14
|
-
self, context: RunContextWrapper[TContext], agent: Agent[TContext]
|
|
15
|
-
) -> None:
|
|
17
|
+
async def on_agent_start(self, context: RunContextWrapper[TContext], agent: TAgent) -> None:
|
|
16
18
|
"""Called before the agent is invoked. Called each time the current agent changes."""
|
|
17
19
|
pass
|
|
18
20
|
|
|
19
21
|
async def on_agent_end(
|
|
20
22
|
self,
|
|
21
23
|
context: RunContextWrapper[TContext],
|
|
22
|
-
agent:
|
|
24
|
+
agent: TAgent,
|
|
23
25
|
output: Any,
|
|
24
26
|
) -> None:
|
|
25
27
|
"""Called when the agent produces a final output."""
|
|
@@ -28,8 +30,8 @@ class RunHooks(Generic[TContext]):
|
|
|
28
30
|
async def on_handoff(
|
|
29
31
|
self,
|
|
30
32
|
context: RunContextWrapper[TContext],
|
|
31
|
-
from_agent:
|
|
32
|
-
to_agent:
|
|
33
|
+
from_agent: TAgent,
|
|
34
|
+
to_agent: TAgent,
|
|
33
35
|
) -> None:
|
|
34
36
|
"""Called when a handoff occurs."""
|
|
35
37
|
pass
|
|
@@ -37,7 +39,7 @@ class RunHooks(Generic[TContext]):
|
|
|
37
39
|
async def on_tool_start(
|
|
38
40
|
self,
|
|
39
41
|
context: RunContextWrapper[TContext],
|
|
40
|
-
agent:
|
|
42
|
+
agent: TAgent,
|
|
41
43
|
tool: Tool,
|
|
42
44
|
) -> None:
|
|
43
45
|
"""Called before a tool is invoked."""
|
|
@@ -46,7 +48,7 @@ class RunHooks(Generic[TContext]):
|
|
|
46
48
|
async def on_tool_end(
|
|
47
49
|
self,
|
|
48
50
|
context: RunContextWrapper[TContext],
|
|
49
|
-
agent:
|
|
51
|
+
agent: TAgent,
|
|
50
52
|
tool: Tool,
|
|
51
53
|
result: str,
|
|
52
54
|
) -> None:
|
|
@@ -54,14 +56,14 @@ class RunHooks(Generic[TContext]):
|
|
|
54
56
|
pass
|
|
55
57
|
|
|
56
58
|
|
|
57
|
-
class
|
|
59
|
+
class AgentHooksBase(Generic[TContext, TAgent]):
|
|
58
60
|
"""A class that receives callbacks on various lifecycle events for a specific agent. You can
|
|
59
61
|
set this on `agent.hooks` to receive events for that specific agent.
|
|
60
62
|
|
|
61
63
|
Subclass and override the methods you need.
|
|
62
64
|
"""
|
|
63
65
|
|
|
64
|
-
async def on_start(self, context: RunContextWrapper[TContext], agent:
|
|
66
|
+
async def on_start(self, context: RunContextWrapper[TContext], agent: TAgent) -> None:
|
|
65
67
|
"""Called before the agent is invoked. Called each time the running agent is changed to this
|
|
66
68
|
agent."""
|
|
67
69
|
pass
|
|
@@ -69,7 +71,7 @@ class AgentHooks(Generic[TContext]):
|
|
|
69
71
|
async def on_end(
|
|
70
72
|
self,
|
|
71
73
|
context: RunContextWrapper[TContext],
|
|
72
|
-
agent:
|
|
74
|
+
agent: TAgent,
|
|
73
75
|
output: Any,
|
|
74
76
|
) -> None:
|
|
75
77
|
"""Called when the agent produces a final output."""
|
|
@@ -78,8 +80,8 @@ class AgentHooks(Generic[TContext]):
|
|
|
78
80
|
async def on_handoff(
|
|
79
81
|
self,
|
|
80
82
|
context: RunContextWrapper[TContext],
|
|
81
|
-
agent:
|
|
82
|
-
source:
|
|
83
|
+
agent: TAgent,
|
|
84
|
+
source: TAgent,
|
|
83
85
|
) -> None:
|
|
84
86
|
"""Called when the agent is being handed off to. The `source` is the agent that is handing
|
|
85
87
|
off to this agent."""
|
|
@@ -88,7 +90,7 @@ class AgentHooks(Generic[TContext]):
|
|
|
88
90
|
async def on_tool_start(
|
|
89
91
|
self,
|
|
90
92
|
context: RunContextWrapper[TContext],
|
|
91
|
-
agent:
|
|
93
|
+
agent: TAgent,
|
|
92
94
|
tool: Tool,
|
|
93
95
|
) -> None:
|
|
94
96
|
"""Called before a tool is invoked."""
|
|
@@ -97,9 +99,16 @@ class AgentHooks(Generic[TContext]):
|
|
|
97
99
|
async def on_tool_end(
|
|
98
100
|
self,
|
|
99
101
|
context: RunContextWrapper[TContext],
|
|
100
|
-
agent:
|
|
102
|
+
agent: TAgent,
|
|
101
103
|
tool: Tool,
|
|
102
104
|
result: str,
|
|
103
105
|
) -> None:
|
|
104
106
|
"""Called after a tool is invoked."""
|
|
105
107
|
pass
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
RunHooks = RunHooksBase[TContext, Agent]
|
|
111
|
+
"""Run hooks when using `Agent`."""
|
|
112
|
+
|
|
113
|
+
AgentHooks = AgentHooksBase[TContext, Agent]
|
|
114
|
+
"""Agent hooks for `Agent`s."""
|