xgae 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xgae might be problematic. Click here for more details.

xgae/engine/xga_base.py CHANGED
@@ -1,19 +1,29 @@
1
- from typing import Union, Optional, Dict, List, Any, Literal
1
+ from typing import Union, Optional, Dict, List, Any, Literal, TypedDict
2
2
  from dataclasses import dataclass
3
3
  from abc import ABC, abstractmethod
4
4
 
5
+ class XGAError(Exception):
6
+ """Custom exception for errors in the XGA system."""
7
+ pass
5
8
 
6
9
 
7
- @dataclass
8
- class XGAMessage:
9
- message_id: str
10
- type: Literal["status", "tool", "assistant", "assistant_response_end"]
10
+ class XGAContextMsg(TypedDict, total=False):
11
+ type: Literal["user", "status", "tool", "assistant", "assistant_response_end"]
12
+ content: Union[Dict[str, Any], List[Any], str]
11
13
  is_llm_message: bool
14
+ metadata: Dict[str, Any]
15
+ message_id: str
16
+ session_id: str
17
+ agent_id: str
18
+ task_id: str
19
+ task_run_id: str
20
+ trace_id: str
21
+
22
+ class XGAResponseMsg(TypedDict, total=False):
23
+ type: Literal["content", "status"]
12
24
  content: Union[Dict[str, Any], List[Any], str]
13
- metadata: Optional[Dict[str, Any]]
14
- session_id: Optional[str]
15
- agent_id: Optional[str]
16
- task_id: Optional[str]
25
+ status: Literal["error", "status"]
26
+ message: str
17
27
 
18
28
  @dataclass
19
29
  class XGAToolSchema:
@@ -42,6 +52,10 @@ class XGAToolBox(ABC):
42
52
  def get_task_tool_schemas(self, task_id: str, type: Literal["general_tool", "custom_tool"]) -> List[XGAToolSchema]:
43
53
  pass
44
54
 
55
+ @abstractmethod
56
+ def get_task_tool_names(self, task_id: str) -> List[str]:
57
+ pass
58
+
45
59
  @abstractmethod
46
60
  async def call_tool(self, task_id: str, tool_name: str, args: Optional[Dict[str, Any]] = None) -> XGAToolResult:
47
61
  pass
xgae/engine/xga_engine.py CHANGED
@@ -1,105 +1,278 @@
1
1
 
2
- from typing import List, Any, Dict, Optional, AsyncGenerator
2
+ import logging
3
+ import json
4
+
5
+ from typing import List, Any, Dict, Optional, AsyncGenerator, cast, Union, Literal
3
6
  from uuid import uuid4
4
7
 
5
- from xgae.engine.xga_base import XGAMessage, XGAToolSchema, XGAToolBox
6
- from xgae.utils.llm_client import LLMClient
8
+ from xgae.engine.responser.xga_responser_base import TaskResponseContext, TaskResponseProcessor, TaskRunContinuousState
9
+ from xgae.engine.xga_base import XGAContextMsg, XGAToolBox, XGAResponseMsg
10
+ from xgae.utils.llm_client import LLMClient, LLMConfig
7
11
  from xgae.utils.setup_env import langfuse
12
+ from xgae.utils.utils import handle_error
13
+
8
14
  from xga_prompt_builder import XGAPromptBuilder
9
15
  from xga_mcp_tool_box import XGAMcpToolBox
10
16
 
11
-
12
- class XGAEngine():
17
+ class XGATaskEngine:
13
18
  def __init__(self,
14
19
  session_id: Optional[str] = None,
15
- trace_id: Optional[str] = None,
20
+ task_id: Optional[str] = None,
16
21
  agent_id: Optional[str] = None,
17
- llm_config: Optional[Dict[str, Any]] = None,
22
+ system_prompt: Optional[str] = None,
23
+ llm_config: Optional[LLMConfig] = None,
18
24
  prompt_builder: Optional[XGAPromptBuilder] = None,
19
25
  tool_box: Optional[XGAToolBox] = None):
20
26
  self.session_id = session_id if session_id else f"xga_sid_{uuid4()}"
27
+ self.task_id = task_id if task_id else f"xga_task_{uuid4()}"
21
28
  self.agent_id = agent_id
22
29
 
23
- self.messages: List[XGAMessage] = []
24
30
  self.llm_client = LLMClient(llm_config)
25
31
  self.model_name = self.llm_client.model_name
26
32
  self.is_stream = self.llm_client.is_stream
27
33
 
28
- self.prompt_builder = prompt_builder or XGAPromptBuilder()
34
+ self.prompt_builder = prompt_builder or XGAPromptBuilder(system_prompt)
29
35
  self.tool_box = tool_box or XGAMcpToolBox()
30
36
 
31
- self.task_id = None
32
- self.trace_id = trace_id if trace_id else langfuse.create_trace_id()
37
+ self.task_context_msgs: List[XGAContextMsg] = []
38
+ self.task_no = -1
39
+ self.task_run_id = f"{self.task_id}[{self.task_no}]"
40
+ self.trace_id = None
33
41
 
34
- async def __async_init__(self) -> None:
42
+ async def __async_init__(self, general_tools:List[str], custom_tools: List[str]) -> None:
35
43
  await self.tool_box.load_mcp_tools_schema()
44
+ await self.tool_box.creat_task_tool_box(self.task_id, general_tools, custom_tools)
45
+ general_tool_schemas = self.tool_box.get_task_tool_schemas(self.task_id, "general_tool")
46
+ custom_tool_schemas = self.tool_box.get_task_tool_schemas(self.task_id, "custom_tool")
47
+
48
+ self.task_prompt = self.prompt_builder.build_task_prompt(self.model_name, general_tool_schemas, custom_tool_schemas)
36
49
 
37
50
  @classmethod
38
51
  async def create(cls,
39
52
  session_id: Optional[str] = None,
40
- trace_id: Optional[str] = None,
53
+ task_id: Optional[str] = None,
41
54
  agent_id: Optional[str] = None,
42
- llm_config: Optional[Dict[str, Any]] = None,
55
+ system_prompt: Optional[str] = None,
56
+ general_tools: Optional[List[str]] = None,
57
+ custom_tools: Optional[List[str]] = None,
58
+ llm_config: Optional[LLMConfig] = None,
43
59
  prompt_builder: Optional[XGAPromptBuilder] = None,
44
- tool_box: Optional[XGAToolBox] = None) -> 'XGAEngine' :
45
- engine: XGAEngine = cls(session_id=session_id,
46
- trace_id=trace_id,
47
- agent_id=agent_id,
48
- llm_config=llm_config,
49
- prompt_builder=prompt_builder,
50
- tool_box=tool_box)
51
-
52
- await engine.__async_init__()
60
+ tool_box: Optional[XGAToolBox] = None) -> 'XGATaskEngine':
61
+ engine: XGATaskEngine = cls(session_id=session_id,
62
+ task_id=task_id,
63
+ agent_id=agent_id,
64
+ system_prompt=system_prompt,
65
+ llm_config=llm_config,
66
+ prompt_builder=prompt_builder,
67
+ tool_box=tool_box)
68
+
69
+ general_tools = general_tools or ["complete"]
70
+ custom_tools = custom_tools or []
71
+ await engine.__async_init__(general_tools, custom_tools)
72
+
73
+ logging.info("*"*30 + f" XGATaskEngine Task'{engine.task_id}' Initialized " + "*"*30)
74
+ logging.info(f"model_name={engine.model_name}, is_stream={engine.is_stream}, trace_id={engine.trace_id}")
75
+ logging.info(f"general_tools={general_tools}, custom_tools={custom_tools}")
76
+
53
77
  return engine
54
78
 
55
79
 
56
80
  async def run_task(self,
57
- task_messages: List[Dict[str, Any]],
58
- task_id: Optional[str] = None,
59
- system_prompt: Optional[str] = None,
60
- general_tools: Optional[List[str]] = ["*"],
61
- custom_tools: Optional[List[str]] = []) -> AsyncGenerator:
81
+ task_message: Dict[str, Any],
82
+ max_auto_run: int = 25,
83
+ trace_id: Optional[str] = None) -> AsyncGenerator[Dict[str, Any], None]:
62
84
  try:
63
- self.task_id = task_id if task_id else f"xga_task_{uuid4()}"
64
- await self.tool_box.creat_task_tool_box(self.task_id, general_tools, custom_tools)
65
- task_prompt = self.build_task_prompt(system_prompt)
66
- yield task_prompt
85
+ self.trace_id = trace_id or self.trace_id or langfuse.create_trace_id()
86
+
87
+ self.task_no += 1
88
+ self.task_run_id = f"{self.task_id}[{self.task_no}]"
89
+
90
+ self.add_context_msg(type="user", content=task_message, is_llm_message=True)
67
91
 
92
+ if max_auto_run <= 1:
93
+ continuous_state:TaskRunContinuousState = {
94
+ "accumulated_content": "",
95
+ "auto_continue_count": 0,
96
+ "auto_continue": False
97
+ }
98
+ async for chunk in self._run_task_once(continuous_state):
99
+ yield chunk
100
+ else:
101
+ async for chunk in self._run_task_auto(max_auto_run):
102
+ yield chunk
68
103
  finally:
69
104
  await self.tool_box.destroy_task_tool_box(self.task_id)
70
-
71
105
 
72
- def _run_task_once(self):
73
- pass
106
+ async def _run_task_once(self, continuous_state: TaskRunContinuousState) -> AsyncGenerator[Dict[str, Any], None]:
107
+ llm_messages = [{"role": "system", "content": self.task_prompt}]
108
+ cxt_llm_contents = self._get_context_llm_contents()
109
+ llm_messages.extend(cxt_llm_contents)
110
+
111
+ partial_content = continuous_state.get('accumulated_content', '')
112
+ if partial_content:
113
+ temp_assistant_message = {
114
+ "role": "assistant",
115
+ "content": partial_content
116
+ }
117
+ llm_messages.append(temp_assistant_message)
118
+
119
+ llm_response = await self.llm_client.create_completion(llm_messages)
120
+ response_processor = self._create_response_processer()
121
+
122
+ async for chunk in response_processor.process_response(llm_response, llm_messages, continuous_state):
123
+ yield chunk
74
124
 
75
- def build_task_prompt(self, system_prompt: Optional[str] = None) -> str:
76
- task_prompt = self.prompt_builder.build_system_prompt(self.model_name, system_prompt)
77
125
 
78
- tool_schemas =self.tool_box.get_task_tool_schemas(self.task_id, "general_tool")
79
- tool_prompt = self.prompt_builder.build_general_tool_prompt(tool_schemas)
80
- task_prompt = task_prompt + "\n" + tool_prompt
126
+ async def _run_task_auto(self, max_auto_run: int) -> AsyncGenerator:
127
+ continuous_state: TaskRunContinuousState = {
128
+ "accumulated_content": "",
129
+ "auto_continue_count": 0,
130
+ "auto_continue": True
131
+ }
81
132
 
82
- tool_schemas = self.tool_box.get_task_tool_schemas(self.task_id, "custom_tool")
83
- tool_prompt = self.prompt_builder.build_custom_tool_prompt(tool_schemas)
84
- task_prompt = task_prompt + "\n" + tool_prompt
133
+ def update_continuous_state(_auto_continue_count, _auto_continue):
134
+ continuous_state["auto_continue_count"] = _auto_continue_count
135
+ continuous_state["auto_continue"] = _auto_continue
85
136
 
86
- return task_prompt
137
+ auto_continue_count = 0
138
+ auto_continue = True
139
+ while auto_continue and auto_continue_count < max_auto_run:
140
+ auto_continue = False
87
141
 
88
- def add_message(self, message: XGAMessage):
89
- message.message_id = f"xga_msg_{uuid4()}"
90
- message.session_id = self.session_id
91
- message.agent_id = self.agent_id
92
- self.messages.append(message)
142
+ try:
143
+ async for chunk in self._run_task_once(continuous_state):
144
+ try:
145
+ if chunk.get("type") == "status":
146
+ content = json.loads(chunk.get('content', '{}'))
147
+ status_type = content.get('status_type', None)
148
+ if status_type == "error":
149
+ logging.error(f"run_task_auto: task_response error: {chunk.get('message', 'Unknown error')}")
150
+ yield chunk
151
+ return
152
+ elif status_type == 'finish':
153
+ finish_reason = content.get('finish_reason', None)
154
+ if finish_reason == 'stop' :
155
+ auto_continue = True
156
+ auto_continue_count += 1
157
+ update_continuous_state(auto_continue_count, auto_continue)
158
+ logging.info(f"run_task_auto: Detected finish_reason='stop', auto-continuing ({auto_continue_count}/{max_auto_run})")
159
+ continue
160
+ elif finish_reason == 'xml_tool_limit_reached':
161
+ logging.info(f"run_task_auto: Detected finish_reason='xml_tool_limit_reached', stopping auto-continue")
162
+ auto_continue = False
163
+ update_continuous_state(auto_continue_count, auto_continue)
164
+ elif finish_reason == 'length':
165
+ auto_continue = True
166
+ auto_continue_count += 1
167
+ update_continuous_state(auto_continue_count, auto_continue)
168
+ logging.info(f"run_task_auto: Detected finish_reason='length', auto-continuing ({auto_continue_count}/{max_auto_run})")
169
+ continue
170
+ except Exception as parse_error:
171
+ logging.error(f"run_task_auto: Error in parse chunk: {str(parse_error)}")
172
+ yield {
173
+ "type": "status",
174
+ "status": "error",
175
+ "message": f"Error in parse chunk: {str(parse_error)}"
176
+ }
177
+ return
178
+
179
+ # Otherwise just yield the chunk normally
180
+ yield chunk
181
+
182
+ # If not auto-continuing, we're done
183
+ if not auto_continue:
184
+ break
185
+ except Exception as run_error:
186
+ logging.error(f"run_task_auto: Call task_run_once error: {str(run_error)}")
187
+ yield {
188
+ "type": "status",
189
+ "status": "error",
190
+ "message": f"Call task_run_once error: {str(run_error)}"
191
+ }
192
+ return
193
+
194
+ def add_context_msg(self, type: Literal["user", "status", "tool", "assistant", "assistant_response_end"],
195
+ content: Union[Dict[str, Any], List[Any], str],
196
+ is_llm_message: bool,
197
+ metadata: Optional[Dict[str, Any]]=None)-> XGAContextMsg:
198
+ message = XGAContextMsg(
199
+ message_id = f"xga_msg_{uuid4()}",
200
+ type = type,
201
+ content = content,
202
+ is_llm_message=is_llm_message,
203
+ metadata = metadata,
204
+ session_id = self.session_id,
205
+ agent_id = self.agent_id,
206
+ task_id = self.task_id,
207
+ task_run_id = self.task_run_id,
208
+ trace_id = self.trace_id
209
+ )
210
+ self.task_context_msgs.append(message)
211
+
212
+ return message
213
+
214
+ def _get_context_llm_contents (self) -> List[Dict[str, Any]]:
215
+ llm_messages = []
216
+ for message in self.task_context_msgs:
217
+ if message["is_llm_message"]:
218
+ llm_messages.append(message)
219
+
220
+ cxt_llm_contents = []
221
+ for llm_message in llm_messages:
222
+ content = llm_message["content"]
223
+ # @todo content List type
224
+ if isinstance(content, str):
225
+ try:
226
+ _content = json.loads(content)
227
+ cxt_llm_contents.append(_content)
228
+ except json.JSONDecodeError as e:
229
+ logging.error(f"get_context_llm_contents: Failed to decode json, content=:{content}")
230
+ handle_error(e)
231
+ else:
232
+ cxt_llm_contents.append(content)
233
+
234
+ return cxt_llm_contents
235
+
236
+ def _create_response_processer(self) -> TaskResponseProcessor:
237
+ response_context = self._create_response_context()
238
+ is_stream = response_context.get("is_stream", False)
239
+ if is_stream:
240
+ from xgae.engine.responser.xga_stream_responser import StreamTaskResponser
241
+ return StreamTaskResponser(response_context)
242
+ else:
243
+ from xgae.engine.responser.xga_non_stream_responser import NonStreamTaskResponser
244
+ return NonStreamTaskResponser(response_context)
245
+
246
+ def _create_response_context(self) -> TaskResponseContext:
247
+ response_context: TaskResponseContext = {
248
+ "is_stream": self.is_stream,
249
+ "task_id": self.task_id,
250
+ "task_run_id": self.task_run_id,
251
+ "trace_id": self.trace_id,
252
+ "model_name": self.model_name,
253
+ "max_xml_tool_calls": 0,
254
+ "add_context_msg": self.add_context_msg,
255
+ "tool_box": self.tool_box,
256
+ "tool_execution_strategy": "parallel",
257
+ "xml_adding_strategy": "user_message",
258
+ }
259
+ return response_context
93
260
 
94
261
  if __name__ == "__main__":
95
262
  import asyncio
96
-
263
+ from xgae.utils.utils import read_file
97
264
  async def main():
98
- #tool_box = XGAMcpToolBox(custom_mcp_server_file="mcpservers/custom_servers.json")
99
- tool_box = None
100
- engine = await XGAEngine.create(tool_box=tool_box)
101
- # async for chunk in engine.run_task(task_messages=[{}], custom_tools=["bomc_fault.*"]):
102
- async for chunk in engine.run_task(task_messages=[{}], custom_tools=[]):
265
+ tool_box = XGAMcpToolBox(custom_mcp_server_file="mcpservers/custom_servers.json")
266
+ system_prompt = read_file("templates/scp_test_prompt.txt")
267
+ engine = await XGATaskEngine.create(tool_box=tool_box,
268
+ general_tools=[],
269
+ custom_tools=["bomc_fault.*"],
270
+ llm_config=LLMConfig(stream=False),
271
+ system_prompt=system_prompt)
272
+ #engine = await XGATaskEngine.create()
273
+
274
+ async for chunk in engine.run_task(task_message={"role": "user", "content": "定位10.0.0.1的故障"},
275
+ max_auto_run=8):
103
276
  print(chunk)
104
277
 
105
278
  asyncio.run(main())
@@ -7,9 +7,8 @@ from typing import List, Any, Dict, Optional, Literal, override
7
7
  from langchain_mcp_adapters.client import MultiServerMCPClient
8
8
  from langchain_mcp_adapters.tools import load_mcp_tools
9
9
 
10
- from xgae.engine.xga_base import XGAToolSchema, XGAToolBox, XGAToolResult
11
- from xgae.utils.setup_env import XGAError
12
-
10
+ from xgae.engine.xga_base import XGAError, XGAToolSchema, XGAToolBox, XGAToolResult
11
+ from xgae.utils.setup_env import langfuse
13
12
 
14
13
  class XGAMcpToolBox(XGAToolBox):
15
14
  GENERAL_MCP_SERVER_NAME = "xga_general"
@@ -76,6 +75,12 @@ class XGAMcpToolBox(XGAToolBox):
76
75
  await self.call_tool(task_id, "end_task", {"task_id": task_id})
77
76
  self.task_tool_schemas.pop(task_id, None)
78
77
 
78
+ @override
79
+ def get_task_tool_names(self, task_id: str) -> List[str]:
80
+ task_tool_schema = self.task_tool_schemas.get(task_id, {})
81
+ task_tool_names = list(task_tool_schema.keys())
82
+ return task_tool_names
83
+
79
84
  @override
80
85
  def get_task_tool_schemas(self, task_id: str, type: Literal["general_tool", "custom_tool"]) -> List[XGAToolSchema]:
81
86
  task_tool_schemas = []
@@ -1,25 +1,28 @@
1
- import datetime
2
- import sys
3
1
  import json
4
-
2
+ import datetime
5
3
  from typing import Optional, List
6
- from io import StringIO
7
4
 
8
- from xga_base import XGAToolSchema
9
- from xgae.utils.setup_env import read_file, XGAError
5
+ from xga_base import XGAToolSchema, XGAError
6
+ from xgae.utils.utils import read_file, format_file_with_args
10
7
 
11
8
 
12
9
  class XGAPromptBuilder():
13
- def __init__(self, system_prompt_template: Optional[str] = None):
14
- self.system_prompt = None
15
- if system_prompt_template:
16
- self.system_prompt = system_prompt_template
10
+ def __init__(self, system_prompt: Optional[str] = None):
11
+ self.system_prompt = system_prompt
12
+
13
+ def build_task_prompt(self, model_name: str, general_tool_schemas: List[XGAToolSchema], custom_tool_schemas: List[XGAToolSchema])-> str:
14
+ if self.system_prompt is None:
15
+ self.system_prompt = self._load_default_system_prompt(model_name)
16
+
17
+ task_prompt = self.system_prompt
18
+
19
+ tool_prompt = self.build_general_tool_prompt(general_tool_schemas)
20
+ task_prompt = task_prompt + "\n" + tool_prompt
21
+
22
+ tool_prompt = self.build_custom_tool_prompt(custom_tool_schemas)
23
+ task_prompt = task_prompt + "\n" + tool_prompt
17
24
 
18
- def build_system_prompt(self, model_name: str, system_prompt: Optional[str]=None)-> str:
19
- task_system_prompt = system_prompt if system_prompt else self.system_prompt
20
- if task_system_prompt is None:
21
- task_system_prompt = self._load_default_system_prompt(model_name)
22
- return task_system_prompt
25
+ return task_prompt
23
26
 
24
27
  def build_general_tool_prompt(self, tool_schemas:List[XGAToolSchema])-> str:
25
28
  tool_prompt = ""
@@ -41,7 +44,7 @@ class XGAPromptBuilder():
41
44
  input_schema = tool_schema.input_schema
42
45
  openai_function["parameters"] = openai_parameters
43
46
  openai_parameters["type"] = input_schema["type"]
44
- openai_parameters["properties"] = input_schema.get("properties", [])
47
+ openai_parameters["properties"] = input_schema.get("properties", {})
45
48
  openai_parameters["required"] = input_schema["required"]
46
49
 
47
50
  openai_schemas.append(openai_schema)
@@ -65,32 +68,21 @@ class XGAPromptBuilder():
65
68
  for tool_schema in tool_schemas:
66
69
  description = tool_schema.description if tool_schema.description else 'No description available'
67
70
  tool_info += f"- **{tool_schema.tool_name}**: {description}\n"
68
- tool_info += f" Parameters: {tool_schema.input_schema}\n"
71
+ parameters = tool_schema.input_schema.get("properties", {})
72
+ tool_info += f" Parameters: {parameters}\n"
69
73
  tool_prompt = tool_prompt.replace("{tool_schemas}", tool_info)
70
74
 
71
75
  return tool_prompt
72
76
 
73
77
  def _load_default_system_prompt(self, model_name) -> Optional[str]:
74
78
  if "gemini-2.5-flash" in model_name.lower() and "gemini-2.5-pro" not in model_name.lower():
75
- org_prompt_template = read_file("templates/gemini_system_prompt_template.txt")
79
+ system_prompt_template = read_file("templates/gemini_system_prompt_template.txt")
76
80
  else:
77
- org_prompt_template = read_file("templates/system_prompt_template.txt")
78
-
79
- original_stdout = sys.stdout
80
- buffer = StringIO()
81
- sys.stdout = buffer
82
- try:
83
- namespace = {
84
- "datetime": datetime,
85
- "__builtins__": __builtins__
86
- }
87
- code = f"print(f\"\"\"{org_prompt_template}\"\"\")"
88
- exec(code, namespace)
89
- system_prompt_template = buffer.getvalue()
90
- finally:
91
- sys.stdout = original_stdout
92
-
93
- system_prompt = system_prompt_template.format(
81
+ system_prompt_template = read_file("templates/system_prompt_template.txt")
82
+
83
+ system_prompt = format_file_with_args(system_prompt_template, {"datetime": datetime})
84
+
85
+ system_prompt = system_prompt.format(
94
86
  current_date=datetime.datetime.now(datetime.timezone.utc).strftime('%Y-%m-%d'),
95
87
  current_time=datetime.datetime.now(datetime.timezone.utc).strftime('%H:%M:%S'),
96
88
  current_year=datetime.datetime.now(datetime.timezone.utc).strftime('%Y')
@@ -102,14 +94,4 @@ class XGAPromptBuilder():
102
94
 
103
95
  return system_prompt
104
96
 
105
- if __name__ == "__main__":
106
-
107
- prompt_builder = XGAPromptBuilder()
108
- prompt = prompt_builder.build_system_prompt("openai/qwen3-235b-a22b")
109
-
110
- # system_prompt = read_file("templates/scp_test_prompt.txt")
111
- # prompt = prompt_builder.build_system_prompt("openai/qwen3-235b-a22b", system_prompt=system_prompt)
112
-
113
- print(prompt)
114
-
115
97