xgae 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xgae might be problematic. Click here for more details.
- xgae/engine/responser/xga_non_stream_responser.py +213 -0
- xgae/engine/responser/xga_responser_base.py +751 -0
- xgae/engine/responser/xga_stream_responser.py +787 -0
- xgae/engine/xga_base.py +23 -9
- xgae/engine/xga_engine.py +228 -55
- xgae/engine/xga_mcp_tool_box.py +8 -3
- xgae/engine/xga_prompt_builder.py +27 -45
- xgae/utils/json_helpers.py +174 -0
- xgae/utils/llm_client.py +17 -6
- xgae/utils/setup_env.py +1 -31
- xgae/utils/utils.py +42 -0
- xgae/utils/xml_tool_parser.py +236 -0
- {xgae-0.1.2.dist-info → xgae-0.1.4.dist-info}/METADATA +1 -1
- xgae-0.1.4.dist-info/RECORD +16 -0
- xgae/engine/responser/xga_responser_utils.py +0 -0
- xgae/engine/responser/xga_stream_reponser.py +0 -0
- xgae-0.1.2.dist-info/RECORD +0 -13
- {xgae-0.1.2.dist-info → xgae-0.1.4.dist-info}/WHEEL +0 -0
xgae/engine/xga_base.py
CHANGED
|
@@ -1,19 +1,29 @@
|
|
|
1
|
-
from typing import Union, Optional, Dict, List, Any, Literal
|
|
1
|
+
from typing import Union, Optional, Dict, List, Any, Literal, TypedDict
|
|
2
2
|
from dataclasses import dataclass
|
|
3
3
|
from abc import ABC, abstractmethod
|
|
4
4
|
|
|
5
|
+
class XGAError(Exception):
|
|
6
|
+
"""Custom exception for errors in the XGA system."""
|
|
7
|
+
pass
|
|
5
8
|
|
|
6
9
|
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
type: Literal["status", "tool", "assistant", "assistant_response_end"]
|
|
10
|
+
class XGAContextMsg(TypedDict, total=False):
|
|
11
|
+
type: Literal["user", "status", "tool", "assistant", "assistant_response_end"]
|
|
12
|
+
content: Union[Dict[str, Any], List[Any], str]
|
|
11
13
|
is_llm_message: bool
|
|
14
|
+
metadata: Dict[str, Any]
|
|
15
|
+
message_id: str
|
|
16
|
+
session_id: str
|
|
17
|
+
agent_id: str
|
|
18
|
+
task_id: str
|
|
19
|
+
task_run_id: str
|
|
20
|
+
trace_id: str
|
|
21
|
+
|
|
22
|
+
class XGAResponseMsg(TypedDict, total=False):
|
|
23
|
+
type: Literal["content", "status"]
|
|
12
24
|
content: Union[Dict[str, Any], List[Any], str]
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
agent_id: Optional[str]
|
|
16
|
-
task_id: Optional[str]
|
|
25
|
+
status: Literal["error", "status"]
|
|
26
|
+
message: str
|
|
17
27
|
|
|
18
28
|
@dataclass
|
|
19
29
|
class XGAToolSchema:
|
|
@@ -42,6 +52,10 @@ class XGAToolBox(ABC):
|
|
|
42
52
|
def get_task_tool_schemas(self, task_id: str, type: Literal["general_tool", "custom_tool"]) -> List[XGAToolSchema]:
|
|
43
53
|
pass
|
|
44
54
|
|
|
55
|
+
@abstractmethod
|
|
56
|
+
def get_task_tool_names(self, task_id: str) -> List[str]:
|
|
57
|
+
pass
|
|
58
|
+
|
|
45
59
|
@abstractmethod
|
|
46
60
|
async def call_tool(self, task_id: str, tool_name: str, args: Optional[Dict[str, Any]] = None) -> XGAToolResult:
|
|
47
61
|
pass
|
xgae/engine/xga_engine.py
CHANGED
|
@@ -1,105 +1,278 @@
|
|
|
1
1
|
|
|
2
|
-
|
|
2
|
+
import logging
|
|
3
|
+
import json
|
|
4
|
+
|
|
5
|
+
from typing import List, Any, Dict, Optional, AsyncGenerator, cast, Union, Literal
|
|
3
6
|
from uuid import uuid4
|
|
4
7
|
|
|
5
|
-
from xgae.engine.
|
|
6
|
-
from xgae.
|
|
8
|
+
from xgae.engine.responser.xga_responser_base import TaskResponseContext, TaskResponseProcessor, TaskRunContinuousState
|
|
9
|
+
from xgae.engine.xga_base import XGAContextMsg, XGAToolBox, XGAResponseMsg
|
|
10
|
+
from xgae.utils.llm_client import LLMClient, LLMConfig
|
|
7
11
|
from xgae.utils.setup_env import langfuse
|
|
12
|
+
from xgae.utils.utils import handle_error
|
|
13
|
+
|
|
8
14
|
from xga_prompt_builder import XGAPromptBuilder
|
|
9
15
|
from xga_mcp_tool_box import XGAMcpToolBox
|
|
10
16
|
|
|
11
|
-
|
|
12
|
-
class XGAEngine():
|
|
17
|
+
class XGATaskEngine:
|
|
13
18
|
def __init__(self,
|
|
14
19
|
session_id: Optional[str] = None,
|
|
15
|
-
|
|
20
|
+
task_id: Optional[str] = None,
|
|
16
21
|
agent_id: Optional[str] = None,
|
|
17
|
-
|
|
22
|
+
system_prompt: Optional[str] = None,
|
|
23
|
+
llm_config: Optional[LLMConfig] = None,
|
|
18
24
|
prompt_builder: Optional[XGAPromptBuilder] = None,
|
|
19
25
|
tool_box: Optional[XGAToolBox] = None):
|
|
20
26
|
self.session_id = session_id if session_id else f"xga_sid_{uuid4()}"
|
|
27
|
+
self.task_id = task_id if task_id else f"xga_task_{uuid4()}"
|
|
21
28
|
self.agent_id = agent_id
|
|
22
29
|
|
|
23
|
-
self.messages: List[XGAMessage] = []
|
|
24
30
|
self.llm_client = LLMClient(llm_config)
|
|
25
31
|
self.model_name = self.llm_client.model_name
|
|
26
32
|
self.is_stream = self.llm_client.is_stream
|
|
27
33
|
|
|
28
|
-
self.prompt_builder = prompt_builder or XGAPromptBuilder()
|
|
34
|
+
self.prompt_builder = prompt_builder or XGAPromptBuilder(system_prompt)
|
|
29
35
|
self.tool_box = tool_box or XGAMcpToolBox()
|
|
30
36
|
|
|
31
|
-
self.
|
|
32
|
-
self.
|
|
37
|
+
self.task_context_msgs: List[XGAContextMsg] = []
|
|
38
|
+
self.task_no = -1
|
|
39
|
+
self.task_run_id = f"{self.task_id}[{self.task_no}]"
|
|
40
|
+
self.trace_id = None
|
|
33
41
|
|
|
34
|
-
async def __async_init__(self) -> None:
|
|
42
|
+
async def __async_init__(self, general_tools:List[str], custom_tools: List[str]) -> None:
|
|
35
43
|
await self.tool_box.load_mcp_tools_schema()
|
|
44
|
+
await self.tool_box.creat_task_tool_box(self.task_id, general_tools, custom_tools)
|
|
45
|
+
general_tool_schemas = self.tool_box.get_task_tool_schemas(self.task_id, "general_tool")
|
|
46
|
+
custom_tool_schemas = self.tool_box.get_task_tool_schemas(self.task_id, "custom_tool")
|
|
47
|
+
|
|
48
|
+
self.task_prompt = self.prompt_builder.build_task_prompt(self.model_name, general_tool_schemas, custom_tool_schemas)
|
|
36
49
|
|
|
37
50
|
@classmethod
|
|
38
51
|
async def create(cls,
|
|
39
52
|
session_id: Optional[str] = None,
|
|
40
|
-
|
|
53
|
+
task_id: Optional[str] = None,
|
|
41
54
|
agent_id: Optional[str] = None,
|
|
42
|
-
|
|
55
|
+
system_prompt: Optional[str] = None,
|
|
56
|
+
general_tools: Optional[List[str]] = None,
|
|
57
|
+
custom_tools: Optional[List[str]] = None,
|
|
58
|
+
llm_config: Optional[LLMConfig] = None,
|
|
43
59
|
prompt_builder: Optional[XGAPromptBuilder] = None,
|
|
44
|
-
tool_box: Optional[XGAToolBox] = None) -> '
|
|
45
|
-
engine:
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
60
|
+
tool_box: Optional[XGAToolBox] = None) -> 'XGATaskEngine':
|
|
61
|
+
engine: XGATaskEngine = cls(session_id=session_id,
|
|
62
|
+
task_id=task_id,
|
|
63
|
+
agent_id=agent_id,
|
|
64
|
+
system_prompt=system_prompt,
|
|
65
|
+
llm_config=llm_config,
|
|
66
|
+
prompt_builder=prompt_builder,
|
|
67
|
+
tool_box=tool_box)
|
|
68
|
+
|
|
69
|
+
general_tools = general_tools or ["complete"]
|
|
70
|
+
custom_tools = custom_tools or []
|
|
71
|
+
await engine.__async_init__(general_tools, custom_tools)
|
|
72
|
+
|
|
73
|
+
logging.info("*"*30 + f" XGATaskEngine Task'{engine.task_id}' Initialized " + "*"*30)
|
|
74
|
+
logging.info(f"model_name={engine.model_name}, is_stream={engine.is_stream}, trace_id={engine.trace_id}")
|
|
75
|
+
logging.info(f"general_tools={general_tools}, custom_tools={custom_tools}")
|
|
76
|
+
|
|
53
77
|
return engine
|
|
54
78
|
|
|
55
79
|
|
|
56
80
|
async def run_task(self,
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
general_tools: Optional[List[str]] = ["*"],
|
|
61
|
-
custom_tools: Optional[List[str]] = []) -> AsyncGenerator:
|
|
81
|
+
task_message: Dict[str, Any],
|
|
82
|
+
max_auto_run: int = 25,
|
|
83
|
+
trace_id: Optional[str] = None) -> AsyncGenerator[Dict[str, Any], None]:
|
|
62
84
|
try:
|
|
63
|
-
self.
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
85
|
+
self.trace_id = trace_id or self.trace_id or langfuse.create_trace_id()
|
|
86
|
+
|
|
87
|
+
self.task_no += 1
|
|
88
|
+
self.task_run_id = f"{self.task_id}[{self.task_no}]"
|
|
89
|
+
|
|
90
|
+
self.add_context_msg(type="user", content=task_message, is_llm_message=True)
|
|
67
91
|
|
|
92
|
+
if max_auto_run <= 1:
|
|
93
|
+
continuous_state:TaskRunContinuousState = {
|
|
94
|
+
"accumulated_content": "",
|
|
95
|
+
"auto_continue_count": 0,
|
|
96
|
+
"auto_continue": False
|
|
97
|
+
}
|
|
98
|
+
async for chunk in self._run_task_once(continuous_state):
|
|
99
|
+
yield chunk
|
|
100
|
+
else:
|
|
101
|
+
async for chunk in self._run_task_auto(max_auto_run):
|
|
102
|
+
yield chunk
|
|
68
103
|
finally:
|
|
69
104
|
await self.tool_box.destroy_task_tool_box(self.task_id)
|
|
70
|
-
|
|
71
105
|
|
|
72
|
-
def _run_task_once(self):
|
|
73
|
-
|
|
106
|
+
async def _run_task_once(self, continuous_state: TaskRunContinuousState) -> AsyncGenerator[Dict[str, Any], None]:
|
|
107
|
+
llm_messages = [{"role": "system", "content": self.task_prompt}]
|
|
108
|
+
cxt_llm_contents = self._get_context_llm_contents()
|
|
109
|
+
llm_messages.extend(cxt_llm_contents)
|
|
110
|
+
|
|
111
|
+
partial_content = continuous_state.get('accumulated_content', '')
|
|
112
|
+
if partial_content:
|
|
113
|
+
temp_assistant_message = {
|
|
114
|
+
"role": "assistant",
|
|
115
|
+
"content": partial_content
|
|
116
|
+
}
|
|
117
|
+
llm_messages.append(temp_assistant_message)
|
|
118
|
+
|
|
119
|
+
llm_response = await self.llm_client.create_completion(llm_messages)
|
|
120
|
+
response_processor = self._create_response_processer()
|
|
121
|
+
|
|
122
|
+
async for chunk in response_processor.process_response(llm_response, llm_messages, continuous_state):
|
|
123
|
+
yield chunk
|
|
74
124
|
|
|
75
|
-
def build_task_prompt(self, system_prompt: Optional[str] = None) -> str:
|
|
76
|
-
task_prompt = self.prompt_builder.build_system_prompt(self.model_name, system_prompt)
|
|
77
125
|
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
126
|
+
async def _run_task_auto(self, max_auto_run: int) -> AsyncGenerator:
|
|
127
|
+
continuous_state: TaskRunContinuousState = {
|
|
128
|
+
"accumulated_content": "",
|
|
129
|
+
"auto_continue_count": 0,
|
|
130
|
+
"auto_continue": True
|
|
131
|
+
}
|
|
81
132
|
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
133
|
+
def update_continuous_state(_auto_continue_count, _auto_continue):
|
|
134
|
+
continuous_state["auto_continue_count"] = _auto_continue_count
|
|
135
|
+
continuous_state["auto_continue"] = _auto_continue
|
|
85
136
|
|
|
86
|
-
|
|
137
|
+
auto_continue_count = 0
|
|
138
|
+
auto_continue = True
|
|
139
|
+
while auto_continue and auto_continue_count < max_auto_run:
|
|
140
|
+
auto_continue = False
|
|
87
141
|
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
142
|
+
try:
|
|
143
|
+
async for chunk in self._run_task_once(continuous_state):
|
|
144
|
+
try:
|
|
145
|
+
if chunk.get("type") == "status":
|
|
146
|
+
content = json.loads(chunk.get('content', '{}'))
|
|
147
|
+
status_type = content.get('status_type', None)
|
|
148
|
+
if status_type == "error":
|
|
149
|
+
logging.error(f"run_task_auto: task_response error: {chunk.get('message', 'Unknown error')}")
|
|
150
|
+
yield chunk
|
|
151
|
+
return
|
|
152
|
+
elif status_type == 'finish':
|
|
153
|
+
finish_reason = content.get('finish_reason', None)
|
|
154
|
+
if finish_reason == 'stop' :
|
|
155
|
+
auto_continue = True
|
|
156
|
+
auto_continue_count += 1
|
|
157
|
+
update_continuous_state(auto_continue_count, auto_continue)
|
|
158
|
+
logging.info(f"run_task_auto: Detected finish_reason='stop', auto-continuing ({auto_continue_count}/{max_auto_run})")
|
|
159
|
+
continue
|
|
160
|
+
elif finish_reason == 'xml_tool_limit_reached':
|
|
161
|
+
logging.info(f"run_task_auto: Detected finish_reason='xml_tool_limit_reached', stopping auto-continue")
|
|
162
|
+
auto_continue = False
|
|
163
|
+
update_continuous_state(auto_continue_count, auto_continue)
|
|
164
|
+
elif finish_reason == 'length':
|
|
165
|
+
auto_continue = True
|
|
166
|
+
auto_continue_count += 1
|
|
167
|
+
update_continuous_state(auto_continue_count, auto_continue)
|
|
168
|
+
logging.info(f"run_task_auto: Detected finish_reason='length', auto-continuing ({auto_continue_count}/{max_auto_run})")
|
|
169
|
+
continue
|
|
170
|
+
except Exception as parse_error:
|
|
171
|
+
logging.error(f"run_task_auto: Error in parse chunk: {str(parse_error)}")
|
|
172
|
+
yield {
|
|
173
|
+
"type": "status",
|
|
174
|
+
"status": "error",
|
|
175
|
+
"message": f"Error in parse chunk: {str(parse_error)}"
|
|
176
|
+
}
|
|
177
|
+
return
|
|
178
|
+
|
|
179
|
+
# Otherwise just yield the chunk normally
|
|
180
|
+
yield chunk
|
|
181
|
+
|
|
182
|
+
# If not auto-continuing, we're done
|
|
183
|
+
if not auto_continue:
|
|
184
|
+
break
|
|
185
|
+
except Exception as run_error:
|
|
186
|
+
logging.error(f"run_task_auto: Call task_run_once error: {str(run_error)}")
|
|
187
|
+
yield {
|
|
188
|
+
"type": "status",
|
|
189
|
+
"status": "error",
|
|
190
|
+
"message": f"Call task_run_once error: {str(run_error)}"
|
|
191
|
+
}
|
|
192
|
+
return
|
|
193
|
+
|
|
194
|
+
def add_context_msg(self, type: Literal["user", "status", "tool", "assistant", "assistant_response_end"],
|
|
195
|
+
content: Union[Dict[str, Any], List[Any], str],
|
|
196
|
+
is_llm_message: bool,
|
|
197
|
+
metadata: Optional[Dict[str, Any]]=None)-> XGAContextMsg:
|
|
198
|
+
message = XGAContextMsg(
|
|
199
|
+
message_id = f"xga_msg_{uuid4()}",
|
|
200
|
+
type = type,
|
|
201
|
+
content = content,
|
|
202
|
+
is_llm_message=is_llm_message,
|
|
203
|
+
metadata = metadata,
|
|
204
|
+
session_id = self.session_id,
|
|
205
|
+
agent_id = self.agent_id,
|
|
206
|
+
task_id = self.task_id,
|
|
207
|
+
task_run_id = self.task_run_id,
|
|
208
|
+
trace_id = self.trace_id
|
|
209
|
+
)
|
|
210
|
+
self.task_context_msgs.append(message)
|
|
211
|
+
|
|
212
|
+
return message
|
|
213
|
+
|
|
214
|
+
def _get_context_llm_contents (self) -> List[Dict[str, Any]]:
|
|
215
|
+
llm_messages = []
|
|
216
|
+
for message in self.task_context_msgs:
|
|
217
|
+
if message["is_llm_message"]:
|
|
218
|
+
llm_messages.append(message)
|
|
219
|
+
|
|
220
|
+
cxt_llm_contents = []
|
|
221
|
+
for llm_message in llm_messages:
|
|
222
|
+
content = llm_message["content"]
|
|
223
|
+
# @todo content List type
|
|
224
|
+
if isinstance(content, str):
|
|
225
|
+
try:
|
|
226
|
+
_content = json.loads(content)
|
|
227
|
+
cxt_llm_contents.append(_content)
|
|
228
|
+
except json.JSONDecodeError as e:
|
|
229
|
+
logging.error(f"get_context_llm_contents: Failed to decode json, content=:{content}")
|
|
230
|
+
handle_error(e)
|
|
231
|
+
else:
|
|
232
|
+
cxt_llm_contents.append(content)
|
|
233
|
+
|
|
234
|
+
return cxt_llm_contents
|
|
235
|
+
|
|
236
|
+
def _create_response_processer(self) -> TaskResponseProcessor:
|
|
237
|
+
response_context = self._create_response_context()
|
|
238
|
+
is_stream = response_context.get("is_stream", False)
|
|
239
|
+
if is_stream:
|
|
240
|
+
from xgae.engine.responser.xga_stream_responser import StreamTaskResponser
|
|
241
|
+
return StreamTaskResponser(response_context)
|
|
242
|
+
else:
|
|
243
|
+
from xgae.engine.responser.xga_non_stream_responser import NonStreamTaskResponser
|
|
244
|
+
return NonStreamTaskResponser(response_context)
|
|
245
|
+
|
|
246
|
+
def _create_response_context(self) -> TaskResponseContext:
|
|
247
|
+
response_context: TaskResponseContext = {
|
|
248
|
+
"is_stream": self.is_stream,
|
|
249
|
+
"task_id": self.task_id,
|
|
250
|
+
"task_run_id": self.task_run_id,
|
|
251
|
+
"trace_id": self.trace_id,
|
|
252
|
+
"model_name": self.model_name,
|
|
253
|
+
"max_xml_tool_calls": 0,
|
|
254
|
+
"add_context_msg": self.add_context_msg,
|
|
255
|
+
"tool_box": self.tool_box,
|
|
256
|
+
"tool_execution_strategy": "parallel",
|
|
257
|
+
"xml_adding_strategy": "user_message",
|
|
258
|
+
}
|
|
259
|
+
return response_context
|
|
93
260
|
|
|
94
261
|
if __name__ == "__main__":
|
|
95
262
|
import asyncio
|
|
96
|
-
|
|
263
|
+
from xgae.utils.utils import read_file
|
|
97
264
|
async def main():
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
engine = await
|
|
101
|
-
|
|
102
|
-
|
|
265
|
+
tool_box = XGAMcpToolBox(custom_mcp_server_file="mcpservers/custom_servers.json")
|
|
266
|
+
system_prompt = read_file("templates/scp_test_prompt.txt")
|
|
267
|
+
engine = await XGATaskEngine.create(tool_box=tool_box,
|
|
268
|
+
general_tools=[],
|
|
269
|
+
custom_tools=["bomc_fault.*"],
|
|
270
|
+
llm_config=LLMConfig(stream=False),
|
|
271
|
+
system_prompt=system_prompt)
|
|
272
|
+
#engine = await XGATaskEngine.create()
|
|
273
|
+
|
|
274
|
+
async for chunk in engine.run_task(task_message={"role": "user", "content": "定位10.0.0.1的故障"},
|
|
275
|
+
max_auto_run=8):
|
|
103
276
|
print(chunk)
|
|
104
277
|
|
|
105
278
|
asyncio.run(main())
|
xgae/engine/xga_mcp_tool_box.py
CHANGED
|
@@ -7,9 +7,8 @@ from typing import List, Any, Dict, Optional, Literal, override
|
|
|
7
7
|
from langchain_mcp_adapters.client import MultiServerMCPClient
|
|
8
8
|
from langchain_mcp_adapters.tools import load_mcp_tools
|
|
9
9
|
|
|
10
|
-
from xgae.engine.xga_base import XGAToolSchema, XGAToolBox, XGAToolResult
|
|
11
|
-
from xgae.utils.setup_env import
|
|
12
|
-
|
|
10
|
+
from xgae.engine.xga_base import XGAError, XGAToolSchema, XGAToolBox, XGAToolResult
|
|
11
|
+
from xgae.utils.setup_env import langfuse
|
|
13
12
|
|
|
14
13
|
class XGAMcpToolBox(XGAToolBox):
|
|
15
14
|
GENERAL_MCP_SERVER_NAME = "xga_general"
|
|
@@ -76,6 +75,12 @@ class XGAMcpToolBox(XGAToolBox):
|
|
|
76
75
|
await self.call_tool(task_id, "end_task", {"task_id": task_id})
|
|
77
76
|
self.task_tool_schemas.pop(task_id, None)
|
|
78
77
|
|
|
78
|
+
@override
|
|
79
|
+
def get_task_tool_names(self, task_id: str) -> List[str]:
|
|
80
|
+
task_tool_schema = self.task_tool_schemas.get(task_id, {})
|
|
81
|
+
task_tool_names = list(task_tool_schema.keys())
|
|
82
|
+
return task_tool_names
|
|
83
|
+
|
|
79
84
|
@override
|
|
80
85
|
def get_task_tool_schemas(self, task_id: str, type: Literal["general_tool", "custom_tool"]) -> List[XGAToolSchema]:
|
|
81
86
|
task_tool_schemas = []
|
|
@@ -1,25 +1,28 @@
|
|
|
1
|
-
import datetime
|
|
2
|
-
import sys
|
|
3
1
|
import json
|
|
4
|
-
|
|
2
|
+
import datetime
|
|
5
3
|
from typing import Optional, List
|
|
6
|
-
from io import StringIO
|
|
7
4
|
|
|
8
|
-
from xga_base import XGAToolSchema
|
|
9
|
-
from xgae.utils.
|
|
5
|
+
from xga_base import XGAToolSchema, XGAError
|
|
6
|
+
from xgae.utils.utils import read_file, format_file_with_args
|
|
10
7
|
|
|
11
8
|
|
|
12
9
|
class XGAPromptBuilder():
|
|
13
|
-
def __init__(self,
|
|
14
|
-
self.system_prompt =
|
|
15
|
-
|
|
16
|
-
|
|
10
|
+
def __init__(self, system_prompt: Optional[str] = None):
|
|
11
|
+
self.system_prompt = system_prompt
|
|
12
|
+
|
|
13
|
+
def build_task_prompt(self, model_name: str, general_tool_schemas: List[XGAToolSchema], custom_tool_schemas: List[XGAToolSchema])-> str:
|
|
14
|
+
if self.system_prompt is None:
|
|
15
|
+
self.system_prompt = self._load_default_system_prompt(model_name)
|
|
16
|
+
|
|
17
|
+
task_prompt = self.system_prompt
|
|
18
|
+
|
|
19
|
+
tool_prompt = self.build_general_tool_prompt(general_tool_schemas)
|
|
20
|
+
task_prompt = task_prompt + "\n" + tool_prompt
|
|
21
|
+
|
|
22
|
+
tool_prompt = self.build_custom_tool_prompt(custom_tool_schemas)
|
|
23
|
+
task_prompt = task_prompt + "\n" + tool_prompt
|
|
17
24
|
|
|
18
|
-
|
|
19
|
-
task_system_prompt = system_prompt if system_prompt else self.system_prompt
|
|
20
|
-
if task_system_prompt is None:
|
|
21
|
-
task_system_prompt = self._load_default_system_prompt(model_name)
|
|
22
|
-
return task_system_prompt
|
|
25
|
+
return task_prompt
|
|
23
26
|
|
|
24
27
|
def build_general_tool_prompt(self, tool_schemas:List[XGAToolSchema])-> str:
|
|
25
28
|
tool_prompt = ""
|
|
@@ -41,7 +44,7 @@ class XGAPromptBuilder():
|
|
|
41
44
|
input_schema = tool_schema.input_schema
|
|
42
45
|
openai_function["parameters"] = openai_parameters
|
|
43
46
|
openai_parameters["type"] = input_schema["type"]
|
|
44
|
-
openai_parameters["properties"] = input_schema.get("properties",
|
|
47
|
+
openai_parameters["properties"] = input_schema.get("properties", {})
|
|
45
48
|
openai_parameters["required"] = input_schema["required"]
|
|
46
49
|
|
|
47
50
|
openai_schemas.append(openai_schema)
|
|
@@ -65,32 +68,21 @@ class XGAPromptBuilder():
|
|
|
65
68
|
for tool_schema in tool_schemas:
|
|
66
69
|
description = tool_schema.description if tool_schema.description else 'No description available'
|
|
67
70
|
tool_info += f"- **{tool_schema.tool_name}**: {description}\n"
|
|
68
|
-
|
|
71
|
+
parameters = tool_schema.input_schema.get("properties", {})
|
|
72
|
+
tool_info += f" Parameters: {parameters}\n"
|
|
69
73
|
tool_prompt = tool_prompt.replace("{tool_schemas}", tool_info)
|
|
70
74
|
|
|
71
75
|
return tool_prompt
|
|
72
76
|
|
|
73
77
|
def _load_default_system_prompt(self, model_name) -> Optional[str]:
|
|
74
78
|
if "gemini-2.5-flash" in model_name.lower() and "gemini-2.5-pro" not in model_name.lower():
|
|
75
|
-
|
|
79
|
+
system_prompt_template = read_file("templates/gemini_system_prompt_template.txt")
|
|
76
80
|
else:
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
try:
|
|
83
|
-
namespace = {
|
|
84
|
-
"datetime": datetime,
|
|
85
|
-
"__builtins__": __builtins__
|
|
86
|
-
}
|
|
87
|
-
code = f"print(f\"\"\"{org_prompt_template}\"\"\")"
|
|
88
|
-
exec(code, namespace)
|
|
89
|
-
system_prompt_template = buffer.getvalue()
|
|
90
|
-
finally:
|
|
91
|
-
sys.stdout = original_stdout
|
|
92
|
-
|
|
93
|
-
system_prompt = system_prompt_template.format(
|
|
81
|
+
system_prompt_template = read_file("templates/system_prompt_template.txt")
|
|
82
|
+
|
|
83
|
+
system_prompt = format_file_with_args(system_prompt_template, {"datetime": datetime})
|
|
84
|
+
|
|
85
|
+
system_prompt = system_prompt.format(
|
|
94
86
|
current_date=datetime.datetime.now(datetime.timezone.utc).strftime('%Y-%m-%d'),
|
|
95
87
|
current_time=datetime.datetime.now(datetime.timezone.utc).strftime('%H:%M:%S'),
|
|
96
88
|
current_year=datetime.datetime.now(datetime.timezone.utc).strftime('%Y')
|
|
@@ -102,14 +94,4 @@ class XGAPromptBuilder():
|
|
|
102
94
|
|
|
103
95
|
return system_prompt
|
|
104
96
|
|
|
105
|
-
if __name__ == "__main__":
|
|
106
|
-
|
|
107
|
-
prompt_builder = XGAPromptBuilder()
|
|
108
|
-
prompt = prompt_builder.build_system_prompt("openai/qwen3-235b-a22b")
|
|
109
|
-
|
|
110
|
-
# system_prompt = read_file("templates/scp_test_prompt.txt")
|
|
111
|
-
# prompt = prompt_builder.build_system_prompt("openai/qwen3-235b-a22b", system_prompt=system_prompt)
|
|
112
|
-
|
|
113
|
-
print(prompt)
|
|
114
|
-
|
|
115
97
|
|