xgae 0.1.17__tar.gz → 0.1.18__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xgae might be problematic. Click here for more details.
- {xgae-0.1.17 → xgae-0.1.18}/PKG-INFO +1 -1
- {xgae-0.1.17 → xgae-0.1.18}/mcpservers/custom_servers.json +4 -0
- {xgae-0.1.17 → xgae-0.1.18}/pyproject.toml +2 -1
- {xgae-0.1.17 → xgae-0.1.18}/release.md +5 -0
- xgae-0.1.18/src/examples/agent/langgraph/react/final_result_agent.py +59 -0
- {xgae-0.1.17 → xgae-0.1.18}/src/examples/agent/langgraph/react/react_agent.py +47 -43
- xgae-0.1.18/src/examples/engine/run_custom_and_agent_tools.py +42 -0
- {xgae-0.1.17 → xgae-0.1.18}/src/examples/tools/custom_fault_tools_app.py +14 -15
- xgae-0.1.18/src/examples/tools/simu_a2a_tools_app.py +59 -0
- {xgae-0.1.17 → xgae-0.1.18}/src/xgae/engine/engine_base.py +4 -1
- {xgae-0.1.17 → xgae-0.1.18}/src/xgae/engine/mcp_tool_box.py +32 -19
- {xgae-0.1.17 → xgae-0.1.18}/src/xgae/engine/prompt_builder.py +17 -2
- {xgae-0.1.17 → xgae-0.1.18}/src/xgae/engine/responser/non_stream_responser.py +1 -1
- {xgae-0.1.17 → xgae-0.1.18}/src/xgae/engine/responser/responser_base.py +25 -32
- {xgae-0.1.17 → xgae-0.1.18}/src/xgae/engine/responser/stream_responser.py +10 -16
- {xgae-0.1.17 → xgae-0.1.18}/src/xgae/engine/task_engine.py +8 -4
- {xgae-0.1.17 → xgae-0.1.18}/src/xgae/tools/without_general_tools_app.py +3 -3
- xgae-0.1.18/templates/custom_tool_prompt_template.txt +25 -0
- {xgae-0.1.17 → xgae-0.1.18}/templates/example/fault_user_prompt.txt +1 -1
- xgae-0.1.18/templates/example/final_result_template.txt +19 -0
- xgae-0.1.18/uv.lock +1386 -0
- xgae-0.1.17/src/examples/engine/run_user_prompt.py +0 -30
- xgae-0.1.17/uv.lock +0 -1386
- {xgae-0.1.17 → xgae-0.1.18}/.env +0 -0
- {xgae-0.1.17 → xgae-0.1.18}/.python-version +0 -0
- {xgae-0.1.17 → xgae-0.1.18}/README.md +0 -0
- {xgae-0.1.17 → xgae-0.1.18}/mcpservers/xga_server.json +0 -0
- {xgae-0.1.17 → xgae-0.1.18}/mcpservers/xga_server_sse.json +0 -0
- {xgae-0.1.17 → xgae-0.1.18}/src/examples/engine/run_general_tools.py +0 -0
- {xgae-0.1.17 → xgae-0.1.18}/src/examples/engine/run_human_in_loop.py +0 -0
- {xgae-0.1.17 → xgae-0.1.18}/src/examples/engine/run_simple.py +0 -0
- {xgae-0.1.17 → xgae-0.1.18}/src/xgae/__init__.py +0 -0
- {xgae-0.1.17 → xgae-0.1.18}/src/xgae/cli_app.py +0 -0
- {xgae-0.1.17 → xgae-0.1.18}/src/xgae/engine/task_langfuse.py +0 -0
- {xgae-0.1.17 → xgae-0.1.18}/src/xgae/utils/__init__.py +0 -0
- {xgae-0.1.17 → xgae-0.1.18}/src/xgae/utils/json_helpers.py +0 -0
- {xgae-0.1.17 → xgae-0.1.18}/src/xgae/utils/llm_client.py +0 -0
- {xgae-0.1.17 → xgae-0.1.18}/src/xgae/utils/misc.py +0 -0
- {xgae-0.1.17 → xgae-0.1.18}/src/xgae/utils/setup_env.py +0 -0
- {xgae-0.1.17 → xgae-0.1.18}/src/xgae/utils/xml_tool_parser.py +0 -0
- /xgae-0.1.17/templates/custom_tool_prompt_template.txt → /xgae-0.1.18/templates/agent_tool_prompt_template.txt +0 -0
- {xgae-0.1.17 → xgae-0.1.18}/templates/gemini_system_prompt_template.txt +0 -0
- {xgae-0.1.17 → xgae-0.1.18}/templates/general_tool_prompt_template.txt +0 -0
- {xgae-0.1.17 → xgae-0.1.18}/templates/system_prompt_response_sample.txt +0 -0
- {xgae-0.1.17 → xgae-0.1.18}/templates/system_prompt_template.txt +0 -0
- {xgae-0.1.17 → xgae-0.1.18}/test/test_langfuse.py +0 -0
- {xgae-0.1.17 → xgae-0.1.18}/test/test_litellm_langfuse.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "xgae"
|
|
3
|
-
version = "0.1.
|
|
3
|
+
version = "0.1.18"
|
|
4
4
|
description = "Extreme General Agent Engine"
|
|
5
5
|
readme = "README.md"
|
|
6
6
|
requires-python = ">=3.13"
|
|
@@ -24,3 +24,4 @@ exclude = ["log/*", ".idea/*"]
|
|
|
24
24
|
xgae = "xgae.cli_app:main"
|
|
25
25
|
xgae-tools = "xgae.tools.without_general_tools_app:main"
|
|
26
26
|
example-fault-tools = "examples.tools.custom_fault_tools_app:main"
|
|
27
|
+
example-a2a-tools = "examples.tools.simu_a2a_tools_app:main"
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
import re
|
|
4
|
+
from typing import Any, Dict, List
|
|
5
|
+
|
|
6
|
+
from xgae.utils.misc import read_file
|
|
7
|
+
from xgae.utils.llm_client import LLMClient, LangfuseMetadata
|
|
8
|
+
|
|
9
|
+
class FinalResultAgent:
|
|
10
|
+
def __init__(self):
|
|
11
|
+
self.model_client = LLMClient()
|
|
12
|
+
self.prompt_template: str = read_file("templates/example/final_result_template.txt")
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
async def final_result(self, user_request: str, task_results: str, langfuse_metadata:LangfuseMetadata=None)-> Dict[str, Any]:
|
|
16
|
+
prompt = self.prompt_template.replace("{user_request}", user_request)
|
|
17
|
+
prompt = prompt.replace("{task_results}", task_results)
|
|
18
|
+
|
|
19
|
+
messages = [{"role": "user", "content": prompt}]
|
|
20
|
+
|
|
21
|
+
response_text: str = ""
|
|
22
|
+
response = await self.model_client.create_completion(
|
|
23
|
+
messages,
|
|
24
|
+
langfuse_metadata
|
|
25
|
+
)
|
|
26
|
+
if self.model_client.is_stream:
|
|
27
|
+
async for chunk in response:
|
|
28
|
+
choices = chunk.get("choices", [{}])
|
|
29
|
+
if not choices:
|
|
30
|
+
continue
|
|
31
|
+
delta = choices[0].get("delta", {})
|
|
32
|
+
content = delta.get("content", "")
|
|
33
|
+
if content:
|
|
34
|
+
response_text += content
|
|
35
|
+
else:
|
|
36
|
+
response_text = response.choices[0].message.content
|
|
37
|
+
|
|
38
|
+
cleaned_text = re.sub(r'^\s*```json|```\s*$', '', response_text, flags=re.MULTILINE).strip()
|
|
39
|
+
final_result = json.loads(cleaned_text)
|
|
40
|
+
return final_result
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
if __name__ == "__main__":
|
|
44
|
+
import asyncio
|
|
45
|
+
from xgae.utils.setup_env import setup_logging
|
|
46
|
+
setup_logging()
|
|
47
|
+
|
|
48
|
+
async def main():
|
|
49
|
+
final_result_agent = FinalResultAgent()
|
|
50
|
+
|
|
51
|
+
user_input = "locate 10.2.3.4 fault and solution"
|
|
52
|
+
answer = ("Task Summary: The fault for IP 10.2.3.4 was identified as a Business Recharge Fault (Code: F01), "
|
|
53
|
+
"caused by a Phone Recharge Application Crash. The solution applied was to restart the application. "
|
|
54
|
+
"Key Deliverables: Fault diagnosis and resolution steps. Impact Achieved: Service restored.")
|
|
55
|
+
return await final_result_agent.final_result(user_input, answer)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
final_result = asyncio.run(main())
|
|
59
|
+
print(f"FINAL_RESULT: {final_result} ")
|
|
@@ -40,32 +40,32 @@ class XGAReactAgent:
|
|
|
40
40
|
graph_builder = StateGraph(TaskState)
|
|
41
41
|
|
|
42
42
|
# Add nodes
|
|
43
|
-
graph_builder.add_node(
|
|
44
|
-
graph_builder.add_node(
|
|
45
|
-
graph_builder.add_node(
|
|
46
|
-
graph_builder.add_node(
|
|
43
|
+
graph_builder.add_node('supervisor', self._supervisor_node)
|
|
44
|
+
graph_builder.add_node('select_tool', self._select_tool_node)
|
|
45
|
+
graph_builder.add_node('exec_task', self._exec_task_node)
|
|
46
|
+
graph_builder.add_node('final_result', self._final_result_node)
|
|
47
47
|
|
|
48
48
|
# Add edges
|
|
49
|
-
graph_builder.add_edge(START,
|
|
49
|
+
graph_builder.add_edge(START, 'supervisor')
|
|
50
50
|
graph_builder.add_conditional_edges(
|
|
51
|
-
|
|
51
|
+
'supervisor',
|
|
52
52
|
self._next_condition,
|
|
53
53
|
{
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
54
|
+
'select_tool': 'select_tool',
|
|
55
|
+
'exec_task': 'exec_task',
|
|
56
|
+
'end': END
|
|
57
57
|
}
|
|
58
58
|
)
|
|
59
59
|
|
|
60
|
-
graph_builder.add_edge(
|
|
61
|
-
graph_builder.add_edge(
|
|
60
|
+
graph_builder.add_edge('select_tool', 'exec_task')
|
|
61
|
+
graph_builder.add_edge('exec_task', 'final_result')
|
|
62
62
|
|
|
63
63
|
graph_builder.add_conditional_edges(
|
|
64
|
-
|
|
64
|
+
'final_result',
|
|
65
65
|
self._next_condition,
|
|
66
66
|
{
|
|
67
|
-
|
|
68
|
-
|
|
67
|
+
'supervisor': 'supervisor',
|
|
68
|
+
'end': END
|
|
69
69
|
}
|
|
70
70
|
)
|
|
71
71
|
|
|
@@ -91,10 +91,10 @@ class XGAReactAgent:
|
|
|
91
91
|
|
|
92
92
|
next_node = "select_tool" if system_prompt else "exec_task"
|
|
93
93
|
return {
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
94
|
+
'system_prompt' : system_prompt,
|
|
95
|
+
'next_node' : next_node,
|
|
96
|
+
'general_tools' : general_tools,
|
|
97
|
+
'custom_tools' : custom_tools,
|
|
98
98
|
}
|
|
99
99
|
|
|
100
100
|
def _select_custom_tools(self, system_prompt: str) -> list[str]:
|
|
@@ -102,20 +102,21 @@ class XGAReactAgent:
|
|
|
102
102
|
return custom_tools
|
|
103
103
|
|
|
104
104
|
async def _select_tool_node(self, state: TaskState) -> Dict[str, Any]:
|
|
105
|
-
system_prompt = state.get(
|
|
105
|
+
system_prompt = state.get('system_prompt',None)
|
|
106
106
|
general_tools = []
|
|
107
107
|
custom_tools = self._select_custom_tools(system_prompt)
|
|
108
108
|
return {
|
|
109
|
-
|
|
110
|
-
|
|
109
|
+
'general_tools' : general_tools,
|
|
110
|
+
'custom_tools' : custom_tools,
|
|
111
111
|
}
|
|
112
112
|
|
|
113
113
|
async def _exec_task_node(self, state: TaskState) -> Dict[str, Any]:
|
|
114
|
-
user_input = state
|
|
115
|
-
system_prompt = state.get(
|
|
116
|
-
general_tools = state.get(
|
|
117
|
-
custom_tools = state.get(
|
|
114
|
+
user_input = state['user_input']
|
|
115
|
+
system_prompt = state.get('system_prompt',None)
|
|
116
|
+
general_tools = state.get('general_tools',[])
|
|
117
|
+
custom_tools = state.get('custom_tools',[])
|
|
118
118
|
is_system_prompt = True if system_prompt is not None else False
|
|
119
|
+
|
|
119
120
|
try:
|
|
120
121
|
logging.info(f"🔥 XGATaskEngine: run_task_with_final_answer: user_input={user_input}, general_tools={general_tools}, custom_tools={custom_tools}, is_system_prompt={is_system_prompt}")
|
|
121
122
|
engine = XGATaskEngine(tool_box=self.tool_box,
|
|
@@ -129,24 +130,25 @@ class XGAReactAgent:
|
|
|
129
130
|
logging.error("Failed to execute task: %s", str(e))
|
|
130
131
|
task_result = XGATaskResult(type="error", content="Failed to execute task")
|
|
131
132
|
|
|
132
|
-
iteration_count = state.get(
|
|
133
|
+
iteration_count = state.get('iteration_count', 0) + 1
|
|
133
134
|
return {
|
|
134
|
-
|
|
135
|
-
|
|
135
|
+
'task_result' : task_result,
|
|
136
|
+
'iteration_count': iteration_count,
|
|
136
137
|
}
|
|
137
138
|
|
|
138
139
|
async def _final_result_node(self, state: TaskState) -> Dict[str, Any]:
|
|
139
|
-
iteration_count = state
|
|
140
|
-
task_result = state
|
|
140
|
+
iteration_count = state['iteration_count']
|
|
141
|
+
task_result = state['task_result']
|
|
141
142
|
next_node = "end"
|
|
142
143
|
if iteration_count < self.MAX_TASK_RETRY and task_result["type"] == "error":
|
|
143
144
|
next_node = "supervisor"
|
|
145
|
+
|
|
144
146
|
return {
|
|
145
|
-
|
|
147
|
+
'next_node' : next_node
|
|
146
148
|
}
|
|
147
149
|
|
|
148
150
|
def _next_condition(self, state: TaskState) -> str:
|
|
149
|
-
next_node = state
|
|
151
|
+
next_node = state['next_node']
|
|
150
152
|
return next_node
|
|
151
153
|
|
|
152
154
|
|
|
@@ -161,24 +163,26 @@ class XGAReactAgent:
|
|
|
161
163
|
|
|
162
164
|
# Initialize state
|
|
163
165
|
initial_state = {
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
166
|
+
'messages' : [HumanMessage(content=f"information for: {user_input}")],
|
|
167
|
+
'user_input' : user_input,
|
|
168
|
+
'next_node' : None,
|
|
169
|
+
'agent_context' : {},
|
|
170
|
+
'iteration_count' : 0
|
|
169
171
|
}
|
|
170
172
|
|
|
171
173
|
# Run the retrieval graph with proper configuration
|
|
172
|
-
config = {
|
|
173
|
-
|
|
174
|
+
config = {'recursion_limit': 100,
|
|
175
|
+
'configurable': {
|
|
176
|
+
'thread_id': "manager_async_generate_thread"
|
|
177
|
+
}}
|
|
174
178
|
final_state = await self.graph.ainvoke(initial_state, config=config)
|
|
175
179
|
|
|
176
180
|
# Parse and return formatted results
|
|
177
181
|
result = final_state["task_result"]
|
|
178
182
|
|
|
179
183
|
logging.info("=" * 100)
|
|
180
|
-
logging.info("
|
|
181
|
-
logging.info("
|
|
184
|
+
logging.info(f"USER QUESTION: {user_input}")
|
|
185
|
+
logging.info(f"LLM ANSWER: {result}")
|
|
182
186
|
logging.info("=" * 100)
|
|
183
187
|
|
|
184
188
|
return result
|
|
@@ -194,8 +198,8 @@ if __name__ == "__main__":
|
|
|
194
198
|
async def main():
|
|
195
199
|
agent = XGAReactAgent()
|
|
196
200
|
user_inputs = [
|
|
197
|
-
|
|
198
|
-
"5+5",
|
|
201
|
+
"locate 10.2.3.4 fault and solution",
|
|
202
|
+
#"5+5",
|
|
199
203
|
]
|
|
200
204
|
for user_input in user_inputs:
|
|
201
205
|
result = await agent.generate(user_input)
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
|
|
3
|
+
from xgae.engine.mcp_tool_box import XGAMcpToolBox
|
|
4
|
+
from xgae.engine.task_engine import XGATaskEngine
|
|
5
|
+
from xgae.utils.misc import read_file
|
|
6
|
+
|
|
7
|
+
from xgae.utils.setup_env import setup_logging
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
is_stream = False
|
|
11
|
+
if is_stream:
|
|
12
|
+
setup_logging(log_level="ERROR") # only show chunk
|
|
13
|
+
else:
|
|
14
|
+
setup_logging()
|
|
15
|
+
|
|
16
|
+
async def main() -> None:
|
|
17
|
+
tool_box = XGAMcpToolBox(custom_mcp_server_file="mcpservers/custom_servers.json")
|
|
18
|
+
system_prompt = read_file("templates/example/fault_user_prompt.txt")
|
|
19
|
+
|
|
20
|
+
engine = XGATaskEngine(tool_box=tool_box,
|
|
21
|
+
general_tools=[],
|
|
22
|
+
custom_tools=["*"],
|
|
23
|
+
system_prompt=system_prompt)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
user_input = "locate 10.2.3.4 fault and solution"
|
|
27
|
+
global is_stream
|
|
28
|
+
if is_stream:
|
|
29
|
+
chunks = []
|
|
30
|
+
async for chunk in engine.run_task(task_message={"role": "user", "content": user_input}):
|
|
31
|
+
chunks.append(chunk)
|
|
32
|
+
print(chunk)
|
|
33
|
+
|
|
34
|
+
final_result = engine.parse_final_result(chunks)
|
|
35
|
+
print(f"\n\nFINAL_RESULT: {final_result}")
|
|
36
|
+
else:
|
|
37
|
+
final_result = await engine.run_task_with_final_answer(task_message={"role": "user", "content": user_input})
|
|
38
|
+
print(f"\n\nFINAL_RESULT: {final_result}")
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
# Before Run Exec: uv run example-fault-tools --alarmtype=2 , uv run example-a2a-tools
|
|
42
|
+
asyncio.run(main())
|
|
@@ -1,8 +1,7 @@
|
|
|
1
1
|
import click
|
|
2
2
|
import logging
|
|
3
3
|
|
|
4
|
-
from typing import Annotated
|
|
5
|
-
from typing import Dict, Any
|
|
4
|
+
from typing import Annotated, Dict, Any
|
|
6
5
|
from pydantic import Field
|
|
7
6
|
|
|
8
7
|
from mcp.server.fastmcp import FastMCP
|
|
@@ -65,19 +64,19 @@ async def get_busi_fault_cause(fault_code: Annotated[str, Field(description="Fau
|
|
|
65
64
|
return fault_cause
|
|
66
65
|
|
|
67
66
|
|
|
68
|
-
@mcp.tool(
|
|
69
|
-
|
|
70
|
-
)
|
|
71
|
-
async def
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
67
|
+
# @mcp.tool(
|
|
68
|
+
# description="Get Equipment Type Fault Solution and Cause",
|
|
69
|
+
# )
|
|
70
|
+
# async def query_equip_fault_cause(fault_code: Annotated[str, Field(description="Fault Code")]) -> str:
|
|
71
|
+
# logging.info(f"get_equip_fault_cause: faultCode={fault_code}")
|
|
72
|
+
#
|
|
73
|
+
# fault_cause = ""
|
|
74
|
+
# if (fault_code == 'F02'):
|
|
75
|
+
# fault_cause = "Host Fault, Fault Cause is 'Host Disk is Damaged' ,Solution is 'Change Host Disk'"
|
|
76
|
+
# else:
|
|
77
|
+
# fault_cause = f"FaultCode '{fault_code}' is not Equipment Type"
|
|
78
|
+
#
|
|
79
|
+
# return fault_cause
|
|
81
80
|
|
|
82
81
|
|
|
83
82
|
@click.command()
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
import click
|
|
2
|
+
import logging
|
|
3
|
+
|
|
4
|
+
from typing import Annotated, Optional, Dict, List, Any, Literal, TypedDict
|
|
5
|
+
from pydantic import Field
|
|
6
|
+
|
|
7
|
+
from mcp.server.fastmcp import FastMCP
|
|
8
|
+
|
|
9
|
+
mcp = FastMCP(name="Simulate A2A MCP Proxy")
|
|
10
|
+
|
|
11
|
+
# XGA agent tool must return {type, content, attachments} like XGAAgentResult class
|
|
12
|
+
class XGAAgentResult(TypedDict, total=False):
|
|
13
|
+
type: Literal["ask", "answer", "error"]
|
|
14
|
+
content: str
|
|
15
|
+
attachments: Optional[List[str]]
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@mcp.tool(
|
|
19
|
+
description="Get Equipment Type Fault Solution and Cause",
|
|
20
|
+
)
|
|
21
|
+
# XGA agent tool, must have 2 parameter, first parameter must be task_id, second parameter must be str type
|
|
22
|
+
def query_equip_fault_cause(task_id:str, input: Annotated[str, Field(description="Fault Code")]):
|
|
23
|
+
logging.info(f"get_equip_fault_cause: task_id={task_id}, faultCode={input}")
|
|
24
|
+
|
|
25
|
+
fault_cause:XGAAgentResult = None
|
|
26
|
+
if 'F02' in input:
|
|
27
|
+
fault_cause:XGAAgentResult = {
|
|
28
|
+
'type': "answer",
|
|
29
|
+
'content': "Host Fault, Fault Cause is 'Host Disk is Damaged' ,Solution is 'Change Host Disk'"
|
|
30
|
+
}
|
|
31
|
+
else:
|
|
32
|
+
fault_cause:XGAAgentResult = {
|
|
33
|
+
'type': "ask",
|
|
34
|
+
'content': f"FaultCode '{input}' is not Equipment Type"
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
return fault_cause
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@click.command()
|
|
41
|
+
@click.option("--transport", type=click.Choice(["stdio", "sse"]), default="sse", help="Transport type")
|
|
42
|
+
@click.option("--host", default="0.0.0.0", help="Host to listen on for SSE")
|
|
43
|
+
@click.option("--port", default=21010, help="Port to listen on for SSE")
|
|
44
|
+
def main(transport: str, host: str, port: int):
|
|
45
|
+
if transport != "stdio":
|
|
46
|
+
from xgae.utils.setup_env import setup_logging
|
|
47
|
+
setup_logging()
|
|
48
|
+
logging.info("=" * 10 + f" Simulate A2A MCP Proxy Sever Started " + "=" * 10)
|
|
49
|
+
logging.info(f"=== transport={transport}, host={host}, port={port}")
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
mcp.settings.host = host
|
|
53
|
+
mcp.settings.port = port
|
|
54
|
+
|
|
55
|
+
mcp.run(transport=transport)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
if __name__ == '__main__':
|
|
59
|
+
main()
|
|
@@ -21,9 +21,12 @@ class XGATaskResult(TypedDict, total=False):
|
|
|
21
21
|
content: str
|
|
22
22
|
attachments: Optional[List[str]]
|
|
23
23
|
|
|
24
|
+
XGAToolType = Literal["general", "custom", "agent"]
|
|
25
|
+
|
|
24
26
|
@dataclass
|
|
25
27
|
class XGAToolSchema:
|
|
26
28
|
tool_name: str
|
|
29
|
+
tool_type: XGAToolType
|
|
27
30
|
server_name: str
|
|
28
31
|
description: str
|
|
29
32
|
input_schema: Dict[str, Any]
|
|
@@ -46,7 +49,7 @@ class XGAToolBox(ABC):
|
|
|
46
49
|
pass
|
|
47
50
|
|
|
48
51
|
@abstractmethod
|
|
49
|
-
def get_task_tool_schemas(self, task_id: str, type:
|
|
52
|
+
def get_task_tool_schemas(self, task_id: str, type: XGAToolType) -> List[XGAToolSchema]:
|
|
50
53
|
pass
|
|
51
54
|
|
|
52
55
|
@abstractmethod
|
|
@@ -7,10 +7,11 @@ from typing import List, Any, Dict, Optional, Literal, override
|
|
|
7
7
|
from langchain_mcp_adapters.client import MultiServerMCPClient
|
|
8
8
|
from langchain_mcp_adapters.tools import load_mcp_tools
|
|
9
9
|
|
|
10
|
-
from xgae.engine.engine_base import XGAError, XGAToolSchema, XGAToolBox, XGAToolResult
|
|
10
|
+
from xgae.engine.engine_base import XGAError, XGAToolSchema, XGAToolBox, XGAToolResult, XGAToolType
|
|
11
11
|
|
|
12
12
|
class XGAMcpToolBox(XGAToolBox):
|
|
13
13
|
GENERAL_MCP_SERVER_NAME = "xga_general"
|
|
14
|
+
AGENT_MCP_SERVER_PREFIX = "_@_"
|
|
14
15
|
|
|
15
16
|
def __init__(self,
|
|
16
17
|
custom_mcp_server_file: Optional[str] = None,
|
|
@@ -37,7 +38,7 @@ class XGAMcpToolBox(XGAToolBox):
|
|
|
37
38
|
@override
|
|
38
39
|
async def creat_task_tool_box(self, task_id: str, general_tools: List[str], custom_tools: List[str]):
|
|
39
40
|
task_tool_schemas = {}
|
|
40
|
-
general_tool_schemas = self.mcp_tool_schemas.get(
|
|
41
|
+
general_tool_schemas = self.mcp_tool_schemas.get(self.GENERAL_MCP_SERVER_NAME, {})
|
|
41
42
|
if "*" in general_tools:
|
|
42
43
|
task_tool_schemas = {tool_schema.tool_name: tool_schema for tool_schema in general_tool_schemas}
|
|
43
44
|
else:
|
|
@@ -49,7 +50,7 @@ class XGAMcpToolBox(XGAToolBox):
|
|
|
49
50
|
if len(custom_tools) == 1 and custom_tools[0] == "*":
|
|
50
51
|
custom_tools = []
|
|
51
52
|
for server_name in self.mcp_server_names:
|
|
52
|
-
if server_name !=
|
|
53
|
+
if server_name != self.GENERAL_MCP_SERVER_NAME:
|
|
53
54
|
custom_tools.append(f"{server_name}.*")
|
|
54
55
|
|
|
55
56
|
for server_tool_name in custom_tools:
|
|
@@ -76,7 +77,7 @@ class XGAMcpToolBox(XGAToolBox):
|
|
|
76
77
|
|
|
77
78
|
@override
|
|
78
79
|
async def destroy_task_tool_box(self, task_id: str):
|
|
79
|
-
tool_schemas = self.get_task_tool_schemas(task_id,
|
|
80
|
+
tool_schemas = self.get_task_tool_schemas(task_id, "general")
|
|
80
81
|
if len(tool_schemas) > 0:
|
|
81
82
|
await self.call_tool(task_id, "end_task", {'task_id': task_id})
|
|
82
83
|
self.task_tool_schemas.pop(task_id, None)
|
|
@@ -88,14 +89,12 @@ class XGAMcpToolBox(XGAToolBox):
|
|
|
88
89
|
return task_tool_names
|
|
89
90
|
|
|
90
91
|
@override
|
|
91
|
-
def get_task_tool_schemas(self, task_id: str,
|
|
92
|
+
def get_task_tool_schemas(self, task_id: str, tool_type: XGAToolType) -> List[XGAToolSchema]:
|
|
92
93
|
task_tool_schemas = []
|
|
93
94
|
|
|
94
95
|
all_task_tool_schemas = self.task_tool_schemas.get(task_id, {})
|
|
95
96
|
for tool_schema in all_task_tool_schemas.values():
|
|
96
|
-
if
|
|
97
|
-
task_tool_schemas.append(tool_schema)
|
|
98
|
-
elif type == "custom_tool" and tool_schema.server_name != self.GENERAL_MCP_SERVER_NAME:
|
|
97
|
+
if tool_schema.tool_type == tool_type:
|
|
99
98
|
task_tool_schemas.append(tool_schema)
|
|
100
99
|
|
|
101
100
|
return task_tool_schemas
|
|
@@ -114,16 +113,16 @@ class XGAMcpToolBox(XGAToolBox):
|
|
|
114
113
|
async with self._mcp_client.session(server_name) as session:
|
|
115
114
|
tools = await load_mcp_tools(session)
|
|
116
115
|
mcp_tool = next((t for t in tools if t.name == tool_name), None)
|
|
117
|
-
|
|
116
|
+
|
|
118
117
|
if mcp_tool:
|
|
119
118
|
tool_args = args or {}
|
|
120
|
-
|
|
119
|
+
tool_type = self._get_tool_type(server_name)
|
|
120
|
+
if tool_type == "general" or tool_type == "agent":
|
|
121
121
|
tool_args = dict({'task_id': task_id}, **tool_args)
|
|
122
|
-
is_general_tool = True
|
|
123
122
|
|
|
124
123
|
try:
|
|
125
124
|
tool_result = await mcp_tool.arun(tool_args)
|
|
126
|
-
if
|
|
125
|
+
if tool_type == "general":
|
|
127
126
|
tool_result = json.loads(tool_result)
|
|
128
127
|
result = XGAToolResult(success=tool_result['success'], output=str(tool_result['output']))
|
|
129
128
|
else:
|
|
@@ -144,11 +143,17 @@ class XGAMcpToolBox(XGAToolBox):
|
|
|
144
143
|
if not self.is_loaded_tool_schemas:
|
|
145
144
|
for server_name in self.mcp_server_names:
|
|
146
145
|
self.mcp_tool_schemas[server_name] = []
|
|
147
|
-
|
|
146
|
+
try:
|
|
147
|
+
mcp_tools = await self._mcp_client.get_tools(server_name=server_name)
|
|
148
|
+
except Exception as e:
|
|
149
|
+
logging.error(f"### McpToolBox load_mcp_tools_schema: Langchain mcp get_tools failed, "
|
|
150
|
+
f"need start mcp server '{server_name}' !")
|
|
151
|
+
continue
|
|
148
152
|
|
|
153
|
+
tool_type = self._get_tool_type(server_name)
|
|
149
154
|
for tool in mcp_tools:
|
|
150
155
|
input_schema = tool.args_schema
|
|
151
|
-
if
|
|
156
|
+
if tool_type == "general" or tool_type == "agent":
|
|
152
157
|
input_schema['properties'].pop("task_id", None)
|
|
153
158
|
if 'task_id' in input_schema['required']:
|
|
154
159
|
input_schema['required'].remove('task_id')
|
|
@@ -158,6 +163,7 @@ class XGAMcpToolBox(XGAToolBox):
|
|
|
158
163
|
|
|
159
164
|
metadata = tool.metadata or {}
|
|
160
165
|
tool_schema = XGAToolSchema(tool_name=tool.name,
|
|
166
|
+
tool_type=tool_type,
|
|
161
167
|
server_name=server_name,
|
|
162
168
|
description=tool.description,
|
|
163
169
|
input_schema=input_schema,
|
|
@@ -169,8 +175,8 @@ class XGAMcpToolBox(XGAToolBox):
|
|
|
169
175
|
self.is_loaded_tool_schemas = False
|
|
170
176
|
await self.load_mcp_tools_schema()
|
|
171
177
|
|
|
172
|
-
|
|
173
|
-
def _load_mcp_servers_config(mcp_config_path: str) -> Dict[str, Any]:
|
|
178
|
+
|
|
179
|
+
def _load_mcp_servers_config(self, mcp_config_path: str) -> Dict[str, Any]:
|
|
174
180
|
try:
|
|
175
181
|
if os.path.exists(mcp_config_path):
|
|
176
182
|
with open(mcp_config_path, 'r', encoding="utf-8") as f:
|
|
@@ -192,6 +198,13 @@ class XGAMcpToolBox(XGAToolBox):
|
|
|
192
198
|
logging.error(f"McpToolBox load_mcp_servers_config: Failed to load MCP servers config: {e}")
|
|
193
199
|
return {'mcpServers': {}}
|
|
194
200
|
|
|
201
|
+
def _get_tool_type(self, server_name: str) -> XGAToolType:
|
|
202
|
+
tool_type: XGAToolType = "custom"
|
|
203
|
+
if server_name == self.GENERAL_MCP_SERVER_NAME:
|
|
204
|
+
tool_type = "general"
|
|
205
|
+
elif server_name.startswith(self.AGENT_MCP_SERVER_PREFIX):
|
|
206
|
+
tool_type = "agent"
|
|
207
|
+
return tool_type
|
|
195
208
|
|
|
196
209
|
if __name__ == "__main__":
|
|
197
210
|
import asyncio
|
|
@@ -207,14 +220,14 @@ if __name__ == "__main__":
|
|
|
207
220
|
|
|
208
221
|
task_id = "task1"
|
|
209
222
|
await mcp_tool_box.load_mcp_tools_schema()
|
|
210
|
-
await mcp_tool_box.creat_task_tool_box(task_id=task_id, general_tools=["*"], custom_tools=["
|
|
211
|
-
tool_schemas = mcp_tool_box.get_task_tool_schemas(task_id, "
|
|
223
|
+
await mcp_tool_box.creat_task_tool_box(task_id=task_id, general_tools=["*"], custom_tools=["*"])
|
|
224
|
+
tool_schemas = mcp_tool_box.get_task_tool_schemas(task_id, "general")
|
|
212
225
|
print("general_tools_schemas" + "*"*50)
|
|
213
226
|
for tool_schema in tool_schemas:
|
|
214
227
|
print(asdict(tool_schema))
|
|
215
228
|
print()
|
|
216
229
|
|
|
217
|
-
tool_schemas = mcp_tool_box.get_task_tool_schemas(task_id, "
|
|
230
|
+
tool_schemas = mcp_tool_box.get_task_tool_schemas(task_id, "custom")
|
|
218
231
|
print("custom_tools_schemas" + "*" * 50)
|
|
219
232
|
for tool_schema in tool_schemas:
|
|
220
233
|
print(asdict(tool_schema))
|
|
@@ -11,7 +11,11 @@ class XGAPromptBuilder():
|
|
|
11
11
|
def __init__(self, system_prompt: Optional[str] = None):
|
|
12
12
|
self.system_prompt = system_prompt
|
|
13
13
|
|
|
14
|
-
def build_task_prompt(self,
|
|
14
|
+
def build_task_prompt(self,
|
|
15
|
+
model_name: str,
|
|
16
|
+
general_tool_schemas: List[XGAToolSchema],
|
|
17
|
+
custom_tool_schemas: List[XGAToolSchema],
|
|
18
|
+
agent_tool_schemas: List[XGAToolSchema])-> str:
|
|
15
19
|
if self.system_prompt is None:
|
|
16
20
|
self.system_prompt = self._load_default_system_prompt(model_name)
|
|
17
21
|
|
|
@@ -23,6 +27,9 @@ class XGAPromptBuilder():
|
|
|
23
27
|
tool_prompt = self.build_custom_tool_prompt(custom_tool_schemas)
|
|
24
28
|
task_prompt = task_prompt + "\n" + tool_prompt
|
|
25
29
|
|
|
30
|
+
tool_prompt = self.build_agent_tool_prompt(agent_tool_schemas)
|
|
31
|
+
task_prompt = task_prompt + "\n" + tool_prompt
|
|
32
|
+
|
|
26
33
|
return task_prompt
|
|
27
34
|
|
|
28
35
|
def build_general_tool_prompt(self, tool_schemas:List[XGAToolSchema])-> str:
|
|
@@ -61,10 +68,18 @@ class XGAPromptBuilder():
|
|
|
61
68
|
|
|
62
69
|
|
|
63
70
|
def build_custom_tool_prompt(self, tool_schemas:List[XGAToolSchema])-> str:
|
|
71
|
+
tool_prompt = self.build_mcp_tool_prompt("templates/custom_tool_prompt_template.txt", tool_schemas)
|
|
72
|
+
return tool_prompt
|
|
73
|
+
|
|
74
|
+
def build_agent_tool_prompt(self, tool_schemas:List[XGAToolSchema])-> str:
|
|
75
|
+
tool_prompt = self.build_mcp_tool_prompt("templates/agent_tool_prompt_template.txt", tool_schemas)
|
|
76
|
+
return tool_prompt
|
|
77
|
+
|
|
78
|
+
def build_mcp_tool_prompt(self, file_path: str, tool_schemas:List[XGAToolSchema])-> str:
|
|
64
79
|
tool_prompt = ""
|
|
65
80
|
tool_schemas = tool_schemas or []
|
|
66
81
|
if len(tool_schemas) > 0:
|
|
67
|
-
tool_prompt = read_file(
|
|
82
|
+
tool_prompt = read_file(file_path)
|
|
68
83
|
tool_info = ""
|
|
69
84
|
for tool_schema in tool_schemas:
|
|
70
85
|
description = tool_schema.description if tool_schema.description else 'No description available'
|
|
@@ -64,7 +64,7 @@ class NonStreamTaskResponser(TaskResponseProcessor):
|
|
|
64
64
|
tool_start_msg = self._add_tool_start_message(tool_context)
|
|
65
65
|
yield tool_start_msg
|
|
66
66
|
|
|
67
|
-
tool_message = self._add_tool_messsage(
|
|
67
|
+
tool_message = self._add_tool_messsage(tool_context, self.xml_adding_strategy)
|
|
68
68
|
|
|
69
69
|
tool_completed_msg = self._add_tool_completed_message(tool_context, tool_message['message_id'])
|
|
70
70
|
yield tool_completed_msg
|