meshagent-openai 0.0.37__py3-none-any.whl → 0.0.38__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of meshagent-openai might be problematic. Click here for more details.

@@ -1,17 +1,25 @@
1
-
2
- from meshagent.agents.agent import Agent, AgentChatContext, AgentCallContext
3
- from meshagent.api import WebSocketClientProtocol, RoomClient, RoomException
1
+ from meshagent.agents.agent import AgentChatContext
2
+ from meshagent.api import RoomClient, RoomException
4
3
  from meshagent.tools.blob import Blob, BlobStorage
5
4
  from meshagent.tools import Toolkit, ToolContext, Tool, BaseTool
6
- from meshagent.api.messaging import Response, LinkResponse, FileResponse, JsonResponse, TextResponse, EmptyResponse, RawOutputs, ensure_response
5
+ from meshagent.api.messaging import (
6
+ Response,
7
+ LinkResponse,
8
+ FileResponse,
9
+ JsonResponse,
10
+ TextResponse,
11
+ EmptyResponse,
12
+ RawOutputs,
13
+ ensure_response,
14
+ )
7
15
  from meshagent.agents.adapter import ToolResponseAdapter, LLMAdapter
8
16
  import json
9
17
  from typing import List, Literal
10
18
  from meshagent.openai.proxy import get_client
11
- from openai import AsyncOpenAI, APIStatusError, NOT_GIVEN, APIStatusError
12
- from openai.types.responses import ResponseFunctionToolCall, ResponseComputerToolCall, ResponseStreamEvent, ResponseImageGenCallCompletedEvent
19
+ from openai import AsyncOpenAI, NOT_GIVEN, APIStatusError
20
+ from openai.types.responses import ResponseFunctionToolCall, ResponseStreamEvent
13
21
  import os
14
- from typing import Optional, Any, Callable
22
+ from typing import Optional, Callable
15
23
  import base64
16
24
 
17
25
  import logging
@@ -20,30 +28,28 @@ import asyncio
20
28
 
21
29
  from pydantic import BaseModel
22
30
  import copy
23
-
24
- logger = logging.getLogger("openai_agent")
25
-
26
31
  from opentelemetry import trace
27
32
 
33
+ logger = logging.getLogger("openai_agent")
28
34
  tracer = trace.get_tracer("openai.llm.responses")
29
35
 
30
36
 
31
37
  def safe_json_dump(data: dict):
32
-
33
38
  return json.dumps(copy.deepcopy(data))
34
39
 
40
+
35
41
  def safe_model_dump(model: BaseModel):
36
42
  try:
37
- return safe_json_dump(model.model_dump(mode='json'))
38
- except:
39
- return {"error":"unable to dump json for model"}
43
+ return safe_json_dump(model.model_dump(mode="json"))
44
+ except Exception:
45
+ return {"error": "unable to dump json for model"}
40
46
 
41
47
 
42
48
  def _replace_non_matching(text: str, allowed_chars: str, replacement: str) -> str:
43
49
  """
44
50
  Replaces every character in `text` that does not match the given
45
51
  `allowed_chars` regex set with `replacement`.
46
-
52
+
47
53
  Parameters:
48
54
  -----------
49
55
  text : str
@@ -53,7 +59,7 @@ def _replace_non_matching(text: str, allowed_chars: str, replacement: str) -> st
53
59
  For example, "a-zA-Z0-9" will keep only letters and digits.
54
60
  replacement : str
55
61
  The string to replace non-matching characters with.
56
-
62
+
57
63
  Returns:
58
64
  --------
59
65
  str
@@ -63,9 +69,11 @@ def _replace_non_matching(text: str, allowed_chars: str, replacement: str) -> st
63
69
  pattern = rf"[^{allowed_chars}]"
64
70
  return re.sub(pattern, replacement, text)
65
71
 
72
+
66
73
  def safe_tool_name(name: str):
67
74
  return _replace_non_matching(name, "a-zA-Z0-9_-", "_")
68
75
 
76
+
69
77
  # Collects a group of tool proxies and manages execution of openai tool calls
70
78
  class ResponsesToolBundle:
71
79
  def __init__(self, toolkits: List[Toolkit]):
@@ -75,51 +83,49 @@ class ResponsesToolBundle:
75
83
  self._tools_by_name = {}
76
84
 
77
85
  open_ai_tools = []
78
-
79
- for toolkit in toolkits:
80
- for v in toolkit.tools:
81
86
 
87
+ for toolkit in toolkits:
88
+ for v in toolkit.tools:
82
89
  k = v.name
83
90
 
84
91
  name = safe_tool_name(k)
85
92
 
86
93
  if k in self._executors:
87
- raise Exception(f"duplicate in bundle '{k}', tool names must be unique.")
94
+ raise Exception(
95
+ f"duplicate in bundle '{k}', tool names must be unique."
96
+ )
88
97
 
89
98
  self._executors[k] = toolkit
90
99
 
91
100
  self._safe_names[name] = k
92
101
  self._tools_by_name[name] = v
93
-
94
- if isinstance(v, OpenAIResponsesTool):
95
102
 
103
+ if isinstance(v, OpenAIResponsesTool):
96
104
  fns = v.get_open_ai_tool_definitions()
97
105
  for fn in fns:
98
106
  open_ai_tools.append(fn)
99
107
 
100
108
  elif isinstance(v, Tool):
101
-
102
109
  strict = True
103
110
  if hasattr(v, "strict"):
104
- strict = getattr(v, "strict") == True
111
+ strict = getattr(v, "strict")
105
112
 
106
113
  fn = {
107
- "type" : "function",
108
- "name" : name,
109
- "description" : v.description,
110
- "parameters" : {
114
+ "type": "function",
115
+ "name": name,
116
+ "description": v.description,
117
+ "parameters": {
111
118
  **v.input_schema,
112
119
  },
113
120
  "strict": strict,
114
121
  }
115
122
 
116
- if v.defs != None:
123
+ if v.defs is not None:
117
124
  fn["parameters"]["$defs"] = v.defs
118
-
125
+
119
126
  open_ai_tools.append(fn)
120
127
 
121
128
  else:
122
-
123
129
  raise RoomException(f"unsupported tool type {type(v)}")
124
130
 
125
131
  if len(open_ai_tools) == 0:
@@ -127,15 +133,18 @@ class ResponsesToolBundle:
127
133
 
128
134
  self._open_ai_tools = open_ai_tools
129
135
 
130
- async def execute(self, *, context: ToolContext, tool_call: ResponseFunctionToolCall) -> Response:
136
+ async def execute(
137
+ self, *, context: ToolContext, tool_call: ResponseFunctionToolCall
138
+ ) -> Response:
131
139
  try:
132
-
133
140
  name = tool_call.name
134
141
  arguments = json.loads(tool_call.arguments)
135
142
 
136
143
  if name not in self._safe_names:
137
- raise RoomException(f"Invalid tool name {name}, check the name of the tool")
138
-
144
+ raise RoomException(
145
+ f"Invalid tool name {name}, check the name of the tool"
146
+ )
147
+
139
148
  name = self._safe_names[name]
140
149
 
141
150
  if name not in self._executors:
@@ -144,25 +153,27 @@ class ResponsesToolBundle:
144
153
  logger.info("executing %s %s %s", tool_call.id, name, arguments)
145
154
 
146
155
  proxy = self._executors[name]
147
- result = await proxy.execute(context=context, name=name, arguments=arguments)
148
- logger.info("success calling %s %s %s", tool_call.id, name, result)
156
+ result = await proxy.execute(
157
+ context=context, name=name, arguments=arguments
158
+ )
159
+ logger.info("success calling %s %s %s", tool_call.id, name, result)
149
160
  return ensure_response(result)
150
161
 
151
162
  except Exception as e:
152
163
  logger.error("failed calling %s %s", tool_call.id, name, exc_info=e)
153
164
  raise
154
-
165
+
155
166
  def get_tool(self, name: str) -> BaseTool | None:
156
167
  return self._tools_by_name.get(name, None)
157
-
168
+
158
169
  def contains(self, name: str) -> bool:
159
170
  return name in self._open_ai_tools
160
171
 
161
172
  def to_json(self) -> List[dict] | None:
162
- if self._open_ai_tools == None:
173
+ if self._open_ai_tools is None:
163
174
  return None
164
175
  return self._open_ai_tools.copy()
165
-
176
+
166
177
 
167
178
  # Converts a tool response into a series of messages that can be inserted into the openai context
168
179
  class OpenAIResponsesToolResponseAdapter(ToolResponseAdapter):
@@ -171,91 +182,116 @@ class OpenAIResponsesToolResponseAdapter(ToolResponseAdapter):
171
182
  pass
172
183
 
173
184
  async def to_plain_text(self, *, room: RoomClient, response: Response) -> str:
174
- if isinstance(response, LinkResponse):
175
- return json.dumps({
176
- "name" : response.name,
177
- "url" : response.url,
178
- })
179
-
180
- elif isinstance(response, JsonResponse):
181
-
185
+ if isinstance(response, LinkResponse):
186
+ return json.dumps(
187
+ {
188
+ "name": response.name,
189
+ "url": response.url,
190
+ }
191
+ )
192
+
193
+ elif isinstance(response, JsonResponse):
182
194
  return json.dumps(response.json)
183
-
195
+
184
196
  elif isinstance(response, TextResponse):
185
197
  return response.text
186
-
187
- elif isinstance(response, FileResponse):
188
198
 
199
+ elif isinstance(response, FileResponse):
189
200
  blob = Blob(mime_type=response.mime_type, data=response.data)
190
201
  uri = self._blob_storage.store(blob=blob)
191
-
202
+
192
203
  return f"The results have been written to a blob with the uri {uri} with the mime type {blob.mime_type}."
193
-
204
+
194
205
  elif isinstance(response, EmptyResponse):
195
206
  return "ok"
196
-
197
- #elif isinstance(response, ImageResponse):
198
- # context.messages.append({
199
- # "role" : "assistant",
200
- # "content" : "the user will upload the image",
201
- # "tool_call_id" : tool_call.id,
202
- # })
203
- # context.messages.append({
204
- # "role" : "user",
205
- # "content" : [
206
- # { "type" : "text", "text": "this is the image from tool call id {tool_call.id}" },
207
- # { "type" : "image_url", "image_url": {"url": response.url, "detail": "auto"} }
208
- # ]
209
- # })
210
-
211
-
212
- elif isinstance(response, dict):
207
+
208
+ # elif isinstance(response, ImageResponse):
209
+ # context.messages.append({
210
+ # "role" : "assistant",
211
+ # "content" : "the user will upload the image",
212
+ # "tool_call_id" : tool_call.id,
213
+ # })
214
+ # context.messages.append({
215
+ # "role" : "user",
216
+ # "content" : [
217
+ # { "type" : "text", "text": "this is the image from tool call id {tool_call.id}" },
218
+ # { "type" : "image_url", "image_url": {"url": response.url, "detail": "auto"} }
219
+ # ]
220
+ # })
221
+
222
+ elif isinstance(response, dict):
213
223
  return json.dumps(response)
214
-
215
- elif isinstance(response, str):
224
+
225
+ elif isinstance(response, str):
216
226
  return response
217
227
 
218
- elif response == None:
228
+ elif response is None:
219
229
  return "ok"
220
-
221
- else:
222
- raise Exception("unexpected return type: {type}".format(type=type(response)))
223
230
 
224
- async def create_messages(self, *, context: AgentChatContext, tool_call: ResponseFunctionToolCall, room: RoomClient, response: Response) -> list:
231
+ else:
232
+ raise Exception(
233
+ "unexpected return type: {type}".format(type=type(response))
234
+ )
225
235
 
236
+ async def create_messages(
237
+ self,
238
+ *,
239
+ context: AgentChatContext,
240
+ tool_call: ResponseFunctionToolCall,
241
+ room: RoomClient,
242
+ response: Response,
243
+ ) -> list:
226
244
  with tracer.start_as_current_span("llm.tool_adapter.create_messages") as span:
227
-
228
-
229
245
  if isinstance(response, RawOutputs):
230
246
  span.set_attribute("kind", "raw")
231
247
  for output in response.outputs:
232
-
233
- room.developer.log_nowait(type="llm.message", data={ "context" : context.id, "participant_id" : room.local_participant.id, "participant_name" : room.local_participant.get_attribute("name"), "message" : output })
234
-
248
+ room.developer.log_nowait(
249
+ type="llm.message",
250
+ data={
251
+ "context": context.id,
252
+ "participant_id": room.local_participant.id,
253
+ "participant_name": room.local_participant.get_attribute(
254
+ "name"
255
+ ),
256
+ "message": output,
257
+ },
258
+ )
259
+
235
260
  return response.outputs
236
261
  else:
237
-
238
262
  span.set_attribute("kind", "text")
239
263
  output = await self.to_plain_text(room=room, response=response)
240
264
  span.set_attribute("output", output)
241
-
265
+
242
266
  message = {
243
- "output" : output,
244
- "call_id" : tool_call.call_id,
245
- "type" : "function_call_output"
267
+ "output": output,
268
+ "call_id": tool_call.call_id,
269
+ "type": "function_call_output",
246
270
  }
247
271
 
248
- room.developer.log_nowait(type="llm.message", data={ "context" : context.id, "participant_id" : room.local_participant.id, "participant_name" : room.local_participant.get_attribute("name"), "message" : message })
272
+ room.developer.log_nowait(
273
+ type="llm.message",
274
+ data={
275
+ "context": context.id,
276
+ "participant_id": room.local_participant.id,
277
+ "participant_name": room.local_participant.get_attribute(
278
+ "name"
279
+ ),
280
+ "message": message,
281
+ },
282
+ )
283
+
284
+ return [message]
249
285
 
250
- return [ message ]
251
286
 
252
287
  class OpenAIResponsesAdapter(LLMAdapter[ResponsesToolBundle]):
253
- def __init__(self,
254
- model: str = os.getenv("OPENAI_MODEL","gpt-4.1"),
255
- parallel_tool_calls : Optional[bool] = None,
288
+ def __init__(
289
+ self,
290
+ model: str = os.getenv("OPENAI_MODEL", "gpt-4.1"),
291
+ parallel_tool_calls: Optional[bool] = None,
256
292
  client: Optional[AsyncOpenAI] = None,
257
- response_options : Optional[dict] = None,
258
- provider: str = "openai"
293
+ response_options: Optional[dict] = None,
294
+ provider: str = "openai",
259
295
  ):
260
296
  self._model = model
261
297
  self._parallel_tool_calls = parallel_tool_calls
@@ -273,20 +309,15 @@ class OpenAIResponsesAdapter(LLMAdapter[ResponsesToolBundle]):
273
309
  system_role = "developer"
274
310
  elif self._model.startswith("computer-use"):
275
311
  system_role = "developer"
276
-
277
312
 
278
- context = AgentChatContext(
279
- system_role=system_role
280
- )
313
+ context = AgentChatContext(system_role=system_role)
281
314
 
282
315
  return context
283
-
284
- async def check_for_termination(self, *, context: AgentChatContext, room: RoomClient) -> bool:
285
- if len(context.previous_messages) > 0:
286
- last_message = context.previous_messages[-1]
287
-
288
- for message in context.messages:
289
316
 
317
+ async def check_for_termination(
318
+ self, *, context: AgentChatContext, room: RoomClient
319
+ ) -> bool:
320
+ for message in context.messages:
290
321
  if message.get("type", "message") != "message":
291
322
  return False
292
323
 
@@ -294,36 +325,29 @@ class OpenAIResponsesAdapter(LLMAdapter[ResponsesToolBundle]):
294
325
 
295
326
  # Takes the current chat context, executes a completion request and processes the response.
296
327
  # If a tool calls are requested, invokes the tools, processes the tool calls results, and appends the tool call results to the context
297
- async def next(self,
328
+ async def next(
329
+ self,
298
330
  *,
299
331
  context: AgentChatContext,
300
332
  room: RoomClient,
301
333
  toolkits: list[Toolkit],
302
334
  tool_adapter: Optional[ToolResponseAdapter] = None,
303
335
  output_schema: Optional[dict] = None,
304
- event_handler: Optional[Callable[[ResponseStreamEvent],None]] = None
336
+ event_handler: Optional[Callable[[ResponseStreamEvent], None]] = None,
305
337
  ):
306
338
  with tracer.start_as_current_span("llm.turn") as span:
339
+ span.set_attributes({"chat_context": context.id, "api": "responses"})
307
340
 
308
- span.set_attributes({
309
- "chat_context" : context.id,
310
- "api" : "responses"
311
- })
312
-
313
- if tool_adapter == None:
341
+ if tool_adapter is None:
314
342
  tool_adapter = OpenAIResponsesToolResponseAdapter()
315
-
343
+
316
344
  try:
317
-
318
345
  while True:
319
-
320
346
  with tracer.start_as_current_span("llm.turn.iteration") as span:
347
+ span.set_attributes(
348
+ {"model": self._model, "provider": self._provider}
349
+ )
321
350
 
322
- span.set_attributes({
323
- "model": self._model,
324
- "provider": self._provider
325
- })
326
-
327
351
  openai = get_client(room=room)
328
352
 
329
353
  response_schema = output_schema
@@ -331,156 +355,257 @@ class OpenAIResponsesAdapter(LLMAdapter[ResponsesToolBundle]):
331
355
 
332
356
  # We need to do this inside the loop because tools can change mid loop
333
357
  # for example computer use adds goto tools after the first interaction
334
- tool_bundle = ResponsesToolBundle(toolkits=[
335
- *toolkits,
336
- ])
358
+ tool_bundle = ResponsesToolBundle(
359
+ toolkits=[
360
+ *toolkits,
361
+ ]
362
+ )
337
363
  open_ai_tools = tool_bundle.to_json()
338
364
 
339
- if open_ai_tools == None:
365
+ if open_ai_tools is None:
340
366
  open_ai_tools = NOT_GIVEN
341
-
367
+
342
368
  ptc = self._parallel_tool_calls
343
369
  extra = {}
344
- if ptc != None and self._model.startswith("o") == False:
345
- extra["parallel_tool_calls"] = ptc
370
+ if ptc is not None and not self._model.startswith("o"):
371
+ extra["parallel_tool_calls"] = ptc
346
372
  span.set_attribute("parallel_tool_calls", ptc)
347
373
  else:
348
374
  span.set_attribute("parallel_tool_calls", False)
349
-
375
+
350
376
  text = NOT_GIVEN
351
- if output_schema != None:
377
+ if output_schema is not None:
352
378
  span.set_attribute("response_format", "json_schema")
353
379
  text = {
354
- "format" : {
355
- "type" : "json_schema",
356
- "name" : response_name,
357
- "schema" : response_schema,
358
- "strict" : True,
380
+ "format": {
381
+ "type": "json_schema",
382
+ "name": response_name,
383
+ "schema": response_schema,
384
+ "strict": True,
359
385
  }
360
386
  }
361
387
  else:
362
388
  span.set_attribute("response_format", "text")
363
389
 
364
-
365
390
  previous_response_id = NOT_GIVEN
366
- if context.previous_response_id != None:
391
+ if context.previous_response_id is not None:
367
392
  previous_response_id = context.previous_response_id
368
-
369
- stream = event_handler != None
370
-
371
-
393
+
394
+ stream = event_handler is not None
395
+
372
396
  with tracer.start_as_current_span("llm.invoke") as span:
373
397
  response_options = self._response_options
374
- if response_options == None:
398
+ if response_options is None:
375
399
  response_options = {}
376
- response : Response = await openai.responses.create(
400
+ response: Response = await openai.responses.create(
377
401
  stream=stream,
378
- model = self._model,
379
- input = context.messages,
380
- tools = open_ai_tools,
381
- text = text,
402
+ model=self._model,
403
+ input=context.messages,
404
+ tools=open_ai_tools,
405
+ text=text,
382
406
  previous_response_id=previous_response_id,
383
-
384
- **response_options
407
+ **response_options,
385
408
  )
386
409
 
387
410
  async def handle_message(message: BaseModel):
388
-
389
- with tracer.start_as_current_span("llm.handle_response") as span:
390
-
391
- span.set_attributes({
392
- "type" : message.type,
393
- "message" : safe_model_dump(message)
394
- })
395
-
396
- room.developer.log_nowait(type=f"llm.message", data={
397
- "context" : context.id, "participant_id" : room.local_participant.id, "participant_name" : room.local_participant.get_attribute("name"), "message" : message.to_dict()
398
- })
411
+ with tracer.start_as_current_span(
412
+ "llm.handle_response"
413
+ ) as span:
414
+ span.set_attributes(
415
+ {
416
+ "type": message.type,
417
+ "message": safe_model_dump(message),
418
+ }
419
+ )
420
+
421
+ room.developer.log_nowait(
422
+ type="llm.message",
423
+ data={
424
+ "context": context.id,
425
+ "participant_id": room.local_participant.id,
426
+ "participant_name": room.local_participant.get_attribute(
427
+ "name"
428
+ ),
429
+ "message": message.to_dict(),
430
+ },
431
+ )
399
432
 
400
433
  if message.type == "function_call":
401
-
402
434
  tasks = []
403
435
 
404
- async def do_tool_call(tool_call: ResponseFunctionToolCall):
405
-
406
- try:
407
- with tracer.start_as_current_span("llm.handle_tool_call") as span:
408
-
409
- span.set_attributes({
436
+ async def do_tool_call(
437
+ tool_call: ResponseFunctionToolCall,
438
+ ):
439
+ try:
440
+ with tracer.start_as_current_span(
441
+ "llm.handle_tool_call"
442
+ ) as span:
443
+ span.set_attributes(
444
+ {
410
445
  "id": tool_call.id,
411
446
  "name": tool_call.name,
412
447
  "call_id": tool_call.call_id,
413
- "arguments": json.dumps(tool_call.arguments)
414
- })
415
-
416
- tool_context = ToolContext(
417
- room=room,
418
- caller=room.local_participant,
419
- caller_context={ "chat" : context.to_json() }
448
+ "arguments": json.dumps(
449
+ tool_call.arguments
450
+ ),
451
+ }
452
+ )
453
+
454
+ tool_context = ToolContext(
455
+ room=room,
456
+ caller=room.local_participant,
457
+ caller_context={
458
+ "chat": context.to_json()
459
+ },
460
+ )
461
+ tool_response = (
462
+ await tool_bundle.execute(
463
+ context=tool_context,
464
+ tool_call=tool_call,
420
465
  )
421
- tool_response = await tool_bundle.execute(context=tool_context, tool_call=tool_call)
422
- if tool_response.caller_context != None:
423
- if tool_response.caller_context.get("chat", None) != None:
424
- tool_chat_context = AgentChatContext.from_json(tool_response.caller_context["chat"])
425
- if tool_chat_context.previous_response_id != None:
426
- context.track_response(tool_chat_context.previous_response_id)
427
-
428
- logger.info(f"tool response {tool_response}")
429
- return await tool_adapter.create_messages(context=context, tool_call=tool_call, room=room, response=tool_response)
430
-
431
- except Exception as e:
432
- logger.error(f"unable to complete tool call {tool_call}", exc_info=e)
433
- room.developer.log_nowait(type="llm.error", data={ "participant_id" : room.local_participant.id, "participant_name" : room.local_participant.get_attribute("name"), "error" : f"{e}" })
434
-
435
- return [{
436
- "output" : json.dumps({"error":f"unable to complete tool call: {e}"}),
437
- "call_id" : tool_call.call_id,
438
- "type" : "function_call_output"
439
- }]
440
-
441
-
442
- tasks.append(asyncio.create_task(do_tool_call(message)))
466
+ )
467
+ if (
468
+ tool_response.caller_context
469
+ is not None
470
+ ):
471
+ if (
472
+ tool_response.caller_context.get(
473
+ "chat", None
474
+ )
475
+ is not None
476
+ ):
477
+ tool_chat_context = AgentChatContext.from_json(
478
+ tool_response.caller_context[
479
+ "chat"
480
+ ]
481
+ )
482
+ if (
483
+ tool_chat_context.previous_response_id
484
+ is not None
485
+ ):
486
+ context.track_response(
487
+ tool_chat_context.previous_response_id
488
+ )
489
+
490
+ logger.info(
491
+ f"tool response {tool_response}"
492
+ )
493
+ return await tool_adapter.create_messages(
494
+ context=context,
495
+ tool_call=tool_call,
496
+ room=room,
497
+ response=tool_response,
498
+ )
499
+
500
+ except Exception as e:
501
+ logger.error(
502
+ f"unable to complete tool call {tool_call}",
503
+ exc_info=e,
504
+ )
505
+ room.developer.log_nowait(
506
+ type="llm.error",
507
+ data={
508
+ "participant_id": room.local_participant.id,
509
+ "participant_name": room.local_participant.get_attribute(
510
+ "name"
511
+ ),
512
+ "error": f"{e}",
513
+ },
514
+ )
515
+
516
+ return [
517
+ {
518
+ "output": json.dumps(
519
+ {
520
+ "error": f"unable to complete tool call: {e}"
521
+ }
522
+ ),
523
+ "call_id": tool_call.call_id,
524
+ "type": "function_call_output",
525
+ }
526
+ ]
527
+
528
+ tasks.append(
529
+ asyncio.create_task(do_tool_call(message))
530
+ )
443
531
 
444
532
  results = await asyncio.gather(*tasks)
445
533
 
446
534
  all_results = []
447
535
  for result in results:
448
- room.developer.log_nowait(type="llm.message", data={ "context" : context.id, "participant_id" : room.local_participant.id, "participant_name" : room.local_participant.get_attribute("name"), "message" : result })
536
+ room.developer.log_nowait(
537
+ type="llm.message",
538
+ data={
539
+ "context": context.id,
540
+ "participant_id": room.local_participant.id,
541
+ "participant_name": room.local_participant.get_attribute(
542
+ "name"
543
+ ),
544
+ "message": result,
545
+ },
546
+ )
449
547
  all_results.extend(result)
450
548
 
451
549
  return all_results, False
452
550
 
453
551
  elif message.type == "message":
454
-
455
-
456
552
  contents = message.content
457
- if response_schema == None:
553
+ if response_schema is None:
458
554
  return [], False
459
555
  else:
460
556
  for content in contents:
461
557
  # First try to parse the result
462
558
  try:
463
- full_response = json.loads(content.text)
464
-
559
+ full_response = json.loads(
560
+ content.text
561
+ )
562
+
465
563
  # sometimes open ai packs two JSON chunks seperated by newline, check if that's why we couldn't parse
466
- except json.decoder.JSONDecodeError as e:
467
- for part in content.text.splitlines():
564
+ except json.decoder.JSONDecodeError:
565
+ for (
566
+ part
567
+ ) in content.text.splitlines():
468
568
  if len(part.strip()) > 0:
469
- full_response = json.loads(part)
470
-
569
+ full_response = json.loads(
570
+ part
571
+ )
572
+
471
573
  try:
472
- self.validate(response=full_response, output_schema=response_schema)
574
+ self.validate(
575
+ response=full_response,
576
+ output_schema=response_schema,
577
+ )
473
578
  except Exception as e:
474
- logger.error("recieved invalid response, retrying", exc_info=e)
475
- error = { "role" : "user", "content" : "encountered a validation error with the output: {error}".format(error=e)}
476
- room.developer.log_nowait(type="llm.message", data={ "context" : message.id, "participant_id" : room.local_participant.id, "participant_name" : room.local_participant.get_attribute("name"), "message" : error })
477
- context.messages.append(error)
579
+ logger.error(
580
+ "recieved invalid response, retrying",
581
+ exc_info=e,
582
+ )
583
+ error = {
584
+ "role": "user",
585
+ "content": "encountered a validation error with the output: {error}".format(
586
+ error=e
587
+ ),
588
+ }
589
+ room.developer.log_nowait(
590
+ type="llm.message",
591
+ data={
592
+ "context": message.id,
593
+ "participant_id": room.local_participant.id,
594
+ "participant_name": room.local_participant.get_attribute(
595
+ "name"
596
+ ),
597
+ "message": error,
598
+ },
599
+ )
600
+ context.messages.append(
601
+ error
602
+ )
478
603
  continue
479
-
480
- return [ full_response ], True
481
- #elif message.type == "computer_call" and tool_bundle.get_tool("computer_call"):
604
+
605
+ return [full_response], True
606
+ # elif message.type == "computer_call" and tool_bundle.get_tool("computer_call"):
482
607
  # with tracer.start_as_current_span("llm.handle_computer_call") as span:
483
- #
608
+ #
484
609
  # computer_call :ResponseComputerToolCall = message
485
610
  # span.set_attributes({
486
611
  # "id": computer_call.id,
@@ -497,49 +622,78 @@ class OpenAIResponsesAdapter(LLMAdapter[ResponsesToolBundle]):
497
622
  # outputs = (await tool_bundle.get_tool("computer_call").execute(context=tool_context, arguments=message.model_dump(mode="json"))).outputs
498
623
 
499
624
  # return outputs, False
500
-
501
625
 
502
626
  else:
503
627
  for toolkit in toolkits:
504
628
  for tool in toolkit.tools:
505
- if isinstance(tool, OpenAIResponsesTool):
506
- with tracer.start_as_current_span("llm.handle_tool_call") as span:
507
-
508
- arguments = message.model_dump(mode="json")
509
- span.set_attributes({
510
- "type" : message.type,
511
- "arguments" : safe_json_dump(arguments)
512
- })
629
+ if isinstance(
630
+ tool, OpenAIResponsesTool
631
+ ):
632
+ with tracer.start_as_current_span(
633
+ "llm.handle_tool_call"
634
+ ) as span:
635
+ arguments = message.model_dump(
636
+ mode="json"
637
+ )
638
+ span.set_attributes(
639
+ {
640
+ "type": message.type,
641
+ "arguments": safe_json_dump(
642
+ arguments
643
+ ),
644
+ }
645
+ )
513
646
 
514
647
  handlers = tool.get_open_ai_output_handlers()
515
648
  if message.type in handlers:
516
649
  tool_context = ToolContext(
517
650
  room=room,
518
651
  caller=room.local_participant,
519
- caller_context={ "chat" : context.to_json() }
652
+ caller_context={
653
+ "chat": context.to_json()
654
+ },
520
655
  )
521
- result = await handlers[message.type](tool_context, **arguments)
522
-
523
- if result != None:
524
- span.set_attribute("result", safe_json_dump(result))
525
- return [ result ], False
656
+ result = await handlers[
657
+ message.type
658
+ ](tool_context, **arguments)
659
+
660
+ if result is not None:
661
+ span.set_attribute(
662
+ "result",
663
+ safe_json_dump(
664
+ result
665
+ ),
666
+ )
667
+ return [result], False
526
668
  else:
669
+ logger.warning(
670
+ f"OpenAI response handler was not registered for {message.type}"
671
+ )
527
672
 
528
- logger.warning(f"OpenAI response handler was not registered for {message.type}")
529
-
530
-
531
673
  return [], False
532
-
533
- if stream == False:
534
- room.developer.log_nowait(type="llm.message", data={ "context" : context.id, "participant_id" : room.local_participant.id, "participant_name" : room.local_participant.get_attribute("name"), "response" : response.to_dict() })
535
-
674
+
675
+ if not stream:
676
+ room.developer.log_nowait(
677
+ type="llm.message",
678
+ data={
679
+ "context": context.id,
680
+ "participant_id": room.local_participant.id,
681
+ "participant_name": room.local_participant.get_attribute(
682
+ "name"
683
+ ),
684
+ "response": response.to_dict(),
685
+ },
686
+ )
687
+
536
688
  context.track_response(response.id)
537
689
 
538
690
  final_outputs = []
539
-
691
+
540
692
  for message in response.output:
541
693
  context.previous_messages.append(message.to_dict())
542
- outputs, done = await handle_message(message=message)
694
+ outputs, done = await handle_message(
695
+ message=message
696
+ )
543
697
  if done:
544
698
  final_outputs.extend(outputs)
545
699
  else:
@@ -547,12 +701,14 @@ class OpenAIResponsesAdapter(LLMAdapter[ResponsesToolBundle]):
547
701
  context.messages.append(output)
548
702
 
549
703
  if len(final_outputs) > 0:
550
-
551
704
  return final_outputs[0]
552
-
553
- with tracer.start_as_current_span("llm.turn.check_for_termination") as span:
554
705
 
555
- term = await self.check_for_termination(context=context, room=room)
706
+ with tracer.start_as_current_span(
707
+ "llm.turn.check_for_termination"
708
+ ) as span:
709
+ term = await self.check_for_termination(
710
+ context=context, room=room
711
+ )
556
712
  if term:
557
713
  span.set_attribute("terminate", True)
558
714
  text = ""
@@ -565,109 +721,121 @@ class OpenAIResponsesAdapter(LLMAdapter[ResponsesToolBundle]):
565
721
  else:
566
722
  span.set_attribute("terminate", False)
567
723
 
568
-
569
724
  else:
570
-
571
725
  final_outputs = []
572
726
  all_outputs = []
573
727
  async for e in response:
574
- with tracer.start_as_current_span("llm.stream.event") as span:
575
-
576
- event : ResponseStreamEvent = e
577
- span.set_attributes({
578
- "type" : event.type,
579
- "event" : safe_model_dump(event)
580
- })
728
+ with tracer.start_as_current_span(
729
+ "llm.stream.event"
730
+ ) as span:
731
+ event: ResponseStreamEvent = e
732
+ span.set_attributes(
733
+ {
734
+ "type": event.type,
735
+ "event": safe_model_dump(event),
736
+ }
737
+ )
581
738
 
582
739
  event_handler(event)
583
740
 
584
741
  if event.type == "response.completed":
585
-
586
-
587
742
  context.track_response(event.response.id)
588
-
743
+
589
744
  context.messages.extend(all_outputs)
590
745
 
591
- with tracer.start_as_current_span("llm.turn.check_for_termination") as span:
592
- term = await self.check_for_termination(context=context, room=room)
593
-
746
+ with tracer.start_as_current_span(
747
+ "llm.turn.check_for_termination"
748
+ ) as span:
749
+ term = await self.check_for_termination(
750
+ context=context, room=room
751
+ )
752
+
594
753
  if term:
595
- span.set_attribute("terminate", True)
596
-
754
+ span.set_attribute(
755
+ "terminate", True
756
+ )
757
+
597
758
  text = ""
598
759
  for output in event.response.output:
599
760
  if output.type == "message":
600
- for content in output.content:
761
+ for (
762
+ content
763
+ ) in output.content:
601
764
  text += content.text
602
765
 
603
766
  return text
604
767
 
605
768
  span.set_attribute("terminate", False)
606
769
 
607
-
608
770
  all_outputs = []
609
771
 
610
772
  elif event.type == "response.output_item.done":
611
-
612
- context.previous_messages.append(event.item.to_dict())
613
-
614
- outputs, done = await handle_message(message=event.item)
615
- if done:
616
- final_outputs.extend(outputs)
617
- else:
618
- for output in outputs:
619
- all_outputs.append(output)
773
+ context.previous_messages.append(
774
+ event.item.to_dict()
775
+ )
776
+
777
+ outputs, done = await handle_message(
778
+ message=event.item
779
+ )
780
+ if done:
781
+ final_outputs.extend(outputs)
782
+ else:
783
+ for output in outputs:
784
+ all_outputs.append(output)
620
785
 
621
786
  else:
622
787
  for toolkit in toolkits:
623
788
  for tool in toolkit.tools:
624
-
625
- if isinstance(tool, OpenAIResponsesTool):
626
-
789
+ if isinstance(
790
+ tool, OpenAIResponsesTool
791
+ ):
627
792
  callbacks = tool.get_open_ai_stream_callbacks()
628
793
 
629
794
  if event.type in callbacks:
630
-
631
795
  tool_context = ToolContext(
632
796
  room=room,
633
797
  caller=room.local_participant,
634
- caller_context={ "chat" : context.to_json() }
798
+ caller_context={
799
+ "chat": context.to_json()
800
+ },
635
801
  )
636
802
 
637
- await callbacks[event.type](tool_context, **event.to_dict())
638
-
803
+ await callbacks[event.type](
804
+ tool_context,
805
+ **event.to_dict(),
806
+ )
639
807
 
640
808
  if len(final_outputs) > 0:
641
-
642
809
  return final_outputs[0]
643
-
810
+
644
811
  except APIStatusError as e:
645
812
  raise RoomException(f"Error from OpenAI: {e}")
646
-
647
813
 
648
- class OpenAIResponsesTool(BaseTool):
649
814
 
815
+ class OpenAIResponsesTool(BaseTool):
650
816
  def get_open_ai_tool_definitions(self) -> list[dict]:
651
817
  return []
652
-
818
+
653
819
  def get_open_ai_stream_callbacks(self) -> dict[str, Callable]:
654
820
  return {}
655
-
656
- def get_open_ai_output_handlers(self) -> dict[str, Callable]:
821
+
822
+ def get_open_ai_output_handlers(self) -> dict[str, Callable]:
657
823
  return {}
658
-
824
+
659
825
 
660
826
  class ImageGenerationTool(OpenAIResponsesTool):
661
- def __init__(self, *,
662
- background: Literal["transparent","opaque","auto"] = None,
827
+ def __init__(
828
+ self,
829
+ *,
830
+ background: Literal["transparent", "opaque", "auto"] = None,
663
831
  input_image_mask_url: Optional[str] = None,
664
832
  model: Optional[str] = None,
665
833
  moderation: Optional[str] = None,
666
834
  output_compression: Optional[int] = None,
667
- output_format: Optional[Literal["png","webp","jpeg"]] = None,
835
+ output_format: Optional[Literal["png", "webp", "jpeg"]] = None,
668
836
  partial_images: Optional[int] = None,
669
837
  quality: Optional[Literal["auto", "low", "medium", "high"]] = None,
670
- size: Optional[Literal["1024x1024","1024x1536","1536x1024","auto"]] = None
838
+ size: Optional[Literal["1024x1024", "1024x1536", "1536x1024", "auto"]] = None,
671
839
  ):
672
840
  super().__init__(name="image_generation")
673
841
  self.background = background
@@ -676,82 +844,154 @@ class ImageGenerationTool(OpenAIResponsesTool):
676
844
  self.moderation = moderation
677
845
  self.output_compression = output_compression
678
846
  self.output_format = output_format
679
- if partial_images == None:
680
- partial_images = 1 # streaming wants non zero, and we stream by default
847
+ if partial_images is None:
848
+ partial_images = 1 # streaming wants non zero, and we stream by default
681
849
  self.partial_images = partial_images
682
850
  self.quality = quality
683
851
  self.size = size
684
852
 
685
853
  def get_open_ai_tool_definitions(self):
686
- opts = {
687
- "type" : "image_generation"
688
- }
854
+ opts = {"type": "image_generation"}
689
855
 
690
- if self.background != None:
856
+ if self.background is not None:
691
857
  opts["background"] = self.background
692
858
 
693
- if self.input_image_mask_url != None:
694
- opts["input_image_mask"] = { "image_url" : self.input_image_mask_url }
859
+ if self.input_image_mask_url is not None:
860
+ opts["input_image_mask"] = {"image_url": self.input_image_mask_url}
695
861
 
696
- if self.model != None:
862
+ if self.model is not None:
697
863
  opts["model"] = self.model
698
864
 
699
- if self.moderation != None:
865
+ if self.moderation is not None:
700
866
  opts["moderation"] = self.moderation
701
867
 
702
- if self.output_compression != None:
868
+ if self.output_compression is not None:
703
869
  opts["output_compression"] = self.output_compression
704
870
 
705
- if self.output_format != None:
871
+ if self.output_format is not None:
706
872
  opts["output_format"] = self.output_format
707
873
 
708
- if self.partial_images != None:
874
+ if self.partial_images is not None:
709
875
  opts["partial_images"] = self.partial_images
710
876
 
711
- if self.quality != None:
877
+ if self.quality is not None:
712
878
  opts["quality"] = self.quality
713
879
 
714
- if self.size != None:
880
+ if self.size is not None:
715
881
  opts["size"] = self.size
716
882
 
717
- return [ opts ]
718
-
883
+ return [opts]
884
+
719
885
  def get_open_ai_stream_callbacks(self):
720
886
  return {
721
- "response.image_generation_call.completed" : self.on_image_generation_completed,
722
- "response.image_generation_call.in_progress" : self.on_image_generation_in_progress,
723
- "response.image_generation_call.generating" : self.on_image_generation_generating,
724
- "response.image_generation_call.partial_image" : self.on_image_generation_partial,
887
+ "response.image_generation_call.completed": self.on_image_generation_completed,
888
+ "response.image_generation_call.in_progress": self.on_image_generation_in_progress,
889
+ "response.image_generation_call.generating": self.on_image_generation_generating,
890
+ "response.image_generation_call.partial_image": self.on_image_generation_partial,
725
891
  }
726
-
892
+
727
893
  def get_open_ai_output_handlers(self):
728
- return {
729
- "image_generation_call" : self.handle_image_generated
730
- }
894
+ return {"image_generation_call": self.handle_image_generated}
731
895
 
732
896
  # response.image_generation_call.completed
733
- async def on_image_generation_completed(self, context: ToolContext, *, item_id: str, output_index: int, sequence_number: int, type: str, **extra):
897
+ async def on_image_generation_completed(
898
+ self,
899
+ context: ToolContext,
900
+ *,
901
+ item_id: str,
902
+ output_index: int,
903
+ sequence_number: int,
904
+ type: str,
905
+ **extra,
906
+ ):
734
907
  pass
735
908
 
736
909
  # response.image_generation_call.in_progress
737
- async def on_image_generation_in_progress(self, context: ToolContext, *, item_id: str, output_index: int, sequence_number: int, type: str, **extra):
910
+ async def on_image_generation_in_progress(
911
+ self,
912
+ context: ToolContext,
913
+ *,
914
+ item_id: str,
915
+ output_index: int,
916
+ sequence_number: int,
917
+ type: str,
918
+ **extra,
919
+ ):
738
920
  pass
739
921
 
740
922
  # response.image_generation_call.generating
741
- async def on_image_generation_generating(self, context: ToolContext, *, item_id: str, output_index: int, sequence_number: int, type: str, **extra):
923
+ async def on_image_generation_generating(
924
+ self,
925
+ context: ToolContext,
926
+ *,
927
+ item_id: str,
928
+ output_index: int,
929
+ sequence_number: int,
930
+ type: str,
931
+ **extra,
932
+ ):
742
933
  pass
743
934
 
744
935
  # response.image_generation_call.partial_image
745
- async def on_image_generation_partial(self, context: ToolContext, *, item_id: str, output_index: int, sequence_number: int, type: str, partial_image_b64: str, partial_image_index: int, size: str, quality: str, background: str, output_format: str, **extra):
936
+ async def on_image_generation_partial(
937
+ self,
938
+ context: ToolContext,
939
+ *,
940
+ item_id: str,
941
+ output_index: int,
942
+ sequence_number: int,
943
+ type: str,
944
+ partial_image_b64: str,
945
+ partial_image_index: int,
946
+ size: str,
947
+ quality: str,
948
+ background: str,
949
+ output_format: str,
950
+ **extra,
951
+ ):
746
952
  pass
747
953
 
748
- async def on_image_generated(self, context: ToolContext, *, item_id: str, data: bytes, status: str, size: str, quality: str, background: str, output_format: str, **extra):
954
+ async def on_image_generated(
955
+ self,
956
+ context: ToolContext,
957
+ *,
958
+ item_id: str,
959
+ data: bytes,
960
+ status: str,
961
+ size: str,
962
+ quality: str,
963
+ background: str,
964
+ output_format: str,
965
+ **extra,
966
+ ):
749
967
  pass
750
968
 
751
- async def handle_image_generated(self, context: ToolContext, *, id: str, result: str | None, status: str, type: str, size: str, quality: str, background: str, output_format: str, **extra):
752
- if result != None:
969
+ async def handle_image_generated(
970
+ self,
971
+ context: ToolContext,
972
+ *,
973
+ id: str,
974
+ result: str | None,
975
+ status: str,
976
+ type: str,
977
+ size: str,
978
+ quality: str,
979
+ background: str,
980
+ output_format: str,
981
+ **extra,
982
+ ):
983
+ if result is not None:
753
984
  data = base64.b64decode(result)
754
- await self.on_image_generated(context, item_id=id, data=data, status=status, size=size, quality=quality, background=background, output_format=output_format)
985
+ await self.on_image_generated(
986
+ context,
987
+ item_id=id,
988
+ data=data,
989
+ status=status,
990
+ size=size,
991
+ quality=quality,
992
+ background=background,
993
+ output_format=output_format,
994
+ )
755
995
 
756
996
 
757
997
  class LocalShellTool(OpenAIResponsesTool):
@@ -759,19 +999,22 @@ class LocalShellTool(OpenAIResponsesTool):
759
999
  super().__init__(name="local_shell")
760
1000
 
761
1001
  def get_open_ai_tool_definitions(self):
762
- return [
763
- {
764
- "type" : "local_shell"
765
- }
766
- ]
767
-
768
- def get_open_ai_output_handlers(self):
769
- return {
770
- "local_shell_call" : self.handle_local_shell_call
771
- }
1002
+ return [{"type": "local_shell"}]
772
1003
 
773
- async def execute_shell_command(self, context: ToolContext, *, command: list[str], env: dict, type: str, timeout_ms: int | None = None, user: str | None = None, working_directory: str | None = None):
1004
+ def get_open_ai_output_handlers(self):
1005
+ return {"local_shell_call": self.handle_local_shell_call}
774
1006
 
1007
+ async def execute_shell_command(
1008
+ self,
1009
+ context: ToolContext,
1010
+ *,
1011
+ command: list[str],
1012
+ env: dict,
1013
+ type: str,
1014
+ timeout_ms: int | None = None,
1015
+ user: str | None = None,
1016
+ working_directory: str | None = None,
1017
+ ):
775
1018
  merged_env = {**os.environ, **(env or {})}
776
1019
 
777
1020
  # Spawn the process
@@ -789,20 +1032,28 @@ class LocalShellTool(OpenAIResponsesTool):
789
1032
  timeout=timeout_ms / 1000 if timeout_ms else None,
790
1033
  )
791
1034
  except asyncio.TimeoutError:
792
- proc.kill() # send SIGKILL / TerminateProcess
1035
+ proc.kill() # send SIGKILL / TerminateProcess
793
1036
  stdout, stderr = await proc.communicate()
794
- raise # re-raise so caller sees the timeout
795
-
1037
+ raise # re-raise so caller sees the timeout
796
1038
 
797
1039
  encoding = os.device_encoding(1) or "utf-8"
798
1040
  stdout = stdout.decode(encoding, errors="replace")
799
1041
  stderr = stderr.decode(encoding, errors="replace")
800
1042
 
801
1043
  return stdout + stderr
802
-
803
- async def handle_local_shell_call(self, context, *, id: str, action: dict, call_id: str, status: str, type: str, **extra):
804
-
805
- result = await self.execute_shell_command(context, **action)
1044
+
1045
+ async def handle_local_shell_call(
1046
+ self,
1047
+ context,
1048
+ *,
1049
+ id: str,
1050
+ action: dict,
1051
+ call_id: str,
1052
+ status: str,
1053
+ type: str,
1054
+ **extra,
1055
+ ):
1056
+ result = await self.execute_shell_command(context, **action)
806
1057
 
807
1058
  output_item = {
808
1059
  "type": "local_shell_call_output",
@@ -819,78 +1070,110 @@ class ContainerFile:
819
1070
  self.mime_type = mime_type
820
1071
  self.container_id = container_id
821
1072
 
1073
+
822
1074
  class CodeInterpreterTool(OpenAIResponsesTool):
823
- def __init__(self, *, container_id: Optional[str] = None, file_ids: Optional[List[str]] = None):
1075
+ def __init__(
1076
+ self,
1077
+ *,
1078
+ container_id: Optional[str] = None,
1079
+ file_ids: Optional[List[str]] = None,
1080
+ ):
824
1081
  super().__init__(name="code_interpreter_call")
825
1082
  self.container_id = container_id
826
1083
  self.file_ids = file_ids
827
1084
 
828
1085
  def get_open_ai_tool_definitions(self):
829
- opts = {
830
- "type" : "code_interpreter"
831
- }
1086
+ opts = {"type": "code_interpreter"}
832
1087
 
833
- if self.container_id != None:
1088
+ if self.container_id is not None:
834
1089
  opts["container_id"] = self.container_id
835
1090
 
836
- if self.file_ids != None:
837
- if self.container_id != None:
838
- raise Exception("Cannot specify both an existing container and files to upload in a code interpreter tool")
839
-
840
- opts["container"] = {
841
- "type" : "auto",
842
- "file_ids" : self.file_ids
843
- }
1091
+ if self.file_ids is not None:
1092
+ if self.container_id is not None:
1093
+ raise Exception(
1094
+ "Cannot specify both an existing container and files to upload in a code interpreter tool"
1095
+ )
1096
+
1097
+ opts["container"] = {"type": "auto", "file_ids": self.file_ids}
1098
+
1099
+ return [opts]
844
1100
 
845
- return [
846
- opts
847
- ]
848
-
849
1101
  def get_open_ai_output_handlers(self):
850
- return {
851
- "code_interpreter_call" : self.handle_code_interpreter_call
852
- }
853
-
854
- async def on_code_interpreter_result(self, context: ToolContext, *, code: str, logs: list[str], files: list[ContainerFile]):
1102
+ return {"code_interpreter_call": self.handle_code_interpreter_call}
1103
+
1104
+ async def on_code_interpreter_result(
1105
+ self,
1106
+ context: ToolContext,
1107
+ *,
1108
+ code: str,
1109
+ logs: list[str],
1110
+ files: list[ContainerFile],
1111
+ ):
855
1112
  pass
856
-
857
- async def handle_code_interpreter_call(self, context, *, code: str, id: str, results: list[dict], call_id: str, status: str, type: str, container_id: str, **extra):
858
-
1113
+
1114
+ async def handle_code_interpreter_call(
1115
+ self,
1116
+ context,
1117
+ *,
1118
+ code: str,
1119
+ id: str,
1120
+ results: list[dict],
1121
+ call_id: str,
1122
+ status: str,
1123
+ type: str,
1124
+ container_id: str,
1125
+ **extra,
1126
+ ):
859
1127
  logs = []
860
1128
  files = []
861
1129
 
862
1130
  for result in results:
863
-
864
1131
  if result.type == "logs":
865
-
866
- logs.append(results["logs"])
867
-
868
- elif result.type == "files":
869
-
870
- files.append(ContainerFile(container_id=container_id, file_id=result["file_id"], mime_type=result["mime_type"]))
1132
+ logs.append(results["logs"])
871
1133
 
872
- await self.on_code_interpreter_result(context, code=code, logs=logs, files=files)
1134
+ elif result.type == "files":
1135
+ files.append(
1136
+ ContainerFile(
1137
+ container_id=container_id,
1138
+ file_id=result["file_id"],
1139
+ mime_type=result["mime_type"],
1140
+ )
1141
+ )
1142
+
1143
+ await self.on_code_interpreter_result(
1144
+ context, code=code, logs=logs, files=files
1145
+ )
873
1146
 
874
1147
 
875
1148
  class MCPToolDefinition:
876
- def __init__(self, *, input_schema: dict, name: str, annotations: dict | None, description: str | None):
1149
+ def __init__(
1150
+ self,
1151
+ *,
1152
+ input_schema: dict,
1153
+ name: str,
1154
+ annotations: dict | None,
1155
+ description: str | None,
1156
+ ):
877
1157
  self.input_schema = input_schema
878
1158
  self.name = name
879
1159
  self.annotations = annotations
880
1160
  self.description = description
881
1161
 
1162
+
882
1163
  class MCPServer:
883
- def __init__(self, *,
1164
+ def __init__(
1165
+ self,
1166
+ *,
884
1167
  server_label: str,
885
1168
  server_url: str,
886
1169
  allowed_tools: Optional[list[str]] = None,
887
1170
  headers: Optional[dict] = None,
888
1171
  # require approval for all tools
889
- require_approval: Optional[Literal["always","never"]] = None,
1172
+ require_approval: Optional[Literal["always", "never"]] = None,
890
1173
  # list of tools that always require approval
891
1174
  always_require_approval: Optional[list[str]] = None,
892
1175
  # list of tools that never require approval
893
- never_require_approval: Optional[list[str]] = None
1176
+ never_require_approval: Optional[list[str]] = None,
894
1177
  ):
895
1178
  self.server_label = server_label
896
1179
  self.server_url = server_url
@@ -902,38 +1185,40 @@ class MCPServer:
902
1185
 
903
1186
 
904
1187
  class MCPTool(OpenAIResponsesTool):
905
- def __init__(self, *,
906
- servers: list[MCPServer]
907
- ):
1188
+ def __init__(self, *, servers: list[MCPServer]):
908
1189
  super().__init__(name="mcp")
909
1190
  self.servers = servers
910
1191
 
911
-
912
1192
  def get_open_ai_tool_definitions(self):
913
-
914
1193
  defs = []
915
1194
  for server in self.servers:
916
1195
  opts = {
917
- "type" : "mcp",
918
- "server_label" : server.server_label,
919
- "server_url" : server.server_url,
1196
+ "type": "mcp",
1197
+ "server_label": server.server_label,
1198
+ "server_url": server.server_url,
920
1199
  }
921
1200
 
922
- if server.allowed_tools != None:
1201
+ if server.allowed_tools is not None:
923
1202
  opts["allowed_tools"] = server.allowed_tools
924
1203
 
925
- if server.headers != None:
1204
+ if server.headers is not None:
926
1205
  opts["headers"] = server.headers
927
-
928
1206
 
929
- if server.always_require_approval != None or server.never_require_approval != None:
1207
+ if (
1208
+ server.always_require_approval is not None
1209
+ or server.never_require_approval is not None
1210
+ ):
930
1211
  opts["require_approval"] = {}
931
1212
 
932
- if server.always_require_approval != None:
933
- opts["require_approval"]["always"] = { "tool_names" : server.always_require_approval }
934
-
935
- if server.never_require_approval != None:
936
- opts["require_approval"]["never"] = { "tool_names" : server.never_require_approval }
1213
+ if server.always_require_approval is not None:
1214
+ opts["require_approval"]["always"] = {
1215
+ "tool_names": server.always_require_approval
1216
+ }
1217
+
1218
+ if server.never_require_approval is not None:
1219
+ opts["require_approval"]["never"] = {
1220
+ "tool_names": server.never_require_approval
1221
+ }
937
1222
 
938
1223
  if server.require_approval:
939
1224
  opts["require_approval"] = server.require_approval
@@ -941,90 +1226,200 @@ class MCPTool(OpenAIResponsesTool):
941
1226
  defs.append(opts)
942
1227
 
943
1228
  return defs
944
-
1229
+
945
1230
  def get_open_ai_stream_callbacks(self):
946
1231
  return {
947
- "response.mcp_list_tools.in_progress" : self.on_mcp_list_tools_in_progress,
948
- "response.mcp_list_tools.failed" : self.on_mcp_list_tools_failed,
949
- "response.mcp_list_tools.completed" : self.on_mcp_list_tools_completed,
950
- "response.mcp_call.in_progress" : self.on_mcp_call_in_progress,
951
- "response.mcp_call.failed" : self.on_mcp_call_failed,
952
- "response.mcp_call.completed" : self.on_mcp_call_completed,
953
- "response.mcp_call.arguments.done" : self.on_mcp_call_arguments_done,
954
- "response.mcp_call.arguments.delta" : self.on_mcp_call_arguments_delta,
1232
+ "response.mcp_list_tools.in_progress": self.on_mcp_list_tools_in_progress,
1233
+ "response.mcp_list_tools.failed": self.on_mcp_list_tools_failed,
1234
+ "response.mcp_list_tools.completed": self.on_mcp_list_tools_completed,
1235
+ "response.mcp_call.in_progress": self.on_mcp_call_in_progress,
1236
+ "response.mcp_call.failed": self.on_mcp_call_failed,
1237
+ "response.mcp_call.completed": self.on_mcp_call_completed,
1238
+ "response.mcp_call.arguments.done": self.on_mcp_call_arguments_done,
1239
+ "response.mcp_call.arguments.delta": self.on_mcp_call_arguments_delta,
955
1240
  }
956
-
957
- async def on_mcp_list_tools_in_progress(self, context: ToolContext, *, sequence_number: int, type: str, **extra):
958
- pass
959
1241
 
960
- async def on_mcp_list_tools_failed(self, context: ToolContext, *, sequence_number: int, type: str, **extra):
1242
+ async def on_mcp_list_tools_in_progress(
1243
+ self, context: ToolContext, *, sequence_number: int, type: str, **extra
1244
+ ):
961
1245
  pass
962
1246
 
963
- async def on_mcp_list_tools_completed(self, context: ToolContext, *, sequence_number: int, type: str, **extra):
1247
+ async def on_mcp_list_tools_failed(
1248
+ self, context: ToolContext, *, sequence_number: int, type: str, **extra
1249
+ ):
964
1250
  pass
965
1251
 
966
- async def on_mcp_call_in_progress(self, context: ToolContext, *, item_id: str, output_index: int, sequence_number: int, type: str, **extra):
1252
+ async def on_mcp_list_tools_completed(
1253
+ self, context: ToolContext, *, sequence_number: int, type: str, **extra
1254
+ ):
967
1255
  pass
968
1256
 
969
- async def on_mcp_call_failed(self, context: ToolContext, *, sequence_number: int, type: str, **extra):
1257
+ async def on_mcp_call_in_progress(
1258
+ self,
1259
+ context: ToolContext,
1260
+ *,
1261
+ item_id: str,
1262
+ output_index: int,
1263
+ sequence_number: int,
1264
+ type: str,
1265
+ **extra,
1266
+ ):
970
1267
  pass
971
1268
 
972
- async def on_mcp_call_completed(self, context: ToolContext, *, sequence_number: int, type: str, **extra):
1269
+ async def on_mcp_call_failed(
1270
+ self, context: ToolContext, *, sequence_number: int, type: str, **extra
1271
+ ):
973
1272
  pass
974
1273
 
975
- async def on_mcp_call_arguments_done(self, context: ToolContext, *, arguments: dict, item_id: str, output_index: int, sequence_number: int, type: str, **extra):
1274
+ async def on_mcp_call_completed(
1275
+ self, context: ToolContext, *, sequence_number: int, type: str, **extra
1276
+ ):
976
1277
  pass
977
1278
 
978
- async def on_mcp_call_arguments_delta(self, context: ToolContext, *, delta: dict, item_id: str, output_index: int, sequence_number: int, type: str, **extra):
1279
+ async def on_mcp_call_arguments_done(
1280
+ self,
1281
+ context: ToolContext,
1282
+ *,
1283
+ arguments: dict,
1284
+ item_id: str,
1285
+ output_index: int,
1286
+ sequence_number: int,
1287
+ type: str,
1288
+ **extra,
1289
+ ):
979
1290
  pass
980
1291
 
1292
+ async def on_mcp_call_arguments_delta(
1293
+ self,
1294
+ context: ToolContext,
1295
+ *,
1296
+ delta: dict,
1297
+ item_id: str,
1298
+ output_index: int,
1299
+ sequence_number: int,
1300
+ type: str,
1301
+ **extra,
1302
+ ):
1303
+ pass
981
1304
 
982
1305
  def get_open_ai_output_handlers(self):
983
1306
  return {
984
- "mcp_call" : self.handle_mcp_call,
985
- "mcp_list_tools" : self.handle_mcp_list_tools,
986
- "mcp_approval_request" : self.handle_mcp_approval_request,
1307
+ "mcp_call": self.handle_mcp_call,
1308
+ "mcp_list_tools": self.handle_mcp_list_tools,
1309
+ "mcp_approval_request": self.handle_mcp_approval_request,
987
1310
  }
988
1311
 
989
- async def on_mcp_list_tools(self, context: ToolContext, *, server_label: str, tools: list[MCPToolDefinition], error: str | None, **extra):
1312
+ async def on_mcp_list_tools(
1313
+ self,
1314
+ context: ToolContext,
1315
+ *,
1316
+ server_label: str,
1317
+ tools: list[MCPToolDefinition],
1318
+ error: str | None,
1319
+ **extra,
1320
+ ):
990
1321
  pass
991
-
992
- async def handle_mcp_list_tools(self, context, *, id: str, server_label: str, tools: list, type: str, error: str | None = None, **extra):
993
-
1322
+
1323
+ async def handle_mcp_list_tools(
1324
+ self,
1325
+ context,
1326
+ *,
1327
+ id: str,
1328
+ server_label: str,
1329
+ tools: list,
1330
+ type: str,
1331
+ error: str | None = None,
1332
+ **extra,
1333
+ ):
994
1334
  mcp_tools = []
995
1335
  for tool in tools:
996
- mcp_tools.append(MCPToolDefinition(input_schema=tool["input_schema"], name=tool["name"], annotations=tool["annotations"], description=tool["description"]))
997
-
998
- await self.on_mcp_list_tools(context, server_label=server_label, tools=mcp_tools, error=error)
1336
+ mcp_tools.append(
1337
+ MCPToolDefinition(
1338
+ input_schema=tool["input_schema"],
1339
+ name=tool["name"],
1340
+ annotations=tool["annotations"],
1341
+ description=tool["description"],
1342
+ )
1343
+ )
999
1344
 
1345
+ await self.on_mcp_list_tools(
1346
+ context, server_label=server_label, tools=mcp_tools, error=error
1347
+ )
1000
1348
 
1001
- async def on_mcp_call(self, context: ToolContext, *, name: str, arguments: str, server_label: str, error: str | None, output: str | None, **extra):
1349
+ async def on_mcp_call(
1350
+ self,
1351
+ context: ToolContext,
1352
+ *,
1353
+ name: str,
1354
+ arguments: str,
1355
+ server_label: str,
1356
+ error: str | None,
1357
+ output: str | None,
1358
+ **extra,
1359
+ ):
1002
1360
  pass
1003
-
1004
- async def handle_mcp_call(self, context, *, arguments: str, id: str, name: str, server_label: str, type: str, error: str | None, output: str | None, **extra):
1005
-
1006
- await self.on_mcp_call(context, name=name, arguments=arguments, server_label=server_label, error=error, output=output)
1007
1361
 
1362
+ async def handle_mcp_call(
1363
+ self,
1364
+ context,
1365
+ *,
1366
+ arguments: str,
1367
+ id: str,
1368
+ name: str,
1369
+ server_label: str,
1370
+ type: str,
1371
+ error: str | None,
1372
+ output: str | None,
1373
+ **extra,
1374
+ ):
1375
+ await self.on_mcp_call(
1376
+ context,
1377
+ name=name,
1378
+ arguments=arguments,
1379
+ server_label=server_label,
1380
+ error=error,
1381
+ output=output,
1382
+ )
1008
1383
 
1009
- async def on_mcp_approval_request(self, context: ToolContext, *, name: str, arguments: str, server_label: str, **extra) -> bool:
1384
+ async def on_mcp_approval_request(
1385
+ self,
1386
+ context: ToolContext,
1387
+ *,
1388
+ name: str,
1389
+ arguments: str,
1390
+ server_label: str,
1391
+ **extra,
1392
+ ) -> bool:
1010
1393
  return True
1011
-
1012
- async def handle_mcp_approval_request(self, context: ToolContext, *, arguments: str, id: str, name: str, server_label: str, type: str, **extra):
1394
+
1395
+ async def handle_mcp_approval_request(
1396
+ self,
1397
+ context: ToolContext,
1398
+ *,
1399
+ arguments: str,
1400
+ id: str,
1401
+ name: str,
1402
+ server_label: str,
1403
+ type: str,
1404
+ **extra,
1405
+ ):
1013
1406
  logger.info("approval requested for MCP tool {server_label}.{name}")
1014
- should_approve = await self.on_mcp_approval_request(context, arguments=arguments, name=name, server_label=server_label)
1407
+ should_approve = await self.on_mcp_approval_request(
1408
+ context, arguments=arguments, name=name, server_label=server_label
1409
+ )
1015
1410
  if should_approve:
1016
1411
  logger.info("approval granted for MCP tool {server_label}.{name}")
1017
1412
  return {
1018
1413
  "type": "mcp_approval_response",
1019
1414
  "approve": True,
1020
- "approval_request_id": id
1415
+ "approval_request_id": id,
1021
1416
  }
1022
1417
  else:
1023
1418
  logger.info("approval denied for MCP tool {server_label}.{name}")
1024
1419
  return {
1025
1420
  "type": "mcp_approval_response",
1026
1421
  "approve": False,
1027
- "approval_request_id": id
1422
+ "approval_request_id": id,
1028
1423
  }
1029
1424
 
1030
1425
 
@@ -1032,96 +1427,190 @@ class ReasoningTool(OpenAIResponsesTool):
1032
1427
  def __init__(self):
1033
1428
  super().__init__(name="reasoning")
1034
1429
 
1035
-
1036
1430
  def get_open_ai_output_handlers(self):
1037
1431
  return {
1038
- "reasoning" : self.handle_reasoning,
1432
+ "reasoning": self.handle_reasoning,
1039
1433
  }
1040
-
1434
+
1041
1435
  def get_open_ai_stream_callbacks(self):
1042
1436
  return {
1043
- "response.reasoning_summary_text.done" : self.on_reasoning_summary_text_done,
1044
- "response.reasoning_summary_text.delta" : self.on_reasoning_summary_text_delta,
1045
- "response.reasoning_summary_part.done" : self.on_reasoning_summary_part_done,
1046
- "response.reasoning_summary_part.added" : self.on_reasoning_summary_part_added,
1437
+ "response.reasoning_summary_text.done": self.on_reasoning_summary_text_done,
1438
+ "response.reasoning_summary_text.delta": self.on_reasoning_summary_text_delta,
1439
+ "response.reasoning_summary_part.done": self.on_reasoning_summary_part_done,
1440
+ "response.reasoning_summary_part.added": self.on_reasoning_summary_part_added,
1047
1441
  }
1048
1442
 
1049
- async def on_reasoning_summary_part_added(self, context: ToolContext, *, item_id: str, output_index: int, part: dict, sequence_number: int, summary_index: int, text: str, type: str, **extra):
1443
+ async def on_reasoning_summary_part_added(
1444
+ self,
1445
+ context: ToolContext,
1446
+ *,
1447
+ item_id: str,
1448
+ output_index: int,
1449
+ part: dict,
1450
+ sequence_number: int,
1451
+ summary_index: int,
1452
+ text: str,
1453
+ type: str,
1454
+ **extra,
1455
+ ):
1050
1456
  pass
1051
1457
 
1052
-
1053
- async def on_reasoning_summary_part_done(self, context: ToolContext, *, item_id: str, output_index: int, part: dict, sequence_number: int, summary_index: int, text: str, type: str, **extra):
1458
+ async def on_reasoning_summary_part_done(
1459
+ self,
1460
+ context: ToolContext,
1461
+ *,
1462
+ item_id: str,
1463
+ output_index: int,
1464
+ part: dict,
1465
+ sequence_number: int,
1466
+ summary_index: int,
1467
+ text: str,
1468
+ type: str,
1469
+ **extra,
1470
+ ):
1054
1471
  pass
1055
1472
 
1056
- async def on_reasoning_summary_text_delta(self, context: ToolContext, *, delta: str, output_index: int, sequence_number: int, summary_index: int, text: str, type: str, **extra):
1473
+ async def on_reasoning_summary_text_delta(
1474
+ self,
1475
+ context: ToolContext,
1476
+ *,
1477
+ delta: str,
1478
+ output_index: int,
1479
+ sequence_number: int,
1480
+ summary_index: int,
1481
+ text: str,
1482
+ type: str,
1483
+ **extra,
1484
+ ):
1057
1485
  pass
1058
1486
 
1059
- async def on_reasoning_summary_text_done(self, context: ToolContext, *, item_id: str, output_index: int, sequence_number: int, summary_index: int, text: str, type: str, **extra):
1487
+ async def on_reasoning_summary_text_done(
1488
+ self,
1489
+ context: ToolContext,
1490
+ *,
1491
+ item_id: str,
1492
+ output_index: int,
1493
+ sequence_number: int,
1494
+ summary_index: int,
1495
+ text: str,
1496
+ type: str,
1497
+ **extra,
1498
+ ):
1060
1499
  pass
1061
1500
 
1062
- async def on_reasoning(self, context: ToolContext, *, summary: str, encrypted_content: str | None, status: Literal["in_progress", "completed", "incomplete"]):
1501
+ async def on_reasoning(
1502
+ self,
1503
+ context: ToolContext,
1504
+ *,
1505
+ summary: str,
1506
+ encrypted_content: str | None,
1507
+ status: Literal["in_progress", "completed", "incomplete"],
1508
+ ):
1063
1509
  pass
1064
-
1065
- async def handle_reasoning(self, context: ToolContext, *, id: str, summary: str, type: str, encrypted_content: str | None, status: str, **extra):
1066
-
1067
- await self.on_reasoning(context, summary=summary, encrypted_content=encrypted_content, status=status)
1510
+
1511
+ async def handle_reasoning(
1512
+ self,
1513
+ context: ToolContext,
1514
+ *,
1515
+ id: str,
1516
+ summary: str,
1517
+ type: str,
1518
+ encrypted_content: str | None,
1519
+ status: str,
1520
+ **extra,
1521
+ ):
1522
+ await self.on_reasoning(
1523
+ context, summary=summary, encrypted_content=encrypted_content, status=status
1524
+ )
1068
1525
 
1069
1526
 
1070
1527
  # TODO: computer tool call
1071
1528
 
1529
+
1072
1530
  class WebSearchTool(OpenAIResponsesTool):
1073
1531
  def __init__(self):
1074
1532
  super().__init__(name="web_search")
1075
1533
 
1076
-
1077
1534
  def get_open_ai_tool_definitions(self) -> list[dict]:
1078
- return [
1079
- {
1080
- "type" : "web_search_preview"
1081
- }
1082
- ]
1083
-
1535
+ return [{"type": "web_search_preview"}]
1084
1536
 
1085
1537
  def get_open_ai_stream_callbacks(self):
1086
1538
  return {
1087
- "response.web_search_call.in_progress" : self.on_web_search_call_in_progress,
1088
- "response.web_search_call.searching" : self.on_web_search_call_searching,
1089
- "response.web_search_call.completed" : self.on_web_search_call_completed,
1539
+ "response.web_search_call.in_progress": self.on_web_search_call_in_progress,
1540
+ "response.web_search_call.searching": self.on_web_search_call_searching,
1541
+ "response.web_search_call.completed": self.on_web_search_call_completed,
1090
1542
  }
1091
-
1543
+
1092
1544
  def get_open_ai_output_handlers(self):
1093
- return {
1094
- "web_search_call" : self.handle_web_search_call
1095
- }
1096
-
1097
- async def on_web_search_call_in_progress(self, context: ToolContext, *, item_id: str, output_index: int, sequence_number: int, type: str, **extra):
1545
+ return {"web_search_call": self.handle_web_search_call}
1546
+
1547
+ async def on_web_search_call_in_progress(
1548
+ self,
1549
+ context: ToolContext,
1550
+ *,
1551
+ item_id: str,
1552
+ output_index: int,
1553
+ sequence_number: int,
1554
+ type: str,
1555
+ **extra,
1556
+ ):
1098
1557
  pass
1099
1558
 
1100
- async def on_web_search_call_searching(self, context: ToolContext, *, item_id: str, output_index: int, sequence_number: int, type: str, **extra):
1559
+ async def on_web_search_call_searching(
1560
+ self,
1561
+ context: ToolContext,
1562
+ *,
1563
+ item_id: str,
1564
+ output_index: int,
1565
+ sequence_number: int,
1566
+ type: str,
1567
+ **extra,
1568
+ ):
1101
1569
  pass
1102
1570
 
1103
- async def on_web_search_call_completed(self, context: ToolContext, *, item_id: str, output_index: int, sequence_number: int, type: str, **extra):
1571
+ async def on_web_search_call_completed(
1572
+ self,
1573
+ context: ToolContext,
1574
+ *,
1575
+ item_id: str,
1576
+ output_index: int,
1577
+ sequence_number: int,
1578
+ type: str,
1579
+ **extra,
1580
+ ):
1104
1581
  pass
1105
1582
 
1106
1583
  async def on_web_search(self, context: ToolContext, *, status: str, **extra):
1107
1584
  pass
1108
-
1109
- async def handle_web_search_call(self, context: ToolContext, *, id: str, status: str, type: str, **extra):
1110
-
1585
+
1586
+ async def handle_web_search_call(
1587
+ self, context: ToolContext, *, id: str, status: str, type: str, **extra
1588
+ ):
1111
1589
  await self.on_web_search(context, status=status)
1112
1590
 
1591
+
1113
1592
  class FileSearchResult:
1114
- def __init__(self, *, attributes: dict, file_id: str, filename: str, score: float, text: str):
1593
+ def __init__(
1594
+ self, *, attributes: dict, file_id: str, filename: str, score: float, text: str
1595
+ ):
1115
1596
  self.attributes = attributes
1116
1597
  self.file_id = file_id
1117
1598
  self.filename = filename
1118
1599
  self.score = score
1119
1600
  self.text = text
1120
1601
 
1602
+
1121
1603
  class FileSearchTool(OpenAIResponsesTool):
1122
- def __init__(self, *, vector_store_ids: list[str], filters: Optional[dict] = None, max_num_results: Optional[int] = None, ranking_options: Optional[dict] = None):
1604
+ def __init__(
1605
+ self,
1606
+ *,
1607
+ vector_store_ids: list[str],
1608
+ filters: Optional[dict] = None,
1609
+ max_num_results: Optional[int] = None,
1610
+ ranking_options: Optional[dict] = None,
1611
+ ):
1123
1612
  super().__init__(name="file_search")
1124
-
1613
+
1125
1614
  self.vector_store_ids = vector_store_ids
1126
1615
  self.filters = filters
1127
1616
  self.max_num_results = max_num_results
@@ -1130,48 +1619,87 @@ class FileSearchTool(OpenAIResponsesTool):
1130
1619
  def get_open_ai_tool_definitions(self) -> list[dict]:
1131
1620
  return [
1132
1621
  {
1133
- "type" : "file_search",
1134
- "vector_store_ids" : self.vector_store_ids,
1135
- "filters" : self.filters,
1136
- "max_num_results" : self.max_num_results,
1137
- "ranking_options" : self.ranking_options
1622
+ "type": "file_search",
1623
+ "vector_store_ids": self.vector_store_ids,
1624
+ "filters": self.filters,
1625
+ "max_num_results": self.max_num_results,
1626
+ "ranking_options": self.ranking_options,
1138
1627
  }
1139
1628
  ]
1140
1629
 
1141
-
1142
1630
  def get_open_ai_stream_callbacks(self):
1143
1631
  return {
1144
- "response.file_search_call.in_progress" : self.on_file_search_call_in_progress,
1145
- "response.file_search_call.searching" : self.on_file_search_call_searching,
1146
- "response.file_search_call.completed" : self.on_file_search_call_completed,
1632
+ "response.file_search_call.in_progress": self.on_file_search_call_in_progress,
1633
+ "response.file_search_call.searching": self.on_file_search_call_searching,
1634
+ "response.file_search_call.completed": self.on_file_search_call_completed,
1147
1635
  }
1148
-
1149
1636
 
1150
1637
  def get_open_ai_output_handlers(self):
1151
- return {
1152
- "handle_file_search_call" : self.handle_file_search_call
1153
- }
1638
+ return {"handle_file_search_call": self.handle_file_search_call}
1154
1639
 
1155
-
1156
- async def on_file_search_call_in_progress(self, context: ToolContext, *, item_id: str, output_index: int, sequence_number: int, type: str, **extra):
1640
+ async def on_file_search_call_in_progress(
1641
+ self,
1642
+ context: ToolContext,
1643
+ *,
1644
+ item_id: str,
1645
+ output_index: int,
1646
+ sequence_number: int,
1647
+ type: str,
1648
+ **extra,
1649
+ ):
1157
1650
  pass
1158
1651
 
1159
- async def on_file_search_call_searching(self, context: ToolContext, *, item_id: str, output_index: int, sequence_number: int, type: str, **extra):
1652
+ async def on_file_search_call_searching(
1653
+ self,
1654
+ context: ToolContext,
1655
+ *,
1656
+ item_id: str,
1657
+ output_index: int,
1658
+ sequence_number: int,
1659
+ type: str,
1660
+ **extra,
1661
+ ):
1160
1662
  pass
1161
1663
 
1162
- async def on_file_search_call_completed(self, context: ToolContext, *, item_id: str, output_index: int, sequence_number: int, type: str, **extra):
1664
+ async def on_file_search_call_completed(
1665
+ self,
1666
+ context: ToolContext,
1667
+ *,
1668
+ item_id: str,
1669
+ output_index: int,
1670
+ sequence_number: int,
1671
+ type: str,
1672
+ **extra,
1673
+ ):
1163
1674
  pass
1164
1675
 
1165
- async def on_file_search(self, context: ToolContext, *, queries: list, results: list[FileSearchResult], status: Literal["in_progress", "searching", "incomplete", "failed"]):
1676
+ async def on_file_search(
1677
+ self,
1678
+ context: ToolContext,
1679
+ *,
1680
+ queries: list,
1681
+ results: list[FileSearchResult],
1682
+ status: Literal["in_progress", "searching", "incomplete", "failed"],
1683
+ ):
1166
1684
  pass
1167
-
1168
- async def handle_file_search_call(self, context: ToolContext, *, id: str, queries: list, status: str, results: dict | None, type: str, **extra):
1169
-
1685
+
1686
+ async def handle_file_search_call(
1687
+ self,
1688
+ context: ToolContext,
1689
+ *,
1690
+ id: str,
1691
+ queries: list,
1692
+ status: str,
1693
+ results: dict | None,
1694
+ type: str,
1695
+ **extra,
1696
+ ):
1170
1697
  search_results = None
1171
- if results != None:
1698
+ if results is not None:
1172
1699
  search_results = []
1173
1700
  for result in results:
1174
1701
  search_results.append(FileSearchResult(**result))
1175
1702
 
1176
- await self.on_file_search(context, queries=queries, results=search_results, status=status)
1177
-
1703
+ await self.on_file_search(
1704
+ context, queries=queries, results=search_results, status=status
1705
+ )