amsdal_ml 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
amsdal_ml/__about__.py ADDED
@@ -0,0 +1 @@
1
+ __version__ = '0.1.0'
amsdal_ml/__init__.py ADDED
File without changes
File without changes
@@ -0,0 +1,39 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC
4
+ from abc import abstractmethod
5
+ from collections.abc import AsyncIterator
6
+ from collections.abc import Iterator
7
+ from typing import Any
8
+ from typing import Literal
9
+
10
+ from pydantic import BaseModel
11
+ from pydantic import Field
12
+
13
+
14
+ class AgentMessage(BaseModel):
15
+ role: Literal["SYSTEM", "USER", "ASSISTANT"]
16
+ content: str
17
+
18
+
19
+ class AgentOutput(BaseModel):
20
+ answer: str
21
+ used_tools: list[str] = Field(default_factory=list)
22
+ citations: list[dict[str, Any]] = Field(default_factory=list)
23
+
24
+
25
+ class Agent(ABC):
26
+ @abstractmethod
27
+ async def arun(self, user_query: str) -> AgentOutput: ...
28
+ @abstractmethod
29
+ async def astream(self, user_query: str) -> AsyncIterator[str]:
30
+ """Yield streamed chunks for the given query."""
31
+ raise NotImplementedError
32
+
33
+ def run(self, user_query: str) -> AgentOutput:
34
+ msg = "This agent is async-only. Use arun()."
35
+ raise NotImplementedError(msg)
36
+
37
+ def stream(self, user_query: str) -> Iterator[str]:
38
+ msg = "This agent is async-only. Use astream()."
39
+ raise NotImplementedError(msg)
@@ -0,0 +1,376 @@
1
+ from __future__ import annotations
2
+
3
+ import inspect
4
+ import json
5
+ import re
6
+ from collections.abc import AsyncIterator
7
+ from collections.abc import Callable
8
+ from dataclasses import dataclass
9
+ from enum import Enum
10
+ from typing import Any
11
+ from typing import no_type_check
12
+
13
+ from amsdal_ml.agents.agent import Agent
14
+ from amsdal_ml.agents.agent import AgentOutput
15
+ from amsdal_ml.agents.promts import get_prompt
16
+ from amsdal_ml.mcp_client.base import ToolClient
17
+ from amsdal_ml.mcp_client.base import ToolInfo
18
+ from amsdal_ml.ml_models.models import MLModel
19
+
20
+ # ---------- STRICT ReAct REGEX ----------
21
+ _TOOL_CALL_RE = re.compile(
22
+ r"Thought:\s*Do I need to use a tool\?\s*Yes\s*"
23
+ r"Action:\s*(?P<action>[^\n]+)\s*"
24
+ r"Action Input:\s*(?P<input>\{.*\})\s*$",
25
+ re.DOTALL | re.IGNORECASE,
26
+ )
27
+
28
+ _FINAL_RE = re.compile(
29
+ r"Thought:\s*Do I need to use a tool\?\s*No\s*"
30
+ r"Final Answer:\s*(?P<answer>.+)",
31
+ re.DOTALL | re.IGNORECASE,
32
+ )
33
+
34
+
35
+
36
+ # ---------- STRICT ReAct REGEX ----------
37
+
38
+ @dataclass
39
+ class Route:
40
+ name: str
41
+ match: Callable[[str], bool]
42
+ handler: Callable[[str], AgentOutput]
43
+
44
+
45
+ class ParseErrorMode(Enum):
46
+ RAISE = "raise"
47
+ RETRY = "retry"
48
+
49
+
50
+ # === proxy-tool ===
51
+ class _ClientToolProxy:
52
+ def __init__(self, client: ToolClient, alias: str, name: str, schema: dict[str, Any], description: str):
53
+ self.client = client
54
+ self.alias = alias
55
+ self.name = name
56
+ self.qualified = f"{alias}.{name}"
57
+ self.parameters = schema
58
+ self.description = description
59
+ self._default_timeout: float | None = 20.0
60
+
61
+ def set_timeout(self, timeout: float | None) -> None:
62
+ self._default_timeout = timeout
63
+
64
+ async def run(
65
+ self,
66
+ args: dict[str, Any],
67
+ context=None,
68
+ *,
69
+ convert_result: bool = True,
70
+ ):
71
+ _ = (context, convert_result)
72
+
73
+ if self.parameters:
74
+ try:
75
+ import jsonschema
76
+ jsonschema.validate(instance=args, schema=self.parameters)
77
+ except Exception as exc:
78
+ msg = f"Tool input validation failed for {self.qualified}: {exc}"
79
+ raise ValueError(
80
+ msg
81
+ ) from exc
82
+
83
+ return await self.client.call(self.name, args, timeout=self._default_timeout)
84
+
85
+ class DefaultQAAgent(Agent):
86
+
87
+ def __init__(
88
+ self,
89
+ *,
90
+ model: MLModel,
91
+ tool_clients: list[ToolClient] | None = None,
92
+ max_steps: int = 6,
93
+ on_parse_error: ParseErrorMode = ParseErrorMode.RAISE,
94
+ enable_stop_guard: bool = True,
95
+ per_call_timeout: float | None = 20.0,
96
+ ):
97
+
98
+ # Only clients MCP (stdio/sse)
99
+ self._tool_clients: list[ToolClient] = tool_clients or []
100
+ self._indexed_tools: dict[str, Any] = {} # qualified -> proxy
101
+
102
+ self.model = model
103
+ self.model.setup()
104
+ self.max_steps = max_steps
105
+ self.per_call_timeout = per_call_timeout
106
+ self.on_parse_error = on_parse_error
107
+ self.enable_stop_guard = enable_stop_guard
108
+
109
+ self._tools_index_built = False
110
+
111
+ # ---------- tools helpers ----------
112
+ def _get_tool(self, name: str) -> Any:
113
+ """
114
+ Look up tools ONLY among client-indexed tools.
115
+ Expected names are qualified: '<alias>.<tool_name>'.
116
+ """
117
+ if not self._tools_index_built:
118
+ msg = "Tool index not built. Ensure arun()/astream() was used."
119
+ raise RuntimeError(msg)
120
+ if name in self._indexed_tools:
121
+ return self._indexed_tools[name]
122
+ available = sorted(self._indexed_tool_names())
123
+ msg = f"Unknown tool: {name}. Available: {', '.join(available)}"
124
+ raise KeyError(msg)
125
+
126
+ def _indexed_tool_names(self) -> list[str]:
127
+ return list(self._indexed_tools.keys()) if self._tools_index_built else []
128
+
129
+ def _tool_names(self) -> str:
130
+ return ", ".join(sorted(self._indexed_tool_names()))
131
+
132
+ def _tool_descriptions(self) -> str:
133
+ parts: list[str] = []
134
+ if self._tools_index_built:
135
+ for qn, t in self._indexed_tools.items():
136
+ desc = t.description or "No description."
137
+ try:
138
+ schema_json = json.dumps(t.parameters or {}, ensure_ascii=False)
139
+ except Exception:
140
+ schema_json = str(t.parameters)
141
+ parts.append(f"- {qn}: {desc}\n Args JSON schema: {schema_json}")
142
+ return "\n".join(parts)
143
+
144
+ async def _build_clients_index(self):
145
+ self._indexed_tools.clear()
146
+ for client in self._tool_clients:
147
+ infos: list[ToolInfo] = await client.list_tools()
148
+ for ti in infos:
149
+ qname = f"{ti.alias}.{ti.name}"
150
+ proxy = _ClientToolProxy(
151
+ client=client,
152
+ alias=ti.alias,
153
+ name=ti.name,
154
+ schema=ti.input_schema or {},
155
+ description=ti.description or "",
156
+ )
157
+ proxy.set_timeout(self.per_call_timeout)
158
+ self._indexed_tools[qname] = proxy
159
+ self._tools_index_built = True
160
+
161
+ # ---------- prompt composition ----------
162
+ def _react_text(self, user_query: str, scratchpad: str) -> str:
163
+ tmpl = get_prompt("react_chat")
164
+ return tmpl.render_text(
165
+ user_query=user_query,
166
+ tools=self._tool_descriptions(),
167
+ tool_names=self._tool_names(),
168
+ agent_scratchpad=scratchpad,
169
+ chat_history="",
170
+ )
171
+
172
+ @staticmethod
173
+ def _stopped_message() -> str:
174
+ return "Agent stopped due to iteration limit or time limit."
175
+
176
+ def _stopped_response(self, used_tools: list[str]) -> AgentOutput:
177
+ return AgentOutput(answer=self._stopped_message(), used_tools=used_tools, citations=[])
178
+
179
+ @staticmethod
180
+ def _serialize_observation(content: Any) -> str:
181
+ if isinstance(content, str | bytes):
182
+ return content if isinstance(content, str) else content.decode("utf-8", errors="ignore")
183
+ try:
184
+ return json.dumps(content, ensure_ascii=False)
185
+ except Exception:
186
+ return str(content)
187
+
188
+ # ---------- core run ----------
189
+ def run(self, _user_query: str) -> AgentOutput:
190
+ msg = "DefaultQAAgent is async-only for now. Use arun()."
191
+ raise NotImplementedError(msg)
192
+
193
+ async def _run_async(self, user_query: str) -> AgentOutput:
194
+ if self._tool_clients and not self._tools_index_built:
195
+ await self._build_clients_index()
196
+
197
+ scratch = ""
198
+ used_tools: list[str] = []
199
+ parse_retries = 0
200
+
201
+ for _ in range(self.max_steps):
202
+ prompt = self._react_text(user_query, scratch)
203
+ out = await self.model.ainvoke(prompt)
204
+
205
+ m_final = _FINAL_RE.search(out or "")
206
+ if m_final:
207
+ return AgentOutput(
208
+ answer=(m_final.group("answer") or "").strip(),
209
+ used_tools=used_tools,
210
+ citations=[],
211
+ )
212
+
213
+ m_tool = _TOOL_CALL_RE.search(out or "")
214
+ if not m_tool:
215
+ parse_retries += 1
216
+ if self.on_parse_error == ParseErrorMode.RAISE or parse_retries >= 1:
217
+ msg = (
218
+ "Invalid ReAct output. Expected EXACT format (Final or Tool-call). "
219
+ f"Got:\n{out}"
220
+ )
221
+ raise ValueError(
222
+ msg
223
+ )
224
+ scratch += (
225
+ "\nThought: Previous output violated the strict format. "
226
+ "Reply again using EXACTLY one of the two specified formats.\n"
227
+ )
228
+ continue
229
+
230
+ action = m_tool.group("action").strip()
231
+ raw_input = m_tool.group("input").strip()
232
+
233
+ try:
234
+ args = json.loads(raw_input)
235
+ if not isinstance(args, dict):
236
+ msg = "Action Input must be a JSON object."
237
+ raise ValueError(msg)
238
+ except Exception as e:
239
+ parse_retries += 1
240
+ if self.on_parse_error == ParseErrorMode.RAISE or parse_retries >= 1:
241
+ msg = f"Invalid Action Input JSON: {raw_input!r} ({e})"
242
+ raise ValueError(msg) from e
243
+ scratch += (
244
+ "\nThought: Action Input must be a ONE-LINE JSON object. "
245
+ "Retry with correct JSON.\n"
246
+ )
247
+ continue
248
+
249
+ tool = self._get_tool(action)
250
+
251
+ try:
252
+ result = await tool.run(args, context=None, convert_result=True)
253
+ except Exception as e:
254
+ # unified error payload
255
+ err = {
256
+ "error": {
257
+ "type": e.__class__.__name__,
258
+ "server": getattr(tool, "alias", "local"),
259
+ "tool": getattr(tool, "name", getattr(tool, "qualified", "unknown")),
260
+ "message": str(e),
261
+ "retryable": False,
262
+ }
263
+ }
264
+ result = err
265
+
266
+ used_tools.append(action)
267
+ observation = self._serialize_observation(result)
268
+
269
+ scratch += (
270
+ "\nThought: Do I need to use a tool? Yes"
271
+ f"\nAction: {action}"
272
+ f"\nAction Input: {raw_input}"
273
+ f"\nObservation: {observation}\n"
274
+ )
275
+
276
+ return self._stopped_response(used_tools)
277
+
278
+ # ---------- public APIs ----------
279
+ async def arun(self, user_query: str) -> AgentOutput:
280
+ return await self._run_async(user_query)
281
+
282
+ # ---------- streaming ----------
283
+ @no_type_check
284
+ async def astream(self, user_query: str) -> AsyncIterator[str]:
285
+ if self._tool_clients and not self._tools_index_built:
286
+ await self._build_clients_index()
287
+
288
+ scratch = ""
289
+ used_tools: list[str] = []
290
+ parse_retries = 0
291
+
292
+ for _ in range(self.max_steps):
293
+ prompt = self._react_text(user_query, scratch)
294
+
295
+ buffer = ""
296
+
297
+ # Normalize model.astream: it might be an async iterator already,
298
+ # or a coroutine (or nested coroutines) that resolves to one.
299
+ _val = self.model.astream(prompt)
300
+ while inspect.iscoroutine(_val):
301
+ _val = await _val
302
+
303
+ # Optional guard (helpful during tests)
304
+ if not hasattr(_val, "__aiter__"):
305
+ msg = f"model.astream() did not yield an AsyncIterator; got {type(_val)!r}"
306
+ raise TypeError(msg)
307
+
308
+ model_stream = _val # now an AsyncIterator[str]
309
+
310
+ async for chunk in model_stream:
311
+ buffer += chunk
312
+
313
+ m_final = _FINAL_RE.search(buffer or "")
314
+ if m_final:
315
+ answer = (m_final.group("answer") or "").strip()
316
+ if answer:
317
+ yield answer
318
+ return
319
+
320
+ m_tool = _TOOL_CALL_RE.search(buffer or "")
321
+ if not m_tool:
322
+ parse_retries += 1
323
+ if self.on_parse_error == ParseErrorMode.RAISE or parse_retries >= 1:
324
+ msg = f"Invalid ReAct output (stream). Expected EXACT format. Got:\n{buffer}"
325
+ raise ValueError(msg)
326
+ scratch += (
327
+ "\nThought: Previous output violated the strict format. "
328
+ "Reply again using EXACTLY one of the two specified formats.\n"
329
+ )
330
+ continue
331
+
332
+ action = m_tool.group("action").strip()
333
+ raw_input = m_tool.group("input").strip()
334
+
335
+ try:
336
+ args = json.loads(raw_input)
337
+ if not isinstance(args, dict):
338
+ msg = "Action Input must be a JSON object."
339
+ raise ValueError(msg)
340
+ except Exception as e:
341
+ parse_retries += 1
342
+ if self.on_parse_error == ParseErrorMode.RAISE or parse_retries >= 1:
343
+ msg = f"Invalid Action Input JSON: {raw_input!r} ({e})"
344
+ raise ValueError(msg) from e
345
+ scratch += (
346
+ "\nThought: Action Input must be a ONE-LINE JSON object. "
347
+ "Retry with correct JSON.\n"
348
+ )
349
+ continue
350
+
351
+ tool = self._get_tool(action)
352
+
353
+ try:
354
+ result = await tool.run(args, context=None, convert_result=True)
355
+ except Exception as e:
356
+ result = {
357
+ "error": {
358
+ "type": e.__class__.__name__,
359
+ "server": getattr(tool, "alias", "local"),
360
+ "tool": getattr(tool, "name", getattr(tool, "qualified", "unknown")),
361
+ "message": str(e),
362
+ "retryable": False,
363
+ }
364
+ }
365
+
366
+ used_tools.append(action)
367
+ observation = self._serialize_observation(result)
368
+
369
+ scratch += (
370
+ "\nThought: Do I need to use a tool? Yes"
371
+ f"\nAction: {action}"
372
+ f"\nAction Input: {raw_input}"
373
+ f"\nObservation: {observation}\n"
374
+ )
375
+
376
+ yield self._stopped_message()
@@ -0,0 +1,58 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from pathlib import Path
5
+ from typing import Any
6
+ from typing import Final
7
+
8
+ # avoid "magic numbers"
9
+ _PARTS_EXPECTED: Final[int] = 2
10
+
11
+
12
+ class _SafeDict(dict[str, Any]):
13
+ def __missing__(self, key: str) -> str:
14
+ return "{" + key + "}"
15
+
16
+
17
+ @dataclass
18
+ class Prompt:
19
+ name: str
20
+ system: str
21
+ user: str
22
+
23
+ def render_text(self, **kwargs: Any) -> str:
24
+ data = _SafeDict(**kwargs)
25
+ sys_txt = self.system.format_map(data)
26
+ usr_txt = self.user.format_map(data)
27
+ return f"{sys_txt}\n\n{usr_txt}".strip()
28
+
29
+ def render_messages(self, **kwargs: Any) -> list[dict[str, str]]:
30
+ data = _SafeDict(**kwargs)
31
+ return [
32
+ {"role": "system", "content": self.system.format_map(data)},
33
+ {"role": "user", "content": self.user.format_map(data)},
34
+ ]
35
+
36
+
37
+ _prompt_cache: dict[str, Prompt] = {}
38
+
39
+
40
+ def _load_file(name: str) -> Prompt:
41
+ base = Path(__file__).resolve().parent
42
+ path = base / f"{name}.prompt"
43
+ if not path.exists():
44
+ msg = f"Prompt '{name}' not found at {path}"
45
+ raise FileNotFoundError(msg)
46
+ raw = path.read_text(encoding="utf-8")
47
+ parts = raw.split("\n---\n", 1)
48
+ if len(parts) == _PARTS_EXPECTED:
49
+ system, user = parts
50
+ else:
51
+ system, user = raw, "{input}"
52
+ return Prompt(name=name, system=system.strip(), user=user.strip())
53
+
54
+
55
+ def get_prompt(name: str) -> Prompt:
56
+ if name not in _prompt_cache:
57
+ _prompt_cache[name] = _load_file(name)
58
+ return _prompt_cache[name]
@@ -0,0 +1,37 @@
1
+ You are a concise, helpful assistant. Use tools only when they improve accuracy.
2
+
3
+ TOOLS
4
+ ------
5
+ {tools}
6
+
7
+ FORMAT (STRICT — EXACTLY ONE; NO extra lines, NO Markdown)
8
+ ----------------------------------------------------------
9
+ 1. If you do NOT need a tool:
10
+ Thought: Do I need to use a tool? No
11
+ Final Answer: <your final answer in plain text>
12
+
13
+ 2. If you DO need a tool:
14
+ Thought: Do I need to use a tool? Yes
15
+ Action: <one tool name from: {tool_names}>
16
+ Action Input: <ONE-LINE JSON object with the tool arguments>
17
+
18
+ RULES
19
+ -----
20
+ - Output MUST match exactly one of the two blocks above.
21
+ - `Action Input` MUST be a valid ONE-LINE JSON object (e.g. {{"a": 1, "b": 2}}).
22
+ - Do NOT add anything before/after the block.
23
+ - Do NOT print "Observation". The system will add it after tool execution.
24
+
25
+ PREVIOUS CONVERSATION
26
+ --------------------
27
+ {chat_history}
28
+
29
+ NEW INPUT
30
+ ---------
31
+ {user_query}
32
+
33
+ SCRATCHPAD
34
+ ----------
35
+ {agent_scratchpad}
36
+
37
+ Assistant:
@@ -0,0 +1,48 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+ from typing import Optional
5
+
6
+ from mcp.server.fastmcp.tools.base import Tool
7
+ from pydantic import BaseModel
8
+ from pydantic import Field
9
+
10
+ from amsdal_ml.ml_retrievers.openai_retriever import OpenAIRetriever
11
+
12
+
13
+ class RetrieverArgs(BaseModel):
14
+ query: str = Field(..., description='User search query')
15
+ k: int = 5
16
+ include_tags: Optional[list[str]] = None
17
+ exclude_tags: Optional[list[str]] = None
18
+
19
+
20
+ _retriever = OpenAIRetriever()
21
+
22
+
23
+ async def retriever_search(args: RetrieverArgs) -> list[dict[str, Any]]:
24
+ chunks = await _retriever.asimilarity_search(
25
+ query=args.query,
26
+ k=args.k,
27
+ include_tags=args.include_tags,
28
+ exclude_tags=args.exclude_tags,
29
+ )
30
+ out: list[dict[str, Any]] = []
31
+ for c in chunks:
32
+ if hasattr(c, 'model_dump'):
33
+ out.append(c.model_dump())
34
+ elif hasattr(c, 'dict'):
35
+ out.append(c.dict())
36
+ elif isinstance(c, dict):
37
+ out.append(c)
38
+ else:
39
+ out.append({'raw_text': str(c)})
40
+ return out
41
+
42
+
43
+ retriever_tool = Tool.from_function(
44
+ retriever_search,
45
+ name='retriever.search',
46
+ description='Semantic search in knowledge base (OpenAI embeddings)',
47
+ structured_output=True,
48
+ )
amsdal_ml/app.py ADDED
@@ -0,0 +1,24 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ # During type checking, point to the known-available fallback.
6
+ if TYPE_CHECKING:
7
+ from amsdal.contrib.app_config import AppConfig as BaseAppConfig
8
+ else:
9
+ try:
10
+ # At runtime prefer the real AMSDAL core if present.
11
+ from amsdal.configs.app import AppConfig as BaseAppConfig # type: ignore[import-not-found]
12
+ except Exception: # pragma: no cover
13
+ from amsdal.contrib.app_config import AppConfig as BaseAppConfig
14
+
15
+
16
+ class MLPluginAppConfig(BaseAppConfig):
17
+ name = "amsdal_ml"
18
+ verbose_name = "AMSDAL ML Plugin"
19
+
20
+ def on_ready(self) -> None:
21
+ pass
22
+
23
+ def on_server_startup(self) -> None:
24
+ pass
File without changes
@@ -0,0 +1,20 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any
5
+ from typing import Protocol
6
+
7
+
8
+ @dataclass
9
+ class ToolInfo:
10
+ alias: str
11
+ name: str
12
+ description: str
13
+ input_schema: dict[str, Any]
14
+
15
+
16
+ class ToolClient(Protocol):
17
+ alias: str
18
+
19
+ async def list_tools(self) -> list[ToolInfo]: ...
20
+ async def call(self, tool_name: str, args: dict[str, Any], *, timeout: float | None = None) -> Any: ...
@@ -0,0 +1,52 @@
1
+ from __future__ import annotations
2
+
3
+ from contextlib import AsyncExitStack
4
+ from typing import Any
5
+ from typing import Optional
6
+
7
+ from mcp import ClientSession
8
+ from mcp.client.sse import sse_client
9
+
10
+ from amsdal_ml.mcp_client.base import ToolClient
11
+ from amsdal_ml.mcp_client.base import ToolInfo
12
+
13
+
14
+ class HttpClient(ToolClient):
15
+ def __init__(self, *, alias: str, url: str, headers: Optional[dict[str, str]] = None):
16
+ self.alias = alias
17
+ self.url = url
18
+ self.headers = headers or {}
19
+
20
+ async def _session(self):
21
+ stack = AsyncExitStack()
22
+ rx, tx = await stack.enter_async_context(sse_client(self.url, headers=self.headers))
23
+ s = await stack.enter_async_context(ClientSession(rx, tx))
24
+ await s.initialize()
25
+ return stack, s
26
+
27
+ async def list_tools(self) -> list[ToolInfo]:
28
+ stack, s = await self._session()
29
+ try:
30
+ resp = await s.list_tools()
31
+ out: list[ToolInfo] = []
32
+ for t in resp.tools:
33
+ out.append(
34
+ ToolInfo(
35
+ alias=self.alias,
36
+ name=t.name,
37
+ description=t.description or '',
38
+ input_schema=(getattr(t, 'inputSchema', None) or {}),
39
+ )
40
+ )
41
+ return out
42
+ finally:
43
+ await stack.aclose()
44
+
45
+ async def call(self, tool_name: str, args: dict[str, Any], *, timeout: float | None = None,) -> Any:
46
+ _ = timeout # ARG002
47
+ stack, s = await self._session()
48
+ try:
49
+ res = await s.call_tool(tool_name, args)
50
+ return getattr(res, 'content', res)
51
+ finally:
52
+ await stack.aclose()