ailoy-py 0.0.1__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ailoy/agent.py ADDED
@@ -0,0 +1,729 @@
1
+ import json
2
+ import subprocess
3
+ import warnings
4
+ from abc import ABC, abstractmethod
5
+ from collections.abc import Awaitable, Callable, Generator
6
+ from pathlib import Path
7
+ from typing import (
8
+ Any,
9
+ Literal,
10
+ Optional,
11
+ TypeVar,
12
+ Union,
13
+ )
14
+ from urllib.parse import urlencode, urlparse, urlunparse
15
+
16
+ import jmespath
17
+ import mcp
18
+ import mcp.types as mcp_types
19
+ from pydantic import BaseModel, ConfigDict, Field
20
+ from rich.console import Console
21
+ from rich.panel import Panel
22
+
23
+ from ailoy.ailoy_py import generate_uuid
24
+ from ailoy.runtime import Runtime
25
+
26
+ __all__ = ["Agent"]
27
+
28
+ ## Types for OpenAI API-compatible data structures
29
+
30
+
31
+ class SystemMessage(BaseModel):
32
+ role: Literal["system"]
33
+ content: str
34
+
35
+
36
+ class UserMessage(BaseModel):
37
+ role: Literal["user"]
38
+ content: str
39
+
40
+
41
+ class AIOutputTextMessage(BaseModel):
42
+ role: Literal["assistant"]
43
+ content: str
44
+ reasoning: Optional[bool] = None
45
+
46
+
47
+ class AIToolCallMessage(BaseModel):
48
+ role: Literal["assistant"]
49
+ content: None
50
+ tool_calls: list["ToolCall"]
51
+
52
+
53
+ class ToolCall(BaseModel):
54
+ id: str
55
+ type: Literal["function"] = "function"
56
+ function: "ToolCallFunction"
57
+
58
+
59
+ class ToolCallFunction(BaseModel):
60
+ name: str
61
+ arguments: dict[str, Any]
62
+
63
+
64
+ class ToolCallResultMessage(BaseModel):
65
+ role: Literal["tool"]
66
+ name: str
67
+ tool_call_id: str
68
+ content: str
69
+
70
+
71
+ Message = Union[
72
+ SystemMessage,
73
+ UserMessage,
74
+ AIOutputTextMessage,
75
+ AIToolCallMessage,
76
+ ToolCallResultMessage,
77
+ ]
78
+
79
+
80
+ class MessageDelta(BaseModel):
81
+ finish_reason: Optional[Literal["stop", "tool_calls", "length", "error"]]
82
+ message: Message
83
+
84
+
85
+ ## Types for LLM Model Definitions
86
+
87
+ TVMModelName = Literal["Qwen/Qwen3-0.6B", "Qwen/Qwen3-1.7B", "Qwen/Qwen3-4B", "Qwen/Qwen3-8B"]
88
+ OpenAIModelName = Literal["gpt-4o"]
89
+ ModelName = Union[TVMModelName, OpenAIModelName]
90
+
91
+
92
+ class TVMModel(BaseModel):
93
+ name: TVMModelName
94
+ quantization: Optional[Literal["q4f16_1"]] = None
95
+ mode: Optional[Literal["interactive"]] = None
96
+
97
+
98
+ class OpenAIModel(BaseModel):
99
+ name: OpenAIModelName
100
+ api_key: str
101
+
102
+
103
+ class ModelDescription(BaseModel):
104
+ model_id: str
105
+ component_type: str
106
+ default_system_message: Optional[str] = None
107
+
108
+
109
+ model_descriptions: dict[ModelName, ModelDescription] = {
110
+ "Qwen/Qwen3-0.6B": ModelDescription(
111
+ model_id="Qwen/Qwen3-0.6B",
112
+ component_type="tvm_language_model",
113
+ default_system_message="You are Qwen, created by Alibaba Cloud. You are a helpful assistant.",
114
+ ),
115
+ "Qwen/Qwen3-1.7B": ModelDescription(
116
+ model_id="Qwen/Qwen3-1.7B",
117
+ component_type="tvm_language_model",
118
+ default_system_message="You are Qwen, created by Alibaba Cloud. You are a helpful assistant.",
119
+ ),
120
+ "Qwen/Qwen3-4B": ModelDescription(
121
+ model_id="Qwen/Qwen3-4B",
122
+ component_type="tvm_language_model",
123
+ default_system_message="You are Qwen, created by Alibaba Cloud. You are a helpful assistant.",
124
+ ),
125
+ "Qwen/Qwen3-8B": ModelDescription(
126
+ model_id="Qwen/Qwen3-8B",
127
+ component_type="tvm_language_model",
128
+ default_system_message="You are Qwen, created by Alibaba Cloud. You are a helpful assistant.",
129
+ ),
130
+ "gpt-4o": ModelDescription(
131
+ model_id="gpt-4o",
132
+ component_type="openai",
133
+ ),
134
+ }
135
+
136
+
137
+ class ComponentState(BaseModel):
138
+ name: str
139
+ valid: bool
140
+
141
+
142
+ ## Types for agent's responses
143
+
144
+ _console = Console(highlight=False)
145
+
146
+
147
+ class AgentResponseBase(BaseModel):
148
+ type: Literal["output_text", "tool_call", "tool_call_result", "reasoning", "error"]
149
+ end_of_turn: bool
150
+ role: Literal["assistant", "tool"]
151
+ content: Any
152
+
153
+ def print(self):
154
+ raise NotImplementedError
155
+
156
+
157
+ class AgentResponseOutputText(AgentResponseBase):
158
+ type: Literal["output_text", "reasoning"]
159
+ role: Literal["assistant"]
160
+ content: str
161
+
162
+ def print(self):
163
+ _console.print(self.content, end="", style=("yellow" if self.type == "reasoning" else None))
164
+ if self.end_of_turn:
165
+ _console.print()
166
+
167
+
168
+ class AgentResponseToolCall(AgentResponseBase):
169
+ type: Literal["tool_call"]
170
+ role: Literal["assistant"]
171
+ content: ToolCall
172
+
173
+ def print(self):
174
+ panel = Panel(
175
+ json.dumps(self.content.function.arguments, indent=2),
176
+ title=f"[magenta]Tool Call[/magenta]: [bold]{self.content.function.name}[/bold] ({self.content.id})",
177
+ title_align="left",
178
+ )
179
+ _console.print(panel)
180
+
181
+
182
+ class AgentResponseToolCallResult(AgentResponseBase):
183
+ type: Literal["tool_call_result"]
184
+ role: Literal["tool"]
185
+ content: ToolCallResultMessage
186
+
187
+ def print(self):
188
+ try:
189
+ # Try to parse as json
190
+ content = json.dumps(json.loads(self.content.content), indent=2)
191
+ except json.JSONDecodeError:
192
+ # Use original content if not json deserializable
193
+ content = self.content.content
194
+ # Truncate long contents
195
+ if len(content) > 500:
196
+ content = content[:500] + "...(truncated)"
197
+
198
+ panel = Panel(
199
+ content,
200
+ title=f"[green]Tool Result[/green]: [bold]{self.content.name}[/bold] ({self.content.tool_call_id})",
201
+ title_align="left",
202
+ )
203
+ _console.print(panel)
204
+
205
+
206
+ class AgentResponseError(AgentResponseBase):
207
+ type: Literal["error"]
208
+ role: Literal["assistant"]
209
+ content: str
210
+
211
+ def print(self):
212
+ panel = Panel(
213
+ self.content,
214
+ title="[bold red]Error[/bold red]",
215
+ )
216
+ _console.print(panel)
217
+
218
+
219
+ AgentResponse = Union[
220
+ AgentResponseOutputText,
221
+ AgentResponseToolCall,
222
+ AgentResponseToolCallResult,
223
+ AgentResponseError,
224
+ ]
225
+
226
+ ## Types and functions related to Tools
227
+
228
+ ToolDefinition = Union["BuiltinToolDefinition", "RESTAPIToolDefinition"]
229
+
230
+
231
+ class ToolDescription(BaseModel):
232
+ name: str
233
+ description: str
234
+ parameters: "ToolParameters"
235
+ return_type: Optional[dict[str, Any]] = Field(default=None, alias="return")
236
+ model_config = ConfigDict(populate_by_name=True)
237
+
238
+
239
+ class ToolParameters(BaseModel):
240
+ type: Literal["object"]
241
+ properties: dict[str, "ToolParametersProperty"]
242
+ required: Optional[list[str]] = []
243
+
244
+
245
+ class ToolParametersProperty(BaseModel):
246
+ type: Literal["string", "number", "boolean", "object", "array", "null"]
247
+ description: Optional[str] = None
248
+ model_config = ConfigDict(extra="allow")
249
+
250
+
251
+ class BuiltinToolDefinition(BaseModel):
252
+ type: Literal["builtin"]
253
+ description: ToolDescription
254
+ behavior: "BuiltinToolBehavior"
255
+
256
+
257
+ class BuiltinToolBehavior(BaseModel):
258
+ output_path: Optional[str] = Field(default=None, alias="outputPath")
259
+ model_config = ConfigDict(populate_by_name=True)
260
+
261
+
262
+ class RESTAPIToolDefinition(BaseModel):
263
+ type: Literal["restapi"]
264
+ description: ToolDescription
265
+ behavior: "RESTAPIBehavior"
266
+
267
+
268
+ class RESTAPIBehavior(BaseModel):
269
+ base_url: str = Field(alias="baseURL")
270
+ method: Literal["GET", "POST", "PUT", "DELETE"]
271
+ authentication: Optional[Literal["bearer"]] = None
272
+ headers: Optional[dict[str, str]] = None
273
+ body: Optional[str] = None
274
+ output_path: Optional[str] = Field(default=None, alias="outputPath")
275
+ model_config = ConfigDict(populate_by_name=True)
276
+
277
+
278
+ class Tool:
279
+ def __init__(
280
+ self,
281
+ desc: ToolDescription,
282
+ call_fn: Callable[..., Any],
283
+ ):
284
+ self.desc = desc
285
+ self.call = call_fn
286
+
287
+
288
+ class ToolAuthenticator(ABC):
289
+ def __call__(self, request: dict[str, Any]) -> dict[str, Any]:
290
+ return self.apply(request)
291
+
292
+ @abstractmethod
293
+ def apply(self, request: dict[str, Any]) -> dict[str, Any]:
294
+ pass
295
+
296
+
297
+ class BearerAuthenticator(ToolAuthenticator):
298
+ def __init__(self, token: str, bearer_format: str = "Bearer"):
299
+ self.token = token
300
+ self.bearer_format = bearer_format
301
+
302
+ def apply(self, request: dict[str, Any]) -> dict[str, Any]:
303
+ headers = request.get("headers", {})
304
+ headers["Authorization"] = f"{self.bearer_format} {self.token}"
305
+ return {**request, "headers": headers}
306
+
307
+
308
+ T_Retval = TypeVar("T_Retval")
309
+
310
+
311
+ def run_async(coro: Callable[..., Awaitable[T_Retval]]) -> T_Retval:
312
+ try:
313
+ import anyio
314
+
315
+ # Running outside async loop
316
+ return anyio.run(lambda: coro)
317
+ except RuntimeError:
318
+ import anyio.from_thread
319
+
320
+ # Already in a running event loop: use anyio from_thread
321
+ return anyio.from_thread.run(coro)
322
+
323
+
324
+ class Agent:
325
+ """
326
+ The `Agent` class provides a high-level interface for interacting with large language models (LLMs) in Ailoy.
327
+ It abstracts the underlying runtime and VM logic, allowing users to easily send queries and receive streaming
328
+ responses.
329
+
330
+ Agents can be extended with external tools or APIs to provide real-time or domain-specific knowledge, enabling
331
+ more powerful and context-aware interactions.
332
+ """
333
+
334
+ def __init__(
335
+ self,
336
+ runtime: Runtime,
337
+ model_name: ModelName,
338
+ system_message: Optional[str] = None,
339
+ api_key: Optional[str] = None,
340
+ attrs: dict[str, Any] = dict(),
341
+ ):
342
+ """
343
+ Create an instance.
344
+
345
+ :param runtime: The runtime environment associated with the agent.
346
+ :param model_name: The name of the LLM model to use.
347
+ :param system_message: Optional system message to set the initial assistant context.
348
+ :param api_key: (web agent only) The API key for AI API.
349
+ :param attrs: Additional initialization parameters (for `define_component` runtime call)
350
+ :raises ValueError: If model name is not supported or validation fails.
351
+ """
352
+ self._runtime = runtime
353
+
354
+ # Initialize component state
355
+ self._component_state = ComponentState(
356
+ name=generate_uuid(),
357
+ valid=False,
358
+ )
359
+
360
+ # Initialize messages
361
+ self._messages: list[Message] = []
362
+ if system_message:
363
+ self._messages.append(SystemMessage(role="system", content=system_message))
364
+
365
+ # Initialize tools
366
+ self._tools: list[Tool] = []
367
+
368
+ # Define the component
369
+ self.define(model_name, api_key=api_key, attrs=attrs)
370
+
371
+ def __del__(self):
372
+ self.delete()
373
+
374
+ def __enter__(self):
375
+ return self
376
+
377
+ def __exit__(self, type, value, traceback):
378
+ self.delete()
379
+
380
+ def define(self, model_name: ModelName, api_key: Optional[str] = None, attrs: dict[str, Any] = dict()) -> None:
381
+ """
382
+ Initializes the agent by defining its model in the runtime.
383
+ This must be called before running the agent. If already initialized, this is a no-op.
384
+ :param model_name: The name of the LLM model to use.
385
+ :param api_key: (web agent only) The API key for AI API.
386
+ :param attrs: Additional initialization parameters (for `define_component` runtime call)
387
+ """
388
+ if self._component_state.valid:
389
+ return
390
+
391
+ if model_name not in model_descriptions:
392
+ raise ValueError(f"Model `{model_name}` not supported")
393
+
394
+ model_desc = model_descriptions[model_name]
395
+
396
+ # Add model name into attrs
397
+ if "model" not in attrs:
398
+ attrs["model"] = model_desc.model_id
399
+
400
+ # Set default system message
401
+ if len(self._messages) == 0 and model_desc.default_system_message:
402
+ self._messages.append(SystemMessage(role="system", content=model_desc.default_system_message))
403
+
404
+ # Add API key
405
+ if api_key:
406
+ attrs["api_key"] = api_key
407
+
408
+ # Call runtime's define
409
+ self._runtime.define(
410
+ model_descriptions[model_name].component_type,
411
+ self._component_state.name,
412
+ attrs,
413
+ )
414
+
415
+ # Mark as defined
416
+ self._component_state.valid = True
417
+
418
+ def delete(self) -> None:
419
+ """
420
+ Deinitializes the agent and releases resources in the runtime.
421
+ This should be called when the agent is no longer needed. If already deinitialized, this is a no-op.
422
+ """
423
+ if not self._component_state.valid:
424
+ return
425
+ self._runtime.delete(self._component_state.name)
426
+ if len(self._messages) > 0 and self._messages[0].role == "system":
427
+ self._messages = [self._messages[0]]
428
+ else:
429
+ self._messages = []
430
+ self._component_state.valid = False
431
+
432
+ def query(
433
+ self,
434
+ message: str,
435
+ enable_reasoning: bool = False,
436
+ ignore_reasoning_messages: bool = False,
437
+ ) -> Generator[AgentResponse, None, None]:
438
+ """
439
+ Runs the agent with a new user message and yields streamed responses.
440
+
441
+ :param message: The user message to send to the model.
442
+ :param enable_reasoning: If True, enables reasoning capabilities. (default: False)
443
+ :param ignore_reasoning_messages: If True, reasoning steps are not included in the response stream. (default: False)
444
+ :yield: AgentResponse output of the LLM inference or tool calls
445
+ """ # noqa: E501
446
+ self._messages.append(UserMessage(role="user", content=message))
447
+
448
+ while True:
449
+ infer_args = {
450
+ "messages": [msg.model_dump() for msg in self._messages],
451
+ "tools": [{"type": "function", "function": t.desc.model_dump()} for t in self._tools],
452
+ }
453
+ if enable_reasoning:
454
+ infer_args["enable_reasoning"] = enable_reasoning
455
+ if ignore_reasoning_messages:
456
+ infer_args["ignore_reasoning_messages"] = ignore_reasoning_messages
457
+
458
+ for resp in self._runtime.call_iter_method(self._component_state.name, "infer", infer_args):
459
+ delta = MessageDelta.model_validate(resp)
460
+
461
+ if delta.finish_reason is None:
462
+ output_msg = AIOutputTextMessage.model_validate(delta.message)
463
+ yield AgentResponseOutputText(
464
+ type="reasoning" if output_msg.reasoning else "output_text",
465
+ end_of_turn=False,
466
+ role="assistant",
467
+ content=output_msg.content,
468
+ )
469
+ continue
470
+
471
+ if delta.finish_reason == "tool_calls":
472
+ tool_call_message = AIToolCallMessage.model_validate(delta.message)
473
+ self._messages.append(tool_call_message)
474
+
475
+ for tool_call in tool_call_message.tool_calls:
476
+ yield AgentResponseToolCall(
477
+ type="tool_call",
478
+ end_of_turn=True,
479
+ role="assistant",
480
+ content=tool_call,
481
+ )
482
+
483
+ tool_call_results: list[ToolCallResultMessage] = []
484
+
485
+ def run_tool(tool_call: ToolCall):
486
+ tool_ = next(
487
+ (t for t in self._tools if t.desc.name == tool_call.function.name),
488
+ None,
489
+ )
490
+ if not tool_:
491
+ raise RuntimeError("Tool not found")
492
+ resp = tool_.call(**tool_call.function.arguments)
493
+ return ToolCallResultMessage(
494
+ role="tool",
495
+ name=tool_call.function.name,
496
+ tool_call_id=tool_call.id,
497
+ content=json.dumps(resp),
498
+ )
499
+
500
+ tool_call_results = [run_tool(tc) for tc in tool_call_message.tool_calls]
501
+
502
+ for result_msg in tool_call_results:
503
+ self._messages.append(result_msg)
504
+ yield AgentResponseToolCallResult(
505
+ type="tool_call_result",
506
+ end_of_turn=True,
507
+ role="tool",
508
+ content=result_msg,
509
+ )
510
+
511
+ # Run infer again with new messages
512
+ break
513
+
514
+ if delta.finish_reason in ["stop", "length", "error"]:
515
+ output_msg = AIOutputTextMessage.model_validate(delta.message)
516
+ yield AgentResponseOutputText(
517
+ type="reasoning" if output_msg.reasoning else "output_text",
518
+ end_of_turn=True,
519
+ role="assistant",
520
+ content=output_msg.content,
521
+ )
522
+
523
+ # finish this Generator
524
+ return
525
+
526
+ def print(self, resp: AgentResponse):
527
+ resp.print()
528
+
529
+ def add_tool(self, tool: Tool) -> None:
530
+ """
531
+ Adds a custom tool to the agent.
532
+
533
+ :param tool: Tool instance to be added.
534
+ """
535
+ if any(t.desc.name == tool.desc.name for t in self._tools):
536
+ warnings.warn(f'Tool "{tool.desc.name}" is already added.')
537
+ return
538
+ self._tools.append(tool)
539
+
540
+ def add_py_function_tool(self, desc: dict, f: Callable[..., Any]):
541
+ """
542
+ Adds a Python function as a tool using callable.
543
+
544
+ :param desc: Tool descriotion.
545
+ :param f: Function will be called when the tool invocation occured.
546
+ """
547
+ self.add_tool(Tool(desc=ToolDescription.model_validate(desc), call_fn=f))
548
+
549
+ def add_builtin_tool(self, tool_def: BuiltinToolDefinition) -> bool:
550
+ """
551
+ Adds a built-in tool.
552
+
553
+ :param tool_def: The built-in tool definition.
554
+ :returns: True if the tool was successfully added.
555
+ :raises ValueError: If the tool type is not "builtin" or required inputs are missing.
556
+ """
557
+ if tool_def.type != "builtin":
558
+ raise ValueError('Tool type is not "builtin"')
559
+
560
+ def call(**inputs: dict[str, Any]) -> Any:
561
+ required = tool_def.description.parameters.required or []
562
+ for param_name in required:
563
+ if param_name not in inputs:
564
+ raise ValueError(f'Parameter "{param_name}" is required but not provided')
565
+
566
+ output = self._runtime.call(tool_def.description.name, inputs)
567
+ if tool_def.behavior.output_path is not None:
568
+ output = jmespath.search(tool_def.behavior.output_path, output)
569
+
570
+ return output
571
+
572
+ return self.add_tool(Tool(desc=tool_def.description, call_fn=call))
573
+
574
+ def add_restapi_tool(
575
+ self,
576
+ tool_def: RESTAPIToolDefinition,
577
+ authenticator: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
578
+ ) -> bool:
579
+ """
580
+ Adds a REST API tool that performs external HTTP requests.
581
+
582
+ :param tool_def: REST API tool definition.
583
+ :param authenticator: Optional authenticator to inject into the request.
584
+ :returns: True if the tool was successfully added.
585
+ :raises ValueError: If the tool type is not "restapi".
586
+ """
587
+ if tool_def.type != "restapi":
588
+ raise ValueError('Tool type is not "restapi"')
589
+
590
+ behavior = tool_def.behavior
591
+
592
+ def call(**inputs: dict[str, Any]) -> Any:
593
+ def render_template(template: str, context: dict[str, Any]) -> tuple[str, list[str]]:
594
+ import re
595
+
596
+ variables = set()
597
+
598
+ def replacer(match: re.Match):
599
+ key = match.group(1).strip()
600
+ variables.add(key)
601
+ return str(context.get(key, f"{{{key}}}"))
602
+
603
+ rendered_url = re.sub(r"\$\{\s*([^}\s]+)\s*\}", replacer, template)
604
+ return rendered_url, list(variables)
605
+
606
+ # Handle path parameters
607
+ url, path_vars = render_template(behavior.base_url, inputs)
608
+
609
+ # Handle body
610
+ if behavior.body is not None:
611
+ body, body_vars = render_template(behavior.body, inputs)
612
+ else:
613
+ body, body_vars = None, []
614
+
615
+ # Handle query parameters
616
+ query_params = {k: v for k, v in inputs.items() if k not in set(path_vars + body_vars)}
617
+
618
+ # Construct a full URL
619
+ full_url = urlunparse(urlparse(url)._replace(query=urlencode(query_params)))
620
+
621
+ # Construct a request payload
622
+ request = {
623
+ "url": full_url,
624
+ "method": behavior.method,
625
+ "headers": behavior.headers,
626
+ }
627
+ if body:
628
+ request["body"] = body
629
+
630
+ # Apply authentication
631
+ if callable(authenticator):
632
+ request = authenticator(request)
633
+
634
+ # Call HTTP request
635
+ output = None
636
+ resp = self._runtime.call("http_request", request)
637
+ output = json.loads(resp["body"])
638
+
639
+ # Parse output path if defined
640
+ if behavior.output_path is not None:
641
+ output = jmespath.search(tool_def.behavior.output_path, output)
642
+
643
+ return output
644
+
645
+ return self.add_tool(Tool(desc=tool_def.description, call_fn=call))
646
+
647
+ def add_tools_from_preset(
648
+ self, preset_name: str, authenticator: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None
649
+ ):
650
+ """
651
+ Loads tools from a predefined JSON preset file.
652
+
653
+ :param preset_name: Name of the tool preset.
654
+ :param authenticator: Optional authenticator to use for REST API tools.
655
+ :raises ValueError: If the preset file is not found.
656
+ """
657
+ tool_presets_path = Path(__file__).parent / "presets" / "tools"
658
+ preset_json = tool_presets_path / f"{preset_name}.json"
659
+ if not preset_json.exists():
660
+ raise ValueError(f'Tool preset "{preset_name}" does not exist')
661
+
662
+ data: dict[str, dict[str, Any]] = json.loads(preset_json.read_text())
663
+ for tool_name, tool_def in data.items():
664
+ tool_type = tool_def.get("type", None)
665
+ if tool_type == "builtin":
666
+ self.add_builtin_tool(BuiltinToolDefinition.model_validate(tool_def))
667
+ elif tool_type == "restapi":
668
+ self.add_restapi_tool(RESTAPIToolDefinition.model_validate(tool_def), authenticator=authenticator)
669
+ else:
670
+ warnings.warn(f'Tool type "{tool_type}" is not supported. Skip adding tool "{tool_name}".')
671
+
672
+ def add_mcp_tool(self, params: mcp.StdioServerParameters, tool: mcp_types.Tool):
673
+ """
674
+ Adds a tool from an MCP (Model Context Protocol) server.
675
+
676
+ :param params: Parameters for connecting to the MCP stdio server.
677
+ :param tool: Tool metadata as defined by MCP.
678
+ :returns: True if the tool was successfully added.
679
+ """
680
+ from mcp.client.stdio import stdio_client
681
+
682
+ def call(**inputs: dict[str, Any]) -> Any:
683
+ async def _inner():
684
+ async with stdio_client(params, errlog=subprocess.STDOUT) as streams:
685
+ async with mcp.ClientSession(*streams) as session:
686
+ await session.initialize()
687
+
688
+ result = await session.call_tool(tool.name, inputs)
689
+ contents: list[str] = []
690
+ for item in result.content:
691
+ if isinstance(item, mcp_types.TextContent):
692
+ contents.append(item.text)
693
+ elif isinstance(item, mcp_types.ImageContent):
694
+ contents.append(item.data)
695
+ elif isinstance(item, mcp_types.EmbeddedResource):
696
+ if isinstance(item.resource, mcp_types.TextResourceContents):
697
+ contents.append(item.resource.text)
698
+ else:
699
+ contents.append(item.resource.blob)
700
+
701
+ return contents
702
+
703
+ return run_async(_inner())
704
+
705
+ desc = ToolDescription(name=tool.name, description=tool.description, parameters=tool.inputSchema)
706
+ return self.add_tool(Tool(desc=desc, call_fn=call))
707
+
708
+ def add_tools_from_mcp_server(self, params: mcp.StdioServerParameters, tools_to_add: Optional[list[str]] = None):
709
+ """
710
+ Fetches tools from an MCP stdio server and registers them with the agent.
711
+
712
+ :param params: Parameters for connecting to the MCP stdio server.
713
+ :param tools_to_add: Optional list of tool names to add. If None, all tools are added.
714
+ :returns: list of all tools returned by the server.
715
+ """
716
+ from mcp.client.stdio import stdio_client
717
+
718
+ async def _inner():
719
+ async with stdio_client(params, errlog=subprocess.STDOUT) as streams:
720
+ async with mcp.ClientSession(*streams) as session:
721
+ await session.initialize()
722
+ resp = await session.list_tools()
723
+ for tool in resp.tools:
724
+ if tools_to_add is None or tool.name in tools_to_add:
725
+ self.add_mcp_tool(params, tool)
726
+ return resp.tools
727
+
728
+ tools = run_async(_inner())
729
+ return tools