openai-agents 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of openai-agents might be problematic. Click here for more details.

Files changed (39) hide show
  1. agents/__init__.py +5 -1
  2. agents/_run_impl.py +5 -1
  3. agents/agent.py +62 -30
  4. agents/agent_output.py +2 -2
  5. agents/function_schema.py +11 -1
  6. agents/guardrail.py +5 -1
  7. agents/handoffs.py +32 -14
  8. agents/lifecycle.py +26 -17
  9. agents/mcp/server.py +82 -11
  10. agents/mcp/util.py +16 -9
  11. agents/memory/__init__.py +3 -0
  12. agents/memory/session.py +369 -0
  13. agents/model_settings.py +15 -7
  14. agents/models/chatcmpl_converter.py +20 -3
  15. agents/models/chatcmpl_stream_handler.py +134 -43
  16. agents/models/openai_responses.py +12 -5
  17. agents/realtime/README.md +3 -0
  18. agents/realtime/__init__.py +177 -0
  19. agents/realtime/agent.py +89 -0
  20. agents/realtime/config.py +188 -0
  21. agents/realtime/events.py +216 -0
  22. agents/realtime/handoffs.py +165 -0
  23. agents/realtime/items.py +184 -0
  24. agents/realtime/model.py +69 -0
  25. agents/realtime/model_events.py +159 -0
  26. agents/realtime/model_inputs.py +100 -0
  27. agents/realtime/openai_realtime.py +670 -0
  28. agents/realtime/runner.py +118 -0
  29. agents/realtime/session.py +535 -0
  30. agents/run.py +106 -4
  31. agents/tool.py +6 -7
  32. agents/tool_context.py +16 -3
  33. agents/voice/models/openai_stt.py +1 -1
  34. agents/voice/pipeline.py +6 -0
  35. agents/voice/workflow.py +8 -0
  36. {openai_agents-0.1.0.dist-info → openai_agents-0.2.1.dist-info}/METADATA +121 -4
  37. {openai_agents-0.1.0.dist-info → openai_agents-0.2.1.dist-info}/RECORD +39 -24
  38. {openai_agents-0.1.0.dist-info → openai_agents-0.2.1.dist-info}/WHEEL +0 -0
  39. {openai_agents-0.1.0.dist-info → openai_agents-0.2.1.dist-info}/licenses/LICENSE +0 -0
agents/__init__.py CHANGED
@@ -5,7 +5,7 @@ from typing import Literal
5
5
  from openai import AsyncOpenAI
6
6
 
7
7
  from . import _config
8
- from .agent import Agent, ToolsToFinalOutputFunction, ToolsToFinalOutputResult
8
+ from .agent import Agent, AgentBase, ToolsToFinalOutputFunction, ToolsToFinalOutputResult
9
9
  from .agent_output import AgentOutputSchema, AgentOutputSchemaBase
10
10
  from .computer import AsyncComputer, Button, Computer, Environment
11
11
  from .exceptions import (
@@ -40,6 +40,7 @@ from .items import (
40
40
  TResponseInputItem,
41
41
  )
42
42
  from .lifecycle import AgentHooks, RunHooks
43
+ from .memory import Session, SQLiteSession
43
44
  from .model_settings import ModelSettings
44
45
  from .models.interface import Model, ModelProvider, ModelTracing
45
46
  from .models.openai_chatcompletions import OpenAIChatCompletionsModel
@@ -160,6 +161,7 @@ def enable_verbose_stdout_logging():
160
161
 
161
162
  __all__ = [
162
163
  "Agent",
164
+ "AgentBase",
163
165
  "ToolsToFinalOutputFunction",
164
166
  "ToolsToFinalOutputResult",
165
167
  "Runner",
@@ -209,6 +211,8 @@ __all__ = [
209
211
  "ItemHelpers",
210
212
  "RunHooks",
211
213
  "AgentHooks",
214
+ "Session",
215
+ "SQLiteSession",
212
216
  "RunContextWrapper",
213
217
  "TContext",
214
218
  "RunErrorDetails",
agents/_run_impl.py CHANGED
@@ -548,7 +548,11 @@ class RunImpl:
548
548
  func_tool: FunctionTool, tool_call: ResponseFunctionToolCall
549
549
  ) -> Any:
550
550
  with function_span(func_tool.name) as span_fn:
551
- tool_context = ToolContext.from_agent_context(context_wrapper, tool_call.call_id)
551
+ tool_context = ToolContext.from_agent_context(
552
+ context_wrapper,
553
+ tool_call.call_id,
554
+ tool_call=tool_call,
555
+ )
552
556
  if config.trace_include_sensitive_data:
553
557
  span_fn.span_data.input = tool_call.arguments
554
558
  try:
agents/agent.py CHANGED
@@ -67,7 +67,63 @@ class MCPConfig(TypedDict):
67
67
 
68
68
 
69
69
  @dataclass
70
- class Agent(Generic[TContext]):
70
+ class AgentBase(Generic[TContext]):
71
+ """Base class for `Agent` and `RealtimeAgent`."""
72
+
73
+ name: str
74
+ """The name of the agent."""
75
+
76
+ handoff_description: str | None = None
77
+ """A description of the agent. This is used when the agent is used as a handoff, so that an
78
+ LLM knows what it does and when to invoke it.
79
+ """
80
+
81
+ tools: list[Tool] = field(default_factory=list)
82
+ """A list of tools that the agent can use."""
83
+
84
+ mcp_servers: list[MCPServer] = field(default_factory=list)
85
+ """A list of [Model Context Protocol](https://modelcontextprotocol.io/) servers that
86
+ the agent can use. Every time the agent runs, it will include tools from these servers in the
87
+ list of available tools.
88
+
89
+ NOTE: You are expected to manage the lifecycle of these servers. Specifically, you must call
90
+ `server.connect()` before passing it to the agent, and `server.cleanup()` when the server is no
91
+ longer needed.
92
+ """
93
+
94
+ mcp_config: MCPConfig = field(default_factory=lambda: MCPConfig())
95
+ """Configuration for MCP servers."""
96
+
97
+ async def get_mcp_tools(self, run_context: RunContextWrapper[TContext]) -> list[Tool]:
98
+ """Fetches the available tools from the MCP servers."""
99
+ convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False)
100
+ return await MCPUtil.get_all_function_tools(
101
+ self.mcp_servers, convert_schemas_to_strict, run_context, self
102
+ )
103
+
104
+ async def get_all_tools(self, run_context: RunContextWrapper[Any]) -> list[Tool]:
105
+ """All agent tools, including MCP tools and function tools."""
106
+ mcp_tools = await self.get_mcp_tools(run_context)
107
+
108
+ async def _check_tool_enabled(tool: Tool) -> bool:
109
+ if not isinstance(tool, FunctionTool):
110
+ return True
111
+
112
+ attr = tool.is_enabled
113
+ if isinstance(attr, bool):
114
+ return attr
115
+ res = attr(run_context, self)
116
+ if inspect.isawaitable(res):
117
+ return bool(await res)
118
+ return bool(res)
119
+
120
+ results = await asyncio.gather(*(_check_tool_enabled(t) for t in self.tools))
121
+ enabled: list[Tool] = [t for t, ok in zip(self.tools, results) if ok]
122
+ return [*mcp_tools, *enabled]
123
+
124
+
125
+ @dataclass
126
+ class Agent(AgentBase, Generic[TContext]):
71
127
  """An agent is an AI model configured with instructions, tools, guardrails, handoffs and more.
72
128
 
73
129
  We strongly recommend passing `instructions`, which is the "system prompt" for the agent. In
@@ -76,10 +132,9 @@ class Agent(Generic[TContext]):
76
132
 
77
133
  Agents are generic on the context type. The context is a (mutable) object you create. It is
78
134
  passed to tool functions, handoffs, guardrails, etc.
79
- """
80
135
 
81
- name: str
82
- """The name of the agent."""
136
+ See `AgentBase` for base parameters that are shared with `RealtimeAgent`s.
137
+ """
83
138
 
84
139
  instructions: (
85
140
  str
@@ -103,12 +158,7 @@ class Agent(Generic[TContext]):
103
158
  usable with OpenAI models, using the Responses API.
104
159
  """
105
160
 
106
- handoff_description: str | None = None
107
- """A description of the agent. This is used when the agent is used as a handoff, so that an
108
- LLM knows what it does and when to invoke it.
109
- """
110
-
111
- handoffs: list[Agent[Any] | Handoff[TContext]] = field(default_factory=list)
161
+ handoffs: list[Agent[Any] | Handoff[TContext, Any]] = field(default_factory=list)
112
162
  """Handoffs are sub-agents that the agent can delegate to. You can provide a list of handoffs,
113
163
  and the agent can choose to delegate to them if relevant. Allows for separation of concerns and
114
164
  modularity.
@@ -125,22 +175,6 @@ class Agent(Generic[TContext]):
125
175
  """Configures model-specific tuning parameters (e.g. temperature, top_p).
126
176
  """
127
177
 
128
- tools: list[Tool] = field(default_factory=list)
129
- """A list of tools that the agent can use."""
130
-
131
- mcp_servers: list[MCPServer] = field(default_factory=list)
132
- """A list of [Model Context Protocol](https://modelcontextprotocol.io/) servers that
133
- the agent can use. Every time the agent runs, it will include tools from these servers in the
134
- list of available tools.
135
-
136
- NOTE: You are expected to manage the lifecycle of these servers. Specifically, you must call
137
- `server.connect()` before passing it to the agent, and `server.cleanup()` when the server is no
138
- longer needed.
139
- """
140
-
141
- mcp_config: MCPConfig = field(default_factory=lambda: MCPConfig())
142
- """Configuration for MCP servers."""
143
-
144
178
  input_guardrails: list[InputGuardrail[TContext]] = field(default_factory=list)
145
179
  """A list of checks that run in parallel to the agent's execution, before generating a
146
180
  response. Runs only if the agent is the first agent in the chain.
@@ -176,7 +210,7 @@ class Agent(Generic[TContext]):
176
210
  The final output will be the output of the first matching tool call. The LLM does not
177
211
  process the result of the tool call.
178
212
  - A function: If you pass a function, it will be called with the run context and the list of
179
- tool results. It must return a `ToolToFinalOutputResult`, which determines whether the tool
213
+ tool results. It must return a `ToolsToFinalOutputResult`, which determines whether the tool
180
214
  calls result in a final output.
181
215
 
182
216
  NOTE: This configuration is specific to FunctionTools. Hosted tools, such as file search,
@@ -256,9 +290,7 @@ class Agent(Generic[TContext]):
256
290
  """Get the prompt for the agent."""
257
291
  return await PromptUtil.to_model_input(self.prompt, run_context, self)
258
292
 
259
- async def get_mcp_tools(
260
- self, run_context: RunContextWrapper[TContext]
261
- ) -> list[Tool]:
293
+ async def get_mcp_tools(self, run_context: RunContextWrapper[TContext]) -> list[Tool]:
262
294
  """Fetches the available tools from the MCP servers."""
263
295
  convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False)
264
296
  return await MCPUtil.get_all_function_tools(
agents/agent_output.py CHANGED
@@ -115,8 +115,8 @@ class AgentOutputSchema(AgentOutputSchemaBase):
115
115
  except UserError as e:
116
116
  raise UserError(
117
117
  "Strict JSON schema is enabled, but the output type is not valid. "
118
- "Either make the output type strict, or pass output_schema_strict=False to "
119
- "your Agent()"
118
+ "Either make the output type strict, "
119
+ "or wrap your type with AgentOutputSchema(your_type, strict_json_schema=False)"
120
120
  ) from e
121
121
 
122
122
  def is_plain_text(self) -> bool:
agents/function_schema.py CHANGED
@@ -9,6 +9,7 @@ from typing import Any, Callable, Literal, get_args, get_origin, get_type_hints
9
9
 
10
10
  from griffe import Docstring, DocstringSectionKind
11
11
  from pydantic import BaseModel, Field, create_model
12
+ from pydantic.fields import FieldInfo
12
13
 
13
14
  from .exceptions import UserError
14
15
  from .run_context import RunContextWrapper
@@ -319,6 +320,14 @@ def function_schema(
319
320
  ann,
320
321
  Field(..., description=field_description),
321
322
  )
323
+ elif isinstance(default, FieldInfo):
324
+ # Parameter with a default value that is a Field(...)
325
+ fields[name] = (
326
+ ann,
327
+ FieldInfo.merge_field_infos(
328
+ default, description=field_description or default.description
329
+ ),
330
+ )
322
331
  else:
323
332
  # Parameter with a default value
324
333
  fields[name] = (
@@ -337,7 +346,8 @@ def function_schema(
337
346
  # 5. Return as a FuncSchema dataclass
338
347
  return FuncSchema(
339
348
  name=func_name,
340
- description=description_override or doc_info.description if doc_info else None,
349
+ # Ensure description_override takes precedence even if docstring info is disabled.
350
+ description=description_override or (doc_info.description if doc_info else None),
341
351
  params_pydantic_model=dynamic_model,
342
352
  params_json_schema=json_schema,
343
353
  signature=sig,
agents/guardrail.py CHANGED
@@ -241,7 +241,11 @@ def input_guardrail(
241
241
  def decorator(
242
242
  f: _InputGuardrailFuncSync[TContext_co] | _InputGuardrailFuncAsync[TContext_co],
243
243
  ) -> InputGuardrail[TContext_co]:
244
- return InputGuardrail(guardrail_function=f, name=name)
244
+ return InputGuardrail(
245
+ guardrail_function=f,
246
+ # If not set, guardrail name uses the function’s name by default.
247
+ name=name if name else f.__name__,
248
+ )
245
249
 
246
250
  if func is not None:
247
251
  # Decorator was used without parentheses
agents/handoffs.py CHANGED
@@ -18,12 +18,15 @@ from .util import _error_tracing, _json, _transforms
18
18
  from .util._types import MaybeAwaitable
19
19
 
20
20
  if TYPE_CHECKING:
21
- from .agent import Agent
21
+ from .agent import Agent, AgentBase
22
22
 
23
23
 
24
24
  # The handoff input type is the type of data passed when the agent is called via a handoff.
25
25
  THandoffInput = TypeVar("THandoffInput", default=Any)
26
26
 
27
+ # The agent type that the handoff returns
28
+ TAgent = TypeVar("TAgent", bound="AgentBase[Any]", default="Agent[Any]")
29
+
27
30
  OnHandoffWithInput = Callable[[RunContextWrapper[Any], THandoffInput], Any]
28
31
  OnHandoffWithoutInput = Callable[[RunContextWrapper[Any]], Any]
29
32
 
@@ -52,7 +55,7 @@ HandoffInputFilter: TypeAlias = Callable[[HandoffInputData], HandoffInputData]
52
55
 
53
56
 
54
57
  @dataclass
55
- class Handoff(Generic[TContext]):
58
+ class Handoff(Generic[TContext, TAgent]):
56
59
  """A handoff is when an agent delegates a task to another agent.
57
60
  For example, in a customer support scenario you might have a "triage agent" that determines
58
61
  which agent should handle the user's request, and sub-agents that specialize in different
@@ -69,7 +72,7 @@ class Handoff(Generic[TContext]):
69
72
  """The JSON schema for the handoff input. Can be empty if the handoff does not take an input.
70
73
  """
71
74
 
72
- on_invoke_handoff: Callable[[RunContextWrapper[Any], str], Awaitable[Agent[TContext]]]
75
+ on_invoke_handoff: Callable[[RunContextWrapper[Any], str], Awaitable[TAgent]]
73
76
  """The function that invokes the handoff. The parameters passed are:
74
77
  1. The handoff run context
75
78
  2. The arguments from the LLM, as a JSON string. Empty string if input_json_schema is empty.
@@ -100,20 +103,22 @@ class Handoff(Generic[TContext]):
100
103
  True, as it increases the likelihood of correct JSON input.
101
104
  """
102
105
 
103
- is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True
106
+ is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase[Any]], MaybeAwaitable[bool]] = (
107
+ True
108
+ )
104
109
  """Whether the handoff is enabled. Either a bool or a Callable that takes the run context and
105
110
  agent and returns whether the handoff is enabled. You can use this to dynamically enable/disable
106
111
  a handoff based on your context/state."""
107
112
 
108
- def get_transfer_message(self, agent: Agent[Any]) -> str:
113
+ def get_transfer_message(self, agent: AgentBase[Any]) -> str:
109
114
  return json.dumps({"assistant": agent.name})
110
115
 
111
116
  @classmethod
112
- def default_tool_name(cls, agent: Agent[Any]) -> str:
117
+ def default_tool_name(cls, agent: AgentBase[Any]) -> str:
113
118
  return _transforms.transform_string_function_style(f"transfer_to_{agent.name}")
114
119
 
115
120
  @classmethod
116
- def default_tool_description(cls, agent: Agent[Any]) -> str:
121
+ def default_tool_description(cls, agent: AgentBase[Any]) -> str:
117
122
  return (
118
123
  f"Handoff to the {agent.name} agent to handle the request. "
119
124
  f"{agent.handoff_description or ''}"
@@ -128,7 +133,7 @@ def handoff(
128
133
  tool_description_override: str | None = None,
129
134
  input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None,
130
135
  is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
131
- ) -> Handoff[TContext]: ...
136
+ ) -> Handoff[TContext, Agent[TContext]]: ...
132
137
 
133
138
 
134
139
  @overload
@@ -141,7 +146,7 @@ def handoff(
141
146
  tool_name_override: str | None = None,
142
147
  input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None,
143
148
  is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
144
- ) -> Handoff[TContext]: ...
149
+ ) -> Handoff[TContext, Agent[TContext]]: ...
145
150
 
146
151
 
147
152
  @overload
@@ -153,7 +158,7 @@ def handoff(
153
158
  tool_name_override: str | None = None,
154
159
  input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None,
155
160
  is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
156
- ) -> Handoff[TContext]: ...
161
+ ) -> Handoff[TContext, Agent[TContext]]: ...
157
162
 
158
163
 
159
164
  def handoff(
@@ -163,8 +168,9 @@ def handoff(
163
168
  on_handoff: OnHandoffWithInput[THandoffInput] | OnHandoffWithoutInput | None = None,
164
169
  input_type: type[THandoffInput] | None = None,
165
170
  input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None,
166
- is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
167
- ) -> Handoff[TContext]:
171
+ is_enabled: bool
172
+ | Callable[[RunContextWrapper[Any], Agent[TContext]], MaybeAwaitable[bool]] = True,
173
+ ) -> Handoff[TContext, Agent[TContext]]:
168
174
  """Create a handoff from an agent.
169
175
 
170
176
  Args:
@@ -202,7 +208,7 @@ def handoff(
202
208
 
203
209
  async def _invoke_handoff(
204
210
  ctx: RunContextWrapper[Any], input_json: str | None = None
205
- ) -> Agent[Any]:
211
+ ) -> Agent[TContext]:
206
212
  if input_type is not None and type_adapter is not None:
207
213
  if input_json is None:
208
214
  _error_tracing.attach_error_to_current_span(
@@ -239,6 +245,18 @@ def handoff(
239
245
  # If there is a need, we can make this configurable in the future
240
246
  input_json_schema = ensure_strict_json_schema(input_json_schema)
241
247
 
248
+ async def _is_enabled(ctx: RunContextWrapper[Any], agent_base: AgentBase[Any]) -> bool:
249
+ from .agent import Agent
250
+
251
+ assert callable(is_enabled), "is_enabled must be non-null here"
252
+ assert isinstance(agent_base, Agent), "Can't handoff to a non-Agent"
253
+ result = is_enabled(ctx, agent_base)
254
+
255
+ if inspect.isawaitable(result):
256
+ return await result
257
+
258
+ return result
259
+
242
260
  return Handoff(
243
261
  tool_name=tool_name,
244
262
  tool_description=tool_description,
@@ -246,5 +264,5 @@ def handoff(
246
264
  on_invoke_handoff=_invoke_handoff,
247
265
  input_filter=input_filter,
248
266
  agent_name=agent.name,
249
- is_enabled=is_enabled,
267
+ is_enabled=_is_enabled if callable(is_enabled) else is_enabled,
250
268
  )
agents/lifecycle.py CHANGED
@@ -1,25 +1,27 @@
1
1
  from typing import Any, Generic
2
2
 
3
- from .agent import Agent
3
+ from typing_extensions import TypeVar
4
+
5
+ from .agent import Agent, AgentBase
4
6
  from .run_context import RunContextWrapper, TContext
5
7
  from .tool import Tool
6
8
 
9
+ TAgent = TypeVar("TAgent", bound=AgentBase, default=AgentBase)
10
+
7
11
 
8
- class RunHooks(Generic[TContext]):
12
+ class RunHooksBase(Generic[TContext, TAgent]):
9
13
  """A class that receives callbacks on various lifecycle events in an agent run. Subclass and
10
14
  override the methods you need.
11
15
  """
12
16
 
13
- async def on_agent_start(
14
- self, context: RunContextWrapper[TContext], agent: Agent[TContext]
15
- ) -> None:
17
+ async def on_agent_start(self, context: RunContextWrapper[TContext], agent: TAgent) -> None:
16
18
  """Called before the agent is invoked. Called each time the current agent changes."""
17
19
  pass
18
20
 
19
21
  async def on_agent_end(
20
22
  self,
21
23
  context: RunContextWrapper[TContext],
22
- agent: Agent[TContext],
24
+ agent: TAgent,
23
25
  output: Any,
24
26
  ) -> None:
25
27
  """Called when the agent produces a final output."""
@@ -28,8 +30,8 @@ class RunHooks(Generic[TContext]):
28
30
  async def on_handoff(
29
31
  self,
30
32
  context: RunContextWrapper[TContext],
31
- from_agent: Agent[TContext],
32
- to_agent: Agent[TContext],
33
+ from_agent: TAgent,
34
+ to_agent: TAgent,
33
35
  ) -> None:
34
36
  """Called when a handoff occurs."""
35
37
  pass
@@ -37,7 +39,7 @@ class RunHooks(Generic[TContext]):
37
39
  async def on_tool_start(
38
40
  self,
39
41
  context: RunContextWrapper[TContext],
40
- agent: Agent[TContext],
42
+ agent: TAgent,
41
43
  tool: Tool,
42
44
  ) -> None:
43
45
  """Called before a tool is invoked."""
@@ -46,7 +48,7 @@ class RunHooks(Generic[TContext]):
46
48
  async def on_tool_end(
47
49
  self,
48
50
  context: RunContextWrapper[TContext],
49
- agent: Agent[TContext],
51
+ agent: TAgent,
50
52
  tool: Tool,
51
53
  result: str,
52
54
  ) -> None:
@@ -54,14 +56,14 @@ class RunHooks(Generic[TContext]):
54
56
  pass
55
57
 
56
58
 
57
- class AgentHooks(Generic[TContext]):
59
+ class AgentHooksBase(Generic[TContext, TAgent]):
58
60
  """A class that receives callbacks on various lifecycle events for a specific agent. You can
59
61
  set this on `agent.hooks` to receive events for that specific agent.
60
62
 
61
63
  Subclass and override the methods you need.
62
64
  """
63
65
 
64
- async def on_start(self, context: RunContextWrapper[TContext], agent: Agent[TContext]) -> None:
66
+ async def on_start(self, context: RunContextWrapper[TContext], agent: TAgent) -> None:
65
67
  """Called before the agent is invoked. Called each time the running agent is changed to this
66
68
  agent."""
67
69
  pass
@@ -69,7 +71,7 @@ class AgentHooks(Generic[TContext]):
69
71
  async def on_end(
70
72
  self,
71
73
  context: RunContextWrapper[TContext],
72
- agent: Agent[TContext],
74
+ agent: TAgent,
73
75
  output: Any,
74
76
  ) -> None:
75
77
  """Called when the agent produces a final output."""
@@ -78,8 +80,8 @@ class AgentHooks(Generic[TContext]):
78
80
  async def on_handoff(
79
81
  self,
80
82
  context: RunContextWrapper[TContext],
81
- agent: Agent[TContext],
82
- source: Agent[TContext],
83
+ agent: TAgent,
84
+ source: TAgent,
83
85
  ) -> None:
84
86
  """Called when the agent is being handed off to. The `source` is the agent that is handing
85
87
  off to this agent."""
@@ -88,7 +90,7 @@ class AgentHooks(Generic[TContext]):
88
90
  async def on_tool_start(
89
91
  self,
90
92
  context: RunContextWrapper[TContext],
91
- agent: Agent[TContext],
93
+ agent: TAgent,
92
94
  tool: Tool,
93
95
  ) -> None:
94
96
  """Called before a tool is invoked."""
@@ -97,9 +99,16 @@ class AgentHooks(Generic[TContext]):
97
99
  async def on_tool_end(
98
100
  self,
99
101
  context: RunContextWrapper[TContext],
100
- agent: Agent[TContext],
102
+ agent: TAgent,
101
103
  tool: Tool,
102
104
  result: str,
103
105
  ) -> None:
104
106
  """Called after a tool is invoked."""
105
107
  pass
108
+
109
+
110
+ RunHooks = RunHooksBase[TContext, Agent]
111
+ """Run hooks when using `Agent`."""
112
+
113
+ AgentHooks = AgentHooksBase[TContext, Agent]
114
+ """Agent hooks for `Agent`s."""