ailoy-py 0.0.1__cp310-cp310-win_amd64.whl → 0.0.3__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ailoy/agent.py CHANGED
@@ -1,211 +1,185 @@
1
+ import base64
1
2
  import json
2
- import subprocess
3
3
  import warnings
4
4
  from abc import ABC, abstractmethod
5
- from collections.abc import Awaitable, Callable, Generator
5
+ from collections.abc import Callable, Generator
6
+ from functools import partial
6
7
  from pathlib import Path
7
8
  from typing import (
9
+ Annotated,
8
10
  Any,
9
11
  Literal,
10
12
  Optional,
11
- TypeVar,
12
13
  Union,
13
14
  )
14
15
  from urllib.parse import urlencode, urlparse, urlunparse
15
16
 
16
17
  import jmespath
17
- import mcp
18
- import mcp.types as mcp_types
19
- from pydantic import BaseModel, ConfigDict, Field
18
+ from PIL.Image import Image
19
+ from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
20
20
  from rich.console import Console
21
21
  from rich.panel import Panel
22
22
 
23
23
  from ailoy.ailoy_py import generate_uuid
24
+ from ailoy.mcp import MCPServer, MCPTool, StdioServerParameters
25
+ from ailoy.models import APIModel, LocalModel
24
26
  from ailoy.runtime import Runtime
27
+ from ailoy.tools import DocstringParsingException, TypeHintParsingException, get_json_schema
28
+ from ailoy.utils.image import pillow_image_to_base64
25
29
 
26
- __all__ = ["Agent"]
30
+ ## Types for internal data structures
27
31
 
28
- ## Types for OpenAI API-compatible data structures
29
32
 
30
-
31
- class SystemMessage(BaseModel):
32
- role: Literal["system"]
33
- content: str
33
+ class TextContent(BaseModel):
34
+ type: Literal["text"] = "text"
35
+ text: str
34
36
 
35
37
 
36
- class UserMessage(BaseModel):
37
- role: Literal["user"]
38
- content: str
38
+ class ImageContent(BaseModel):
39
+ class UrlData(BaseModel):
40
+ url: str
39
41
 
42
+ type: Literal["image_url"] = "image_url"
43
+ image_url: UrlData
40
44
 
41
- class AIOutputTextMessage(BaseModel):
42
- role: Literal["assistant"]
43
- content: str
44
- reasoning: Optional[bool] = None
45
-
46
-
47
- class AIToolCallMessage(BaseModel):
48
- role: Literal["assistant"]
49
- content: None
50
- tool_calls: list["ToolCall"]
51
-
52
-
53
- class ToolCall(BaseModel):
54
- id: str
55
- type: Literal["function"] = "function"
56
- function: "ToolCallFunction"
45
+ @staticmethod
46
+ def from_url(url: str):
47
+ return ImageContent(image_url={"url": url})
57
48
 
49
+ @staticmethod
50
+ def from_pillow(image: Image):
51
+ return ImageContent(image_url={"url": pillow_image_to_base64(image)})
58
52
 
59
- class ToolCallFunction(BaseModel):
60
- name: str
61
- arguments: dict[str, Any]
62
53
 
54
+ class AudioContent(BaseModel):
55
+ class AudioData(BaseModel):
56
+ data: str
57
+ format: Literal["mp3", "wav"]
63
58
 
64
- class ToolCallResultMessage(BaseModel):
65
- role: Literal["tool"]
66
- name: str
67
- tool_call_id: str
68
- content: str
59
+ type: Literal["input_audio"] = "input_audio"
60
+ input_audio: AudioData
69
61
 
62
+ @staticmethod
63
+ def from_bytes(data: bytes, format: Literal["mp3", "wav"]):
64
+ return AudioContent(input_audio={"data": base64.b64encode(data).decode("utf-8"), "format": format})
70
65
 
71
- Message = Union[
72
- SystemMessage,
73
- UserMessage,
74
- AIOutputTextMessage,
75
- AIToolCallMessage,
76
- ToolCallResultMessage,
77
- ]
78
66
 
67
+ class FunctionData(BaseModel):
68
+ class FunctionBody(BaseModel):
69
+ name: str
70
+ arguments: Any
79
71
 
80
- class MessageDelta(BaseModel):
81
- finish_reason: Optional[Literal["stop", "tool_calls", "length", "error"]]
82
- message: Message
72
+ type: Literal["function"] = "function"
73
+ id: Optional[str] = None
74
+ function: FunctionBody
83
75
 
84
76
 
85
- ## Types for LLM Model Definitions
77
+ class SystemMessage(BaseModel):
78
+ role: Literal["system"] = "system"
79
+ content: str | list[TextContent]
86
80
 
87
- TVMModelName = Literal["Qwen/Qwen3-0.6B", "Qwen/Qwen3-1.7B", "Qwen/Qwen3-4B", "Qwen/Qwen3-8B"]
88
- OpenAIModelName = Literal["gpt-4o"]
89
- ModelName = Union[TVMModelName, OpenAIModelName]
90
81
 
82
+ class UserMessage(BaseModel):
83
+ role: Literal["user"] = "user"
84
+ content: str | list[TextContent | ImageContent | AudioContent]
91
85
 
92
- class TVMModel(BaseModel):
93
- name: TVMModelName
94
- quantization: Optional[Literal["q4f16_1"]] = None
95
- mode: Optional[Literal["interactive"]] = None
96
86
 
87
+ class AssistantMessage(BaseModel):
88
+ role: Literal["assistant"] = "assistant"
89
+ content: Optional[str | list[TextContent]] = None
90
+ name: Optional[str] = None
91
+ tool_calls: Optional[list[FunctionData]] = None
97
92
 
98
- class OpenAIModel(BaseModel):
99
- name: OpenAIModelName
100
- api_key: str
93
+ # Non-OpenAI fields
94
+ reasoning: Optional[list[TextContent]] = None
101
95
 
102
96
 
103
- class ModelDescription(BaseModel):
104
- model_id: str
105
- component_type: str
106
- default_system_message: Optional[str] = None
97
+ class ToolMessage(BaseModel):
98
+ role: Literal["tool"] = "tool"
99
+ content: str | list[TextContent]
100
+ tool_call_id: Optional[str] = None
107
101
 
108
102
 
109
- model_descriptions: dict[ModelName, ModelDescription] = {
110
- "Qwen/Qwen3-0.6B": ModelDescription(
111
- model_id="Qwen/Qwen3-0.6B",
112
- component_type="tvm_language_model",
113
- default_system_message="You are Qwen, created by Alibaba Cloud. You are a helpful assistant.",
114
- ),
115
- "Qwen/Qwen3-1.7B": ModelDescription(
116
- model_id="Qwen/Qwen3-1.7B",
117
- component_type="tvm_language_model",
118
- default_system_message="You are Qwen, created by Alibaba Cloud. You are a helpful assistant.",
119
- ),
120
- "Qwen/Qwen3-4B": ModelDescription(
121
- model_id="Qwen/Qwen3-4B",
122
- component_type="tvm_language_model",
123
- default_system_message="You are Qwen, created by Alibaba Cloud. You are a helpful assistant.",
124
- ),
125
- "Qwen/Qwen3-8B": ModelDescription(
126
- model_id="Qwen/Qwen3-8B",
127
- component_type="tvm_language_model",
128
- default_system_message="You are Qwen, created by Alibaba Cloud. You are a helpful assistant.",
129
- ),
130
- "gpt-4o": ModelDescription(
131
- model_id="gpt-4o",
132
- component_type="openai",
133
- ),
134
- }
103
+ Message = Union[
104
+ SystemMessage,
105
+ UserMessage,
106
+ AssistantMessage,
107
+ ToolMessage,
108
+ ]
135
109
 
136
110
 
137
- class ComponentState(BaseModel):
138
- name: str
139
- valid: bool
111
+ class MessageOutput(BaseModel):
112
+ message: AssistantMessage
113
+ finish_reason: Optional[Literal["stop", "tool_calls", "invalid_tool_call", "length", "error"]] = None
140
114
 
141
115
 
142
116
  ## Types for agent's responses
143
117
 
144
- _console = Console(highlight=False)
145
-
118
+ _console = Console(highlight=False, force_jupyter=False, force_terminal=True)
146
119
 
147
- class AgentResponseBase(BaseModel):
148
- type: Literal["output_text", "tool_call", "tool_call_result", "reasoning", "error"]
149
- end_of_turn: bool
150
- role: Literal["assistant", "tool"]
151
- content: Any
152
120
 
153
- def print(self):
154
- raise NotImplementedError
155
-
156
-
157
- class AgentResponseOutputText(AgentResponseBase):
121
+ class AgentResponseOutputText(BaseModel):
158
122
  type: Literal["output_text", "reasoning"]
159
- role: Literal["assistant"]
123
+ role: Literal["assistant"] = "assistant"
124
+ is_type_switched: bool = False
160
125
  content: str
161
126
 
162
127
  def print(self):
128
+ if self.is_type_switched:
129
+ _console.print() # add newline if type has been switched
163
130
  _console.print(self.content, end="", style=("yellow" if self.type == "reasoning" else None))
164
- if self.end_of_turn:
165
- _console.print()
166
131
 
167
132
 
168
- class AgentResponseToolCall(AgentResponseBase):
169
- type: Literal["tool_call"]
170
- role: Literal["assistant"]
171
- content: ToolCall
133
+ class AgentResponseToolCall(BaseModel):
134
+ type: Literal["tool_call"] = "tool_call"
135
+ role: Literal["assistant"] = "assistant"
136
+ is_type_switched: bool = False
137
+ content: FunctionData
172
138
 
173
139
  def print(self):
140
+ title = f"[magenta]Tool Call[/magenta]: [bold]{self.content.function.name}[/bold]"
141
+ if self.content.id is not None and len(self.content.id) > 0:
142
+ title += f" ({self.content.id})"
174
143
  panel = Panel(
175
144
  json.dumps(self.content.function.arguments, indent=2),
176
- title=f"[magenta]Tool Call[/magenta]: [bold]{self.content.function.name}[/bold] ({self.content.id})",
145
+ title=title,
177
146
  title_align="left",
178
147
  )
179
148
  _console.print(panel)
180
149
 
181
150
 
182
- class AgentResponseToolCallResult(AgentResponseBase):
183
- type: Literal["tool_call_result"]
184
- role: Literal["tool"]
185
- content: ToolCallResultMessage
151
+ class AgentResponseToolResult(BaseModel):
152
+ type: Literal["tool_call_result"] = "tool_call_result"
153
+ role: Literal["tool"] = "tool"
154
+ is_type_switched: bool = False
155
+ content: ToolMessage
186
156
 
187
157
  def print(self):
188
158
  try:
189
159
  # Try to parse as json
190
- content = json.dumps(json.loads(self.content.content), indent=2)
160
+ content = json.dumps(json.loads(self.content.content[0].text), indent=2)
191
161
  except json.JSONDecodeError:
192
162
  # Use original content if not json deserializable
193
- content = self.content.content
163
+ content = self.content.content[0].text
194
164
  # Truncate long contents
195
165
  if len(content) > 500:
196
166
  content = content[:500] + "...(truncated)"
197
167
 
168
+ title = "[green]Tool Result[/green]"
169
+ if self.content.tool_call_id is not None and len(self.content.tool_call_id) > 0:
170
+ title += f" ({self.content.tool_call_id})"
198
171
  panel = Panel(
199
172
  content,
200
- title=f"[green]Tool Result[/green]: [bold]{self.content.name}[/bold] ({self.content.tool_call_id})",
173
+ title=title,
201
174
  title_align="left",
202
175
  )
203
176
  _console.print(panel)
204
177
 
205
178
 
206
- class AgentResponseError(AgentResponseBase):
207
- type: Literal["error"]
208
- role: Literal["assistant"]
179
+ class AgentResponseError(BaseModel):
180
+ type: Literal["error"] = "error"
181
+ role: Literal["assistant"] = "assistant"
182
+ is_type_switched: bool = False
209
183
  content: str
210
184
 
211
185
  def print(self):
@@ -219,7 +193,7 @@ class AgentResponseError(AgentResponseBase):
219
193
  AgentResponse = Union[
220
194
  AgentResponseOutputText,
221
195
  AgentResponseToolCall,
222
- AgentResponseToolCallResult,
196
+ AgentResponseToolResult,
223
197
  AgentResponseError,
224
198
  ]
225
199
 
@@ -305,22 +279,6 @@ class BearerAuthenticator(ToolAuthenticator):
305
279
  return {**request, "headers": headers}
306
280
 
307
281
 
308
- T_Retval = TypeVar("T_Retval")
309
-
310
-
311
- def run_async(coro: Callable[..., Awaitable[T_Retval]]) -> T_Retval:
312
- try:
313
- import anyio
314
-
315
- # Running outside async loop
316
- return anyio.run(lambda: coro)
317
- except RuntimeError:
318
- import anyio.from_thread
319
-
320
- # Already in a running event loop: use anyio from_thread
321
- return anyio.from_thread.run(coro)
322
-
323
-
324
282
  class Agent:
325
283
  """
326
284
  The `Agent` class provides a high-level interface for interacting with large language models (LLMs) in Ailoy.
@@ -334,39 +292,37 @@ class Agent:
334
292
  def __init__(
335
293
  self,
336
294
  runtime: Runtime,
337
- model_name: ModelName,
295
+ model: APIModel | LocalModel,
338
296
  system_message: Optional[str] = None,
339
- api_key: Optional[str] = None,
340
- attrs: dict[str, Any] = dict(),
341
297
  ):
342
298
  """
343
299
  Create an instance.
344
300
 
345
301
  :param runtime: The runtime environment associated with the agent.
346
- :param model_name: The name of the LLM model to use.
302
+ :param model: The model instance.
347
303
  :param system_message: Optional system message to set the initial assistant context.
348
- :param api_key: (web agent only) The API key for AI API.
349
- :param attrs: Additional initialization parameters (for `define_component` runtime call)
350
304
  :raises ValueError: If model name is not supported or validation fails.
351
305
  """
352
306
  self._runtime = runtime
353
307
 
354
308
  # Initialize component state
355
- self._component_state = ComponentState(
356
- name=generate_uuid(),
357
- valid=False,
358
- )
309
+ self._component_name = generate_uuid()
310
+ self._component_ready = False
359
311
 
360
312
  # Initialize messages
361
313
  self._messages: list[Message] = []
362
- if system_message:
363
- self._messages.append(SystemMessage(role="system", content=system_message))
314
+
315
+ # Initialize system message
316
+ self._system_message = system_message
364
317
 
365
318
  # Initialize tools
366
319
  self._tools: list[Tool] = []
367
320
 
321
+ # Initialize MCP servers
322
+ self._mcp_servers: list[MCPServer] = []
323
+
368
324
  # Define the component
369
- self.define(model_name, api_key=api_key, attrs=attrs)
325
+ self.define(model)
370
326
 
371
327
  def __del__(self):
372
328
  self.delete()
@@ -377,151 +333,216 @@ class Agent:
377
333
  def __exit__(self, type, value, traceback):
378
334
  self.delete()
379
335
 
380
- def define(self, model_name: ModelName, api_key: Optional[str] = None, attrs: dict[str, Any] = dict()) -> None:
336
+ def define(self, model: APIModel | LocalModel) -> None:
381
337
  """
382
338
  Initializes the agent by defining its model in the runtime.
383
339
  This must be called before running the agent. If already initialized, this is a no-op.
384
- :param model_name: The name of the LLM model to use.
385
- :param api_key: (web agent only) The API key for AI API.
386
- :param attrs: Additional initialization parameters (for `define_component` runtime call)
340
+ :param model: The model instance.
387
341
  """
388
- if self._component_state.valid:
342
+ if self._component_ready:
389
343
  return
390
344
 
391
- if model_name not in model_descriptions:
392
- raise ValueError(f"Model `{model_name}` not supported")
345
+ if not self._runtime.is_alive():
346
+ raise ValueError("Runtime is currently stopped.")
393
347
 
394
- model_desc = model_descriptions[model_name]
348
+ # Set default system message if not given; still can be None
349
+ if self._system_message is None:
350
+ self._system_message = getattr(model, "default_system_message", None)
395
351
 
396
- # Add model name into attrs
397
- if "model" not in attrs:
398
- attrs["model"] = model_desc.model_id
399
-
400
- # Set default system message
401
- if len(self._messages) == 0 and model_desc.default_system_message:
402
- self._messages.append(SystemMessage(role="system", content=model_desc.default_system_message))
403
-
404
- # Add API key
405
- if api_key:
406
- attrs["api_key"] = api_key
352
+ self.clear_messages()
407
353
 
408
354
  # Call runtime's define
409
355
  self._runtime.define(
410
- model_descriptions[model_name].component_type,
411
- self._component_state.name,
412
- attrs,
356
+ model.component_type,
357
+ self._component_name,
358
+ model.to_attrs(),
413
359
  )
414
360
 
415
361
  # Mark as defined
416
- self._component_state.valid = True
362
+ self._component_ready = True
417
363
 
418
364
  def delete(self) -> None:
419
365
  """
420
366
  Deinitializes the agent and releases resources in the runtime.
421
367
  This should be called when the agent is no longer needed. If already deinitialized, this is a no-op.
422
368
  """
423
- if not self._component_state.valid:
369
+ if not self._component_ready:
424
370
  return
425
- self._runtime.delete(self._component_state.name)
426
- if len(self._messages) > 0 and self._messages[0].role == "system":
427
- self._messages = [self._messages[0]]
428
- else:
429
- self._messages = []
430
- self._component_state.valid = False
371
+
372
+ if self._runtime.is_alive():
373
+ self._runtime.delete(self._component_name)
374
+
375
+ self.clear_messages()
376
+
377
+ for mcp_server in self._mcp_servers:
378
+ mcp_server.cleanup()
379
+
380
+ self._component_ready = False
431
381
 
432
382
  def query(
433
383
  self,
434
- message: str,
435
- enable_reasoning: bool = False,
436
- ignore_reasoning_messages: bool = False,
384
+ message: str | list[str | Image | dict | TextContent | ImageContent | AudioContent],
385
+ reasoning: bool = False,
437
386
  ) -> Generator[AgentResponse, None, None]:
438
387
  """
439
388
  Runs the agent with a new user message and yields streamed responses.
440
389
 
441
390
  :param message: The user message to send to the model.
442
- :param enable_reasoning: If True, enables reasoning capabilities. (default: False)
443
- :param ignore_reasoning_messages: If True, reasoning steps are not included in the response stream. (default: False)
444
- :yield: AgentResponse output of the LLM inference or tool calls
391
+ :param reasoning: If True, enables reasoning capabilities. (Default: False)
392
+ :return: An iterator over the output, where each item represents either a generated token from the assistant or a tool call.
393
+ :rtype: Iterator[:class:`AgentResponse`]
445
394
  """ # noqa: E501
446
- self._messages.append(UserMessage(role="user", content=message))
395
+ if not self._component_ready:
396
+ raise ValueError("Agent is not valid. Create one or define newly.")
397
+
398
+ if not self._runtime.is_alive():
399
+ raise ValueError("Runtime is currently stopped.")
400
+
401
+ if isinstance(message, str):
402
+ self._messages.append(UserMessage(content=[TextContent(text=message)]))
403
+ elif isinstance(message, list):
404
+ if len(message) == 0:
405
+ raise ValueError("Message is empty")
406
+
407
+ contents = []
408
+ for content in message:
409
+ if isinstance(content, str):
410
+ contents.append(TextContent(text=content))
411
+ elif isinstance(content, Image):
412
+ contents.append(ImageContent.from_pillow(image=content))
413
+ elif isinstance(content, dict):
414
+ ta: TypeAdapter[TextContent | ImageContent | AudioContent] = TypeAdapter(
415
+ Annotated[TextContent | ImageContent | AudioContent, Field(discriminator="type")]
416
+ )
417
+ validated_content = ta.validate_python(content)
418
+ contents.append(validated_content)
419
+ else:
420
+ contents.append(content)
421
+
422
+ self._messages.append(UserMessage(content=contents))
423
+ else:
424
+ raise ValueError(f"Invalid message type: {type(message)}")
425
+
426
+ prev_resp_type = None
447
427
 
448
428
  while True:
449
429
  infer_args = {
450
- "messages": [msg.model_dump() for msg in self._messages],
451
- "tools": [{"type": "function", "function": t.desc.model_dump()} for t in self._tools],
430
+ "messages": [msg.model_dump(exclude_none=True) for msg in self._messages],
431
+ "tools": [{"type": "function", "function": t.desc.model_dump(exclude_none=True)} for t in self._tools],
452
432
  }
453
- if enable_reasoning:
454
- infer_args["enable_reasoning"] = enable_reasoning
455
- if ignore_reasoning_messages:
456
- infer_args["ignore_reasoning_messages"] = ignore_reasoning_messages
457
-
458
- for resp in self._runtime.call_iter_method(self._component_state.name, "infer", infer_args):
459
- delta = MessageDelta.model_validate(resp)
460
-
461
- if delta.finish_reason is None:
462
- output_msg = AIOutputTextMessage.model_validate(delta.message)
463
- yield AgentResponseOutputText(
464
- type="reasoning" if output_msg.reasoning else "output_text",
465
- end_of_turn=False,
466
- role="assistant",
467
- content=output_msg.content,
468
- )
469
- continue
470
-
471
- if delta.finish_reason == "tool_calls":
472
- tool_call_message = AIToolCallMessage.model_validate(delta.message)
473
- self._messages.append(tool_call_message)
474
-
475
- for tool_call in tool_call_message.tool_calls:
476
- yield AgentResponseToolCall(
477
- type="tool_call",
478
- end_of_turn=True,
479
- role="assistant",
480
- content=tool_call,
433
+ if reasoning:
434
+ infer_args["reasoning"] = reasoning
435
+
436
+ assistant_reasoning = None
437
+ assistant_content = None
438
+ assistant_tool_calls = None
439
+ finish_reason = ""
440
+ for result in self._runtime.call_iter_method(self._component_name, "infer", infer_args):
441
+ msg = MessageOutput.model_validate(result)
442
+
443
+ if msg.message.reasoning:
444
+ for v in msg.message.reasoning:
445
+ if not assistant_reasoning:
446
+ assistant_reasoning = [v]
447
+ else:
448
+ assistant_reasoning[0].text += v.text
449
+ resp = AgentResponseOutputText(
450
+ type="reasoning",
451
+ is_type_switched=(prev_resp_type != "reasoning"),
452
+ content=v.text,
481
453
  )
482
-
483
- tool_call_results: list[ToolCallResultMessage] = []
484
-
485
- def run_tool(tool_call: ToolCall):
486
- tool_ = next(
487
- (t for t in self._tools if t.desc.name == tool_call.function.name),
488
- None,
454
+ prev_resp_type = resp.type
455
+ yield resp
456
+ if msg.message.content is not None:
457
+ # Canonicalize message content to the array of TextContent
458
+ if isinstance(msg.message.content, str):
459
+ msg.message.content = [TextContent(text=msg.message.content)]
460
+
461
+ for v in msg.message.content:
462
+ if not assistant_content:
463
+ assistant_content = [v]
464
+ else:
465
+ assistant_content[0].text += v.text
466
+ resp = AgentResponseOutputText(
467
+ type="output_text",
468
+ is_type_switched=(prev_resp_type != "output_text"),
469
+ content=v.text,
489
470
  )
490
- if not tool_:
491
- raise RuntimeError("Tool not found")
492
- resp = tool_.call(**tool_call.function.arguments)
493
- return ToolCallResultMessage(
494
- role="tool",
495
- name=tool_call.function.name,
496
- tool_call_id=tool_call.id,
497
- content=json.dumps(resp),
471
+ prev_resp_type = resp.type
472
+ yield resp
473
+ if msg.message.tool_calls:
474
+ for v in msg.message.tool_calls:
475
+ if not assistant_tool_calls:
476
+ assistant_tool_calls = [v]
477
+ else:
478
+ assistant_tool_calls.append(v)
479
+ resp = AgentResponseToolCall(
480
+ is_type_switched=True,
481
+ content=v,
498
482
  )
483
+ prev_resp_type = resp.type
484
+ yield resp
485
+ if msg.finish_reason:
486
+ finish_reason = msg.finish_reason
487
+ break
488
+
489
+ # Append output
490
+ self._messages.append(
491
+ AssistantMessage(
492
+ reasoning=assistant_reasoning,
493
+ content=assistant_content,
494
+ tool_calls=assistant_tool_calls,
495
+ )
496
+ )
497
+
498
+ if finish_reason == "tool_calls":
499
+
500
+ def run_tool(tool_call: FunctionData) -> ToolMessage:
501
+ tool_ = next(
502
+ (t for t in self._tools if t.desc.name == tool_call.function.name),
503
+ None,
504
+ )
505
+ if not tool_:
506
+ raise RuntimeError("Tool not found")
507
+ tool_result = tool_.call(**tool_call.function.arguments)
508
+ return ToolMessage(
509
+ content=[TextContent(text=json.dumps(tool_result))],
510
+ tool_call_id=tool_call.id,
511
+ )
499
512
 
500
- tool_call_results = [run_tool(tc) for tc in tool_call_message.tool_calls]
513
+ tool_call_results = [run_tool(tc) for tc in assistant_tool_calls]
514
+ for result_msg in tool_call_results:
515
+ self._messages.append(result_msg)
516
+ resp = AgentResponseToolResult(
517
+ is_type_switched=True,
518
+ content=result_msg,
519
+ )
520
+ prev_resp_type = resp.type
521
+ yield resp
522
+ # Infer again if tool calls happened
523
+ continue
501
524
 
502
- for result_msg in tool_call_results:
503
- self._messages.append(result_msg)
504
- yield AgentResponseToolCallResult(
505
- type="tool_call_result",
506
- end_of_turn=True,
507
- role="tool",
508
- content=result_msg,
509
- )
525
+ # Finish this generator
526
+ yield AgentResponseOutputText(type="output_text", content="\n")
527
+ break
510
528
 
511
- # Run infer again with new messages
512
- break
529
+ def get_messages(self) -> list[Message]:
530
+ """
531
+ Get the current conversation history.
532
+ Each item in the list represents a message from either the user or the assistant.
513
533
 
514
- if delta.finish_reason in ["stop", "length", "error"]:
515
- output_msg = AIOutputTextMessage.model_validate(delta.message)
516
- yield AgentResponseOutputText(
517
- type="reasoning" if output_msg.reasoning else "output_text",
518
- end_of_turn=True,
519
- role="assistant",
520
- content=output_msg.content,
521
- )
534
+ :return: The conversation history so far in the form of a list.
535
+ :rtype: list[Message]
536
+ """
537
+ return self._messages
522
538
 
523
- # finish this Generator
524
- return
539
+ def clear_messages(self):
540
+ """
541
+ Clear the history of conversation messages.
542
+ """
543
+ self._messages.clear()
544
+ if self._system_message is not None:
545
+ self._messages.append(SystemMessage(role="system", content=[TextContent(text=self._system_message)]))
525
546
 
526
547
  def print(self, resp: AgentResponse):
527
548
  resp.print()
@@ -537,14 +558,29 @@ class Agent:
537
558
  return
538
559
  self._tools.append(tool)
539
560
 
540
- def add_py_function_tool(self, desc: dict, f: Callable[..., Any]):
561
+ def add_py_function_tool(self, f: Callable[..., Any], desc: Optional[dict] = None):
541
562
  """
542
563
  Adds a Python function as a tool using callable.
543
564
 
544
- :param desc: Tool descriotion.
545
565
  :param f: Function will be called when the tool invocation occured.
566
+ :param desc: Tool description to override. If not given, parsed from docstring of function `f`.
567
+
568
+ :raises ValueError: Docstring parsing is failed.
569
+ :raises ValidationError: Given or parsed description is not a valid `ToolDescription`.
546
570
  """
547
- self.add_tool(Tool(desc=ToolDescription.model_validate(desc), call_fn=f))
571
+ tool_description = None
572
+ if desc is not None:
573
+ tool_description = ToolDescription.model_validate(desc)
574
+
575
+ if tool_description is None:
576
+ try:
577
+ json_schema = get_json_schema(f)
578
+ except (TypeHintParsingException, DocstringParsingException) as e:
579
+ raise ValueError("Failed to parse docstring", e)
580
+
581
+ tool_description = ToolDescription.model_validate(json_schema.get("function", {}))
582
+
583
+ self.add_tool(Tool(desc=tool_description, call_fn=f))
548
584
 
549
585
  def add_builtin_tool(self, tool_def: BuiltinToolDefinition) -> bool:
550
586
  """
@@ -669,61 +705,65 @@ class Agent:
669
705
  else:
670
706
  warnings.warn(f'Tool type "{tool_type}" is not supported. Skip adding tool "{tool_name}".')
671
707
 
672
- def add_mcp_tool(self, params: mcp.StdioServerParameters, tool: mcp_types.Tool):
708
+ def add_tools_from_mcp_server(
709
+ self, name: str, params: StdioServerParameters, tools_to_add: Optional[list[str]] = None
710
+ ):
673
711
  """
674
- Adds a tool from an MCP (Model Context Protocol) server.
712
+ Create a MCP server and register its tools to agent.
675
713
 
714
+ :param name: The unique name of the MCP server.
715
+ If there's already a MCP server with the same name, it raises RuntimeError.
676
716
  :param params: Parameters for connecting to the MCP stdio server.
677
- :param tool: Tool metadata as defined by MCP.
678
- :returns: True if the tool was successfully added.
717
+ :param tools_to_add: Optional list of tool names to add. If None, all tools are added.
679
718
  """
680
- from mcp.client.stdio import stdio_client
719
+ if any([s.name == name for s in self._mcp_servers]):
720
+ raise RuntimeError(f"MCP server with name '{name}' is already registered")
681
721
 
682
- def call(**inputs: dict[str, Any]) -> Any:
683
- async def _inner():
684
- async with stdio_client(params, errlog=subprocess.STDOUT) as streams:
685
- async with mcp.ClientSession(*streams) as session:
686
- await session.initialize()
687
-
688
- result = await session.call_tool(tool.name, inputs)
689
- contents: list[str] = []
690
- for item in result.content:
691
- if isinstance(item, mcp_types.TextContent):
692
- contents.append(item.text)
693
- elif isinstance(item, mcp_types.ImageContent):
694
- contents.append(item.data)
695
- elif isinstance(item, mcp_types.EmbeddedResource):
696
- if isinstance(item.resource, mcp_types.TextResourceContents):
697
- contents.append(item.resource.text)
698
- else:
699
- contents.append(item.resource.blob)
700
-
701
- return contents
702
-
703
- return run_async(_inner())
704
-
705
- desc = ToolDescription(name=tool.name, description=tool.description, parameters=tool.inputSchema)
706
- return self.add_tool(Tool(desc=desc, call_fn=call))
707
-
708
- def add_tools_from_mcp_server(self, params: mcp.StdioServerParameters, tools_to_add: Optional[list[str]] = None):
722
+ # Create and register MCP server
723
+ mcp_server = MCPServer(name, params)
724
+ self._mcp_servers.append(mcp_server)
725
+
726
+ # Register tools
727
+ for tool in mcp_server.list_tools():
728
+ # Skip if this tool is not in the whitelist
729
+ if tools_to_add is not None and tool.name not in tools_to_add:
730
+ continue
731
+
732
+ desc = ToolDescription(
733
+ name=f"{name}-{tool.name}", description=tool.description, parameters=tool.inputSchema
734
+ )
735
+
736
+ def call(tool: MCPTool, **inputs: dict[str, Any]) -> list[str]:
737
+ return mcp_server.call_tool(tool, inputs)
738
+
739
+ self.add_tool(Tool(desc=desc, call_fn=partial(call, tool)))
740
+
741
+ def remove_mcp_server(self, name: str):
709
742
  """
710
- Fetches tools from an MCP stdio server and registers them with the agent.
743
+ Removes the MCP server and its tools from the agent, with terminating the MCP server process.
711
744
 
712
- :param params: Parameters for connecting to the MCP stdio server.
713
- :param tools_to_add: Optional list of tool names to add. If None, all tools are added.
714
- :returns: list of all tools returned by the server.
745
+ :param name: The unique name of the MCP server.
746
+ If there's no MCP server matches the name, it raises RuntimeError.
747
+ """
748
+ if all([s.name != name for s in self._mcp_servers]):
749
+ raise RuntimeError(f"MCP server with name '{name}' does not exist")
750
+
751
+ # Remove the MCP server
752
+ mcp_server = next(filter(lambda s: s.name == name, self._mcp_servers))
753
+ self._mcp_servers.remove(mcp_server)
754
+ mcp_server.cleanup()
755
+
756
+ # Remove tools registered from the MCP server
757
+ self._tools = list(filter(lambda t: not t.desc.name.startswith(f"{mcp_server.name}-"), self._tools))
758
+
759
+ def get_tools(self):
760
+ """
761
+ Get the list of registered tools.
762
+ """
763
+ return self._tools
764
+
765
+ def clear_tools(self):
766
+ """
767
+ Clear the registered tools.
715
768
  """
716
- from mcp.client.stdio import stdio_client
717
-
718
- async def _inner():
719
- async with stdio_client(params, errlog=subprocess.STDOUT) as streams:
720
- async with mcp.ClientSession(*streams) as session:
721
- await session.initialize()
722
- resp = await session.list_tools()
723
- for tool in resp.tools:
724
- if tools_to_add is None or tool.name in tools_to_add:
725
- self.add_mcp_tool(params, tool)
726
- return resp.tools
727
-
728
- tools = run_async(_inner())
729
- return tools
769
+ self._tools.clear()