amsdal_ml 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- amsdal_ml/Third-Party Materials - AMSDAL Dependencies - License Notices.md +1477 -0
- amsdal_ml/__about__.py +1 -0
- amsdal_ml/__init__.py +0 -0
- amsdal_ml/agents/__init__.py +0 -0
- amsdal_ml/agents/agent.py +39 -0
- amsdal_ml/agents/default_qa_agent.py +376 -0
- amsdal_ml/agents/promts/__init__.py +58 -0
- amsdal_ml/agents/promts/react_chat.prompt +37 -0
- amsdal_ml/agents/retriever_tool.py +48 -0
- amsdal_ml/app.py +24 -0
- amsdal_ml/mcp_client/__init__.py +0 -0
- amsdal_ml/mcp_client/base.py +20 -0
- amsdal_ml/mcp_client/http_client.py +52 -0
- amsdal_ml/mcp_client/stdio_client.py +130 -0
- amsdal_ml/mcp_server/__init__.py +0 -0
- amsdal_ml/mcp_server/server_retriever_stdio.py +11 -0
- amsdal_ml/ml_config.py +57 -0
- amsdal_ml/ml_ingesting/__init__.py +0 -0
- amsdal_ml/ml_ingesting/default_ingesting.py +319 -0
- amsdal_ml/ml_ingesting/embedding_data.py +9 -0
- amsdal_ml/ml_ingesting/ingesting.py +52 -0
- amsdal_ml/ml_ingesting/openai_ingesting.py +38 -0
- amsdal_ml/ml_models/__init__.py +0 -0
- amsdal_ml/ml_models/models.py +50 -0
- amsdal_ml/ml_models/openai_model.py +171 -0
- amsdal_ml/ml_retrievers/__init__.py +0 -0
- amsdal_ml/ml_retrievers/default_retriever.py +105 -0
- amsdal_ml/ml_retrievers/openai_retriever.py +39 -0
- amsdal_ml/ml_retrievers/retriever.py +40 -0
- amsdal_ml/models/__init__.py +0 -0
- amsdal_ml/models/embedding_model.py +21 -0
- amsdal_ml/py.typed +0 -0
- amsdal_ml-0.1.0.dist-info/METADATA +69 -0
- amsdal_ml-0.1.0.dist-info/RECORD +35 -0
- amsdal_ml-0.1.0.dist-info/WHEEL +4 -0
amsdal_ml/__about__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = '0.1.0'
|
amsdal_ml/__init__.py
ADDED
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from abc import ABC
|
|
4
|
+
from abc import abstractmethod
|
|
5
|
+
from collections.abc import AsyncIterator
|
|
6
|
+
from collections.abc import Iterator
|
|
7
|
+
from typing import Any
|
|
8
|
+
from typing import Literal
|
|
9
|
+
|
|
10
|
+
from pydantic import BaseModel
|
|
11
|
+
from pydantic import Field
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class AgentMessage(BaseModel):
|
|
15
|
+
role: Literal["SYSTEM", "USER", "ASSISTANT"]
|
|
16
|
+
content: str
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class AgentOutput(BaseModel):
|
|
20
|
+
answer: str
|
|
21
|
+
used_tools: list[str] = Field(default_factory=list)
|
|
22
|
+
citations: list[dict[str, Any]] = Field(default_factory=list)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class Agent(ABC):
|
|
26
|
+
@abstractmethod
|
|
27
|
+
async def arun(self, user_query: str) -> AgentOutput: ...
|
|
28
|
+
@abstractmethod
|
|
29
|
+
async def astream(self, user_query: str) -> AsyncIterator[str]:
|
|
30
|
+
"""Yield streamed chunks for the given query."""
|
|
31
|
+
raise NotImplementedError
|
|
32
|
+
|
|
33
|
+
def run(self, user_query: str) -> AgentOutput:
|
|
34
|
+
msg = "This agent is async-only. Use arun()."
|
|
35
|
+
raise NotImplementedError(msg)
|
|
36
|
+
|
|
37
|
+
def stream(self, user_query: str) -> Iterator[str]:
|
|
38
|
+
msg = "This agent is async-only. Use astream()."
|
|
39
|
+
raise NotImplementedError(msg)
|
|
@@ -0,0 +1,376 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import inspect
|
|
4
|
+
import json
|
|
5
|
+
import re
|
|
6
|
+
from collections.abc import AsyncIterator
|
|
7
|
+
from collections.abc import Callable
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from enum import Enum
|
|
10
|
+
from typing import Any
|
|
11
|
+
from typing import no_type_check
|
|
12
|
+
|
|
13
|
+
from amsdal_ml.agents.agent import Agent
|
|
14
|
+
from amsdal_ml.agents.agent import AgentOutput
|
|
15
|
+
from amsdal_ml.agents.promts import get_prompt
|
|
16
|
+
from amsdal_ml.mcp_client.base import ToolClient
|
|
17
|
+
from amsdal_ml.mcp_client.base import ToolInfo
|
|
18
|
+
from amsdal_ml.ml_models.models import MLModel
|
|
19
|
+
|
|
20
|
+
# ---------- STRICT ReAct REGEX ----------
|
|
21
|
+
_TOOL_CALL_RE = re.compile(
|
|
22
|
+
r"Thought:\s*Do I need to use a tool\?\s*Yes\s*"
|
|
23
|
+
r"Action:\s*(?P<action>[^\n]+)\s*"
|
|
24
|
+
r"Action Input:\s*(?P<input>\{.*\})\s*$",
|
|
25
|
+
re.DOTALL | re.IGNORECASE,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
_FINAL_RE = re.compile(
|
|
29
|
+
r"Thought:\s*Do I need to use a tool\?\s*No\s*"
|
|
30
|
+
r"Final Answer:\s*(?P<answer>.+)",
|
|
31
|
+
re.DOTALL | re.IGNORECASE,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
# ---------- STRICT ReAct REGEX ----------
|
|
37
|
+
|
|
38
|
+
@dataclass
|
|
39
|
+
class Route:
|
|
40
|
+
name: str
|
|
41
|
+
match: Callable[[str], bool]
|
|
42
|
+
handler: Callable[[str], AgentOutput]
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class ParseErrorMode(Enum):
|
|
46
|
+
RAISE = "raise"
|
|
47
|
+
RETRY = "retry"
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
# === proxy-tool ===
|
|
51
|
+
class _ClientToolProxy:
|
|
52
|
+
def __init__(self, client: ToolClient, alias: str, name: str, schema: dict[str, Any], description: str):
|
|
53
|
+
self.client = client
|
|
54
|
+
self.alias = alias
|
|
55
|
+
self.name = name
|
|
56
|
+
self.qualified = f"{alias}.{name}"
|
|
57
|
+
self.parameters = schema
|
|
58
|
+
self.description = description
|
|
59
|
+
self._default_timeout: float | None = 20.0
|
|
60
|
+
|
|
61
|
+
def set_timeout(self, timeout: float | None) -> None:
|
|
62
|
+
self._default_timeout = timeout
|
|
63
|
+
|
|
64
|
+
async def run(
|
|
65
|
+
self,
|
|
66
|
+
args: dict[str, Any],
|
|
67
|
+
context=None,
|
|
68
|
+
*,
|
|
69
|
+
convert_result: bool = True,
|
|
70
|
+
):
|
|
71
|
+
_ = (context, convert_result)
|
|
72
|
+
|
|
73
|
+
if self.parameters:
|
|
74
|
+
try:
|
|
75
|
+
import jsonschema
|
|
76
|
+
jsonschema.validate(instance=args, schema=self.parameters)
|
|
77
|
+
except Exception as exc:
|
|
78
|
+
msg = f"Tool input validation failed for {self.qualified}: {exc}"
|
|
79
|
+
raise ValueError(
|
|
80
|
+
msg
|
|
81
|
+
) from exc
|
|
82
|
+
|
|
83
|
+
return await self.client.call(self.name, args, timeout=self._default_timeout)
|
|
84
|
+
|
|
85
|
+
class DefaultQAAgent(Agent):
|
|
86
|
+
|
|
87
|
+
def __init__(
|
|
88
|
+
self,
|
|
89
|
+
*,
|
|
90
|
+
model: MLModel,
|
|
91
|
+
tool_clients: list[ToolClient] | None = None,
|
|
92
|
+
max_steps: int = 6,
|
|
93
|
+
on_parse_error: ParseErrorMode = ParseErrorMode.RAISE,
|
|
94
|
+
enable_stop_guard: bool = True,
|
|
95
|
+
per_call_timeout: float | None = 20.0,
|
|
96
|
+
):
|
|
97
|
+
|
|
98
|
+
# Only clients MCP (stdio/sse)
|
|
99
|
+
self._tool_clients: list[ToolClient] = tool_clients or []
|
|
100
|
+
self._indexed_tools: dict[str, Any] = {} # qualified -> proxy
|
|
101
|
+
|
|
102
|
+
self.model = model
|
|
103
|
+
self.model.setup()
|
|
104
|
+
self.max_steps = max_steps
|
|
105
|
+
self.per_call_timeout = per_call_timeout
|
|
106
|
+
self.on_parse_error = on_parse_error
|
|
107
|
+
self.enable_stop_guard = enable_stop_guard
|
|
108
|
+
|
|
109
|
+
self._tools_index_built = False
|
|
110
|
+
|
|
111
|
+
# ---------- tools helpers ----------
|
|
112
|
+
def _get_tool(self, name: str) -> Any:
|
|
113
|
+
"""
|
|
114
|
+
Look up tools ONLY among client-indexed tools.
|
|
115
|
+
Expected names are qualified: '<alias>.<tool_name>'.
|
|
116
|
+
"""
|
|
117
|
+
if not self._tools_index_built:
|
|
118
|
+
msg = "Tool index not built. Ensure arun()/astream() was used."
|
|
119
|
+
raise RuntimeError(msg)
|
|
120
|
+
if name in self._indexed_tools:
|
|
121
|
+
return self._indexed_tools[name]
|
|
122
|
+
available = sorted(self._indexed_tool_names())
|
|
123
|
+
msg = f"Unknown tool: {name}. Available: {', '.join(available)}"
|
|
124
|
+
raise KeyError(msg)
|
|
125
|
+
|
|
126
|
+
def _indexed_tool_names(self) -> list[str]:
|
|
127
|
+
return list(self._indexed_tools.keys()) if self._tools_index_built else []
|
|
128
|
+
|
|
129
|
+
def _tool_names(self) -> str:
|
|
130
|
+
return ", ".join(sorted(self._indexed_tool_names()))
|
|
131
|
+
|
|
132
|
+
def _tool_descriptions(self) -> str:
|
|
133
|
+
parts: list[str] = []
|
|
134
|
+
if self._tools_index_built:
|
|
135
|
+
for qn, t in self._indexed_tools.items():
|
|
136
|
+
desc = t.description or "No description."
|
|
137
|
+
try:
|
|
138
|
+
schema_json = json.dumps(t.parameters or {}, ensure_ascii=False)
|
|
139
|
+
except Exception:
|
|
140
|
+
schema_json = str(t.parameters)
|
|
141
|
+
parts.append(f"- {qn}: {desc}\n Args JSON schema: {schema_json}")
|
|
142
|
+
return "\n".join(parts)
|
|
143
|
+
|
|
144
|
+
async def _build_clients_index(self):
|
|
145
|
+
self._indexed_tools.clear()
|
|
146
|
+
for client in self._tool_clients:
|
|
147
|
+
infos: list[ToolInfo] = await client.list_tools()
|
|
148
|
+
for ti in infos:
|
|
149
|
+
qname = f"{ti.alias}.{ti.name}"
|
|
150
|
+
proxy = _ClientToolProxy(
|
|
151
|
+
client=client,
|
|
152
|
+
alias=ti.alias,
|
|
153
|
+
name=ti.name,
|
|
154
|
+
schema=ti.input_schema or {},
|
|
155
|
+
description=ti.description or "",
|
|
156
|
+
)
|
|
157
|
+
proxy.set_timeout(self.per_call_timeout)
|
|
158
|
+
self._indexed_tools[qname] = proxy
|
|
159
|
+
self._tools_index_built = True
|
|
160
|
+
|
|
161
|
+
# ---------- prompt composition ----------
|
|
162
|
+
def _react_text(self, user_query: str, scratchpad: str) -> str:
|
|
163
|
+
tmpl = get_prompt("react_chat")
|
|
164
|
+
return tmpl.render_text(
|
|
165
|
+
user_query=user_query,
|
|
166
|
+
tools=self._tool_descriptions(),
|
|
167
|
+
tool_names=self._tool_names(),
|
|
168
|
+
agent_scratchpad=scratchpad,
|
|
169
|
+
chat_history="",
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
@staticmethod
|
|
173
|
+
def _stopped_message() -> str:
|
|
174
|
+
return "Agent stopped due to iteration limit or time limit."
|
|
175
|
+
|
|
176
|
+
def _stopped_response(self, used_tools: list[str]) -> AgentOutput:
|
|
177
|
+
return AgentOutput(answer=self._stopped_message(), used_tools=used_tools, citations=[])
|
|
178
|
+
|
|
179
|
+
@staticmethod
|
|
180
|
+
def _serialize_observation(content: Any) -> str:
|
|
181
|
+
if isinstance(content, str | bytes):
|
|
182
|
+
return content if isinstance(content, str) else content.decode("utf-8", errors="ignore")
|
|
183
|
+
try:
|
|
184
|
+
return json.dumps(content, ensure_ascii=False)
|
|
185
|
+
except Exception:
|
|
186
|
+
return str(content)
|
|
187
|
+
|
|
188
|
+
# ---------- core run ----------
|
|
189
|
+
def run(self, _user_query: str) -> AgentOutput:
|
|
190
|
+
msg = "DefaultQAAgent is async-only for now. Use arun()."
|
|
191
|
+
raise NotImplementedError(msg)
|
|
192
|
+
|
|
193
|
+
async def _run_async(self, user_query: str) -> AgentOutput:
|
|
194
|
+
if self._tool_clients and not self._tools_index_built:
|
|
195
|
+
await self._build_clients_index()
|
|
196
|
+
|
|
197
|
+
scratch = ""
|
|
198
|
+
used_tools: list[str] = []
|
|
199
|
+
parse_retries = 0
|
|
200
|
+
|
|
201
|
+
for _ in range(self.max_steps):
|
|
202
|
+
prompt = self._react_text(user_query, scratch)
|
|
203
|
+
out = await self.model.ainvoke(prompt)
|
|
204
|
+
|
|
205
|
+
m_final = _FINAL_RE.search(out or "")
|
|
206
|
+
if m_final:
|
|
207
|
+
return AgentOutput(
|
|
208
|
+
answer=(m_final.group("answer") or "").strip(),
|
|
209
|
+
used_tools=used_tools,
|
|
210
|
+
citations=[],
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
m_tool = _TOOL_CALL_RE.search(out or "")
|
|
214
|
+
if not m_tool:
|
|
215
|
+
parse_retries += 1
|
|
216
|
+
if self.on_parse_error == ParseErrorMode.RAISE or parse_retries >= 1:
|
|
217
|
+
msg = (
|
|
218
|
+
"Invalid ReAct output. Expected EXACT format (Final or Tool-call). "
|
|
219
|
+
f"Got:\n{out}"
|
|
220
|
+
)
|
|
221
|
+
raise ValueError(
|
|
222
|
+
msg
|
|
223
|
+
)
|
|
224
|
+
scratch += (
|
|
225
|
+
"\nThought: Previous output violated the strict format. "
|
|
226
|
+
"Reply again using EXACTLY one of the two specified formats.\n"
|
|
227
|
+
)
|
|
228
|
+
continue
|
|
229
|
+
|
|
230
|
+
action = m_tool.group("action").strip()
|
|
231
|
+
raw_input = m_tool.group("input").strip()
|
|
232
|
+
|
|
233
|
+
try:
|
|
234
|
+
args = json.loads(raw_input)
|
|
235
|
+
if not isinstance(args, dict):
|
|
236
|
+
msg = "Action Input must be a JSON object."
|
|
237
|
+
raise ValueError(msg)
|
|
238
|
+
except Exception as e:
|
|
239
|
+
parse_retries += 1
|
|
240
|
+
if self.on_parse_error == ParseErrorMode.RAISE or parse_retries >= 1:
|
|
241
|
+
msg = f"Invalid Action Input JSON: {raw_input!r} ({e})"
|
|
242
|
+
raise ValueError(msg) from e
|
|
243
|
+
scratch += (
|
|
244
|
+
"\nThought: Action Input must be a ONE-LINE JSON object. "
|
|
245
|
+
"Retry with correct JSON.\n"
|
|
246
|
+
)
|
|
247
|
+
continue
|
|
248
|
+
|
|
249
|
+
tool = self._get_tool(action)
|
|
250
|
+
|
|
251
|
+
try:
|
|
252
|
+
result = await tool.run(args, context=None, convert_result=True)
|
|
253
|
+
except Exception as e:
|
|
254
|
+
# unified error payload
|
|
255
|
+
err = {
|
|
256
|
+
"error": {
|
|
257
|
+
"type": e.__class__.__name__,
|
|
258
|
+
"server": getattr(tool, "alias", "local"),
|
|
259
|
+
"tool": getattr(tool, "name", getattr(tool, "qualified", "unknown")),
|
|
260
|
+
"message": str(e),
|
|
261
|
+
"retryable": False,
|
|
262
|
+
}
|
|
263
|
+
}
|
|
264
|
+
result = err
|
|
265
|
+
|
|
266
|
+
used_tools.append(action)
|
|
267
|
+
observation = self._serialize_observation(result)
|
|
268
|
+
|
|
269
|
+
scratch += (
|
|
270
|
+
"\nThought: Do I need to use a tool? Yes"
|
|
271
|
+
f"\nAction: {action}"
|
|
272
|
+
f"\nAction Input: {raw_input}"
|
|
273
|
+
f"\nObservation: {observation}\n"
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
return self._stopped_response(used_tools)
|
|
277
|
+
|
|
278
|
+
# ---------- public APIs ----------
|
|
279
|
+
async def arun(self, user_query: str) -> AgentOutput:
|
|
280
|
+
return await self._run_async(user_query)
|
|
281
|
+
|
|
282
|
+
# ---------- streaming ----------
|
|
283
|
+
@no_type_check
|
|
284
|
+
async def astream(self, user_query: str) -> AsyncIterator[str]:
|
|
285
|
+
if self._tool_clients and not self._tools_index_built:
|
|
286
|
+
await self._build_clients_index()
|
|
287
|
+
|
|
288
|
+
scratch = ""
|
|
289
|
+
used_tools: list[str] = []
|
|
290
|
+
parse_retries = 0
|
|
291
|
+
|
|
292
|
+
for _ in range(self.max_steps):
|
|
293
|
+
prompt = self._react_text(user_query, scratch)
|
|
294
|
+
|
|
295
|
+
buffer = ""
|
|
296
|
+
|
|
297
|
+
# Normalize model.astream: it might be an async iterator already,
|
|
298
|
+
# or a coroutine (or nested coroutines) that resolves to one.
|
|
299
|
+
_val = self.model.astream(prompt)
|
|
300
|
+
while inspect.iscoroutine(_val):
|
|
301
|
+
_val = await _val
|
|
302
|
+
|
|
303
|
+
# Optional guard (helpful during tests)
|
|
304
|
+
if not hasattr(_val, "__aiter__"):
|
|
305
|
+
msg = f"model.astream() did not yield an AsyncIterator; got {type(_val)!r}"
|
|
306
|
+
raise TypeError(msg)
|
|
307
|
+
|
|
308
|
+
model_stream = _val # now an AsyncIterator[str]
|
|
309
|
+
|
|
310
|
+
async for chunk in model_stream:
|
|
311
|
+
buffer += chunk
|
|
312
|
+
|
|
313
|
+
m_final = _FINAL_RE.search(buffer or "")
|
|
314
|
+
if m_final:
|
|
315
|
+
answer = (m_final.group("answer") or "").strip()
|
|
316
|
+
if answer:
|
|
317
|
+
yield answer
|
|
318
|
+
return
|
|
319
|
+
|
|
320
|
+
m_tool = _TOOL_CALL_RE.search(buffer or "")
|
|
321
|
+
if not m_tool:
|
|
322
|
+
parse_retries += 1
|
|
323
|
+
if self.on_parse_error == ParseErrorMode.RAISE or parse_retries >= 1:
|
|
324
|
+
msg = f"Invalid ReAct output (stream). Expected EXACT format. Got:\n{buffer}"
|
|
325
|
+
raise ValueError(msg)
|
|
326
|
+
scratch += (
|
|
327
|
+
"\nThought: Previous output violated the strict format. "
|
|
328
|
+
"Reply again using EXACTLY one of the two specified formats.\n"
|
|
329
|
+
)
|
|
330
|
+
continue
|
|
331
|
+
|
|
332
|
+
action = m_tool.group("action").strip()
|
|
333
|
+
raw_input = m_tool.group("input").strip()
|
|
334
|
+
|
|
335
|
+
try:
|
|
336
|
+
args = json.loads(raw_input)
|
|
337
|
+
if not isinstance(args, dict):
|
|
338
|
+
msg = "Action Input must be a JSON object."
|
|
339
|
+
raise ValueError(msg)
|
|
340
|
+
except Exception as e:
|
|
341
|
+
parse_retries += 1
|
|
342
|
+
if self.on_parse_error == ParseErrorMode.RAISE or parse_retries >= 1:
|
|
343
|
+
msg = f"Invalid Action Input JSON: {raw_input!r} ({e})"
|
|
344
|
+
raise ValueError(msg) from e
|
|
345
|
+
scratch += (
|
|
346
|
+
"\nThought: Action Input must be a ONE-LINE JSON object. "
|
|
347
|
+
"Retry with correct JSON.\n"
|
|
348
|
+
)
|
|
349
|
+
continue
|
|
350
|
+
|
|
351
|
+
tool = self._get_tool(action)
|
|
352
|
+
|
|
353
|
+
try:
|
|
354
|
+
result = await tool.run(args, context=None, convert_result=True)
|
|
355
|
+
except Exception as e:
|
|
356
|
+
result = {
|
|
357
|
+
"error": {
|
|
358
|
+
"type": e.__class__.__name__,
|
|
359
|
+
"server": getattr(tool, "alias", "local"),
|
|
360
|
+
"tool": getattr(tool, "name", getattr(tool, "qualified", "unknown")),
|
|
361
|
+
"message": str(e),
|
|
362
|
+
"retryable": False,
|
|
363
|
+
}
|
|
364
|
+
}
|
|
365
|
+
|
|
366
|
+
used_tools.append(action)
|
|
367
|
+
observation = self._serialize_observation(result)
|
|
368
|
+
|
|
369
|
+
scratch += (
|
|
370
|
+
"\nThought: Do I need to use a tool? Yes"
|
|
371
|
+
f"\nAction: {action}"
|
|
372
|
+
f"\nAction Input: {raw_input}"
|
|
373
|
+
f"\nObservation: {observation}\n"
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
yield self._stopped_message()
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any
|
|
6
|
+
from typing import Final
|
|
7
|
+
|
|
8
|
+
# avoid "magic numbers"
|
|
9
|
+
_PARTS_EXPECTED: Final[int] = 2
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class _SafeDict(dict[str, Any]):
|
|
13
|
+
def __missing__(self, key: str) -> str:
|
|
14
|
+
return "{" + key + "}"
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class Prompt:
|
|
19
|
+
name: str
|
|
20
|
+
system: str
|
|
21
|
+
user: str
|
|
22
|
+
|
|
23
|
+
def render_text(self, **kwargs: Any) -> str:
|
|
24
|
+
data = _SafeDict(**kwargs)
|
|
25
|
+
sys_txt = self.system.format_map(data)
|
|
26
|
+
usr_txt = self.user.format_map(data)
|
|
27
|
+
return f"{sys_txt}\n\n{usr_txt}".strip()
|
|
28
|
+
|
|
29
|
+
def render_messages(self, **kwargs: Any) -> list[dict[str, str]]:
|
|
30
|
+
data = _SafeDict(**kwargs)
|
|
31
|
+
return [
|
|
32
|
+
{"role": "system", "content": self.system.format_map(data)},
|
|
33
|
+
{"role": "user", "content": self.user.format_map(data)},
|
|
34
|
+
]
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
_prompt_cache: dict[str, Prompt] = {}
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _load_file(name: str) -> Prompt:
|
|
41
|
+
base = Path(__file__).resolve().parent
|
|
42
|
+
path = base / f"{name}.prompt"
|
|
43
|
+
if not path.exists():
|
|
44
|
+
msg = f"Prompt '{name}' not found at {path}"
|
|
45
|
+
raise FileNotFoundError(msg)
|
|
46
|
+
raw = path.read_text(encoding="utf-8")
|
|
47
|
+
parts = raw.split("\n---\n", 1)
|
|
48
|
+
if len(parts) == _PARTS_EXPECTED:
|
|
49
|
+
system, user = parts
|
|
50
|
+
else:
|
|
51
|
+
system, user = raw, "{input}"
|
|
52
|
+
return Prompt(name=name, system=system.strip(), user=user.strip())
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def get_prompt(name: str) -> Prompt:
|
|
56
|
+
if name not in _prompt_cache:
|
|
57
|
+
_prompt_cache[name] = _load_file(name)
|
|
58
|
+
return _prompt_cache[name]
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
You are a concise, helpful assistant. Use tools only when they improve accuracy.
|
|
2
|
+
|
|
3
|
+
TOOLS
|
|
4
|
+
------
|
|
5
|
+
{tools}
|
|
6
|
+
|
|
7
|
+
FORMAT (STRICT — EXACTLY ONE; NO extra lines, NO Markdown)
|
|
8
|
+
----------------------------------------------------------
|
|
9
|
+
1. If you do NOT need a tool:
|
|
10
|
+
Thought: Do I need to use a tool? No
|
|
11
|
+
Final Answer: <your final answer in plain text>
|
|
12
|
+
|
|
13
|
+
2. If you DO need a tool:
|
|
14
|
+
Thought: Do I need to use a tool? Yes
|
|
15
|
+
Action: <one tool name from: {tool_names}>
|
|
16
|
+
Action Input: <ONE-LINE JSON object with the tool arguments>
|
|
17
|
+
|
|
18
|
+
RULES
|
|
19
|
+
-----
|
|
20
|
+
- Output MUST match exactly one of the two blocks above.
|
|
21
|
+
- `Action Input` MUST be a valid ONE-LINE JSON object (e.g. {{"a": 1, "b": 2}}).
|
|
22
|
+
- Do NOT add anything before/after the block.
|
|
23
|
+
- Do NOT print "Observation". The system will add it after tool execution.
|
|
24
|
+
|
|
25
|
+
PREVIOUS CONVERSATION
|
|
26
|
+
--------------------
|
|
27
|
+
{chat_history}
|
|
28
|
+
|
|
29
|
+
NEW INPUT
|
|
30
|
+
---------
|
|
31
|
+
{user_query}
|
|
32
|
+
|
|
33
|
+
SCRATCHPAD
|
|
34
|
+
----------
|
|
35
|
+
{agent_scratchpad}
|
|
36
|
+
|
|
37
|
+
Assistant:
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
from mcp.server.fastmcp.tools.base import Tool
|
|
7
|
+
from pydantic import BaseModel
|
|
8
|
+
from pydantic import Field
|
|
9
|
+
|
|
10
|
+
from amsdal_ml.ml_retrievers.openai_retriever import OpenAIRetriever
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class RetrieverArgs(BaseModel):
|
|
14
|
+
query: str = Field(..., description='User search query')
|
|
15
|
+
k: int = 5
|
|
16
|
+
include_tags: Optional[list[str]] = None
|
|
17
|
+
exclude_tags: Optional[list[str]] = None
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
_retriever = OpenAIRetriever()
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
async def retriever_search(args: RetrieverArgs) -> list[dict[str, Any]]:
|
|
24
|
+
chunks = await _retriever.asimilarity_search(
|
|
25
|
+
query=args.query,
|
|
26
|
+
k=args.k,
|
|
27
|
+
include_tags=args.include_tags,
|
|
28
|
+
exclude_tags=args.exclude_tags,
|
|
29
|
+
)
|
|
30
|
+
out: list[dict[str, Any]] = []
|
|
31
|
+
for c in chunks:
|
|
32
|
+
if hasattr(c, 'model_dump'):
|
|
33
|
+
out.append(c.model_dump())
|
|
34
|
+
elif hasattr(c, 'dict'):
|
|
35
|
+
out.append(c.dict())
|
|
36
|
+
elif isinstance(c, dict):
|
|
37
|
+
out.append(c)
|
|
38
|
+
else:
|
|
39
|
+
out.append({'raw_text': str(c)})
|
|
40
|
+
return out
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
retriever_tool = Tool.from_function(
|
|
44
|
+
retriever_search,
|
|
45
|
+
name='retriever.search',
|
|
46
|
+
description='Semantic search in knowledge base (OpenAI embeddings)',
|
|
47
|
+
structured_output=True,
|
|
48
|
+
)
|
amsdal_ml/app.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
# During type checking, point to the known-available fallback.
|
|
6
|
+
if TYPE_CHECKING:
|
|
7
|
+
from amsdal.contrib.app_config import AppConfig as BaseAppConfig
|
|
8
|
+
else:
|
|
9
|
+
try:
|
|
10
|
+
# At runtime prefer the real AMSDAL core if present.
|
|
11
|
+
from amsdal.configs.app import AppConfig as BaseAppConfig # type: ignore[import-not-found]
|
|
12
|
+
except Exception: # pragma: no cover
|
|
13
|
+
from amsdal.contrib.app_config import AppConfig as BaseAppConfig
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class MLPluginAppConfig(BaseAppConfig):
|
|
17
|
+
name = "amsdal_ml"
|
|
18
|
+
verbose_name = "AMSDAL ML Plugin"
|
|
19
|
+
|
|
20
|
+
def on_ready(self) -> None:
|
|
21
|
+
pass
|
|
22
|
+
|
|
23
|
+
def on_server_startup(self) -> None:
|
|
24
|
+
pass
|
|
File without changes
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any
|
|
5
|
+
from typing import Protocol
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class ToolInfo:
|
|
10
|
+
alias: str
|
|
11
|
+
name: str
|
|
12
|
+
description: str
|
|
13
|
+
input_schema: dict[str, Any]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ToolClient(Protocol):
|
|
17
|
+
alias: str
|
|
18
|
+
|
|
19
|
+
async def list_tools(self) -> list[ToolInfo]: ...
|
|
20
|
+
async def call(self, tool_name: str, args: dict[str, Any], *, timeout: float | None = None) -> Any: ...
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from contextlib import AsyncExitStack
|
|
4
|
+
from typing import Any
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
from mcp import ClientSession
|
|
8
|
+
from mcp.client.sse import sse_client
|
|
9
|
+
|
|
10
|
+
from amsdal_ml.mcp_client.base import ToolClient
|
|
11
|
+
from amsdal_ml.mcp_client.base import ToolInfo
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class HttpClient(ToolClient):
|
|
15
|
+
def __init__(self, *, alias: str, url: str, headers: Optional[dict[str, str]] = None):
|
|
16
|
+
self.alias = alias
|
|
17
|
+
self.url = url
|
|
18
|
+
self.headers = headers or {}
|
|
19
|
+
|
|
20
|
+
async def _session(self):
|
|
21
|
+
stack = AsyncExitStack()
|
|
22
|
+
rx, tx = await stack.enter_async_context(sse_client(self.url, headers=self.headers))
|
|
23
|
+
s = await stack.enter_async_context(ClientSession(rx, tx))
|
|
24
|
+
await s.initialize()
|
|
25
|
+
return stack, s
|
|
26
|
+
|
|
27
|
+
async def list_tools(self) -> list[ToolInfo]:
|
|
28
|
+
stack, s = await self._session()
|
|
29
|
+
try:
|
|
30
|
+
resp = await s.list_tools()
|
|
31
|
+
out: list[ToolInfo] = []
|
|
32
|
+
for t in resp.tools:
|
|
33
|
+
out.append(
|
|
34
|
+
ToolInfo(
|
|
35
|
+
alias=self.alias,
|
|
36
|
+
name=t.name,
|
|
37
|
+
description=t.description or '',
|
|
38
|
+
input_schema=(getattr(t, 'inputSchema', None) or {}),
|
|
39
|
+
)
|
|
40
|
+
)
|
|
41
|
+
return out
|
|
42
|
+
finally:
|
|
43
|
+
await stack.aclose()
|
|
44
|
+
|
|
45
|
+
async def call(self, tool_name: str, args: dict[str, Any], *, timeout: float | None = None,) -> Any:
|
|
46
|
+
_ = timeout # ARG002
|
|
47
|
+
stack, s = await self._session()
|
|
48
|
+
try:
|
|
49
|
+
res = await s.call_tool(tool_name, args)
|
|
50
|
+
return getattr(res, 'content', res)
|
|
51
|
+
finally:
|
|
52
|
+
await stack.aclose()
|