meshagent-openai 0.18.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- meshagent/openai/__init__.py +16 -0
- meshagent/openai/proxy/__init__.py +3 -0
- meshagent/openai/proxy/proxy.py +79 -0
- meshagent/openai/tools/__init__.py +18 -0
- meshagent/openai/tools/apply_patch.py +344 -0
- meshagent/openai/tools/completions_adapter.py +437 -0
- meshagent/openai/tools/responses_adapter.py +2369 -0
- meshagent/openai/tools/schema.py +253 -0
- meshagent/openai/tools/stt.py +118 -0
- meshagent/openai/tools/stt_test.py +87 -0
- meshagent/openai/version.py +1 -0
- meshagent_openai-0.18.0.dist-info/METADATA +50 -0
- meshagent_openai-0.18.0.dist-info/RECORD +16 -0
- meshagent_openai-0.18.0.dist-info/WHEEL +5 -0
- meshagent_openai-0.18.0.dist-info/licenses/LICENSE +201 -0
- meshagent_openai-0.18.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,2369 @@
|
|
|
1
|
+
from meshagent.agents.agent import AgentChatContext
|
|
2
|
+
from meshagent.api import RoomClient, RoomException, RemoteParticipant
|
|
3
|
+
from meshagent.tools import Toolkit, ToolContext, Tool, BaseTool
|
|
4
|
+
from meshagent.api.messaging import (
|
|
5
|
+
Response,
|
|
6
|
+
LinkResponse,
|
|
7
|
+
FileResponse,
|
|
8
|
+
JsonResponse,
|
|
9
|
+
TextResponse,
|
|
10
|
+
EmptyResponse,
|
|
11
|
+
RawOutputs,
|
|
12
|
+
ensure_response,
|
|
13
|
+
)
|
|
14
|
+
from meshagent.agents.adapter import (
|
|
15
|
+
ToolResponseAdapter,
|
|
16
|
+
LLMAdapter,
|
|
17
|
+
ToolkitBuilder,
|
|
18
|
+
ToolkitConfig,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
from meshagent.api.specs.service import ContainerMountSpec, RoomStorageMountSpec
|
|
22
|
+
import json
|
|
23
|
+
from typing import List, Literal
|
|
24
|
+
from meshagent.openai.proxy import get_client
|
|
25
|
+
from openai import AsyncOpenAI, NOT_GIVEN, APIStatusError
|
|
26
|
+
from openai.types.responses import ResponseFunctionToolCall, ResponseStreamEvent
|
|
27
|
+
import os
|
|
28
|
+
from typing import Optional, Callable
|
|
29
|
+
import base64
|
|
30
|
+
|
|
31
|
+
import logging
|
|
32
|
+
import re
|
|
33
|
+
import asyncio
|
|
34
|
+
from pydantic import BaseModel
|
|
35
|
+
import copy
|
|
36
|
+
from opentelemetry import trace
|
|
37
|
+
|
|
38
|
+
logger = logging.getLogger("openai_agent")
|
|
39
|
+
tracer = trace.get_tracer("openai.llm.responses")
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def safe_json_dump(data: dict):
|
|
43
|
+
return json.dumps(copy.deepcopy(data))
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def safe_model_dump(model: BaseModel):
|
|
47
|
+
try:
|
|
48
|
+
return safe_json_dump(model.model_dump(mode="json"))
|
|
49
|
+
except Exception:
|
|
50
|
+
return {"error": "unable to dump json for model"}
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _replace_non_matching(text: str, allowed_chars: str, replacement: str) -> str:
|
|
54
|
+
"""
|
|
55
|
+
Replaces every character in `text` that does not match the given
|
|
56
|
+
`allowed_chars` regex set with `replacement`.
|
|
57
|
+
|
|
58
|
+
Parameters:
|
|
59
|
+
-----------
|
|
60
|
+
text : str
|
|
61
|
+
The input string on which the replacement is to be done.
|
|
62
|
+
allowed_chars : str
|
|
63
|
+
A string defining the set of allowed characters (part of a character set).
|
|
64
|
+
For example, "a-zA-Z0-9" will keep only letters and digits.
|
|
65
|
+
replacement : str
|
|
66
|
+
The string to replace non-matching characters with.
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
--------
|
|
70
|
+
str
|
|
71
|
+
A new string where all characters not in `allowed_chars` are replaced.
|
|
72
|
+
"""
|
|
73
|
+
# Build a regex that matches any character NOT in allowed_chars
|
|
74
|
+
pattern = rf"[^{allowed_chars}]"
|
|
75
|
+
return re.sub(pattern, replacement, text)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def safe_tool_name(name: str):
|
|
79
|
+
return _replace_non_matching(name, "a-zA-Z0-9_-", "_")
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
# Collects a group of tool proxies and manages execution of openai tool calls
|
|
83
|
+
class ResponsesToolBundle:
|
|
84
|
+
def __init__(self, toolkits: List[Toolkit]):
|
|
85
|
+
self._toolkits = toolkits
|
|
86
|
+
self._executors = dict[str, Toolkit]()
|
|
87
|
+
self._safe_names = {}
|
|
88
|
+
self._tools_by_name = {}
|
|
89
|
+
|
|
90
|
+
open_ai_tools = []
|
|
91
|
+
|
|
92
|
+
for toolkit in toolkits:
|
|
93
|
+
for v in toolkit.tools:
|
|
94
|
+
k = v.name
|
|
95
|
+
|
|
96
|
+
name = safe_tool_name(k)
|
|
97
|
+
|
|
98
|
+
if k in self._executors:
|
|
99
|
+
raise Exception(
|
|
100
|
+
f"duplicate in bundle '{k}', tool names must be unique."
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
self._executors[k] = toolkit
|
|
104
|
+
|
|
105
|
+
self._safe_names[name] = k
|
|
106
|
+
self._tools_by_name[name] = v
|
|
107
|
+
|
|
108
|
+
if isinstance(v, OpenAIResponsesTool):
|
|
109
|
+
fns = v.get_open_ai_tool_definitions()
|
|
110
|
+
for fn in fns:
|
|
111
|
+
open_ai_tools.append(fn)
|
|
112
|
+
|
|
113
|
+
elif isinstance(v, Tool):
|
|
114
|
+
strict = True
|
|
115
|
+
if hasattr(v, "strict"):
|
|
116
|
+
strict = getattr(v, "strict")
|
|
117
|
+
|
|
118
|
+
fn = {
|
|
119
|
+
"type": "function",
|
|
120
|
+
"name": name,
|
|
121
|
+
"description": v.description,
|
|
122
|
+
"parameters": {
|
|
123
|
+
**v.input_schema,
|
|
124
|
+
},
|
|
125
|
+
"strict": strict,
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
if v.defs is not None:
|
|
129
|
+
fn["parameters"]["$defs"] = v.defs
|
|
130
|
+
|
|
131
|
+
open_ai_tools.append(fn)
|
|
132
|
+
|
|
133
|
+
else:
|
|
134
|
+
raise RoomException(f"unsupported tool type {type(v)}")
|
|
135
|
+
|
|
136
|
+
if len(open_ai_tools) == 0:
|
|
137
|
+
open_ai_tools = None
|
|
138
|
+
|
|
139
|
+
self._open_ai_tools = open_ai_tools
|
|
140
|
+
|
|
141
|
+
async def execute(
|
|
142
|
+
self, *, context: ToolContext, tool_call: ResponseFunctionToolCall
|
|
143
|
+
) -> Response:
|
|
144
|
+
name = tool_call.name
|
|
145
|
+
arguments = json.loads(tool_call.arguments)
|
|
146
|
+
|
|
147
|
+
if name not in self._safe_names:
|
|
148
|
+
raise RoomException(f"Invalid tool name {name}, check the name of the tool")
|
|
149
|
+
|
|
150
|
+
name = self._safe_names[name]
|
|
151
|
+
|
|
152
|
+
if name not in self._executors:
|
|
153
|
+
raise Exception(f"Unregistered tool name {name}")
|
|
154
|
+
|
|
155
|
+
proxy = self._executors[name]
|
|
156
|
+
result = await proxy.execute(context=context, name=name, arguments=arguments)
|
|
157
|
+
return ensure_response(result)
|
|
158
|
+
|
|
159
|
+
def get_tool(self, name: str) -> BaseTool | None:
|
|
160
|
+
return self._tools_by_name.get(name, None)
|
|
161
|
+
|
|
162
|
+
def contains(self, name: str) -> bool:
|
|
163
|
+
return name in self._open_ai_tools
|
|
164
|
+
|
|
165
|
+
def to_json(self) -> List[dict] | None:
|
|
166
|
+
if self._open_ai_tools is None:
|
|
167
|
+
return None
|
|
168
|
+
return self._open_ai_tools.copy()
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
# Converts a tool response into a series of messages that can be inserted into the openai context
|
|
172
|
+
class OpenAIResponsesToolResponseAdapter(ToolResponseAdapter):
|
|
173
|
+
def __init__(self):
|
|
174
|
+
pass
|
|
175
|
+
|
|
176
|
+
async def to_plain_text(self, *, room: RoomClient, response: Response) -> str:
|
|
177
|
+
if isinstance(response, LinkResponse):
|
|
178
|
+
return json.dumps(
|
|
179
|
+
{
|
|
180
|
+
"name": response.name,
|
|
181
|
+
"url": response.url,
|
|
182
|
+
}
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
elif isinstance(response, JsonResponse):
|
|
186
|
+
return json.dumps(response.json)
|
|
187
|
+
|
|
188
|
+
elif isinstance(response, TextResponse):
|
|
189
|
+
return response.text
|
|
190
|
+
|
|
191
|
+
elif isinstance(response, FileResponse):
|
|
192
|
+
return f"{response.name}"
|
|
193
|
+
|
|
194
|
+
elif isinstance(response, EmptyResponse):
|
|
195
|
+
return "ok"
|
|
196
|
+
|
|
197
|
+
# elif isinstance(response, ImageResponse):
|
|
198
|
+
# context.messages.append({
|
|
199
|
+
# "role" : "assistant",
|
|
200
|
+
# "content" : "the user will upload the image",
|
|
201
|
+
# "tool_call_id" : tool_call.id,
|
|
202
|
+
# })
|
|
203
|
+
# context.messages.append({
|
|
204
|
+
# "role" : "user",
|
|
205
|
+
# "content" : [
|
|
206
|
+
# { "type" : "text", "text": "this is the image from tool call id {tool_call.id}" },
|
|
207
|
+
# { "type" : "image_url", "image_url": {"url": response.url, "detail": "auto"} }
|
|
208
|
+
# ]
|
|
209
|
+
# })
|
|
210
|
+
|
|
211
|
+
elif isinstance(response, dict):
|
|
212
|
+
return json.dumps(response)
|
|
213
|
+
|
|
214
|
+
elif isinstance(response, str):
|
|
215
|
+
return response
|
|
216
|
+
|
|
217
|
+
elif response is None:
|
|
218
|
+
return "ok"
|
|
219
|
+
|
|
220
|
+
else:
|
|
221
|
+
raise Exception(
|
|
222
|
+
"unexpected return type: {type}".format(type=type(response))
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
async def create_messages(
|
|
226
|
+
self,
|
|
227
|
+
*,
|
|
228
|
+
context: AgentChatContext,
|
|
229
|
+
tool_call: ResponseFunctionToolCall,
|
|
230
|
+
room: RoomClient,
|
|
231
|
+
response: Response,
|
|
232
|
+
) -> list:
|
|
233
|
+
with tracer.start_as_current_span("llm.tool_adapter.create_messages") as span:
|
|
234
|
+
if isinstance(response, RawOutputs):
|
|
235
|
+
span.set_attribute("kind", "raw")
|
|
236
|
+
for output in response.outputs:
|
|
237
|
+
room.developer.log_nowait(
|
|
238
|
+
type="llm.message",
|
|
239
|
+
data={
|
|
240
|
+
"context": context.id,
|
|
241
|
+
"participant_id": room.local_participant.id,
|
|
242
|
+
"participant_name": room.local_participant.get_attribute(
|
|
243
|
+
"name"
|
|
244
|
+
),
|
|
245
|
+
"message": output,
|
|
246
|
+
},
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
return response.outputs
|
|
250
|
+
|
|
251
|
+
else:
|
|
252
|
+
span.set_attribute("kind", "text")
|
|
253
|
+
|
|
254
|
+
if isinstance(response, FileResponse):
|
|
255
|
+
if response.mime_type and response.mime_type.startswith("image/"):
|
|
256
|
+
span.set_attribute(
|
|
257
|
+
"output", f"image: {response.name}, {response.mime_type}"
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
message = {
|
|
261
|
+
"output": [
|
|
262
|
+
{
|
|
263
|
+
"type": "input_image",
|
|
264
|
+
"image_url": f"data:{response.mime_type};base64,{base64.b64encode(response.data).decode()}",
|
|
265
|
+
}
|
|
266
|
+
],
|
|
267
|
+
"call_id": tool_call.call_id,
|
|
268
|
+
"type": "function_call_output",
|
|
269
|
+
}
|
|
270
|
+
else:
|
|
271
|
+
span.set_attribute(
|
|
272
|
+
"output", f"file: {response.name}, {response.mime_type}"
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
if response.mime_type == "application/pdf":
|
|
276
|
+
message = {
|
|
277
|
+
"output": [
|
|
278
|
+
{
|
|
279
|
+
"type": "input_file",
|
|
280
|
+
"filename": response.name,
|
|
281
|
+
"file_data": f"data:{response.mime_type or 'text/plain'};base64,{base64.b64encode(response.data).decode()}",
|
|
282
|
+
}
|
|
283
|
+
],
|
|
284
|
+
"call_id": tool_call.call_id,
|
|
285
|
+
"type": "function_call_output",
|
|
286
|
+
}
|
|
287
|
+
elif response.mime_type is not None and (
|
|
288
|
+
response.mime_type.startswith("text/")
|
|
289
|
+
or response.mime_type == "application/json"
|
|
290
|
+
):
|
|
291
|
+
message = {
|
|
292
|
+
"output": response.data.decode(),
|
|
293
|
+
"call_id": tool_call.call_id,
|
|
294
|
+
"type": "function_call_output",
|
|
295
|
+
}
|
|
296
|
+
|
|
297
|
+
else:
|
|
298
|
+
message = {
|
|
299
|
+
"output": f"{response.name} was not in a supported format",
|
|
300
|
+
"call_id": tool_call.call_id,
|
|
301
|
+
"type": "function_call_output",
|
|
302
|
+
}
|
|
303
|
+
|
|
304
|
+
room.developer.log_nowait(
|
|
305
|
+
type="llm.message",
|
|
306
|
+
data={
|
|
307
|
+
"context": context.id,
|
|
308
|
+
"participant_id": room.local_participant.id,
|
|
309
|
+
"participant_name": room.local_participant.get_attribute(
|
|
310
|
+
"name"
|
|
311
|
+
),
|
|
312
|
+
"message": message,
|
|
313
|
+
},
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
return [message]
|
|
317
|
+
else:
|
|
318
|
+
output = await self.to_plain_text(room=room, response=response)
|
|
319
|
+
span.set_attribute("output", output)
|
|
320
|
+
|
|
321
|
+
message = {
|
|
322
|
+
"output": output,
|
|
323
|
+
"call_id": tool_call.call_id,
|
|
324
|
+
"type": "function_call_output",
|
|
325
|
+
}
|
|
326
|
+
|
|
327
|
+
room.developer.log_nowait(
|
|
328
|
+
type="llm.message",
|
|
329
|
+
data={
|
|
330
|
+
"context": context.id,
|
|
331
|
+
"participant_id": room.local_participant.id,
|
|
332
|
+
"participant_name": room.local_participant.get_attribute(
|
|
333
|
+
"name"
|
|
334
|
+
),
|
|
335
|
+
"message": message,
|
|
336
|
+
},
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
return [message]
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
class OpenAIResponsesAdapter(LLMAdapter[ResponseStreamEvent]):
|
|
343
|
+
def __init__(
|
|
344
|
+
self,
|
|
345
|
+
model: str = os.getenv("OPENAI_MODEL", "gpt-5.2"),
|
|
346
|
+
parallel_tool_calls: Optional[bool] = None,
|
|
347
|
+
client: Optional[AsyncOpenAI] = None,
|
|
348
|
+
response_options: Optional[dict] = None,
|
|
349
|
+
reasoning_effort: Optional[str] = None,
|
|
350
|
+
provider: str = "openai",
|
|
351
|
+
):
|
|
352
|
+
self._model = model
|
|
353
|
+
self._parallel_tool_calls = parallel_tool_calls
|
|
354
|
+
self._client = client
|
|
355
|
+
self._response_options = response_options
|
|
356
|
+
self._provider = provider
|
|
357
|
+
self._reasoning_effort = reasoning_effort
|
|
358
|
+
|
|
359
|
+
def default_model(self) -> str:
|
|
360
|
+
return self._model
|
|
361
|
+
|
|
362
|
+
def create_chat_context(self):
|
|
363
|
+
context = AgentChatContext(system_role=None)
|
|
364
|
+
return context
|
|
365
|
+
|
|
366
|
+
async def check_for_termination(
|
|
367
|
+
self, *, context: AgentChatContext, room: RoomClient
|
|
368
|
+
) -> bool:
|
|
369
|
+
for message in context.messages:
|
|
370
|
+
if message.get("type", "message") != "message":
|
|
371
|
+
return False
|
|
372
|
+
|
|
373
|
+
return True
|
|
374
|
+
|
|
375
|
+
# Takes the current chat context, executes a completion request and processes the response.
|
|
376
|
+
# If a tool calls are requested, invokes the tools, processes the tool calls results, and appends the tool call results to the context
|
|
377
|
+
async def next(
|
|
378
|
+
self,
|
|
379
|
+
*,
|
|
380
|
+
model: Optional[str] = None,
|
|
381
|
+
context: AgentChatContext,
|
|
382
|
+
room: RoomClient,
|
|
383
|
+
toolkits: list[Toolkit],
|
|
384
|
+
tool_adapter: Optional[ToolResponseAdapter] = None,
|
|
385
|
+
output_schema: Optional[dict] = None,
|
|
386
|
+
event_handler: Optional[Callable[[ResponseStreamEvent], None]] = None,
|
|
387
|
+
on_behalf_of: Optional[RemoteParticipant] = None,
|
|
388
|
+
):
|
|
389
|
+
if model is None:
|
|
390
|
+
model = self.default_model()
|
|
391
|
+
|
|
392
|
+
with tracer.start_as_current_span("llm.turn") as span:
|
|
393
|
+
span.set_attributes({"chat_context": context.id, "api": "responses"})
|
|
394
|
+
|
|
395
|
+
if tool_adapter is None:
|
|
396
|
+
tool_adapter = OpenAIResponsesToolResponseAdapter()
|
|
397
|
+
|
|
398
|
+
try:
|
|
399
|
+
while True:
|
|
400
|
+
with tracer.start_as_current_span("llm.turn.iteration") as span:
|
|
401
|
+
span.set_attributes(
|
|
402
|
+
{"model": model, "provider": self._provider}
|
|
403
|
+
)
|
|
404
|
+
|
|
405
|
+
openai = (
|
|
406
|
+
self._client
|
|
407
|
+
if self._client is not None
|
|
408
|
+
else get_client(room=room)
|
|
409
|
+
)
|
|
410
|
+
|
|
411
|
+
response_schema = output_schema
|
|
412
|
+
response_name = "response"
|
|
413
|
+
|
|
414
|
+
# We need to do this inside the loop because tools can change mid loop
|
|
415
|
+
# for example computer use adds goto tools after the first interaction
|
|
416
|
+
tool_bundle = ResponsesToolBundle(
|
|
417
|
+
toolkits=[
|
|
418
|
+
*toolkits,
|
|
419
|
+
]
|
|
420
|
+
)
|
|
421
|
+
open_ai_tools = tool_bundle.to_json()
|
|
422
|
+
|
|
423
|
+
if open_ai_tools is None:
|
|
424
|
+
open_ai_tools = NOT_GIVEN
|
|
425
|
+
|
|
426
|
+
ptc = self._parallel_tool_calls
|
|
427
|
+
extra = {}
|
|
428
|
+
if ptc is not None and not model.startswith("o"):
|
|
429
|
+
extra["parallel_tool_calls"] = ptc
|
|
430
|
+
span.set_attribute("parallel_tool_calls", ptc)
|
|
431
|
+
else:
|
|
432
|
+
span.set_attribute("parallel_tool_calls", False)
|
|
433
|
+
|
|
434
|
+
text = NOT_GIVEN
|
|
435
|
+
if output_schema is not None:
|
|
436
|
+
span.set_attribute("response_format", "json_schema")
|
|
437
|
+
text = {
|
|
438
|
+
"format": {
|
|
439
|
+
"type": "json_schema",
|
|
440
|
+
"name": response_name,
|
|
441
|
+
"schema": response_schema,
|
|
442
|
+
"strict": True,
|
|
443
|
+
}
|
|
444
|
+
}
|
|
445
|
+
else:
|
|
446
|
+
span.set_attribute("response_format", "text")
|
|
447
|
+
|
|
448
|
+
previous_response_id = NOT_GIVEN
|
|
449
|
+
instructions = context.get_system_instructions()
|
|
450
|
+
if context.previous_response_id is not None:
|
|
451
|
+
previous_response_id = context.previous_response_id
|
|
452
|
+
|
|
453
|
+
stream = event_handler is not None
|
|
454
|
+
|
|
455
|
+
with tracer.start_as_current_span("llm.invoke") as span:
|
|
456
|
+
response_options = copy.deepcopy(self._response_options)
|
|
457
|
+
if response_options is None:
|
|
458
|
+
response_options = {}
|
|
459
|
+
|
|
460
|
+
if self._reasoning_effort is not None:
|
|
461
|
+
response_options["reasoning"] = {
|
|
462
|
+
"effort": self._reasoning_effort,
|
|
463
|
+
"summary": "detailed",
|
|
464
|
+
}
|
|
465
|
+
|
|
466
|
+
extra_headers = {}
|
|
467
|
+
if on_behalf_of is not None:
|
|
468
|
+
on_behalf_of_name = on_behalf_of.get_attribute("name")
|
|
469
|
+
logger.info(
|
|
470
|
+
f"{room.local_participant.get_attribute('name')} making openai request on behalf of {on_behalf_of_name}"
|
|
471
|
+
)
|
|
472
|
+
extra_headers["Meshagent-On-Behalf-Of"] = (
|
|
473
|
+
on_behalf_of_name
|
|
474
|
+
)
|
|
475
|
+
|
|
476
|
+
logger.info(
|
|
477
|
+
f"requesting response from openai with model: {model}"
|
|
478
|
+
)
|
|
479
|
+
|
|
480
|
+
response: Response = await openai.responses.create(
|
|
481
|
+
extra_headers=extra_headers,
|
|
482
|
+
stream=stream,
|
|
483
|
+
model=model,
|
|
484
|
+
input=context.messages,
|
|
485
|
+
tools=open_ai_tools,
|
|
486
|
+
text=text,
|
|
487
|
+
previous_response_id=previous_response_id,
|
|
488
|
+
instructions=instructions or NOT_GIVEN,
|
|
489
|
+
**response_options,
|
|
490
|
+
)
|
|
491
|
+
|
|
492
|
+
async def handle_message(message: BaseModel):
|
|
493
|
+
with tracer.start_as_current_span(
|
|
494
|
+
"llm.handle_response"
|
|
495
|
+
) as span:
|
|
496
|
+
span.set_attributes(
|
|
497
|
+
{
|
|
498
|
+
"type": message.type,
|
|
499
|
+
"message": safe_model_dump(message),
|
|
500
|
+
}
|
|
501
|
+
)
|
|
502
|
+
|
|
503
|
+
room.developer.log_nowait(
|
|
504
|
+
type="llm.message",
|
|
505
|
+
data={
|
|
506
|
+
"context": context.id,
|
|
507
|
+
"participant_id": room.local_participant.id,
|
|
508
|
+
"participant_name": room.local_participant.get_attribute(
|
|
509
|
+
"name"
|
|
510
|
+
),
|
|
511
|
+
"message": message.to_dict(),
|
|
512
|
+
},
|
|
513
|
+
)
|
|
514
|
+
|
|
515
|
+
if message.type == "function_call":
|
|
516
|
+
tasks = []
|
|
517
|
+
|
|
518
|
+
async def do_tool_call(
|
|
519
|
+
tool_call: ResponseFunctionToolCall,
|
|
520
|
+
):
|
|
521
|
+
try:
|
|
522
|
+
with tracer.start_as_current_span(
|
|
523
|
+
"llm.handle_tool_call"
|
|
524
|
+
) as span:
|
|
525
|
+
span.set_attributes(
|
|
526
|
+
{
|
|
527
|
+
"id": tool_call.id,
|
|
528
|
+
"name": tool_call.name,
|
|
529
|
+
"call_id": tool_call.call_id,
|
|
530
|
+
"arguments": json.dumps(
|
|
531
|
+
tool_call.arguments
|
|
532
|
+
),
|
|
533
|
+
}
|
|
534
|
+
)
|
|
535
|
+
|
|
536
|
+
tool_context = ToolContext(
|
|
537
|
+
room=room,
|
|
538
|
+
caller=room.local_participant,
|
|
539
|
+
on_behalf_of=on_behalf_of,
|
|
540
|
+
caller_context={
|
|
541
|
+
"chat": context.to_json()
|
|
542
|
+
},
|
|
543
|
+
)
|
|
544
|
+
tool_response = (
|
|
545
|
+
await tool_bundle.execute(
|
|
546
|
+
context=tool_context,
|
|
547
|
+
tool_call=tool_call,
|
|
548
|
+
)
|
|
549
|
+
)
|
|
550
|
+
if (
|
|
551
|
+
tool_response.caller_context
|
|
552
|
+
is not None
|
|
553
|
+
):
|
|
554
|
+
if (
|
|
555
|
+
tool_response.caller_context.get(
|
|
556
|
+
"chat", None
|
|
557
|
+
)
|
|
558
|
+
is not None
|
|
559
|
+
):
|
|
560
|
+
tool_chat_context = AgentChatContext.from_json(
|
|
561
|
+
tool_response.caller_context[
|
|
562
|
+
"chat"
|
|
563
|
+
]
|
|
564
|
+
)
|
|
565
|
+
if (
|
|
566
|
+
tool_chat_context.previous_response_id
|
|
567
|
+
is not None
|
|
568
|
+
):
|
|
569
|
+
context.track_response(
|
|
570
|
+
tool_chat_context.previous_response_id
|
|
571
|
+
)
|
|
572
|
+
|
|
573
|
+
logger.info(
|
|
574
|
+
f"tool response {tool_response}"
|
|
575
|
+
)
|
|
576
|
+
return await tool_adapter.create_messages(
|
|
577
|
+
context=context,
|
|
578
|
+
tool_call=tool_call,
|
|
579
|
+
room=room,
|
|
580
|
+
response=tool_response,
|
|
581
|
+
)
|
|
582
|
+
|
|
583
|
+
except Exception as e:
|
|
584
|
+
logger.error(
|
|
585
|
+
f"unable to complete tool call {tool_call}",
|
|
586
|
+
exc_info=e,
|
|
587
|
+
)
|
|
588
|
+
room.developer.log_nowait(
|
|
589
|
+
type="llm.error",
|
|
590
|
+
data={
|
|
591
|
+
"participant_id": room.local_participant.id,
|
|
592
|
+
"participant_name": room.local_participant.get_attribute(
|
|
593
|
+
"name"
|
|
594
|
+
),
|
|
595
|
+
"error": f"{e}",
|
|
596
|
+
},
|
|
597
|
+
)
|
|
598
|
+
|
|
599
|
+
return [
|
|
600
|
+
{
|
|
601
|
+
"output": json.dumps(
|
|
602
|
+
{
|
|
603
|
+
"error": f"unable to complete tool call: {e}"
|
|
604
|
+
}
|
|
605
|
+
),
|
|
606
|
+
"call_id": tool_call.call_id,
|
|
607
|
+
"type": "function_call_output",
|
|
608
|
+
}
|
|
609
|
+
]
|
|
610
|
+
|
|
611
|
+
tasks.append(
|
|
612
|
+
asyncio.create_task(do_tool_call(message))
|
|
613
|
+
)
|
|
614
|
+
|
|
615
|
+
results = await asyncio.gather(*tasks)
|
|
616
|
+
|
|
617
|
+
all_results = []
|
|
618
|
+
for result in results:
|
|
619
|
+
room.developer.log_nowait(
|
|
620
|
+
type="llm.message",
|
|
621
|
+
data={
|
|
622
|
+
"context": context.id,
|
|
623
|
+
"participant_id": room.local_participant.id,
|
|
624
|
+
"participant_name": room.local_participant.get_attribute(
|
|
625
|
+
"name"
|
|
626
|
+
),
|
|
627
|
+
"message": result,
|
|
628
|
+
},
|
|
629
|
+
)
|
|
630
|
+
all_results.extend(result)
|
|
631
|
+
|
|
632
|
+
return all_results, False
|
|
633
|
+
|
|
634
|
+
elif message.type == "message":
|
|
635
|
+
contents = message.content
|
|
636
|
+
if response_schema is None:
|
|
637
|
+
return [], False
|
|
638
|
+
else:
|
|
639
|
+
for content in contents:
|
|
640
|
+
# First try to parse the result
|
|
641
|
+
try:
|
|
642
|
+
full_response = json.loads(
|
|
643
|
+
content.text
|
|
644
|
+
)
|
|
645
|
+
|
|
646
|
+
# sometimes open ai packs two JSON chunks seperated by newline, check if that's why we couldn't parse
|
|
647
|
+
except json.decoder.JSONDecodeError:
|
|
648
|
+
for (
|
|
649
|
+
part
|
|
650
|
+
) in content.text.splitlines():
|
|
651
|
+
if len(part.strip()) > 0:
|
|
652
|
+
full_response = json.loads(
|
|
653
|
+
part
|
|
654
|
+
)
|
|
655
|
+
|
|
656
|
+
try:
|
|
657
|
+
self.validate(
|
|
658
|
+
response=full_response,
|
|
659
|
+
output_schema=response_schema,
|
|
660
|
+
)
|
|
661
|
+
except Exception as e:
|
|
662
|
+
logger.error(
|
|
663
|
+
"recieved invalid response, retrying",
|
|
664
|
+
exc_info=e,
|
|
665
|
+
)
|
|
666
|
+
error = {
|
|
667
|
+
"role": "user",
|
|
668
|
+
"content": "encountered a validation error with the output: {error}".format(
|
|
669
|
+
error=e
|
|
670
|
+
),
|
|
671
|
+
}
|
|
672
|
+
room.developer.log_nowait(
|
|
673
|
+
type="llm.message",
|
|
674
|
+
data={
|
|
675
|
+
"context": message.id,
|
|
676
|
+
"participant_id": room.local_participant.id,
|
|
677
|
+
"participant_name": room.local_participant.get_attribute(
|
|
678
|
+
"name"
|
|
679
|
+
),
|
|
680
|
+
"message": error,
|
|
681
|
+
},
|
|
682
|
+
)
|
|
683
|
+
context.messages.append(
|
|
684
|
+
error
|
|
685
|
+
)
|
|
686
|
+
continue
|
|
687
|
+
|
|
688
|
+
return [full_response], True
|
|
689
|
+
# elif message.type == "computer_call" and tool_bundle.get_tool("computer_call"):
|
|
690
|
+
# with tracer.start_as_current_span("llm.handle_computer_call") as span:
|
|
691
|
+
#
|
|
692
|
+
# computer_call :ResponseComputerToolCall = message
|
|
693
|
+
# span.set_attributes({
|
|
694
|
+
# "id": computer_call.id,
|
|
695
|
+
# "action": computer_call.action,
|
|
696
|
+
# "call_id": computer_call.call_id,
|
|
697
|
+
# "type": json.dumps(computer_call.type)
|
|
698
|
+
# })
|
|
699
|
+
|
|
700
|
+
# tool_context = ToolContext(
|
|
701
|
+
# room=room,
|
|
702
|
+
# caller=room.local_participant,
|
|
703
|
+
# caller_context={ "chat" : context.to_json }
|
|
704
|
+
# )
|
|
705
|
+
# outputs = (await tool_bundle.get_tool("computer_call").execute(context=tool_context, arguments=message.model_dump(mode="json"))).outputs
|
|
706
|
+
|
|
707
|
+
# return outputs, False
|
|
708
|
+
|
|
709
|
+
else:
|
|
710
|
+
with tracer.start_as_current_span(
|
|
711
|
+
"llm.handle_tool_call"
|
|
712
|
+
) as span:
|
|
713
|
+
for toolkit in toolkits:
|
|
714
|
+
for tool in toolkit.tools:
|
|
715
|
+
if isinstance(
|
|
716
|
+
tool, OpenAIResponsesTool
|
|
717
|
+
):
|
|
718
|
+
arguments = message.model_dump(
|
|
719
|
+
mode="json"
|
|
720
|
+
)
|
|
721
|
+
span.set_attributes(
|
|
722
|
+
{
|
|
723
|
+
"type": message.type,
|
|
724
|
+
"arguments": safe_json_dump(
|
|
725
|
+
arguments
|
|
726
|
+
),
|
|
727
|
+
}
|
|
728
|
+
)
|
|
729
|
+
|
|
730
|
+
handlers = tool.get_open_ai_output_handlers()
|
|
731
|
+
if message.type in handlers:
|
|
732
|
+
tool_context = ToolContext(
|
|
733
|
+
room=room,
|
|
734
|
+
caller=room.local_participant,
|
|
735
|
+
caller_context={
|
|
736
|
+
"chat": context.to_json()
|
|
737
|
+
},
|
|
738
|
+
)
|
|
739
|
+
|
|
740
|
+
try:
|
|
741
|
+
if (
|
|
742
|
+
event_handler
|
|
743
|
+
is not None
|
|
744
|
+
):
|
|
745
|
+
event_handler(
|
|
746
|
+
{
|
|
747
|
+
"type": "meshagent.handler.added",
|
|
748
|
+
"item": message.model_dump(
|
|
749
|
+
mode="json"
|
|
750
|
+
),
|
|
751
|
+
}
|
|
752
|
+
)
|
|
753
|
+
|
|
754
|
+
result = await handlers[
|
|
755
|
+
message.type
|
|
756
|
+
](
|
|
757
|
+
tool_context,
|
|
758
|
+
**arguments,
|
|
759
|
+
)
|
|
760
|
+
|
|
761
|
+
except Exception as e:
|
|
762
|
+
if (
|
|
763
|
+
event_handler
|
|
764
|
+
is not None
|
|
765
|
+
):
|
|
766
|
+
event_handler(
|
|
767
|
+
{
|
|
768
|
+
"type": "meshagent.handler.done",
|
|
769
|
+
"error": f"{e}",
|
|
770
|
+
}
|
|
771
|
+
)
|
|
772
|
+
|
|
773
|
+
raise
|
|
774
|
+
|
|
775
|
+
if (
|
|
776
|
+
event_handler
|
|
777
|
+
is not None
|
|
778
|
+
):
|
|
779
|
+
event_handler(
|
|
780
|
+
{
|
|
781
|
+
"type": "meshagent.handler.done",
|
|
782
|
+
"item": result,
|
|
783
|
+
}
|
|
784
|
+
)
|
|
785
|
+
|
|
786
|
+
if result is not None:
|
|
787
|
+
span.set_attribute(
|
|
788
|
+
"result",
|
|
789
|
+
safe_json_dump(
|
|
790
|
+
result
|
|
791
|
+
),
|
|
792
|
+
)
|
|
793
|
+
return [result], False
|
|
794
|
+
|
|
795
|
+
return [], False
|
|
796
|
+
|
|
797
|
+
logger.warning(
|
|
798
|
+
f"OpenAI response handler was not registered for {message.type}"
|
|
799
|
+
)
|
|
800
|
+
|
|
801
|
+
return [], False
|
|
802
|
+
|
|
803
|
+
if not stream:
|
|
804
|
+
room.developer.log_nowait(
|
|
805
|
+
type="llm.message",
|
|
806
|
+
data={
|
|
807
|
+
"context": context.id,
|
|
808
|
+
"participant_id": room.local_participant.id,
|
|
809
|
+
"participant_name": room.local_participant.get_attribute(
|
|
810
|
+
"name"
|
|
811
|
+
),
|
|
812
|
+
"response": response.to_dict(),
|
|
813
|
+
},
|
|
814
|
+
)
|
|
815
|
+
|
|
816
|
+
context.track_response(response.id)
|
|
817
|
+
|
|
818
|
+
final_outputs = []
|
|
819
|
+
|
|
820
|
+
for message in response.output:
|
|
821
|
+
context.previous_messages.append(message.to_dict())
|
|
822
|
+
outputs, done = await handle_message(
|
|
823
|
+
message=message
|
|
824
|
+
)
|
|
825
|
+
if done:
|
|
826
|
+
final_outputs.extend(outputs)
|
|
827
|
+
else:
|
|
828
|
+
for output in outputs:
|
|
829
|
+
context.messages.append(output)
|
|
830
|
+
|
|
831
|
+
if len(final_outputs) > 0:
|
|
832
|
+
return final_outputs[0]
|
|
833
|
+
|
|
834
|
+
with tracer.start_as_current_span(
|
|
835
|
+
"llm.turn.check_for_termination"
|
|
836
|
+
) as span:
|
|
837
|
+
term = await self.check_for_termination(
|
|
838
|
+
context=context, room=room
|
|
839
|
+
)
|
|
840
|
+
if term:
|
|
841
|
+
span.set_attribute("terminate", True)
|
|
842
|
+
text = ""
|
|
843
|
+
for output in response.output:
|
|
844
|
+
if output.type == "message":
|
|
845
|
+
for content in output.content:
|
|
846
|
+
text += content.text
|
|
847
|
+
|
|
848
|
+
return text
|
|
849
|
+
else:
|
|
850
|
+
span.set_attribute("terminate", False)
|
|
851
|
+
|
|
852
|
+
else:
|
|
853
|
+
final_outputs = []
|
|
854
|
+
all_outputs = []
|
|
855
|
+
async for e in response:
|
|
856
|
+
with tracer.start_as_current_span(
|
|
857
|
+
"llm.stream.event"
|
|
858
|
+
) as span:
|
|
859
|
+
event: ResponseStreamEvent = e
|
|
860
|
+
span.set_attributes(
|
|
861
|
+
{
|
|
862
|
+
"type": event.type,
|
|
863
|
+
"event": safe_model_dump(event),
|
|
864
|
+
}
|
|
865
|
+
)
|
|
866
|
+
event_handler(event)
|
|
867
|
+
|
|
868
|
+
if event.type == "response.completed":
|
|
869
|
+
context.track_response(event.response.id)
|
|
870
|
+
|
|
871
|
+
context.messages.extend(all_outputs)
|
|
872
|
+
|
|
873
|
+
with tracer.start_as_current_span(
|
|
874
|
+
"llm.turn.check_for_termination"
|
|
875
|
+
) as span:
|
|
876
|
+
term = await self.check_for_termination(
|
|
877
|
+
context=context, room=room
|
|
878
|
+
)
|
|
879
|
+
|
|
880
|
+
if term:
|
|
881
|
+
span.set_attribute(
|
|
882
|
+
"terminate", True
|
|
883
|
+
)
|
|
884
|
+
|
|
885
|
+
text = ""
|
|
886
|
+
for output in event.response.output:
|
|
887
|
+
if output.type == "message":
|
|
888
|
+
for (
|
|
889
|
+
content
|
|
890
|
+
) in output.content:
|
|
891
|
+
text += content.text
|
|
892
|
+
|
|
893
|
+
return text
|
|
894
|
+
|
|
895
|
+
span.set_attribute("terminate", False)
|
|
896
|
+
|
|
897
|
+
all_outputs = []
|
|
898
|
+
|
|
899
|
+
elif event.type == "response.output_item.done":
|
|
900
|
+
context.previous_messages.append(
|
|
901
|
+
event.item.to_dict()
|
|
902
|
+
)
|
|
903
|
+
|
|
904
|
+
outputs, done = await handle_message(
|
|
905
|
+
message=event.item
|
|
906
|
+
)
|
|
907
|
+
if done:
|
|
908
|
+
final_outputs.extend(outputs)
|
|
909
|
+
else:
|
|
910
|
+
for output in outputs:
|
|
911
|
+
all_outputs.append(output)
|
|
912
|
+
|
|
913
|
+
else:
|
|
914
|
+
for toolkit in toolkits:
|
|
915
|
+
for tool in toolkit.tools:
|
|
916
|
+
if isinstance(
|
|
917
|
+
tool, OpenAIResponsesTool
|
|
918
|
+
):
|
|
919
|
+
callbacks = tool.get_open_ai_stream_callbacks()
|
|
920
|
+
|
|
921
|
+
if event.type in callbacks:
|
|
922
|
+
tool_context = ToolContext(
|
|
923
|
+
room=room,
|
|
924
|
+
caller=room.local_participant,
|
|
925
|
+
caller_context={
|
|
926
|
+
"chat": context.to_json()
|
|
927
|
+
},
|
|
928
|
+
)
|
|
929
|
+
|
|
930
|
+
await callbacks[event.type](
|
|
931
|
+
tool_context,
|
|
932
|
+
**event.to_dict(),
|
|
933
|
+
)
|
|
934
|
+
|
|
935
|
+
if len(final_outputs) > 0:
|
|
936
|
+
return final_outputs[0]
|
|
937
|
+
|
|
938
|
+
except APIStatusError as e:
|
|
939
|
+
raise RoomException(f"Error from OpenAI: {e}")
|
|
940
|
+
|
|
941
|
+
|
|
942
|
+
class OpenAIResponsesTool(BaseTool):
|
|
943
|
+
def get_open_ai_tool_definitions(self) -> list[dict]:
|
|
944
|
+
return []
|
|
945
|
+
|
|
946
|
+
def get_open_ai_stream_callbacks(self) -> dict[str, Callable]:
|
|
947
|
+
return {}
|
|
948
|
+
|
|
949
|
+
def get_open_ai_output_handlers(self) -> dict[str, Callable]:
|
|
950
|
+
return {}
|
|
951
|
+
|
|
952
|
+
|
|
953
|
+
class ImageGenerationConfig(ToolkitConfig):
|
|
954
|
+
name: Literal["image_generation"] = "image_generation"
|
|
955
|
+
background: Literal["transparent", "opaque", "auto"] = None
|
|
956
|
+
input_image_mask_url: Optional[str] = None
|
|
957
|
+
model: Optional[str] = None
|
|
958
|
+
moderation: Optional[str] = None
|
|
959
|
+
output_compression: Optional[int] = None
|
|
960
|
+
output_format: Optional[Literal["png", "webp", "jpeg"]] = None
|
|
961
|
+
partial_images: Optional[int] = None
|
|
962
|
+
quality: Optional[Literal["auto", "low", "medium", "high"]] = None
|
|
963
|
+
size: Optional[Literal["1024x1024", "1024x1536", "1536x1024", "auto"]] = None
|
|
964
|
+
|
|
965
|
+
|
|
966
|
+
class ImageGenerationToolkitBuilder(ToolkitBuilder):
|
|
967
|
+
def __init__(self):
|
|
968
|
+
super().__init__(name="image_generation", type=ImageGenerationConfig)
|
|
969
|
+
|
|
970
|
+
async def make(
|
|
971
|
+
self, *, room: RoomClient, model: str, config: ImageGenerationConfig
|
|
972
|
+
):
|
|
973
|
+
return Toolkit(
|
|
974
|
+
name="image_generation", tools=[ImageGenerationTool(config=config)]
|
|
975
|
+
)
|
|
976
|
+
|
|
977
|
+
|
|
978
|
+
class ImageGenerationTool(OpenAIResponsesTool):
|
|
979
|
+
def __init__(
|
|
980
|
+
self,
|
|
981
|
+
*,
|
|
982
|
+
config: ImageGenerationConfig,
|
|
983
|
+
):
|
|
984
|
+
super().__init__(name="image_generation")
|
|
985
|
+
self.background = config.background
|
|
986
|
+
self.input_image_mask_url = config.input_image_mask_url
|
|
987
|
+
self.model = config.model
|
|
988
|
+
self.moderation = config.moderation
|
|
989
|
+
self.output_compression = config.output_compression
|
|
990
|
+
self.output_format = config.output_format
|
|
991
|
+
self.partial_images = (
|
|
992
|
+
config.partial_images if config.partial_images is not None else 1
|
|
993
|
+
)
|
|
994
|
+
self.quality = config.quality
|
|
995
|
+
self.size = config.size
|
|
996
|
+
|
|
997
|
+
def get_open_ai_tool_definitions(self):
|
|
998
|
+
opts = {"type": "image_generation"}
|
|
999
|
+
|
|
1000
|
+
if self.background is not None:
|
|
1001
|
+
opts["background"] = self.background
|
|
1002
|
+
|
|
1003
|
+
if self.input_image_mask_url is not None:
|
|
1004
|
+
opts["input_image_mask"] = {"image_url": self.input_image_mask_url}
|
|
1005
|
+
|
|
1006
|
+
if self.model is not None:
|
|
1007
|
+
opts["model"] = self.model
|
|
1008
|
+
|
|
1009
|
+
if self.moderation is not None:
|
|
1010
|
+
opts["moderation"] = self.moderation
|
|
1011
|
+
|
|
1012
|
+
if self.output_compression is not None:
|
|
1013
|
+
opts["output_compression"] = self.output_compression
|
|
1014
|
+
|
|
1015
|
+
if self.output_format is not None:
|
|
1016
|
+
opts["output_format"] = self.output_format
|
|
1017
|
+
|
|
1018
|
+
if self.partial_images is not None:
|
|
1019
|
+
opts["partial_images"] = self.partial_images
|
|
1020
|
+
|
|
1021
|
+
if self.quality is not None:
|
|
1022
|
+
opts["quality"] = self.quality
|
|
1023
|
+
|
|
1024
|
+
if self.size is not None:
|
|
1025
|
+
opts["size"] = self.size
|
|
1026
|
+
|
|
1027
|
+
return [opts]
|
|
1028
|
+
|
|
1029
|
+
def get_open_ai_stream_callbacks(self):
|
|
1030
|
+
return {
|
|
1031
|
+
"response.image_generation_call.completed": self.on_image_generation_completed,
|
|
1032
|
+
"response.image_generation_call.in_progress": self.on_image_generation_in_progress,
|
|
1033
|
+
"response.image_generation_call.generating": self.on_image_generation_generating,
|
|
1034
|
+
"response.image_generation_call.partial_image": self.on_image_generation_partial,
|
|
1035
|
+
}
|
|
1036
|
+
|
|
1037
|
+
def get_open_ai_output_handlers(self):
|
|
1038
|
+
return {"image_generation_call": self.handle_image_generated}
|
|
1039
|
+
|
|
1040
|
+
# response.image_generation_call.completed
|
|
1041
|
+
async def on_image_generation_completed(
|
|
1042
|
+
self,
|
|
1043
|
+
context: ToolContext,
|
|
1044
|
+
*,
|
|
1045
|
+
item_id: str,
|
|
1046
|
+
output_index: int,
|
|
1047
|
+
sequence_number: int,
|
|
1048
|
+
type: str,
|
|
1049
|
+
**extra,
|
|
1050
|
+
):
|
|
1051
|
+
pass
|
|
1052
|
+
|
|
1053
|
+
# response.image_generation_call.in_progress
|
|
1054
|
+
async def on_image_generation_in_progress(
|
|
1055
|
+
self,
|
|
1056
|
+
context: ToolContext,
|
|
1057
|
+
*,
|
|
1058
|
+
item_id: str,
|
|
1059
|
+
output_index: int,
|
|
1060
|
+
sequence_number: int,
|
|
1061
|
+
type: str,
|
|
1062
|
+
**extra,
|
|
1063
|
+
):
|
|
1064
|
+
pass
|
|
1065
|
+
|
|
1066
|
+
# response.image_generation_call.generating
|
|
1067
|
+
async def on_image_generation_generating(
|
|
1068
|
+
self,
|
|
1069
|
+
context: ToolContext,
|
|
1070
|
+
*,
|
|
1071
|
+
item_id: str,
|
|
1072
|
+
output_index: int,
|
|
1073
|
+
sequence_number: int,
|
|
1074
|
+
type: str,
|
|
1075
|
+
**extra,
|
|
1076
|
+
):
|
|
1077
|
+
pass
|
|
1078
|
+
|
|
1079
|
+
# response.image_generation_call.partial_image
|
|
1080
|
+
async def on_image_generation_partial(
|
|
1081
|
+
self,
|
|
1082
|
+
context: ToolContext,
|
|
1083
|
+
*,
|
|
1084
|
+
item_id: str,
|
|
1085
|
+
output_index: int,
|
|
1086
|
+
sequence_number: int,
|
|
1087
|
+
type: str,
|
|
1088
|
+
partial_image_b64: str,
|
|
1089
|
+
partial_image_index: int,
|
|
1090
|
+
size: str,
|
|
1091
|
+
quality: str,
|
|
1092
|
+
background: str,
|
|
1093
|
+
output_format: str,
|
|
1094
|
+
**extra,
|
|
1095
|
+
):
|
|
1096
|
+
pass
|
|
1097
|
+
|
|
1098
|
+
async def on_image_generated(
|
|
1099
|
+
self,
|
|
1100
|
+
context: ToolContext,
|
|
1101
|
+
*,
|
|
1102
|
+
item_id: str,
|
|
1103
|
+
data: bytes,
|
|
1104
|
+
status: str,
|
|
1105
|
+
size: str,
|
|
1106
|
+
quality: str,
|
|
1107
|
+
background: str,
|
|
1108
|
+
output_format: str,
|
|
1109
|
+
**extra,
|
|
1110
|
+
):
|
|
1111
|
+
pass
|
|
1112
|
+
|
|
1113
|
+
async def handle_image_generated(
|
|
1114
|
+
self,
|
|
1115
|
+
context: ToolContext,
|
|
1116
|
+
*,
|
|
1117
|
+
id: str,
|
|
1118
|
+
result: str | None,
|
|
1119
|
+
status: str,
|
|
1120
|
+
type: str,
|
|
1121
|
+
size: str,
|
|
1122
|
+
quality: str,
|
|
1123
|
+
background: str,
|
|
1124
|
+
output_format: str,
|
|
1125
|
+
**extra,
|
|
1126
|
+
):
|
|
1127
|
+
if result is not None:
|
|
1128
|
+
data = base64.b64decode(result)
|
|
1129
|
+
await self.on_image_generated(
|
|
1130
|
+
context,
|
|
1131
|
+
item_id=id,
|
|
1132
|
+
data=data,
|
|
1133
|
+
status=status,
|
|
1134
|
+
size=size,
|
|
1135
|
+
quality=quality,
|
|
1136
|
+
background=background,
|
|
1137
|
+
output_format=output_format,
|
|
1138
|
+
)
|
|
1139
|
+
|
|
1140
|
+
|
|
1141
|
+
class LocalShellConfig(ToolkitConfig):
|
|
1142
|
+
name: Literal["local_shell"] = "local_shell"
|
|
1143
|
+
|
|
1144
|
+
|
|
1145
|
+
class LocalShellToolkitBuilder(ToolkitBuilder):
|
|
1146
|
+
def __init__(self, *, working_directory: Optional[str] = None):
|
|
1147
|
+
super().__init__(name="local_shell", type=LocalShellConfig)
|
|
1148
|
+
self.working_directory = working_directory
|
|
1149
|
+
|
|
1150
|
+
async def make(self, *, room: RoomClient, model: str, config: LocalShellConfig):
|
|
1151
|
+
return Toolkit(
|
|
1152
|
+
name="local_shell",
|
|
1153
|
+
tools=[
|
|
1154
|
+
LocalShellTool(config=config, working_directory=self.working_directory)
|
|
1155
|
+
],
|
|
1156
|
+
)
|
|
1157
|
+
|
|
1158
|
+
|
|
1159
|
+
MAX_SHELL_OUTPUT_SIZE = 1024 * 100
|
|
1160
|
+
|
|
1161
|
+
|
|
1162
|
+
class LocalShellTool(OpenAIResponsesTool):
|
|
1163
|
+
def __init__(
|
|
1164
|
+
self,
|
|
1165
|
+
*,
|
|
1166
|
+
config: Optional[LocalShellConfig] = None,
|
|
1167
|
+
working_directory: Optional[str] = None,
|
|
1168
|
+
):
|
|
1169
|
+
super().__init__(name="local_shell")
|
|
1170
|
+
if config is None:
|
|
1171
|
+
config = LocalShellConfig(name="local_shell")
|
|
1172
|
+
|
|
1173
|
+
self.working_directory = working_directory
|
|
1174
|
+
|
|
1175
|
+
def get_open_ai_tool_definitions(self):
|
|
1176
|
+
return [{"type": "local_shell"}]
|
|
1177
|
+
|
|
1178
|
+
def get_open_ai_output_handlers(self):
|
|
1179
|
+
return {"local_shell_call": self.handle_local_shell_call}
|
|
1180
|
+
|
|
1181
|
+
async def execute_shell_command(
|
|
1182
|
+
self,
|
|
1183
|
+
context: ToolContext,
|
|
1184
|
+
*,
|
|
1185
|
+
command: list[str],
|
|
1186
|
+
env: dict,
|
|
1187
|
+
type: str,
|
|
1188
|
+
timeout_ms: int | None = None,
|
|
1189
|
+
user: str | None = None,
|
|
1190
|
+
working_directory: str | None = None,
|
|
1191
|
+
):
|
|
1192
|
+
merged_env = {**os.environ, **(env or {})}
|
|
1193
|
+
|
|
1194
|
+
try:
|
|
1195
|
+
# Spawn the process
|
|
1196
|
+
proc = await asyncio.create_subprocess_exec(
|
|
1197
|
+
*(command if isinstance(command, (list, tuple)) else [command]),
|
|
1198
|
+
cwd=working_directory or self.working_directory or os.getcwd(),
|
|
1199
|
+
env=merged_env,
|
|
1200
|
+
stdout=asyncio.subprocess.PIPE,
|
|
1201
|
+
stderr=asyncio.subprocess.PIPE,
|
|
1202
|
+
)
|
|
1203
|
+
|
|
1204
|
+
timeout = float(timeout_ms) / 1000.0 if timeout_ms else 20.0
|
|
1205
|
+
|
|
1206
|
+
logger.info(f"executing command {command} with timeout: {timeout}s")
|
|
1207
|
+
|
|
1208
|
+
stdout, stderr = await asyncio.wait_for(
|
|
1209
|
+
proc.communicate(),
|
|
1210
|
+
timeout=timeout,
|
|
1211
|
+
)
|
|
1212
|
+
except asyncio.TimeoutError:
|
|
1213
|
+
proc.kill() # send SIGKILL / TerminateProcess
|
|
1214
|
+
logger.info(f"The command timed out after {timeout}s")
|
|
1215
|
+
stdout, stderr = await proc.communicate()
|
|
1216
|
+
return f"The command timed out after {timeout}s"
|
|
1217
|
+
# re-raise so caller sees the timeout
|
|
1218
|
+
except Exception as ex:
|
|
1219
|
+
return f"The command failed: {ex}"
|
|
1220
|
+
|
|
1221
|
+
encoding = os.device_encoding(1) or "utf-8"
|
|
1222
|
+
stdout = stdout.decode(encoding, errors="replace")
|
|
1223
|
+
stderr = stderr.decode(encoding, errors="replace")
|
|
1224
|
+
|
|
1225
|
+
result = stdout + stderr
|
|
1226
|
+
if len(result) > MAX_SHELL_OUTPUT_SIZE:
|
|
1227
|
+
return f"Error: the command returned too much data ({result} bytes)"
|
|
1228
|
+
|
|
1229
|
+
return result
|
|
1230
|
+
|
|
1231
|
+
async def handle_local_shell_call(
|
|
1232
|
+
self,
|
|
1233
|
+
context,
|
|
1234
|
+
*,
|
|
1235
|
+
id: str,
|
|
1236
|
+
action: dict,
|
|
1237
|
+
call_id: str,
|
|
1238
|
+
status: str,
|
|
1239
|
+
type: str,
|
|
1240
|
+
**extra,
|
|
1241
|
+
):
|
|
1242
|
+
result = await self.execute_shell_command(context, **action)
|
|
1243
|
+
|
|
1244
|
+
output_item = {
|
|
1245
|
+
"type": "local_shell_call_output",
|
|
1246
|
+
"call_id": call_id,
|
|
1247
|
+
"output": result,
|
|
1248
|
+
}
|
|
1249
|
+
|
|
1250
|
+
return output_item
|
|
1251
|
+
|
|
1252
|
+
|
|
1253
|
+
class ShellConfig(ToolkitConfig):
|
|
1254
|
+
name: Literal["shell"] = ("shell",)
|
|
1255
|
+
|
|
1256
|
+
|
|
1257
|
+
DEFAULT_CONTAINER_MOUNT_SPEC = ContainerMountSpec(
|
|
1258
|
+
room=[RoomStorageMountSpec(path="/data")]
|
|
1259
|
+
)
|
|
1260
|
+
|
|
1261
|
+
|
|
1262
|
+
class ShellToolkitBuilder(ToolkitBuilder):
|
|
1263
|
+
def __init__(
|
|
1264
|
+
self,
|
|
1265
|
+
*,
|
|
1266
|
+
working_directory: Optional[str] = None,
|
|
1267
|
+
image: Optional[str] = "ubuntu:latest",
|
|
1268
|
+
mounts: Optional[ContainerMountSpec] = DEFAULT_CONTAINER_MOUNT_SPEC,
|
|
1269
|
+
):
|
|
1270
|
+
super().__init__(name="shell", type=ShellConfig)
|
|
1271
|
+
self.working_directory = working_directory
|
|
1272
|
+
self.image = image
|
|
1273
|
+
self.mounts = mounts
|
|
1274
|
+
|
|
1275
|
+
async def make(self, *, room: RoomClient, model: str, config: LocalShellConfig):
|
|
1276
|
+
return Toolkit(
|
|
1277
|
+
name="shell",
|
|
1278
|
+
tools=[
|
|
1279
|
+
ShellTool(
|
|
1280
|
+
config=config,
|
|
1281
|
+
working_directory=self.working_directory,
|
|
1282
|
+
image=self.image,
|
|
1283
|
+
mounts=self.mounts,
|
|
1284
|
+
)
|
|
1285
|
+
],
|
|
1286
|
+
)
|
|
1287
|
+
|
|
1288
|
+
|
|
1289
|
+
class ShellTool(OpenAIResponsesTool):
|
|
1290
|
+
def __init__(
|
|
1291
|
+
self,
|
|
1292
|
+
*,
|
|
1293
|
+
config: Optional[ShellConfig] = None,
|
|
1294
|
+
working_directory: Optional[str] = None,
|
|
1295
|
+
image: Optional[str] = "ubuntu:latest",
|
|
1296
|
+
mounts: Optional[ContainerMountSpec] = DEFAULT_CONTAINER_MOUNT_SPEC,
|
|
1297
|
+
):
|
|
1298
|
+
super().__init__(name="shell")
|
|
1299
|
+
if config is None:
|
|
1300
|
+
config = ShellConfig(name="shell")
|
|
1301
|
+
self.working_directory = working_directory
|
|
1302
|
+
self.image = image
|
|
1303
|
+
self.mounts = mounts
|
|
1304
|
+
|
|
1305
|
+
def get_open_ai_tool_definitions(self):
|
|
1306
|
+
return [{"type": "shell"}]
|
|
1307
|
+
|
|
1308
|
+
def get_open_ai_output_handlers(self):
|
|
1309
|
+
return {"shell_call": self.handle_shell_call}
|
|
1310
|
+
|
|
1311
|
+
async def execute_shell_command(
|
|
1312
|
+
self,
|
|
1313
|
+
context: ToolContext,
|
|
1314
|
+
*,
|
|
1315
|
+
commands: list[str],
|
|
1316
|
+
max_output_length: Optional[int] = None,
|
|
1317
|
+
timeout_ms: Optional[int] = None,
|
|
1318
|
+
):
|
|
1319
|
+
merged_env = {**os.environ}
|
|
1320
|
+
|
|
1321
|
+
results = []
|
|
1322
|
+
encoding = os.device_encoding(1) or "utf-8"
|
|
1323
|
+
|
|
1324
|
+
left = max_output_length
|
|
1325
|
+
|
|
1326
|
+
def limit(s: str):
|
|
1327
|
+
nonlocal left
|
|
1328
|
+
if left is not None:
|
|
1329
|
+
s = s[0:left]
|
|
1330
|
+
left -= len(s)
|
|
1331
|
+
return s
|
|
1332
|
+
else:
|
|
1333
|
+
return s
|
|
1334
|
+
|
|
1335
|
+
timeout = float(timeout_ms) / 1000.0 if timeout_ms else 20.0
|
|
1336
|
+
|
|
1337
|
+
if self.image is not None:
|
|
1338
|
+
container_id = await context.room.containers.run(
|
|
1339
|
+
command="sleep infinity",
|
|
1340
|
+
image=self.image,
|
|
1341
|
+
mounts=self.mounts,
|
|
1342
|
+
)
|
|
1343
|
+
|
|
1344
|
+
try:
|
|
1345
|
+
# TODO: what if container start fails
|
|
1346
|
+
|
|
1347
|
+
logger.info(f"executing shell commands in container {container_id}")
|
|
1348
|
+
|
|
1349
|
+
for command in commands:
|
|
1350
|
+
exec = await context.room.containers.exec(
|
|
1351
|
+
container_id=container_id, command=command, tty=False
|
|
1352
|
+
)
|
|
1353
|
+
|
|
1354
|
+
stdout = bytearray()
|
|
1355
|
+
stderr = bytearray()
|
|
1356
|
+
|
|
1357
|
+
async for se in exec.stderr():
|
|
1358
|
+
stdout.extend(se)
|
|
1359
|
+
|
|
1360
|
+
async for so in exec.stdout():
|
|
1361
|
+
stdout.extend(so)
|
|
1362
|
+
|
|
1363
|
+
try:
|
|
1364
|
+
async with asyncio.Timeout(timeout):
|
|
1365
|
+
exit_code = await exec.result
|
|
1366
|
+
|
|
1367
|
+
results.append(
|
|
1368
|
+
{
|
|
1369
|
+
"outcome": {
|
|
1370
|
+
"type": "exit",
|
|
1371
|
+
"exit_code": exit_code,
|
|
1372
|
+
},
|
|
1373
|
+
"stdout": stdout.decode(),
|
|
1374
|
+
"stderr": stderr.decode(),
|
|
1375
|
+
}
|
|
1376
|
+
)
|
|
1377
|
+
|
|
1378
|
+
except asyncio.TimeoutError:
|
|
1379
|
+
logger.info(f"The command timed out after {timeout}s")
|
|
1380
|
+
await exec.close()
|
|
1381
|
+
|
|
1382
|
+
results.append(
|
|
1383
|
+
{
|
|
1384
|
+
"outcome": {"type": "timeout"},
|
|
1385
|
+
"stdout": limit(
|
|
1386
|
+
stdout.decode(encoding, errors="replace")
|
|
1387
|
+
),
|
|
1388
|
+
"stderr": limit(
|
|
1389
|
+
stderr.decode(encoding, errors="replace")
|
|
1390
|
+
),
|
|
1391
|
+
}
|
|
1392
|
+
)
|
|
1393
|
+
break
|
|
1394
|
+
|
|
1395
|
+
except Exception as ex:
|
|
1396
|
+
results.append(
|
|
1397
|
+
{
|
|
1398
|
+
"outcome": {
|
|
1399
|
+
"type": "exit",
|
|
1400
|
+
"exit_code": 1,
|
|
1401
|
+
},
|
|
1402
|
+
"stdout": "",
|
|
1403
|
+
"stderr": f"{ex}",
|
|
1404
|
+
}
|
|
1405
|
+
)
|
|
1406
|
+
break
|
|
1407
|
+
|
|
1408
|
+
except Exception as ex:
|
|
1409
|
+
results.append(
|
|
1410
|
+
{
|
|
1411
|
+
"outcome": {
|
|
1412
|
+
"type": "exit",
|
|
1413
|
+
"exit_code": 1,
|
|
1414
|
+
},
|
|
1415
|
+
"stdout": "",
|
|
1416
|
+
"stderr": f"{ex}",
|
|
1417
|
+
}
|
|
1418
|
+
)
|
|
1419
|
+
|
|
1420
|
+
if container_id is not None:
|
|
1421
|
+
await context.room.containers.stop(container_id=container_id)
|
|
1422
|
+
await context.room.containers.delete(container_id=container_id)
|
|
1423
|
+
|
|
1424
|
+
else:
|
|
1425
|
+
for command in commands:
|
|
1426
|
+
logger.info(f"executing command {command} with timeout: {timeout}s")
|
|
1427
|
+
|
|
1428
|
+
# Spawn the process
|
|
1429
|
+
try:
|
|
1430
|
+
proc = await asyncio.create_subprocess_shell(
|
|
1431
|
+
command,
|
|
1432
|
+
cwd=self.working_directory or os.getcwd(),
|
|
1433
|
+
env=merged_env,
|
|
1434
|
+
stdout=asyncio.subprocess.PIPE,
|
|
1435
|
+
stderr=asyncio.subprocess.PIPE,
|
|
1436
|
+
)
|
|
1437
|
+
|
|
1438
|
+
stdout, stderr = await asyncio.wait_for(
|
|
1439
|
+
proc.communicate(),
|
|
1440
|
+
timeout=timeout,
|
|
1441
|
+
)
|
|
1442
|
+
except asyncio.TimeoutError:
|
|
1443
|
+
logger.info(f"The command timed out after {timeout}s")
|
|
1444
|
+
proc.kill() # send SIGKILL / TerminateProcess
|
|
1445
|
+
|
|
1446
|
+
stdout, stderr = await proc.communicate()
|
|
1447
|
+
|
|
1448
|
+
results.append(
|
|
1449
|
+
{
|
|
1450
|
+
"outcome": {"type": "timeout"},
|
|
1451
|
+
"stdout": limit(stdout.decode(encoding, errors="replace")),
|
|
1452
|
+
"stderr": limit(stderr.decode(encoding, errors="replace")),
|
|
1453
|
+
}
|
|
1454
|
+
)
|
|
1455
|
+
|
|
1456
|
+
break
|
|
1457
|
+
|
|
1458
|
+
except Exception as ex:
|
|
1459
|
+
results.append(
|
|
1460
|
+
{
|
|
1461
|
+
"outcome": {
|
|
1462
|
+
"type": "exit",
|
|
1463
|
+
"exit_code": 1,
|
|
1464
|
+
},
|
|
1465
|
+
"stdout": "",
|
|
1466
|
+
"stderr": f"{ex}",
|
|
1467
|
+
}
|
|
1468
|
+
)
|
|
1469
|
+
break
|
|
1470
|
+
|
|
1471
|
+
results.append(
|
|
1472
|
+
{
|
|
1473
|
+
"outcome": {
|
|
1474
|
+
"type": "exit",
|
|
1475
|
+
"exit_code": proc.returncode,
|
|
1476
|
+
},
|
|
1477
|
+
"stdout": limit(stdout.decode(encoding, errors="replace")),
|
|
1478
|
+
"stderr": limit(stderr.decode(encoding, errors="replace")),
|
|
1479
|
+
}
|
|
1480
|
+
)
|
|
1481
|
+
|
|
1482
|
+
return results
|
|
1483
|
+
|
|
1484
|
+
async def handle_shell_call(
|
|
1485
|
+
self,
|
|
1486
|
+
context,
|
|
1487
|
+
*,
|
|
1488
|
+
id: str,
|
|
1489
|
+
action: dict,
|
|
1490
|
+
call_id: str,
|
|
1491
|
+
status: str,
|
|
1492
|
+
type: str,
|
|
1493
|
+
**extra,
|
|
1494
|
+
):
|
|
1495
|
+
result = await self.execute_shell_command(context, **action)
|
|
1496
|
+
|
|
1497
|
+
output_item = {
|
|
1498
|
+
"type": "shell_call_output",
|
|
1499
|
+
"call_id": call_id,
|
|
1500
|
+
"output": result,
|
|
1501
|
+
}
|
|
1502
|
+
|
|
1503
|
+
return output_item
|
|
1504
|
+
|
|
1505
|
+
|
|
1506
|
+
class ContainerFile:
|
|
1507
|
+
def __init__(self, *, file_id: str, mime_type: str, container_id: str):
|
|
1508
|
+
self.file_id = file_id
|
|
1509
|
+
self.mime_type = mime_type
|
|
1510
|
+
self.container_id = container_id
|
|
1511
|
+
|
|
1512
|
+
|
|
1513
|
+
class CodeInterpreterTool(OpenAIResponsesTool):
|
|
1514
|
+
def __init__(
|
|
1515
|
+
self,
|
|
1516
|
+
*,
|
|
1517
|
+
container_id: Optional[str] = None,
|
|
1518
|
+
file_ids: Optional[List[str]] = None,
|
|
1519
|
+
):
|
|
1520
|
+
super().__init__(name="code_interpreter_call")
|
|
1521
|
+
self.container_id = container_id
|
|
1522
|
+
self.file_ids = file_ids
|
|
1523
|
+
|
|
1524
|
+
def get_open_ai_tool_definitions(self):
|
|
1525
|
+
opts = {"type": "code_interpreter"}
|
|
1526
|
+
|
|
1527
|
+
if self.container_id is not None:
|
|
1528
|
+
opts["container_id"] = self.container_id
|
|
1529
|
+
|
|
1530
|
+
if self.file_ids is not None:
|
|
1531
|
+
if self.container_id is not None:
|
|
1532
|
+
raise Exception(
|
|
1533
|
+
"Cannot specify both an existing container and files to upload in a code interpreter tool"
|
|
1534
|
+
)
|
|
1535
|
+
|
|
1536
|
+
opts["container"] = {"type": "auto", "file_ids": self.file_ids}
|
|
1537
|
+
|
|
1538
|
+
return [opts]
|
|
1539
|
+
|
|
1540
|
+
def get_open_ai_output_handlers(self):
|
|
1541
|
+
return {"code_interpreter_call": self.handle_code_interpreter_call}
|
|
1542
|
+
|
|
1543
|
+
async def on_code_interpreter_result(
|
|
1544
|
+
self,
|
|
1545
|
+
context: ToolContext,
|
|
1546
|
+
*,
|
|
1547
|
+
code: str,
|
|
1548
|
+
logs: list[str],
|
|
1549
|
+
files: list[ContainerFile],
|
|
1550
|
+
):
|
|
1551
|
+
pass
|
|
1552
|
+
|
|
1553
|
+
async def handle_code_interpreter_call(
|
|
1554
|
+
self,
|
|
1555
|
+
context,
|
|
1556
|
+
*,
|
|
1557
|
+
code: str,
|
|
1558
|
+
id: str,
|
|
1559
|
+
results: list[dict],
|
|
1560
|
+
call_id: str,
|
|
1561
|
+
status: str,
|
|
1562
|
+
type: str,
|
|
1563
|
+
container_id: str,
|
|
1564
|
+
**extra,
|
|
1565
|
+
):
|
|
1566
|
+
logs = []
|
|
1567
|
+
files = []
|
|
1568
|
+
|
|
1569
|
+
for result in results:
|
|
1570
|
+
if result.type == "logs":
|
|
1571
|
+
logs.append(results["logs"])
|
|
1572
|
+
|
|
1573
|
+
elif result.type == "files":
|
|
1574
|
+
files.append(
|
|
1575
|
+
ContainerFile(
|
|
1576
|
+
container_id=container_id,
|
|
1577
|
+
file_id=result["file_id"],
|
|
1578
|
+
mime_type=result["mime_type"],
|
|
1579
|
+
)
|
|
1580
|
+
)
|
|
1581
|
+
|
|
1582
|
+
await self.on_code_interpreter_result(
|
|
1583
|
+
context, code=code, logs=logs, files=files
|
|
1584
|
+
)
|
|
1585
|
+
|
|
1586
|
+
|
|
1587
|
+
class MCPToolDefinition:
|
|
1588
|
+
def __init__(
|
|
1589
|
+
self,
|
|
1590
|
+
*,
|
|
1591
|
+
input_schema: dict,
|
|
1592
|
+
name: str,
|
|
1593
|
+
annotations: dict | None,
|
|
1594
|
+
description: str | None,
|
|
1595
|
+
):
|
|
1596
|
+
self.input_schema = input_schema
|
|
1597
|
+
self.name = name
|
|
1598
|
+
self.annotations = annotations
|
|
1599
|
+
self.description = description
|
|
1600
|
+
|
|
1601
|
+
|
|
1602
|
+
class MCPServer(BaseModel):
|
|
1603
|
+
server_label: str
|
|
1604
|
+
server_url: Optional[str] = None
|
|
1605
|
+
allowed_tools: Optional[list[str]] = None
|
|
1606
|
+
authorization: Optional[str] = None
|
|
1607
|
+
headers: Optional[dict] = None
|
|
1608
|
+
|
|
1609
|
+
# require approval for all tools
|
|
1610
|
+
require_approval: Optional[Literal["always", "never"]] = None
|
|
1611
|
+
# list of tools that always require approval
|
|
1612
|
+
always_require_approval: Optional[list[str]] = None
|
|
1613
|
+
# list of tools that never require approval
|
|
1614
|
+
never_require_approval: Optional[list[str]] = None
|
|
1615
|
+
|
|
1616
|
+
openai_connector_id: Optional[str] = None
|
|
1617
|
+
|
|
1618
|
+
|
|
1619
|
+
class MCPConfig(ToolkitConfig):
|
|
1620
|
+
name: Literal["mcp"] = "mcp"
|
|
1621
|
+
servers: list[MCPServer]
|
|
1622
|
+
|
|
1623
|
+
|
|
1624
|
+
class MCPToolkitBuilder(ToolkitBuilder):
|
|
1625
|
+
def __init__(self):
|
|
1626
|
+
super().__init__(name="mcp", type=MCPConfig)
|
|
1627
|
+
|
|
1628
|
+
async def make(self, *, room: RoomClient, model: str, config: MCPConfig):
|
|
1629
|
+
return Toolkit(name="mcp", tools=[MCPTool(config=config)])
|
|
1630
|
+
|
|
1631
|
+
|
|
1632
|
+
class MCPTool(OpenAIResponsesTool):
|
|
1633
|
+
def __init__(self, *, config: MCPConfig):
|
|
1634
|
+
super().__init__(name="mcp")
|
|
1635
|
+
self.servers = config.servers
|
|
1636
|
+
|
|
1637
|
+
def get_open_ai_tool_definitions(self):
|
|
1638
|
+
defs = []
|
|
1639
|
+
for server in self.servers:
|
|
1640
|
+
opts = {
|
|
1641
|
+
"type": "mcp",
|
|
1642
|
+
"server_label": server.server_label,
|
|
1643
|
+
}
|
|
1644
|
+
|
|
1645
|
+
if server.server_url is not None:
|
|
1646
|
+
opts["server_url"] = server.server_url
|
|
1647
|
+
|
|
1648
|
+
if server.openai_connector_id is not None:
|
|
1649
|
+
opts["connector_id"] = server.openai_connector_id
|
|
1650
|
+
|
|
1651
|
+
if server.allowed_tools is not None:
|
|
1652
|
+
opts["allowed_tools"] = server.allowed_tools
|
|
1653
|
+
|
|
1654
|
+
if server.authorization is not None:
|
|
1655
|
+
opts["authorization"] = server.authorization
|
|
1656
|
+
|
|
1657
|
+
if server.headers is not None:
|
|
1658
|
+
opts["headers"] = server.headers
|
|
1659
|
+
|
|
1660
|
+
if (
|
|
1661
|
+
server.always_require_approval is not None
|
|
1662
|
+
or server.never_require_approval is not None
|
|
1663
|
+
):
|
|
1664
|
+
opts["require_approval"] = {}
|
|
1665
|
+
|
|
1666
|
+
if server.always_require_approval is not None:
|
|
1667
|
+
opts["require_approval"]["always"] = {
|
|
1668
|
+
"tool_names": server.always_require_approval
|
|
1669
|
+
}
|
|
1670
|
+
|
|
1671
|
+
if server.never_require_approval is not None:
|
|
1672
|
+
opts["require_approval"]["never"] = {
|
|
1673
|
+
"tool_names": server.never_require_approval
|
|
1674
|
+
}
|
|
1675
|
+
|
|
1676
|
+
if server.require_approval:
|
|
1677
|
+
opts["require_approval"] = server.require_approval
|
|
1678
|
+
|
|
1679
|
+
defs.append(opts)
|
|
1680
|
+
|
|
1681
|
+
return defs
|
|
1682
|
+
|
|
1683
|
+
def get_open_ai_stream_callbacks(self):
|
|
1684
|
+
return {
|
|
1685
|
+
"response.mcp_list_tools.in_progress": self.on_mcp_list_tools_in_progress,
|
|
1686
|
+
"response.mcp_list_tools.failed": self.on_mcp_list_tools_failed,
|
|
1687
|
+
"response.mcp_list_tools.completed": self.on_mcp_list_tools_completed,
|
|
1688
|
+
"response.mcp_call.in_progress": self.on_mcp_call_in_progress,
|
|
1689
|
+
"response.mcp_call.failed": self.on_mcp_call_failed,
|
|
1690
|
+
"response.mcp_call.completed": self.on_mcp_call_completed,
|
|
1691
|
+
"response.mcp_call.arguments.done": self.on_mcp_call_arguments_done,
|
|
1692
|
+
"response.mcp_call.arguments.delta": self.on_mcp_call_arguments_delta,
|
|
1693
|
+
}
|
|
1694
|
+
|
|
1695
|
+
async def on_mcp_list_tools_in_progress(
|
|
1696
|
+
self, context: ToolContext, *, sequence_number: int, type: str, **extra
|
|
1697
|
+
):
|
|
1698
|
+
pass
|
|
1699
|
+
|
|
1700
|
+
async def on_mcp_list_tools_failed(
|
|
1701
|
+
self, context: ToolContext, *, sequence_number: int, type: str, **extra
|
|
1702
|
+
):
|
|
1703
|
+
pass
|
|
1704
|
+
|
|
1705
|
+
async def on_mcp_list_tools_completed(
|
|
1706
|
+
self, context: ToolContext, *, sequence_number: int, type: str, **extra
|
|
1707
|
+
):
|
|
1708
|
+
pass
|
|
1709
|
+
|
|
1710
|
+
async def on_mcp_call_in_progress(
|
|
1711
|
+
self,
|
|
1712
|
+
context: ToolContext,
|
|
1713
|
+
*,
|
|
1714
|
+
item_id: str,
|
|
1715
|
+
output_index: int,
|
|
1716
|
+
sequence_number: int,
|
|
1717
|
+
type: str,
|
|
1718
|
+
**extra,
|
|
1719
|
+
):
|
|
1720
|
+
pass
|
|
1721
|
+
|
|
1722
|
+
async def on_mcp_call_failed(
|
|
1723
|
+
self, context: ToolContext, *, sequence_number: int, type: str, **extra
|
|
1724
|
+
):
|
|
1725
|
+
pass
|
|
1726
|
+
|
|
1727
|
+
async def on_mcp_call_completed(
|
|
1728
|
+
self, context: ToolContext, *, sequence_number: int, type: str, **extra
|
|
1729
|
+
):
|
|
1730
|
+
pass
|
|
1731
|
+
|
|
1732
|
+
async def on_mcp_call_arguments_done(
|
|
1733
|
+
self,
|
|
1734
|
+
context: ToolContext,
|
|
1735
|
+
*,
|
|
1736
|
+
arguments: dict,
|
|
1737
|
+
item_id: str,
|
|
1738
|
+
output_index: int,
|
|
1739
|
+
sequence_number: int,
|
|
1740
|
+
type: str,
|
|
1741
|
+
**extra,
|
|
1742
|
+
):
|
|
1743
|
+
pass
|
|
1744
|
+
|
|
1745
|
+
async def on_mcp_call_arguments_delta(
|
|
1746
|
+
self,
|
|
1747
|
+
context: ToolContext,
|
|
1748
|
+
*,
|
|
1749
|
+
delta: dict,
|
|
1750
|
+
item_id: str,
|
|
1751
|
+
output_index: int,
|
|
1752
|
+
sequence_number: int,
|
|
1753
|
+
type: str,
|
|
1754
|
+
**extra,
|
|
1755
|
+
):
|
|
1756
|
+
pass
|
|
1757
|
+
|
|
1758
|
+
def get_open_ai_output_handlers(self):
|
|
1759
|
+
return {
|
|
1760
|
+
"mcp_call": self.handle_mcp_call,
|
|
1761
|
+
"mcp_list_tools": self.handle_mcp_list_tools,
|
|
1762
|
+
"mcp_approval_request": self.handle_mcp_approval_request,
|
|
1763
|
+
}
|
|
1764
|
+
|
|
1765
|
+
async def on_mcp_list_tools(
|
|
1766
|
+
self,
|
|
1767
|
+
context: ToolContext,
|
|
1768
|
+
*,
|
|
1769
|
+
server_label: str,
|
|
1770
|
+
tools: list[MCPToolDefinition],
|
|
1771
|
+
error: str | None,
|
|
1772
|
+
**extra,
|
|
1773
|
+
):
|
|
1774
|
+
pass
|
|
1775
|
+
|
|
1776
|
+
async def handle_mcp_list_tools(
|
|
1777
|
+
self,
|
|
1778
|
+
context,
|
|
1779
|
+
*,
|
|
1780
|
+
id: str,
|
|
1781
|
+
server_label: str,
|
|
1782
|
+
tools: list,
|
|
1783
|
+
type: str,
|
|
1784
|
+
error: str | None = None,
|
|
1785
|
+
**extra,
|
|
1786
|
+
):
|
|
1787
|
+
mcp_tools = []
|
|
1788
|
+
for tool in tools:
|
|
1789
|
+
mcp_tools.append(
|
|
1790
|
+
MCPToolDefinition(
|
|
1791
|
+
input_schema=tool["input_schema"],
|
|
1792
|
+
name=tool["name"],
|
|
1793
|
+
annotations=tool["annotations"],
|
|
1794
|
+
description=tool["description"],
|
|
1795
|
+
)
|
|
1796
|
+
)
|
|
1797
|
+
|
|
1798
|
+
await self.on_mcp_list_tools(
|
|
1799
|
+
context, server_label=server_label, tools=mcp_tools, error=error
|
|
1800
|
+
)
|
|
1801
|
+
|
|
1802
|
+
async def on_mcp_call(
|
|
1803
|
+
self,
|
|
1804
|
+
context: ToolContext,
|
|
1805
|
+
*,
|
|
1806
|
+
name: str,
|
|
1807
|
+
arguments: str,
|
|
1808
|
+
server_label: str,
|
|
1809
|
+
error: str | None,
|
|
1810
|
+
output: str | None,
|
|
1811
|
+
**extra,
|
|
1812
|
+
):
|
|
1813
|
+
pass
|
|
1814
|
+
|
|
1815
|
+
async def handle_mcp_call(
|
|
1816
|
+
self,
|
|
1817
|
+
context,
|
|
1818
|
+
*,
|
|
1819
|
+
arguments: str,
|
|
1820
|
+
id: str,
|
|
1821
|
+
name: str,
|
|
1822
|
+
server_label: str,
|
|
1823
|
+
type: str,
|
|
1824
|
+
error: str | None,
|
|
1825
|
+
output: str | None,
|
|
1826
|
+
**extra,
|
|
1827
|
+
):
|
|
1828
|
+
await self.on_mcp_call(
|
|
1829
|
+
context,
|
|
1830
|
+
name=name,
|
|
1831
|
+
arguments=arguments,
|
|
1832
|
+
server_label=server_label,
|
|
1833
|
+
error=error,
|
|
1834
|
+
output=output,
|
|
1835
|
+
)
|
|
1836
|
+
|
|
1837
|
+
async def on_mcp_approval_request(
|
|
1838
|
+
self,
|
|
1839
|
+
context: ToolContext,
|
|
1840
|
+
*,
|
|
1841
|
+
name: str,
|
|
1842
|
+
arguments: str,
|
|
1843
|
+
server_label: str,
|
|
1844
|
+
**extra,
|
|
1845
|
+
) -> bool:
|
|
1846
|
+
return True
|
|
1847
|
+
|
|
1848
|
+
async def handle_mcp_approval_request(
|
|
1849
|
+
self,
|
|
1850
|
+
context: ToolContext,
|
|
1851
|
+
*,
|
|
1852
|
+
arguments: str,
|
|
1853
|
+
id: str,
|
|
1854
|
+
name: str,
|
|
1855
|
+
server_label: str,
|
|
1856
|
+
type: str,
|
|
1857
|
+
**extra,
|
|
1858
|
+
):
|
|
1859
|
+
logger.info(f"approval requested for MCP tool {server_label}.{name}")
|
|
1860
|
+
should_approve = await self.on_mcp_approval_request(
|
|
1861
|
+
context, arguments=arguments, name=name, server_label=server_label
|
|
1862
|
+
)
|
|
1863
|
+
if should_approve:
|
|
1864
|
+
logger.info(f"approval granted for MCP tool {server_label}.{name}")
|
|
1865
|
+
return {
|
|
1866
|
+
"type": "mcp_approval_response",
|
|
1867
|
+
"approve": True,
|
|
1868
|
+
"approval_request_id": id,
|
|
1869
|
+
}
|
|
1870
|
+
else:
|
|
1871
|
+
logger.info(f"approval denied for MCP tool {server_label}.{name}")
|
|
1872
|
+
return {
|
|
1873
|
+
"type": "mcp_approval_response",
|
|
1874
|
+
"approve": False,
|
|
1875
|
+
"approval_request_id": id,
|
|
1876
|
+
}
|
|
1877
|
+
|
|
1878
|
+
|
|
1879
|
+
class ReasoningTool(OpenAIResponsesTool):
|
|
1880
|
+
def __init__(self):
|
|
1881
|
+
super().__init__(name="reasoning")
|
|
1882
|
+
|
|
1883
|
+
def get_open_ai_output_handlers(self):
|
|
1884
|
+
return {
|
|
1885
|
+
"reasoning": self.handle_reasoning,
|
|
1886
|
+
}
|
|
1887
|
+
|
|
1888
|
+
def get_open_ai_stream_callbacks(self):
|
|
1889
|
+
return {
|
|
1890
|
+
"response.reasoning_summary_text.done": self.on_reasoning_summary_text_done,
|
|
1891
|
+
"response.reasoning_summary_text.delta": self.on_reasoning_summary_text_delta,
|
|
1892
|
+
"response.reasoning_summary_part.done": self.on_reasoning_summary_part_done,
|
|
1893
|
+
"response.reasoning_summary_part.added": self.on_reasoning_summary_part_added,
|
|
1894
|
+
}
|
|
1895
|
+
|
|
1896
|
+
async def on_reasoning_summary_part_added(
|
|
1897
|
+
self,
|
|
1898
|
+
context: ToolContext,
|
|
1899
|
+
*,
|
|
1900
|
+
item_id: str,
|
|
1901
|
+
output_index: int,
|
|
1902
|
+
part: dict,
|
|
1903
|
+
sequence_number: int,
|
|
1904
|
+
summary_index: int,
|
|
1905
|
+
type: str,
|
|
1906
|
+
**extra,
|
|
1907
|
+
):
|
|
1908
|
+
pass
|
|
1909
|
+
|
|
1910
|
+
async def on_reasoning_summary_part_done(
|
|
1911
|
+
self,
|
|
1912
|
+
context: ToolContext,
|
|
1913
|
+
*,
|
|
1914
|
+
item_id: str,
|
|
1915
|
+
output_index: int,
|
|
1916
|
+
part: dict,
|
|
1917
|
+
sequence_number: int,
|
|
1918
|
+
summary_index: int,
|
|
1919
|
+
type: str,
|
|
1920
|
+
**extra,
|
|
1921
|
+
):
|
|
1922
|
+
pass
|
|
1923
|
+
|
|
1924
|
+
async def on_reasoning_summary_text_delta(
|
|
1925
|
+
self,
|
|
1926
|
+
context: ToolContext,
|
|
1927
|
+
*,
|
|
1928
|
+
delta: str,
|
|
1929
|
+
output_index: int,
|
|
1930
|
+
sequence_number: int,
|
|
1931
|
+
summary_index: int,
|
|
1932
|
+
type: str,
|
|
1933
|
+
**extra,
|
|
1934
|
+
):
|
|
1935
|
+
pass
|
|
1936
|
+
|
|
1937
|
+
async def on_reasoning_summary_text_done(
|
|
1938
|
+
self,
|
|
1939
|
+
context: ToolContext,
|
|
1940
|
+
*,
|
|
1941
|
+
item_id: str,
|
|
1942
|
+
output_index: int,
|
|
1943
|
+
sequence_number: int,
|
|
1944
|
+
summary_index: int,
|
|
1945
|
+
type: str,
|
|
1946
|
+
**extra,
|
|
1947
|
+
):
|
|
1948
|
+
pass
|
|
1949
|
+
|
|
1950
|
+
async def on_reasoning(
|
|
1951
|
+
self,
|
|
1952
|
+
context: ToolContext,
|
|
1953
|
+
*,
|
|
1954
|
+
summary: list[str],
|
|
1955
|
+
content: Optional[list[str]] = None,
|
|
1956
|
+
encrypted_content: str | None,
|
|
1957
|
+
status: Literal["in_progress", "completed", "incomplete"],
|
|
1958
|
+
):
|
|
1959
|
+
pass
|
|
1960
|
+
|
|
1961
|
+
async def handle_reasoning(
|
|
1962
|
+
self,
|
|
1963
|
+
context: ToolContext,
|
|
1964
|
+
*,
|
|
1965
|
+
id: str,
|
|
1966
|
+
summary: list[dict],
|
|
1967
|
+
type: str,
|
|
1968
|
+
content: Optional[list[dict]],
|
|
1969
|
+
encrypted_content: str | None,
|
|
1970
|
+
status: str,
|
|
1971
|
+
**extra,
|
|
1972
|
+
):
|
|
1973
|
+
await self.on_reasoning(
|
|
1974
|
+
context,
|
|
1975
|
+
summary=summary,
|
|
1976
|
+
content=content,
|
|
1977
|
+
encrypted_content=encrypted_content,
|
|
1978
|
+
status=status,
|
|
1979
|
+
)
|
|
1980
|
+
|
|
1981
|
+
|
|
1982
|
+
# TODO: computer tool call
|
|
1983
|
+
|
|
1984
|
+
|
|
1985
|
+
class WebSearchConfig(ToolkitConfig):
|
|
1986
|
+
name: Literal["web_search"] = "web_search"
|
|
1987
|
+
|
|
1988
|
+
|
|
1989
|
+
class WebSearchToolkitBuilder(ToolkitBuilder):
|
|
1990
|
+
def __init__(self):
|
|
1991
|
+
super().__init__(name="web_search", type=WebSearchConfig)
|
|
1992
|
+
|
|
1993
|
+
async def make(self, *, room: RoomClient, model: str, config: WebSearchConfig):
|
|
1994
|
+
return Toolkit(name="web_search", tools=[WebSearchTool(config=config)])
|
|
1995
|
+
|
|
1996
|
+
|
|
1997
|
+
class WebSearchTool(OpenAIResponsesTool):
|
|
1998
|
+
def __init__(self, *, config: Optional[WebSearchConfig] = None):
|
|
1999
|
+
if config is None:
|
|
2000
|
+
config = WebSearchConfig(name="web_search")
|
|
2001
|
+
super().__init__(name="web_search")
|
|
2002
|
+
|
|
2003
|
+
def get_open_ai_tool_definitions(self) -> list[dict]:
|
|
2004
|
+
return [{"type": "web_search_preview"}]
|
|
2005
|
+
|
|
2006
|
+
def get_open_ai_stream_callbacks(self):
|
|
2007
|
+
return {
|
|
2008
|
+
"response.web_search_call.in_progress": self.on_web_search_call_in_progress,
|
|
2009
|
+
"response.web_search_call.searching": self.on_web_search_call_searching,
|
|
2010
|
+
"response.web_search_call.completed": self.on_web_search_call_completed,
|
|
2011
|
+
}
|
|
2012
|
+
|
|
2013
|
+
def get_open_ai_output_handlers(self):
|
|
2014
|
+
return {"web_search_call": self.handle_web_search_call}
|
|
2015
|
+
|
|
2016
|
+
async def on_web_search_call_in_progress(
|
|
2017
|
+
self,
|
|
2018
|
+
context: ToolContext,
|
|
2019
|
+
*,
|
|
2020
|
+
item_id: str,
|
|
2021
|
+
output_index: int,
|
|
2022
|
+
sequence_number: int,
|
|
2023
|
+
type: str,
|
|
2024
|
+
**extra,
|
|
2025
|
+
):
|
|
2026
|
+
pass
|
|
2027
|
+
|
|
2028
|
+
async def on_web_search_call_searching(
|
|
2029
|
+
self,
|
|
2030
|
+
context: ToolContext,
|
|
2031
|
+
*,
|
|
2032
|
+
item_id: str,
|
|
2033
|
+
output_index: int,
|
|
2034
|
+
sequence_number: int,
|
|
2035
|
+
type: str,
|
|
2036
|
+
**extra,
|
|
2037
|
+
):
|
|
2038
|
+
pass
|
|
2039
|
+
|
|
2040
|
+
async def on_web_search_call_completed(
|
|
2041
|
+
self,
|
|
2042
|
+
context: ToolContext,
|
|
2043
|
+
*,
|
|
2044
|
+
item_id: str,
|
|
2045
|
+
output_index: int,
|
|
2046
|
+
sequence_number: int,
|
|
2047
|
+
type: str,
|
|
2048
|
+
**extra,
|
|
2049
|
+
):
|
|
2050
|
+
pass
|
|
2051
|
+
|
|
2052
|
+
async def on_web_search(self, context: ToolContext, *, status: str, **extra):
|
|
2053
|
+
pass
|
|
2054
|
+
|
|
2055
|
+
async def handle_web_search_call(
|
|
2056
|
+
self, context: ToolContext, *, id: str, status: str, type: str, **extra
|
|
2057
|
+
):
|
|
2058
|
+
await self.on_web_search(context, status=status)
|
|
2059
|
+
|
|
2060
|
+
|
|
2061
|
+
class FileSearchResult:
|
|
2062
|
+
def __init__(
|
|
2063
|
+
self, *, attributes: dict, file_id: str, filename: str, score: float, text: str
|
|
2064
|
+
):
|
|
2065
|
+
self.attributes = attributes
|
|
2066
|
+
self.file_id = file_id
|
|
2067
|
+
self.filename = filename
|
|
2068
|
+
self.score = score
|
|
2069
|
+
self.text = text
|
|
2070
|
+
|
|
2071
|
+
|
|
2072
|
+
class FileSearchTool(OpenAIResponsesTool):
|
|
2073
|
+
def __init__(
|
|
2074
|
+
self,
|
|
2075
|
+
*,
|
|
2076
|
+
vector_store_ids: list[str],
|
|
2077
|
+
filters: Optional[dict] = None,
|
|
2078
|
+
max_num_results: Optional[int] = None,
|
|
2079
|
+
ranking_options: Optional[dict] = None,
|
|
2080
|
+
):
|
|
2081
|
+
super().__init__(name="file_search")
|
|
2082
|
+
|
|
2083
|
+
self.vector_store_ids = vector_store_ids
|
|
2084
|
+
self.filters = filters
|
|
2085
|
+
self.max_num_results = max_num_results
|
|
2086
|
+
self.ranking_options = ranking_options
|
|
2087
|
+
|
|
2088
|
+
def get_open_ai_tool_definitions(self) -> list[dict]:
|
|
2089
|
+
return [
|
|
2090
|
+
{
|
|
2091
|
+
"type": "file_search",
|
|
2092
|
+
"vector_store_ids": self.vector_store_ids,
|
|
2093
|
+
"filters": self.filters,
|
|
2094
|
+
"max_num_results": self.max_num_results,
|
|
2095
|
+
"ranking_options": self.ranking_options,
|
|
2096
|
+
}
|
|
2097
|
+
]
|
|
2098
|
+
|
|
2099
|
+
def get_open_ai_stream_callbacks(self):
|
|
2100
|
+
return {
|
|
2101
|
+
"response.file_search_call.in_progress": self.on_file_search_call_in_progress,
|
|
2102
|
+
"response.file_search_call.searching": self.on_file_search_call_searching,
|
|
2103
|
+
"response.file_search_call.completed": self.on_file_search_call_completed,
|
|
2104
|
+
}
|
|
2105
|
+
|
|
2106
|
+
def get_open_ai_output_handlers(self):
|
|
2107
|
+
return {"handle_file_search_call": self.handle_file_search_call}
|
|
2108
|
+
|
|
2109
|
+
async def on_file_search_call_in_progress(
|
|
2110
|
+
self,
|
|
2111
|
+
context: ToolContext,
|
|
2112
|
+
*,
|
|
2113
|
+
item_id: str,
|
|
2114
|
+
output_index: int,
|
|
2115
|
+
sequence_number: int,
|
|
2116
|
+
type: str,
|
|
2117
|
+
**extra,
|
|
2118
|
+
):
|
|
2119
|
+
pass
|
|
2120
|
+
|
|
2121
|
+
async def on_file_search_call_searching(
|
|
2122
|
+
self,
|
|
2123
|
+
context: ToolContext,
|
|
2124
|
+
*,
|
|
2125
|
+
item_id: str,
|
|
2126
|
+
output_index: int,
|
|
2127
|
+
sequence_number: int,
|
|
2128
|
+
type: str,
|
|
2129
|
+
**extra,
|
|
2130
|
+
):
|
|
2131
|
+
pass
|
|
2132
|
+
|
|
2133
|
+
async def on_file_search_call_completed(
|
|
2134
|
+
self,
|
|
2135
|
+
context: ToolContext,
|
|
2136
|
+
*,
|
|
2137
|
+
item_id: str,
|
|
2138
|
+
output_index: int,
|
|
2139
|
+
sequence_number: int,
|
|
2140
|
+
type: str,
|
|
2141
|
+
**extra,
|
|
2142
|
+
):
|
|
2143
|
+
pass
|
|
2144
|
+
|
|
2145
|
+
async def on_file_search(
|
|
2146
|
+
self,
|
|
2147
|
+
context: ToolContext,
|
|
2148
|
+
*,
|
|
2149
|
+
queries: list,
|
|
2150
|
+
results: list[FileSearchResult],
|
|
2151
|
+
status: Literal["in_progress", "searching", "incomplete", "failed"],
|
|
2152
|
+
):
|
|
2153
|
+
pass
|
|
2154
|
+
|
|
2155
|
+
async def handle_file_search_call(
|
|
2156
|
+
self,
|
|
2157
|
+
context: ToolContext,
|
|
2158
|
+
*,
|
|
2159
|
+
id: str,
|
|
2160
|
+
queries: list,
|
|
2161
|
+
status: str,
|
|
2162
|
+
results: dict | None,
|
|
2163
|
+
type: str,
|
|
2164
|
+
**extra,
|
|
2165
|
+
):
|
|
2166
|
+
search_results = None
|
|
2167
|
+
if results is not None:
|
|
2168
|
+
search_results = []
|
|
2169
|
+
for result in results:
|
|
2170
|
+
search_results.append(FileSearchResult(**result))
|
|
2171
|
+
|
|
2172
|
+
await self.on_file_search(
|
|
2173
|
+
context, queries=queries, results=search_results, status=status
|
|
2174
|
+
)
|
|
2175
|
+
|
|
2176
|
+
|
|
2177
|
+
class ApplyPatchConfig(ToolkitConfig):
|
|
2178
|
+
name: Literal["apply_patch"] = "apply_patch"
|
|
2179
|
+
|
|
2180
|
+
|
|
2181
|
+
class ApplyPatchToolkitBuilder(ToolkitBuilder):
|
|
2182
|
+
def __init__(self):
|
|
2183
|
+
super().__init__(name="apply_patch", type=ApplyPatchConfig)
|
|
2184
|
+
|
|
2185
|
+
async def make(self, *, room: RoomClient, model: str, config: ApplyPatchConfig):
|
|
2186
|
+
return Toolkit(name="apply_patch", tools=[ApplyPatchTool(config=config)])
|
|
2187
|
+
|
|
2188
|
+
|
|
2189
|
+
class ApplyPatchTool(OpenAIResponsesTool):
|
|
2190
|
+
"""
|
|
2191
|
+
Wrapper for the built-in `apply_patch` tool.
|
|
2192
|
+
|
|
2193
|
+
The model will emit `apply_patch_call` items whenever it wants to create,
|
|
2194
|
+
update, or delete a file using a unified diff. The server / host
|
|
2195
|
+
environment is expected to actually apply the patch and, if desired,
|
|
2196
|
+
log results via `apply_patch_call_output`.
|
|
2197
|
+
|
|
2198
|
+
The two key handler entrypoints you can override are:
|
|
2199
|
+
|
|
2200
|
+
* `on_apply_patch_call` – called when the model requests a patch
|
|
2201
|
+
* `on_apply_patch_call_output` – called when the tool emits a log/output item
|
|
2202
|
+
"""
|
|
2203
|
+
|
|
2204
|
+
def __init__(self, *, config: ApplyPatchConfig):
|
|
2205
|
+
super().__init__(name="apply_patch")
|
|
2206
|
+
|
|
2207
|
+
# Tool definition advertised to OpenAI
|
|
2208
|
+
def get_open_ai_tool_definitions(self) -> list[dict]:
|
|
2209
|
+
# No extra options for now – the built-in tool just needs the type
|
|
2210
|
+
return [{"type": "apply_patch"}]
|
|
2211
|
+
|
|
2212
|
+
# Stream callbacks for `response.apply_patch_call.*` events
|
|
2213
|
+
def get_open_ai_stream_callbacks(self):
|
|
2214
|
+
return {
|
|
2215
|
+
"response.apply_patch_call.in_progress": self.on_apply_patch_call_in_progress,
|
|
2216
|
+
"response.apply_patch_call.completed": self.on_apply_patch_call_completed,
|
|
2217
|
+
}
|
|
2218
|
+
|
|
2219
|
+
# Output handlers for item types
|
|
2220
|
+
def get_open_ai_output_handlers(self):
|
|
2221
|
+
return {
|
|
2222
|
+
# The tool call itself (what to apply)
|
|
2223
|
+
"apply_patch_call": self.handle_apply_patch_call,
|
|
2224
|
+
}
|
|
2225
|
+
|
|
2226
|
+
# --- Stream callbacks -------------------------------------------------
|
|
2227
|
+
|
|
2228
|
+
# response.apply_patch_call.in_progress
|
|
2229
|
+
async def on_apply_patch_call_in_progress(
|
|
2230
|
+
self,
|
|
2231
|
+
context: ToolContext,
|
|
2232
|
+
*,
|
|
2233
|
+
item_id: str,
|
|
2234
|
+
output_index: int,
|
|
2235
|
+
sequence_number: int,
|
|
2236
|
+
type: str,
|
|
2237
|
+
**extra,
|
|
2238
|
+
):
|
|
2239
|
+
# Default: no-op, but you can log progress / show UI here if you want
|
|
2240
|
+
pass
|
|
2241
|
+
|
|
2242
|
+
# response.apply_patch_call.completed
|
|
2243
|
+
async def on_apply_patch_call_completed(
|
|
2244
|
+
self,
|
|
2245
|
+
context: ToolContext,
|
|
2246
|
+
*,
|
|
2247
|
+
item_id: str,
|
|
2248
|
+
output_index: int,
|
|
2249
|
+
sequence_number: int,
|
|
2250
|
+
type: str,
|
|
2251
|
+
**extra,
|
|
2252
|
+
):
|
|
2253
|
+
# Default: no-op
|
|
2254
|
+
pass
|
|
2255
|
+
|
|
2256
|
+
# --- High-level hooks -------------------------------------------------
|
|
2257
|
+
|
|
2258
|
+
async def on_apply_patch_call(
|
|
2259
|
+
self,
|
|
2260
|
+
context: ToolContext,
|
|
2261
|
+
*,
|
|
2262
|
+
call_id: str,
|
|
2263
|
+
operation: dict,
|
|
2264
|
+
status: str,
|
|
2265
|
+
**extra,
|
|
2266
|
+
):
|
|
2267
|
+
"""
|
|
2268
|
+
Called when the model requests an apply_patch operation.
|
|
2269
|
+
|
|
2270
|
+
operation looks like one of:
|
|
2271
|
+
|
|
2272
|
+
create_file:
|
|
2273
|
+
{
|
|
2274
|
+
"type": "create_file",
|
|
2275
|
+
"path": "relative/path/to/file",
|
|
2276
|
+
"diff": "...unified diff..."
|
|
2277
|
+
}
|
|
2278
|
+
|
|
2279
|
+
update_file:
|
|
2280
|
+
{
|
|
2281
|
+
"type": "update_file",
|
|
2282
|
+
"path": "relative/path/to/file",
|
|
2283
|
+
"diff": "...unified diff..."
|
|
2284
|
+
}
|
|
2285
|
+
|
|
2286
|
+
delete_file:
|
|
2287
|
+
{
|
|
2288
|
+
"type": "delete_file",
|
|
2289
|
+
"path": "relative/path/to/file"
|
|
2290
|
+
}
|
|
2291
|
+
"""
|
|
2292
|
+
# Override this to actually apply the patch in your workspace.
|
|
2293
|
+
# Default is no-op.
|
|
2294
|
+
|
|
2295
|
+
from meshagent.openai.tools.apply_patch import apply_diff
|
|
2296
|
+
|
|
2297
|
+
if operation["type"] == "delete_file":
|
|
2298
|
+
path = operation["path"]
|
|
2299
|
+
logger.info(f"applying patch: deleting file {path}")
|
|
2300
|
+
await context.room.storage.delete(path=path)
|
|
2301
|
+
log = f"Deleted file: {path}"
|
|
2302
|
+
logger.info(log)
|
|
2303
|
+
return {"status": "completed", "output": log}
|
|
2304
|
+
|
|
2305
|
+
elif operation["type"] == "create_file":
|
|
2306
|
+
diff = operation["diff"]
|
|
2307
|
+
path = operation["path"]
|
|
2308
|
+
logger.info(f"applying patch: creating file {path} with {diff}")
|
|
2309
|
+
handle = await context.room.storage.open(path=path, overwrite=False)
|
|
2310
|
+
try:
|
|
2311
|
+
patched = apply_diff("", diff, "create")
|
|
2312
|
+
except Exception as ex:
|
|
2313
|
+
return {"status": "failed", "output": f"{ex}"}
|
|
2314
|
+
await context.room.storage.write(handle=handle, data=patched.encode())
|
|
2315
|
+
await context.room.storage.close(handle=handle)
|
|
2316
|
+
|
|
2317
|
+
log = f"Created file: {path} ({len(patched)} bytes)"
|
|
2318
|
+
logger.info(log)
|
|
2319
|
+
return {"status": "completed", "output": log}
|
|
2320
|
+
|
|
2321
|
+
elif operation["type"] == "update_file":
|
|
2322
|
+
path = operation["path"]
|
|
2323
|
+
content = await context.room.storage.download(path=path)
|
|
2324
|
+
text = content.data.decode()
|
|
2325
|
+
diff = operation["diff"]
|
|
2326
|
+
|
|
2327
|
+
logger.info(f"applying patch: updating file {path} with {diff}")
|
|
2328
|
+
|
|
2329
|
+
try:
|
|
2330
|
+
patched = apply_diff(text, diff)
|
|
2331
|
+
except Exception as ex:
|
|
2332
|
+
return {"status": "failed", "output": f"{ex}"}
|
|
2333
|
+
|
|
2334
|
+
handle = await context.room.storage.open(path=path, overwrite=True)
|
|
2335
|
+
await context.room.storage.write(handle=handle, data=patched.encode())
|
|
2336
|
+
await context.room.storage.close(handle=handle)
|
|
2337
|
+
|
|
2338
|
+
log = f"Updated file: {path} ({len(text)} -> {len(patched)} bytes)"
|
|
2339
|
+
logger.info(log)
|
|
2340
|
+
return {"status": "completed", "output": log}
|
|
2341
|
+
|
|
2342
|
+
# apply patch
|
|
2343
|
+
else:
|
|
2344
|
+
raise Exception(f"Unexpected patch operation {operation}")
|
|
2345
|
+
|
|
2346
|
+
async def handle_apply_patch_call(
|
|
2347
|
+
self,
|
|
2348
|
+
context: ToolContext,
|
|
2349
|
+
*,
|
|
2350
|
+
call_id: str,
|
|
2351
|
+
operation: dict,
|
|
2352
|
+
status: str,
|
|
2353
|
+
type: str,
|
|
2354
|
+
id: str | None = None,
|
|
2355
|
+
**extra,
|
|
2356
|
+
):
|
|
2357
|
+
result = await self.on_apply_patch_call(
|
|
2358
|
+
context,
|
|
2359
|
+
call_id=call_id,
|
|
2360
|
+
operation=operation,
|
|
2361
|
+
status=status,
|
|
2362
|
+
**extra,
|
|
2363
|
+
)
|
|
2364
|
+
|
|
2365
|
+
return {
|
|
2366
|
+
"type": "apply_patch_call_output",
|
|
2367
|
+
"call_id": call_id,
|
|
2368
|
+
**result,
|
|
2369
|
+
}
|