amsdal_ml 0.1.4__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- amsdal_ml/Third-Party Materials - AMSDAL Dependencies - License Notices.md +617 -0
- amsdal_ml/__about__.py +1 -1
- amsdal_ml/agents/__init__.py +13 -0
- amsdal_ml/agents/agent.py +5 -7
- amsdal_ml/agents/default_qa_agent.py +108 -143
- amsdal_ml/agents/functional_calling_agent.py +233 -0
- amsdal_ml/agents/mcp_client_tool.py +46 -0
- amsdal_ml/agents/python_tool.py +86 -0
- amsdal_ml/agents/retriever_tool.py +5 -6
- amsdal_ml/agents/tool_adapters.py +98 -0
- amsdal_ml/fileio/base_loader.py +7 -5
- amsdal_ml/fileio/openai_loader.py +16 -17
- amsdal_ml/mcp_client/base.py +2 -0
- amsdal_ml/mcp_client/http_client.py +7 -1
- amsdal_ml/mcp_client/stdio_client.py +19 -16
- amsdal_ml/mcp_server/server_retriever_stdio.py +8 -11
- amsdal_ml/ml_ingesting/__init__.py +29 -0
- amsdal_ml/ml_ingesting/default_ingesting.py +49 -51
- amsdal_ml/ml_ingesting/embedders/__init__.py +4 -0
- amsdal_ml/ml_ingesting/embedders/embedder.py +12 -0
- amsdal_ml/ml_ingesting/embedders/openai_embedder.py +30 -0
- amsdal_ml/ml_ingesting/embedding_data.py +3 -0
- amsdal_ml/ml_ingesting/loaders/__init__.py +6 -0
- amsdal_ml/ml_ingesting/loaders/folder_loader.py +52 -0
- amsdal_ml/ml_ingesting/loaders/loader.py +28 -0
- amsdal_ml/ml_ingesting/loaders/pdf_loader.py +136 -0
- amsdal_ml/ml_ingesting/loaders/text_loader.py +44 -0
- amsdal_ml/ml_ingesting/model_ingester.py +278 -0
- amsdal_ml/ml_ingesting/pipeline.py +131 -0
- amsdal_ml/ml_ingesting/pipeline_interface.py +31 -0
- amsdal_ml/ml_ingesting/processors/__init__.py +4 -0
- amsdal_ml/ml_ingesting/processors/cleaner.py +14 -0
- amsdal_ml/ml_ingesting/processors/text_cleaner.py +42 -0
- amsdal_ml/ml_ingesting/splitters/__init__.py +4 -0
- amsdal_ml/ml_ingesting/splitters/splitter.py +15 -0
- amsdal_ml/ml_ingesting/splitters/token_splitter.py +85 -0
- amsdal_ml/ml_ingesting/stores/__init__.py +4 -0
- amsdal_ml/ml_ingesting/stores/embedding_data.py +63 -0
- amsdal_ml/ml_ingesting/stores/store.py +22 -0
- amsdal_ml/ml_ingesting/types.py +40 -0
- amsdal_ml/ml_models/models.py +96 -4
- amsdal_ml/ml_models/openai_model.py +430 -122
- amsdal_ml/ml_models/utils.py +7 -0
- amsdal_ml/ml_retrievers/__init__.py +17 -0
- amsdal_ml/ml_retrievers/adapters.py +93 -0
- amsdal_ml/ml_retrievers/default_retriever.py +11 -1
- amsdal_ml/ml_retrievers/openai_retriever.py +27 -7
- amsdal_ml/ml_retrievers/query_retriever.py +487 -0
- amsdal_ml/ml_retrievers/retriever.py +12 -0
- amsdal_ml/models/embedding_model.py +7 -7
- amsdal_ml/prompts/__init__.py +77 -0
- amsdal_ml/prompts/database_query_agent.prompt +14 -0
- amsdal_ml/prompts/functional_calling_agent_base.prompt +9 -0
- amsdal_ml/prompts/nl_query_filter.prompt +318 -0
- amsdal_ml/{agents/promts → prompts}/react_chat.prompt +17 -8
- amsdal_ml/utils/__init__.py +5 -0
- amsdal_ml/utils/query_utils.py +189 -0
- {amsdal_ml-0.1.4.dist-info → amsdal_ml-0.2.1.dist-info}/METADATA +61 -3
- amsdal_ml-0.2.1.dist-info/RECORD +72 -0
- {amsdal_ml-0.1.4.dist-info → amsdal_ml-0.2.1.dist-info}/WHEEL +1 -1
- amsdal_ml/agents/promts/__init__.py +0 -58
- amsdal_ml-0.1.4.dist-info/RECORD +0 -39
amsdal_ml/agents/agent.py
CHANGED
|
@@ -15,7 +15,7 @@ from amsdal_ml.fileio.base_loader import FileAttachment
|
|
|
15
15
|
|
|
16
16
|
|
|
17
17
|
class AgentMessage(BaseModel):
|
|
18
|
-
role: Literal[
|
|
18
|
+
role: Literal['SYSTEM', 'USER', 'ASSISTANT']
|
|
19
19
|
content: str
|
|
20
20
|
|
|
21
21
|
|
|
@@ -32,8 +32,7 @@ class Agent(ABC):
|
|
|
32
32
|
user_query: str,
|
|
33
33
|
*,
|
|
34
34
|
attachments: Optional[list[FileAttachment]] = None,
|
|
35
|
-
) -> AgentOutput:
|
|
36
|
-
...
|
|
35
|
+
) -> AgentOutput: ...
|
|
37
36
|
|
|
38
37
|
@abstractmethod
|
|
39
38
|
async def astream(
|
|
@@ -41,8 +40,7 @@ class Agent(ABC):
|
|
|
41
40
|
user_query: str,
|
|
42
41
|
*,
|
|
43
42
|
attachments: Optional[list[FileAttachment]] = None,
|
|
44
|
-
) -> AsyncIterator[str]:
|
|
45
|
-
...
|
|
43
|
+
) -> AsyncIterator[str]: ...
|
|
46
44
|
|
|
47
45
|
def run(
|
|
48
46
|
self,
|
|
@@ -50,7 +48,7 @@ class Agent(ABC):
|
|
|
50
48
|
*,
|
|
51
49
|
attachments: Optional[list[FileAttachment]] = None,
|
|
52
50
|
) -> AgentOutput:
|
|
53
|
-
msg =
|
|
51
|
+
msg = 'This agent is async-only. Use arun().'
|
|
54
52
|
raise NotImplementedError(msg)
|
|
55
53
|
|
|
56
54
|
def stream(
|
|
@@ -59,5 +57,5 @@ class Agent(ABC):
|
|
|
59
57
|
*,
|
|
60
58
|
attachments: Optional[list[FileAttachment]] = None,
|
|
61
59
|
) -> Iterator[str]:
|
|
62
|
-
msg =
|
|
60
|
+
msg = 'This agent is async-only. Use astream().'
|
|
63
61
|
raise NotImplementedError(msg)
|
|
@@ -12,23 +12,26 @@ from typing import no_type_check
|
|
|
12
12
|
|
|
13
13
|
from amsdal_ml.agents.agent import Agent
|
|
14
14
|
from amsdal_ml.agents.agent import AgentOutput
|
|
15
|
-
from amsdal_ml.agents.
|
|
15
|
+
from amsdal_ml.agents.mcp_client_tool import ClientToolProxy
|
|
16
|
+
from amsdal_ml.agents.python_tool import PythonTool
|
|
17
|
+
from amsdal_ml.agents.python_tool import _PythonToolProxy
|
|
16
18
|
from amsdal_ml.fileio.base_loader import FileAttachment
|
|
17
19
|
from amsdal_ml.mcp_client.base import ToolClient
|
|
18
20
|
from amsdal_ml.mcp_client.base import ToolInfo
|
|
19
21
|
from amsdal_ml.ml_models.models import MLModel
|
|
22
|
+
from amsdal_ml.prompts import get_prompt
|
|
20
23
|
|
|
21
24
|
# ---------- STRICT ReAct REGEX ----------
|
|
22
25
|
_TOOL_CALL_RE = re.compile(
|
|
23
|
-
r
|
|
24
|
-
r
|
|
25
|
-
r
|
|
26
|
+
r'Thought:\s*Do I need to use a tool\?\s*Yes[\.\!]?\s*'
|
|
27
|
+
r'Action:\s*(?P<action>[^\n]+)\s*'
|
|
28
|
+
r'Action Input:\s*(?P<input>\{.*\})\s*',
|
|
26
29
|
re.DOTALL | re.IGNORECASE,
|
|
27
30
|
)
|
|
28
31
|
|
|
29
32
|
_FINAL_RE = re.compile(
|
|
30
|
-
r
|
|
31
|
-
r
|
|
33
|
+
r'(?:Thought:\s*Do I need to use a tool\?\s*No[\.\!]?\s*)?'
|
|
34
|
+
r'Final Answer:\s*(?P<answer>.+)',
|
|
32
35
|
re.DOTALL | re.IGNORECASE,
|
|
33
36
|
)
|
|
34
37
|
# ---------- constants ----------
|
|
@@ -36,9 +39,9 @@ _FINAL_RE = re.compile(
|
|
|
36
39
|
_MAX_PARSE_RETRIES = 5
|
|
37
40
|
|
|
38
41
|
|
|
39
|
-
|
|
40
42
|
# ---------- STRICT ReAct REGEX ----------
|
|
41
43
|
|
|
44
|
+
|
|
42
45
|
@dataclass
|
|
43
46
|
class Route:
|
|
44
47
|
name: str
|
|
@@ -47,61 +50,23 @@ class Route:
|
|
|
47
50
|
|
|
48
51
|
|
|
49
52
|
class ParseErrorMode(Enum):
|
|
50
|
-
RAISE =
|
|
51
|
-
RETRY =
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
# === proxy-tool ===
|
|
55
|
-
class _ClientToolProxy:
|
|
56
|
-
def __init__(self, client: ToolClient, alias: str, name: str, schema: dict[str, Any], description: str):
|
|
57
|
-
self.client = client
|
|
58
|
-
self.alias = alias
|
|
59
|
-
self.name = name
|
|
60
|
-
self.qualified = f"{alias}.{name}"
|
|
61
|
-
self.parameters = schema
|
|
62
|
-
self.description = description
|
|
63
|
-
self._default_timeout: float | None = 20.0
|
|
64
|
-
|
|
65
|
-
def set_timeout(self, timeout: float | None) -> None:
|
|
66
|
-
self._default_timeout = timeout
|
|
67
|
-
|
|
68
|
-
async def run(
|
|
69
|
-
self,
|
|
70
|
-
args: dict[str, Any],
|
|
71
|
-
context=None,
|
|
72
|
-
*,
|
|
73
|
-
convert_result: bool = True,
|
|
74
|
-
):
|
|
75
|
-
_ = (context, convert_result)
|
|
76
|
-
|
|
77
|
-
if self.parameters:
|
|
78
|
-
try:
|
|
79
|
-
import jsonschema
|
|
80
|
-
jsonschema.validate(instance=args, schema=self.parameters)
|
|
81
|
-
except Exception as exc:
|
|
82
|
-
msg = f"Tool input validation failed for {self.qualified}: {exc}"
|
|
83
|
-
raise ValueError(
|
|
84
|
-
msg
|
|
85
|
-
) from exc
|
|
53
|
+
RAISE = 'raise'
|
|
54
|
+
RETRY = 'retry'
|
|
86
55
|
|
|
87
|
-
return await self.client.call(self.name, args, timeout=self._default_timeout)
|
|
88
56
|
|
|
89
57
|
class DefaultQAAgent(Agent):
|
|
90
|
-
|
|
91
58
|
def __init__(
|
|
92
59
|
self,
|
|
93
60
|
*,
|
|
94
61
|
model: MLModel,
|
|
95
|
-
|
|
62
|
+
tools: list[PythonTool | ToolClient] | None = None,
|
|
96
63
|
max_steps: int = 6,
|
|
97
64
|
on_parse_error: ParseErrorMode = ParseErrorMode.RAISE,
|
|
98
65
|
enable_stop_guard: bool = True,
|
|
99
66
|
per_call_timeout: float | None = 20.0,
|
|
100
67
|
):
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
self._tool_clients: list[ToolClient] = tool_clients or []
|
|
104
|
-
self._indexed_tools: dict[str, Any] = {} # qualified -> proxy
|
|
68
|
+
self._tools: list[PythonTool | ToolClient] = tools or []
|
|
69
|
+
self._indexed_tools: dict[str, ClientToolProxy | _PythonToolProxy] = {}
|
|
105
70
|
|
|
106
71
|
self.model = model
|
|
107
72
|
self.model.setup()
|
|
@@ -110,7 +75,7 @@ class DefaultQAAgent(Agent):
|
|
|
110
75
|
self.on_parse_error = on_parse_error
|
|
111
76
|
self.enable_stop_guard = enable_stop_guard
|
|
112
77
|
|
|
113
|
-
self.
|
|
78
|
+
self._is_tools_index_built = False
|
|
114
79
|
|
|
115
80
|
# ---------- tools helpers ----------
|
|
116
81
|
def _get_tool(self, name: str) -> Any:
|
|
@@ -118,64 +83,76 @@ class DefaultQAAgent(Agent):
|
|
|
118
83
|
Look up tools ONLY among client-indexed tools.
|
|
119
84
|
Expected names are qualified: '<alias>.<tool_name>'.
|
|
120
85
|
"""
|
|
121
|
-
if not self.
|
|
122
|
-
msg =
|
|
86
|
+
if not self._is_tools_index_built:
|
|
87
|
+
msg = 'Tool index not built. Ensure arun()/astream() was used.'
|
|
123
88
|
raise RuntimeError(msg)
|
|
124
89
|
if name in self._indexed_tools:
|
|
125
90
|
return self._indexed_tools[name]
|
|
126
91
|
available = sorted(self._indexed_tool_names())
|
|
127
|
-
msg = f
|
|
92
|
+
msg = f'Unknown tool: {name}. Available: {", ".join(available)}'
|
|
128
93
|
raise KeyError(msg)
|
|
129
94
|
|
|
130
95
|
def _indexed_tool_names(self) -> list[str]:
|
|
131
|
-
return list(self._indexed_tools.keys()) if self.
|
|
96
|
+
return list(self._indexed_tools.keys()) if self._is_tools_index_built else []
|
|
132
97
|
|
|
133
98
|
def _tool_names(self) -> str:
|
|
134
|
-
return
|
|
99
|
+
return ', '.join(sorted(self._indexed_tool_names()))
|
|
135
100
|
|
|
136
101
|
def _tool_descriptions(self) -> str:
|
|
137
102
|
parts: list[str] = []
|
|
138
|
-
if self.
|
|
103
|
+
if self._is_tools_index_built:
|
|
139
104
|
for qn, t in self._indexed_tools.items():
|
|
140
|
-
desc = t.description or
|
|
105
|
+
desc = t.description or 'No description.'
|
|
141
106
|
try:
|
|
142
107
|
schema_json = json.dumps(t.parameters or {}, ensure_ascii=False)
|
|
143
108
|
except Exception:
|
|
144
109
|
schema_json = str(t.parameters)
|
|
145
|
-
parts.append(f
|
|
146
|
-
return
|
|
110
|
+
parts.append(f'- {qn}: {desc}\n Args JSON schema: {schema_json}')
|
|
111
|
+
return '\n'.join(parts)
|
|
147
112
|
|
|
148
113
|
async def _build_clients_index(self):
|
|
149
114
|
self._indexed_tools.clear()
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
115
|
+
|
|
116
|
+
for tool in self._tools:
|
|
117
|
+
if isinstance(tool, ToolClient):
|
|
118
|
+
infos: list[ToolInfo] = await tool.list_tools()
|
|
119
|
+
for ti in infos:
|
|
120
|
+
qname = f'{ti.alias}.{ti.name}'
|
|
121
|
+
proxy = ClientToolProxy(
|
|
122
|
+
client=tool,
|
|
123
|
+
alias=ti.alias,
|
|
124
|
+
name=ti.name,
|
|
125
|
+
schema=ti.input_schema or {},
|
|
126
|
+
description=ti.description or '',
|
|
127
|
+
)
|
|
128
|
+
proxy.set_timeout(self.per_call_timeout)
|
|
129
|
+
self._indexed_tools[qname] = proxy
|
|
130
|
+
elif isinstance(tool, PythonTool):
|
|
131
|
+
if tool.name in self._indexed_tools:
|
|
132
|
+
msg = f'Tool name conflict: {tool.name} is already defined.'
|
|
133
|
+
raise ValueError(msg)
|
|
134
|
+
proxy = _PythonToolProxy(tool, timeout=self.per_call_timeout) # type: ignore[assignment]
|
|
135
|
+
self._indexed_tools[tool.name] = proxy
|
|
136
|
+
else:
|
|
137
|
+
msg = f'Unsupported tool type: {type(tool)}'
|
|
138
|
+
raise TypeError(msg)
|
|
139
|
+
|
|
140
|
+
self._is_tools_index_built = True
|
|
164
141
|
|
|
165
142
|
# ---------- prompt composition ----------
|
|
166
143
|
def _react_text(self, user_query: str, scratchpad: str) -> str:
|
|
167
|
-
tmpl = get_prompt(
|
|
144
|
+
tmpl = get_prompt('react_chat')
|
|
168
145
|
return tmpl.render_text(
|
|
169
146
|
user_query=user_query,
|
|
170
147
|
tools=self._tool_descriptions(),
|
|
171
148
|
tool_names=self._tool_names(),
|
|
172
149
|
agent_scratchpad=scratchpad,
|
|
173
|
-
chat_history=
|
|
150
|
+
chat_history='',
|
|
174
151
|
)
|
|
175
152
|
|
|
176
153
|
@staticmethod
|
|
177
154
|
def _stopped_message() -> str:
|
|
178
|
-
return
|
|
155
|
+
return 'Agent stopped due to iteration limit or time limit.'
|
|
179
156
|
|
|
180
157
|
def _stopped_response(self, used_tools: list[str]) -> AgentOutput:
|
|
181
158
|
return AgentOutput(answer=self._stopped_message(), used_tools=used_tools, citations=[])
|
|
@@ -183,7 +160,7 @@ class DefaultQAAgent(Agent):
|
|
|
183
160
|
@staticmethod
|
|
184
161
|
def _serialize_observation(content: Any) -> str:
|
|
185
162
|
if isinstance(content, str | bytes):
|
|
186
|
-
return content if isinstance(content, str) else content.decode(
|
|
163
|
+
return content if isinstance(content, str) else content.decode('utf-8', errors='ignore')
|
|
187
164
|
try:
|
|
188
165
|
return json.dumps(content, ensure_ascii=False)
|
|
189
166
|
except Exception:
|
|
@@ -191,82 +168,73 @@ class DefaultQAAgent(Agent):
|
|
|
191
168
|
|
|
192
169
|
# ---------- core run ----------
|
|
193
170
|
def run(self, user_query: str, *, attachments: list[FileAttachment] | None = None) -> AgentOutput:
|
|
194
|
-
msg =
|
|
171
|
+
msg = 'DefaultQAAgent is async-only for now. Use arun().'
|
|
195
172
|
raise NotImplementedError(msg)
|
|
196
173
|
|
|
197
174
|
async def _run_async(self, user_query: str, *, attachments: list[FileAttachment] | None = None) -> AgentOutput:
|
|
198
|
-
if
|
|
175
|
+
if not self._is_tools_index_built:
|
|
199
176
|
await self._build_clients_index()
|
|
200
177
|
|
|
201
|
-
scratch =
|
|
178
|
+
scratch = ''
|
|
202
179
|
used_tools: list[str] = []
|
|
203
180
|
parse_retries = 0
|
|
204
181
|
|
|
205
|
-
|
|
206
182
|
for _ in range(self.max_steps):
|
|
207
183
|
prompt = self._react_text(user_query, scratch)
|
|
208
184
|
out = await self.model.ainvoke(prompt, attachments=attachments)
|
|
209
|
-
print(
|
|
185
|
+
print('Model output:', out) # noqa: T201
|
|
210
186
|
print('promt:', prompt) # noqa: T201
|
|
211
|
-
m_final = _FINAL_RE.search(out or
|
|
187
|
+
m_final = _FINAL_RE.search(out or '')
|
|
212
188
|
if m_final:
|
|
213
189
|
return AgentOutput(
|
|
214
|
-
answer=(m_final.group(
|
|
190
|
+
answer=(m_final.group('answer') or '').strip(),
|
|
215
191
|
used_tools=used_tools,
|
|
216
192
|
citations=[],
|
|
217
193
|
)
|
|
218
194
|
|
|
219
|
-
m_tool = _TOOL_CALL_RE.search(out or
|
|
195
|
+
m_tool = _TOOL_CALL_RE.search(out or '')
|
|
220
196
|
if not m_tool:
|
|
221
197
|
parse_retries += 1
|
|
222
198
|
if self.on_parse_error == ParseErrorMode.RAISE or parse_retries >= _MAX_PARSE_RETRIES:
|
|
223
|
-
msg = (
|
|
224
|
-
|
|
225
|
-
f"Got:\n{out}"
|
|
226
|
-
)
|
|
227
|
-
raise ValueError(
|
|
228
|
-
msg
|
|
229
|
-
)
|
|
199
|
+
msg = f'Invalid ReAct output. Expected EXACT format (Final or Tool-call). Got:\n{out}'
|
|
200
|
+
raise ValueError(msg)
|
|
230
201
|
|
|
231
202
|
scratch += (
|
|
232
|
-
|
|
233
|
-
|
|
203
|
+
'\nThought: Previous output violated the strict format. '
|
|
204
|
+
'Reply again using EXACTLY one of the two specified formats.\n'
|
|
234
205
|
)
|
|
235
206
|
continue
|
|
236
207
|
|
|
237
|
-
action = m_tool.group(
|
|
238
|
-
raw_input = m_tool.group(
|
|
208
|
+
action = m_tool.group('action').strip()
|
|
209
|
+
raw_input = m_tool.group('input').strip()
|
|
239
210
|
|
|
240
211
|
try:
|
|
241
212
|
args = json.loads(raw_input)
|
|
242
213
|
if not isinstance(args, dict):
|
|
243
|
-
msg =
|
|
214
|
+
msg = 'Action Input must be a JSON object.'
|
|
244
215
|
raise ValueError(msg)
|
|
245
216
|
except Exception as e:
|
|
246
217
|
parse_retries += 1
|
|
247
218
|
if self.on_parse_error == ParseErrorMode.RAISE or parse_retries >= _MAX_PARSE_RETRIES:
|
|
248
|
-
msg = f
|
|
219
|
+
msg = f'Invalid Action Input JSON: {raw_input!r} ({e})'
|
|
249
220
|
raise ValueError(msg) from e
|
|
250
|
-
scratch +=
|
|
251
|
-
"\nThought: Action Input must be a ONE-LINE JSON object. "
|
|
252
|
-
"Retry with correct JSON.\n"
|
|
253
|
-
)
|
|
221
|
+
scratch += '\nThought: Action Input must be a ONE-LINE JSON object. Retry with correct JSON.\n'
|
|
254
222
|
continue
|
|
255
223
|
|
|
256
224
|
tool = self._get_tool(action)
|
|
257
225
|
|
|
258
226
|
try:
|
|
259
227
|
result = await tool.run(args, context=None, convert_result=True)
|
|
260
|
-
print(
|
|
228
|
+
print('Similarity search result:', result) # noqa: T201
|
|
261
229
|
except Exception as e:
|
|
262
230
|
# unified error payload
|
|
263
231
|
err = {
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
232
|
+
'error': {
|
|
233
|
+
'type': e.__class__.__name__,
|
|
234
|
+
'server': getattr(tool, 'alias', 'local'),
|
|
235
|
+
'tool': getattr(tool, 'name', getattr(tool, 'qualified', 'unknown')),
|
|
236
|
+
'message': str(e),
|
|
237
|
+
'retryable': False,
|
|
270
238
|
}
|
|
271
239
|
}
|
|
272
240
|
result = err
|
|
@@ -275,10 +243,10 @@ class DefaultQAAgent(Agent):
|
|
|
275
243
|
observation = self._serialize_observation(result)
|
|
276
244
|
|
|
277
245
|
scratch += (
|
|
278
|
-
|
|
279
|
-
f
|
|
280
|
-
f
|
|
281
|
-
f
|
|
246
|
+
'\nThought: Do I need to use a tool? Yes'
|
|
247
|
+
f'\nAction: {action}'
|
|
248
|
+
f'\nAction Input: {raw_input}'
|
|
249
|
+
f'\nObservation: {observation}\n'
|
|
282
250
|
)
|
|
283
251
|
|
|
284
252
|
return self._stopped_response(used_tools)
|
|
@@ -290,17 +258,17 @@ class DefaultQAAgent(Agent):
|
|
|
290
258
|
# ---------- streaming ----------
|
|
291
259
|
@no_type_check
|
|
292
260
|
async def astream(self, user_query: str, *, attachments: list[FileAttachment] | None = None) -> AsyncIterator[str]:
|
|
293
|
-
if
|
|
261
|
+
if not self._is_tools_index_built:
|
|
294
262
|
await self._build_clients_index()
|
|
295
263
|
|
|
296
|
-
scratch =
|
|
264
|
+
scratch = ''
|
|
297
265
|
used_tools: list[str] = []
|
|
298
266
|
parse_retries = 0
|
|
299
267
|
|
|
300
268
|
for _ in range(self.max_steps):
|
|
301
269
|
prompt = self._react_text(user_query, scratch)
|
|
302
270
|
|
|
303
|
-
buffer =
|
|
271
|
+
buffer = ''
|
|
304
272
|
|
|
305
273
|
# Normalize model.astream: it might be an async iterator already,
|
|
306
274
|
# or a coroutine (or nested coroutines) that resolves to one.
|
|
@@ -309,8 +277,8 @@ class DefaultQAAgent(Agent):
|
|
|
309
277
|
_val = await _val
|
|
310
278
|
|
|
311
279
|
# Optional guard (helpful during tests)
|
|
312
|
-
if not hasattr(_val,
|
|
313
|
-
msg = f
|
|
280
|
+
if not hasattr(_val, '__aiter__'):
|
|
281
|
+
msg = f'model.astream() did not yield an AsyncIterator; got {type(_val)!r}'
|
|
314
282
|
raise TypeError(msg)
|
|
315
283
|
|
|
316
284
|
model_stream = _val # now an AsyncIterator[str]
|
|
@@ -318,42 +286,39 @@ class DefaultQAAgent(Agent):
|
|
|
318
286
|
async for chunk in model_stream:
|
|
319
287
|
buffer += chunk
|
|
320
288
|
|
|
321
|
-
m_final = _FINAL_RE.search(buffer or
|
|
289
|
+
m_final = _FINAL_RE.search(buffer or '')
|
|
322
290
|
if m_final:
|
|
323
|
-
answer = (m_final.group(
|
|
291
|
+
answer = (m_final.group('answer') or '').strip()
|
|
324
292
|
if answer:
|
|
325
293
|
yield answer
|
|
326
294
|
return
|
|
327
295
|
|
|
328
|
-
m_tool = _TOOL_CALL_RE.search(buffer or
|
|
296
|
+
m_tool = _TOOL_CALL_RE.search(buffer or '')
|
|
329
297
|
if not m_tool:
|
|
330
298
|
parse_retries += 1
|
|
331
299
|
if self.on_parse_error == ParseErrorMode.RAISE or parse_retries >= _MAX_PARSE_RETRIES:
|
|
332
|
-
msg = f
|
|
300
|
+
msg = f'Invalid ReAct output (stream). Expected EXACT format. Got:\n{buffer}'
|
|
333
301
|
raise ValueError(msg)
|
|
334
302
|
scratch += (
|
|
335
|
-
|
|
336
|
-
|
|
303
|
+
'\nThought: Previous output violated the strict format. '
|
|
304
|
+
'Reply again using EXACTLY one of the two specified formats.\n'
|
|
337
305
|
)
|
|
338
306
|
continue
|
|
339
307
|
|
|
340
|
-
action = m_tool.group(
|
|
341
|
-
raw_input = m_tool.group(
|
|
308
|
+
action = m_tool.group('action').strip()
|
|
309
|
+
raw_input = m_tool.group('input').strip()
|
|
342
310
|
|
|
343
311
|
try:
|
|
344
312
|
args = json.loads(raw_input)
|
|
345
313
|
if not isinstance(args, dict):
|
|
346
|
-
msg =
|
|
314
|
+
msg = 'Action Input must be a JSON object.'
|
|
347
315
|
raise ValueError(msg)
|
|
348
316
|
except Exception as e:
|
|
349
317
|
parse_retries += 1
|
|
350
318
|
if self.on_parse_error == ParseErrorMode.RAISE or parse_retries >= _MAX_PARSE_RETRIES:
|
|
351
|
-
msg = f
|
|
319
|
+
msg = f'Invalid Action Input JSON: {raw_input!r} ({e})'
|
|
352
320
|
raise ValueError(msg) from e
|
|
353
|
-
scratch +=
|
|
354
|
-
"\nThought: Action Input must be a ONE-LINE JSON object. "
|
|
355
|
-
"Retry with correct JSON.\n"
|
|
356
|
-
)
|
|
321
|
+
scratch += '\nThought: Action Input must be a ONE-LINE JSON object. Retry with correct JSON.\n'
|
|
357
322
|
continue
|
|
358
323
|
|
|
359
324
|
tool = self._get_tool(action)
|
|
@@ -362,12 +327,12 @@ class DefaultQAAgent(Agent):
|
|
|
362
327
|
result = await tool.run(args, context=None, convert_result=True)
|
|
363
328
|
except Exception as e:
|
|
364
329
|
result = {
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
330
|
+
'error': {
|
|
331
|
+
'type': e.__class__.__name__,
|
|
332
|
+
'server': getattr(tool, 'alias', 'local'),
|
|
333
|
+
'tool': getattr(tool, 'name', getattr(tool, 'qualified', 'unknown')),
|
|
334
|
+
'message': str(e),
|
|
335
|
+
'retryable': False,
|
|
371
336
|
}
|
|
372
337
|
}
|
|
373
338
|
|
|
@@ -375,10 +340,10 @@ class DefaultQAAgent(Agent):
|
|
|
375
340
|
observation = self._serialize_observation(result)
|
|
376
341
|
|
|
377
342
|
scratch += (
|
|
378
|
-
|
|
379
|
-
f
|
|
380
|
-
f
|
|
381
|
-
f
|
|
343
|
+
'\nThought: Do I need to use a tool? Yes'
|
|
344
|
+
f'\nAction: {action}'
|
|
345
|
+
f'\nAction Input: {raw_input}'
|
|
346
|
+
f'\nObservation: {observation}\n'
|
|
382
347
|
)
|
|
383
348
|
|
|
384
349
|
yield self._stopped_message()
|