openai-agents 0.0.17__py3-none-any.whl → 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of openai-agents might be problematic. Click here for more details.

agents/run.py CHANGED
@@ -3,9 +3,13 @@ from __future__ import annotations
3
3
  import asyncio
4
4
  import copy
5
5
  from dataclasses import dataclass, field
6
- from typing import Any, cast
6
+ from typing import Any, Generic, cast
7
7
 
8
8
  from openai.types.responses import ResponseCompletedEvent
9
+ from openai.types.responses.response_prompt_param import (
10
+ ResponsePromptParam,
11
+ )
12
+ from typing_extensions import NotRequired, TypedDict, Unpack
9
13
 
10
14
  from ._run_impl import (
11
15
  AgentToolUseTracker,
@@ -28,7 +32,12 @@ from .exceptions import (
28
32
  OutputGuardrailTripwireTriggered,
29
33
  RunErrorDetails,
30
34
  )
31
- from .guardrail import InputGuardrail, InputGuardrailResult, OutputGuardrail, OutputGuardrailResult
35
+ from .guardrail import (
36
+ InputGuardrail,
37
+ InputGuardrailResult,
38
+ OutputGuardrail,
39
+ OutputGuardrailResult,
40
+ )
32
41
  from .handoffs import Handoff, HandoffInputFilter, handoff
33
42
  from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem
34
43
  from .lifecycle import RunHooks
@@ -47,6 +56,27 @@ from .util import _coro, _error_tracing
47
56
 
48
57
  DEFAULT_MAX_TURNS = 10
49
58
 
59
+ DEFAULT_AGENT_RUNNER: AgentRunner = None # type: ignore
60
+ # the value is set at the end of the module
61
+
62
+
63
+ def set_default_agent_runner(runner: AgentRunner | None) -> None:
64
+ """
65
+ WARNING: this class is experimental and not part of the public API
66
+ It should not be used directly.
67
+ """
68
+ global DEFAULT_AGENT_RUNNER
69
+ DEFAULT_AGENT_RUNNER = runner or AgentRunner()
70
+
71
+
72
+ def get_default_agent_runner() -> AgentRunner:
73
+ """
74
+ WARNING: this class is experimental and not part of the public API
75
+ It should not be used directly.
76
+ """
77
+ global DEFAULT_AGENT_RUNNER
78
+ return DEFAULT_AGENT_RUNNER
79
+
50
80
 
51
81
  @dataclass
52
82
  class RunConfig:
@@ -107,6 +137,25 @@ class RunConfig:
107
137
  """
108
138
 
109
139
 
140
+ class RunOptions(TypedDict, Generic[TContext]):
141
+ """Arguments for ``AgentRunner`` methods."""
142
+
143
+ context: NotRequired[TContext | None]
144
+ """The context for the run."""
145
+
146
+ max_turns: NotRequired[int]
147
+ """The maximum number of turns to run for."""
148
+
149
+ hooks: NotRequired[RunHooks[TContext] | None]
150
+ """Lifecycle hooks for the run."""
151
+
152
+ run_config: NotRequired[RunConfig | None]
153
+ """Run configuration."""
154
+
155
+ previous_response_id: NotRequired[str | None]
156
+ """The ID of the previous response, if any."""
157
+
158
+
110
159
  class Runner:
111
160
  @classmethod
112
161
  async def run(
@@ -127,13 +176,10 @@ class Runner:
127
176
  `agent.output_type`, the loop terminates.
128
177
  3. If there's a handoff, we run the loop again, with the new agent.
129
178
  4. Else, we run tool calls (if any), and re-run the loop.
130
-
131
179
  In two cases, the agent may raise an exception:
132
180
  1. If the max_turns is exceeded, a MaxTurnsExceeded exception is raised.
133
181
  2. If a guardrail tripwire is triggered, a GuardrailTripwireTriggered exception is raised.
134
-
135
182
  Note that only the first agent's input guardrails are run.
136
-
137
183
  Args:
138
184
  starting_agent: The starting agent to run.
139
185
  input: The initial input to the agent. You can pass a single string for a user message,
@@ -145,11 +191,139 @@ class Runner:
145
191
  run_config: Global settings for the entire agent run.
146
192
  previous_response_id: The ID of the previous response, if using OpenAI models via the
147
193
  Responses API, this allows you to skip passing in input from the previous turn.
194
+ Returns:
195
+ A run result containing all the inputs, guardrail results and the output of the last
196
+ agent. Agents may perform handoffs, so we don't know the specific type of the output.
197
+ """
198
+ runner = DEFAULT_AGENT_RUNNER
199
+ return await runner.run(
200
+ starting_agent,
201
+ input,
202
+ context=context,
203
+ max_turns=max_turns,
204
+ hooks=hooks,
205
+ run_config=run_config,
206
+ previous_response_id=previous_response_id,
207
+ )
148
208
 
209
+ @classmethod
210
+ def run_sync(
211
+ cls,
212
+ starting_agent: Agent[TContext],
213
+ input: str | list[TResponseInputItem],
214
+ *,
215
+ context: TContext | None = None,
216
+ max_turns: int = DEFAULT_MAX_TURNS,
217
+ hooks: RunHooks[TContext] | None = None,
218
+ run_config: RunConfig | None = None,
219
+ previous_response_id: str | None = None,
220
+ ) -> RunResult:
221
+ """Run a workflow synchronously, starting at the given agent. Note that this just wraps the
222
+ `run` method, so it will not work if there's already an event loop (e.g. inside an async
223
+ function, or in a Jupyter notebook or async context like FastAPI). For those cases, use
224
+ the `run` method instead.
225
+ The agent will run in a loop until a final output is generated. The loop runs like so:
226
+ 1. The agent is invoked with the given input.
227
+ 2. If there is a final output (i.e. the agent produces something of type
228
+ `agent.output_type`, the loop terminates.
229
+ 3. If there's a handoff, we run the loop again, with the new agent.
230
+ 4. Else, we run tool calls (if any), and re-run the loop.
231
+ In two cases, the agent may raise an exception:
232
+ 1. If the max_turns is exceeded, a MaxTurnsExceeded exception is raised.
233
+ 2. If a guardrail tripwire is triggered, a GuardrailTripwireTriggered exception is raised.
234
+ Note that only the first agent's input guardrails are run.
235
+ Args:
236
+ starting_agent: The starting agent to run.
237
+ input: The initial input to the agent. You can pass a single string for a user message,
238
+ or a list of input items.
239
+ context: The context to run the agent with.
240
+ max_turns: The maximum number of turns to run the agent for. A turn is defined as one
241
+ AI invocation (including any tool calls that might occur).
242
+ hooks: An object that receives callbacks on various lifecycle events.
243
+ run_config: Global settings for the entire agent run.
244
+ previous_response_id: The ID of the previous response, if using OpenAI models via the
245
+ Responses API, this allows you to skip passing in input from the previous turn.
149
246
  Returns:
150
247
  A run result containing all the inputs, guardrail results and the output of the last
151
248
  agent. Agents may perform handoffs, so we don't know the specific type of the output.
152
249
  """
250
+ runner = DEFAULT_AGENT_RUNNER
251
+ return runner.run_sync(
252
+ starting_agent,
253
+ input,
254
+ context=context,
255
+ max_turns=max_turns,
256
+ hooks=hooks,
257
+ run_config=run_config,
258
+ previous_response_id=previous_response_id,
259
+ )
260
+
261
+ @classmethod
262
+ def run_streamed(
263
+ cls,
264
+ starting_agent: Agent[TContext],
265
+ input: str | list[TResponseInputItem],
266
+ context: TContext | None = None,
267
+ max_turns: int = DEFAULT_MAX_TURNS,
268
+ hooks: RunHooks[TContext] | None = None,
269
+ run_config: RunConfig | None = None,
270
+ previous_response_id: str | None = None,
271
+ ) -> RunResultStreaming:
272
+ """Run a workflow starting at the given agent in streaming mode. The returned result object
273
+ contains a method you can use to stream semantic events as they are generated.
274
+ The agent will run in a loop until a final output is generated. The loop runs like so:
275
+ 1. The agent is invoked with the given input.
276
+ 2. If there is a final output (i.e. the agent produces something of type
277
+ `agent.output_type`, the loop terminates.
278
+ 3. If there's a handoff, we run the loop again, with the new agent.
279
+ 4. Else, we run tool calls (if any), and re-run the loop.
280
+ In two cases, the agent may raise an exception:
281
+ 1. If the max_turns is exceeded, a MaxTurnsExceeded exception is raised.
282
+ 2. If a guardrail tripwire is triggered, a GuardrailTripwireTriggered exception is raised.
283
+ Note that only the first agent's input guardrails are run.
284
+ Args:
285
+ starting_agent: The starting agent to run.
286
+ input: The initial input to the agent. You can pass a single string for a user message,
287
+ or a list of input items.
288
+ context: The context to run the agent with.
289
+ max_turns: The maximum number of turns to run the agent for. A turn is defined as one
290
+ AI invocation (including any tool calls that might occur).
291
+ hooks: An object that receives callbacks on various lifecycle events.
292
+ run_config: Global settings for the entire agent run.
293
+ previous_response_id: The ID of the previous response, if using OpenAI models via the
294
+ Responses API, this allows you to skip passing in input from the previous turn.
295
+ Returns:
296
+ A result object that contains data about the run, as well as a method to stream events.
297
+ """
298
+ runner = DEFAULT_AGENT_RUNNER
299
+ return runner.run_streamed(
300
+ starting_agent,
301
+ input,
302
+ context=context,
303
+ max_turns=max_turns,
304
+ hooks=hooks,
305
+ run_config=run_config,
306
+ previous_response_id=previous_response_id,
307
+ )
308
+
309
+
310
+ class AgentRunner:
311
+ """
312
+ WARNING: this class is experimental and not part of the public API
313
+ It should not be used directly or subclassed.
314
+ """
315
+
316
+ async def run(
317
+ self,
318
+ starting_agent: Agent[TContext],
319
+ input: str | list[TResponseInputItem],
320
+ **kwargs: Unpack[RunOptions[TContext]],
321
+ ) -> RunResult:
322
+ context = kwargs.get("context")
323
+ max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS)
324
+ hooks = kwargs.get("hooks")
325
+ run_config = kwargs.get("run_config")
326
+ previous_response_id = kwargs.get("previous_response_id")
153
327
  if hooks is None:
154
328
  hooks = RunHooks[Any]()
155
329
  if run_config is None:
@@ -181,13 +355,15 @@ class Runner:
181
355
 
182
356
  try:
183
357
  while True:
184
- all_tools = await cls._get_all_tools(current_agent, context_wrapper)
358
+ all_tools = await AgentRunner._get_all_tools(current_agent, context_wrapper)
185
359
 
186
360
  # Start an agent span if we don't have one. This span is ended if the current
187
361
  # agent changes, or if the agent loop ends.
188
362
  if current_span is None:
189
- handoff_names = [h.agent_name for h in cls._get_handoffs(current_agent)]
190
- if output_schema := cls._get_output_schema(current_agent):
363
+ handoff_names = [
364
+ h.agent_name for h in AgentRunner._get_handoffs(current_agent)
365
+ ]
366
+ if output_schema := AgentRunner._get_output_schema(current_agent):
191
367
  output_type_name = output_schema.name()
192
368
  else:
193
369
  output_type_name = "str"
@@ -217,14 +393,14 @@ class Runner:
217
393
 
218
394
  if current_turn == 1:
219
395
  input_guardrail_results, turn_result = await asyncio.gather(
220
- cls._run_input_guardrails(
396
+ self._run_input_guardrails(
221
397
  starting_agent,
222
398
  starting_agent.input_guardrails
223
399
  + (run_config.input_guardrails or []),
224
400
  copy.deepcopy(input),
225
401
  context_wrapper,
226
402
  ),
227
- cls._run_single_turn(
403
+ self._run_single_turn(
228
404
  agent=current_agent,
229
405
  all_tools=all_tools,
230
406
  original_input=original_input,
@@ -238,7 +414,7 @@ class Runner:
238
414
  ),
239
415
  )
240
416
  else:
241
- turn_result = await cls._run_single_turn(
417
+ turn_result = await self._run_single_turn(
242
418
  agent=current_agent,
243
419
  all_tools=all_tools,
244
420
  original_input=original_input,
@@ -257,7 +433,7 @@ class Runner:
257
433
  generated_items = turn_result.generated_items
258
434
 
259
435
  if isinstance(turn_result.next_step, NextStepFinalOutput):
260
- output_guardrail_results = await cls._run_output_guardrails(
436
+ output_guardrail_results = await self._run_output_guardrails(
261
437
  current_agent.output_guardrails + (run_config.output_guardrails or []),
262
438
  current_agent,
263
439
  turn_result.next_step.output,
@@ -299,54 +475,19 @@ class Runner:
299
475
  if current_span:
300
476
  current_span.finish(reset_current=True)
301
477
 
302
- @classmethod
303
478
  def run_sync(
304
- cls,
479
+ self,
305
480
  starting_agent: Agent[TContext],
306
481
  input: str | list[TResponseInputItem],
307
- *,
308
- context: TContext | None = None,
309
- max_turns: int = DEFAULT_MAX_TURNS,
310
- hooks: RunHooks[TContext] | None = None,
311
- run_config: RunConfig | None = None,
312
- previous_response_id: str | None = None,
482
+ **kwargs: Unpack[RunOptions[TContext]],
313
483
  ) -> RunResult:
314
- """Run a workflow synchronously, starting at the given agent. Note that this just wraps the
315
- `run` method, so it will not work if there's already an event loop (e.g. inside an async
316
- function, or in a Jupyter notebook or async context like FastAPI). For those cases, use
317
- the `run` method instead.
318
-
319
- The agent will run in a loop until a final output is generated. The loop runs like so:
320
- 1. The agent is invoked with the given input.
321
- 2. If there is a final output (i.e. the agent produces something of type
322
- `agent.output_type`, the loop terminates.
323
- 3. If there's a handoff, we run the loop again, with the new agent.
324
- 4. Else, we run tool calls (if any), and re-run the loop.
325
-
326
- In two cases, the agent may raise an exception:
327
- 1. If the max_turns is exceeded, a MaxTurnsExceeded exception is raised.
328
- 2. If a guardrail tripwire is triggered, a GuardrailTripwireTriggered exception is raised.
329
-
330
- Note that only the first agent's input guardrails are run.
331
-
332
- Args:
333
- starting_agent: The starting agent to run.
334
- input: The initial input to the agent. You can pass a single string for a user message,
335
- or a list of input items.
336
- context: The context to run the agent with.
337
- max_turns: The maximum number of turns to run the agent for. A turn is defined as one
338
- AI invocation (including any tool calls that might occur).
339
- hooks: An object that receives callbacks on various lifecycle events.
340
- run_config: Global settings for the entire agent run.
341
- previous_response_id: The ID of the previous response, if using OpenAI models via the
342
- Responses API, this allows you to skip passing in input from the previous turn.
343
-
344
- Returns:
345
- A run result containing all the inputs, guardrail results and the output of the last
346
- agent. Agents may perform handoffs, so we don't know the specific type of the output.
347
- """
484
+ context = kwargs.get("context")
485
+ max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS)
486
+ hooks = kwargs.get("hooks")
487
+ run_config = kwargs.get("run_config")
488
+ previous_response_id = kwargs.get("previous_response_id")
348
489
  return asyncio.get_event_loop().run_until_complete(
349
- cls.run(
490
+ self.run(
350
491
  starting_agent,
351
492
  input,
352
493
  context=context,
@@ -357,47 +498,17 @@ class Runner:
357
498
  )
358
499
  )
359
500
 
360
- @classmethod
361
501
  def run_streamed(
362
- cls,
502
+ self,
363
503
  starting_agent: Agent[TContext],
364
504
  input: str | list[TResponseInputItem],
365
- context: TContext | None = None,
366
- max_turns: int = DEFAULT_MAX_TURNS,
367
- hooks: RunHooks[TContext] | None = None,
368
- run_config: RunConfig | None = None,
369
- previous_response_id: str | None = None,
505
+ **kwargs: Unpack[RunOptions[TContext]],
370
506
  ) -> RunResultStreaming:
371
- """Run a workflow starting at the given agent in streaming mode. The returned result object
372
- contains a method you can use to stream semantic events as they are generated.
373
-
374
- The agent will run in a loop until a final output is generated. The loop runs like so:
375
- 1. The agent is invoked with the given input.
376
- 2. If there is a final output (i.e. the agent produces something of type
377
- `agent.output_type`, the loop terminates.
378
- 3. If there's a handoff, we run the loop again, with the new agent.
379
- 4. Else, we run tool calls (if any), and re-run the loop.
380
-
381
- In two cases, the agent may raise an exception:
382
- 1. If the max_turns is exceeded, a MaxTurnsExceeded exception is raised.
383
- 2. If a guardrail tripwire is triggered, a GuardrailTripwireTriggered exception is raised.
384
-
385
- Note that only the first agent's input guardrails are run.
386
-
387
- Args:
388
- starting_agent: The starting agent to run.
389
- input: The initial input to the agent. You can pass a single string for a user message,
390
- or a list of input items.
391
- context: The context to run the agent with.
392
- max_turns: The maximum number of turns to run the agent for. A turn is defined as one
393
- AI invocation (including any tool calls that might occur).
394
- hooks: An object that receives callbacks on various lifecycle events.
395
- run_config: Global settings for the entire agent run.
396
- previous_response_id: The ID of the previous response, if using OpenAI models via the
397
- Responses API, this allows you to skip passing in input from the previous turn.
398
- Returns:
399
- A result object that contains data about the run, as well as a method to stream events.
400
- """
507
+ context = kwargs.get("context")
508
+ max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS)
509
+ hooks = kwargs.get("hooks")
510
+ run_config = kwargs.get("run_config")
511
+ previous_response_id = kwargs.get("previous_response_id")
401
512
  if hooks is None:
402
513
  hooks = RunHooks[Any]()
403
514
  if run_config is None:
@@ -418,7 +529,7 @@ class Runner:
418
529
  )
419
530
  )
420
531
 
421
- output_schema = cls._get_output_schema(starting_agent)
532
+ output_schema = AgentRunner._get_output_schema(starting_agent)
422
533
  context_wrapper: RunContextWrapper[TContext] = RunContextWrapper(
423
534
  context=context # type: ignore
424
535
  )
@@ -441,7 +552,7 @@ class Runner:
441
552
 
442
553
  # Kick off the actual agent loop in the background and return the streamed result object.
443
554
  streamed_result._run_impl_task = asyncio.create_task(
444
- cls._run_streamed_impl(
555
+ self._start_streaming(
445
556
  starting_input=input,
446
557
  streamed_result=streamed_result,
447
558
  starting_agent=starting_agent,
@@ -498,7 +609,7 @@ class Runner:
498
609
  streamed_result.input_guardrail_results = guardrail_results
499
610
 
500
611
  @classmethod
501
- async def _run_streamed_impl(
612
+ async def _start_streaming(
502
613
  cls,
503
614
  starting_input: str | list[TResponseInputItem],
504
615
  streamed_result: RunResultStreaming,
@@ -682,7 +793,10 @@ class Runner:
682
793
  streamed_result.current_agent = agent
683
794
  streamed_result._current_agent_output_schema = output_schema
684
795
 
685
- system_prompt = await agent.get_system_prompt(context_wrapper)
796
+ system_prompt, prompt_config = await asyncio.gather(
797
+ agent.get_system_prompt(context_wrapper),
798
+ agent.get_prompt(context_wrapper),
799
+ )
686
800
 
687
801
  handoffs = cls._get_handoffs(agent)
688
802
  model = cls._get_model(agent, run_config)
@@ -706,6 +820,7 @@ class Runner:
706
820
  run_config.tracing_disabled, run_config.trace_include_sensitive_data
707
821
  ),
708
822
  previous_response_id=previous_response_id,
823
+ prompt=prompt_config,
709
824
  ):
710
825
  if isinstance(event, ResponseCompletedEvent):
711
826
  usage = (
@@ -777,7 +892,10 @@ class Runner:
777
892
  ),
778
893
  )
779
894
 
780
- system_prompt = await agent.get_system_prompt(context_wrapper)
895
+ system_prompt, prompt_config = await asyncio.gather(
896
+ agent.get_system_prompt(context_wrapper),
897
+ agent.get_prompt(context_wrapper),
898
+ )
781
899
 
782
900
  output_schema = cls._get_output_schema(agent)
783
901
  handoffs = cls._get_handoffs(agent)
@@ -795,6 +913,7 @@ class Runner:
795
913
  run_config,
796
914
  tool_use_tracker,
797
915
  previous_response_id,
916
+ prompt_config,
798
917
  )
799
918
 
800
919
  return await cls._get_single_step_result_from_response(
@@ -938,6 +1057,7 @@ class Runner:
938
1057
  run_config: RunConfig,
939
1058
  tool_use_tracker: AgentToolUseTracker,
940
1059
  previous_response_id: str | None,
1060
+ prompt_config: ResponsePromptParam | None,
941
1061
  ) -> ModelResponse:
942
1062
  model = cls._get_model(agent, run_config)
943
1063
  model_settings = agent.model_settings.resolve(run_config.model_settings)
@@ -954,6 +1074,7 @@ class Runner:
954
1074
  run_config.tracing_disabled, run_config.trace_include_sensitive_data
955
1075
  ),
956
1076
  previous_response_id=previous_response_id,
1077
+ prompt=prompt_config,
957
1078
  )
958
1079
 
959
1080
  context_wrapper.usage.add(new_response.usage)
@@ -995,3 +1116,6 @@ class Runner:
995
1116
  return agent.model
996
1117
 
997
1118
  return run_config.model_provider.get_model(agent.model)
1119
+
1120
+
1121
+ DEFAULT_AGENT_RUNNER = AgentRunner()
agents/tool.py CHANGED
@@ -20,6 +20,7 @@ from .function_schema import DocstringStyle, function_schema
20
20
  from .items import RunItem
21
21
  from .logger import logger
22
22
  from .run_context import RunContextWrapper
23
+ from .tool_context import ToolContext
23
24
  from .tracing import SpanError
24
25
  from .util import _error_tracing
25
26
  from .util._types import MaybeAwaitable
@@ -31,8 +32,13 @@ ToolParams = ParamSpec("ToolParams")
31
32
 
32
33
  ToolFunctionWithoutContext = Callable[ToolParams, Any]
33
34
  ToolFunctionWithContext = Callable[Concatenate[RunContextWrapper[Any], ToolParams], Any]
35
+ ToolFunctionWithToolContext = Callable[Concatenate[ToolContext, ToolParams], Any]
34
36
 
35
- ToolFunction = Union[ToolFunctionWithoutContext[ToolParams], ToolFunctionWithContext[ToolParams]]
37
+ ToolFunction = Union[
38
+ ToolFunctionWithoutContext[ToolParams],
39
+ ToolFunctionWithContext[ToolParams],
40
+ ToolFunctionWithToolContext[ToolParams],
41
+ ]
36
42
 
37
43
 
38
44
  @dataclass
@@ -62,7 +68,7 @@ class FunctionTool:
62
68
  params_json_schema: dict[str, Any]
63
69
  """The JSON schema for the tool's parameters."""
64
70
 
65
- on_invoke_tool: Callable[[RunContextWrapper[Any], str], Awaitable[Any]]
71
+ on_invoke_tool: Callable[[ToolContext[Any], str], Awaitable[Any]]
66
72
  """A function that invokes the tool with the given context and parameters. The params passed
67
73
  are:
68
74
  1. The tool run context.
@@ -344,7 +350,7 @@ def function_tool(
344
350
  strict_json_schema=strict_mode,
345
351
  )
346
352
 
347
- async def _on_invoke_tool_impl(ctx: RunContextWrapper[Any], input: str) -> Any:
353
+ async def _on_invoke_tool_impl(ctx: ToolContext[Any], input: str) -> Any:
348
354
  try:
349
355
  json_data: dict[str, Any] = json.loads(input) if input else {}
350
356
  except Exception as e:
@@ -393,7 +399,7 @@ def function_tool(
393
399
 
394
400
  return result
395
401
 
396
- async def _on_invoke_tool(ctx: RunContextWrapper[Any], input: str) -> Any:
402
+ async def _on_invoke_tool(ctx: ToolContext[Any], input: str) -> Any:
397
403
  try:
398
404
  return await _on_invoke_tool_impl(ctx, input)
399
405
  except Exception as e:
agents/tool_context.py ADDED
@@ -0,0 +1,29 @@
1
+ from dataclasses import dataclass, field, fields
2
+ from typing import Any
3
+
4
+ from .run_context import RunContextWrapper, TContext
5
+
6
+
7
+ def _assert_must_pass_tool_call_id() -> str:
8
+ raise ValueError("tool_call_id must be passed to ToolContext")
9
+
10
+
11
+ @dataclass
12
+ class ToolContext(RunContextWrapper[TContext]):
13
+ """The context of a tool call."""
14
+
15
+ tool_call_id: str = field(default_factory=_assert_must_pass_tool_call_id)
16
+ """The ID of the tool call."""
17
+
18
+ @classmethod
19
+ def from_agent_context(
20
+ cls, context: RunContextWrapper[TContext], tool_call_id: str
21
+ ) -> "ToolContext":
22
+ """
23
+ Create a ToolContext from a RunContextWrapper.
24
+ """
25
+ # Grab the names of the RunContextWrapper's init=True fields
26
+ base_values: dict[str, Any] = {
27
+ f.name: getattr(context, f.name) for f in fields(RunContextWrapper) if f.init
28
+ }
29
+ return cls(tool_call_id=tool_call_id, **base_values)
@@ -1,5 +1,7 @@
1
1
  import atexit
2
2
 
3
+ from agents.tracing.provider import DefaultTraceProvider, TraceProvider
4
+
3
5
  from .create import (
4
6
  agent_span,
5
7
  custom_span,
@@ -18,7 +20,7 @@ from .create import (
18
20
  )
19
21
  from .processor_interface import TracingProcessor
20
22
  from .processors import default_exporter, default_processor
21
- from .setup import GLOBAL_TRACE_PROVIDER
23
+ from .setup import get_trace_provider, set_trace_provider
22
24
  from .span_data import (
23
25
  AgentSpanData,
24
26
  CustomSpanData,
@@ -45,10 +47,12 @@ __all__ = [
45
47
  "generation_span",
46
48
  "get_current_span",
47
49
  "get_current_trace",
50
+ "get_trace_provider",
48
51
  "guardrail_span",
49
52
  "handoff_span",
50
53
  "response_span",
51
54
  "set_trace_processors",
55
+ "set_trace_provider",
52
56
  "set_tracing_disabled",
53
57
  "trace",
54
58
  "Trace",
@@ -67,6 +71,7 @@ __all__ = [
67
71
  "SpeechSpanData",
68
72
  "TranscriptionSpanData",
69
73
  "TracingProcessor",
74
+ "TraceProvider",
70
75
  "gen_trace_id",
71
76
  "gen_span_id",
72
77
  "speech_group_span",
@@ -80,21 +85,21 @@ def add_trace_processor(span_processor: TracingProcessor) -> None:
80
85
  """
81
86
  Adds a new trace processor. This processor will receive all traces/spans.
82
87
  """
83
- GLOBAL_TRACE_PROVIDER.register_processor(span_processor)
88
+ get_trace_provider().register_processor(span_processor)
84
89
 
85
90
 
86
91
  def set_trace_processors(processors: list[TracingProcessor]) -> None:
87
92
  """
88
93
  Set the list of trace processors. This will replace the current list of processors.
89
94
  """
90
- GLOBAL_TRACE_PROVIDER.set_processors(processors)
95
+ get_trace_provider().set_processors(processors)
91
96
 
92
97
 
93
98
  def set_tracing_disabled(disabled: bool) -> None:
94
99
  """
95
100
  Set whether tracing is globally disabled.
96
101
  """
97
- GLOBAL_TRACE_PROVIDER.set_disabled(disabled)
102
+ get_trace_provider().set_disabled(disabled)
98
103
 
99
104
 
100
105
  def set_tracing_export_api_key(api_key: str) -> None:
@@ -104,10 +109,11 @@ def set_tracing_export_api_key(api_key: str) -> None:
104
109
  default_exporter().set_api_key(api_key)
105
110
 
106
111
 
112
+ set_trace_provider(DefaultTraceProvider())
107
113
  # Add the default processor, which exports traces and spans to the backend in batches. You can
108
114
  # change the default behavior by either:
109
115
  # 1. calling add_trace_processor(), which adds additional processors, or
110
116
  # 2. calling set_trace_processors(), which replaces the default processor.
111
117
  add_trace_processor(default_processor())
112
118
 
113
- atexit.register(GLOBAL_TRACE_PROVIDER.shutdown)
119
+ atexit.register(get_trace_provider().shutdown)