ailoy-py 0.0.1__cp311-cp311-manylinux_2_28_x86_64.whl → 0.0.2__cp311-cp311-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ailoy/__init__.py +18 -0
- ailoy/agent.py +273 -196
- ailoy/ailoy_py.cpython-311-x86_64-linux-gnu.so +0 -0
- ailoy/mcp.py +159 -0
- ailoy/runtime.py +34 -19
- ailoy/tools.py +205 -0
- ailoy/vector_store.py +10 -9
- {ailoy_py-0.0.1.dist-info → ailoy_py-0.0.2.dist-info}/METADATA +2 -2
- ailoy_py-0.0.2.dist-info/RECORD +20 -0
- {ailoy_py-0.0.1.dist-info → ailoy_py-0.0.2.dist-info}/WHEEL +1 -1
- ailoy_py.libs/libtvm_runtime-2d14ca42.so +0 -0
- ailoy_py-0.0.1.dist-info/RECORD +0 -19
- ailoy_py.libs/libmvec-2-8eb5c230.28.so +0 -0
- ailoy_py.libs/libtvm_runtime-7067e461.so +0 -0
- {ailoy_py-0.0.1.dist-info → ailoy_py-0.0.2.dist-info}/entry_points.txt +0 -0
ailoy/__init__.py
CHANGED
@@ -1,3 +1,21 @@
|
|
1
|
+
if __doc__ is None:
|
2
|
+
try:
|
3
|
+
import importlib.metadata
|
4
|
+
|
5
|
+
meta = importlib.metadata.metadata("ailoy-py")
|
6
|
+
__doc__ = meta.get("Description")
|
7
|
+
except importlib.metadata.PackageNotFoundError:
|
8
|
+
pass
|
9
|
+
|
10
|
+
if __doc__ is None:
|
11
|
+
from pathlib import Path
|
12
|
+
|
13
|
+
readme = Path(__file__).parent.parent / "README.md"
|
14
|
+
if readme.exists():
|
15
|
+
__doc__ = readme.read_text()
|
16
|
+
else: # fallback docstring
|
17
|
+
__doc__ = "# ailoy-py\n\nPython binding for Ailoy runtime APIs"
|
18
|
+
|
1
19
|
from .agent import Agent # noqa: F401
|
2
20
|
from .runtime import AsyncRuntime, Runtime # noqa: F401
|
3
21
|
from .vector_store import VectorStore # noqa: F401
|
ailoy/agent.py
CHANGED
@@ -1,8 +1,8 @@
|
|
1
1
|
import json
|
2
|
-
import subprocess
|
3
2
|
import warnings
|
4
3
|
from abc import ABC, abstractmethod
|
5
4
|
from collections.abc import Awaitable, Callable, Generator
|
5
|
+
from functools import partial
|
6
6
|
from pathlib import Path
|
7
7
|
from typing import (
|
8
8
|
Any,
|
@@ -14,72 +14,75 @@ from typing import (
|
|
14
14
|
from urllib.parse import urlencode, urlparse, urlunparse
|
15
15
|
|
16
16
|
import jmespath
|
17
|
-
import mcp
|
18
|
-
import mcp.types as mcp_types
|
19
17
|
from pydantic import BaseModel, ConfigDict, Field
|
20
18
|
from rich.console import Console
|
21
19
|
from rich.panel import Panel
|
22
20
|
|
23
21
|
from ailoy.ailoy_py import generate_uuid
|
22
|
+
from ailoy.mcp import MCPServer, MCPTool, StdioServerParameters
|
24
23
|
from ailoy.runtime import Runtime
|
24
|
+
from ailoy.tools import DocstringParsingException, TypeHintParsingException, get_json_schema
|
25
25
|
|
26
26
|
__all__ = ["Agent"]
|
27
27
|
|
28
|
-
## Types for
|
28
|
+
## Types for internal data structures
|
29
29
|
|
30
30
|
|
31
|
-
class
|
32
|
-
|
33
|
-
|
31
|
+
class TextData(BaseModel):
|
32
|
+
type: Literal["text"]
|
33
|
+
text: str
|
34
34
|
|
35
35
|
|
36
|
-
class
|
37
|
-
|
38
|
-
|
39
|
-
|
36
|
+
class FunctionData(BaseModel):
|
37
|
+
class FunctionBody(BaseModel):
|
38
|
+
name: str
|
39
|
+
arguments: Any
|
40
40
|
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
reasoning: Optional[bool] = None
|
41
|
+
type: Literal["function"]
|
42
|
+
id: Optional[str] = None
|
43
|
+
function: FunctionBody
|
45
44
|
|
46
45
|
|
47
|
-
class
|
48
|
-
role: Literal["
|
49
|
-
content:
|
50
|
-
tool_calls: list["ToolCall"]
|
46
|
+
class SystemMessage(BaseModel):
|
47
|
+
role: Literal["system"]
|
48
|
+
content: list[TextData]
|
51
49
|
|
52
50
|
|
53
|
-
class
|
54
|
-
|
55
|
-
|
56
|
-
function: "ToolCallFunction"
|
51
|
+
class UserMessage(BaseModel):
|
52
|
+
role: Literal["user"]
|
53
|
+
content: list[TextData]
|
57
54
|
|
58
55
|
|
59
|
-
class
|
60
|
-
|
61
|
-
|
56
|
+
class AssistantMessage(BaseModel):
|
57
|
+
role: Literal["assistant"]
|
58
|
+
reasoning: Optional[list[TextData]] = None
|
59
|
+
content: Optional[list[TextData]] = None
|
60
|
+
tool_calls: Optional[list[FunctionData]] = None
|
62
61
|
|
63
62
|
|
64
|
-
class
|
63
|
+
class ToolMessage(BaseModel):
|
65
64
|
role: Literal["tool"]
|
66
65
|
name: str
|
67
|
-
|
68
|
-
|
66
|
+
content: list[TextData]
|
67
|
+
tool_call_id: Optional[str] = None
|
69
68
|
|
70
69
|
|
71
70
|
Message = Union[
|
72
71
|
SystemMessage,
|
73
72
|
UserMessage,
|
74
|
-
|
75
|
-
|
76
|
-
ToolCallResultMessage,
|
73
|
+
AssistantMessage,
|
74
|
+
ToolMessage,
|
77
75
|
]
|
78
76
|
|
79
77
|
|
80
|
-
class
|
81
|
-
|
82
|
-
|
78
|
+
class MessageOutput(BaseModel):
|
79
|
+
class AssistantMessageDelta(BaseModel):
|
80
|
+
content: Optional[list[TextData]] = None
|
81
|
+
reasoning: Optional[list[TextData]] = None
|
82
|
+
tool_calls: Optional[list[FunctionData]] = None
|
83
|
+
|
84
|
+
message: AssistantMessageDelta
|
85
|
+
finish_reason: Optional[Literal["stop", "tool_calls", "invalid_tool_call", "length", "error"]] = None
|
83
86
|
|
84
87
|
|
85
88
|
## Types for LLM Model Definitions
|
@@ -141,71 +144,71 @@ class ComponentState(BaseModel):
|
|
141
144
|
|
142
145
|
## Types for agent's responses
|
143
146
|
|
144
|
-
_console = Console(highlight=False)
|
147
|
+
_console = Console(highlight=False, force_jupyter=False, force_terminal=True)
|
145
148
|
|
146
149
|
|
147
|
-
class
|
148
|
-
type: Literal["output_text", "tool_call", "tool_call_result", "reasoning", "error"]
|
149
|
-
end_of_turn: bool
|
150
|
-
role: Literal["assistant", "tool"]
|
151
|
-
content: Any
|
152
|
-
|
153
|
-
def print(self):
|
154
|
-
raise NotImplementedError
|
155
|
-
|
156
|
-
|
157
|
-
class AgentResponseOutputText(AgentResponseBase):
|
150
|
+
class AgentResponseOutputText(BaseModel):
|
158
151
|
type: Literal["output_text", "reasoning"]
|
159
152
|
role: Literal["assistant"]
|
153
|
+
is_type_switched: bool = False
|
160
154
|
content: str
|
161
155
|
|
162
156
|
def print(self):
|
157
|
+
if self.is_type_switched:
|
158
|
+
_console.print() # add newline if type has been switched
|
163
159
|
_console.print(self.content, end="", style=("yellow" if self.type == "reasoning" else None))
|
164
|
-
if self.end_of_turn:
|
165
|
-
_console.print()
|
166
160
|
|
167
161
|
|
168
|
-
class AgentResponseToolCall(
|
162
|
+
class AgentResponseToolCall(BaseModel):
|
169
163
|
type: Literal["tool_call"]
|
170
164
|
role: Literal["assistant"]
|
171
|
-
|
165
|
+
is_type_switched: bool = False
|
166
|
+
content: FunctionData
|
172
167
|
|
173
168
|
def print(self):
|
169
|
+
title = f"[magenta]Tool Call[/magenta]: [bold]{self.content.function.name}[/bold]"
|
170
|
+
if self.content.id is not None:
|
171
|
+
title += f" ({self.content.id})"
|
174
172
|
panel = Panel(
|
175
173
|
json.dumps(self.content.function.arguments, indent=2),
|
176
|
-
title=
|
174
|
+
title=title,
|
177
175
|
title_align="left",
|
178
176
|
)
|
179
177
|
_console.print(panel)
|
180
178
|
|
181
179
|
|
182
|
-
class
|
180
|
+
class AgentResponseToolResult(BaseModel):
|
183
181
|
type: Literal["tool_call_result"]
|
184
182
|
role: Literal["tool"]
|
185
|
-
|
183
|
+
is_type_switched: bool = False
|
184
|
+
content: ToolMessage
|
186
185
|
|
187
186
|
def print(self):
|
188
187
|
try:
|
189
188
|
# Try to parse as json
|
190
|
-
content = json.dumps(json.loads(self.content.content), indent=2)
|
189
|
+
content = json.dumps(json.loads(self.content.content[0].text), indent=2)
|
191
190
|
except json.JSONDecodeError:
|
192
191
|
# Use original content if not json deserializable
|
193
|
-
content = self.content.content
|
192
|
+
content = self.content.content[0].text
|
194
193
|
# Truncate long contents
|
195
194
|
if len(content) > 500:
|
196
195
|
content = content[:500] + "...(truncated)"
|
197
196
|
|
197
|
+
title = f"[green]Tool Result[/green]: [bold]{self.content.name}[/bold]"
|
198
|
+
if self.content.tool_call_id is not None:
|
199
|
+
title += f" ({self.content.tool_call_id})"
|
198
200
|
panel = Panel(
|
199
201
|
content,
|
200
|
-
title=
|
202
|
+
title=title,
|
201
203
|
title_align="left",
|
202
204
|
)
|
203
205
|
_console.print(panel)
|
204
206
|
|
205
207
|
|
206
|
-
class AgentResponseError(
|
208
|
+
class AgentResponseError(BaseModel):
|
207
209
|
type: Literal["error"]
|
208
210
|
role: Literal["assistant"]
|
211
|
+
is_type_switched: bool = False
|
209
212
|
content: str
|
210
213
|
|
211
214
|
def print(self):
|
@@ -219,7 +222,7 @@ class AgentResponseError(AgentResponseBase):
|
|
219
222
|
AgentResponse = Union[
|
220
223
|
AgentResponseOutputText,
|
221
224
|
AgentResponseToolCall,
|
222
|
-
|
225
|
+
AgentResponseToolResult,
|
223
226
|
AgentResponseError,
|
224
227
|
]
|
225
228
|
|
@@ -337,7 +340,7 @@ class Agent:
|
|
337
340
|
model_name: ModelName,
|
338
341
|
system_message: Optional[str] = None,
|
339
342
|
api_key: Optional[str] = None,
|
340
|
-
attrs
|
343
|
+
**attrs,
|
341
344
|
):
|
342
345
|
"""
|
343
346
|
Create an instance.
|
@@ -359,14 +362,18 @@ class Agent:
|
|
359
362
|
|
360
363
|
# Initialize messages
|
361
364
|
self._messages: list[Message] = []
|
362
|
-
|
363
|
-
|
365
|
+
|
366
|
+
# Initialize system message
|
367
|
+
self._system_message = system_message
|
364
368
|
|
365
369
|
# Initialize tools
|
366
370
|
self._tools: list[Tool] = []
|
367
371
|
|
372
|
+
# Initialize MCP servers
|
373
|
+
self._mcp_servers: list[MCPServer] = []
|
374
|
+
|
368
375
|
# Define the component
|
369
|
-
self.define(model_name, api_key=api_key, attrs
|
376
|
+
self.define(model_name, api_key=api_key, **attrs)
|
370
377
|
|
371
378
|
def __del__(self):
|
372
379
|
self.delete()
|
@@ -377,7 +384,7 @@ class Agent:
|
|
377
384
|
def __exit__(self, type, value, traceback):
|
378
385
|
self.delete()
|
379
386
|
|
380
|
-
def define(self, model_name: ModelName, api_key: Optional[str] = None, attrs
|
387
|
+
def define(self, model_name: ModelName, api_key: Optional[str] = None, **attrs) -> None:
|
381
388
|
"""
|
382
389
|
Initializes the agent by defining its model in the runtime.
|
383
390
|
This must be called before running the agent. If already initialized, this is a no-op.
|
@@ -388,6 +395,9 @@ class Agent:
|
|
388
395
|
if self._component_state.valid:
|
389
396
|
return
|
390
397
|
|
398
|
+
if not self._runtime.is_alive():
|
399
|
+
raise ValueError("Runtime is currently stopped.")
|
400
|
+
|
391
401
|
if model_name not in model_descriptions:
|
392
402
|
raise ValueError(f"Model `{model_name}` not supported")
|
393
403
|
|
@@ -397,9 +407,11 @@ class Agent:
|
|
397
407
|
if "model" not in attrs:
|
398
408
|
attrs["model"] = model_desc.model_id
|
399
409
|
|
400
|
-
# Set default system message
|
401
|
-
if
|
402
|
-
self.
|
410
|
+
# Set default system message if not given; still can be None
|
411
|
+
if self._system_message is None:
|
412
|
+
self._system_message = model_desc.default_system_message
|
413
|
+
|
414
|
+
self.clear_messages()
|
403
415
|
|
404
416
|
# Add API key
|
405
417
|
if api_key:
|
@@ -422,106 +434,164 @@ class Agent:
|
|
422
434
|
"""
|
423
435
|
if not self._component_state.valid:
|
424
436
|
return
|
425
|
-
|
426
|
-
if
|
427
|
-
self.
|
428
|
-
|
429
|
-
|
437
|
+
|
438
|
+
if self._runtime.is_alive():
|
439
|
+
self._runtime.delete(self._component_state.name)
|
440
|
+
|
441
|
+
self.clear_messages()
|
442
|
+
|
443
|
+
for mcp_server in self._mcp_servers:
|
444
|
+
mcp_server.cleanup()
|
445
|
+
|
430
446
|
self._component_state.valid = False
|
431
447
|
|
432
448
|
def query(
|
433
449
|
self,
|
434
450
|
message: str,
|
435
|
-
|
436
|
-
ignore_reasoning_messages: bool = False,
|
451
|
+
reasoning: bool = False,
|
437
452
|
) -> Generator[AgentResponse, None, None]:
|
438
453
|
"""
|
439
454
|
Runs the agent with a new user message and yields streamed responses.
|
440
455
|
|
441
456
|
:param message: The user message to send to the model.
|
442
|
-
:param
|
443
|
-
:
|
444
|
-
:
|
457
|
+
:param reasoning: If True, enables reasoning capabilities. (Default: False)
|
458
|
+
:return: An iterator over the output, where each item represents either a generated token from the assistant or a tool call.
|
459
|
+
:rtype: Iterator[:class:`AgentResponse`]
|
445
460
|
""" # noqa: E501
|
446
|
-
self.
|
461
|
+
if not self._component_state.valid:
|
462
|
+
raise ValueError("Agent is not valid. Create one or define newly.")
|
463
|
+
|
464
|
+
if not self._runtime.is_alive():
|
465
|
+
raise ValueError("Runtime is currently stopped.")
|
466
|
+
|
467
|
+
self._messages.append(UserMessage(role="user", content=[{"type": "text", "text": message}]))
|
468
|
+
|
469
|
+
prev_resp_type = None
|
447
470
|
|
448
471
|
while True:
|
449
472
|
infer_args = {
|
450
|
-
"messages": [msg.model_dump() for msg in self._messages],
|
451
|
-
"tools": [{"type": "function", "function": t.desc.model_dump()} for t in self._tools],
|
473
|
+
"messages": [msg.model_dump(exclude_none=True) for msg in self._messages],
|
474
|
+
"tools": [{"type": "function", "function": t.desc.model_dump(exclude_none=True)} for t in self._tools],
|
452
475
|
}
|
453
|
-
if
|
454
|
-
infer_args["
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
if delta.finish_reason == "tool_calls":
|
472
|
-
tool_call_message = AIToolCallMessage.model_validate(delta.message)
|
473
|
-
self._messages.append(tool_call_message)
|
474
|
-
|
475
|
-
for tool_call in tool_call_message.tool_calls:
|
476
|
-
yield AgentResponseToolCall(
|
477
|
-
type="tool_call",
|
478
|
-
end_of_turn=True,
|
476
|
+
if reasoning:
|
477
|
+
infer_args["reasoning"] = reasoning
|
478
|
+
|
479
|
+
assistant_reasoning = None
|
480
|
+
assistant_content = None
|
481
|
+
assistant_tool_calls = None
|
482
|
+
finish_reason = ""
|
483
|
+
for result in self._runtime.call_iter_method(self._component_state.name, "infer", infer_args):
|
484
|
+
msg = MessageOutput.model_validate(result)
|
485
|
+
|
486
|
+
if msg.message.reasoning:
|
487
|
+
for v in msg.message.reasoning:
|
488
|
+
if not assistant_reasoning:
|
489
|
+
assistant_reasoning = [v]
|
490
|
+
else:
|
491
|
+
assistant_reasoning[0].text += v.text
|
492
|
+
resp = AgentResponseOutputText(
|
493
|
+
type="reasoning",
|
479
494
|
role="assistant",
|
480
|
-
|
495
|
+
is_type_switched=(prev_resp_type != "reasoning"),
|
496
|
+
content=v.text,
|
481
497
|
)
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
|
498
|
+
prev_resp_type = resp.type
|
499
|
+
yield resp
|
500
|
+
if msg.message.content:
|
501
|
+
for v in msg.message.content:
|
502
|
+
if not assistant_content:
|
503
|
+
assistant_content = [v]
|
504
|
+
else:
|
505
|
+
assistant_content[0].text += v.text
|
506
|
+
resp = AgentResponseOutputText(
|
507
|
+
type="output_text",
|
508
|
+
role="assistant",
|
509
|
+
is_type_switched=(prev_resp_type != "output_text"),
|
510
|
+
content=v.text,
|
489
511
|
)
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
512
|
+
prev_resp_type = resp.type
|
513
|
+
yield resp
|
514
|
+
if msg.message.tool_calls:
|
515
|
+
for v in msg.message.tool_calls:
|
516
|
+
if not assistant_tool_calls:
|
517
|
+
assistant_tool_calls = [v]
|
518
|
+
else:
|
519
|
+
assistant_tool_calls.append(v)
|
520
|
+
resp = AgentResponseToolCall(
|
521
|
+
type="tool_call",
|
522
|
+
role="assistant",
|
523
|
+
is_type_switched=True,
|
524
|
+
content=v,
|
498
525
|
)
|
526
|
+
prev_resp_type = resp.type
|
527
|
+
yield resp
|
528
|
+
if msg.finish_reason:
|
529
|
+
finish_reason = msg.finish_reason
|
530
|
+
break
|
531
|
+
|
532
|
+
# Append output
|
533
|
+
self._messages.append(
|
534
|
+
AssistantMessage(
|
535
|
+
role="assistant",
|
536
|
+
reasoning=assistant_reasoning,
|
537
|
+
content=assistant_content,
|
538
|
+
tool_calls=assistant_tool_calls,
|
539
|
+
)
|
540
|
+
)
|
541
|
+
|
542
|
+
if finish_reason == "tool_calls":
|
543
|
+
|
544
|
+
def run_tool(tool_call: FunctionData) -> ToolMessage:
|
545
|
+
tool_ = next(
|
546
|
+
(t for t in self._tools if t.desc.name == tool_call.function.name),
|
547
|
+
None,
|
548
|
+
)
|
549
|
+
if not tool_:
|
550
|
+
raise RuntimeError("Tool not found")
|
551
|
+
tool_result = tool_.call(**tool_call.function.arguments)
|
552
|
+
return ToolMessage(
|
553
|
+
role="tool",
|
554
|
+
name=tool_call.function.name,
|
555
|
+
content=[TextData(type="text", text=json.dumps(tool_result))],
|
556
|
+
tool_call_id=tool_call.id if tool_call.id else None,
|
557
|
+
)
|
499
558
|
|
500
|
-
|
559
|
+
tool_call_results = [run_tool(tc) for tc in assistant_tool_calls]
|
560
|
+
for result_msg in tool_call_results:
|
561
|
+
self._messages.append(result_msg)
|
562
|
+
resp = AgentResponseToolResult(
|
563
|
+
type="tool_call_result",
|
564
|
+
role="tool",
|
565
|
+
is_type_switched=True,
|
566
|
+
content=result_msg,
|
567
|
+
)
|
568
|
+
prev_resp_type = resp.type
|
569
|
+
yield resp
|
570
|
+
# Infer again if tool calls happened
|
571
|
+
continue
|
501
572
|
|
502
|
-
|
503
|
-
|
504
|
-
yield AgentResponseToolCallResult(
|
505
|
-
type="tool_call_result",
|
506
|
-
end_of_turn=True,
|
507
|
-
role="tool",
|
508
|
-
content=result_msg,
|
509
|
-
)
|
573
|
+
# Finish this generator
|
574
|
+
break
|
510
575
|
|
511
|
-
|
512
|
-
|
576
|
+
def get_messages(self) -> list[Message]:
|
577
|
+
"""
|
578
|
+
Get the current conversation history.
|
579
|
+
Each item in the list represents a message from either the user or the assistant.
|
513
580
|
|
514
|
-
|
515
|
-
|
516
|
-
|
517
|
-
|
518
|
-
end_of_turn=True,
|
519
|
-
role="assistant",
|
520
|
-
content=output_msg.content,
|
521
|
-
)
|
581
|
+
:return: The conversation history so far in the form of a list.
|
582
|
+
:rtype: list[Message]
|
583
|
+
"""
|
584
|
+
return self._messages
|
522
585
|
|
523
|
-
|
524
|
-
|
586
|
+
def clear_messages(self):
|
587
|
+
"""
|
588
|
+
Clear the history of conversation messages.
|
589
|
+
"""
|
590
|
+
self._messages.clear()
|
591
|
+
if self._system_message is not None:
|
592
|
+
self._messages.append(
|
593
|
+
SystemMessage(role="system", content=[TextData(type="text", text=self._system_message)])
|
594
|
+
)
|
525
595
|
|
526
596
|
def print(self, resp: AgentResponse):
|
527
597
|
resp.print()
|
@@ -537,14 +607,29 @@ class Agent:
|
|
537
607
|
return
|
538
608
|
self._tools.append(tool)
|
539
609
|
|
540
|
-
def add_py_function_tool(self,
|
610
|
+
def add_py_function_tool(self, f: Callable[..., Any], desc: Optional[dict] = None):
|
541
611
|
"""
|
542
612
|
Adds a Python function as a tool using callable.
|
543
613
|
|
544
|
-
:param desc: Tool descriotion.
|
545
614
|
:param f: Function will be called when the tool invocation occured.
|
615
|
+
:param desc: Tool description to override. If not given, parsed from docstring of function `f`.
|
616
|
+
|
617
|
+
:raises ValueError: Docstring parsing is failed.
|
618
|
+
:raises ValidationError: Given or parsed description is not a valid `ToolDescription`.
|
546
619
|
"""
|
547
|
-
|
620
|
+
tool_description = None
|
621
|
+
if desc is not None:
|
622
|
+
tool_description = ToolDescription.model_validate(desc)
|
623
|
+
|
624
|
+
if tool_description is None:
|
625
|
+
try:
|
626
|
+
json_schema = get_json_schema(f)
|
627
|
+
except (TypeHintParsingException, DocstringParsingException) as e:
|
628
|
+
raise ValueError("Failed to parse docstring", e)
|
629
|
+
|
630
|
+
tool_description = ToolDescription.model_validate(json_schema.get("function", {}))
|
631
|
+
|
632
|
+
self.add_tool(Tool(desc=tool_description, call_fn=f))
|
548
633
|
|
549
634
|
def add_builtin_tool(self, tool_def: BuiltinToolDefinition) -> bool:
|
550
635
|
"""
|
@@ -669,61 +754,53 @@ class Agent:
|
|
669
754
|
else:
|
670
755
|
warnings.warn(f'Tool type "{tool_type}" is not supported. Skip adding tool "{tool_name}".')
|
671
756
|
|
672
|
-
def
|
757
|
+
def add_tools_from_mcp_server(
|
758
|
+
self, name: str, params: StdioServerParameters, tools_to_add: Optional[list[str]] = None
|
759
|
+
):
|
673
760
|
"""
|
674
|
-
|
761
|
+
Create a MCP server and register its tools to agent.
|
675
762
|
|
763
|
+
:param name: The unique name of the MCP server.
|
764
|
+
If there's already a MCP server with the same name, it raises RuntimeError.
|
676
765
|
:param params: Parameters for connecting to the MCP stdio server.
|
677
|
-
:param
|
678
|
-
:returns: True if the tool was successfully added.
|
766
|
+
:param tools_to_add: Optional list of tool names to add. If None, all tools are added.
|
679
767
|
"""
|
680
|
-
|
768
|
+
if any([s.name == name for s in self._mcp_servers]):
|
769
|
+
raise RuntimeError(f"MCP server with name '{name}' is already registered")
|
681
770
|
|
682
|
-
|
683
|
-
|
684
|
-
|
685
|
-
|
686
|
-
|
687
|
-
|
688
|
-
|
689
|
-
|
690
|
-
|
691
|
-
|
692
|
-
|
693
|
-
|
694
|
-
|
695
|
-
|
696
|
-
|
697
|
-
|
698
|
-
|
699
|
-
|
700
|
-
|
701
|
-
|
702
|
-
|
703
|
-
return run_async(_inner())
|
704
|
-
|
705
|
-
desc = ToolDescription(name=tool.name, description=tool.description, parameters=tool.inputSchema)
|
706
|
-
return self.add_tool(Tool(desc=desc, call_fn=call))
|
707
|
-
|
708
|
-
def add_tools_from_mcp_server(self, params: mcp.StdioServerParameters, tools_to_add: Optional[list[str]] = None):
|
771
|
+
# Create and register MCP server
|
772
|
+
mcp_server = MCPServer(name, params)
|
773
|
+
self._mcp_servers.append(mcp_server)
|
774
|
+
|
775
|
+
# Register tools
|
776
|
+
for tool in mcp_server.list_tools():
|
777
|
+
# Skip if this tool is not in the whitelist
|
778
|
+
if tools_to_add is not None and tool.name not in tools_to_add:
|
779
|
+
continue
|
780
|
+
|
781
|
+
desc = ToolDescription(
|
782
|
+
name=f"{name}/{tool.name}", description=tool.description, parameters=tool.inputSchema
|
783
|
+
)
|
784
|
+
|
785
|
+
def call(tool: MCPTool, **inputs: dict[str, Any]) -> list[str]:
|
786
|
+
return mcp_server.call_tool(tool, inputs)
|
787
|
+
|
788
|
+
self.add_tool(Tool(desc=desc, call_fn=partial(call, tool)))
|
789
|
+
|
790
|
+
def remove_mcp_server(self, name: str):
|
709
791
|
"""
|
710
|
-
|
792
|
+
Removes the MCP server and its tools from the agent, with terminating the MCP server process.
|
711
793
|
|
712
|
-
:param
|
713
|
-
|
714
|
-
:returns: list of all tools returned by the server.
|
794
|
+
:param name: The unique name of the MCP server.
|
795
|
+
If there's no MCP server matches the name, it raises RuntimeError.
|
715
796
|
"""
|
716
|
-
|
717
|
-
|
718
|
-
|
719
|
-
|
720
|
-
|
721
|
-
|
722
|
-
|
723
|
-
|
724
|
-
|
725
|
-
|
726
|
-
return resp.tools
|
727
|
-
|
728
|
-
tools = run_async(_inner())
|
729
|
-
return tools
|
797
|
+
if all([s.name != name for s in self._mcp_servers]):
|
798
|
+
raise RuntimeError(f"MCP server with name '{name}' does not exist")
|
799
|
+
|
800
|
+
# Remove the MCP server
|
801
|
+
mcp_server = next(filter(lambda s: s.name == name, self._mcp_servers))
|
802
|
+
self._mcp_servers.remove(mcp_server)
|
803
|
+
mcp_server.cleanup()
|
804
|
+
|
805
|
+
# Remove tools registered from the MCP server
|
806
|
+
self._tools = list(filter(lambda t: not t.desc.name.startswith(f"{mcp_server.name}/"), self._tools))
|
Binary file
|
ailoy/mcp.py
ADDED
@@ -0,0 +1,159 @@
|
|
1
|
+
import asyncio
|
2
|
+
import json
|
3
|
+
import multiprocessing
|
4
|
+
import platform
|
5
|
+
import subprocess
|
6
|
+
from multiprocessing.connection import Connection
|
7
|
+
from typing import Annotated, Any, Literal, Union
|
8
|
+
|
9
|
+
import mcp.types as mcp_types
|
10
|
+
from mcp import Tool as MCPTool
|
11
|
+
from mcp.client.session import ClientSession
|
12
|
+
from mcp.client.stdio import (
|
13
|
+
StdioServerParameters,
|
14
|
+
stdio_client,
|
15
|
+
)
|
16
|
+
from pydantic import BaseModel, Field, TypeAdapter
|
17
|
+
|
18
|
+
__all__ = ["MCPServer"]
|
19
|
+
|
20
|
+
|
21
|
+
class ListToolsRequest(BaseModel):
|
22
|
+
type: Literal["list_tools"] = "list_tools"
|
23
|
+
|
24
|
+
|
25
|
+
class CallToolRequest(BaseModel):
|
26
|
+
type: Literal["call_tool"] = "call_tool"
|
27
|
+
tool: MCPTool
|
28
|
+
arguments: dict[str, Any]
|
29
|
+
|
30
|
+
|
31
|
+
class ShutdownRequest(BaseModel):
|
32
|
+
type: Literal["shutdown"] = "shutdown"
|
33
|
+
|
34
|
+
|
35
|
+
# Requests (main -> subprocess)
|
36
|
+
RequestMessage = Annotated[Union[ListToolsRequest, CallToolRequest, ShutdownRequest], Field(discriminator="type")]
|
37
|
+
|
38
|
+
|
39
|
+
class ResultMessage(BaseModel):
|
40
|
+
type: Literal["result"] = "result"
|
41
|
+
result: Any
|
42
|
+
|
43
|
+
|
44
|
+
class ErrorMessage(BaseModel):
|
45
|
+
type: Literal["error"] = "error"
|
46
|
+
error: str
|
47
|
+
|
48
|
+
|
49
|
+
# Response (subprocess -> main)
|
50
|
+
ResponseMessage = Annotated[Union[ResultMessage, ErrorMessage], Field(discriminator="type")]
|
51
|
+
|
52
|
+
|
53
|
+
class MCPServer:
|
54
|
+
"""
|
55
|
+
MCPServer manages a subprocess that acts as a bridge between an MCP stdio server and the main process.
|
56
|
+
|
57
|
+
- The subprocess communicates with the MCP stdio server using the official MCP Python SDK.
|
58
|
+
- Communication between the main process and the subprocess is handled through a multiprocessing Pipe.
|
59
|
+
Messages sent over this Pipe are serialized and deserialized using structured Pydantic models:
|
60
|
+
- `RequestMessage` for requests from the main process to the subprocess.
|
61
|
+
- `ResponseMessage` for responses from the subprocess to the main process.
|
62
|
+
|
63
|
+
This design ensures:
|
64
|
+
- Type-safe, structured inter-process communication.
|
65
|
+
- Synchronous interaction with an asynchronous MCP session (via message passing).
|
66
|
+
- Subprocess lifecycle control (including initialization and shutdown).
|
67
|
+
"""
|
68
|
+
|
69
|
+
def __init__(self, name: str, params: StdioServerParameters):
|
70
|
+
self.name = name
|
71
|
+
self.params = params
|
72
|
+
|
73
|
+
self._parent_conn, self._child_conn = multiprocessing.Pipe()
|
74
|
+
|
75
|
+
ctx = multiprocessing.get_context("fork" if platform.system() != "Windows" else "spawn")
|
76
|
+
self._proc = ctx.Process(target=self._run_process, args=(self._child_conn,))
|
77
|
+
self._proc.start()
|
78
|
+
|
79
|
+
# Wait for subprocess to signal initialization complete
|
80
|
+
self._recv_response()
|
81
|
+
|
82
|
+
def __del__(self):
|
83
|
+
self.cleanup()
|
84
|
+
|
85
|
+
def _run_process(self, conn: Connection):
|
86
|
+
asyncio.run(self._process_main(conn))
|
87
|
+
|
88
|
+
async def _process_main(self, conn: Connection):
|
89
|
+
async with stdio_client(self.params, errlog=subprocess.PIPE) as (read, write):
|
90
|
+
async with ClientSession(read, write) as session:
|
91
|
+
# Notify to main process that the initialization has been finished and ready to receive requests
|
92
|
+
try:
|
93
|
+
await session.initialize()
|
94
|
+
conn.send(ResultMessage(result=True).model_dump())
|
95
|
+
except Exception as e:
|
96
|
+
conn.send(ErrorMessage(error=f"Failed to initialize MCP subprocess: {e}").model_dump())
|
97
|
+
|
98
|
+
while True:
|
99
|
+
if not conn.poll(0.1):
|
100
|
+
await asyncio.sleep(0.1)
|
101
|
+
continue
|
102
|
+
|
103
|
+
try:
|
104
|
+
raw = conn.recv()
|
105
|
+
req = TypeAdapter(RequestMessage).validate_python(raw)
|
106
|
+
|
107
|
+
if isinstance(req, ListToolsRequest):
|
108
|
+
result = await session.list_tools()
|
109
|
+
conn.send(ResultMessage(result=result.tools).model_dump())
|
110
|
+
|
111
|
+
elif isinstance(req, CallToolRequest):
|
112
|
+
result = await session.call_tool(req.tool.name, req.arguments)
|
113
|
+
contents: list[str] = []
|
114
|
+
for item in result.content:
|
115
|
+
if isinstance(item, mcp_types.TextContent):
|
116
|
+
try:
|
117
|
+
content = json.loads(item.text)
|
118
|
+
contents.append(json.dumps(content))
|
119
|
+
except json.JSONDecodeError:
|
120
|
+
contents.append(item.text)
|
121
|
+
elif isinstance(item, mcp_types.ImageContent):
|
122
|
+
contents.append(item.data)
|
123
|
+
elif isinstance(item, mcp_types.EmbeddedResource):
|
124
|
+
if isinstance(item.resource, mcp_types.TextResourceContents):
|
125
|
+
contents.append(item.resource.text)
|
126
|
+
else:
|
127
|
+
contents.append(item.resource.blob)
|
128
|
+
conn.send(ResultMessage(result=contents).model_dump())
|
129
|
+
|
130
|
+
elif isinstance(req, ShutdownRequest):
|
131
|
+
break
|
132
|
+
|
133
|
+
except Exception as e:
|
134
|
+
conn.send(ErrorMessage(error=str(e)).model_dump())
|
135
|
+
|
136
|
+
def _send_request(self, msg: RequestMessage):
|
137
|
+
self._parent_conn.send(msg.model_dump())
|
138
|
+
|
139
|
+
def _recv_response(self) -> ResultMessage:
|
140
|
+
raw = self._parent_conn.recv()
|
141
|
+
msg = TypeAdapter(ResponseMessage).validate_python(raw)
|
142
|
+
if isinstance(msg, ErrorMessage):
|
143
|
+
raise RuntimeError(msg.error)
|
144
|
+
return msg
|
145
|
+
|
146
|
+
def list_tools(self) -> list[MCPTool]:
|
147
|
+
self._send_request(ListToolsRequest())
|
148
|
+
msg = self._recv_response()
|
149
|
+
return [MCPTool.model_validate(tool) for tool in msg.result]
|
150
|
+
|
151
|
+
def call_tool(self, tool: MCPTool, arguments: dict[str, Any]) -> list[str]:
|
152
|
+
self._send_request(CallToolRequest(tool=tool, arguments=arguments))
|
153
|
+
msg = self._recv_response()
|
154
|
+
return msg.result
|
155
|
+
|
156
|
+
def cleanup(self) -> None:
|
157
|
+
if self._proc.is_alive():
|
158
|
+
self._send_request(ShutdownRequest())
|
159
|
+
self._proc.join()
|
ailoy/runtime.py
CHANGED
@@ -1,3 +1,4 @@
|
|
1
|
+
import time
|
1
2
|
from asyncio import Event, to_thread
|
2
3
|
from collections import defaultdict
|
3
4
|
from typing import Any, AsyncGenerator, Generator, Literal, Optional, TypedDict
|
@@ -15,21 +16,25 @@ class Packet(TypedDict):
|
|
15
16
|
|
16
17
|
|
17
18
|
class RuntimeBase:
|
18
|
-
|
19
|
-
|
19
|
+
__client_count: dict[str, int] = {}
|
20
|
+
|
21
|
+
def __init__(self, url: str = "inproc://"):
|
22
|
+
self.url: str = url
|
20
23
|
self._responses: dict[str, Packet] = {}
|
21
24
|
self._exec_responses: defaultdict[str, dict[int, Packet]] = defaultdict(dict)
|
22
25
|
self._listen_lock: Optional[Event] = None
|
23
26
|
|
24
|
-
|
25
|
-
|
27
|
+
if RuntimeBase.__client_count.get(self.url, 0) == 0:
|
28
|
+
start_threads(self.url)
|
29
|
+
RuntimeBase.__client_count[self.url] = 0
|
30
|
+
|
31
|
+
self._client: BrokerClient = BrokerClient(self.url)
|
26
32
|
txid = self._send_type1("connect")
|
27
|
-
if not txid:
|
28
|
-
raise RuntimeError("Connection failed")
|
29
33
|
self._sync_listen()
|
30
34
|
if not self._responses[txid]["body"]["status"]:
|
31
35
|
raise RuntimeError("Connection failed")
|
32
36
|
del self._responses[txid]
|
37
|
+
RuntimeBase.__client_count[self.url] += 1
|
33
38
|
|
34
39
|
def __del__(self):
|
35
40
|
self.stop()
|
@@ -41,22 +46,32 @@ class RuntimeBase:
|
|
41
46
|
self.stop()
|
42
47
|
|
43
48
|
def stop(self):
|
44
|
-
if self.
|
49
|
+
if self.is_alive():
|
45
50
|
txid = self._send_type1("disconnect")
|
46
|
-
if not txid:
|
47
|
-
raise RuntimeError("Disconnection failed")
|
48
51
|
while txid not in self._responses:
|
49
52
|
self._sync_listen()
|
50
53
|
if not self._responses[txid]["body"]["status"]:
|
51
54
|
raise RuntimeError("Disconnection failed")
|
52
55
|
self._client = None
|
53
|
-
|
56
|
+
RuntimeBase.__client_count[self.url] -= 1
|
57
|
+
if RuntimeBase.__client_count.get(self.url, 0) <= 0:
|
58
|
+
stop_threads(self.url)
|
59
|
+
RuntimeBase.__client_count.pop(self.url, 0)
|
54
60
|
|
55
|
-
def
|
61
|
+
def is_alive(self):
|
62
|
+
return self._client is not None
|
63
|
+
|
64
|
+
def _send_type1(self, ptype: Literal["connect", "disconnect"]) -> str:
|
56
65
|
txid = generate_uuid()
|
57
|
-
|
58
|
-
|
59
|
-
|
66
|
+
retry_count = 0
|
67
|
+
# Since the broker thread might start slightly later than the runtime client,
|
68
|
+
# we retry sending the packat a few times to ensure delivery.
|
69
|
+
while retry_count < 3:
|
70
|
+
if self._client.send_type1(txid, ptype):
|
71
|
+
return txid
|
72
|
+
time.sleep(0.001)
|
73
|
+
retry_count += 1
|
74
|
+
raise RuntimeError(f'Failed to send packet "{ptype}"')
|
60
75
|
|
61
76
|
def _send_type2(
|
62
77
|
self,
|
@@ -76,7 +91,7 @@ class RuntimeBase:
|
|
76
91
|
*args,
|
77
92
|
):
|
78
93
|
txid = generate_uuid()
|
79
|
-
if self._client.
|
94
|
+
if self._client.send_type3(txid, ptype, status, *args):
|
80
95
|
return txid
|
81
96
|
raise RuntimeError("Failed to send packet")
|
82
97
|
|
@@ -112,8 +127,8 @@ class RuntimeBase:
|
|
112
127
|
|
113
128
|
|
114
129
|
class Runtime(RuntimeBase):
|
115
|
-
def __init__(self,
|
116
|
-
super().__init__(
|
130
|
+
def __init__(self, url: str = "inproc://"):
|
131
|
+
super().__init__(url)
|
117
132
|
|
118
133
|
def call(self, func_name: str, input: Any) -> Any:
|
119
134
|
rv = [v for v in self.call_iter(func_name, input)]
|
@@ -193,8 +208,8 @@ class Runtime(RuntimeBase):
|
|
193
208
|
|
194
209
|
|
195
210
|
class AsyncRuntime(RuntimeBase):
|
196
|
-
def __init__(self,
|
197
|
-
super().__init__(
|
211
|
+
def __init__(self, url: str = "inproc://"):
|
212
|
+
super().__init__(url)
|
198
213
|
|
199
214
|
async def call(self, func_name: str, input: Any) -> Any:
|
200
215
|
rv = [v async for v in self.call_iter(func_name, input)]
|
ailoy/tools.py
ADDED
@@ -0,0 +1,205 @@
|
|
1
|
+
import inspect
|
2
|
+
import json
|
3
|
+
import re
|
4
|
+
import types
|
5
|
+
from typing import (
|
6
|
+
Any,
|
7
|
+
Callable,
|
8
|
+
Optional,
|
9
|
+
Union,
|
10
|
+
get_args,
|
11
|
+
get_origin,
|
12
|
+
get_type_hints,
|
13
|
+
)
|
14
|
+
|
15
|
+
description_re = re.compile(r"^(.*?)[\n\s]*(Args:|Returns:|Raises:|\Z)", re.DOTALL)
|
16
|
+
# Extracts the Args: block from the docstring
|
17
|
+
args_re = re.compile(r"\n\s*Args:\n\s*(.*?)[\n\s]*(Returns:|Raises:|\Z)", re.DOTALL)
|
18
|
+
# Splits the Args: block into individual arguments
|
19
|
+
args_split_re = re.compile(
|
20
|
+
r"""
|
21
|
+
(?:^|\n) # Match the start of the args block, or a newline
|
22
|
+
\s*(\w+):\s* # Capture the argument name and strip spacing
|
23
|
+
(.*?)\s* # Capture the argument description, which can span multiple lines, and strip trailing spacing
|
24
|
+
(?=\n\s*\w+:|\Z) # Stop when you hit the next argument or the end of the block
|
25
|
+
""",
|
26
|
+
re.DOTALL | re.VERBOSE,
|
27
|
+
)
|
28
|
+
# Extracts the Returns: block from the docstring, if present. Note that most chat templates ignore the return type/doc!
|
29
|
+
returns_re = re.compile(r"\n\s*Returns:\n\s*(.*?)[\n\s]*(Raises:|\Z)", re.DOTALL)
|
30
|
+
|
31
|
+
|
32
|
+
class TypeHintParsingException(Exception):
|
33
|
+
"""Exception raised for errors in parsing type hints to generate JSON schemas"""
|
34
|
+
|
35
|
+
pass
|
36
|
+
|
37
|
+
|
38
|
+
class DocstringParsingException(Exception):
|
39
|
+
"""Exception raised for errors in parsing docstrings to generate JSON schemas"""
|
40
|
+
|
41
|
+
pass
|
42
|
+
|
43
|
+
|
44
|
+
def _get_json_schema_type(param_type: str) -> dict[str, str]:
|
45
|
+
type_mapping = {
|
46
|
+
int: {"type": "integer"},
|
47
|
+
float: {"type": "number"},
|
48
|
+
str: {"type": "string"},
|
49
|
+
bool: {"type": "boolean"},
|
50
|
+
type(None): {"type": "null"},
|
51
|
+
Any: {},
|
52
|
+
}
|
53
|
+
# if is_vision_available():
|
54
|
+
# type_mapping[Image] = {"type": "image"}
|
55
|
+
# if is_torch_available():
|
56
|
+
# type_mapping[Tensor] = {"type": "audio"}
|
57
|
+
return type_mapping.get(param_type, {"type": "object"})
|
58
|
+
|
59
|
+
|
60
|
+
def _parse_type_hint(hint: str) -> dict:
|
61
|
+
origin = get_origin(hint)
|
62
|
+
args = get_args(hint)
|
63
|
+
|
64
|
+
if origin is None:
|
65
|
+
try:
|
66
|
+
return _get_json_schema_type(hint)
|
67
|
+
except KeyError:
|
68
|
+
raise TypeHintParsingException(
|
69
|
+
"Couldn't parse this type hint, likely due to a custom class or object: ", hint
|
70
|
+
)
|
71
|
+
|
72
|
+
elif origin is Union or (hasattr(types, "UnionType") and origin is types.UnionType):
|
73
|
+
# Recurse into each of the subtypes in the Union, except None, which is handled separately at the end
|
74
|
+
subtypes = [_parse_type_hint(t) for t in args if t is not type(None)]
|
75
|
+
if len(subtypes) == 1:
|
76
|
+
# A single non-null type can be expressed directly
|
77
|
+
return_dict = subtypes[0]
|
78
|
+
elif all(isinstance(subtype["type"], str) for subtype in subtypes):
|
79
|
+
# A union of basic types can be expressed as a list in the schema
|
80
|
+
return_dict = {"type": sorted([subtype["type"] for subtype in subtypes])}
|
81
|
+
else:
|
82
|
+
# A union of more complex types requires "anyOf"
|
83
|
+
return_dict = {"anyOf": subtypes}
|
84
|
+
if type(None) in args:
|
85
|
+
return_dict["nullable"] = True
|
86
|
+
return return_dict
|
87
|
+
|
88
|
+
elif origin is list:
|
89
|
+
if not args:
|
90
|
+
return {"type": "array"}
|
91
|
+
else:
|
92
|
+
# Lists can only have a single type argument, so recurse into it
|
93
|
+
return {"type": "array", "items": _parse_type_hint(args[0])}
|
94
|
+
|
95
|
+
elif origin is tuple:
|
96
|
+
if not args:
|
97
|
+
return {"type": "array"}
|
98
|
+
if len(args) == 1:
|
99
|
+
raise TypeHintParsingException(
|
100
|
+
f"The type hint {str(hint).replace('typing.', '')} is a Tuple with a single element, which "
|
101
|
+
"we do not automatically convert to JSON schema as it is rarely necessary. If this input can contain "
|
102
|
+
"more than one element, we recommend "
|
103
|
+
"using a List[] type instead, or if it really is a single element, remove the Tuple[] wrapper and just "
|
104
|
+
"pass the element directly."
|
105
|
+
)
|
106
|
+
if ... in args:
|
107
|
+
raise TypeHintParsingException(
|
108
|
+
"Conversion of '...' is not supported in Tuple type hints. "
|
109
|
+
"Use List[] types for variable-length"
|
110
|
+
" inputs instead."
|
111
|
+
)
|
112
|
+
return {"type": "array", "prefixItems": [_parse_type_hint(t) for t in args]}
|
113
|
+
|
114
|
+
elif origin is dict:
|
115
|
+
# The JSON equivalent to a dict is 'object', which mandates that all keys are strings
|
116
|
+
# However, we can specify the type of the dict values with "additionalProperties"
|
117
|
+
out = {"type": "object"}
|
118
|
+
if len(args) == 2:
|
119
|
+
out["additionalProperties"] = _parse_type_hint(args[1])
|
120
|
+
return out
|
121
|
+
|
122
|
+
raise TypeHintParsingException("Couldn't parse this type hint, likely due to a custom class or object: ", hint)
|
123
|
+
|
124
|
+
|
125
|
+
def _convert_type_hints_to_json_schema(func: Callable) -> dict:
|
126
|
+
type_hints = get_type_hints(func)
|
127
|
+
signature = inspect.signature(func)
|
128
|
+
required = []
|
129
|
+
for param_name, param in signature.parameters.items():
|
130
|
+
if param.annotation == inspect.Parameter.empty:
|
131
|
+
raise TypeHintParsingException(f"Argument {param.name} is missing a type hint in function {func.__name__}")
|
132
|
+
if param.default == inspect.Parameter.empty:
|
133
|
+
required.append(param_name)
|
134
|
+
|
135
|
+
properties = {}
|
136
|
+
for param_name, param_type in type_hints.items():
|
137
|
+
properties[param_name] = _parse_type_hint(param_type)
|
138
|
+
|
139
|
+
schema = {"type": "object", "properties": properties}
|
140
|
+
if required:
|
141
|
+
schema["required"] = required
|
142
|
+
|
143
|
+
return schema
|
144
|
+
|
145
|
+
|
146
|
+
def parse_google_format_docstring(docstring: str) -> tuple[Optional[str], Optional[dict], Optional[str]]:
|
147
|
+
"""
|
148
|
+
Parses a Google-style docstring to extract the function description,
|
149
|
+
argument descriptions, and return description.
|
150
|
+
|
151
|
+
Args:
|
152
|
+
docstring (str): The docstring to parse.
|
153
|
+
|
154
|
+
Returns:
|
155
|
+
The function description, arguments, and return description.
|
156
|
+
"""
|
157
|
+
|
158
|
+
# Extract the sections
|
159
|
+
description_match = description_re.search(docstring)
|
160
|
+
args_match = args_re.search(docstring)
|
161
|
+
returns_match = returns_re.search(docstring)
|
162
|
+
|
163
|
+
# Clean and store the sections
|
164
|
+
description = description_match.group(1).strip() if description_match else None
|
165
|
+
docstring_args = args_match.group(1).strip() if args_match else None
|
166
|
+
returns = returns_match.group(1).strip() if returns_match else None
|
167
|
+
|
168
|
+
# Parsing the arguments into a dictionary
|
169
|
+
if docstring_args is not None:
|
170
|
+
docstring_args = "\n".join([line for line in docstring_args.split("\n") if line.strip()]) # Remove blank lines
|
171
|
+
matches = args_split_re.findall(docstring_args)
|
172
|
+
args_dict = {match[0]: re.sub(r"\s*\n+\s*", " ", match[1].strip()) for match in matches}
|
173
|
+
else:
|
174
|
+
args_dict = {}
|
175
|
+
|
176
|
+
return description, args_dict, returns
|
177
|
+
|
178
|
+
|
179
|
+
def get_json_schema(func: Callable) -> dict:
|
180
|
+
doc = inspect.getdoc(func)
|
181
|
+
if not doc:
|
182
|
+
raise DocstringParsingException(f"Cannot generate JSON schema for {func.__name__} because it has no docstring!")
|
183
|
+
doc = doc.strip()
|
184
|
+
main_doc, param_descriptions, return_doc = parse_google_format_docstring(doc)
|
185
|
+
|
186
|
+
json_schema = _convert_type_hints_to_json_schema(func)
|
187
|
+
if (return_dict := json_schema["properties"].pop("return", None)) is not None:
|
188
|
+
if return_doc is not None: # We allow a missing return docstring since most templates ignore it
|
189
|
+
return_dict["description"] = return_doc
|
190
|
+
for arg, schema in json_schema["properties"].items():
|
191
|
+
if arg not in param_descriptions:
|
192
|
+
raise DocstringParsingException(
|
193
|
+
f"Cannot generate JSON schema for {func.__name__} because the docstring has no description for the argument '{arg}'"
|
194
|
+
)
|
195
|
+
desc = param_descriptions[arg]
|
196
|
+
enum_choices = re.search(r"\(choices:\s*(.*?)\)\s*$", desc, flags=re.IGNORECASE)
|
197
|
+
if enum_choices:
|
198
|
+
schema["enum"] = [c.strip() for c in json.loads(enum_choices.group(1))]
|
199
|
+
desc = enum_choices.string[: enum_choices.start()].strip()
|
200
|
+
schema["description"] = desc
|
201
|
+
|
202
|
+
output = {"name": func.__name__, "description": main_doc, "parameters": json_schema}
|
203
|
+
if return_dict is not None:
|
204
|
+
output["return"] = return_dict
|
205
|
+
return {"type": "function", "function": output}
|
ailoy/vector_store.py
CHANGED
@@ -46,8 +46,8 @@ class VectorStore:
|
|
46
46
|
vector_store_name: Literal["faiss", "chromadb"],
|
47
47
|
url: Optional[str] = None,
|
48
48
|
collection: Optional[str] = None,
|
49
|
-
embedding_model_attrs: dict[str, Any] =
|
50
|
-
vector_store_attrs: dict[str, Any] =
|
49
|
+
embedding_model_attrs: Optional[dict[str, Any]] = None,
|
50
|
+
vector_store_attrs: Optional[dict[str, Any]] = None,
|
51
51
|
):
|
52
52
|
"""
|
53
53
|
Creates an instance.
|
@@ -69,10 +69,10 @@ class VectorStore:
|
|
69
69
|
self.define(
|
70
70
|
embedding_model_name,
|
71
71
|
vector_store_name,
|
72
|
-
url,
|
73
|
-
collection,
|
74
|
-
embedding_model_attrs,
|
75
|
-
vector_store_attrs,
|
72
|
+
url=url,
|
73
|
+
collection=collection,
|
74
|
+
embedding_model_attrs=embedding_model_attrs,
|
75
|
+
vector_store_attrs=vector_store_attrs,
|
76
76
|
)
|
77
77
|
|
78
78
|
def __del__(self):
|
@@ -90,8 +90,8 @@ class VectorStore:
|
|
90
90
|
vector_store_name: Literal["faiss", "chromadb"],
|
91
91
|
url: Optional[str] = None,
|
92
92
|
collection: Optional[str] = None,
|
93
|
-
embedding_model_attrs: dict[str, Any] =
|
94
|
-
vector_store_attrs: dict[str, Any] =
|
93
|
+
embedding_model_attrs: Optional[dict[str, Any]] = None,
|
94
|
+
vector_store_attrs: Optional[dict[str, Any]] = None,
|
95
95
|
):
|
96
96
|
"""
|
97
97
|
Defines the embedding model and vector store components to the runtime.
|
@@ -111,13 +111,14 @@ class VectorStore:
|
|
111
111
|
self._component_state.embedding_model_name,
|
112
112
|
{
|
113
113
|
"model": "BAAI/bge-m3",
|
114
|
-
**embedding_model_attrs,
|
114
|
+
**(embedding_model_attrs or {}),
|
115
115
|
},
|
116
116
|
)
|
117
117
|
else:
|
118
118
|
raise NotImplementedError(f"Unsupprted embedding model: {embedding_model_name}")
|
119
119
|
|
120
120
|
# Initialize vector store
|
121
|
+
vector_store_attrs = vector_store_attrs or {}
|
121
122
|
if vector_store_name == "faiss":
|
122
123
|
if "dimension" not in vector_store_attrs:
|
123
124
|
vector_store_attrs["dimension"] = dimension
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: ailoy-py
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.2
|
4
4
|
Summary: Python binding for Ailoy runtime APIs
|
5
5
|
Author-Email: "Brekkylab Inc." <contact@brekkylab.com>
|
6
6
|
License-Expression: Apache-2.0
|
@@ -65,7 +65,7 @@ rt.stop()
|
|
65
65
|
- LLVM Clang >= 17
|
66
66
|
- Apple Clang >= 15
|
67
67
|
- MSVC >= 19.29
|
68
|
-
- CMake >= 3.
|
68
|
+
- CMake >= 3.28.0
|
69
69
|
- Git
|
70
70
|
- OpenSSL
|
71
71
|
- Rust & Cargo >= 1.82.0
|
@@ -0,0 +1,20 @@
|
|
1
|
+
ailoy/__init__.py,sha256=mzkLUc95OCc2okURWm9iA5xR8WZdxwvPgaanc9fwoH4,647
|
2
|
+
ailoy/agent.py,sha256=uQ1o4CjQEO1vP7y0frGtVknN-A_iLNtkqt8h3vobgiM,27613
|
3
|
+
ailoy/ailoy_py.cpython-311-x86_64-linux-gnu.so,sha256=HewPp-NrYxV9_Un38rrHZ8cn_63MKRn2OaWUJ1X0g9I,21245329
|
4
|
+
ailoy/ailoy_py.pyi,sha256=Yf90FEXkslpCpr1r2eqQ3-_1jLo65zmG94bBXDRqinU,991
|
5
|
+
ailoy/mcp.py,sha256=bC58tAWqhvMdZVCKHSdOVNUoAuYfZiou1hSH1oa_9Ag,6190
|
6
|
+
ailoy/runtime.py,sha256=-75KawEMQSwxGvX5wtECVCWiTNdcHojsQ1e-OVB4IQ8,10545
|
7
|
+
ailoy/tools.py,sha256=RnTfmWlqYY1q0V377CpAAyAK-yET7k45GgEhgM9G8eI,8207
|
8
|
+
ailoy/vector_store.py,sha256=ZfIuGYKv2dQmjOuDlSKDc-BBPlQ8no_70mZwnPzbBzo,7515
|
9
|
+
ailoy/cli/__main__.py,sha256=HnBVb2em1F2NLPeNX5r3xRndRrnGaXVCduo8WBULAI0,179
|
10
|
+
ailoy/cli/model.py,sha256=cerCHE-VY9TOwqRcLBtmqnV-5vphpvyhtrfPFZiTKCM,2979
|
11
|
+
ailoy/presets/tools/calculator.json,sha256=ePnZsjZChnvS08s9eVdIp4Bys_PlJBXPHCCjv6oMvzA,1040
|
12
|
+
ailoy/presets/tools/frankfurter.json,sha256=bZ5vhszf_aR-B_QN4L2xrI5nR-f4AMZk41UUDq1dTXg,1152
|
13
|
+
ailoy/presets/tools/nytimes.json,sha256=wrfe9bnAlSPzHladoGEX2oCAeE0wed3BvgXQ_Z2PdXg,918
|
14
|
+
ailoy/presets/tools/tmdb.json,sha256=UGLN5uAJ2b-Hu3nLcW95WXDLB3mfC3rBYfQANp_e8Ps,7046
|
15
|
+
ailoy_py.libs/libgomp-870cb1d0.so.1.0.0,sha256=Ta6ZPLbakQH8LP74JzBt0DuJIBHS4nicjkSCjKnyWDw,253289
|
16
|
+
ailoy_py.libs/libtvm_runtime-2d14ca42.so,sha256=qPtn3HaKtxt-sL0wdu6Wqz7QsTmKY2ZWOPwO92TPfzU,5061889
|
17
|
+
ailoy_py-0.0.2.dist-info/METADATA,sha256=B5RbxeITquJfdiw9bhA6w02q9OvLkuFH7jMRg6Lxc2A,2010
|
18
|
+
ailoy_py-0.0.2.dist-info/WHEEL,sha256=pUMnbkEoOJH3JIiTuE-9tixQOeWRbAaVqA62Pyrra40,118
|
19
|
+
ailoy_py-0.0.2.dist-info/entry_points.txt,sha256=gVG45uDE6kef0wm6SEMYSgZgRNNRhSAeP2n2lPR00dI,50
|
20
|
+
ailoy_py-0.0.2.dist-info/RECORD,,
|
Binary file
|
ailoy_py-0.0.1.dist-info/RECORD
DELETED
@@ -1,19 +0,0 @@
|
|
1
|
-
ailoy/__init__.py,sha256=ArHu4OLbxU_re9cX-7zhshRqkn46gBcl-0SxDa5Wk00,148
|
2
|
-
ailoy/agent.py,sha256=orrfgeH6Xo9nPkHqJJ1YqS_8-5-0_tkpKlF8YcdRIqU,24998
|
3
|
-
ailoy/ailoy_py.cpython-311-x86_64-linux-gnu.so,sha256=IA5ayndZJMzsTlVcHMF7brPrZ2lSuMyZYv9qm_4uBf4,24729601
|
4
|
-
ailoy/ailoy_py.pyi,sha256=Yf90FEXkslpCpr1r2eqQ3-_1jLo65zmG94bBXDRqinU,991
|
5
|
-
ailoy/runtime.py,sha256=tVdUaqqx9NB-h4grRkW_R2XYW5ihFn95aUhnTZkl_Zg,9997
|
6
|
-
ailoy/vector_store.py,sha256=Ojhr4bcSfKKuMaldfsjz_G41AGgz4vyvpkWn5WMFE2c,7365
|
7
|
-
ailoy/cli/__main__.py,sha256=HnBVb2em1F2NLPeNX5r3xRndRrnGaXVCduo8WBULAI0,179
|
8
|
-
ailoy/cli/model.py,sha256=cerCHE-VY9TOwqRcLBtmqnV-5vphpvyhtrfPFZiTKCM,2979
|
9
|
-
ailoy/presets/tools/calculator.json,sha256=ePnZsjZChnvS08s9eVdIp4Bys_PlJBXPHCCjv6oMvzA,1040
|
10
|
-
ailoy/presets/tools/frankfurter.json,sha256=bZ5vhszf_aR-B_QN4L2xrI5nR-f4AMZk41UUDq1dTXg,1152
|
11
|
-
ailoy/presets/tools/nytimes.json,sha256=wrfe9bnAlSPzHladoGEX2oCAeE0wed3BvgXQ_Z2PdXg,918
|
12
|
-
ailoy/presets/tools/tmdb.json,sha256=UGLN5uAJ2b-Hu3nLcW95WXDLB3mfC3rBYfQANp_e8Ps,7046
|
13
|
-
ailoy_py.libs/libgomp-870cb1d0.so.1.0.0,sha256=Ta6ZPLbakQH8LP74JzBt0DuJIBHS4nicjkSCjKnyWDw,253289
|
14
|
-
ailoy_py.libs/libmvec-2-8eb5c230.28.so,sha256=65kXCJhVuWfybQTtrbZvv-omg-x4rUlzPkkPhbZxa7o,181969
|
15
|
-
ailoy_py.libs/libtvm_runtime-7067e461.so,sha256=Slb0RS5igenySBg490s7LPCKYmrZl10He1j2QsemmWM,4759545
|
16
|
-
ailoy_py-0.0.1.dist-info/METADATA,sha256=FAyzKTteIMgZse5H99IVQ5q2ScjU9hlTuOggBNrgTF0,2010
|
17
|
-
ailoy_py-0.0.1.dist-info/WHEEL,sha256=ynEJWBsXE4ohY630UToGhqtzU95gVdbtKm3DQK0bASo,118
|
18
|
-
ailoy_py-0.0.1.dist-info/entry_points.txt,sha256=gVG45uDE6kef0wm6SEMYSgZgRNNRhSAeP2n2lPR00dI,50
|
19
|
-
ailoy_py-0.0.1.dist-info/RECORD,,
|
Binary file
|
Binary file
|
File without changes
|