xgae 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xgae might be problematic. Click here for more details.
- xgae/engine/responser/xga_non_stream_responser.py +216 -0
- xgae/engine/responser/xga_responser_base.py +710 -0
- xgae/engine/responser/xga_stream_responser.py +830 -0
- xgae/engine/xga_base.py +20 -8
- xgae/engine/xga_engine.py +280 -32
- xgae/engine/xga_mcp_tool_box.py +7 -4
- xgae/utils/json_helpers.py +174 -0
- xgae/utils/llm_client.py +17 -6
- xgae/utils/setup_env.py +1 -4
- xgae/utils/xml_tool_parser.py +236 -0
- {xgae-0.1.3.dist-info → xgae-0.1.5.dist-info}/METADATA +1 -1
- xgae-0.1.5.dist-info/RECORD +16 -0
- xgae/engine/responser/xga_responser_utils.py +0 -0
- xgae/engine/responser/xga_stream_reponser.py +0 -0
- xgae-0.1.3.dist-info/RECORD +0 -14
- {xgae-0.1.3.dist-info → xgae-0.1.5.dist-info}/WHEEL +0 -0
xgae/engine/xga_base.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Union, Optional, Dict, List, Any, Literal
|
|
1
|
+
from typing import Union, Optional, Dict, List, Any, Literal, TypedDict
|
|
2
2
|
from dataclasses import dataclass
|
|
3
3
|
from abc import ABC, abstractmethod
|
|
4
4
|
|
|
@@ -6,16 +6,23 @@ class XGAError(Exception):
|
|
|
6
6
|
"""Custom exception for errors in the XGA system."""
|
|
7
7
|
pass
|
|
8
8
|
|
|
9
|
-
|
|
10
|
-
class
|
|
11
|
-
|
|
12
|
-
type: Literal["status", "tool", "assistant", "assistant_response_end"]
|
|
13
|
-
is_llm_message: bool
|
|
9
|
+
|
|
10
|
+
class XGAResponseMsg(TypedDict, total=False):
|
|
11
|
+
type: Literal["user", "status", "tool", "assistant", "assistant_response_end"]
|
|
14
12
|
content: Union[Dict[str, Any], List[Any], str]
|
|
15
|
-
|
|
13
|
+
is_llm_message: bool
|
|
14
|
+
metadata: Dict[str, Any]
|
|
15
|
+
message_id: str
|
|
16
|
+
task_id: str
|
|
17
|
+
task_run_id: str
|
|
18
|
+
trace_id: str
|
|
16
19
|
session_id: Optional[str]
|
|
17
20
|
agent_id: Optional[str]
|
|
18
|
-
|
|
21
|
+
|
|
22
|
+
class XGATaskResult(TypedDict, total=False):
|
|
23
|
+
type: Literal["ask", "answer", "error"]
|
|
24
|
+
content: str
|
|
25
|
+
attachments: Optional[List[str]]
|
|
19
26
|
|
|
20
27
|
@dataclass
|
|
21
28
|
class XGAToolSchema:
|
|
@@ -31,6 +38,7 @@ class XGAToolResult:
|
|
|
31
38
|
success: bool
|
|
32
39
|
output: str
|
|
33
40
|
|
|
41
|
+
|
|
34
42
|
class XGAToolBox(ABC):
|
|
35
43
|
@abstractmethod
|
|
36
44
|
async def creat_task_tool_box(self, task_id: str, general_tools: List[str], custom_tools: List[str]):
|
|
@@ -44,6 +52,10 @@ class XGAToolBox(ABC):
|
|
|
44
52
|
def get_task_tool_schemas(self, task_id: str, type: Literal["general_tool", "custom_tool"]) -> List[XGAToolSchema]:
|
|
45
53
|
pass
|
|
46
54
|
|
|
55
|
+
@abstractmethod
|
|
56
|
+
def get_task_tool_names(self, task_id: str) -> List[str]:
|
|
57
|
+
pass
|
|
58
|
+
|
|
47
59
|
@abstractmethod
|
|
48
60
|
async def call_tool(self, task_id: str, tool_name: str, args: Optional[Dict[str, Any]] = None) -> XGAToolResult:
|
|
49
61
|
pass
|
xgae/engine/xga_engine.py
CHANGED
|
@@ -1,30 +1,32 @@
|
|
|
1
1
|
|
|
2
|
-
|
|
2
|
+
import logging
|
|
3
|
+
import json
|
|
4
|
+
|
|
5
|
+
from typing import List, Any, Dict, Optional, AsyncGenerator, cast, Union, Literal
|
|
3
6
|
from uuid import uuid4
|
|
4
7
|
|
|
5
|
-
from xgae.engine.
|
|
6
|
-
from xgae.
|
|
8
|
+
from xgae.engine.responser.xga_responser_base import TaskResponseContext, TaskResponseProcessor, TaskRunContinuousState
|
|
9
|
+
from xgae.engine.xga_base import XGAResponseMsg, XGAToolBox, XGATaskResult
|
|
10
|
+
from xgae.utils.llm_client import LLMClient, LLMConfig
|
|
7
11
|
from xgae.utils.setup_env import langfuse
|
|
12
|
+
from xgae.utils.utils import handle_error
|
|
13
|
+
|
|
8
14
|
from xga_prompt_builder import XGAPromptBuilder
|
|
9
15
|
from xga_mcp_tool_box import XGAMcpToolBox
|
|
10
16
|
|
|
11
|
-
|
|
12
|
-
class XGATaskEngine():
|
|
17
|
+
class XGATaskEngine:
|
|
13
18
|
def __init__(self,
|
|
14
19
|
session_id: Optional[str] = None,
|
|
15
20
|
task_id: Optional[str] = None,
|
|
16
21
|
agent_id: Optional[str] = None,
|
|
17
|
-
trace_id: Optional[str] = None,
|
|
18
22
|
system_prompt: Optional[str] = None,
|
|
19
|
-
llm_config: Optional[
|
|
23
|
+
llm_config: Optional[LLMConfig] = None,
|
|
20
24
|
prompt_builder: Optional[XGAPromptBuilder] = None,
|
|
21
25
|
tool_box: Optional[XGAToolBox] = None):
|
|
22
|
-
self.session_id = session_id if session_id else f"xga_sid_{uuid4()}"
|
|
23
26
|
self.task_id = task_id if task_id else f"xga_task_{uuid4()}"
|
|
24
27
|
self.agent_id = agent_id
|
|
25
|
-
self.
|
|
28
|
+
self.session_id = session_id
|
|
26
29
|
|
|
27
|
-
self.messages: List[XGAMessage] = []
|
|
28
30
|
self.llm_client = LLMClient(llm_config)
|
|
29
31
|
self.model_name = self.llm_client.model_name
|
|
30
32
|
self.is_stream = self.llm_client.is_stream
|
|
@@ -32,8 +34,12 @@ class XGATaskEngine():
|
|
|
32
34
|
self.prompt_builder = prompt_builder or XGAPromptBuilder(system_prompt)
|
|
33
35
|
self.tool_box = tool_box or XGAMcpToolBox()
|
|
34
36
|
|
|
37
|
+
self.task_response_msgs: List[XGAResponseMsg] = []
|
|
38
|
+
self.task_no = -1
|
|
39
|
+
self.task_run_id = f"{self.task_id}[{self.task_no}]"
|
|
40
|
+
self.trace_id = None
|
|
35
41
|
|
|
36
|
-
async def
|
|
42
|
+
async def _post_init_(self, general_tools:List[str], custom_tools: List[str]) -> None:
|
|
37
43
|
await self.tool_box.load_mcp_tools_schema()
|
|
38
44
|
await self.tool_box.creat_task_tool_box(self.task_id, general_tools, custom_tools)
|
|
39
45
|
general_tool_schemas = self.tool_box.get_task_tool_schemas(self.task_id, "general_tool")
|
|
@@ -46,54 +52,296 @@ class XGATaskEngine():
|
|
|
46
52
|
session_id: Optional[str] = None,
|
|
47
53
|
task_id: Optional[str] = None,
|
|
48
54
|
agent_id: Optional[str] = None,
|
|
49
|
-
trace_id: Optional[str] = None,
|
|
50
55
|
system_prompt: Optional[str] = None,
|
|
51
56
|
general_tools: Optional[List[str]] = None,
|
|
52
57
|
custom_tools: Optional[List[str]] = None,
|
|
53
|
-
llm_config: Optional[
|
|
58
|
+
llm_config: Optional[LLMConfig] = None,
|
|
54
59
|
prompt_builder: Optional[XGAPromptBuilder] = None,
|
|
55
60
|
tool_box: Optional[XGAToolBox] = None) -> 'XGATaskEngine':
|
|
56
61
|
engine: XGATaskEngine = cls(session_id=session_id,
|
|
57
62
|
task_id=task_id,
|
|
58
63
|
agent_id=agent_id,
|
|
59
|
-
trace_id=trace_id,
|
|
60
64
|
system_prompt=system_prompt,
|
|
61
65
|
llm_config=llm_config,
|
|
62
66
|
prompt_builder=prompt_builder,
|
|
63
67
|
tool_box=tool_box)
|
|
64
|
-
|
|
68
|
+
|
|
69
|
+
general_tools = general_tools or ["complete", "ask"]
|
|
65
70
|
custom_tools = custom_tools or []
|
|
66
|
-
await engine.
|
|
71
|
+
await engine._post_init_(general_tools, custom_tools)
|
|
72
|
+
|
|
73
|
+
logging.info("*"*30 + f" XGATaskEngine Task'{engine.task_id}' Initialized " + "*"*30)
|
|
74
|
+
logging.info(f"model_name={engine.model_name}, is_stream={engine.is_stream}, trace_id={engine.trace_id}")
|
|
75
|
+
logging.info(f"general_tools={general_tools}, custom_tools={custom_tools}")
|
|
76
|
+
|
|
67
77
|
return engine
|
|
68
78
|
|
|
79
|
+
async def run_task_with_final_answer(self,
|
|
80
|
+
task_message: Dict[str, Any],
|
|
81
|
+
max_auto_run: int = 25,
|
|
82
|
+
trace_id: Optional[str] = None) -> XGATaskResult:
|
|
83
|
+
chunks = []
|
|
84
|
+
async for chunk in self.run_task(task_message=task_message, max_auto_run=max_auto_run, trace_id=trace_id):
|
|
85
|
+
chunks.append(chunk)
|
|
69
86
|
|
|
70
|
-
|
|
87
|
+
final_result = self._parse_final_result(chunks)
|
|
88
|
+
return final_result
|
|
89
|
+
|
|
90
|
+
async def run_task(self,
|
|
91
|
+
task_message: Dict[str, Any],
|
|
92
|
+
max_auto_run: int = 25,
|
|
93
|
+
trace_id: Optional[str] = None) -> AsyncGenerator[Dict[str, Any], None]:
|
|
71
94
|
try:
|
|
72
|
-
|
|
95
|
+
self.trace_id = trace_id or self.trace_id or langfuse.create_trace_id()
|
|
96
|
+
|
|
97
|
+
self.task_no += 1
|
|
98
|
+
self.task_run_id = f"{self.task_id}[{self.task_no}]"
|
|
73
99
|
|
|
100
|
+
self.add_response_msg(type="user", content=task_message, is_llm_message=True)
|
|
101
|
+
|
|
102
|
+
if max_auto_run <= 1:
|
|
103
|
+
continuous_state:TaskRunContinuousState = {
|
|
104
|
+
"accumulated_content": "",
|
|
105
|
+
"auto_continue_count": 0,
|
|
106
|
+
"auto_continue": False
|
|
107
|
+
}
|
|
108
|
+
async for chunk in self._run_task_once(continuous_state):
|
|
109
|
+
yield chunk
|
|
110
|
+
else:
|
|
111
|
+
async for chunk in self._run_task_auto(max_auto_run):
|
|
112
|
+
yield chunk
|
|
74
113
|
finally:
|
|
75
114
|
await self.tool_box.destroy_task_tool_box(self.task_id)
|
|
76
|
-
|
|
77
115
|
|
|
78
|
-
def _run_task_once(self):
|
|
79
|
-
|
|
116
|
+
async def _run_task_once(self, continuous_state: TaskRunContinuousState) -> AsyncGenerator[Dict[str, Any], None]:
|
|
117
|
+
llm_messages = [{"role": "system", "content": self.task_prompt}]
|
|
118
|
+
cxt_llm_contents = self._get_response_llm_contents()
|
|
119
|
+
llm_messages.extend(cxt_llm_contents)
|
|
120
|
+
|
|
121
|
+
partial_content = continuous_state.get('accumulated_content', '')
|
|
122
|
+
if partial_content:
|
|
123
|
+
temp_assistant_message = {
|
|
124
|
+
"role": "assistant",
|
|
125
|
+
"content": partial_content
|
|
126
|
+
}
|
|
127
|
+
llm_messages.append(temp_assistant_message)
|
|
128
|
+
|
|
129
|
+
llm_response = await self.llm_client.create_completion(llm_messages)
|
|
130
|
+
response_processor = self._create_response_processer()
|
|
131
|
+
|
|
132
|
+
async for chunk in response_processor.process_response(llm_response, llm_messages, continuous_state):
|
|
133
|
+
self._reponse_chunk_log(chunk)
|
|
134
|
+
yield chunk
|
|
135
|
+
|
|
136
|
+
async def _run_task_auto(self, max_auto_run: int) -> AsyncGenerator[Dict[str, Any], None]:
|
|
137
|
+
continuous_state: TaskRunContinuousState = {
|
|
138
|
+
"accumulated_content": "",
|
|
139
|
+
"auto_continue_count": 0,
|
|
140
|
+
"auto_continue": True
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
def update_continuous_state(_auto_continue_count, _auto_continue):
|
|
144
|
+
continuous_state["auto_continue_count"] = _auto_continue_count
|
|
145
|
+
continuous_state["auto_continue"] = _auto_continue
|
|
146
|
+
|
|
147
|
+
auto_continue_count = 0
|
|
148
|
+
auto_continue = True
|
|
149
|
+
while auto_continue and auto_continue_count < max_auto_run:
|
|
150
|
+
auto_continue = False
|
|
80
151
|
|
|
152
|
+
try:
|
|
153
|
+
async for chunk in self._run_task_once(continuous_state):
|
|
154
|
+
yield chunk
|
|
155
|
+
try:
|
|
156
|
+
if chunk.get("type") == "status":
|
|
157
|
+
content = json.loads(chunk.get('content', '{}'))
|
|
158
|
+
status_type = content.get('status_type', None)
|
|
159
|
+
if status_type == "error":
|
|
160
|
+
logging.error(f"run_task_auto: task_response error: {chunk.get('message', 'Unknown error')}")
|
|
161
|
+
auto_continue = False
|
|
162
|
+
break
|
|
163
|
+
elif status_type == 'finish':
|
|
164
|
+
finish_reason = content.get('finish_reason', None)
|
|
165
|
+
if finish_reason == 'completed':
|
|
166
|
+
logging.warning(f"run_task_auto: Detected finish_reason='completed', Task Completed Success !")
|
|
167
|
+
auto_continue = False
|
|
168
|
+
break
|
|
169
|
+
elif finish_reason == 'xml_tool_limit_reached':
|
|
170
|
+
logging.warning(f"run_task_auto: Detected finish_reason='xml_tool_limit_reached', stopping auto-continue")
|
|
171
|
+
auto_continue = False
|
|
172
|
+
break
|
|
173
|
+
elif finish_reason == 'stop' or finish_reason == 'length': # 'length' never occur
|
|
174
|
+
auto_continue = True
|
|
175
|
+
auto_continue_count += 1
|
|
176
|
+
update_continuous_state(auto_continue_count, auto_continue)
|
|
177
|
+
logging.info(f"run_task_auto: Detected finish_reason='{finish_reason}', auto-continuing ({auto_continue_count}/{max_auto_run})")
|
|
178
|
+
except StopAsyncIteration:
|
|
179
|
+
pass
|
|
180
|
+
except Exception as parse_error:
|
|
181
|
+
logging.error(f"run_task_auto: Error in parse chunk: {str(parse_error)}")
|
|
182
|
+
content = {"role": "system", "status_type": "error", "message": "Parse response chunk Error"}
|
|
183
|
+
error_msg = self.add_response_msg(type="status", content=content, is_llm_message=False)
|
|
184
|
+
yield error_msg
|
|
185
|
+
except Exception as run_error:
|
|
186
|
+
logging.error(f"run_task_auto: Call task_run_once error: {str(run_error)}")
|
|
187
|
+
content = {"role": "system", "status_type": "error", "message": "Call task_run_once error"}
|
|
188
|
+
error_msg = self.add_response_msg(type="status", content=content, is_llm_message=False)
|
|
189
|
+
yield error_msg
|
|
190
|
+
|
|
191
|
+
def _parse_final_result(self, chunks: List[Dict[str, Any]]) -> XGATaskResult:
|
|
192
|
+
final_result: XGATaskResult = None
|
|
193
|
+
try:
|
|
194
|
+
finish_reason = ''
|
|
195
|
+
for chunk in reversed(chunks):
|
|
196
|
+
chunk_type = chunk.get("type")
|
|
197
|
+
if chunk_type == "status":
|
|
198
|
+
status_content = json.loads(chunk.get('content', '{}'))
|
|
199
|
+
status_type = status_content.get('status_type', None)
|
|
200
|
+
if status_type == "error":
|
|
201
|
+
error = status_content.get('message', 'Unknown error')
|
|
202
|
+
final_result = XGATaskResult(type="error", content=error)
|
|
203
|
+
break
|
|
204
|
+
elif status_type == "finish":
|
|
205
|
+
finish_reason = status_content.get('finish_reason', None)
|
|
206
|
+
if finish_reason == 'xml_tool_limit_reached':
|
|
207
|
+
error = "Completed due to over task max_auto_run limit !"
|
|
208
|
+
final_result = XGATaskResult(type="error", content=error)
|
|
209
|
+
break
|
|
210
|
+
continue
|
|
211
|
+
elif chunk_type == "tool" and finish_reason in ['completed', 'stop']:
|
|
212
|
+
tool_content = json.loads(chunk.get('content', '{}'))
|
|
213
|
+
tool_execution = tool_content.get('tool_execution')
|
|
214
|
+
tool_name = tool_execution.get('function_name')
|
|
215
|
+
if tool_name == "complete":
|
|
216
|
+
result_content = tool_execution["arguments"].get("text", "Task completed with no answer")
|
|
217
|
+
attachments = tool_execution["arguments"].get("attachments", None)
|
|
218
|
+
final_result = XGATaskResult(type="answer", content=result_content, attachments=attachments)
|
|
219
|
+
elif tool_name == "ask":
|
|
220
|
+
result_content = tool_execution["arguments"].get("text", "Task ask for more info")
|
|
221
|
+
attachments = tool_execution["arguments"].get("attachments", None)
|
|
222
|
+
final_result = XGATaskResult(type="ask", content=result_content, attachments=attachments)
|
|
223
|
+
else:
|
|
224
|
+
tool_result = tool_execution.get("result", None)
|
|
225
|
+
if tool_result is not None:
|
|
226
|
+
success = tool_result.get("success")
|
|
227
|
+
output = tool_result.get("output")
|
|
228
|
+
result_type = "answer" if success else "error"
|
|
229
|
+
result_content = f"Task execute '{tool_name}' {result_type}: {output}"
|
|
230
|
+
final_result = XGATaskResult(type=result_type, content=result_content)
|
|
231
|
+
elif chunk_type == "assistant" and finish_reason == 'stop':
|
|
232
|
+
assis_content = chunk.get('content', '{}')
|
|
233
|
+
result_content = assis_content.get("content", "LLM output is empty")
|
|
234
|
+
final_result = XGATaskResult(type="answer", content=result_content)
|
|
235
|
+
if final_result is not None:
|
|
236
|
+
break
|
|
237
|
+
except Exception as e:
|
|
238
|
+
logging.error(f"parse_final_result: Final result pass error: {str(e)}")
|
|
239
|
+
final_result = XGATaskResult(type="error", content="Parse final result failed!")
|
|
240
|
+
handle_error(e)
|
|
241
|
+
|
|
242
|
+
return final_result
|
|
243
|
+
|
|
244
|
+
def add_response_msg(self, type: Literal["user", "status", "tool", "assistant", "assistant_response_end"],
|
|
245
|
+
content: Union[Dict[str, Any], List[Any], str],
|
|
246
|
+
is_llm_message: bool,
|
|
247
|
+
metadata: Optional[Dict[str, Any]]=None)-> XGAResponseMsg:
|
|
248
|
+
message = XGAResponseMsg(
|
|
249
|
+
message_id = f"xga_msg_{uuid4()}",
|
|
250
|
+
type = type,
|
|
251
|
+
content = content,
|
|
252
|
+
is_llm_message=is_llm_message,
|
|
253
|
+
metadata = metadata,
|
|
254
|
+
session_id = self.session_id,
|
|
255
|
+
agent_id = self.agent_id,
|
|
256
|
+
task_id = self.task_id,
|
|
257
|
+
task_run_id = self.task_run_id,
|
|
258
|
+
trace_id = self.trace_id
|
|
259
|
+
)
|
|
260
|
+
self.task_response_msgs.append(message)
|
|
261
|
+
|
|
262
|
+
return message
|
|
263
|
+
|
|
264
|
+
def _get_response_llm_contents (self) -> List[Dict[str, Any]]:
|
|
265
|
+
llm_messages = []
|
|
266
|
+
for message in self.task_response_msgs:
|
|
267
|
+
if message["is_llm_message"]:
|
|
268
|
+
llm_messages.append(message)
|
|
269
|
+
|
|
270
|
+
cxt_llm_contents = []
|
|
271
|
+
for llm_message in llm_messages:
|
|
272
|
+
content = llm_message["content"]
|
|
273
|
+
# @todo content List type
|
|
274
|
+
if isinstance(content, str):
|
|
275
|
+
try:
|
|
276
|
+
_content = json.loads(content)
|
|
277
|
+
cxt_llm_contents.append(_content)
|
|
278
|
+
except json.JSONDecodeError as e:
|
|
279
|
+
logging.error(f"get_context_llm_contents: Failed to decode json, content=:{content}")
|
|
280
|
+
handle_error(e)
|
|
281
|
+
else:
|
|
282
|
+
cxt_llm_contents.append(content)
|
|
283
|
+
|
|
284
|
+
return cxt_llm_contents
|
|
285
|
+
|
|
286
|
+
def _create_response_processer(self) -> TaskResponseProcessor:
|
|
287
|
+
response_context = self._create_response_context()
|
|
288
|
+
is_stream = response_context.get("is_stream", False)
|
|
289
|
+
if is_stream:
|
|
290
|
+
from xgae.engine.responser.xga_stream_responser import StreamTaskResponser
|
|
291
|
+
return StreamTaskResponser(response_context)
|
|
292
|
+
else:
|
|
293
|
+
from xgae.engine.responser.xga_non_stream_responser import NonStreamTaskResponser
|
|
294
|
+
return NonStreamTaskResponser(response_context)
|
|
295
|
+
|
|
296
|
+
def _create_response_context(self) -> TaskResponseContext:
|
|
297
|
+
response_context: TaskResponseContext = {
|
|
298
|
+
"is_stream": self.is_stream,
|
|
299
|
+
"task_id": self.task_id,
|
|
300
|
+
"task_run_id": self.task_run_id,
|
|
301
|
+
"trace_id": self.trace_id,
|
|
302
|
+
"model_name": self.model_name,
|
|
303
|
+
"max_xml_tool_calls": 0,
|
|
304
|
+
"add_context_msg": self.add_response_msg,
|
|
305
|
+
"tool_box": self.tool_box,
|
|
306
|
+
"tool_execution_strategy": "parallel",
|
|
307
|
+
"xml_adding_strategy": "user_message",
|
|
308
|
+
}
|
|
309
|
+
return response_context
|
|
310
|
+
|
|
311
|
+
def _reponse_chunk_log(self, chunk):
|
|
312
|
+
chunk_type = chunk.get('type')
|
|
313
|
+
prefix = ""
|
|
314
|
+
|
|
315
|
+
if chunk_type == 'status':
|
|
316
|
+
content = json.loads(chunk.get('content', '{}'))
|
|
317
|
+
status_type = content.get('status_type', "empty")
|
|
318
|
+
prefix = "-" + status_type
|
|
319
|
+
elif chunk_type == 'tool':
|
|
320
|
+
tool_content = json.loads(chunk.get('content', '{}'))
|
|
321
|
+
tool_execution = tool_content.get('tool_execution')
|
|
322
|
+
tool_name = tool_execution.get('function_name')
|
|
323
|
+
prefix = "-" + tool_name
|
|
324
|
+
|
|
325
|
+
logging.info(f"TASK_RESP_CHUNK[{chunk_type}{prefix}]: {chunk}")
|
|
81
326
|
|
|
82
|
-
def add_message(self, message: XGAMessage):
|
|
83
|
-
message.message_id = f"xga_msg_{uuid4()}"
|
|
84
|
-
message.session_id = self.session_id
|
|
85
|
-
message.agent_id = self.agent_id
|
|
86
|
-
self.messages.append(message)
|
|
87
327
|
|
|
88
328
|
if __name__ == "__main__":
|
|
89
329
|
import asyncio
|
|
90
|
-
|
|
330
|
+
from xgae.utils.utils import read_file
|
|
91
331
|
async def main():
|
|
92
332
|
tool_box = XGAMcpToolBox(custom_mcp_server_file="mcpservers/custom_servers.json")
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
333
|
+
system_prompt = read_file("templates/scp_test_prompt.txt")
|
|
334
|
+
engine = await XGATaskEngine.create(tool_box=tool_box,
|
|
335
|
+
general_tools=[],
|
|
336
|
+
custom_tools=["bomc_fault.*"],
|
|
337
|
+
llm_config=LLMConfig(stream=False),
|
|
338
|
+
system_prompt=system_prompt)
|
|
339
|
+
# engine = await XGATaskEngine.create(llm_config=LLMConfig(stream=False))
|
|
340
|
+
#chunks = []
|
|
341
|
+
# async for chunk in engine.run_task(task_message={"role": "user", "content": "定位10.0.0.1的故障"},max_auto_run=8):
|
|
342
|
+
# print(chunk)
|
|
343
|
+
#final_result = await engine.run_task_with_final_answer(task_message={"role": "user", "content": "1+1"}, max_auto_run=2)
|
|
98
344
|
|
|
345
|
+
final_result = await engine.run_task_with_final_answer(task_message={"role": "user", "content": "定位10.0.1.1故障"},max_auto_run=8)
|
|
346
|
+
print("FINAL RESULT:", final_result)
|
|
99
347
|
asyncio.run(main())
|
xgae/engine/xga_mcp_tool_box.py
CHANGED
|
@@ -75,6 +75,12 @@ class XGAMcpToolBox(XGAToolBox):
|
|
|
75
75
|
await self.call_tool(task_id, "end_task", {"task_id": task_id})
|
|
76
76
|
self.task_tool_schemas.pop(task_id, None)
|
|
77
77
|
|
|
78
|
+
@override
|
|
79
|
+
def get_task_tool_names(self, task_id: str) -> List[str]:
|
|
80
|
+
task_tool_schema = self.task_tool_schemas.get(task_id, {})
|
|
81
|
+
task_tool_names = list(task_tool_schema.keys())
|
|
82
|
+
return task_tool_names
|
|
83
|
+
|
|
78
84
|
@override
|
|
79
85
|
def get_task_tool_schemas(self, task_id: str, type: Literal["general_tool", "custom_tool"]) -> List[XGAToolSchema]:
|
|
80
86
|
task_tool_schemas = []
|
|
@@ -106,10 +112,7 @@ class XGAMcpToolBox(XGAToolBox):
|
|
|
106
112
|
if mcp_tool:
|
|
107
113
|
tool_args = args or {}
|
|
108
114
|
if server_name == self.GENERAL_MCP_SERVER_NAME:
|
|
109
|
-
|
|
110
|
-
#tool_args["task_id"] = task_id #xga general tool, first param must be task_id
|
|
111
|
-
else:
|
|
112
|
-
tool_args = args
|
|
115
|
+
tool_args = dict({"task_id": task_id}, **tool_args)
|
|
113
116
|
|
|
114
117
|
try:
|
|
115
118
|
tool_result = await mcp_tool.arun(tool_args)
|
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
"""
|
|
2
|
+
JSON helper utilities for handling both legacy (string) and new (dict/list) formats.
|
|
3
|
+
|
|
4
|
+
These utilities help with the transition from storing JSON as strings to storing
|
|
5
|
+
them as proper JSONB objects in the database.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
from typing import Any, Union, Dict, List
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def ensure_dict(value: Union[str, Dict[str, Any], None], default: Dict[str, Any] = None) -> Dict[str, Any]:
|
|
13
|
+
"""
|
|
14
|
+
Ensure a value is a dictionary.
|
|
15
|
+
|
|
16
|
+
Handles:
|
|
17
|
+
- None -> returns default or {}
|
|
18
|
+
- Dict -> returns as-is
|
|
19
|
+
- JSON string -> parses and returns dict
|
|
20
|
+
- Other -> returns default or {}
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
value: The value to ensure is a dict
|
|
24
|
+
default: Default value if conversion fails
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
A dictionary
|
|
28
|
+
"""
|
|
29
|
+
if default is None:
|
|
30
|
+
default = {}
|
|
31
|
+
|
|
32
|
+
if value is None:
|
|
33
|
+
return default
|
|
34
|
+
|
|
35
|
+
if isinstance(value, dict):
|
|
36
|
+
return value
|
|
37
|
+
|
|
38
|
+
if isinstance(value, str):
|
|
39
|
+
try:
|
|
40
|
+
parsed = json.loads(value)
|
|
41
|
+
if isinstance(parsed, dict):
|
|
42
|
+
return parsed
|
|
43
|
+
return default
|
|
44
|
+
except (json.JSONDecodeError, TypeError):
|
|
45
|
+
return default
|
|
46
|
+
|
|
47
|
+
return default
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def ensure_list(value: Union[str, List[Any], None], default: List[Any] = None) -> List[Any]:
|
|
51
|
+
"""
|
|
52
|
+
Ensure a value is a list.
|
|
53
|
+
|
|
54
|
+
Handles:
|
|
55
|
+
- None -> returns default or []
|
|
56
|
+
- List -> returns as-is
|
|
57
|
+
- JSON string -> parses and returns list
|
|
58
|
+
- Other -> returns default or []
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
value: The value to ensure is a list
|
|
62
|
+
default: Default value if conversion fails
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
A list
|
|
66
|
+
"""
|
|
67
|
+
if default is None:
|
|
68
|
+
default = []
|
|
69
|
+
|
|
70
|
+
if value is None:
|
|
71
|
+
return default
|
|
72
|
+
|
|
73
|
+
if isinstance(value, list):
|
|
74
|
+
return value
|
|
75
|
+
|
|
76
|
+
if isinstance(value, str):
|
|
77
|
+
try:
|
|
78
|
+
parsed = json.loads(value)
|
|
79
|
+
if isinstance(parsed, list):
|
|
80
|
+
return parsed
|
|
81
|
+
return default
|
|
82
|
+
except (json.JSONDecodeError, TypeError):
|
|
83
|
+
return default
|
|
84
|
+
|
|
85
|
+
return default
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def safe_json_parse(value: Union[str, Dict, List, Any], default: Any = None) -> Any:
|
|
89
|
+
"""
|
|
90
|
+
Safely parse a value that might be JSON string or already parsed.
|
|
91
|
+
|
|
92
|
+
This handles the transition period where some data might be stored as
|
|
93
|
+
JSON strings (old format) and some as proper objects (new format).
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
value: The value to parse
|
|
97
|
+
default: Default value if parsing fails
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
Parsed value or default
|
|
101
|
+
"""
|
|
102
|
+
if value is None:
|
|
103
|
+
return default
|
|
104
|
+
|
|
105
|
+
# If it's already a dict or list, return as-is
|
|
106
|
+
if isinstance(value, (dict, list)):
|
|
107
|
+
return value
|
|
108
|
+
|
|
109
|
+
# If it's a string, try to parse it
|
|
110
|
+
if isinstance(value, str):
|
|
111
|
+
try:
|
|
112
|
+
return json.loads(value)
|
|
113
|
+
except (json.JSONDecodeError, TypeError):
|
|
114
|
+
# If it's not valid JSON, return the string itself
|
|
115
|
+
return value
|
|
116
|
+
|
|
117
|
+
# For any other type, return as-is
|
|
118
|
+
return value
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def to_json_string(value: Any) -> str:
|
|
122
|
+
"""
|
|
123
|
+
Convert a value to a JSON string if needed.
|
|
124
|
+
|
|
125
|
+
This is used for backwards compatibility when yielding data that
|
|
126
|
+
expects JSON strings.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
value: The value to convert
|
|
130
|
+
|
|
131
|
+
Returns:
|
|
132
|
+
JSON string representation
|
|
133
|
+
"""
|
|
134
|
+
if isinstance(value, str):
|
|
135
|
+
# If it's already a string, check if it's valid JSON
|
|
136
|
+
try:
|
|
137
|
+
json.loads(value)
|
|
138
|
+
return value # It's already a JSON string
|
|
139
|
+
except (json.JSONDecodeError, TypeError):
|
|
140
|
+
# It's a plain string, encode it as JSON
|
|
141
|
+
return json.dumps(value)
|
|
142
|
+
|
|
143
|
+
# For all other types, convert to JSON
|
|
144
|
+
return json.dumps(value)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def format_for_yield(message_object: Dict[str, Any]) -> Dict[str, Any]:
|
|
148
|
+
"""
|
|
149
|
+
Format a message object for yielding, ensuring content and metadata are JSON strings.
|
|
150
|
+
|
|
151
|
+
This maintains backward compatibility with clients expecting JSON strings
|
|
152
|
+
while the database now stores proper objects.
|
|
153
|
+
|
|
154
|
+
Args:
|
|
155
|
+
message_object: The message object from the database
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
Message object with content and metadata as JSON strings
|
|
159
|
+
"""
|
|
160
|
+
if not message_object:
|
|
161
|
+
return message_object
|
|
162
|
+
|
|
163
|
+
# Create a copy to avoid modifying the original
|
|
164
|
+
formatted = message_object.copy()
|
|
165
|
+
|
|
166
|
+
# Ensure content is a JSON string
|
|
167
|
+
if 'content' in formatted and not isinstance(formatted['content'], str):
|
|
168
|
+
formatted['content'] = json.dumps(formatted['content'])
|
|
169
|
+
|
|
170
|
+
# Ensure metadata is a JSON string
|
|
171
|
+
if 'metadata' in formatted and not isinstance(formatted['metadata'], str):
|
|
172
|
+
formatted['metadata'] = json.dumps(formatted['metadata'])
|
|
173
|
+
|
|
174
|
+
return formatted
|
xgae/utils/llm_client.py
CHANGED
|
@@ -4,11 +4,24 @@ import logging
|
|
|
4
4
|
import os
|
|
5
5
|
import litellm
|
|
6
6
|
|
|
7
|
-
from typing import Union, Dict, Any, Optional, List
|
|
7
|
+
from typing import Union, Dict, Any, Optional, List, TypedDict
|
|
8
8
|
|
|
9
9
|
from litellm.utils import ModelResponse, CustomStreamWrapper
|
|
10
10
|
from openai import OpenAIError
|
|
11
11
|
|
|
12
|
+
class LLMConfig(TypedDict, total=False):
|
|
13
|
+
model: str
|
|
14
|
+
model_name: str
|
|
15
|
+
model_id: str
|
|
16
|
+
api_key: str
|
|
17
|
+
api_base: str
|
|
18
|
+
temperature: float
|
|
19
|
+
max_tokens: int
|
|
20
|
+
stream: bool
|
|
21
|
+
enable_thinking: bool
|
|
22
|
+
reasoning_effort: str
|
|
23
|
+
response_format: str
|
|
24
|
+
top_p: int
|
|
12
25
|
|
|
13
26
|
class LLMError(Exception):
|
|
14
27
|
"""Base exception for LLM-related errors."""
|
|
@@ -18,7 +31,7 @@ class LLMClient:
|
|
|
18
31
|
RATE_LIMIT_DELAY = 30
|
|
19
32
|
RETRY_DELAY = 0.1
|
|
20
33
|
|
|
21
|
-
def __init__(self, llm_config:
|
|
34
|
+
def __init__(self, llm_config: LLMConfig=None) -> None:
|
|
22
35
|
"""
|
|
23
36
|
Arg: llm_config (Optional[Dict[str, Any]], optional)
|
|
24
37
|
model: Override default model to use, default set by .env LLM_MODEL
|
|
@@ -34,7 +47,7 @@ class LLMClient:
|
|
|
34
47
|
reasoning_effort: Optional level of reasoning effort, default is ‘low’
|
|
35
48
|
top_p: Optional Top-p sampling parameter, default is None
|
|
36
49
|
"""
|
|
37
|
-
llm_config = llm_config or
|
|
50
|
+
llm_config = llm_config or LLMConfig()
|
|
38
51
|
litellm.modify_params = True
|
|
39
52
|
litellm.drop_params = True
|
|
40
53
|
|
|
@@ -214,9 +227,7 @@ class LLMClient:
|
|
|
214
227
|
|
|
215
228
|
if __name__ == "__main__":
|
|
216
229
|
async def llm_completion():
|
|
217
|
-
llm_client = LLMClient(
|
|
218
|
-
"stream": False #default is True
|
|
219
|
-
})
|
|
230
|
+
llm_client = LLMClient(LLMConfig(stream=False))
|
|
220
231
|
messages = [{"role": "user", "content": "今天是2025年8月15日,北京本周每天温度"}]
|
|
221
232
|
response = await llm_client.create_completion(messages)
|
|
222
233
|
if llm_client.is_stream:
|