openai-agents 0.0.19__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of openai-agents might be problematic. Click here for more details.

Files changed (43) hide show
  1. agents/__init__.py +5 -2
  2. agents/_run_impl.py +35 -1
  3. agents/agent.py +65 -29
  4. agents/extensions/models/litellm_model.py +7 -3
  5. agents/function_schema.py +11 -1
  6. agents/guardrail.py +5 -1
  7. agents/handoffs.py +14 -0
  8. agents/lifecycle.py +26 -17
  9. agents/mcp/__init__.py +13 -1
  10. agents/mcp/server.py +173 -16
  11. agents/mcp/util.py +89 -6
  12. agents/memory/__init__.py +3 -0
  13. agents/memory/session.py +369 -0
  14. agents/model_settings.py +60 -6
  15. agents/models/chatcmpl_converter.py +31 -2
  16. agents/models/chatcmpl_stream_handler.py +128 -16
  17. agents/models/openai_chatcompletions.py +12 -10
  18. agents/models/openai_responses.py +25 -8
  19. agents/realtime/README.md +3 -0
  20. agents/realtime/__init__.py +174 -0
  21. agents/realtime/agent.py +80 -0
  22. agents/realtime/config.py +128 -0
  23. agents/realtime/events.py +216 -0
  24. agents/realtime/items.py +91 -0
  25. agents/realtime/model.py +69 -0
  26. agents/realtime/model_events.py +159 -0
  27. agents/realtime/model_inputs.py +100 -0
  28. agents/realtime/openai_realtime.py +584 -0
  29. agents/realtime/runner.py +118 -0
  30. agents/realtime/session.py +502 -0
  31. agents/repl.py +1 -4
  32. agents/run.py +131 -10
  33. agents/tool.py +30 -6
  34. agents/tool_context.py +16 -3
  35. agents/tracing/__init__.py +1 -2
  36. agents/tracing/processor_interface.py +1 -1
  37. agents/voice/models/openai_stt.py +1 -1
  38. agents/voice/pipeline.py +6 -0
  39. agents/voice/workflow.py +8 -0
  40. {openai_agents-0.0.19.dist-info → openai_agents-0.2.0.dist-info}/METADATA +133 -8
  41. {openai_agents-0.0.19.dist-info → openai_agents-0.2.0.dist-info}/RECORD +43 -29
  42. {openai_agents-0.0.19.dist-info → openai_agents-0.2.0.dist-info}/WHEEL +0 -0
  43. {openai_agents-0.0.19.dist-info → openai_agents-0.2.0.dist-info}/licenses/LICENSE +0 -0
agents/__init__.py CHANGED
@@ -5,7 +5,7 @@ from typing import Literal
5
5
  from openai import AsyncOpenAI
6
6
 
7
7
  from . import _config
8
- from .agent import Agent, ToolsToFinalOutputFunction, ToolsToFinalOutputResult
8
+ from .agent import Agent, AgentBase, ToolsToFinalOutputFunction, ToolsToFinalOutputResult
9
9
  from .agent_output import AgentOutputSchema, AgentOutputSchemaBase
10
10
  from .computer import AsyncComputer, Button, Computer, Environment
11
11
  from .exceptions import (
@@ -40,6 +40,7 @@ from .items import (
40
40
  TResponseInputItem,
41
41
  )
42
42
  from .lifecycle import AgentHooks, RunHooks
43
+ from .memory import Session, SQLiteSession
43
44
  from .model_settings import ModelSettings
44
45
  from .models.interface import Model, ModelProvider, ModelTracing
45
46
  from .models.openai_chatcompletions import OpenAIChatCompletionsModel
@@ -160,6 +161,7 @@ def enable_verbose_stdout_logging():
160
161
 
161
162
  __all__ = [
162
163
  "Agent",
164
+ "AgentBase",
163
165
  "ToolsToFinalOutputFunction",
164
166
  "ToolsToFinalOutputResult",
165
167
  "Runner",
@@ -206,10 +208,11 @@ __all__ = [
206
208
  "ToolCallItem",
207
209
  "ToolCallOutputItem",
208
210
  "ReasoningItem",
209
- "ModelResponse",
210
211
  "ItemHelpers",
211
212
  "RunHooks",
212
213
  "AgentHooks",
214
+ "Session",
215
+ "SQLiteSession",
213
216
  "RunContextWrapper",
214
217
  "TContext",
215
218
  "RunErrorDetails",
agents/_run_impl.py CHANGED
@@ -28,6 +28,9 @@ from openai.types.responses.response_computer_tool_call import (
28
28
  ActionType,
29
29
  ActionWait,
30
30
  )
31
+ from openai.types.responses.response_input_item_param import (
32
+ ComputerCallOutputAcknowledgedSafetyCheck,
33
+ )
31
34
  from openai.types.responses.response_input_param import ComputerCallOutput, McpApprovalResponse
32
35
  from openai.types.responses.response_output_item import (
33
36
  ImageGenerationCall,
@@ -67,6 +70,7 @@ from .run_context import RunContextWrapper, TContext
67
70
  from .stream_events import RunItemStreamEvent, StreamEvent
68
71
  from .tool import (
69
72
  ComputerTool,
73
+ ComputerToolSafetyCheckData,
70
74
  FunctionTool,
71
75
  FunctionToolResult,
72
76
  HostedMCPTool,
@@ -544,7 +548,11 @@ class RunImpl:
544
548
  func_tool: FunctionTool, tool_call: ResponseFunctionToolCall
545
549
  ) -> Any:
546
550
  with function_span(func_tool.name) as span_fn:
547
- tool_context = ToolContext.from_agent_context(context_wrapper, tool_call.call_id)
551
+ tool_context = ToolContext.from_agent_context(
552
+ context_wrapper,
553
+ tool_call.call_id,
554
+ tool_call=tool_call,
555
+ )
548
556
  if config.trace_include_sensitive_data:
549
557
  span_fn.span_data.input = tool_call.arguments
550
558
  try:
@@ -638,6 +646,29 @@ class RunImpl:
638
646
  results: list[RunItem] = []
639
647
  # Need to run these serially, because each action can affect the computer state
640
648
  for action in actions:
649
+ acknowledged: list[ComputerCallOutputAcknowledgedSafetyCheck] | None = None
650
+ if action.tool_call.pending_safety_checks and action.computer_tool.on_safety_check:
651
+ acknowledged = []
652
+ for check in action.tool_call.pending_safety_checks:
653
+ data = ComputerToolSafetyCheckData(
654
+ ctx_wrapper=context_wrapper,
655
+ agent=agent,
656
+ tool_call=action.tool_call,
657
+ safety_check=check,
658
+ )
659
+ maybe = action.computer_tool.on_safety_check(data)
660
+ ack = await maybe if inspect.isawaitable(maybe) else maybe
661
+ if ack:
662
+ acknowledged.append(
663
+ ComputerCallOutputAcknowledgedSafetyCheck(
664
+ id=check.id,
665
+ code=check.code,
666
+ message=check.message,
667
+ )
668
+ )
669
+ else:
670
+ raise UserError("Computer tool safety check was not acknowledged")
671
+
641
672
  results.append(
642
673
  await ComputerAction.execute(
643
674
  agent=agent,
@@ -645,6 +676,7 @@ class RunImpl:
645
676
  hooks=hooks,
646
677
  context_wrapper=context_wrapper,
647
678
  config=config,
679
+ acknowledged_safety_checks=acknowledged,
648
680
  )
649
681
  )
650
682
 
@@ -998,6 +1030,7 @@ class ComputerAction:
998
1030
  hooks: RunHooks[TContext],
999
1031
  context_wrapper: RunContextWrapper[TContext],
1000
1032
  config: RunConfig,
1033
+ acknowledged_safety_checks: list[ComputerCallOutputAcknowledgedSafetyCheck] | None = None,
1001
1034
  ) -> RunItem:
1002
1035
  output_func = (
1003
1036
  cls._get_screenshot_async(action.computer_tool.computer, action.tool_call)
@@ -1036,6 +1069,7 @@ class ComputerAction:
1036
1069
  "image_url": image_url,
1037
1070
  },
1038
1071
  type="computer_call_output",
1072
+ acknowledged_safety_checks=acknowledged_safety_checks,
1039
1073
  ),
1040
1074
  )
1041
1075
 
agents/agent.py CHANGED
@@ -67,7 +67,63 @@ class MCPConfig(TypedDict):
67
67
 
68
68
 
69
69
  @dataclass
70
- class Agent(Generic[TContext]):
70
+ class AgentBase(Generic[TContext]):
71
+ """Base class for `Agent` and `RealtimeAgent`."""
72
+
73
+ name: str
74
+ """The name of the agent."""
75
+
76
+ handoff_description: str | None = None
77
+ """A description of the agent. This is used when the agent is used as a handoff, so that an
78
+ LLM knows what it does and when to invoke it.
79
+ """
80
+
81
+ tools: list[Tool] = field(default_factory=list)
82
+ """A list of tools that the agent can use."""
83
+
84
+ mcp_servers: list[MCPServer] = field(default_factory=list)
85
+ """A list of [Model Context Protocol](https://modelcontextprotocol.io/) servers that
86
+ the agent can use. Every time the agent runs, it will include tools from these servers in the
87
+ list of available tools.
88
+
89
+ NOTE: You are expected to manage the lifecycle of these servers. Specifically, you must call
90
+ `server.connect()` before passing it to the agent, and `server.cleanup()` when the server is no
91
+ longer needed.
92
+ """
93
+
94
+ mcp_config: MCPConfig = field(default_factory=lambda: MCPConfig())
95
+ """Configuration for MCP servers."""
96
+
97
+ async def get_mcp_tools(self, run_context: RunContextWrapper[TContext]) -> list[Tool]:
98
+ """Fetches the available tools from the MCP servers."""
99
+ convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False)
100
+ return await MCPUtil.get_all_function_tools(
101
+ self.mcp_servers, convert_schemas_to_strict, run_context, self
102
+ )
103
+
104
+ async def get_all_tools(self, run_context: RunContextWrapper[Any]) -> list[Tool]:
105
+ """All agent tools, including MCP tools and function tools."""
106
+ mcp_tools = await self.get_mcp_tools(run_context)
107
+
108
+ async def _check_tool_enabled(tool: Tool) -> bool:
109
+ if not isinstance(tool, FunctionTool):
110
+ return True
111
+
112
+ attr = tool.is_enabled
113
+ if isinstance(attr, bool):
114
+ return attr
115
+ res = attr(run_context, self)
116
+ if inspect.isawaitable(res):
117
+ return bool(await res)
118
+ return bool(res)
119
+
120
+ results = await asyncio.gather(*(_check_tool_enabled(t) for t in self.tools))
121
+ enabled: list[Tool] = [t for t, ok in zip(self.tools, results) if ok]
122
+ return [*mcp_tools, *enabled]
123
+
124
+
125
+ @dataclass
126
+ class Agent(AgentBase, Generic[TContext]):
71
127
  """An agent is an AI model configured with instructions, tools, guardrails, handoffs and more.
72
128
 
73
129
  We strongly recommend passing `instructions`, which is the "system prompt" for the agent. In
@@ -76,10 +132,9 @@ class Agent(Generic[TContext]):
76
132
 
77
133
  Agents are generic on the context type. The context is a (mutable) object you create. It is
78
134
  passed to tool functions, handoffs, guardrails, etc.
79
- """
80
135
 
81
- name: str
82
- """The name of the agent."""
136
+ See `AgentBase` for base parameters that are shared with `RealtimeAgent`s.
137
+ """
83
138
 
84
139
  instructions: (
85
140
  str
@@ -103,11 +158,6 @@ class Agent(Generic[TContext]):
103
158
  usable with OpenAI models, using the Responses API.
104
159
  """
105
160
 
106
- handoff_description: str | None = None
107
- """A description of the agent. This is used when the agent is used as a handoff, so that an
108
- LLM knows what it does and when to invoke it.
109
- """
110
-
111
161
  handoffs: list[Agent[Any] | Handoff[TContext]] = field(default_factory=list)
112
162
  """Handoffs are sub-agents that the agent can delegate to. You can provide a list of handoffs,
113
163
  and the agent can choose to delegate to them if relevant. Allows for separation of concerns and
@@ -125,22 +175,6 @@ class Agent(Generic[TContext]):
125
175
  """Configures model-specific tuning parameters (e.g. temperature, top_p).
126
176
  """
127
177
 
128
- tools: list[Tool] = field(default_factory=list)
129
- """A list of tools that the agent can use."""
130
-
131
- mcp_servers: list[MCPServer] = field(default_factory=list)
132
- """A list of [Model Context Protocol](https://modelcontextprotocol.io/) servers that
133
- the agent can use. Every time the agent runs, it will include tools from these servers in the
134
- list of available tools.
135
-
136
- NOTE: You are expected to manage the lifecycle of these servers. Specifically, you must call
137
- `server.connect()` before passing it to the agent, and `server.cleanup()` when the server is no
138
- longer needed.
139
- """
140
-
141
- mcp_config: MCPConfig = field(default_factory=lambda: MCPConfig())
142
- """Configuration for MCP servers."""
143
-
144
178
  input_guardrails: list[InputGuardrail[TContext]] = field(default_factory=list)
145
179
  """A list of checks that run in parallel to the agent's execution, before generating a
146
180
  response. Runs only if the agent is the first agent in the chain.
@@ -176,7 +210,7 @@ class Agent(Generic[TContext]):
176
210
  The final output will be the output of the first matching tool call. The LLM does not
177
211
  process the result of the tool call.
178
212
  - A function: If you pass a function, it will be called with the run context and the list of
179
- tool results. It must return a `ToolToFinalOutputResult`, which determines whether the tool
213
+ tool results. It must return a `ToolsToFinalOutputResult`, which determines whether the tool
180
214
  calls result in a final output.
181
215
 
182
216
  NOTE: This configuration is specific to FunctionTools. Hosted tools, such as file search,
@@ -256,14 +290,16 @@ class Agent(Generic[TContext]):
256
290
  """Get the prompt for the agent."""
257
291
  return await PromptUtil.to_model_input(self.prompt, run_context, self)
258
292
 
259
- async def get_mcp_tools(self) -> list[Tool]:
293
+ async def get_mcp_tools(self, run_context: RunContextWrapper[TContext]) -> list[Tool]:
260
294
  """Fetches the available tools from the MCP servers."""
261
295
  convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False)
262
- return await MCPUtil.get_all_function_tools(self.mcp_servers, convert_schemas_to_strict)
296
+ return await MCPUtil.get_all_function_tools(
297
+ self.mcp_servers, convert_schemas_to_strict, run_context, self
298
+ )
263
299
 
264
300
  async def get_all_tools(self, run_context: RunContextWrapper[Any]) -> list[Tool]:
265
301
  """All agent tools, including MCP tools and function tools."""
266
- mcp_tools = await self.get_mcp_tools()
302
+ mcp_tools = await self.get_mcp_tools(run_context)
267
303
 
268
304
  async def _check_tool_enabled(tool: Tool) -> bool:
269
305
  if not isinstance(tool, FunctionTool):
@@ -98,7 +98,11 @@ class LitellmModel(Model):
98
98
  logger.debug("Received model response")
99
99
  else:
100
100
  logger.debug(
101
- f"LLM resp:\n{json.dumps(response.choices[0].message.model_dump(), indent=2)}\n"
101
+ f"""LLM resp:\n{
102
+ json.dumps(
103
+ response.choices[0].message.model_dump(), indent=2, ensure_ascii=False
104
+ )
105
+ }\n"""
102
106
  )
103
107
 
104
108
  if hasattr(response, "usage"):
@@ -269,8 +273,8 @@ class LitellmModel(Model):
269
273
  else:
270
274
  logger.debug(
271
275
  f"Calling Litellm model: {self.model}\n"
272
- f"{json.dumps(converted_messages, indent=2)}\n"
273
- f"Tools:\n{json.dumps(converted_tools, indent=2)}\n"
276
+ f"{json.dumps(converted_messages, indent=2, ensure_ascii=False)}\n"
277
+ f"Tools:\n{json.dumps(converted_tools, indent=2, ensure_ascii=False)}\n"
274
278
  f"Stream: {stream}\n"
275
279
  f"Tool choice: {tool_choice}\n"
276
280
  f"Response format: {response_format}\n"
agents/function_schema.py CHANGED
@@ -9,6 +9,7 @@ from typing import Any, Callable, Literal, get_args, get_origin, get_type_hints
9
9
 
10
10
  from griffe import Docstring, DocstringSectionKind
11
11
  from pydantic import BaseModel, Field, create_model
12
+ from pydantic.fields import FieldInfo
12
13
 
13
14
  from .exceptions import UserError
14
15
  from .run_context import RunContextWrapper
@@ -319,6 +320,14 @@ def function_schema(
319
320
  ann,
320
321
  Field(..., description=field_description),
321
322
  )
323
+ elif isinstance(default, FieldInfo):
324
+ # Parameter with a default value that is a Field(...)
325
+ fields[name] = (
326
+ ann,
327
+ FieldInfo.merge_field_infos(
328
+ default, description=field_description or default.description
329
+ ),
330
+ )
322
331
  else:
323
332
  # Parameter with a default value
324
333
  fields[name] = (
@@ -337,7 +346,8 @@ def function_schema(
337
346
  # 5. Return as a FuncSchema dataclass
338
347
  return FuncSchema(
339
348
  name=func_name,
340
- description=description_override or doc_info.description if doc_info else None,
349
+ # Ensure description_override takes precedence even if docstring info is disabled.
350
+ description=description_override or (doc_info.description if doc_info else None),
341
351
  params_pydantic_model=dynamic_model,
342
352
  params_json_schema=json_schema,
343
353
  signature=sig,
agents/guardrail.py CHANGED
@@ -241,7 +241,11 @@ def input_guardrail(
241
241
  def decorator(
242
242
  f: _InputGuardrailFuncSync[TContext_co] | _InputGuardrailFuncAsync[TContext_co],
243
243
  ) -> InputGuardrail[TContext_co]:
244
- return InputGuardrail(guardrail_function=f, name=name)
244
+ return InputGuardrail(
245
+ guardrail_function=f,
246
+ # If not set, guardrail name uses the function’s name by default.
247
+ name=name if name else f.__name__
248
+ )
245
249
 
246
250
  if func is not None:
247
251
  # Decorator was used without parentheses
agents/handoffs.py CHANGED
@@ -15,6 +15,7 @@ from .run_context import RunContextWrapper, TContext
15
15
  from .strict_schema import ensure_strict_json_schema
16
16
  from .tracing.spans import SpanError
17
17
  from .util import _error_tracing, _json, _transforms
18
+ from .util._types import MaybeAwaitable
18
19
 
19
20
  if TYPE_CHECKING:
20
21
  from .agent import Agent
@@ -99,6 +100,11 @@ class Handoff(Generic[TContext]):
99
100
  True, as it increases the likelihood of correct JSON input.
100
101
  """
101
102
 
103
+ is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True
104
+ """Whether the handoff is enabled. Either a bool or a Callable that takes the run context and
105
+ agent and returns whether the handoff is enabled. You can use this to dynamically enable/disable
106
+ a handoff based on your context/state."""
107
+
102
108
  def get_transfer_message(self, agent: Agent[Any]) -> str:
103
109
  return json.dumps({"assistant": agent.name})
104
110
 
@@ -121,6 +127,7 @@ def handoff(
121
127
  tool_name_override: str | None = None,
122
128
  tool_description_override: str | None = None,
123
129
  input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None,
130
+ is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
124
131
  ) -> Handoff[TContext]: ...
125
132
 
126
133
 
@@ -133,6 +140,7 @@ def handoff(
133
140
  tool_description_override: str | None = None,
134
141
  tool_name_override: str | None = None,
135
142
  input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None,
143
+ is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
136
144
  ) -> Handoff[TContext]: ...
137
145
 
138
146
 
@@ -144,6 +152,7 @@ def handoff(
144
152
  tool_description_override: str | None = None,
145
153
  tool_name_override: str | None = None,
146
154
  input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None,
155
+ is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
147
156
  ) -> Handoff[TContext]: ...
148
157
 
149
158
 
@@ -154,6 +163,7 @@ def handoff(
154
163
  on_handoff: OnHandoffWithInput[THandoffInput] | OnHandoffWithoutInput | None = None,
155
164
  input_type: type[THandoffInput] | None = None,
156
165
  input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None,
166
+ is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
157
167
  ) -> Handoff[TContext]:
158
168
  """Create a handoff from an agent.
159
169
 
@@ -166,6 +176,9 @@ def handoff(
166
176
  input_type: the type of the input to the handoff. If provided, the input will be validated
167
177
  against this type. Only relevant if you pass a function that takes an input.
168
178
  input_filter: a function that filters the inputs that are passed to the next agent.
179
+ is_enabled: Whether the handoff is enabled. Can be a bool or a callable that takes the run
180
+ context and agent and returns whether the handoff is enabled. Disabled handoffs are
181
+ hidden from the LLM at runtime.
169
182
  """
170
183
  assert (on_handoff and input_type) or not (on_handoff and input_type), (
171
184
  "You must provide either both on_handoff and input_type, or neither"
@@ -233,4 +246,5 @@ def handoff(
233
246
  on_invoke_handoff=_invoke_handoff,
234
247
  input_filter=input_filter,
235
248
  agent_name=agent.name,
249
+ is_enabled=is_enabled,
236
250
  )
agents/lifecycle.py CHANGED
@@ -1,25 +1,27 @@
1
1
  from typing import Any, Generic
2
2
 
3
- from .agent import Agent
3
+ from typing_extensions import TypeVar
4
+
5
+ from .agent import Agent, AgentBase
4
6
  from .run_context import RunContextWrapper, TContext
5
7
  from .tool import Tool
6
8
 
9
+ TAgent = TypeVar("TAgent", bound=AgentBase, default=AgentBase)
10
+
7
11
 
8
- class RunHooks(Generic[TContext]):
12
+ class RunHooksBase(Generic[TContext, TAgent]):
9
13
  """A class that receives callbacks on various lifecycle events in an agent run. Subclass and
10
14
  override the methods you need.
11
15
  """
12
16
 
13
- async def on_agent_start(
14
- self, context: RunContextWrapper[TContext], agent: Agent[TContext]
15
- ) -> None:
17
+ async def on_agent_start(self, context: RunContextWrapper[TContext], agent: TAgent) -> None:
16
18
  """Called before the agent is invoked. Called each time the current agent changes."""
17
19
  pass
18
20
 
19
21
  async def on_agent_end(
20
22
  self,
21
23
  context: RunContextWrapper[TContext],
22
- agent: Agent[TContext],
24
+ agent: TAgent,
23
25
  output: Any,
24
26
  ) -> None:
25
27
  """Called when the agent produces a final output."""
@@ -28,8 +30,8 @@ class RunHooks(Generic[TContext]):
28
30
  async def on_handoff(
29
31
  self,
30
32
  context: RunContextWrapper[TContext],
31
- from_agent: Agent[TContext],
32
- to_agent: Agent[TContext],
33
+ from_agent: TAgent,
34
+ to_agent: TAgent,
33
35
  ) -> None:
34
36
  """Called when a handoff occurs."""
35
37
  pass
@@ -37,7 +39,7 @@ class RunHooks(Generic[TContext]):
37
39
  async def on_tool_start(
38
40
  self,
39
41
  context: RunContextWrapper[TContext],
40
- agent: Agent[TContext],
42
+ agent: TAgent,
41
43
  tool: Tool,
42
44
  ) -> None:
43
45
  """Called before a tool is invoked."""
@@ -46,7 +48,7 @@ class RunHooks(Generic[TContext]):
46
48
  async def on_tool_end(
47
49
  self,
48
50
  context: RunContextWrapper[TContext],
49
- agent: Agent[TContext],
51
+ agent: TAgent,
50
52
  tool: Tool,
51
53
  result: str,
52
54
  ) -> None:
@@ -54,14 +56,14 @@ class RunHooks(Generic[TContext]):
54
56
  pass
55
57
 
56
58
 
57
- class AgentHooks(Generic[TContext]):
59
+ class AgentHooksBase(Generic[TContext, TAgent]):
58
60
  """A class that receives callbacks on various lifecycle events for a specific agent. You can
59
61
  set this on `agent.hooks` to receive events for that specific agent.
60
62
 
61
63
  Subclass and override the methods you need.
62
64
  """
63
65
 
64
- async def on_start(self, context: RunContextWrapper[TContext], agent: Agent[TContext]) -> None:
66
+ async def on_start(self, context: RunContextWrapper[TContext], agent: TAgent) -> None:
65
67
  """Called before the agent is invoked. Called each time the running agent is changed to this
66
68
  agent."""
67
69
  pass
@@ -69,7 +71,7 @@ class AgentHooks(Generic[TContext]):
69
71
  async def on_end(
70
72
  self,
71
73
  context: RunContextWrapper[TContext],
72
- agent: Agent[TContext],
74
+ agent: TAgent,
73
75
  output: Any,
74
76
  ) -> None:
75
77
  """Called when the agent produces a final output."""
@@ -78,8 +80,8 @@ class AgentHooks(Generic[TContext]):
78
80
  async def on_handoff(
79
81
  self,
80
82
  context: RunContextWrapper[TContext],
81
- agent: Agent[TContext],
82
- source: Agent[TContext],
83
+ agent: TAgent,
84
+ source: TAgent,
83
85
  ) -> None:
84
86
  """Called when the agent is being handed off to. The `source` is the agent that is handing
85
87
  off to this agent."""
@@ -88,7 +90,7 @@ class AgentHooks(Generic[TContext]):
88
90
  async def on_tool_start(
89
91
  self,
90
92
  context: RunContextWrapper[TContext],
91
- agent: Agent[TContext],
93
+ agent: TAgent,
92
94
  tool: Tool,
93
95
  ) -> None:
94
96
  """Called before a tool is invoked."""
@@ -97,9 +99,16 @@ class AgentHooks(Generic[TContext]):
97
99
  async def on_tool_end(
98
100
  self,
99
101
  context: RunContextWrapper[TContext],
100
- agent: Agent[TContext],
102
+ agent: TAgent,
101
103
  tool: Tool,
102
104
  result: str,
103
105
  ) -> None:
104
106
  """Called after a tool is invoked."""
105
107
  pass
108
+
109
+
110
+ RunHooks = RunHooksBase[TContext, Agent]
111
+ """Run hooks when using `Agent`."""
112
+
113
+ AgentHooks = AgentHooksBase[TContext, Agent]
114
+ """Agent hooks for `Agent`s."""
agents/mcp/__init__.py CHANGED
@@ -11,7 +11,14 @@ try:
11
11
  except ImportError:
12
12
  pass
13
13
 
14
- from .util import MCPUtil
14
+ from .util import (
15
+ MCPUtil,
16
+ ToolFilter,
17
+ ToolFilterCallable,
18
+ ToolFilterContext,
19
+ ToolFilterStatic,
20
+ create_static_tool_filter,
21
+ )
15
22
 
16
23
  __all__ = [
17
24
  "MCPServer",
@@ -22,4 +29,9 @@ __all__ = [
22
29
  "MCPServerStreamableHttp",
23
30
  "MCPServerStreamableHttpParams",
24
31
  "MCPUtil",
32
+ "ToolFilter",
33
+ "ToolFilterCallable",
34
+ "ToolFilterContext",
35
+ "ToolFilterStatic",
36
+ "create_static_tool_filter",
25
37
  ]