meshagent-openai 0.18.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,2369 @@
1
+ from meshagent.agents.agent import AgentChatContext
2
+ from meshagent.api import RoomClient, RoomException, RemoteParticipant
3
+ from meshagent.tools import Toolkit, ToolContext, Tool, BaseTool
4
+ from meshagent.api.messaging import (
5
+ Response,
6
+ LinkResponse,
7
+ FileResponse,
8
+ JsonResponse,
9
+ TextResponse,
10
+ EmptyResponse,
11
+ RawOutputs,
12
+ ensure_response,
13
+ )
14
+ from meshagent.agents.adapter import (
15
+ ToolResponseAdapter,
16
+ LLMAdapter,
17
+ ToolkitBuilder,
18
+ ToolkitConfig,
19
+ )
20
+
21
+ from meshagent.api.specs.service import ContainerMountSpec, RoomStorageMountSpec
22
+ import json
23
+ from typing import List, Literal
24
+ from meshagent.openai.proxy import get_client
25
+ from openai import AsyncOpenAI, NOT_GIVEN, APIStatusError
26
+ from openai.types.responses import ResponseFunctionToolCall, ResponseStreamEvent
27
+ import os
28
+ from typing import Optional, Callable
29
+ import base64
30
+
31
+ import logging
32
+ import re
33
+ import asyncio
34
+ from pydantic import BaseModel
35
+ import copy
36
+ from opentelemetry import trace
37
+
38
+ logger = logging.getLogger("openai_agent")
39
+ tracer = trace.get_tracer("openai.llm.responses")
40
+
41
+
42
+ def safe_json_dump(data: dict):
43
+ return json.dumps(copy.deepcopy(data))
44
+
45
+
46
+ def safe_model_dump(model: BaseModel):
47
+ try:
48
+ return safe_json_dump(model.model_dump(mode="json"))
49
+ except Exception:
50
+ return {"error": "unable to dump json for model"}
51
+
52
+
53
+ def _replace_non_matching(text: str, allowed_chars: str, replacement: str) -> str:
54
+ """
55
+ Replaces every character in `text` that does not match the given
56
+ `allowed_chars` regex set with `replacement`.
57
+
58
+ Parameters:
59
+ -----------
60
+ text : str
61
+ The input string on which the replacement is to be done.
62
+ allowed_chars : str
63
+ A string defining the set of allowed characters (part of a character set).
64
+ For example, "a-zA-Z0-9" will keep only letters and digits.
65
+ replacement : str
66
+ The string to replace non-matching characters with.
67
+
68
+ Returns:
69
+ --------
70
+ str
71
+ A new string where all characters not in `allowed_chars` are replaced.
72
+ """
73
+ # Build a regex that matches any character NOT in allowed_chars
74
+ pattern = rf"[^{allowed_chars}]"
75
+ return re.sub(pattern, replacement, text)
76
+
77
+
78
+ def safe_tool_name(name: str):
79
+ return _replace_non_matching(name, "a-zA-Z0-9_-", "_")
80
+
81
+
82
+ # Collects a group of tool proxies and manages execution of openai tool calls
83
+ class ResponsesToolBundle:
84
+ def __init__(self, toolkits: List[Toolkit]):
85
+ self._toolkits = toolkits
86
+ self._executors = dict[str, Toolkit]()
87
+ self._safe_names = {}
88
+ self._tools_by_name = {}
89
+
90
+ open_ai_tools = []
91
+
92
+ for toolkit in toolkits:
93
+ for v in toolkit.tools:
94
+ k = v.name
95
+
96
+ name = safe_tool_name(k)
97
+
98
+ if k in self._executors:
99
+ raise Exception(
100
+ f"duplicate in bundle '{k}', tool names must be unique."
101
+ )
102
+
103
+ self._executors[k] = toolkit
104
+
105
+ self._safe_names[name] = k
106
+ self._tools_by_name[name] = v
107
+
108
+ if isinstance(v, OpenAIResponsesTool):
109
+ fns = v.get_open_ai_tool_definitions()
110
+ for fn in fns:
111
+ open_ai_tools.append(fn)
112
+
113
+ elif isinstance(v, Tool):
114
+ strict = True
115
+ if hasattr(v, "strict"):
116
+ strict = getattr(v, "strict")
117
+
118
+ fn = {
119
+ "type": "function",
120
+ "name": name,
121
+ "description": v.description,
122
+ "parameters": {
123
+ **v.input_schema,
124
+ },
125
+ "strict": strict,
126
+ }
127
+
128
+ if v.defs is not None:
129
+ fn["parameters"]["$defs"] = v.defs
130
+
131
+ open_ai_tools.append(fn)
132
+
133
+ else:
134
+ raise RoomException(f"unsupported tool type {type(v)}")
135
+
136
+ if len(open_ai_tools) == 0:
137
+ open_ai_tools = None
138
+
139
+ self._open_ai_tools = open_ai_tools
140
+
141
+ async def execute(
142
+ self, *, context: ToolContext, tool_call: ResponseFunctionToolCall
143
+ ) -> Response:
144
+ name = tool_call.name
145
+ arguments = json.loads(tool_call.arguments)
146
+
147
+ if name not in self._safe_names:
148
+ raise RoomException(f"Invalid tool name {name}, check the name of the tool")
149
+
150
+ name = self._safe_names[name]
151
+
152
+ if name not in self._executors:
153
+ raise Exception(f"Unregistered tool name {name}")
154
+
155
+ proxy = self._executors[name]
156
+ result = await proxy.execute(context=context, name=name, arguments=arguments)
157
+ return ensure_response(result)
158
+
159
+ def get_tool(self, name: str) -> BaseTool | None:
160
+ return self._tools_by_name.get(name, None)
161
+
162
+ def contains(self, name: str) -> bool:
163
+ return name in self._open_ai_tools
164
+
165
+ def to_json(self) -> List[dict] | None:
166
+ if self._open_ai_tools is None:
167
+ return None
168
+ return self._open_ai_tools.copy()
169
+
170
+
171
+ # Converts a tool response into a series of messages that can be inserted into the openai context
172
+ class OpenAIResponsesToolResponseAdapter(ToolResponseAdapter):
173
+ def __init__(self):
174
+ pass
175
+
176
+ async def to_plain_text(self, *, room: RoomClient, response: Response) -> str:
177
+ if isinstance(response, LinkResponse):
178
+ return json.dumps(
179
+ {
180
+ "name": response.name,
181
+ "url": response.url,
182
+ }
183
+ )
184
+
185
+ elif isinstance(response, JsonResponse):
186
+ return json.dumps(response.json)
187
+
188
+ elif isinstance(response, TextResponse):
189
+ return response.text
190
+
191
+ elif isinstance(response, FileResponse):
192
+ return f"{response.name}"
193
+
194
+ elif isinstance(response, EmptyResponse):
195
+ return "ok"
196
+
197
+ # elif isinstance(response, ImageResponse):
198
+ # context.messages.append({
199
+ # "role" : "assistant",
200
+ # "content" : "the user will upload the image",
201
+ # "tool_call_id" : tool_call.id,
202
+ # })
203
+ # context.messages.append({
204
+ # "role" : "user",
205
+ # "content" : [
206
+ # { "type" : "text", "text": "this is the image from tool call id {tool_call.id}" },
207
+ # { "type" : "image_url", "image_url": {"url": response.url, "detail": "auto"} }
208
+ # ]
209
+ # })
210
+
211
+ elif isinstance(response, dict):
212
+ return json.dumps(response)
213
+
214
+ elif isinstance(response, str):
215
+ return response
216
+
217
+ elif response is None:
218
+ return "ok"
219
+
220
+ else:
221
+ raise Exception(
222
+ "unexpected return type: {type}".format(type=type(response))
223
+ )
224
+
225
+ async def create_messages(
226
+ self,
227
+ *,
228
+ context: AgentChatContext,
229
+ tool_call: ResponseFunctionToolCall,
230
+ room: RoomClient,
231
+ response: Response,
232
+ ) -> list:
233
+ with tracer.start_as_current_span("llm.tool_adapter.create_messages") as span:
234
+ if isinstance(response, RawOutputs):
235
+ span.set_attribute("kind", "raw")
236
+ for output in response.outputs:
237
+ room.developer.log_nowait(
238
+ type="llm.message",
239
+ data={
240
+ "context": context.id,
241
+ "participant_id": room.local_participant.id,
242
+ "participant_name": room.local_participant.get_attribute(
243
+ "name"
244
+ ),
245
+ "message": output,
246
+ },
247
+ )
248
+
249
+ return response.outputs
250
+
251
+ else:
252
+ span.set_attribute("kind", "text")
253
+
254
+ if isinstance(response, FileResponse):
255
+ if response.mime_type and response.mime_type.startswith("image/"):
256
+ span.set_attribute(
257
+ "output", f"image: {response.name}, {response.mime_type}"
258
+ )
259
+
260
+ message = {
261
+ "output": [
262
+ {
263
+ "type": "input_image",
264
+ "image_url": f"data:{response.mime_type};base64,{base64.b64encode(response.data).decode()}",
265
+ }
266
+ ],
267
+ "call_id": tool_call.call_id,
268
+ "type": "function_call_output",
269
+ }
270
+ else:
271
+ span.set_attribute(
272
+ "output", f"file: {response.name}, {response.mime_type}"
273
+ )
274
+
275
+ if response.mime_type == "application/pdf":
276
+ message = {
277
+ "output": [
278
+ {
279
+ "type": "input_file",
280
+ "filename": response.name,
281
+ "file_data": f"data:{response.mime_type or 'text/plain'};base64,{base64.b64encode(response.data).decode()}",
282
+ }
283
+ ],
284
+ "call_id": tool_call.call_id,
285
+ "type": "function_call_output",
286
+ }
287
+ elif response.mime_type is not None and (
288
+ response.mime_type.startswith("text/")
289
+ or response.mime_type == "application/json"
290
+ ):
291
+ message = {
292
+ "output": response.data.decode(),
293
+ "call_id": tool_call.call_id,
294
+ "type": "function_call_output",
295
+ }
296
+
297
+ else:
298
+ message = {
299
+ "output": f"{response.name} was not in a supported format",
300
+ "call_id": tool_call.call_id,
301
+ "type": "function_call_output",
302
+ }
303
+
304
+ room.developer.log_nowait(
305
+ type="llm.message",
306
+ data={
307
+ "context": context.id,
308
+ "participant_id": room.local_participant.id,
309
+ "participant_name": room.local_participant.get_attribute(
310
+ "name"
311
+ ),
312
+ "message": message,
313
+ },
314
+ )
315
+
316
+ return [message]
317
+ else:
318
+ output = await self.to_plain_text(room=room, response=response)
319
+ span.set_attribute("output", output)
320
+
321
+ message = {
322
+ "output": output,
323
+ "call_id": tool_call.call_id,
324
+ "type": "function_call_output",
325
+ }
326
+
327
+ room.developer.log_nowait(
328
+ type="llm.message",
329
+ data={
330
+ "context": context.id,
331
+ "participant_id": room.local_participant.id,
332
+ "participant_name": room.local_participant.get_attribute(
333
+ "name"
334
+ ),
335
+ "message": message,
336
+ },
337
+ )
338
+
339
+ return [message]
340
+
341
+
342
+ class OpenAIResponsesAdapter(LLMAdapter[ResponseStreamEvent]):
343
+ def __init__(
344
+ self,
345
+ model: str = os.getenv("OPENAI_MODEL", "gpt-5.2"),
346
+ parallel_tool_calls: Optional[bool] = None,
347
+ client: Optional[AsyncOpenAI] = None,
348
+ response_options: Optional[dict] = None,
349
+ reasoning_effort: Optional[str] = None,
350
+ provider: str = "openai",
351
+ ):
352
+ self._model = model
353
+ self._parallel_tool_calls = parallel_tool_calls
354
+ self._client = client
355
+ self._response_options = response_options
356
+ self._provider = provider
357
+ self._reasoning_effort = reasoning_effort
358
+
359
+ def default_model(self) -> str:
360
+ return self._model
361
+
362
+ def create_chat_context(self):
363
+ context = AgentChatContext(system_role=None)
364
+ return context
365
+
366
+ async def check_for_termination(
367
+ self, *, context: AgentChatContext, room: RoomClient
368
+ ) -> bool:
369
+ for message in context.messages:
370
+ if message.get("type", "message") != "message":
371
+ return False
372
+
373
+ return True
374
+
375
+ # Takes the current chat context, executes a completion request and processes the response.
376
+ # If a tool calls are requested, invokes the tools, processes the tool calls results, and appends the tool call results to the context
377
+ async def next(
378
+ self,
379
+ *,
380
+ model: Optional[str] = None,
381
+ context: AgentChatContext,
382
+ room: RoomClient,
383
+ toolkits: list[Toolkit],
384
+ tool_adapter: Optional[ToolResponseAdapter] = None,
385
+ output_schema: Optional[dict] = None,
386
+ event_handler: Optional[Callable[[ResponseStreamEvent], None]] = None,
387
+ on_behalf_of: Optional[RemoteParticipant] = None,
388
+ ):
389
+ if model is None:
390
+ model = self.default_model()
391
+
392
+ with tracer.start_as_current_span("llm.turn") as span:
393
+ span.set_attributes({"chat_context": context.id, "api": "responses"})
394
+
395
+ if tool_adapter is None:
396
+ tool_adapter = OpenAIResponsesToolResponseAdapter()
397
+
398
+ try:
399
+ while True:
400
+ with tracer.start_as_current_span("llm.turn.iteration") as span:
401
+ span.set_attributes(
402
+ {"model": model, "provider": self._provider}
403
+ )
404
+
405
+ openai = (
406
+ self._client
407
+ if self._client is not None
408
+ else get_client(room=room)
409
+ )
410
+
411
+ response_schema = output_schema
412
+ response_name = "response"
413
+
414
+ # We need to do this inside the loop because tools can change mid loop
415
+ # for example computer use adds goto tools after the first interaction
416
+ tool_bundle = ResponsesToolBundle(
417
+ toolkits=[
418
+ *toolkits,
419
+ ]
420
+ )
421
+ open_ai_tools = tool_bundle.to_json()
422
+
423
+ if open_ai_tools is None:
424
+ open_ai_tools = NOT_GIVEN
425
+
426
+ ptc = self._parallel_tool_calls
427
+ extra = {}
428
+ if ptc is not None and not model.startswith("o"):
429
+ extra["parallel_tool_calls"] = ptc
430
+ span.set_attribute("parallel_tool_calls", ptc)
431
+ else:
432
+ span.set_attribute("parallel_tool_calls", False)
433
+
434
+ text = NOT_GIVEN
435
+ if output_schema is not None:
436
+ span.set_attribute("response_format", "json_schema")
437
+ text = {
438
+ "format": {
439
+ "type": "json_schema",
440
+ "name": response_name,
441
+ "schema": response_schema,
442
+ "strict": True,
443
+ }
444
+ }
445
+ else:
446
+ span.set_attribute("response_format", "text")
447
+
448
+ previous_response_id = NOT_GIVEN
449
+ instructions = context.get_system_instructions()
450
+ if context.previous_response_id is not None:
451
+ previous_response_id = context.previous_response_id
452
+
453
+ stream = event_handler is not None
454
+
455
+ with tracer.start_as_current_span("llm.invoke") as span:
456
+ response_options = copy.deepcopy(self._response_options)
457
+ if response_options is None:
458
+ response_options = {}
459
+
460
+ if self._reasoning_effort is not None:
461
+ response_options["reasoning"] = {
462
+ "effort": self._reasoning_effort,
463
+ "summary": "detailed",
464
+ }
465
+
466
+ extra_headers = {}
467
+ if on_behalf_of is not None:
468
+ on_behalf_of_name = on_behalf_of.get_attribute("name")
469
+ logger.info(
470
+ f"{room.local_participant.get_attribute('name')} making openai request on behalf of {on_behalf_of_name}"
471
+ )
472
+ extra_headers["Meshagent-On-Behalf-Of"] = (
473
+ on_behalf_of_name
474
+ )
475
+
476
+ logger.info(
477
+ f"requesting response from openai with model: {model}"
478
+ )
479
+
480
+ response: Response = await openai.responses.create(
481
+ extra_headers=extra_headers,
482
+ stream=stream,
483
+ model=model,
484
+ input=context.messages,
485
+ tools=open_ai_tools,
486
+ text=text,
487
+ previous_response_id=previous_response_id,
488
+ instructions=instructions or NOT_GIVEN,
489
+ **response_options,
490
+ )
491
+
492
+ async def handle_message(message: BaseModel):
493
+ with tracer.start_as_current_span(
494
+ "llm.handle_response"
495
+ ) as span:
496
+ span.set_attributes(
497
+ {
498
+ "type": message.type,
499
+ "message": safe_model_dump(message),
500
+ }
501
+ )
502
+
503
+ room.developer.log_nowait(
504
+ type="llm.message",
505
+ data={
506
+ "context": context.id,
507
+ "participant_id": room.local_participant.id,
508
+ "participant_name": room.local_participant.get_attribute(
509
+ "name"
510
+ ),
511
+ "message": message.to_dict(),
512
+ },
513
+ )
514
+
515
+ if message.type == "function_call":
516
+ tasks = []
517
+
518
+ async def do_tool_call(
519
+ tool_call: ResponseFunctionToolCall,
520
+ ):
521
+ try:
522
+ with tracer.start_as_current_span(
523
+ "llm.handle_tool_call"
524
+ ) as span:
525
+ span.set_attributes(
526
+ {
527
+ "id": tool_call.id,
528
+ "name": tool_call.name,
529
+ "call_id": tool_call.call_id,
530
+ "arguments": json.dumps(
531
+ tool_call.arguments
532
+ ),
533
+ }
534
+ )
535
+
536
+ tool_context = ToolContext(
537
+ room=room,
538
+ caller=room.local_participant,
539
+ on_behalf_of=on_behalf_of,
540
+ caller_context={
541
+ "chat": context.to_json()
542
+ },
543
+ )
544
+ tool_response = (
545
+ await tool_bundle.execute(
546
+ context=tool_context,
547
+ tool_call=tool_call,
548
+ )
549
+ )
550
+ if (
551
+ tool_response.caller_context
552
+ is not None
553
+ ):
554
+ if (
555
+ tool_response.caller_context.get(
556
+ "chat", None
557
+ )
558
+ is not None
559
+ ):
560
+ tool_chat_context = AgentChatContext.from_json(
561
+ tool_response.caller_context[
562
+ "chat"
563
+ ]
564
+ )
565
+ if (
566
+ tool_chat_context.previous_response_id
567
+ is not None
568
+ ):
569
+ context.track_response(
570
+ tool_chat_context.previous_response_id
571
+ )
572
+
573
+ logger.info(
574
+ f"tool response {tool_response}"
575
+ )
576
+ return await tool_adapter.create_messages(
577
+ context=context,
578
+ tool_call=tool_call,
579
+ room=room,
580
+ response=tool_response,
581
+ )
582
+
583
+ except Exception as e:
584
+ logger.error(
585
+ f"unable to complete tool call {tool_call}",
586
+ exc_info=e,
587
+ )
588
+ room.developer.log_nowait(
589
+ type="llm.error",
590
+ data={
591
+ "participant_id": room.local_participant.id,
592
+ "participant_name": room.local_participant.get_attribute(
593
+ "name"
594
+ ),
595
+ "error": f"{e}",
596
+ },
597
+ )
598
+
599
+ return [
600
+ {
601
+ "output": json.dumps(
602
+ {
603
+ "error": f"unable to complete tool call: {e}"
604
+ }
605
+ ),
606
+ "call_id": tool_call.call_id,
607
+ "type": "function_call_output",
608
+ }
609
+ ]
610
+
611
+ tasks.append(
612
+ asyncio.create_task(do_tool_call(message))
613
+ )
614
+
615
+ results = await asyncio.gather(*tasks)
616
+
617
+ all_results = []
618
+ for result in results:
619
+ room.developer.log_nowait(
620
+ type="llm.message",
621
+ data={
622
+ "context": context.id,
623
+ "participant_id": room.local_participant.id,
624
+ "participant_name": room.local_participant.get_attribute(
625
+ "name"
626
+ ),
627
+ "message": result,
628
+ },
629
+ )
630
+ all_results.extend(result)
631
+
632
+ return all_results, False
633
+
634
+ elif message.type == "message":
635
+ contents = message.content
636
+ if response_schema is None:
637
+ return [], False
638
+ else:
639
+ for content in contents:
640
+ # First try to parse the result
641
+ try:
642
+ full_response = json.loads(
643
+ content.text
644
+ )
645
+
646
+ # sometimes open ai packs two JSON chunks seperated by newline, check if that's why we couldn't parse
647
+ except json.decoder.JSONDecodeError:
648
+ for (
649
+ part
650
+ ) in content.text.splitlines():
651
+ if len(part.strip()) > 0:
652
+ full_response = json.loads(
653
+ part
654
+ )
655
+
656
+ try:
657
+ self.validate(
658
+ response=full_response,
659
+ output_schema=response_schema,
660
+ )
661
+ except Exception as e:
662
+ logger.error(
663
+ "recieved invalid response, retrying",
664
+ exc_info=e,
665
+ )
666
+ error = {
667
+ "role": "user",
668
+ "content": "encountered a validation error with the output: {error}".format(
669
+ error=e
670
+ ),
671
+ }
672
+ room.developer.log_nowait(
673
+ type="llm.message",
674
+ data={
675
+ "context": message.id,
676
+ "participant_id": room.local_participant.id,
677
+ "participant_name": room.local_participant.get_attribute(
678
+ "name"
679
+ ),
680
+ "message": error,
681
+ },
682
+ )
683
+ context.messages.append(
684
+ error
685
+ )
686
+ continue
687
+
688
+ return [full_response], True
689
+ # elif message.type == "computer_call" and tool_bundle.get_tool("computer_call"):
690
+ # with tracer.start_as_current_span("llm.handle_computer_call") as span:
691
+ #
692
+ # computer_call :ResponseComputerToolCall = message
693
+ # span.set_attributes({
694
+ # "id": computer_call.id,
695
+ # "action": computer_call.action,
696
+ # "call_id": computer_call.call_id,
697
+ # "type": json.dumps(computer_call.type)
698
+ # })
699
+
700
+ # tool_context = ToolContext(
701
+ # room=room,
702
+ # caller=room.local_participant,
703
+ # caller_context={ "chat" : context.to_json }
704
+ # )
705
+ # outputs = (await tool_bundle.get_tool("computer_call").execute(context=tool_context, arguments=message.model_dump(mode="json"))).outputs
706
+
707
+ # return outputs, False
708
+
709
+ else:
710
+ with tracer.start_as_current_span(
711
+ "llm.handle_tool_call"
712
+ ) as span:
713
+ for toolkit in toolkits:
714
+ for tool in toolkit.tools:
715
+ if isinstance(
716
+ tool, OpenAIResponsesTool
717
+ ):
718
+ arguments = message.model_dump(
719
+ mode="json"
720
+ )
721
+ span.set_attributes(
722
+ {
723
+ "type": message.type,
724
+ "arguments": safe_json_dump(
725
+ arguments
726
+ ),
727
+ }
728
+ )
729
+
730
+ handlers = tool.get_open_ai_output_handlers()
731
+ if message.type in handlers:
732
+ tool_context = ToolContext(
733
+ room=room,
734
+ caller=room.local_participant,
735
+ caller_context={
736
+ "chat": context.to_json()
737
+ },
738
+ )
739
+
740
+ try:
741
+ if (
742
+ event_handler
743
+ is not None
744
+ ):
745
+ event_handler(
746
+ {
747
+ "type": "meshagent.handler.added",
748
+ "item": message.model_dump(
749
+ mode="json"
750
+ ),
751
+ }
752
+ )
753
+
754
+ result = await handlers[
755
+ message.type
756
+ ](
757
+ tool_context,
758
+ **arguments,
759
+ )
760
+
761
+ except Exception as e:
762
+ if (
763
+ event_handler
764
+ is not None
765
+ ):
766
+ event_handler(
767
+ {
768
+ "type": "meshagent.handler.done",
769
+ "error": f"{e}",
770
+ }
771
+ )
772
+
773
+ raise
774
+
775
+ if (
776
+ event_handler
777
+ is not None
778
+ ):
779
+ event_handler(
780
+ {
781
+ "type": "meshagent.handler.done",
782
+ "item": result,
783
+ }
784
+ )
785
+
786
+ if result is not None:
787
+ span.set_attribute(
788
+ "result",
789
+ safe_json_dump(
790
+ result
791
+ ),
792
+ )
793
+ return [result], False
794
+
795
+ return [], False
796
+
797
+ logger.warning(
798
+ f"OpenAI response handler was not registered for {message.type}"
799
+ )
800
+
801
+ return [], False
802
+
803
+ if not stream:
804
+ room.developer.log_nowait(
805
+ type="llm.message",
806
+ data={
807
+ "context": context.id,
808
+ "participant_id": room.local_participant.id,
809
+ "participant_name": room.local_participant.get_attribute(
810
+ "name"
811
+ ),
812
+ "response": response.to_dict(),
813
+ },
814
+ )
815
+
816
+ context.track_response(response.id)
817
+
818
+ final_outputs = []
819
+
820
+ for message in response.output:
821
+ context.previous_messages.append(message.to_dict())
822
+ outputs, done = await handle_message(
823
+ message=message
824
+ )
825
+ if done:
826
+ final_outputs.extend(outputs)
827
+ else:
828
+ for output in outputs:
829
+ context.messages.append(output)
830
+
831
+ if len(final_outputs) > 0:
832
+ return final_outputs[0]
833
+
834
+ with tracer.start_as_current_span(
835
+ "llm.turn.check_for_termination"
836
+ ) as span:
837
+ term = await self.check_for_termination(
838
+ context=context, room=room
839
+ )
840
+ if term:
841
+ span.set_attribute("terminate", True)
842
+ text = ""
843
+ for output in response.output:
844
+ if output.type == "message":
845
+ for content in output.content:
846
+ text += content.text
847
+
848
+ return text
849
+ else:
850
+ span.set_attribute("terminate", False)
851
+
852
+ else:
853
+ final_outputs = []
854
+ all_outputs = []
855
+ async for e in response:
856
+ with tracer.start_as_current_span(
857
+ "llm.stream.event"
858
+ ) as span:
859
+ event: ResponseStreamEvent = e
860
+ span.set_attributes(
861
+ {
862
+ "type": event.type,
863
+ "event": safe_model_dump(event),
864
+ }
865
+ )
866
+ event_handler(event)
867
+
868
+ if event.type == "response.completed":
869
+ context.track_response(event.response.id)
870
+
871
+ context.messages.extend(all_outputs)
872
+
873
+ with tracer.start_as_current_span(
874
+ "llm.turn.check_for_termination"
875
+ ) as span:
876
+ term = await self.check_for_termination(
877
+ context=context, room=room
878
+ )
879
+
880
+ if term:
881
+ span.set_attribute(
882
+ "terminate", True
883
+ )
884
+
885
+ text = ""
886
+ for output in event.response.output:
887
+ if output.type == "message":
888
+ for (
889
+ content
890
+ ) in output.content:
891
+ text += content.text
892
+
893
+ return text
894
+
895
+ span.set_attribute("terminate", False)
896
+
897
+ all_outputs = []
898
+
899
+ elif event.type == "response.output_item.done":
900
+ context.previous_messages.append(
901
+ event.item.to_dict()
902
+ )
903
+
904
+ outputs, done = await handle_message(
905
+ message=event.item
906
+ )
907
+ if done:
908
+ final_outputs.extend(outputs)
909
+ else:
910
+ for output in outputs:
911
+ all_outputs.append(output)
912
+
913
+ else:
914
+ for toolkit in toolkits:
915
+ for tool in toolkit.tools:
916
+ if isinstance(
917
+ tool, OpenAIResponsesTool
918
+ ):
919
+ callbacks = tool.get_open_ai_stream_callbacks()
920
+
921
+ if event.type in callbacks:
922
+ tool_context = ToolContext(
923
+ room=room,
924
+ caller=room.local_participant,
925
+ caller_context={
926
+ "chat": context.to_json()
927
+ },
928
+ )
929
+
930
+ await callbacks[event.type](
931
+ tool_context,
932
+ **event.to_dict(),
933
+ )
934
+
935
+ if len(final_outputs) > 0:
936
+ return final_outputs[0]
937
+
938
+ except APIStatusError as e:
939
+ raise RoomException(f"Error from OpenAI: {e}")
940
+
941
+
942
+ class OpenAIResponsesTool(BaseTool):
943
+ def get_open_ai_tool_definitions(self) -> list[dict]:
944
+ return []
945
+
946
+ def get_open_ai_stream_callbacks(self) -> dict[str, Callable]:
947
+ return {}
948
+
949
+ def get_open_ai_output_handlers(self) -> dict[str, Callable]:
950
+ return {}
951
+
952
+
953
+ class ImageGenerationConfig(ToolkitConfig):
954
+ name: Literal["image_generation"] = "image_generation"
955
+ background: Literal["transparent", "opaque", "auto"] = None
956
+ input_image_mask_url: Optional[str] = None
957
+ model: Optional[str] = None
958
+ moderation: Optional[str] = None
959
+ output_compression: Optional[int] = None
960
+ output_format: Optional[Literal["png", "webp", "jpeg"]] = None
961
+ partial_images: Optional[int] = None
962
+ quality: Optional[Literal["auto", "low", "medium", "high"]] = None
963
+ size: Optional[Literal["1024x1024", "1024x1536", "1536x1024", "auto"]] = None
964
+
965
+
966
+ class ImageGenerationToolkitBuilder(ToolkitBuilder):
967
+ def __init__(self):
968
+ super().__init__(name="image_generation", type=ImageGenerationConfig)
969
+
970
+ async def make(
971
+ self, *, room: RoomClient, model: str, config: ImageGenerationConfig
972
+ ):
973
+ return Toolkit(
974
+ name="image_generation", tools=[ImageGenerationTool(config=config)]
975
+ )
976
+
977
+
978
+ class ImageGenerationTool(OpenAIResponsesTool):
979
+ def __init__(
980
+ self,
981
+ *,
982
+ config: ImageGenerationConfig,
983
+ ):
984
+ super().__init__(name="image_generation")
985
+ self.background = config.background
986
+ self.input_image_mask_url = config.input_image_mask_url
987
+ self.model = config.model
988
+ self.moderation = config.moderation
989
+ self.output_compression = config.output_compression
990
+ self.output_format = config.output_format
991
+ self.partial_images = (
992
+ config.partial_images if config.partial_images is not None else 1
993
+ )
994
+ self.quality = config.quality
995
+ self.size = config.size
996
+
997
+ def get_open_ai_tool_definitions(self):
998
+ opts = {"type": "image_generation"}
999
+
1000
+ if self.background is not None:
1001
+ opts["background"] = self.background
1002
+
1003
+ if self.input_image_mask_url is not None:
1004
+ opts["input_image_mask"] = {"image_url": self.input_image_mask_url}
1005
+
1006
+ if self.model is not None:
1007
+ opts["model"] = self.model
1008
+
1009
+ if self.moderation is not None:
1010
+ opts["moderation"] = self.moderation
1011
+
1012
+ if self.output_compression is not None:
1013
+ opts["output_compression"] = self.output_compression
1014
+
1015
+ if self.output_format is not None:
1016
+ opts["output_format"] = self.output_format
1017
+
1018
+ if self.partial_images is not None:
1019
+ opts["partial_images"] = self.partial_images
1020
+
1021
+ if self.quality is not None:
1022
+ opts["quality"] = self.quality
1023
+
1024
+ if self.size is not None:
1025
+ opts["size"] = self.size
1026
+
1027
+ return [opts]
1028
+
1029
+ def get_open_ai_stream_callbacks(self):
1030
+ return {
1031
+ "response.image_generation_call.completed": self.on_image_generation_completed,
1032
+ "response.image_generation_call.in_progress": self.on_image_generation_in_progress,
1033
+ "response.image_generation_call.generating": self.on_image_generation_generating,
1034
+ "response.image_generation_call.partial_image": self.on_image_generation_partial,
1035
+ }
1036
+
1037
+ def get_open_ai_output_handlers(self):
1038
+ return {"image_generation_call": self.handle_image_generated}
1039
+
1040
+ # response.image_generation_call.completed
1041
+ async def on_image_generation_completed(
1042
+ self,
1043
+ context: ToolContext,
1044
+ *,
1045
+ item_id: str,
1046
+ output_index: int,
1047
+ sequence_number: int,
1048
+ type: str,
1049
+ **extra,
1050
+ ):
1051
+ pass
1052
+
1053
+ # response.image_generation_call.in_progress
1054
+ async def on_image_generation_in_progress(
1055
+ self,
1056
+ context: ToolContext,
1057
+ *,
1058
+ item_id: str,
1059
+ output_index: int,
1060
+ sequence_number: int,
1061
+ type: str,
1062
+ **extra,
1063
+ ):
1064
+ pass
1065
+
1066
+ # response.image_generation_call.generating
1067
+ async def on_image_generation_generating(
1068
+ self,
1069
+ context: ToolContext,
1070
+ *,
1071
+ item_id: str,
1072
+ output_index: int,
1073
+ sequence_number: int,
1074
+ type: str,
1075
+ **extra,
1076
+ ):
1077
+ pass
1078
+
1079
+ # response.image_generation_call.partial_image
1080
+ async def on_image_generation_partial(
1081
+ self,
1082
+ context: ToolContext,
1083
+ *,
1084
+ item_id: str,
1085
+ output_index: int,
1086
+ sequence_number: int,
1087
+ type: str,
1088
+ partial_image_b64: str,
1089
+ partial_image_index: int,
1090
+ size: str,
1091
+ quality: str,
1092
+ background: str,
1093
+ output_format: str,
1094
+ **extra,
1095
+ ):
1096
+ pass
1097
+
1098
+ async def on_image_generated(
1099
+ self,
1100
+ context: ToolContext,
1101
+ *,
1102
+ item_id: str,
1103
+ data: bytes,
1104
+ status: str,
1105
+ size: str,
1106
+ quality: str,
1107
+ background: str,
1108
+ output_format: str,
1109
+ **extra,
1110
+ ):
1111
+ pass
1112
+
1113
+ async def handle_image_generated(
1114
+ self,
1115
+ context: ToolContext,
1116
+ *,
1117
+ id: str,
1118
+ result: str | None,
1119
+ status: str,
1120
+ type: str,
1121
+ size: str,
1122
+ quality: str,
1123
+ background: str,
1124
+ output_format: str,
1125
+ **extra,
1126
+ ):
1127
+ if result is not None:
1128
+ data = base64.b64decode(result)
1129
+ await self.on_image_generated(
1130
+ context,
1131
+ item_id=id,
1132
+ data=data,
1133
+ status=status,
1134
+ size=size,
1135
+ quality=quality,
1136
+ background=background,
1137
+ output_format=output_format,
1138
+ )
1139
+
1140
+
1141
+ class LocalShellConfig(ToolkitConfig):
1142
+ name: Literal["local_shell"] = "local_shell"
1143
+
1144
+
1145
+ class LocalShellToolkitBuilder(ToolkitBuilder):
1146
+ def __init__(self, *, working_directory: Optional[str] = None):
1147
+ super().__init__(name="local_shell", type=LocalShellConfig)
1148
+ self.working_directory = working_directory
1149
+
1150
+ async def make(self, *, room: RoomClient, model: str, config: LocalShellConfig):
1151
+ return Toolkit(
1152
+ name="local_shell",
1153
+ tools=[
1154
+ LocalShellTool(config=config, working_directory=self.working_directory)
1155
+ ],
1156
+ )
1157
+
1158
+
1159
+ MAX_SHELL_OUTPUT_SIZE = 1024 * 100
1160
+
1161
+
1162
+ class LocalShellTool(OpenAIResponsesTool):
1163
+ def __init__(
1164
+ self,
1165
+ *,
1166
+ config: Optional[LocalShellConfig] = None,
1167
+ working_directory: Optional[str] = None,
1168
+ ):
1169
+ super().__init__(name="local_shell")
1170
+ if config is None:
1171
+ config = LocalShellConfig(name="local_shell")
1172
+
1173
+ self.working_directory = working_directory
1174
+
1175
+ def get_open_ai_tool_definitions(self):
1176
+ return [{"type": "local_shell"}]
1177
+
1178
+ def get_open_ai_output_handlers(self):
1179
+ return {"local_shell_call": self.handle_local_shell_call}
1180
+
1181
+ async def execute_shell_command(
1182
+ self,
1183
+ context: ToolContext,
1184
+ *,
1185
+ command: list[str],
1186
+ env: dict,
1187
+ type: str,
1188
+ timeout_ms: int | None = None,
1189
+ user: str | None = None,
1190
+ working_directory: str | None = None,
1191
+ ):
1192
+ merged_env = {**os.environ, **(env or {})}
1193
+
1194
+ try:
1195
+ # Spawn the process
1196
+ proc = await asyncio.create_subprocess_exec(
1197
+ *(command if isinstance(command, (list, tuple)) else [command]),
1198
+ cwd=working_directory or self.working_directory or os.getcwd(),
1199
+ env=merged_env,
1200
+ stdout=asyncio.subprocess.PIPE,
1201
+ stderr=asyncio.subprocess.PIPE,
1202
+ )
1203
+
1204
+ timeout = float(timeout_ms) / 1000.0 if timeout_ms else 20.0
1205
+
1206
+ logger.info(f"executing command {command} with timeout: {timeout}s")
1207
+
1208
+ stdout, stderr = await asyncio.wait_for(
1209
+ proc.communicate(),
1210
+ timeout=timeout,
1211
+ )
1212
+ except asyncio.TimeoutError:
1213
+ proc.kill() # send SIGKILL / TerminateProcess
1214
+ logger.info(f"The command timed out after {timeout}s")
1215
+ stdout, stderr = await proc.communicate()
1216
+ return f"The command timed out after {timeout}s"
1217
+ # re-raise so caller sees the timeout
1218
+ except Exception as ex:
1219
+ return f"The command failed: {ex}"
1220
+
1221
+ encoding = os.device_encoding(1) or "utf-8"
1222
+ stdout = stdout.decode(encoding, errors="replace")
1223
+ stderr = stderr.decode(encoding, errors="replace")
1224
+
1225
+ result = stdout + stderr
1226
+ if len(result) > MAX_SHELL_OUTPUT_SIZE:
1227
+ return f"Error: the command returned too much data ({result} bytes)"
1228
+
1229
+ return result
1230
+
1231
+ async def handle_local_shell_call(
1232
+ self,
1233
+ context,
1234
+ *,
1235
+ id: str,
1236
+ action: dict,
1237
+ call_id: str,
1238
+ status: str,
1239
+ type: str,
1240
+ **extra,
1241
+ ):
1242
+ result = await self.execute_shell_command(context, **action)
1243
+
1244
+ output_item = {
1245
+ "type": "local_shell_call_output",
1246
+ "call_id": call_id,
1247
+ "output": result,
1248
+ }
1249
+
1250
+ return output_item
1251
+
1252
+
1253
+ class ShellConfig(ToolkitConfig):
1254
+ name: Literal["shell"] = ("shell",)
1255
+
1256
+
1257
+ DEFAULT_CONTAINER_MOUNT_SPEC = ContainerMountSpec(
1258
+ room=[RoomStorageMountSpec(path="/data")]
1259
+ )
1260
+
1261
+
1262
+ class ShellToolkitBuilder(ToolkitBuilder):
1263
+ def __init__(
1264
+ self,
1265
+ *,
1266
+ working_directory: Optional[str] = None,
1267
+ image: Optional[str] = "ubuntu:latest",
1268
+ mounts: Optional[ContainerMountSpec] = DEFAULT_CONTAINER_MOUNT_SPEC,
1269
+ ):
1270
+ super().__init__(name="shell", type=ShellConfig)
1271
+ self.working_directory = working_directory
1272
+ self.image = image
1273
+ self.mounts = mounts
1274
+
1275
+ async def make(self, *, room: RoomClient, model: str, config: LocalShellConfig):
1276
+ return Toolkit(
1277
+ name="shell",
1278
+ tools=[
1279
+ ShellTool(
1280
+ config=config,
1281
+ working_directory=self.working_directory,
1282
+ image=self.image,
1283
+ mounts=self.mounts,
1284
+ )
1285
+ ],
1286
+ )
1287
+
1288
+
1289
+ class ShellTool(OpenAIResponsesTool):
1290
+ def __init__(
1291
+ self,
1292
+ *,
1293
+ config: Optional[ShellConfig] = None,
1294
+ working_directory: Optional[str] = None,
1295
+ image: Optional[str] = "ubuntu:latest",
1296
+ mounts: Optional[ContainerMountSpec] = DEFAULT_CONTAINER_MOUNT_SPEC,
1297
+ ):
1298
+ super().__init__(name="shell")
1299
+ if config is None:
1300
+ config = ShellConfig(name="shell")
1301
+ self.working_directory = working_directory
1302
+ self.image = image
1303
+ self.mounts = mounts
1304
+
1305
+ def get_open_ai_tool_definitions(self):
1306
+ return [{"type": "shell"}]
1307
+
1308
+ def get_open_ai_output_handlers(self):
1309
+ return {"shell_call": self.handle_shell_call}
1310
+
1311
+ async def execute_shell_command(
1312
+ self,
1313
+ context: ToolContext,
1314
+ *,
1315
+ commands: list[str],
1316
+ max_output_length: Optional[int] = None,
1317
+ timeout_ms: Optional[int] = None,
1318
+ ):
1319
+ merged_env = {**os.environ}
1320
+
1321
+ results = []
1322
+ encoding = os.device_encoding(1) or "utf-8"
1323
+
1324
+ left = max_output_length
1325
+
1326
+ def limit(s: str):
1327
+ nonlocal left
1328
+ if left is not None:
1329
+ s = s[0:left]
1330
+ left -= len(s)
1331
+ return s
1332
+ else:
1333
+ return s
1334
+
1335
+ timeout = float(timeout_ms) / 1000.0 if timeout_ms else 20.0
1336
+
1337
+ if self.image is not None:
1338
+ container_id = await context.room.containers.run(
1339
+ command="sleep infinity",
1340
+ image=self.image,
1341
+ mounts=self.mounts,
1342
+ )
1343
+
1344
+ try:
1345
+ # TODO: what if container start fails
1346
+
1347
+ logger.info(f"executing shell commands in container {container_id}")
1348
+
1349
+ for command in commands:
1350
+ exec = await context.room.containers.exec(
1351
+ container_id=container_id, command=command, tty=False
1352
+ )
1353
+
1354
+ stdout = bytearray()
1355
+ stderr = bytearray()
1356
+
1357
+ async for se in exec.stderr():
1358
+ stdout.extend(se)
1359
+
1360
+ async for so in exec.stdout():
1361
+ stdout.extend(so)
1362
+
1363
+ try:
1364
+ async with asyncio.Timeout(timeout):
1365
+ exit_code = await exec.result
1366
+
1367
+ results.append(
1368
+ {
1369
+ "outcome": {
1370
+ "type": "exit",
1371
+ "exit_code": exit_code,
1372
+ },
1373
+ "stdout": stdout.decode(),
1374
+ "stderr": stderr.decode(),
1375
+ }
1376
+ )
1377
+
1378
+ except asyncio.TimeoutError:
1379
+ logger.info(f"The command timed out after {timeout}s")
1380
+ await exec.close()
1381
+
1382
+ results.append(
1383
+ {
1384
+ "outcome": {"type": "timeout"},
1385
+ "stdout": limit(
1386
+ stdout.decode(encoding, errors="replace")
1387
+ ),
1388
+ "stderr": limit(
1389
+ stderr.decode(encoding, errors="replace")
1390
+ ),
1391
+ }
1392
+ )
1393
+ break
1394
+
1395
+ except Exception as ex:
1396
+ results.append(
1397
+ {
1398
+ "outcome": {
1399
+ "type": "exit",
1400
+ "exit_code": 1,
1401
+ },
1402
+ "stdout": "",
1403
+ "stderr": f"{ex}",
1404
+ }
1405
+ )
1406
+ break
1407
+
1408
+ except Exception as ex:
1409
+ results.append(
1410
+ {
1411
+ "outcome": {
1412
+ "type": "exit",
1413
+ "exit_code": 1,
1414
+ },
1415
+ "stdout": "",
1416
+ "stderr": f"{ex}",
1417
+ }
1418
+ )
1419
+
1420
+ if container_id is not None:
1421
+ await context.room.containers.stop(container_id=container_id)
1422
+ await context.room.containers.delete(container_id=container_id)
1423
+
1424
+ else:
1425
+ for command in commands:
1426
+ logger.info(f"executing command {command} with timeout: {timeout}s")
1427
+
1428
+ # Spawn the process
1429
+ try:
1430
+ proc = await asyncio.create_subprocess_shell(
1431
+ command,
1432
+ cwd=self.working_directory or os.getcwd(),
1433
+ env=merged_env,
1434
+ stdout=asyncio.subprocess.PIPE,
1435
+ stderr=asyncio.subprocess.PIPE,
1436
+ )
1437
+
1438
+ stdout, stderr = await asyncio.wait_for(
1439
+ proc.communicate(),
1440
+ timeout=timeout,
1441
+ )
1442
+ except asyncio.TimeoutError:
1443
+ logger.info(f"The command timed out after {timeout}s")
1444
+ proc.kill() # send SIGKILL / TerminateProcess
1445
+
1446
+ stdout, stderr = await proc.communicate()
1447
+
1448
+ results.append(
1449
+ {
1450
+ "outcome": {"type": "timeout"},
1451
+ "stdout": limit(stdout.decode(encoding, errors="replace")),
1452
+ "stderr": limit(stderr.decode(encoding, errors="replace")),
1453
+ }
1454
+ )
1455
+
1456
+ break
1457
+
1458
+ except Exception as ex:
1459
+ results.append(
1460
+ {
1461
+ "outcome": {
1462
+ "type": "exit",
1463
+ "exit_code": 1,
1464
+ },
1465
+ "stdout": "",
1466
+ "stderr": f"{ex}",
1467
+ }
1468
+ )
1469
+ break
1470
+
1471
+ results.append(
1472
+ {
1473
+ "outcome": {
1474
+ "type": "exit",
1475
+ "exit_code": proc.returncode,
1476
+ },
1477
+ "stdout": limit(stdout.decode(encoding, errors="replace")),
1478
+ "stderr": limit(stderr.decode(encoding, errors="replace")),
1479
+ }
1480
+ )
1481
+
1482
+ return results
1483
+
1484
+ async def handle_shell_call(
1485
+ self,
1486
+ context,
1487
+ *,
1488
+ id: str,
1489
+ action: dict,
1490
+ call_id: str,
1491
+ status: str,
1492
+ type: str,
1493
+ **extra,
1494
+ ):
1495
+ result = await self.execute_shell_command(context, **action)
1496
+
1497
+ output_item = {
1498
+ "type": "shell_call_output",
1499
+ "call_id": call_id,
1500
+ "output": result,
1501
+ }
1502
+
1503
+ return output_item
1504
+
1505
+
1506
+ class ContainerFile:
1507
+ def __init__(self, *, file_id: str, mime_type: str, container_id: str):
1508
+ self.file_id = file_id
1509
+ self.mime_type = mime_type
1510
+ self.container_id = container_id
1511
+
1512
+
1513
+ class CodeInterpreterTool(OpenAIResponsesTool):
1514
+ def __init__(
1515
+ self,
1516
+ *,
1517
+ container_id: Optional[str] = None,
1518
+ file_ids: Optional[List[str]] = None,
1519
+ ):
1520
+ super().__init__(name="code_interpreter_call")
1521
+ self.container_id = container_id
1522
+ self.file_ids = file_ids
1523
+
1524
+ def get_open_ai_tool_definitions(self):
1525
+ opts = {"type": "code_interpreter"}
1526
+
1527
+ if self.container_id is not None:
1528
+ opts["container_id"] = self.container_id
1529
+
1530
+ if self.file_ids is not None:
1531
+ if self.container_id is not None:
1532
+ raise Exception(
1533
+ "Cannot specify both an existing container and files to upload in a code interpreter tool"
1534
+ )
1535
+
1536
+ opts["container"] = {"type": "auto", "file_ids": self.file_ids}
1537
+
1538
+ return [opts]
1539
+
1540
+ def get_open_ai_output_handlers(self):
1541
+ return {"code_interpreter_call": self.handle_code_interpreter_call}
1542
+
1543
+ async def on_code_interpreter_result(
1544
+ self,
1545
+ context: ToolContext,
1546
+ *,
1547
+ code: str,
1548
+ logs: list[str],
1549
+ files: list[ContainerFile],
1550
+ ):
1551
+ pass
1552
+
1553
+ async def handle_code_interpreter_call(
1554
+ self,
1555
+ context,
1556
+ *,
1557
+ code: str,
1558
+ id: str,
1559
+ results: list[dict],
1560
+ call_id: str,
1561
+ status: str,
1562
+ type: str,
1563
+ container_id: str,
1564
+ **extra,
1565
+ ):
1566
+ logs = []
1567
+ files = []
1568
+
1569
+ for result in results:
1570
+ if result.type == "logs":
1571
+ logs.append(results["logs"])
1572
+
1573
+ elif result.type == "files":
1574
+ files.append(
1575
+ ContainerFile(
1576
+ container_id=container_id,
1577
+ file_id=result["file_id"],
1578
+ mime_type=result["mime_type"],
1579
+ )
1580
+ )
1581
+
1582
+ await self.on_code_interpreter_result(
1583
+ context, code=code, logs=logs, files=files
1584
+ )
1585
+
1586
+
1587
+ class MCPToolDefinition:
1588
+ def __init__(
1589
+ self,
1590
+ *,
1591
+ input_schema: dict,
1592
+ name: str,
1593
+ annotations: dict | None,
1594
+ description: str | None,
1595
+ ):
1596
+ self.input_schema = input_schema
1597
+ self.name = name
1598
+ self.annotations = annotations
1599
+ self.description = description
1600
+
1601
+
1602
+ class MCPServer(BaseModel):
1603
+ server_label: str
1604
+ server_url: Optional[str] = None
1605
+ allowed_tools: Optional[list[str]] = None
1606
+ authorization: Optional[str] = None
1607
+ headers: Optional[dict] = None
1608
+
1609
+ # require approval for all tools
1610
+ require_approval: Optional[Literal["always", "never"]] = None
1611
+ # list of tools that always require approval
1612
+ always_require_approval: Optional[list[str]] = None
1613
+ # list of tools that never require approval
1614
+ never_require_approval: Optional[list[str]] = None
1615
+
1616
+ openai_connector_id: Optional[str] = None
1617
+
1618
+
1619
+ class MCPConfig(ToolkitConfig):
1620
+ name: Literal["mcp"] = "mcp"
1621
+ servers: list[MCPServer]
1622
+
1623
+
1624
+ class MCPToolkitBuilder(ToolkitBuilder):
1625
+ def __init__(self):
1626
+ super().__init__(name="mcp", type=MCPConfig)
1627
+
1628
+ async def make(self, *, room: RoomClient, model: str, config: MCPConfig):
1629
+ return Toolkit(name="mcp", tools=[MCPTool(config=config)])
1630
+
1631
+
1632
+ class MCPTool(OpenAIResponsesTool):
1633
+ def __init__(self, *, config: MCPConfig):
1634
+ super().__init__(name="mcp")
1635
+ self.servers = config.servers
1636
+
1637
+ def get_open_ai_tool_definitions(self):
1638
+ defs = []
1639
+ for server in self.servers:
1640
+ opts = {
1641
+ "type": "mcp",
1642
+ "server_label": server.server_label,
1643
+ }
1644
+
1645
+ if server.server_url is not None:
1646
+ opts["server_url"] = server.server_url
1647
+
1648
+ if server.openai_connector_id is not None:
1649
+ opts["connector_id"] = server.openai_connector_id
1650
+
1651
+ if server.allowed_tools is not None:
1652
+ opts["allowed_tools"] = server.allowed_tools
1653
+
1654
+ if server.authorization is not None:
1655
+ opts["authorization"] = server.authorization
1656
+
1657
+ if server.headers is not None:
1658
+ opts["headers"] = server.headers
1659
+
1660
+ if (
1661
+ server.always_require_approval is not None
1662
+ or server.never_require_approval is not None
1663
+ ):
1664
+ opts["require_approval"] = {}
1665
+
1666
+ if server.always_require_approval is not None:
1667
+ opts["require_approval"]["always"] = {
1668
+ "tool_names": server.always_require_approval
1669
+ }
1670
+
1671
+ if server.never_require_approval is not None:
1672
+ opts["require_approval"]["never"] = {
1673
+ "tool_names": server.never_require_approval
1674
+ }
1675
+
1676
+ if server.require_approval:
1677
+ opts["require_approval"] = server.require_approval
1678
+
1679
+ defs.append(opts)
1680
+
1681
+ return defs
1682
+
1683
+ def get_open_ai_stream_callbacks(self):
1684
+ return {
1685
+ "response.mcp_list_tools.in_progress": self.on_mcp_list_tools_in_progress,
1686
+ "response.mcp_list_tools.failed": self.on_mcp_list_tools_failed,
1687
+ "response.mcp_list_tools.completed": self.on_mcp_list_tools_completed,
1688
+ "response.mcp_call.in_progress": self.on_mcp_call_in_progress,
1689
+ "response.mcp_call.failed": self.on_mcp_call_failed,
1690
+ "response.mcp_call.completed": self.on_mcp_call_completed,
1691
+ "response.mcp_call.arguments.done": self.on_mcp_call_arguments_done,
1692
+ "response.mcp_call.arguments.delta": self.on_mcp_call_arguments_delta,
1693
+ }
1694
+
1695
+ async def on_mcp_list_tools_in_progress(
1696
+ self, context: ToolContext, *, sequence_number: int, type: str, **extra
1697
+ ):
1698
+ pass
1699
+
1700
+ async def on_mcp_list_tools_failed(
1701
+ self, context: ToolContext, *, sequence_number: int, type: str, **extra
1702
+ ):
1703
+ pass
1704
+
1705
+ async def on_mcp_list_tools_completed(
1706
+ self, context: ToolContext, *, sequence_number: int, type: str, **extra
1707
+ ):
1708
+ pass
1709
+
1710
+ async def on_mcp_call_in_progress(
1711
+ self,
1712
+ context: ToolContext,
1713
+ *,
1714
+ item_id: str,
1715
+ output_index: int,
1716
+ sequence_number: int,
1717
+ type: str,
1718
+ **extra,
1719
+ ):
1720
+ pass
1721
+
1722
+ async def on_mcp_call_failed(
1723
+ self, context: ToolContext, *, sequence_number: int, type: str, **extra
1724
+ ):
1725
+ pass
1726
+
1727
+ async def on_mcp_call_completed(
1728
+ self, context: ToolContext, *, sequence_number: int, type: str, **extra
1729
+ ):
1730
+ pass
1731
+
1732
+ async def on_mcp_call_arguments_done(
1733
+ self,
1734
+ context: ToolContext,
1735
+ *,
1736
+ arguments: dict,
1737
+ item_id: str,
1738
+ output_index: int,
1739
+ sequence_number: int,
1740
+ type: str,
1741
+ **extra,
1742
+ ):
1743
+ pass
1744
+
1745
+ async def on_mcp_call_arguments_delta(
1746
+ self,
1747
+ context: ToolContext,
1748
+ *,
1749
+ delta: dict,
1750
+ item_id: str,
1751
+ output_index: int,
1752
+ sequence_number: int,
1753
+ type: str,
1754
+ **extra,
1755
+ ):
1756
+ pass
1757
+
1758
+ def get_open_ai_output_handlers(self):
1759
+ return {
1760
+ "mcp_call": self.handle_mcp_call,
1761
+ "mcp_list_tools": self.handle_mcp_list_tools,
1762
+ "mcp_approval_request": self.handle_mcp_approval_request,
1763
+ }
1764
+
1765
+ async def on_mcp_list_tools(
1766
+ self,
1767
+ context: ToolContext,
1768
+ *,
1769
+ server_label: str,
1770
+ tools: list[MCPToolDefinition],
1771
+ error: str | None,
1772
+ **extra,
1773
+ ):
1774
+ pass
1775
+
1776
+ async def handle_mcp_list_tools(
1777
+ self,
1778
+ context,
1779
+ *,
1780
+ id: str,
1781
+ server_label: str,
1782
+ tools: list,
1783
+ type: str,
1784
+ error: str | None = None,
1785
+ **extra,
1786
+ ):
1787
+ mcp_tools = []
1788
+ for tool in tools:
1789
+ mcp_tools.append(
1790
+ MCPToolDefinition(
1791
+ input_schema=tool["input_schema"],
1792
+ name=tool["name"],
1793
+ annotations=tool["annotations"],
1794
+ description=tool["description"],
1795
+ )
1796
+ )
1797
+
1798
+ await self.on_mcp_list_tools(
1799
+ context, server_label=server_label, tools=mcp_tools, error=error
1800
+ )
1801
+
1802
+ async def on_mcp_call(
1803
+ self,
1804
+ context: ToolContext,
1805
+ *,
1806
+ name: str,
1807
+ arguments: str,
1808
+ server_label: str,
1809
+ error: str | None,
1810
+ output: str | None,
1811
+ **extra,
1812
+ ):
1813
+ pass
1814
+
1815
+ async def handle_mcp_call(
1816
+ self,
1817
+ context,
1818
+ *,
1819
+ arguments: str,
1820
+ id: str,
1821
+ name: str,
1822
+ server_label: str,
1823
+ type: str,
1824
+ error: str | None,
1825
+ output: str | None,
1826
+ **extra,
1827
+ ):
1828
+ await self.on_mcp_call(
1829
+ context,
1830
+ name=name,
1831
+ arguments=arguments,
1832
+ server_label=server_label,
1833
+ error=error,
1834
+ output=output,
1835
+ )
1836
+
1837
+ async def on_mcp_approval_request(
1838
+ self,
1839
+ context: ToolContext,
1840
+ *,
1841
+ name: str,
1842
+ arguments: str,
1843
+ server_label: str,
1844
+ **extra,
1845
+ ) -> bool:
1846
+ return True
1847
+
1848
+ async def handle_mcp_approval_request(
1849
+ self,
1850
+ context: ToolContext,
1851
+ *,
1852
+ arguments: str,
1853
+ id: str,
1854
+ name: str,
1855
+ server_label: str,
1856
+ type: str,
1857
+ **extra,
1858
+ ):
1859
+ logger.info(f"approval requested for MCP tool {server_label}.{name}")
1860
+ should_approve = await self.on_mcp_approval_request(
1861
+ context, arguments=arguments, name=name, server_label=server_label
1862
+ )
1863
+ if should_approve:
1864
+ logger.info(f"approval granted for MCP tool {server_label}.{name}")
1865
+ return {
1866
+ "type": "mcp_approval_response",
1867
+ "approve": True,
1868
+ "approval_request_id": id,
1869
+ }
1870
+ else:
1871
+ logger.info(f"approval denied for MCP tool {server_label}.{name}")
1872
+ return {
1873
+ "type": "mcp_approval_response",
1874
+ "approve": False,
1875
+ "approval_request_id": id,
1876
+ }
1877
+
1878
+
1879
+ class ReasoningTool(OpenAIResponsesTool):
1880
+ def __init__(self):
1881
+ super().__init__(name="reasoning")
1882
+
1883
+ def get_open_ai_output_handlers(self):
1884
+ return {
1885
+ "reasoning": self.handle_reasoning,
1886
+ }
1887
+
1888
+ def get_open_ai_stream_callbacks(self):
1889
+ return {
1890
+ "response.reasoning_summary_text.done": self.on_reasoning_summary_text_done,
1891
+ "response.reasoning_summary_text.delta": self.on_reasoning_summary_text_delta,
1892
+ "response.reasoning_summary_part.done": self.on_reasoning_summary_part_done,
1893
+ "response.reasoning_summary_part.added": self.on_reasoning_summary_part_added,
1894
+ }
1895
+
1896
+ async def on_reasoning_summary_part_added(
1897
+ self,
1898
+ context: ToolContext,
1899
+ *,
1900
+ item_id: str,
1901
+ output_index: int,
1902
+ part: dict,
1903
+ sequence_number: int,
1904
+ summary_index: int,
1905
+ type: str,
1906
+ **extra,
1907
+ ):
1908
+ pass
1909
+
1910
+ async def on_reasoning_summary_part_done(
1911
+ self,
1912
+ context: ToolContext,
1913
+ *,
1914
+ item_id: str,
1915
+ output_index: int,
1916
+ part: dict,
1917
+ sequence_number: int,
1918
+ summary_index: int,
1919
+ type: str,
1920
+ **extra,
1921
+ ):
1922
+ pass
1923
+
1924
+ async def on_reasoning_summary_text_delta(
1925
+ self,
1926
+ context: ToolContext,
1927
+ *,
1928
+ delta: str,
1929
+ output_index: int,
1930
+ sequence_number: int,
1931
+ summary_index: int,
1932
+ type: str,
1933
+ **extra,
1934
+ ):
1935
+ pass
1936
+
1937
+ async def on_reasoning_summary_text_done(
1938
+ self,
1939
+ context: ToolContext,
1940
+ *,
1941
+ item_id: str,
1942
+ output_index: int,
1943
+ sequence_number: int,
1944
+ summary_index: int,
1945
+ type: str,
1946
+ **extra,
1947
+ ):
1948
+ pass
1949
+
1950
+ async def on_reasoning(
1951
+ self,
1952
+ context: ToolContext,
1953
+ *,
1954
+ summary: list[str],
1955
+ content: Optional[list[str]] = None,
1956
+ encrypted_content: str | None,
1957
+ status: Literal["in_progress", "completed", "incomplete"],
1958
+ ):
1959
+ pass
1960
+
1961
+ async def handle_reasoning(
1962
+ self,
1963
+ context: ToolContext,
1964
+ *,
1965
+ id: str,
1966
+ summary: list[dict],
1967
+ type: str,
1968
+ content: Optional[list[dict]],
1969
+ encrypted_content: str | None,
1970
+ status: str,
1971
+ **extra,
1972
+ ):
1973
+ await self.on_reasoning(
1974
+ context,
1975
+ summary=summary,
1976
+ content=content,
1977
+ encrypted_content=encrypted_content,
1978
+ status=status,
1979
+ )
1980
+
1981
+
1982
+ # TODO: computer tool call
1983
+
1984
+
1985
+ class WebSearchConfig(ToolkitConfig):
1986
+ name: Literal["web_search"] = "web_search"
1987
+
1988
+
1989
+ class WebSearchToolkitBuilder(ToolkitBuilder):
1990
+ def __init__(self):
1991
+ super().__init__(name="web_search", type=WebSearchConfig)
1992
+
1993
+ async def make(self, *, room: RoomClient, model: str, config: WebSearchConfig):
1994
+ return Toolkit(name="web_search", tools=[WebSearchTool(config=config)])
1995
+
1996
+
1997
+ class WebSearchTool(OpenAIResponsesTool):
1998
+ def __init__(self, *, config: Optional[WebSearchConfig] = None):
1999
+ if config is None:
2000
+ config = WebSearchConfig(name="web_search")
2001
+ super().__init__(name="web_search")
2002
+
2003
+ def get_open_ai_tool_definitions(self) -> list[dict]:
2004
+ return [{"type": "web_search_preview"}]
2005
+
2006
+ def get_open_ai_stream_callbacks(self):
2007
+ return {
2008
+ "response.web_search_call.in_progress": self.on_web_search_call_in_progress,
2009
+ "response.web_search_call.searching": self.on_web_search_call_searching,
2010
+ "response.web_search_call.completed": self.on_web_search_call_completed,
2011
+ }
2012
+
2013
+ def get_open_ai_output_handlers(self):
2014
+ return {"web_search_call": self.handle_web_search_call}
2015
+
2016
+ async def on_web_search_call_in_progress(
2017
+ self,
2018
+ context: ToolContext,
2019
+ *,
2020
+ item_id: str,
2021
+ output_index: int,
2022
+ sequence_number: int,
2023
+ type: str,
2024
+ **extra,
2025
+ ):
2026
+ pass
2027
+
2028
+ async def on_web_search_call_searching(
2029
+ self,
2030
+ context: ToolContext,
2031
+ *,
2032
+ item_id: str,
2033
+ output_index: int,
2034
+ sequence_number: int,
2035
+ type: str,
2036
+ **extra,
2037
+ ):
2038
+ pass
2039
+
2040
+ async def on_web_search_call_completed(
2041
+ self,
2042
+ context: ToolContext,
2043
+ *,
2044
+ item_id: str,
2045
+ output_index: int,
2046
+ sequence_number: int,
2047
+ type: str,
2048
+ **extra,
2049
+ ):
2050
+ pass
2051
+
2052
+ async def on_web_search(self, context: ToolContext, *, status: str, **extra):
2053
+ pass
2054
+
2055
+ async def handle_web_search_call(
2056
+ self, context: ToolContext, *, id: str, status: str, type: str, **extra
2057
+ ):
2058
+ await self.on_web_search(context, status=status)
2059
+
2060
+
2061
+ class FileSearchResult:
2062
+ def __init__(
2063
+ self, *, attributes: dict, file_id: str, filename: str, score: float, text: str
2064
+ ):
2065
+ self.attributes = attributes
2066
+ self.file_id = file_id
2067
+ self.filename = filename
2068
+ self.score = score
2069
+ self.text = text
2070
+
2071
+
2072
+ class FileSearchTool(OpenAIResponsesTool):
2073
+ def __init__(
2074
+ self,
2075
+ *,
2076
+ vector_store_ids: list[str],
2077
+ filters: Optional[dict] = None,
2078
+ max_num_results: Optional[int] = None,
2079
+ ranking_options: Optional[dict] = None,
2080
+ ):
2081
+ super().__init__(name="file_search")
2082
+
2083
+ self.vector_store_ids = vector_store_ids
2084
+ self.filters = filters
2085
+ self.max_num_results = max_num_results
2086
+ self.ranking_options = ranking_options
2087
+
2088
+ def get_open_ai_tool_definitions(self) -> list[dict]:
2089
+ return [
2090
+ {
2091
+ "type": "file_search",
2092
+ "vector_store_ids": self.vector_store_ids,
2093
+ "filters": self.filters,
2094
+ "max_num_results": self.max_num_results,
2095
+ "ranking_options": self.ranking_options,
2096
+ }
2097
+ ]
2098
+
2099
+ def get_open_ai_stream_callbacks(self):
2100
+ return {
2101
+ "response.file_search_call.in_progress": self.on_file_search_call_in_progress,
2102
+ "response.file_search_call.searching": self.on_file_search_call_searching,
2103
+ "response.file_search_call.completed": self.on_file_search_call_completed,
2104
+ }
2105
+
2106
+ def get_open_ai_output_handlers(self):
2107
+ return {"handle_file_search_call": self.handle_file_search_call}
2108
+
2109
+ async def on_file_search_call_in_progress(
2110
+ self,
2111
+ context: ToolContext,
2112
+ *,
2113
+ item_id: str,
2114
+ output_index: int,
2115
+ sequence_number: int,
2116
+ type: str,
2117
+ **extra,
2118
+ ):
2119
+ pass
2120
+
2121
+ async def on_file_search_call_searching(
2122
+ self,
2123
+ context: ToolContext,
2124
+ *,
2125
+ item_id: str,
2126
+ output_index: int,
2127
+ sequence_number: int,
2128
+ type: str,
2129
+ **extra,
2130
+ ):
2131
+ pass
2132
+
2133
+ async def on_file_search_call_completed(
2134
+ self,
2135
+ context: ToolContext,
2136
+ *,
2137
+ item_id: str,
2138
+ output_index: int,
2139
+ sequence_number: int,
2140
+ type: str,
2141
+ **extra,
2142
+ ):
2143
+ pass
2144
+
2145
+ async def on_file_search(
2146
+ self,
2147
+ context: ToolContext,
2148
+ *,
2149
+ queries: list,
2150
+ results: list[FileSearchResult],
2151
+ status: Literal["in_progress", "searching", "incomplete", "failed"],
2152
+ ):
2153
+ pass
2154
+
2155
+ async def handle_file_search_call(
2156
+ self,
2157
+ context: ToolContext,
2158
+ *,
2159
+ id: str,
2160
+ queries: list,
2161
+ status: str,
2162
+ results: dict | None,
2163
+ type: str,
2164
+ **extra,
2165
+ ):
2166
+ search_results = None
2167
+ if results is not None:
2168
+ search_results = []
2169
+ for result in results:
2170
+ search_results.append(FileSearchResult(**result))
2171
+
2172
+ await self.on_file_search(
2173
+ context, queries=queries, results=search_results, status=status
2174
+ )
2175
+
2176
+
2177
+ class ApplyPatchConfig(ToolkitConfig):
2178
+ name: Literal["apply_patch"] = "apply_patch"
2179
+
2180
+
2181
+ class ApplyPatchToolkitBuilder(ToolkitBuilder):
2182
+ def __init__(self):
2183
+ super().__init__(name="apply_patch", type=ApplyPatchConfig)
2184
+
2185
+ async def make(self, *, room: RoomClient, model: str, config: ApplyPatchConfig):
2186
+ return Toolkit(name="apply_patch", tools=[ApplyPatchTool(config=config)])
2187
+
2188
+
2189
+ class ApplyPatchTool(OpenAIResponsesTool):
2190
+ """
2191
+ Wrapper for the built-in `apply_patch` tool.
2192
+
2193
+ The model will emit `apply_patch_call` items whenever it wants to create,
2194
+ update, or delete a file using a unified diff. The server / host
2195
+ environment is expected to actually apply the patch and, if desired,
2196
+ log results via `apply_patch_call_output`.
2197
+
2198
+ The two key handler entrypoints you can override are:
2199
+
2200
+ * `on_apply_patch_call` – called when the model requests a patch
2201
+ * `on_apply_patch_call_output` – called when the tool emits a log/output item
2202
+ """
2203
+
2204
+ def __init__(self, *, config: ApplyPatchConfig):
2205
+ super().__init__(name="apply_patch")
2206
+
2207
+ # Tool definition advertised to OpenAI
2208
+ def get_open_ai_tool_definitions(self) -> list[dict]:
2209
+ # No extra options for now – the built-in tool just needs the type
2210
+ return [{"type": "apply_patch"}]
2211
+
2212
+ # Stream callbacks for `response.apply_patch_call.*` events
2213
+ def get_open_ai_stream_callbacks(self):
2214
+ return {
2215
+ "response.apply_patch_call.in_progress": self.on_apply_patch_call_in_progress,
2216
+ "response.apply_patch_call.completed": self.on_apply_patch_call_completed,
2217
+ }
2218
+
2219
+ # Output handlers for item types
2220
+ def get_open_ai_output_handlers(self):
2221
+ return {
2222
+ # The tool call itself (what to apply)
2223
+ "apply_patch_call": self.handle_apply_patch_call,
2224
+ }
2225
+
2226
+ # --- Stream callbacks -------------------------------------------------
2227
+
2228
+ # response.apply_patch_call.in_progress
2229
+ async def on_apply_patch_call_in_progress(
2230
+ self,
2231
+ context: ToolContext,
2232
+ *,
2233
+ item_id: str,
2234
+ output_index: int,
2235
+ sequence_number: int,
2236
+ type: str,
2237
+ **extra,
2238
+ ):
2239
+ # Default: no-op, but you can log progress / show UI here if you want
2240
+ pass
2241
+
2242
+ # response.apply_patch_call.completed
2243
+ async def on_apply_patch_call_completed(
2244
+ self,
2245
+ context: ToolContext,
2246
+ *,
2247
+ item_id: str,
2248
+ output_index: int,
2249
+ sequence_number: int,
2250
+ type: str,
2251
+ **extra,
2252
+ ):
2253
+ # Default: no-op
2254
+ pass
2255
+
2256
+ # --- High-level hooks -------------------------------------------------
2257
+
2258
+ async def on_apply_patch_call(
2259
+ self,
2260
+ context: ToolContext,
2261
+ *,
2262
+ call_id: str,
2263
+ operation: dict,
2264
+ status: str,
2265
+ **extra,
2266
+ ):
2267
+ """
2268
+ Called when the model requests an apply_patch operation.
2269
+
2270
+ operation looks like one of:
2271
+
2272
+ create_file:
2273
+ {
2274
+ "type": "create_file",
2275
+ "path": "relative/path/to/file",
2276
+ "diff": "...unified diff..."
2277
+ }
2278
+
2279
+ update_file:
2280
+ {
2281
+ "type": "update_file",
2282
+ "path": "relative/path/to/file",
2283
+ "diff": "...unified diff..."
2284
+ }
2285
+
2286
+ delete_file:
2287
+ {
2288
+ "type": "delete_file",
2289
+ "path": "relative/path/to/file"
2290
+ }
2291
+ """
2292
+ # Override this to actually apply the patch in your workspace.
2293
+ # Default is no-op.
2294
+
2295
+ from meshagent.openai.tools.apply_patch import apply_diff
2296
+
2297
+ if operation["type"] == "delete_file":
2298
+ path = operation["path"]
2299
+ logger.info(f"applying patch: deleting file {path}")
2300
+ await context.room.storage.delete(path=path)
2301
+ log = f"Deleted file: {path}"
2302
+ logger.info(log)
2303
+ return {"status": "completed", "output": log}
2304
+
2305
+ elif operation["type"] == "create_file":
2306
+ diff = operation["diff"]
2307
+ path = operation["path"]
2308
+ logger.info(f"applying patch: creating file {path} with {diff}")
2309
+ handle = await context.room.storage.open(path=path, overwrite=False)
2310
+ try:
2311
+ patched = apply_diff("", diff, "create")
2312
+ except Exception as ex:
2313
+ return {"status": "failed", "output": f"{ex}"}
2314
+ await context.room.storage.write(handle=handle, data=patched.encode())
2315
+ await context.room.storage.close(handle=handle)
2316
+
2317
+ log = f"Created file: {path} ({len(patched)} bytes)"
2318
+ logger.info(log)
2319
+ return {"status": "completed", "output": log}
2320
+
2321
+ elif operation["type"] == "update_file":
2322
+ path = operation["path"]
2323
+ content = await context.room.storage.download(path=path)
2324
+ text = content.data.decode()
2325
+ diff = operation["diff"]
2326
+
2327
+ logger.info(f"applying patch: updating file {path} with {diff}")
2328
+
2329
+ try:
2330
+ patched = apply_diff(text, diff)
2331
+ except Exception as ex:
2332
+ return {"status": "failed", "output": f"{ex}"}
2333
+
2334
+ handle = await context.room.storage.open(path=path, overwrite=True)
2335
+ await context.room.storage.write(handle=handle, data=patched.encode())
2336
+ await context.room.storage.close(handle=handle)
2337
+
2338
+ log = f"Updated file: {path} ({len(text)} -> {len(patched)} bytes)"
2339
+ logger.info(log)
2340
+ return {"status": "completed", "output": log}
2341
+
2342
+ # apply patch
2343
+ else:
2344
+ raise Exception(f"Unexpected patch operation {operation}")
2345
+
2346
+ async def handle_apply_patch_call(
2347
+ self,
2348
+ context: ToolContext,
2349
+ *,
2350
+ call_id: str,
2351
+ operation: dict,
2352
+ status: str,
2353
+ type: str,
2354
+ id: str | None = None,
2355
+ **extra,
2356
+ ):
2357
+ result = await self.on_apply_patch_call(
2358
+ context,
2359
+ call_id=call_id,
2360
+ operation=operation,
2361
+ status=status,
2362
+ **extra,
2363
+ )
2364
+
2365
+ return {
2366
+ "type": "apply_patch_call_output",
2367
+ "call_id": call_id,
2368
+ **result,
2369
+ }