openai-agents 0.0.18__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of openai-agents might be problematic. Click here for more details.

agents/run.py CHANGED
@@ -2,13 +2,15 @@ from __future__ import annotations
2
2
 
3
3
  import asyncio
4
4
  import copy
5
+ import inspect
5
6
  from dataclasses import dataclass, field
6
- from typing import Any, cast
7
+ from typing import Any, Generic, cast
7
8
 
8
9
  from openai.types.responses import ResponseCompletedEvent
9
10
  from openai.types.responses.response_prompt_param import (
10
11
  ResponsePromptParam,
11
12
  )
13
+ from typing_extensions import NotRequired, TypedDict, Unpack
12
14
 
13
15
  from ._run_impl import (
14
16
  AgentToolUseTracker,
@@ -31,7 +33,12 @@ from .exceptions import (
31
33
  OutputGuardrailTripwireTriggered,
32
34
  RunErrorDetails,
33
35
  )
34
- from .guardrail import InputGuardrail, InputGuardrailResult, OutputGuardrail, OutputGuardrailResult
36
+ from .guardrail import (
37
+ InputGuardrail,
38
+ InputGuardrailResult,
39
+ OutputGuardrail,
40
+ OutputGuardrailResult,
41
+ )
35
42
  from .handoffs import Handoff, HandoffInputFilter, handoff
36
43
  from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem
37
44
  from .lifecycle import RunHooks
@@ -50,6 +57,27 @@ from .util import _coro, _error_tracing
50
57
 
51
58
  DEFAULT_MAX_TURNS = 10
52
59
 
60
+ DEFAULT_AGENT_RUNNER: AgentRunner = None # type: ignore
61
+ # the value is set at the end of the module
62
+
63
+
64
+ def set_default_agent_runner(runner: AgentRunner | None) -> None:
65
+ """
66
+ WARNING: this class is experimental and not part of the public API
67
+ It should not be used directly.
68
+ """
69
+ global DEFAULT_AGENT_RUNNER
70
+ DEFAULT_AGENT_RUNNER = runner or AgentRunner()
71
+
72
+
73
+ def get_default_agent_runner() -> AgentRunner:
74
+ """
75
+ WARNING: this class is experimental and not part of the public API
76
+ It should not be used directly.
77
+ """
78
+ global DEFAULT_AGENT_RUNNER
79
+ return DEFAULT_AGENT_RUNNER
80
+
53
81
 
54
82
  @dataclass
55
83
  class RunConfig:
@@ -110,6 +138,25 @@ class RunConfig:
110
138
  """
111
139
 
112
140
 
141
+ class RunOptions(TypedDict, Generic[TContext]):
142
+ """Arguments for ``AgentRunner`` methods."""
143
+
144
+ context: NotRequired[TContext | None]
145
+ """The context for the run."""
146
+
147
+ max_turns: NotRequired[int]
148
+ """The maximum number of turns to run for."""
149
+
150
+ hooks: NotRequired[RunHooks[TContext] | None]
151
+ """Lifecycle hooks for the run."""
152
+
153
+ run_config: NotRequired[RunConfig | None]
154
+ """Run configuration."""
155
+
156
+ previous_response_id: NotRequired[str | None]
157
+ """The ID of the previous response, if any."""
158
+
159
+
113
160
  class Runner:
114
161
  @classmethod
115
162
  async def run(
@@ -130,13 +177,10 @@ class Runner:
130
177
  `agent.output_type`, the loop terminates.
131
178
  3. If there's a handoff, we run the loop again, with the new agent.
132
179
  4. Else, we run tool calls (if any), and re-run the loop.
133
-
134
180
  In two cases, the agent may raise an exception:
135
181
  1. If the max_turns is exceeded, a MaxTurnsExceeded exception is raised.
136
182
  2. If a guardrail tripwire is triggered, a GuardrailTripwireTriggered exception is raised.
137
-
138
183
  Note that only the first agent's input guardrails are run.
139
-
140
184
  Args:
141
185
  starting_agent: The starting agent to run.
142
186
  input: The initial input to the agent. You can pass a single string for a user message,
@@ -148,11 +192,139 @@ class Runner:
148
192
  run_config: Global settings for the entire agent run.
149
193
  previous_response_id: The ID of the previous response, if using OpenAI models via the
150
194
  Responses API, this allows you to skip passing in input from the previous turn.
195
+ Returns:
196
+ A run result containing all the inputs, guardrail results and the output of the last
197
+ agent. Agents may perform handoffs, so we don't know the specific type of the output.
198
+ """
199
+ runner = DEFAULT_AGENT_RUNNER
200
+ return await runner.run(
201
+ starting_agent,
202
+ input,
203
+ context=context,
204
+ max_turns=max_turns,
205
+ hooks=hooks,
206
+ run_config=run_config,
207
+ previous_response_id=previous_response_id,
208
+ )
151
209
 
210
+ @classmethod
211
+ def run_sync(
212
+ cls,
213
+ starting_agent: Agent[TContext],
214
+ input: str | list[TResponseInputItem],
215
+ *,
216
+ context: TContext | None = None,
217
+ max_turns: int = DEFAULT_MAX_TURNS,
218
+ hooks: RunHooks[TContext] | None = None,
219
+ run_config: RunConfig | None = None,
220
+ previous_response_id: str | None = None,
221
+ ) -> RunResult:
222
+ """Run a workflow synchronously, starting at the given agent. Note that this just wraps the
223
+ `run` method, so it will not work if there's already an event loop (e.g. inside an async
224
+ function, or in a Jupyter notebook or async context like FastAPI). For those cases, use
225
+ the `run` method instead.
226
+ The agent will run in a loop until a final output is generated. The loop runs like so:
227
+ 1. The agent is invoked with the given input.
228
+ 2. If there is a final output (i.e. the agent produces something of type
229
+ `agent.output_type`, the loop terminates.
230
+ 3. If there's a handoff, we run the loop again, with the new agent.
231
+ 4. Else, we run tool calls (if any), and re-run the loop.
232
+ In two cases, the agent may raise an exception:
233
+ 1. If the max_turns is exceeded, a MaxTurnsExceeded exception is raised.
234
+ 2. If a guardrail tripwire is triggered, a GuardrailTripwireTriggered exception is raised.
235
+ Note that only the first agent's input guardrails are run.
236
+ Args:
237
+ starting_agent: The starting agent to run.
238
+ input: The initial input to the agent. You can pass a single string for a user message,
239
+ or a list of input items.
240
+ context: The context to run the agent with.
241
+ max_turns: The maximum number of turns to run the agent for. A turn is defined as one
242
+ AI invocation (including any tool calls that might occur).
243
+ hooks: An object that receives callbacks on various lifecycle events.
244
+ run_config: Global settings for the entire agent run.
245
+ previous_response_id: The ID of the previous response, if using OpenAI models via the
246
+ Responses API, this allows you to skip passing in input from the previous turn.
152
247
  Returns:
153
248
  A run result containing all the inputs, guardrail results and the output of the last
154
249
  agent. Agents may perform handoffs, so we don't know the specific type of the output.
155
250
  """
251
+ runner = DEFAULT_AGENT_RUNNER
252
+ return runner.run_sync(
253
+ starting_agent,
254
+ input,
255
+ context=context,
256
+ max_turns=max_turns,
257
+ hooks=hooks,
258
+ run_config=run_config,
259
+ previous_response_id=previous_response_id,
260
+ )
261
+
262
+ @classmethod
263
+ def run_streamed(
264
+ cls,
265
+ starting_agent: Agent[TContext],
266
+ input: str | list[TResponseInputItem],
267
+ context: TContext | None = None,
268
+ max_turns: int = DEFAULT_MAX_TURNS,
269
+ hooks: RunHooks[TContext] | None = None,
270
+ run_config: RunConfig | None = None,
271
+ previous_response_id: str | None = None,
272
+ ) -> RunResultStreaming:
273
+ """Run a workflow starting at the given agent in streaming mode. The returned result object
274
+ contains a method you can use to stream semantic events as they are generated.
275
+ The agent will run in a loop until a final output is generated. The loop runs like so:
276
+ 1. The agent is invoked with the given input.
277
+ 2. If there is a final output (i.e. the agent produces something of type
278
+ `agent.output_type`, the loop terminates.
279
+ 3. If there's a handoff, we run the loop again, with the new agent.
280
+ 4. Else, we run tool calls (if any), and re-run the loop.
281
+ In two cases, the agent may raise an exception:
282
+ 1. If the max_turns is exceeded, a MaxTurnsExceeded exception is raised.
283
+ 2. If a guardrail tripwire is triggered, a GuardrailTripwireTriggered exception is raised.
284
+ Note that only the first agent's input guardrails are run.
285
+ Args:
286
+ starting_agent: The starting agent to run.
287
+ input: The initial input to the agent. You can pass a single string for a user message,
288
+ or a list of input items.
289
+ context: The context to run the agent with.
290
+ max_turns: The maximum number of turns to run the agent for. A turn is defined as one
291
+ AI invocation (including any tool calls that might occur).
292
+ hooks: An object that receives callbacks on various lifecycle events.
293
+ run_config: Global settings for the entire agent run.
294
+ previous_response_id: The ID of the previous response, if using OpenAI models via the
295
+ Responses API, this allows you to skip passing in input from the previous turn.
296
+ Returns:
297
+ A result object that contains data about the run, as well as a method to stream events.
298
+ """
299
+ runner = DEFAULT_AGENT_RUNNER
300
+ return runner.run_streamed(
301
+ starting_agent,
302
+ input,
303
+ context=context,
304
+ max_turns=max_turns,
305
+ hooks=hooks,
306
+ run_config=run_config,
307
+ previous_response_id=previous_response_id,
308
+ )
309
+
310
+
311
+ class AgentRunner:
312
+ """
313
+ WARNING: this class is experimental and not part of the public API
314
+ It should not be used directly or subclassed.
315
+ """
316
+
317
+ async def run(
318
+ self,
319
+ starting_agent: Agent[TContext],
320
+ input: str | list[TResponseInputItem],
321
+ **kwargs: Unpack[RunOptions[TContext]],
322
+ ) -> RunResult:
323
+ context = kwargs.get("context")
324
+ max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS)
325
+ hooks = kwargs.get("hooks")
326
+ run_config = kwargs.get("run_config")
327
+ previous_response_id = kwargs.get("previous_response_id")
156
328
  if hooks is None:
157
329
  hooks = RunHooks[Any]()
158
330
  if run_config is None:
@@ -184,13 +356,16 @@ class Runner:
184
356
 
185
357
  try:
186
358
  while True:
187
- all_tools = await cls._get_all_tools(current_agent, context_wrapper)
359
+ all_tools = await AgentRunner._get_all_tools(current_agent, context_wrapper)
188
360
 
189
361
  # Start an agent span if we don't have one. This span is ended if the current
190
362
  # agent changes, or if the agent loop ends.
191
363
  if current_span is None:
192
- handoff_names = [h.agent_name for h in cls._get_handoffs(current_agent)]
193
- if output_schema := cls._get_output_schema(current_agent):
364
+ handoff_names = [
365
+ h.agent_name
366
+ for h in await AgentRunner._get_handoffs(current_agent, context_wrapper)
367
+ ]
368
+ if output_schema := AgentRunner._get_output_schema(current_agent):
194
369
  output_type_name = output_schema.name()
195
370
  else:
196
371
  output_type_name = "str"
@@ -220,14 +395,14 @@ class Runner:
220
395
 
221
396
  if current_turn == 1:
222
397
  input_guardrail_results, turn_result = await asyncio.gather(
223
- cls._run_input_guardrails(
398
+ self._run_input_guardrails(
224
399
  starting_agent,
225
400
  starting_agent.input_guardrails
226
401
  + (run_config.input_guardrails or []),
227
402
  copy.deepcopy(input),
228
403
  context_wrapper,
229
404
  ),
230
- cls._run_single_turn(
405
+ self._run_single_turn(
231
406
  agent=current_agent,
232
407
  all_tools=all_tools,
233
408
  original_input=original_input,
@@ -241,7 +416,7 @@ class Runner:
241
416
  ),
242
417
  )
243
418
  else:
244
- turn_result = await cls._run_single_turn(
419
+ turn_result = await self._run_single_turn(
245
420
  agent=current_agent,
246
421
  all_tools=all_tools,
247
422
  original_input=original_input,
@@ -260,7 +435,7 @@ class Runner:
260
435
  generated_items = turn_result.generated_items
261
436
 
262
437
  if isinstance(turn_result.next_step, NextStepFinalOutput):
263
- output_guardrail_results = await cls._run_output_guardrails(
438
+ output_guardrail_results = await self._run_output_guardrails(
264
439
  current_agent.output_guardrails + (run_config.output_guardrails or []),
265
440
  current_agent,
266
441
  turn_result.next_step.output,
@@ -302,54 +477,19 @@ class Runner:
302
477
  if current_span:
303
478
  current_span.finish(reset_current=True)
304
479
 
305
- @classmethod
306
480
  def run_sync(
307
- cls,
481
+ self,
308
482
  starting_agent: Agent[TContext],
309
483
  input: str | list[TResponseInputItem],
310
- *,
311
- context: TContext | None = None,
312
- max_turns: int = DEFAULT_MAX_TURNS,
313
- hooks: RunHooks[TContext] | None = None,
314
- run_config: RunConfig | None = None,
315
- previous_response_id: str | None = None,
484
+ **kwargs: Unpack[RunOptions[TContext]],
316
485
  ) -> RunResult:
317
- """Run a workflow synchronously, starting at the given agent. Note that this just wraps the
318
- `run` method, so it will not work if there's already an event loop (e.g. inside an async
319
- function, or in a Jupyter notebook or async context like FastAPI). For those cases, use
320
- the `run` method instead.
321
-
322
- The agent will run in a loop until a final output is generated. The loop runs like so:
323
- 1. The agent is invoked with the given input.
324
- 2. If there is a final output (i.e. the agent produces something of type
325
- `agent.output_type`, the loop terminates.
326
- 3. If there's a handoff, we run the loop again, with the new agent.
327
- 4. Else, we run tool calls (if any), and re-run the loop.
328
-
329
- In two cases, the agent may raise an exception:
330
- 1. If the max_turns is exceeded, a MaxTurnsExceeded exception is raised.
331
- 2. If a guardrail tripwire is triggered, a GuardrailTripwireTriggered exception is raised.
332
-
333
- Note that only the first agent's input guardrails are run.
334
-
335
- Args:
336
- starting_agent: The starting agent to run.
337
- input: The initial input to the agent. You can pass a single string for a user message,
338
- or a list of input items.
339
- context: The context to run the agent with.
340
- max_turns: The maximum number of turns to run the agent for. A turn is defined as one
341
- AI invocation (including any tool calls that might occur).
342
- hooks: An object that receives callbacks on various lifecycle events.
343
- run_config: Global settings for the entire agent run.
344
- previous_response_id: The ID of the previous response, if using OpenAI models via the
345
- Responses API, this allows you to skip passing in input from the previous turn.
346
-
347
- Returns:
348
- A run result containing all the inputs, guardrail results and the output of the last
349
- agent. Agents may perform handoffs, so we don't know the specific type of the output.
350
- """
486
+ context = kwargs.get("context")
487
+ max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS)
488
+ hooks = kwargs.get("hooks")
489
+ run_config = kwargs.get("run_config")
490
+ previous_response_id = kwargs.get("previous_response_id")
351
491
  return asyncio.get_event_loop().run_until_complete(
352
- cls.run(
492
+ self.run(
353
493
  starting_agent,
354
494
  input,
355
495
  context=context,
@@ -360,47 +500,17 @@ class Runner:
360
500
  )
361
501
  )
362
502
 
363
- @classmethod
364
503
  def run_streamed(
365
- cls,
504
+ self,
366
505
  starting_agent: Agent[TContext],
367
506
  input: str | list[TResponseInputItem],
368
- context: TContext | None = None,
369
- max_turns: int = DEFAULT_MAX_TURNS,
370
- hooks: RunHooks[TContext] | None = None,
371
- run_config: RunConfig | None = None,
372
- previous_response_id: str | None = None,
507
+ **kwargs: Unpack[RunOptions[TContext]],
373
508
  ) -> RunResultStreaming:
374
- """Run a workflow starting at the given agent in streaming mode. The returned result object
375
- contains a method you can use to stream semantic events as they are generated.
376
-
377
- The agent will run in a loop until a final output is generated. The loop runs like so:
378
- 1. The agent is invoked with the given input.
379
- 2. If there is a final output (i.e. the agent produces something of type
380
- `agent.output_type`, the loop terminates.
381
- 3. If there's a handoff, we run the loop again, with the new agent.
382
- 4. Else, we run tool calls (if any), and re-run the loop.
383
-
384
- In two cases, the agent may raise an exception:
385
- 1. If the max_turns is exceeded, a MaxTurnsExceeded exception is raised.
386
- 2. If a guardrail tripwire is triggered, a GuardrailTripwireTriggered exception is raised.
387
-
388
- Note that only the first agent's input guardrails are run.
389
-
390
- Args:
391
- starting_agent: The starting agent to run.
392
- input: The initial input to the agent. You can pass a single string for a user message,
393
- or a list of input items.
394
- context: The context to run the agent with.
395
- max_turns: The maximum number of turns to run the agent for. A turn is defined as one
396
- AI invocation (including any tool calls that might occur).
397
- hooks: An object that receives callbacks on various lifecycle events.
398
- run_config: Global settings for the entire agent run.
399
- previous_response_id: The ID of the previous response, if using OpenAI models via the
400
- Responses API, this allows you to skip passing in input from the previous turn.
401
- Returns:
402
- A result object that contains data about the run, as well as a method to stream events.
403
- """
509
+ context = kwargs.get("context")
510
+ max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS)
511
+ hooks = kwargs.get("hooks")
512
+ run_config = kwargs.get("run_config")
513
+ previous_response_id = kwargs.get("previous_response_id")
404
514
  if hooks is None:
405
515
  hooks = RunHooks[Any]()
406
516
  if run_config is None:
@@ -421,7 +531,7 @@ class Runner:
421
531
  )
422
532
  )
423
533
 
424
- output_schema = cls._get_output_schema(starting_agent)
534
+ output_schema = AgentRunner._get_output_schema(starting_agent)
425
535
  context_wrapper: RunContextWrapper[TContext] = RunContextWrapper(
426
536
  context=context # type: ignore
427
537
  )
@@ -444,7 +554,7 @@ class Runner:
444
554
 
445
555
  # Kick off the actual agent loop in the background and return the streamed result object.
446
556
  streamed_result._run_impl_task = asyncio.create_task(
447
- cls._run_streamed_impl(
557
+ self._start_streaming(
448
558
  starting_input=input,
449
559
  streamed_result=streamed_result,
450
560
  starting_agent=starting_agent,
@@ -501,7 +611,7 @@ class Runner:
501
611
  streamed_result.input_guardrail_results = guardrail_results
502
612
 
503
613
  @classmethod
504
- async def _run_streamed_impl(
614
+ async def _start_streaming(
505
615
  cls,
506
616
  starting_input: str | list[TResponseInputItem],
507
617
  streamed_result: RunResultStreaming,
@@ -533,7 +643,10 @@ class Runner:
533
643
  # Start an agent span if we don't have one. This span is ended if the current
534
644
  # agent changes, or if the agent loop ends.
535
645
  if current_span is None:
536
- handoff_names = [h.agent_name for h in cls._get_handoffs(current_agent)]
646
+ handoff_names = [
647
+ h.agent_name
648
+ for h in await cls._get_handoffs(current_agent, context_wrapper)
649
+ ]
537
650
  if output_schema := cls._get_output_schema(current_agent):
538
651
  output_type_name = output_schema.name()
539
652
  else:
@@ -690,7 +803,7 @@ class Runner:
690
803
  agent.get_prompt(context_wrapper),
691
804
  )
692
805
 
693
- handoffs = cls._get_handoffs(agent)
806
+ handoffs = await cls._get_handoffs(agent, context_wrapper)
694
807
  model = cls._get_model(agent, run_config)
695
808
  model_settings = agent.model_settings.resolve(run_config.model_settings)
696
809
  model_settings = RunImpl.maybe_reset_tool_choice(agent, tool_use_tracker, model_settings)
@@ -790,7 +903,7 @@ class Runner:
790
903
  )
791
904
 
792
905
  output_schema = cls._get_output_schema(agent)
793
- handoffs = cls._get_handoffs(agent)
906
+ handoffs = await cls._get_handoffs(agent, context_wrapper)
794
907
  input = ItemHelpers.input_to_new_input_list(original_input)
795
908
  input.extend([generated_item.to_input_item() for generated_item in generated_items])
796
909
 
@@ -983,14 +1096,28 @@ class Runner:
983
1096
  return AgentOutputSchema(agent.output_type)
984
1097
 
985
1098
  @classmethod
986
- def _get_handoffs(cls, agent: Agent[Any]) -> list[Handoff]:
1099
+ async def _get_handoffs(
1100
+ cls, agent: Agent[Any], context_wrapper: RunContextWrapper[Any]
1101
+ ) -> list[Handoff]:
987
1102
  handoffs = []
988
1103
  for handoff_item in agent.handoffs:
989
1104
  if isinstance(handoff_item, Handoff):
990
1105
  handoffs.append(handoff_item)
991
1106
  elif isinstance(handoff_item, Agent):
992
1107
  handoffs.append(handoff(handoff_item))
993
- return handoffs
1108
+
1109
+ async def _check_handoff_enabled(handoff_obj: Handoff) -> bool:
1110
+ attr = handoff_obj.is_enabled
1111
+ if isinstance(attr, bool):
1112
+ return attr
1113
+ res = attr(context_wrapper, agent)
1114
+ if inspect.isawaitable(res):
1115
+ return bool(await res)
1116
+ return bool(res)
1117
+
1118
+ results = await asyncio.gather(*(_check_handoff_enabled(h) for h in handoffs))
1119
+ enabled: list[Handoff] = [h for h, ok in zip(handoffs, results) if ok]
1120
+ return enabled
994
1121
 
995
1122
  @classmethod
996
1123
  async def _get_all_tools(
@@ -1008,3 +1135,6 @@ class Runner:
1008
1135
  return agent.model
1009
1136
 
1010
1137
  return run_config.model_provider.get_model(agent.model)
1138
+
1139
+
1140
+ DEFAULT_AGENT_RUNNER = AgentRunner()
agents/tool.py CHANGED
@@ -7,6 +7,10 @@ from dataclasses import dataclass
7
7
  from typing import TYPE_CHECKING, Any, Callable, Literal, Union, overload
8
8
 
9
9
  from openai.types.responses.file_search_tool_param import Filters, RankingOptions
10
+ from openai.types.responses.response_computer_tool_call import (
11
+ PendingSafetyCheck,
12
+ ResponseComputerToolCall,
13
+ )
10
14
  from openai.types.responses.response_output_item import LocalShellCall, McpApprovalRequest
11
15
  from openai.types.responses.tool_param import CodeInterpreter, ImageGeneration, Mcp
12
16
  from openai.types.responses.web_search_tool_param import UserLocation
@@ -26,6 +30,7 @@ from .util import _error_tracing
26
30
  from .util._types import MaybeAwaitable
27
31
 
28
32
  if TYPE_CHECKING:
33
+
29
34
  from .agent import Agent
30
35
 
31
36
  ToolParams = ParamSpec("ToolParams")
@@ -141,11 +146,31 @@ class ComputerTool:
141
146
  as well as implements the computer actions like click, screenshot, etc.
142
147
  """
143
148
 
149
+ on_safety_check: Callable[[ComputerToolSafetyCheckData], MaybeAwaitable[bool]] | None = None
150
+ """Optional callback to acknowledge computer tool safety checks."""
151
+
144
152
  @property
145
153
  def name(self):
146
154
  return "computer_use_preview"
147
155
 
148
156
 
157
+ @dataclass
158
+ class ComputerToolSafetyCheckData:
159
+ """Information about a computer tool safety check."""
160
+
161
+ ctx_wrapper: RunContextWrapper[Any]
162
+ """The run context."""
163
+
164
+ agent: Agent[Any]
165
+ """The agent performing the computer action."""
166
+
167
+ tool_call: ResponseComputerToolCall
168
+ """The computer tool call."""
169
+
170
+ safety_check: PendingSafetyCheck
171
+ """The pending safety check to acknowledge."""
172
+
173
+
149
174
  @dataclass
150
175
  class MCPToolApprovalRequest:
151
176
  """A request to approve a tool call."""
@@ -18,7 +18,8 @@ from .create import (
18
18
  )
19
19
  from .processor_interface import TracingProcessor
20
20
  from .processors import default_exporter, default_processor
21
- from .setup import GLOBAL_TRACE_PROVIDER
21
+ from .provider import DefaultTraceProvider, TraceProvider
22
+ from .setup import get_trace_provider, set_trace_provider
22
23
  from .span_data import (
23
24
  AgentSpanData,
24
25
  CustomSpanData,
@@ -45,10 +46,12 @@ __all__ = [
45
46
  "generation_span",
46
47
  "get_current_span",
47
48
  "get_current_trace",
49
+ "get_trace_provider",
48
50
  "guardrail_span",
49
51
  "handoff_span",
50
52
  "response_span",
51
53
  "set_trace_processors",
54
+ "set_trace_provider",
52
55
  "set_tracing_disabled",
53
56
  "trace",
54
57
  "Trace",
@@ -67,6 +70,7 @@ __all__ = [
67
70
  "SpeechSpanData",
68
71
  "TranscriptionSpanData",
69
72
  "TracingProcessor",
73
+ "TraceProvider",
70
74
  "gen_trace_id",
71
75
  "gen_span_id",
72
76
  "speech_group_span",
@@ -80,21 +84,21 @@ def add_trace_processor(span_processor: TracingProcessor) -> None:
80
84
  """
81
85
  Adds a new trace processor. This processor will receive all traces/spans.
82
86
  """
83
- GLOBAL_TRACE_PROVIDER.register_processor(span_processor)
87
+ get_trace_provider().register_processor(span_processor)
84
88
 
85
89
 
86
90
  def set_trace_processors(processors: list[TracingProcessor]) -> None:
87
91
  """
88
92
  Set the list of trace processors. This will replace the current list of processors.
89
93
  """
90
- GLOBAL_TRACE_PROVIDER.set_processors(processors)
94
+ get_trace_provider().set_processors(processors)
91
95
 
92
96
 
93
97
  def set_tracing_disabled(disabled: bool) -> None:
94
98
  """
95
99
  Set whether tracing is globally disabled.
96
100
  """
97
- GLOBAL_TRACE_PROVIDER.set_disabled(disabled)
101
+ get_trace_provider().set_disabled(disabled)
98
102
 
99
103
 
100
104
  def set_tracing_export_api_key(api_key: str) -> None:
@@ -104,10 +108,11 @@ def set_tracing_export_api_key(api_key: str) -> None:
104
108
  default_exporter().set_api_key(api_key)
105
109
 
106
110
 
111
+ set_trace_provider(DefaultTraceProvider())
107
112
  # Add the default processor, which exports traces and spans to the backend in batches. You can
108
113
  # change the default behavior by either:
109
114
  # 1. calling add_trace_processor(), which adds additional processors, or
110
115
  # 2. calling set_trace_processors(), which replaces the default processor.
111
116
  add_trace_processor(default_processor())
112
117
 
113
- atexit.register(GLOBAL_TRACE_PROVIDER.shutdown)
118
+ atexit.register(get_trace_provider().shutdown)