amsdal_ml 0.1.4__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- amsdal_ml/Third-Party Materials - AMSDAL Dependencies - License Notices.md +617 -0
- amsdal_ml/__about__.py +1 -1
- amsdal_ml/agents/__init__.py +13 -0
- amsdal_ml/agents/agent.py +5 -7
- amsdal_ml/agents/default_qa_agent.py +108 -143
- amsdal_ml/agents/functional_calling_agent.py +233 -0
- amsdal_ml/agents/mcp_client_tool.py +46 -0
- amsdal_ml/agents/python_tool.py +86 -0
- amsdal_ml/agents/retriever_tool.py +5 -6
- amsdal_ml/agents/tool_adapters.py +98 -0
- amsdal_ml/fileio/base_loader.py +7 -5
- amsdal_ml/fileio/openai_loader.py +16 -17
- amsdal_ml/mcp_client/base.py +2 -0
- amsdal_ml/mcp_client/http_client.py +7 -1
- amsdal_ml/mcp_client/stdio_client.py +19 -16
- amsdal_ml/mcp_server/server_retriever_stdio.py +8 -11
- amsdal_ml/ml_ingesting/__init__.py +29 -0
- amsdal_ml/ml_ingesting/default_ingesting.py +49 -51
- amsdal_ml/ml_ingesting/embedders/__init__.py +4 -0
- amsdal_ml/ml_ingesting/embedders/embedder.py +12 -0
- amsdal_ml/ml_ingesting/embedders/openai_embedder.py +30 -0
- amsdal_ml/ml_ingesting/embedding_data.py +3 -0
- amsdal_ml/ml_ingesting/loaders/__init__.py +6 -0
- amsdal_ml/ml_ingesting/loaders/folder_loader.py +52 -0
- amsdal_ml/ml_ingesting/loaders/loader.py +28 -0
- amsdal_ml/ml_ingesting/loaders/pdf_loader.py +136 -0
- amsdal_ml/ml_ingesting/loaders/text_loader.py +44 -0
- amsdal_ml/ml_ingesting/model_ingester.py +278 -0
- amsdal_ml/ml_ingesting/pipeline.py +131 -0
- amsdal_ml/ml_ingesting/pipeline_interface.py +31 -0
- amsdal_ml/ml_ingesting/processors/__init__.py +4 -0
- amsdal_ml/ml_ingesting/processors/cleaner.py +14 -0
- amsdal_ml/ml_ingesting/processors/text_cleaner.py +42 -0
- amsdal_ml/ml_ingesting/splitters/__init__.py +4 -0
- amsdal_ml/ml_ingesting/splitters/splitter.py +15 -0
- amsdal_ml/ml_ingesting/splitters/token_splitter.py +85 -0
- amsdal_ml/ml_ingesting/stores/__init__.py +4 -0
- amsdal_ml/ml_ingesting/stores/embedding_data.py +63 -0
- amsdal_ml/ml_ingesting/stores/store.py +22 -0
- amsdal_ml/ml_ingesting/types.py +40 -0
- amsdal_ml/ml_models/models.py +96 -4
- amsdal_ml/ml_models/openai_model.py +430 -122
- amsdal_ml/ml_models/utils.py +7 -0
- amsdal_ml/ml_retrievers/__init__.py +17 -0
- amsdal_ml/ml_retrievers/adapters.py +93 -0
- amsdal_ml/ml_retrievers/default_retriever.py +11 -1
- amsdal_ml/ml_retrievers/openai_retriever.py +27 -7
- amsdal_ml/ml_retrievers/query_retriever.py +487 -0
- amsdal_ml/ml_retrievers/retriever.py +12 -0
- amsdal_ml/models/embedding_model.py +7 -7
- amsdal_ml/prompts/__init__.py +77 -0
- amsdal_ml/prompts/database_query_agent.prompt +14 -0
- amsdal_ml/prompts/functional_calling_agent_base.prompt +9 -0
- amsdal_ml/prompts/nl_query_filter.prompt +318 -0
- amsdal_ml/{agents/promts → prompts}/react_chat.prompt +17 -8
- amsdal_ml/utils/__init__.py +5 -0
- amsdal_ml/utils/query_utils.py +189 -0
- {amsdal_ml-0.1.4.dist-info → amsdal_ml-0.2.1.dist-info}/METADATA +61 -3
- amsdal_ml-0.2.1.dist-info/RECORD +72 -0
- {amsdal_ml-0.1.4.dist-info → amsdal_ml-0.2.1.dist-info}/WHEEL +1 -1
- amsdal_ml/agents/promts/__init__.py +0 -58
- amsdal_ml-0.1.4.dist-info/RECORD +0 -39
|
@@ -0,0 +1,233 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from collections.abc import AsyncIterator
|
|
5
|
+
from typing import no_type_check
|
|
6
|
+
|
|
7
|
+
from amsdal_ml.agents.agent import Agent
|
|
8
|
+
from amsdal_ml.agents.agent import AgentOutput
|
|
9
|
+
from amsdal_ml.agents.mcp_client_tool import ClientToolProxy
|
|
10
|
+
from amsdal_ml.agents.python_tool import PythonTool
|
|
11
|
+
from amsdal_ml.agents.python_tool import _PythonToolProxy
|
|
12
|
+
from amsdal_ml.agents.tool_adapters import ToolAdapter
|
|
13
|
+
from amsdal_ml.agents.tool_adapters import get_tool_adapter
|
|
14
|
+
from amsdal_ml.fileio.base_loader import PLAIN_TEXT
|
|
15
|
+
from amsdal_ml.fileio.base_loader import FileAttachment
|
|
16
|
+
from amsdal_ml.mcp_client.base import ToolClient
|
|
17
|
+
from amsdal_ml.mcp_client.base import ToolInfo
|
|
18
|
+
from amsdal_ml.ml_models.models import MLModel
|
|
19
|
+
from amsdal_ml.ml_models.models import StructuredMessage
|
|
20
|
+
from amsdal_ml.ml_models.utils import ResponseFormat
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class FunctionalCallingAgent(Agent):
|
|
24
|
+
"""
|
|
25
|
+
An agent that uses the native function calling capabilities of LLMs (e.g., OpenAI)
|
|
26
|
+
to execute tools and answer user queries.
|
|
27
|
+
|
|
28
|
+
Attributes:
|
|
29
|
+
model (MLModel): The LLM model instance to use (must support function calling).
|
|
30
|
+
max_steps (int): Maximum number of tool execution steps allowed.
|
|
31
|
+
per_call_timeout (float | None): Timeout in seconds for each tool execution.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
*,
|
|
37
|
+
model: MLModel,
|
|
38
|
+
tools: list[PythonTool | ToolClient] | None = None,
|
|
39
|
+
max_steps: int = 6,
|
|
40
|
+
per_call_timeout: float | None = 20.0,
|
|
41
|
+
adapter: ToolAdapter | None = None,
|
|
42
|
+
):
|
|
43
|
+
"""
|
|
44
|
+
Initialize the FunctionalCallingAgent.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
model: The LLM model to use.
|
|
48
|
+
tools: A list of tools (PythonTool or ToolClient) available to the agent.
|
|
49
|
+
max_steps: The maximum number of iterations (model -> tool -> model) allowed.
|
|
50
|
+
per_call_timeout: Timeout for individual tool calls.
|
|
51
|
+
adapter: Optional tool adapter. If None, will auto-detect based on LLM type.
|
|
52
|
+
"""
|
|
53
|
+
self._tools: list[PythonTool | ToolClient] = tools or []
|
|
54
|
+
self._indexed_tools: dict[str, ClientToolProxy | _PythonToolProxy] = {}
|
|
55
|
+
self.model = model
|
|
56
|
+
self.model.setup()
|
|
57
|
+
self.max_steps = max_steps
|
|
58
|
+
self.per_call_timeout = per_call_timeout
|
|
59
|
+
self._is_tools_index_built = False
|
|
60
|
+
self.adapter = adapter or get_tool_adapter(model)
|
|
61
|
+
self._response_format: ResponseFormat = self._select_response_format()
|
|
62
|
+
|
|
63
|
+
async def arun(
|
|
64
|
+
self,
|
|
65
|
+
user_query: str,
|
|
66
|
+
*,
|
|
67
|
+
history: list[StructuredMessage] | None = None,
|
|
68
|
+
attachments: list[FileAttachment] | None = None,
|
|
69
|
+
) -> AgentOutput:
|
|
70
|
+
"""
|
|
71
|
+
Run the agent asynchronously to answer a user query.
|
|
72
|
+
|
|
73
|
+
This method executes the main loop:
|
|
74
|
+
1. Send query and tools to the model.
|
|
75
|
+
2. If model requests tool calls, execute them and report back.
|
|
76
|
+
3. Repeat until the model provides a final answer or max_steps is reached.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
user_query: The question or instruction from the user.
|
|
80
|
+
history: Optional chat history to continue the conversation.
|
|
81
|
+
attachments: Optional list of files/documents to include in context.
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
AgentOutput: The final answer and metadata about used tools.
|
|
85
|
+
"""
|
|
86
|
+
if not self._is_tools_index_built:
|
|
87
|
+
await self._build_clients_index()
|
|
88
|
+
|
|
89
|
+
content = self._merge_attachments(user_query, attachments)
|
|
90
|
+
messages = history.copy() if history else []
|
|
91
|
+
|
|
92
|
+
#TODO: JSON markdown tables support for nlqretriever
|
|
93
|
+
#if self._response_format == ResponseFormat.JSON_OBJECT:
|
|
94
|
+
# messages.append({'role': 'system', 'content': 'Please respond in json format.'})
|
|
95
|
+
|
|
96
|
+
messages.append({self.model.role_field: self.model.input_role, self.model.content_field: content}) # type: ignore[misc]
|
|
97
|
+
used_tools: list[str] = []
|
|
98
|
+
tools_schema = self.adapter.get_tools_schema(self._indexed_tools)
|
|
99
|
+
|
|
100
|
+
for _ in range(self.max_steps):
|
|
101
|
+
response_str = await self.model.ainvoke(
|
|
102
|
+
input=messages,
|
|
103
|
+
tools=tools_schema if tools_schema else None,
|
|
104
|
+
response_format=self._response_format,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
try:
|
|
108
|
+
response_data = json.loads(response_str)
|
|
109
|
+
except json.JSONDecodeError:
|
|
110
|
+
response_data = {self.model.role_field: self.model.output_role, self.model.content_field: response_str}
|
|
111
|
+
|
|
112
|
+
messages.append(response_data)
|
|
113
|
+
|
|
114
|
+
content_text, tool_calls = self.adapter.parse_response(response_data)
|
|
115
|
+
|
|
116
|
+
if not tool_calls:
|
|
117
|
+
return AgentOutput(answer=content_text or '', used_tools=used_tools)
|
|
118
|
+
|
|
119
|
+
for tool_call in tool_calls:
|
|
120
|
+
function_name, arguments_str, call_id = self.adapter.get_tool_call_info(tool_call)
|
|
121
|
+
|
|
122
|
+
used_tools.append(function_name)
|
|
123
|
+
|
|
124
|
+
try:
|
|
125
|
+
args = json.loads(arguments_str)
|
|
126
|
+
tool = self._indexed_tools[function_name]
|
|
127
|
+
result = await tool.run(args)
|
|
128
|
+
|
|
129
|
+
if isinstance(result, (dict, list)):
|
|
130
|
+
content_str = json.dumps(result, ensure_ascii=False)
|
|
131
|
+
else:
|
|
132
|
+
content_str = str(result)
|
|
133
|
+
except Exception as e:
|
|
134
|
+
content_str = f'Error: {e!s}'
|
|
135
|
+
|
|
136
|
+
messages.append({
|
|
137
|
+
self.model.role_field: self.model.tool_role, # type: ignore[misc]
|
|
138
|
+
self.model.tool_call_id_field: call_id,
|
|
139
|
+
self.model.tool_name_field: function_name,
|
|
140
|
+
self.model.content_field: content_str,
|
|
141
|
+
})
|
|
142
|
+
|
|
143
|
+
return AgentOutput(
|
|
144
|
+
answer='Agent stopped due to iteration limit.',
|
|
145
|
+
used_tools=used_tools,
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
@no_type_check
|
|
149
|
+
async def astream(
|
|
150
|
+
self,
|
|
151
|
+
user_query: str,
|
|
152
|
+
*,
|
|
153
|
+
history: list[StructuredMessage] | None = None,
|
|
154
|
+
attachments: list[FileAttachment] | None = None,
|
|
155
|
+
) -> AsyncIterator[str]:
|
|
156
|
+
"""
|
|
157
|
+
Stream the agent's response asynchronously.
|
|
158
|
+
|
|
159
|
+
Currently, this method buffers the full execution and yields the final answer
|
|
160
|
+
at once. True streaming of intermediate steps is not yet implemented.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
user_query: The question or instruction from the user.
|
|
164
|
+
history: Optional chat history to continue the conversation.
|
|
165
|
+
attachments: Optional list of files/documents.
|
|
166
|
+
|
|
167
|
+
Yields:
|
|
168
|
+
str: Chunks of the final answer (currently just the full answer).
|
|
169
|
+
"""
|
|
170
|
+
output = await self.arun(user_query, history=history, attachments=attachments)
|
|
171
|
+
yield output.answer
|
|
172
|
+
|
|
173
|
+
def _select_response_format(self) -> ResponseFormat:
|
|
174
|
+
"""
|
|
175
|
+
Select the best response format supported by the model.
|
|
176
|
+
|
|
177
|
+
Returns:
|
|
178
|
+
ResponseFormat: PLAIN_TEXT to allow raw Markdown tables.
|
|
179
|
+
"""
|
|
180
|
+
return ResponseFormat.PLAIN_TEXT
|
|
181
|
+
|
|
182
|
+
async def _build_clients_index(self) -> None:
|
|
183
|
+
"""
|
|
184
|
+
Build the internal index of tools.
|
|
185
|
+
|
|
186
|
+
Iterates through the provided tools, resolving ToolClients into individual
|
|
187
|
+
callable proxies and indexing PythonTools directly.
|
|
188
|
+
"""
|
|
189
|
+
self._indexed_tools.clear()
|
|
190
|
+
|
|
191
|
+
for tool in self._tools:
|
|
192
|
+
if isinstance(tool, ToolClient):
|
|
193
|
+
infos: list[ToolInfo] = await tool.list_tools()
|
|
194
|
+
for ti in infos:
|
|
195
|
+
qname = f'{ti.alias}.{ti.name}'
|
|
196
|
+
proxy = ClientToolProxy(
|
|
197
|
+
client=tool,
|
|
198
|
+
alias=ti.alias,
|
|
199
|
+
name=ti.name,
|
|
200
|
+
schema=ti.input_schema or {},
|
|
201
|
+
description=ti.description or '',
|
|
202
|
+
)
|
|
203
|
+
proxy.set_timeout(self.per_call_timeout)
|
|
204
|
+
self._indexed_tools[qname] = proxy
|
|
205
|
+
elif isinstance(tool, PythonTool):
|
|
206
|
+
if tool.name in self._indexed_tools:
|
|
207
|
+
msg = f'Tool name conflict: {tool.name} is already defined.'
|
|
208
|
+
raise ValueError(msg)
|
|
209
|
+
proxy = _PythonToolProxy(tool, timeout=self.per_call_timeout) # type: ignore[assignment]
|
|
210
|
+
self._indexed_tools[tool.name] = proxy
|
|
211
|
+
else:
|
|
212
|
+
msg = f'Unsupported tool type: {type(tool)}'
|
|
213
|
+
raise TypeError(msg)
|
|
214
|
+
|
|
215
|
+
self._is_tools_index_built = True
|
|
216
|
+
|
|
217
|
+
def _merge_attachments(self, query: str, attachments: list[FileAttachment] | None) -> str:
|
|
218
|
+
"""
|
|
219
|
+
Merge plain text attachments into the user query.
|
|
220
|
+
|
|
221
|
+
Args:
|
|
222
|
+
query: The original user query.
|
|
223
|
+
attachments: Optional list of file attachments.
|
|
224
|
+
|
|
225
|
+
Returns:
|
|
226
|
+
str: The query augmented with attachment content.
|
|
227
|
+
"""
|
|
228
|
+
if not attachments:
|
|
229
|
+
return query
|
|
230
|
+
extras = [str(a.content) for a in attachments if a.type == PLAIN_TEXT]
|
|
231
|
+
if not extras:
|
|
232
|
+
return query
|
|
233
|
+
return f'{query}\n\n[ATTACHMENTS]\n' + '\n\n'.join(extras)
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from amsdal_ml.mcp_client.base import ToolClient
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class ClientToolProxy:
|
|
9
|
+
def __init__(
|
|
10
|
+
self,
|
|
11
|
+
client: ToolClient,
|
|
12
|
+
alias: str,
|
|
13
|
+
name: str,
|
|
14
|
+
schema: dict[str, Any],
|
|
15
|
+
description: str,
|
|
16
|
+
):
|
|
17
|
+
self.client = client
|
|
18
|
+
self.alias = alias
|
|
19
|
+
self.name = name
|
|
20
|
+
self.qualified = f'{alias}.{name}'
|
|
21
|
+
self.parameters = schema
|
|
22
|
+
self.description = description
|
|
23
|
+
self._default_timeout: float | None = 20.0
|
|
24
|
+
|
|
25
|
+
def set_timeout(self, timeout: float | None) -> None:
|
|
26
|
+
self._default_timeout = timeout
|
|
27
|
+
|
|
28
|
+
async def run(
|
|
29
|
+
self,
|
|
30
|
+
args: dict[str, Any],
|
|
31
|
+
context: Any = None,
|
|
32
|
+
*,
|
|
33
|
+
convert_result: bool = True,
|
|
34
|
+
) -> Any:
|
|
35
|
+
_ = (context, convert_result)
|
|
36
|
+
|
|
37
|
+
if self.parameters:
|
|
38
|
+
try:
|
|
39
|
+
import jsonschema
|
|
40
|
+
|
|
41
|
+
jsonschema.validate(instance=args, schema=self.parameters)
|
|
42
|
+
except Exception as exc:
|
|
43
|
+
msg = f'Tool input validation failed for {self.qualified}: {exc}'
|
|
44
|
+
raise ValueError(msg) from exc
|
|
45
|
+
|
|
46
|
+
return await self.client.call(self.name, args, timeout=self._default_timeout)
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import inspect
|
|
4
|
+
from collections.abc import Callable
|
|
5
|
+
from collections.abc import Coroutine
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from pydantic import create_model
|
|
9
|
+
from pydantic.fields import FieldInfo
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class PythonTool:
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
func: Callable[..., Coroutine[Any, Any, Any] | Any],
|
|
16
|
+
name: str,
|
|
17
|
+
description: str,
|
|
18
|
+
):
|
|
19
|
+
if not inspect.iscoroutinefunction(func) and not inspect.isfunction(func):
|
|
20
|
+
msg = 'Tool must be a function or coroutine function'
|
|
21
|
+
raise TypeError(msg)
|
|
22
|
+
|
|
23
|
+
self.func = func
|
|
24
|
+
self.name = name
|
|
25
|
+
self.description = description
|
|
26
|
+
self.is_async = inspect.iscoroutinefunction(func)
|
|
27
|
+
self.parameters = self._build_schema()
|
|
28
|
+
|
|
29
|
+
def _build_schema(self) -> dict[str, Any]:
|
|
30
|
+
sig = inspect.signature(self.func)
|
|
31
|
+
fields: dict[str, Any] = {}
|
|
32
|
+
for param in sig.parameters.values():
|
|
33
|
+
if param.kind not in (
|
|
34
|
+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
|
35
|
+
inspect.Parameter.KEYWORD_ONLY,
|
|
36
|
+
):
|
|
37
|
+
continue
|
|
38
|
+
|
|
39
|
+
field_info = (
|
|
40
|
+
param.default
|
|
41
|
+
if isinstance(param.default, FieldInfo)
|
|
42
|
+
else FieldInfo(
|
|
43
|
+
default=param.default if param.default is not inspect.Parameter.empty else ...,
|
|
44
|
+
description=None,
|
|
45
|
+
)
|
|
46
|
+
)
|
|
47
|
+
fields[param.name] = (
|
|
48
|
+
param.annotation if param.annotation is not inspect.Parameter.empty else Any,
|
|
49
|
+
field_info,
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
model = create_model(f'{self.name}Input', **fields)
|
|
53
|
+
schema = model.model_json_schema()
|
|
54
|
+
|
|
55
|
+
return {
|
|
56
|
+
'type': 'object',
|
|
57
|
+
'properties': schema.get('properties', {}),
|
|
58
|
+
'required': schema.get('required', []),
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class _PythonToolProxy:
|
|
63
|
+
def __init__(self, tool: PythonTool, timeout: float | None = 20.0):
|
|
64
|
+
self.tool = tool
|
|
65
|
+
self.name = tool.name
|
|
66
|
+
self.qualified = tool.name
|
|
67
|
+
self.parameters = tool.parameters
|
|
68
|
+
self.description = tool.description
|
|
69
|
+
self._default_timeout = timeout
|
|
70
|
+
|
|
71
|
+
def set_timeout(self, timeout: float | None) -> None:
|
|
72
|
+
self._default_timeout = timeout
|
|
73
|
+
|
|
74
|
+
async def run(
|
|
75
|
+
self,
|
|
76
|
+
args: dict[str, Any],
|
|
77
|
+
context: Any = None,
|
|
78
|
+
*,
|
|
79
|
+
convert_result: bool = True,
|
|
80
|
+
) -> Any:
|
|
81
|
+
_ = (context, convert_result)
|
|
82
|
+
|
|
83
|
+
if self.tool.is_async:
|
|
84
|
+
return await self.tool.func(**args)
|
|
85
|
+
else:
|
|
86
|
+
return self.tool.func(**args)
|
|
@@ -14,12 +14,10 @@ from amsdal_ml.ml_retrievers.openai_retriever import OpenAIRetriever
|
|
|
14
14
|
logging.basicConfig(
|
|
15
15
|
level=logging.INFO,
|
|
16
16
|
format='%(asctime)s [%(levelname)s] %(message)s',
|
|
17
|
-
handlers=[
|
|
18
|
-
logging.FileHandler("server2.log"),
|
|
19
|
-
logging.StreamHandler(sys.stdout)
|
|
20
|
-
]
|
|
17
|
+
handlers=[logging.FileHandler('server2.log'), logging.StreamHandler(sys.stdout)],
|
|
21
18
|
)
|
|
22
19
|
|
|
20
|
+
|
|
23
21
|
class RetrieverArgs(BaseModel):
|
|
24
22
|
query: str = Field(..., description='User search query')
|
|
25
23
|
k: int = 5
|
|
@@ -29,6 +27,7 @@ class RetrieverArgs(BaseModel):
|
|
|
29
27
|
|
|
30
28
|
class _RetrieverSingleton:
|
|
31
29
|
"""Singleton holder for lazy retriever initialization."""
|
|
30
|
+
|
|
32
31
|
_instance: Optional[OpenAIRetriever] = None
|
|
33
32
|
|
|
34
33
|
@classmethod
|
|
@@ -46,7 +45,7 @@ async def retriever_search(
|
|
|
46
45
|
exclude_tags: Optional[list[str]] = None,
|
|
47
46
|
) -> list[dict[str, Any]]:
|
|
48
47
|
logging.info(
|
|
49
|
-
f
|
|
48
|
+
f'retriever_search called with query={query}, k={k}, include_tags={include_tags}, exclude_tags={exclude_tags}'
|
|
50
49
|
)
|
|
51
50
|
retriever = _RetrieverSingleton.get()
|
|
52
51
|
chunks = await retriever.asimilarity_search(
|
|
@@ -55,7 +54,7 @@ async def retriever_search(
|
|
|
55
54
|
include_tags=include_tags,
|
|
56
55
|
exclude_tags=exclude_tags,
|
|
57
56
|
)
|
|
58
|
-
logging.info(f
|
|
57
|
+
logging.info(f'retriever_search found {len(chunks)} chunks: {chunks}')
|
|
59
58
|
out: list[dict[str, Any]] = []
|
|
60
59
|
for c in chunks:
|
|
61
60
|
if hasattr(c, 'model_dump'):
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from abc import ABC
|
|
4
|
+
from abc import abstractmethod
|
|
5
|
+
from typing import TYPE_CHECKING
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from amsdal_ml.ml_models.models import MLModel
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ToolAdapter(ABC):
|
|
13
|
+
@abstractmethod
|
|
14
|
+
def get_tools_schema(self, tools: dict[str, Any]) -> list[dict[str, Any]]:
|
|
15
|
+
"""
|
|
16
|
+
Converts indexed tools to the model-specific function calling schema.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
tools: A dictionary mapping tool names to tool proxy objects.
|
|
20
|
+
Tool objects are expected to have `description` and `parameters` attributes.
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
list[dict[str, Any]]: A list of tool definitions compatible with the model's API.
|
|
24
|
+
"""
|
|
25
|
+
raise NotImplementedError
|
|
26
|
+
|
|
27
|
+
@abstractmethod
|
|
28
|
+
def parse_response(self, response_data: dict[str, Any]) -> tuple[str | None, list[dict[str, Any]] | None]:
|
|
29
|
+
"""
|
|
30
|
+
Parses the model's response to extract content and tool calls.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
response_data: The JSON response from the model.
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
tuple[str | None, list[dict[str, Any]] | None]: A tuple containing:
|
|
37
|
+
- content_text: The text content of the response (or None).
|
|
38
|
+
- tool_calls: A list of tool calls (or None).
|
|
39
|
+
"""
|
|
40
|
+
raise NotImplementedError
|
|
41
|
+
|
|
42
|
+
@abstractmethod
|
|
43
|
+
def get_tool_call_info(self, tool_call: dict[str, Any]) -> tuple[str, str, str]:
|
|
44
|
+
"""
|
|
45
|
+
Extracts function name, arguments string, and call ID from a tool call object.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
tool_call: A single tool call object from the parsed response.
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
tuple[str, str, str]: A tuple containing (function_name, arguments_str, call_id).
|
|
52
|
+
"""
|
|
53
|
+
raise NotImplementedError
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class OpenAIToolAdapter(ToolAdapter):
|
|
57
|
+
def get_tools_schema(self, tools: dict[str, Any]) -> list[dict[str, Any]]:
|
|
58
|
+
tools_schema = []
|
|
59
|
+
for name, tool in tools.items():
|
|
60
|
+
tools_schema.append({
|
|
61
|
+
'type': 'function',
|
|
62
|
+
'function': {
|
|
63
|
+
'name': name,
|
|
64
|
+
'description': tool.description,
|
|
65
|
+
'parameters': tool.parameters,
|
|
66
|
+
},
|
|
67
|
+
})
|
|
68
|
+
return tools_schema
|
|
69
|
+
|
|
70
|
+
def parse_response(self, response_data: dict[str, Any]) -> tuple[str | None, list[dict[str, Any]] | None]:
|
|
71
|
+
content_text = response_data.get('content')
|
|
72
|
+
tool_calls = response_data.get('tool_calls')
|
|
73
|
+
return content_text, tool_calls
|
|
74
|
+
|
|
75
|
+
def get_tool_call_info(self, tool_call: dict[str, Any]) -> tuple[str, str, str]:
|
|
76
|
+
function_name = tool_call['function']['name']
|
|
77
|
+
arguments_str = tool_call['function']['arguments']
|
|
78
|
+
call_id = tool_call['id']
|
|
79
|
+
return function_name, arguments_str, call_id
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def get_tool_adapter(model: MLModel) -> ToolAdapter:
|
|
84
|
+
"""
|
|
85
|
+
Factory function to get the appropriate ToolAdapter for a given model.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
model: The MLModel instance.
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
ToolAdapter: An instance of a ToolAdapter subclass.
|
|
92
|
+
"""
|
|
93
|
+
model_name = model.__class__.__name__.lower()
|
|
94
|
+
|
|
95
|
+
if "openai" in model_name:
|
|
96
|
+
return OpenAIToolAdapter()
|
|
97
|
+
|
|
98
|
+
return OpenAIToolAdapter()
|
amsdal_ml/fileio/base_loader.py
CHANGED
|
@@ -9,8 +9,8 @@ from typing import Optional
|
|
|
9
9
|
|
|
10
10
|
from pydantic import BaseModel
|
|
11
11
|
|
|
12
|
-
PLAIN_TEXT =
|
|
13
|
-
FILE_ID =
|
|
12
|
+
PLAIN_TEXT = 'plain_text'
|
|
13
|
+
FILE_ID = 'file_id'
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
class FileData(BaseModel):
|
|
@@ -39,18 +39,20 @@ class FileItem:
|
|
|
39
39
|
@staticmethod
|
|
40
40
|
def from_path(path: str, *, filedata: FileData | None = None) -> FileItem:
|
|
41
41
|
# Caller is responsible for lifecycle; loaders may close after upload.
|
|
42
|
-
f = open(path,
|
|
43
|
-
return FileItem(file=f, filename=path.split(
|
|
42
|
+
f = open(path, 'rb')
|
|
43
|
+
return FileItem(file=f, filename=path.split('/')[-1], filedata=filedata)
|
|
44
44
|
|
|
45
45
|
@staticmethod
|
|
46
46
|
def from_bytes(data: bytes, *, filename: str | None = None, filedata: FileData | None = None) -> FileItem:
|
|
47
47
|
import io
|
|
48
|
+
|
|
48
49
|
return FileItem(file=io.BytesIO(data), filename=filename, filedata=filedata)
|
|
49
50
|
|
|
50
51
|
@staticmethod
|
|
51
52
|
def from_str(text: str, *, filename: str | None = None, filedata: FileData | None = None) -> FileItem:
|
|
52
53
|
import io
|
|
53
|
-
|
|
54
|
+
|
|
55
|
+
return FileItem(file=io.BytesIO(text.encode('utf-8')), filename=filename, filedata=filedata)
|
|
54
56
|
|
|
55
57
|
|
|
56
58
|
class BaseFileLoader(ABC):
|
|
@@ -18,7 +18,7 @@ from amsdal_ml.fileio.base_loader import FileItem
|
|
|
18
18
|
|
|
19
19
|
logger = logging.getLogger(__name__)
|
|
20
20
|
|
|
21
|
-
AllowedPurpose = Literal[
|
|
21
|
+
AllowedPurpose = Literal['assistants', 'batch', 'fine-tune', 'vision', 'user_data', 'evals']
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
class OpenAIFileLoader(BaseFileLoader):
|
|
@@ -26,35 +26,35 @@ class OpenAIFileLoader(BaseFileLoader):
|
|
|
26
26
|
Loader which uploads files into OpenAI Files API and returns openai_file_id.
|
|
27
27
|
"""
|
|
28
28
|
|
|
29
|
-
def __init__(self, client: AsyncOpenAI, *, purpose: AllowedPurpose =
|
|
29
|
+
def __init__(self, client: AsyncOpenAI, *, purpose: AllowedPurpose = 'assistants') -> None:
|
|
30
30
|
self.client = client
|
|
31
31
|
self.purpose: AllowedPurpose = purpose # mypy: Literal union, matches SDK
|
|
32
32
|
|
|
33
33
|
async def _upload_one(self, file: BinaryIO, *, filename: str | None, filedata: FileData | None) -> FileAttachment:
|
|
34
34
|
try:
|
|
35
|
-
if hasattr(file,
|
|
35
|
+
if hasattr(file, 'seek'):
|
|
36
36
|
file.seek(0)
|
|
37
37
|
except Exception as exc: # pragma: no cover
|
|
38
|
-
logger.debug(
|
|
38
|
+
logger.debug('seek(0) failed for %r: %s', filename or file, exc)
|
|
39
39
|
|
|
40
40
|
buf = file if isinstance(file, io.BytesIO) else io.BytesIO(file.read())
|
|
41
41
|
|
|
42
|
-
up = await self.client.files.create(file=(filename or
|
|
42
|
+
up = await self.client.files.create(file=(filename or 'upload.bin', buf), purpose=self.purpose)
|
|
43
43
|
|
|
44
44
|
meta: dict[str, Any] = {
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
45
|
+
'filename': filename,
|
|
46
|
+
'provider': 'openai',
|
|
47
|
+
'file': {
|
|
48
|
+
'id': up.id,
|
|
49
|
+
'bytes': getattr(up, 'bytes', None),
|
|
50
|
+
'purpose': getattr(up, 'purpose', self.purpose),
|
|
51
|
+
'created_at': getattr(up, 'created_at', None),
|
|
52
|
+
'status': getattr(up, 'status', None),
|
|
53
|
+
'status_details': getattr(up, 'status_details', None),
|
|
54
54
|
},
|
|
55
55
|
}
|
|
56
56
|
if filedata is not None:
|
|
57
|
-
meta[
|
|
57
|
+
meta['filedata'] = filedata.model_dump()
|
|
58
58
|
|
|
59
59
|
return FileAttachment(type=FILE_ID, content=up.id, metadata=meta)
|
|
60
60
|
|
|
@@ -63,7 +63,6 @@ class OpenAIFileLoader(BaseFileLoader):
|
|
|
63
63
|
|
|
64
64
|
async def load_batch(self, items: Sequence[FileItem]) -> list[FileAttachment]:
|
|
65
65
|
tasks = [
|
|
66
|
-
asyncio.create_task(self._upload_one(it.file, filename=it.filename, filedata=it.filedata))
|
|
67
|
-
for it in items
|
|
66
|
+
asyncio.create_task(self._upload_one(it.file, filename=it.filename, filedata=it.filedata)) for it in items
|
|
68
67
|
]
|
|
69
68
|
return await asyncio.gather(*tasks)
|
amsdal_ml/mcp_client/base.py
CHANGED
|
@@ -3,6 +3,7 @@ from __future__ import annotations
|
|
|
3
3
|
from dataclasses import dataclass
|
|
4
4
|
from typing import Any
|
|
5
5
|
from typing import Protocol
|
|
6
|
+
from typing import runtime_checkable
|
|
6
7
|
|
|
7
8
|
|
|
8
9
|
@dataclass
|
|
@@ -13,6 +14,7 @@ class ToolInfo:
|
|
|
13
14
|
input_schema: dict[str, Any]
|
|
14
15
|
|
|
15
16
|
|
|
17
|
+
@runtime_checkable
|
|
16
18
|
class ToolClient(Protocol):
|
|
17
19
|
alias: str
|
|
18
20
|
|
|
@@ -42,7 +42,13 @@ class HttpClient(ToolClient):
|
|
|
42
42
|
finally:
|
|
43
43
|
await stack.aclose()
|
|
44
44
|
|
|
45
|
-
async def call(
|
|
45
|
+
async def call(
|
|
46
|
+
self,
|
|
47
|
+
tool_name: str,
|
|
48
|
+
args: dict[str, Any],
|
|
49
|
+
*,
|
|
50
|
+
timeout: float | None = None,
|
|
51
|
+
) -> Any:
|
|
46
52
|
_ = timeout # ARG002
|
|
47
53
|
stack, s = await self._session()
|
|
48
54
|
try:
|