ailoy-py 0.0.1__cp310-cp310-win_amd64.whl → 0.0.3__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ailoy/__init__.py +20 -1
- ailoy/agent.py +349 -309
- ailoy/ailoy_py.cp310-win_amd64.pyd +0 -0
- ailoy/mcp.py +171 -0
- ailoy/models/__init__.py +7 -0
- ailoy/models/api_model.py +71 -0
- ailoy/models/local_model.py +44 -0
- ailoy/runtime.py +34 -19
- ailoy/tools.py +205 -0
- ailoy/utils/__init__.py +0 -0
- ailoy/utils/image.py +11 -0
- ailoy/vector_store.py +10 -9
- ailoy_py-0.0.3.dist-info/DELVEWHEEL +2 -0
- {ailoy_py-0.0.1.dist-info → ailoy_py-0.0.3.dist-info}/METADATA +5 -4
- ailoy_py-0.0.3.dist-info/RECORD +27 -0
- {ailoy_py-0.0.1.dist-info → ailoy_py-0.0.3.dist-info}/WHEEL +1 -1
- ailoy_py.libs/msvcp140-0c97ddc05c5b9024aa6af9538804ea77.dll +0 -0
- ailoy_py.libs/tvm_runtime-b9e3c7109c2f4b1e95a6f576ff368094.dll +0 -0
- ailoy_py-0.0.1.dist-info/DELVEWHEEL +0 -2
- ailoy_py-0.0.1.dist-info/RECORD +0 -20
- ailoy_py.libs/msvcp140-9867ece6bcf7e4746fa7e6671b0a17bd.dll +0 -0
- ailoy_py.libs/tvm_runtime-781b77698d9c76cd695ed4ae13795465.dll +0 -0
- {ailoy_py-0.0.1.dist-info → ailoy_py-0.0.3.dist-info}/entry_points.txt +0 -0
ailoy/agent.py
CHANGED
@@ -1,211 +1,185 @@
|
|
1
|
+
import base64
|
1
2
|
import json
|
2
|
-
import subprocess
|
3
3
|
import warnings
|
4
4
|
from abc import ABC, abstractmethod
|
5
|
-
from collections.abc import
|
5
|
+
from collections.abc import Callable, Generator
|
6
|
+
from functools import partial
|
6
7
|
from pathlib import Path
|
7
8
|
from typing import (
|
9
|
+
Annotated,
|
8
10
|
Any,
|
9
11
|
Literal,
|
10
12
|
Optional,
|
11
|
-
TypeVar,
|
12
13
|
Union,
|
13
14
|
)
|
14
15
|
from urllib.parse import urlencode, urlparse, urlunparse
|
15
16
|
|
16
17
|
import jmespath
|
17
|
-
import
|
18
|
-
import
|
19
|
-
from pydantic import BaseModel, ConfigDict, Field
|
18
|
+
from PIL.Image import Image
|
19
|
+
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
|
20
20
|
from rich.console import Console
|
21
21
|
from rich.panel import Panel
|
22
22
|
|
23
23
|
from ailoy.ailoy_py import generate_uuid
|
24
|
+
from ailoy.mcp import MCPServer, MCPTool, StdioServerParameters
|
25
|
+
from ailoy.models import APIModel, LocalModel
|
24
26
|
from ailoy.runtime import Runtime
|
27
|
+
from ailoy.tools import DocstringParsingException, TypeHintParsingException, get_json_schema
|
28
|
+
from ailoy.utils.image import pillow_image_to_base64
|
25
29
|
|
26
|
-
|
30
|
+
## Types for internal data structures
|
27
31
|
|
28
|
-
## Types for OpenAI API-compatible data structures
|
29
32
|
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
content: str
|
33
|
+
class TextContent(BaseModel):
|
34
|
+
type: Literal["text"] = "text"
|
35
|
+
text: str
|
34
36
|
|
35
37
|
|
36
|
-
class
|
37
|
-
|
38
|
-
|
38
|
+
class ImageContent(BaseModel):
|
39
|
+
class UrlData(BaseModel):
|
40
|
+
url: str
|
39
41
|
|
42
|
+
type: Literal["image_url"] = "image_url"
|
43
|
+
image_url: UrlData
|
40
44
|
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
reasoning: Optional[bool] = None
|
45
|
-
|
46
|
-
|
47
|
-
class AIToolCallMessage(BaseModel):
|
48
|
-
role: Literal["assistant"]
|
49
|
-
content: None
|
50
|
-
tool_calls: list["ToolCall"]
|
51
|
-
|
52
|
-
|
53
|
-
class ToolCall(BaseModel):
|
54
|
-
id: str
|
55
|
-
type: Literal["function"] = "function"
|
56
|
-
function: "ToolCallFunction"
|
45
|
+
@staticmethod
|
46
|
+
def from_url(url: str):
|
47
|
+
return ImageContent(image_url={"url": url})
|
57
48
|
|
49
|
+
@staticmethod
|
50
|
+
def from_pillow(image: Image):
|
51
|
+
return ImageContent(image_url={"url": pillow_image_to_base64(image)})
|
58
52
|
|
59
|
-
class ToolCallFunction(BaseModel):
|
60
|
-
name: str
|
61
|
-
arguments: dict[str, Any]
|
62
53
|
|
54
|
+
class AudioContent(BaseModel):
|
55
|
+
class AudioData(BaseModel):
|
56
|
+
data: str
|
57
|
+
format: Literal["mp3", "wav"]
|
63
58
|
|
64
|
-
|
65
|
-
|
66
|
-
name: str
|
67
|
-
tool_call_id: str
|
68
|
-
content: str
|
59
|
+
type: Literal["input_audio"] = "input_audio"
|
60
|
+
input_audio: AudioData
|
69
61
|
|
62
|
+
@staticmethod
|
63
|
+
def from_bytes(data: bytes, format: Literal["mp3", "wav"]):
|
64
|
+
return AudioContent(input_audio={"data": base64.b64encode(data).decode("utf-8"), "format": format})
|
70
65
|
|
71
|
-
Message = Union[
|
72
|
-
SystemMessage,
|
73
|
-
UserMessage,
|
74
|
-
AIOutputTextMessage,
|
75
|
-
AIToolCallMessage,
|
76
|
-
ToolCallResultMessage,
|
77
|
-
]
|
78
66
|
|
67
|
+
class FunctionData(BaseModel):
|
68
|
+
class FunctionBody(BaseModel):
|
69
|
+
name: str
|
70
|
+
arguments: Any
|
79
71
|
|
80
|
-
|
81
|
-
|
82
|
-
|
72
|
+
type: Literal["function"] = "function"
|
73
|
+
id: Optional[str] = None
|
74
|
+
function: FunctionBody
|
83
75
|
|
84
76
|
|
85
|
-
|
77
|
+
class SystemMessage(BaseModel):
|
78
|
+
role: Literal["system"] = "system"
|
79
|
+
content: str | list[TextContent]
|
86
80
|
|
87
|
-
TVMModelName = Literal["Qwen/Qwen3-0.6B", "Qwen/Qwen3-1.7B", "Qwen/Qwen3-4B", "Qwen/Qwen3-8B"]
|
88
|
-
OpenAIModelName = Literal["gpt-4o"]
|
89
|
-
ModelName = Union[TVMModelName, OpenAIModelName]
|
90
81
|
|
82
|
+
class UserMessage(BaseModel):
|
83
|
+
role: Literal["user"] = "user"
|
84
|
+
content: str | list[TextContent | ImageContent | AudioContent]
|
91
85
|
|
92
|
-
class TVMModel(BaseModel):
|
93
|
-
name: TVMModelName
|
94
|
-
quantization: Optional[Literal["q4f16_1"]] = None
|
95
|
-
mode: Optional[Literal["interactive"]] = None
|
96
86
|
|
87
|
+
class AssistantMessage(BaseModel):
|
88
|
+
role: Literal["assistant"] = "assistant"
|
89
|
+
content: Optional[str | list[TextContent]] = None
|
90
|
+
name: Optional[str] = None
|
91
|
+
tool_calls: Optional[list[FunctionData]] = None
|
97
92
|
|
98
|
-
|
99
|
-
|
100
|
-
api_key: str
|
93
|
+
# Non-OpenAI fields
|
94
|
+
reasoning: Optional[list[TextContent]] = None
|
101
95
|
|
102
96
|
|
103
|
-
class
|
104
|
-
|
105
|
-
|
106
|
-
|
97
|
+
class ToolMessage(BaseModel):
|
98
|
+
role: Literal["tool"] = "tool"
|
99
|
+
content: str | list[TextContent]
|
100
|
+
tool_call_id: Optional[str] = None
|
107
101
|
|
108
102
|
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
"Qwen/Qwen3-1.7B": ModelDescription(
|
116
|
-
model_id="Qwen/Qwen3-1.7B",
|
117
|
-
component_type="tvm_language_model",
|
118
|
-
default_system_message="You are Qwen, created by Alibaba Cloud. You are a helpful assistant.",
|
119
|
-
),
|
120
|
-
"Qwen/Qwen3-4B": ModelDescription(
|
121
|
-
model_id="Qwen/Qwen3-4B",
|
122
|
-
component_type="tvm_language_model",
|
123
|
-
default_system_message="You are Qwen, created by Alibaba Cloud. You are a helpful assistant.",
|
124
|
-
),
|
125
|
-
"Qwen/Qwen3-8B": ModelDescription(
|
126
|
-
model_id="Qwen/Qwen3-8B",
|
127
|
-
component_type="tvm_language_model",
|
128
|
-
default_system_message="You are Qwen, created by Alibaba Cloud. You are a helpful assistant.",
|
129
|
-
),
|
130
|
-
"gpt-4o": ModelDescription(
|
131
|
-
model_id="gpt-4o",
|
132
|
-
component_type="openai",
|
133
|
-
),
|
134
|
-
}
|
103
|
+
Message = Union[
|
104
|
+
SystemMessage,
|
105
|
+
UserMessage,
|
106
|
+
AssistantMessage,
|
107
|
+
ToolMessage,
|
108
|
+
]
|
135
109
|
|
136
110
|
|
137
|
-
class
|
138
|
-
|
139
|
-
|
111
|
+
class MessageOutput(BaseModel):
|
112
|
+
message: AssistantMessage
|
113
|
+
finish_reason: Optional[Literal["stop", "tool_calls", "invalid_tool_call", "length", "error"]] = None
|
140
114
|
|
141
115
|
|
142
116
|
## Types for agent's responses
|
143
117
|
|
144
|
-
_console = Console(highlight=False)
|
145
|
-
|
118
|
+
_console = Console(highlight=False, force_jupyter=False, force_terminal=True)
|
146
119
|
|
147
|
-
class AgentResponseBase(BaseModel):
|
148
|
-
type: Literal["output_text", "tool_call", "tool_call_result", "reasoning", "error"]
|
149
|
-
end_of_turn: bool
|
150
|
-
role: Literal["assistant", "tool"]
|
151
|
-
content: Any
|
152
120
|
|
153
|
-
|
154
|
-
raise NotImplementedError
|
155
|
-
|
156
|
-
|
157
|
-
class AgentResponseOutputText(AgentResponseBase):
|
121
|
+
class AgentResponseOutputText(BaseModel):
|
158
122
|
type: Literal["output_text", "reasoning"]
|
159
|
-
role: Literal["assistant"]
|
123
|
+
role: Literal["assistant"] = "assistant"
|
124
|
+
is_type_switched: bool = False
|
160
125
|
content: str
|
161
126
|
|
162
127
|
def print(self):
|
128
|
+
if self.is_type_switched:
|
129
|
+
_console.print() # add newline if type has been switched
|
163
130
|
_console.print(self.content, end="", style=("yellow" if self.type == "reasoning" else None))
|
164
|
-
if self.end_of_turn:
|
165
|
-
_console.print()
|
166
131
|
|
167
132
|
|
168
|
-
class AgentResponseToolCall(
|
169
|
-
type: Literal["tool_call"]
|
170
|
-
role: Literal["assistant"]
|
171
|
-
|
133
|
+
class AgentResponseToolCall(BaseModel):
|
134
|
+
type: Literal["tool_call"] = "tool_call"
|
135
|
+
role: Literal["assistant"] = "assistant"
|
136
|
+
is_type_switched: bool = False
|
137
|
+
content: FunctionData
|
172
138
|
|
173
139
|
def print(self):
|
140
|
+
title = f"[magenta]Tool Call[/magenta]: [bold]{self.content.function.name}[/bold]"
|
141
|
+
if self.content.id is not None and len(self.content.id) > 0:
|
142
|
+
title += f" ({self.content.id})"
|
174
143
|
panel = Panel(
|
175
144
|
json.dumps(self.content.function.arguments, indent=2),
|
176
|
-
title=
|
145
|
+
title=title,
|
177
146
|
title_align="left",
|
178
147
|
)
|
179
148
|
_console.print(panel)
|
180
149
|
|
181
150
|
|
182
|
-
class
|
183
|
-
type: Literal["tool_call_result"]
|
184
|
-
role: Literal["tool"]
|
185
|
-
|
151
|
+
class AgentResponseToolResult(BaseModel):
|
152
|
+
type: Literal["tool_call_result"] = "tool_call_result"
|
153
|
+
role: Literal["tool"] = "tool"
|
154
|
+
is_type_switched: bool = False
|
155
|
+
content: ToolMessage
|
186
156
|
|
187
157
|
def print(self):
|
188
158
|
try:
|
189
159
|
# Try to parse as json
|
190
|
-
content = json.dumps(json.loads(self.content.content), indent=2)
|
160
|
+
content = json.dumps(json.loads(self.content.content[0].text), indent=2)
|
191
161
|
except json.JSONDecodeError:
|
192
162
|
# Use original content if not json deserializable
|
193
|
-
content = self.content.content
|
163
|
+
content = self.content.content[0].text
|
194
164
|
# Truncate long contents
|
195
165
|
if len(content) > 500:
|
196
166
|
content = content[:500] + "...(truncated)"
|
197
167
|
|
168
|
+
title = "[green]Tool Result[/green]"
|
169
|
+
if self.content.tool_call_id is not None and len(self.content.tool_call_id) > 0:
|
170
|
+
title += f" ({self.content.tool_call_id})"
|
198
171
|
panel = Panel(
|
199
172
|
content,
|
200
|
-
title=
|
173
|
+
title=title,
|
201
174
|
title_align="left",
|
202
175
|
)
|
203
176
|
_console.print(panel)
|
204
177
|
|
205
178
|
|
206
|
-
class AgentResponseError(
|
207
|
-
type: Literal["error"]
|
208
|
-
role: Literal["assistant"]
|
179
|
+
class AgentResponseError(BaseModel):
|
180
|
+
type: Literal["error"] = "error"
|
181
|
+
role: Literal["assistant"] = "assistant"
|
182
|
+
is_type_switched: bool = False
|
209
183
|
content: str
|
210
184
|
|
211
185
|
def print(self):
|
@@ -219,7 +193,7 @@ class AgentResponseError(AgentResponseBase):
|
|
219
193
|
AgentResponse = Union[
|
220
194
|
AgentResponseOutputText,
|
221
195
|
AgentResponseToolCall,
|
222
|
-
|
196
|
+
AgentResponseToolResult,
|
223
197
|
AgentResponseError,
|
224
198
|
]
|
225
199
|
|
@@ -305,22 +279,6 @@ class BearerAuthenticator(ToolAuthenticator):
|
|
305
279
|
return {**request, "headers": headers}
|
306
280
|
|
307
281
|
|
308
|
-
T_Retval = TypeVar("T_Retval")
|
309
|
-
|
310
|
-
|
311
|
-
def run_async(coro: Callable[..., Awaitable[T_Retval]]) -> T_Retval:
|
312
|
-
try:
|
313
|
-
import anyio
|
314
|
-
|
315
|
-
# Running outside async loop
|
316
|
-
return anyio.run(lambda: coro)
|
317
|
-
except RuntimeError:
|
318
|
-
import anyio.from_thread
|
319
|
-
|
320
|
-
# Already in a running event loop: use anyio from_thread
|
321
|
-
return anyio.from_thread.run(coro)
|
322
|
-
|
323
|
-
|
324
282
|
class Agent:
|
325
283
|
"""
|
326
284
|
The `Agent` class provides a high-level interface for interacting with large language models (LLMs) in Ailoy.
|
@@ -334,39 +292,37 @@ class Agent:
|
|
334
292
|
def __init__(
|
335
293
|
self,
|
336
294
|
runtime: Runtime,
|
337
|
-
|
295
|
+
model: APIModel | LocalModel,
|
338
296
|
system_message: Optional[str] = None,
|
339
|
-
api_key: Optional[str] = None,
|
340
|
-
attrs: dict[str, Any] = dict(),
|
341
297
|
):
|
342
298
|
"""
|
343
299
|
Create an instance.
|
344
300
|
|
345
301
|
:param runtime: The runtime environment associated with the agent.
|
346
|
-
:param
|
302
|
+
:param model: The model instance.
|
347
303
|
:param system_message: Optional system message to set the initial assistant context.
|
348
|
-
:param api_key: (web agent only) The API key for AI API.
|
349
|
-
:param attrs: Additional initialization parameters (for `define_component` runtime call)
|
350
304
|
:raises ValueError: If model name is not supported or validation fails.
|
351
305
|
"""
|
352
306
|
self._runtime = runtime
|
353
307
|
|
354
308
|
# Initialize component state
|
355
|
-
self.
|
356
|
-
|
357
|
-
valid=False,
|
358
|
-
)
|
309
|
+
self._component_name = generate_uuid()
|
310
|
+
self._component_ready = False
|
359
311
|
|
360
312
|
# Initialize messages
|
361
313
|
self._messages: list[Message] = []
|
362
|
-
|
363
|
-
|
314
|
+
|
315
|
+
# Initialize system message
|
316
|
+
self._system_message = system_message
|
364
317
|
|
365
318
|
# Initialize tools
|
366
319
|
self._tools: list[Tool] = []
|
367
320
|
|
321
|
+
# Initialize MCP servers
|
322
|
+
self._mcp_servers: list[MCPServer] = []
|
323
|
+
|
368
324
|
# Define the component
|
369
|
-
self.define(
|
325
|
+
self.define(model)
|
370
326
|
|
371
327
|
def __del__(self):
|
372
328
|
self.delete()
|
@@ -377,151 +333,216 @@ class Agent:
|
|
377
333
|
def __exit__(self, type, value, traceback):
|
378
334
|
self.delete()
|
379
335
|
|
380
|
-
def define(self,
|
336
|
+
def define(self, model: APIModel | LocalModel) -> None:
|
381
337
|
"""
|
382
338
|
Initializes the agent by defining its model in the runtime.
|
383
339
|
This must be called before running the agent. If already initialized, this is a no-op.
|
384
|
-
:param
|
385
|
-
:param api_key: (web agent only) The API key for AI API.
|
386
|
-
:param attrs: Additional initialization parameters (for `define_component` runtime call)
|
340
|
+
:param model: The model instance.
|
387
341
|
"""
|
388
|
-
if self.
|
342
|
+
if self._component_ready:
|
389
343
|
return
|
390
344
|
|
391
|
-
if
|
392
|
-
raise ValueError(
|
345
|
+
if not self._runtime.is_alive():
|
346
|
+
raise ValueError("Runtime is currently stopped.")
|
393
347
|
|
394
|
-
|
348
|
+
# Set default system message if not given; still can be None
|
349
|
+
if self._system_message is None:
|
350
|
+
self._system_message = getattr(model, "default_system_message", None)
|
395
351
|
|
396
|
-
|
397
|
-
if "model" not in attrs:
|
398
|
-
attrs["model"] = model_desc.model_id
|
399
|
-
|
400
|
-
# Set default system message
|
401
|
-
if len(self._messages) == 0 and model_desc.default_system_message:
|
402
|
-
self._messages.append(SystemMessage(role="system", content=model_desc.default_system_message))
|
403
|
-
|
404
|
-
# Add API key
|
405
|
-
if api_key:
|
406
|
-
attrs["api_key"] = api_key
|
352
|
+
self.clear_messages()
|
407
353
|
|
408
354
|
# Call runtime's define
|
409
355
|
self._runtime.define(
|
410
|
-
|
411
|
-
self.
|
412
|
-
|
356
|
+
model.component_type,
|
357
|
+
self._component_name,
|
358
|
+
model.to_attrs(),
|
413
359
|
)
|
414
360
|
|
415
361
|
# Mark as defined
|
416
|
-
self.
|
362
|
+
self._component_ready = True
|
417
363
|
|
418
364
|
def delete(self) -> None:
|
419
365
|
"""
|
420
366
|
Deinitializes the agent and releases resources in the runtime.
|
421
367
|
This should be called when the agent is no longer needed. If already deinitialized, this is a no-op.
|
422
368
|
"""
|
423
|
-
if not self.
|
369
|
+
if not self._component_ready:
|
424
370
|
return
|
425
|
-
|
426
|
-
if
|
427
|
-
self.
|
428
|
-
|
429
|
-
|
430
|
-
|
371
|
+
|
372
|
+
if self._runtime.is_alive():
|
373
|
+
self._runtime.delete(self._component_name)
|
374
|
+
|
375
|
+
self.clear_messages()
|
376
|
+
|
377
|
+
for mcp_server in self._mcp_servers:
|
378
|
+
mcp_server.cleanup()
|
379
|
+
|
380
|
+
self._component_ready = False
|
431
381
|
|
432
382
|
def query(
|
433
383
|
self,
|
434
|
-
message: str,
|
435
|
-
|
436
|
-
ignore_reasoning_messages: bool = False,
|
384
|
+
message: str | list[str | Image | dict | TextContent | ImageContent | AudioContent],
|
385
|
+
reasoning: bool = False,
|
437
386
|
) -> Generator[AgentResponse, None, None]:
|
438
387
|
"""
|
439
388
|
Runs the agent with a new user message and yields streamed responses.
|
440
389
|
|
441
390
|
:param message: The user message to send to the model.
|
442
|
-
:param
|
443
|
-
:
|
444
|
-
:
|
391
|
+
:param reasoning: If True, enables reasoning capabilities. (Default: False)
|
392
|
+
:return: An iterator over the output, where each item represents either a generated token from the assistant or a tool call.
|
393
|
+
:rtype: Iterator[:class:`AgentResponse`]
|
445
394
|
""" # noqa: E501
|
446
|
-
self.
|
395
|
+
if not self._component_ready:
|
396
|
+
raise ValueError("Agent is not valid. Create one or define newly.")
|
397
|
+
|
398
|
+
if not self._runtime.is_alive():
|
399
|
+
raise ValueError("Runtime is currently stopped.")
|
400
|
+
|
401
|
+
if isinstance(message, str):
|
402
|
+
self._messages.append(UserMessage(content=[TextContent(text=message)]))
|
403
|
+
elif isinstance(message, list):
|
404
|
+
if len(message) == 0:
|
405
|
+
raise ValueError("Message is empty")
|
406
|
+
|
407
|
+
contents = []
|
408
|
+
for content in message:
|
409
|
+
if isinstance(content, str):
|
410
|
+
contents.append(TextContent(text=content))
|
411
|
+
elif isinstance(content, Image):
|
412
|
+
contents.append(ImageContent.from_pillow(image=content))
|
413
|
+
elif isinstance(content, dict):
|
414
|
+
ta: TypeAdapter[TextContent | ImageContent | AudioContent] = TypeAdapter(
|
415
|
+
Annotated[TextContent | ImageContent | AudioContent, Field(discriminator="type")]
|
416
|
+
)
|
417
|
+
validated_content = ta.validate_python(content)
|
418
|
+
contents.append(validated_content)
|
419
|
+
else:
|
420
|
+
contents.append(content)
|
421
|
+
|
422
|
+
self._messages.append(UserMessage(content=contents))
|
423
|
+
else:
|
424
|
+
raise ValueError(f"Invalid message type: {type(message)}")
|
425
|
+
|
426
|
+
prev_resp_type = None
|
447
427
|
|
448
428
|
while True:
|
449
429
|
infer_args = {
|
450
|
-
"messages": [msg.model_dump() for msg in self._messages],
|
451
|
-
"tools": [{"type": "function", "function": t.desc.model_dump()} for t in self._tools],
|
430
|
+
"messages": [msg.model_dump(exclude_none=True) for msg in self._messages],
|
431
|
+
"tools": [{"type": "function", "function": t.desc.model_dump(exclude_none=True)} for t in self._tools],
|
452
432
|
}
|
453
|
-
if
|
454
|
-
infer_args["
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
-
self._messages.append(tool_call_message)
|
474
|
-
|
475
|
-
for tool_call in tool_call_message.tool_calls:
|
476
|
-
yield AgentResponseToolCall(
|
477
|
-
type="tool_call",
|
478
|
-
end_of_turn=True,
|
479
|
-
role="assistant",
|
480
|
-
content=tool_call,
|
433
|
+
if reasoning:
|
434
|
+
infer_args["reasoning"] = reasoning
|
435
|
+
|
436
|
+
assistant_reasoning = None
|
437
|
+
assistant_content = None
|
438
|
+
assistant_tool_calls = None
|
439
|
+
finish_reason = ""
|
440
|
+
for result in self._runtime.call_iter_method(self._component_name, "infer", infer_args):
|
441
|
+
msg = MessageOutput.model_validate(result)
|
442
|
+
|
443
|
+
if msg.message.reasoning:
|
444
|
+
for v in msg.message.reasoning:
|
445
|
+
if not assistant_reasoning:
|
446
|
+
assistant_reasoning = [v]
|
447
|
+
else:
|
448
|
+
assistant_reasoning[0].text += v.text
|
449
|
+
resp = AgentResponseOutputText(
|
450
|
+
type="reasoning",
|
451
|
+
is_type_switched=(prev_resp_type != "reasoning"),
|
452
|
+
content=v.text,
|
481
453
|
)
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
|
454
|
+
prev_resp_type = resp.type
|
455
|
+
yield resp
|
456
|
+
if msg.message.content is not None:
|
457
|
+
# Canonicalize message content to the array of TextContent
|
458
|
+
if isinstance(msg.message.content, str):
|
459
|
+
msg.message.content = [TextContent(text=msg.message.content)]
|
460
|
+
|
461
|
+
for v in msg.message.content:
|
462
|
+
if not assistant_content:
|
463
|
+
assistant_content = [v]
|
464
|
+
else:
|
465
|
+
assistant_content[0].text += v.text
|
466
|
+
resp = AgentResponseOutputText(
|
467
|
+
type="output_text",
|
468
|
+
is_type_switched=(prev_resp_type != "output_text"),
|
469
|
+
content=v.text,
|
489
470
|
)
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
471
|
+
prev_resp_type = resp.type
|
472
|
+
yield resp
|
473
|
+
if msg.message.tool_calls:
|
474
|
+
for v in msg.message.tool_calls:
|
475
|
+
if not assistant_tool_calls:
|
476
|
+
assistant_tool_calls = [v]
|
477
|
+
else:
|
478
|
+
assistant_tool_calls.append(v)
|
479
|
+
resp = AgentResponseToolCall(
|
480
|
+
is_type_switched=True,
|
481
|
+
content=v,
|
498
482
|
)
|
483
|
+
prev_resp_type = resp.type
|
484
|
+
yield resp
|
485
|
+
if msg.finish_reason:
|
486
|
+
finish_reason = msg.finish_reason
|
487
|
+
break
|
488
|
+
|
489
|
+
# Append output
|
490
|
+
self._messages.append(
|
491
|
+
AssistantMessage(
|
492
|
+
reasoning=assistant_reasoning,
|
493
|
+
content=assistant_content,
|
494
|
+
tool_calls=assistant_tool_calls,
|
495
|
+
)
|
496
|
+
)
|
497
|
+
|
498
|
+
if finish_reason == "tool_calls":
|
499
|
+
|
500
|
+
def run_tool(tool_call: FunctionData) -> ToolMessage:
|
501
|
+
tool_ = next(
|
502
|
+
(t for t in self._tools if t.desc.name == tool_call.function.name),
|
503
|
+
None,
|
504
|
+
)
|
505
|
+
if not tool_:
|
506
|
+
raise RuntimeError("Tool not found")
|
507
|
+
tool_result = tool_.call(**tool_call.function.arguments)
|
508
|
+
return ToolMessage(
|
509
|
+
content=[TextContent(text=json.dumps(tool_result))],
|
510
|
+
tool_call_id=tool_call.id,
|
511
|
+
)
|
499
512
|
|
500
|
-
|
513
|
+
tool_call_results = [run_tool(tc) for tc in assistant_tool_calls]
|
514
|
+
for result_msg in tool_call_results:
|
515
|
+
self._messages.append(result_msg)
|
516
|
+
resp = AgentResponseToolResult(
|
517
|
+
is_type_switched=True,
|
518
|
+
content=result_msg,
|
519
|
+
)
|
520
|
+
prev_resp_type = resp.type
|
521
|
+
yield resp
|
522
|
+
# Infer again if tool calls happened
|
523
|
+
continue
|
501
524
|
|
502
|
-
|
503
|
-
|
504
|
-
|
505
|
-
type="tool_call_result",
|
506
|
-
end_of_turn=True,
|
507
|
-
role="tool",
|
508
|
-
content=result_msg,
|
509
|
-
)
|
525
|
+
# Finish this generator
|
526
|
+
yield AgentResponseOutputText(type="output_text", content="\n")
|
527
|
+
break
|
510
528
|
|
511
|
-
|
512
|
-
|
529
|
+
def get_messages(self) -> list[Message]:
|
530
|
+
"""
|
531
|
+
Get the current conversation history.
|
532
|
+
Each item in the list represents a message from either the user or the assistant.
|
513
533
|
|
514
|
-
|
515
|
-
|
516
|
-
|
517
|
-
|
518
|
-
end_of_turn=True,
|
519
|
-
role="assistant",
|
520
|
-
content=output_msg.content,
|
521
|
-
)
|
534
|
+
:return: The conversation history so far in the form of a list.
|
535
|
+
:rtype: list[Message]
|
536
|
+
"""
|
537
|
+
return self._messages
|
522
538
|
|
523
|
-
|
524
|
-
|
539
|
+
def clear_messages(self):
|
540
|
+
"""
|
541
|
+
Clear the history of conversation messages.
|
542
|
+
"""
|
543
|
+
self._messages.clear()
|
544
|
+
if self._system_message is not None:
|
545
|
+
self._messages.append(SystemMessage(role="system", content=[TextContent(text=self._system_message)]))
|
525
546
|
|
526
547
|
def print(self, resp: AgentResponse):
|
527
548
|
resp.print()
|
@@ -537,14 +558,29 @@ class Agent:
|
|
537
558
|
return
|
538
559
|
self._tools.append(tool)
|
539
560
|
|
540
|
-
def add_py_function_tool(self,
|
561
|
+
def add_py_function_tool(self, f: Callable[..., Any], desc: Optional[dict] = None):
|
541
562
|
"""
|
542
563
|
Adds a Python function as a tool using callable.
|
543
564
|
|
544
|
-
:param desc: Tool descriotion.
|
545
565
|
:param f: Function will be called when the tool invocation occured.
|
566
|
+
:param desc: Tool description to override. If not given, parsed from docstring of function `f`.
|
567
|
+
|
568
|
+
:raises ValueError: Docstring parsing is failed.
|
569
|
+
:raises ValidationError: Given or parsed description is not a valid `ToolDescription`.
|
546
570
|
"""
|
547
|
-
|
571
|
+
tool_description = None
|
572
|
+
if desc is not None:
|
573
|
+
tool_description = ToolDescription.model_validate(desc)
|
574
|
+
|
575
|
+
if tool_description is None:
|
576
|
+
try:
|
577
|
+
json_schema = get_json_schema(f)
|
578
|
+
except (TypeHintParsingException, DocstringParsingException) as e:
|
579
|
+
raise ValueError("Failed to parse docstring", e)
|
580
|
+
|
581
|
+
tool_description = ToolDescription.model_validate(json_schema.get("function", {}))
|
582
|
+
|
583
|
+
self.add_tool(Tool(desc=tool_description, call_fn=f))
|
548
584
|
|
549
585
|
def add_builtin_tool(self, tool_def: BuiltinToolDefinition) -> bool:
|
550
586
|
"""
|
@@ -669,61 +705,65 @@ class Agent:
|
|
669
705
|
else:
|
670
706
|
warnings.warn(f'Tool type "{tool_type}" is not supported. Skip adding tool "{tool_name}".')
|
671
707
|
|
672
|
-
def
|
708
|
+
def add_tools_from_mcp_server(
|
709
|
+
self, name: str, params: StdioServerParameters, tools_to_add: Optional[list[str]] = None
|
710
|
+
):
|
673
711
|
"""
|
674
|
-
|
712
|
+
Create a MCP server and register its tools to agent.
|
675
713
|
|
714
|
+
:param name: The unique name of the MCP server.
|
715
|
+
If there's already a MCP server with the same name, it raises RuntimeError.
|
676
716
|
:param params: Parameters for connecting to the MCP stdio server.
|
677
|
-
:param
|
678
|
-
:returns: True if the tool was successfully added.
|
717
|
+
:param tools_to_add: Optional list of tool names to add. If None, all tools are added.
|
679
718
|
"""
|
680
|
-
|
719
|
+
if any([s.name == name for s in self._mcp_servers]):
|
720
|
+
raise RuntimeError(f"MCP server with name '{name}' is already registered")
|
681
721
|
|
682
|
-
|
683
|
-
|
684
|
-
|
685
|
-
|
686
|
-
|
687
|
-
|
688
|
-
|
689
|
-
|
690
|
-
|
691
|
-
|
692
|
-
|
693
|
-
|
694
|
-
|
695
|
-
|
696
|
-
|
697
|
-
|
698
|
-
|
699
|
-
|
700
|
-
|
701
|
-
|
702
|
-
|
703
|
-
return run_async(_inner())
|
704
|
-
|
705
|
-
desc = ToolDescription(name=tool.name, description=tool.description, parameters=tool.inputSchema)
|
706
|
-
return self.add_tool(Tool(desc=desc, call_fn=call))
|
707
|
-
|
708
|
-
def add_tools_from_mcp_server(self, params: mcp.StdioServerParameters, tools_to_add: Optional[list[str]] = None):
|
722
|
+
# Create and register MCP server
|
723
|
+
mcp_server = MCPServer(name, params)
|
724
|
+
self._mcp_servers.append(mcp_server)
|
725
|
+
|
726
|
+
# Register tools
|
727
|
+
for tool in mcp_server.list_tools():
|
728
|
+
# Skip if this tool is not in the whitelist
|
729
|
+
if tools_to_add is not None and tool.name not in tools_to_add:
|
730
|
+
continue
|
731
|
+
|
732
|
+
desc = ToolDescription(
|
733
|
+
name=f"{name}-{tool.name}", description=tool.description, parameters=tool.inputSchema
|
734
|
+
)
|
735
|
+
|
736
|
+
def call(tool: MCPTool, **inputs: dict[str, Any]) -> list[str]:
|
737
|
+
return mcp_server.call_tool(tool, inputs)
|
738
|
+
|
739
|
+
self.add_tool(Tool(desc=desc, call_fn=partial(call, tool)))
|
740
|
+
|
741
|
+
def remove_mcp_server(self, name: str):
|
709
742
|
"""
|
710
|
-
|
743
|
+
Removes the MCP server and its tools from the agent, with terminating the MCP server process.
|
711
744
|
|
712
|
-
:param
|
713
|
-
|
714
|
-
|
745
|
+
:param name: The unique name of the MCP server.
|
746
|
+
If there's no MCP server matches the name, it raises RuntimeError.
|
747
|
+
"""
|
748
|
+
if all([s.name != name for s in self._mcp_servers]):
|
749
|
+
raise RuntimeError(f"MCP server with name '{name}' does not exist")
|
750
|
+
|
751
|
+
# Remove the MCP server
|
752
|
+
mcp_server = next(filter(lambda s: s.name == name, self._mcp_servers))
|
753
|
+
self._mcp_servers.remove(mcp_server)
|
754
|
+
mcp_server.cleanup()
|
755
|
+
|
756
|
+
# Remove tools registered from the MCP server
|
757
|
+
self._tools = list(filter(lambda t: not t.desc.name.startswith(f"{mcp_server.name}-"), self._tools))
|
758
|
+
|
759
|
+
def get_tools(self):
|
760
|
+
"""
|
761
|
+
Get the list of registered tools.
|
762
|
+
"""
|
763
|
+
return self._tools
|
764
|
+
|
765
|
+
def clear_tools(self):
|
766
|
+
"""
|
767
|
+
Clear the registered tools.
|
715
768
|
"""
|
716
|
-
|
717
|
-
|
718
|
-
async def _inner():
|
719
|
-
async with stdio_client(params, errlog=subprocess.STDOUT) as streams:
|
720
|
-
async with mcp.ClientSession(*streams) as session:
|
721
|
-
await session.initialize()
|
722
|
-
resp = await session.list_tools()
|
723
|
-
for tool in resp.tools:
|
724
|
-
if tools_to_add is None or tool.name in tools_to_add:
|
725
|
-
self.add_mcp_tool(params, tool)
|
726
|
-
return resp.tools
|
727
|
-
|
728
|
-
tools = run_async(_inner())
|
729
|
-
return tools
|
769
|
+
self._tools.clear()
|