xgae 0.1.3__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xgae might be problematic. Click here for more details.

xgae/engine/xga_base.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Union, Optional, Dict, List, Any, Literal
1
+ from typing import Union, Optional, Dict, List, Any, Literal, TypedDict
2
2
  from dataclasses import dataclass
3
3
  from abc import ABC, abstractmethod
4
4
 
@@ -6,16 +6,24 @@ class XGAError(Exception):
6
6
  """Custom exception for errors in the XGA system."""
7
7
  pass
8
8
 
9
- @dataclass
10
- class XGAMessage:
11
- message_id: str
12
- type: Literal["status", "tool", "assistant", "assistant_response_end"]
9
+
10
+ class XGAContextMsg(TypedDict, total=False):
11
+ type: Literal["user", "status", "tool", "assistant", "assistant_response_end"]
12
+ content: Union[Dict[str, Any], List[Any], str]
13
13
  is_llm_message: bool
14
+ metadata: Dict[str, Any]
15
+ message_id: str
16
+ session_id: str
17
+ agent_id: str
18
+ task_id: str
19
+ task_run_id: str
20
+ trace_id: str
21
+
22
+ class XGAResponseMsg(TypedDict, total=False):
23
+ type: Literal["content", "status"]
14
24
  content: Union[Dict[str, Any], List[Any], str]
15
- metadata: Optional[Dict[str, Any]]
16
- session_id: Optional[str]
17
- agent_id: Optional[str]
18
- task_id: Optional[str]
25
+ status: Literal["error", "status"]
26
+ message: str
19
27
 
20
28
  @dataclass
21
29
  class XGAToolSchema:
@@ -44,6 +52,10 @@ class XGAToolBox(ABC):
44
52
  def get_task_tool_schemas(self, task_id: str, type: Literal["general_tool", "custom_tool"]) -> List[XGAToolSchema]:
45
53
  pass
46
54
 
55
+ @abstractmethod
56
+ def get_task_tool_names(self, task_id: str) -> List[str]:
57
+ pass
58
+
47
59
  @abstractmethod
48
60
  async def call_tool(self, task_id: str, tool_name: str, args: Optional[Dict[str, Any]] = None) -> XGAToolResult:
49
61
  pass
xgae/engine/xga_engine.py CHANGED
@@ -1,30 +1,32 @@
1
1
 
2
- from typing import List, Any, Dict, Optional, AsyncGenerator
2
+ import logging
3
+ import json
4
+
5
+ from typing import List, Any, Dict, Optional, AsyncGenerator, cast, Union, Literal
3
6
  from uuid import uuid4
4
7
 
5
- from xgae.engine.xga_base import XGAMessage, XGAToolBox
6
- from xgae.utils.llm_client import LLMClient
8
+ from xgae.engine.responser.xga_responser_base import TaskResponseContext, TaskResponseProcessor, TaskRunContinuousState
9
+ from xgae.engine.xga_base import XGAContextMsg, XGAToolBox, XGAResponseMsg
10
+ from xgae.utils.llm_client import LLMClient, LLMConfig
7
11
  from xgae.utils.setup_env import langfuse
12
+ from xgae.utils.utils import handle_error
13
+
8
14
  from xga_prompt_builder import XGAPromptBuilder
9
15
  from xga_mcp_tool_box import XGAMcpToolBox
10
16
 
11
-
12
- class XGATaskEngine():
17
+ class XGATaskEngine:
13
18
  def __init__(self,
14
19
  session_id: Optional[str] = None,
15
20
  task_id: Optional[str] = None,
16
21
  agent_id: Optional[str] = None,
17
- trace_id: Optional[str] = None,
18
22
  system_prompt: Optional[str] = None,
19
- llm_config: Optional[Dict[str, Any]] = None,
23
+ llm_config: Optional[LLMConfig] = None,
20
24
  prompt_builder: Optional[XGAPromptBuilder] = None,
21
25
  tool_box: Optional[XGAToolBox] = None):
22
26
  self.session_id = session_id if session_id else f"xga_sid_{uuid4()}"
23
27
  self.task_id = task_id if task_id else f"xga_task_{uuid4()}"
24
28
  self.agent_id = agent_id
25
- self.trace_id = trace_id if trace_id else langfuse.create_trace_id()
26
29
 
27
- self.messages: List[XGAMessage] = []
28
30
  self.llm_client = LLMClient(llm_config)
29
31
  self.model_name = self.llm_client.model_name
30
32
  self.is_stream = self.llm_client.is_stream
@@ -32,6 +34,10 @@ class XGATaskEngine():
32
34
  self.prompt_builder = prompt_builder or XGAPromptBuilder(system_prompt)
33
35
  self.tool_box = tool_box or XGAMcpToolBox()
34
36
 
37
+ self.task_context_msgs: List[XGAContextMsg] = []
38
+ self.task_no = -1
39
+ self.task_run_id = f"{self.task_id}[{self.task_no}]"
40
+ self.trace_id = None
35
41
 
36
42
  async def __async_init__(self, general_tools:List[str], custom_tools: List[str]) -> None:
37
43
  await self.tool_box.load_mcp_tools_schema()
@@ -46,54 +52,227 @@ class XGATaskEngine():
46
52
  session_id: Optional[str] = None,
47
53
  task_id: Optional[str] = None,
48
54
  agent_id: Optional[str] = None,
49
- trace_id: Optional[str] = None,
50
55
  system_prompt: Optional[str] = None,
51
56
  general_tools: Optional[List[str]] = None,
52
57
  custom_tools: Optional[List[str]] = None,
53
- llm_config: Optional[Dict[str, Any]] = None,
58
+ llm_config: Optional[LLMConfig] = None,
54
59
  prompt_builder: Optional[XGAPromptBuilder] = None,
55
60
  tool_box: Optional[XGAToolBox] = None) -> 'XGATaskEngine':
56
61
  engine: XGATaskEngine = cls(session_id=session_id,
57
62
  task_id=task_id,
58
63
  agent_id=agent_id,
59
- trace_id=trace_id,
60
64
  system_prompt=system_prompt,
61
65
  llm_config=llm_config,
62
66
  prompt_builder=prompt_builder,
63
67
  tool_box=tool_box)
64
- general_tools = general_tools or ["*"]
68
+
69
+ general_tools = general_tools or ["complete"]
65
70
  custom_tools = custom_tools or []
66
71
  await engine.__async_init__(general_tools, custom_tools)
72
+
73
+ logging.info("*"*30 + f" XGATaskEngine Task'{engine.task_id}' Initialized " + "*"*30)
74
+ logging.info(f"model_name={engine.model_name}, is_stream={engine.is_stream}, trace_id={engine.trace_id}")
75
+ logging.info(f"general_tools={general_tools}, custom_tools={custom_tools}")
76
+
67
77
  return engine
68
78
 
69
79
 
70
- async def run_task(self, task_messages: List[Dict[str, Any]]) -> AsyncGenerator:
80
+ async def run_task(self,
81
+ task_message: Dict[str, Any],
82
+ max_auto_run: int = 25,
83
+ trace_id: Optional[str] = None) -> AsyncGenerator[Dict[str, Any], None]:
71
84
  try:
72
- yield self.task_prompt
85
+ self.trace_id = trace_id or self.trace_id or langfuse.create_trace_id()
73
86
 
87
+ self.task_no += 1
88
+ self.task_run_id = f"{self.task_id}[{self.task_no}]"
89
+
90
+ self.add_context_msg(type="user", content=task_message, is_llm_message=True)
91
+
92
+ if max_auto_run <= 1:
93
+ continuous_state:TaskRunContinuousState = {
94
+ "accumulated_content": "",
95
+ "auto_continue_count": 0,
96
+ "auto_continue": False
97
+ }
98
+ async for chunk in self._run_task_once(continuous_state):
99
+ yield chunk
100
+ else:
101
+ async for chunk in self._run_task_auto(max_auto_run):
102
+ yield chunk
74
103
  finally:
75
104
  await self.tool_box.destroy_task_tool_box(self.task_id)
76
-
77
105
 
78
- def _run_task_once(self):
79
- pass
106
+ async def _run_task_once(self, continuous_state: TaskRunContinuousState) -> AsyncGenerator[Dict[str, Any], None]:
107
+ llm_messages = [{"role": "system", "content": self.task_prompt}]
108
+ cxt_llm_contents = self._get_context_llm_contents()
109
+ llm_messages.extend(cxt_llm_contents)
110
+
111
+ partial_content = continuous_state.get('accumulated_content', '')
112
+ if partial_content:
113
+ temp_assistant_message = {
114
+ "role": "assistant",
115
+ "content": partial_content
116
+ }
117
+ llm_messages.append(temp_assistant_message)
118
+
119
+ llm_response = await self.llm_client.create_completion(llm_messages)
120
+ response_processor = self._create_response_processer()
121
+
122
+ async for chunk in response_processor.process_response(llm_response, llm_messages, continuous_state):
123
+ yield chunk
124
+
80
125
 
126
+ async def _run_task_auto(self, max_auto_run: int) -> AsyncGenerator:
127
+ continuous_state: TaskRunContinuousState = {
128
+ "accumulated_content": "",
129
+ "auto_continue_count": 0,
130
+ "auto_continue": True
131
+ }
81
132
 
82
- def add_message(self, message: XGAMessage):
83
- message.message_id = f"xga_msg_{uuid4()}"
84
- message.session_id = self.session_id
85
- message.agent_id = self.agent_id
86
- self.messages.append(message)
133
+ def update_continuous_state(_auto_continue_count, _auto_continue):
134
+ continuous_state["auto_continue_count"] = _auto_continue_count
135
+ continuous_state["auto_continue"] = _auto_continue
136
+
137
+ auto_continue_count = 0
138
+ auto_continue = True
139
+ while auto_continue and auto_continue_count < max_auto_run:
140
+ auto_continue = False
141
+
142
+ try:
143
+ async for chunk in self._run_task_once(continuous_state):
144
+ try:
145
+ if chunk.get("type") == "status":
146
+ content = json.loads(chunk.get('content', '{}'))
147
+ status_type = content.get('status_type', None)
148
+ if status_type == "error":
149
+ logging.error(f"run_task_auto: task_response error: {chunk.get('message', 'Unknown error')}")
150
+ yield chunk
151
+ return
152
+ elif status_type == 'finish':
153
+ finish_reason = content.get('finish_reason', None)
154
+ if finish_reason == 'stop' :
155
+ auto_continue = True
156
+ auto_continue_count += 1
157
+ update_continuous_state(auto_continue_count, auto_continue)
158
+ logging.info(f"run_task_auto: Detected finish_reason='stop', auto-continuing ({auto_continue_count}/{max_auto_run})")
159
+ continue
160
+ elif finish_reason == 'xml_tool_limit_reached':
161
+ logging.info(f"run_task_auto: Detected finish_reason='xml_tool_limit_reached', stopping auto-continue")
162
+ auto_continue = False
163
+ update_continuous_state(auto_continue_count, auto_continue)
164
+ elif finish_reason == 'length':
165
+ auto_continue = True
166
+ auto_continue_count += 1
167
+ update_continuous_state(auto_continue_count, auto_continue)
168
+ logging.info(f"run_task_auto: Detected finish_reason='length', auto-continuing ({auto_continue_count}/{max_auto_run})")
169
+ continue
170
+ except Exception as parse_error:
171
+ logging.error(f"run_task_auto: Error in parse chunk: {str(parse_error)}")
172
+ yield {
173
+ "type": "status",
174
+ "status": "error",
175
+ "message": f"Error in parse chunk: {str(parse_error)}"
176
+ }
177
+ return
178
+
179
+ # Otherwise just yield the chunk normally
180
+ yield chunk
181
+
182
+ # If not auto-continuing, we're done
183
+ if not auto_continue:
184
+ break
185
+ except Exception as run_error:
186
+ logging.error(f"run_task_auto: Call task_run_once error: {str(run_error)}")
187
+ yield {
188
+ "type": "status",
189
+ "status": "error",
190
+ "message": f"Call task_run_once error: {str(run_error)}"
191
+ }
192
+ return
193
+
194
+ def add_context_msg(self, type: Literal["user", "status", "tool", "assistant", "assistant_response_end"],
195
+ content: Union[Dict[str, Any], List[Any], str],
196
+ is_llm_message: bool,
197
+ metadata: Optional[Dict[str, Any]]=None)-> XGAContextMsg:
198
+ message = XGAContextMsg(
199
+ message_id = f"xga_msg_{uuid4()}",
200
+ type = type,
201
+ content = content,
202
+ is_llm_message=is_llm_message,
203
+ metadata = metadata,
204
+ session_id = self.session_id,
205
+ agent_id = self.agent_id,
206
+ task_id = self.task_id,
207
+ task_run_id = self.task_run_id,
208
+ trace_id = self.trace_id
209
+ )
210
+ self.task_context_msgs.append(message)
211
+
212
+ return message
213
+
214
+ def _get_context_llm_contents (self) -> List[Dict[str, Any]]:
215
+ llm_messages = []
216
+ for message in self.task_context_msgs:
217
+ if message["is_llm_message"]:
218
+ llm_messages.append(message)
219
+
220
+ cxt_llm_contents = []
221
+ for llm_message in llm_messages:
222
+ content = llm_message["content"]
223
+ # @todo content List type
224
+ if isinstance(content, str):
225
+ try:
226
+ _content = json.loads(content)
227
+ cxt_llm_contents.append(_content)
228
+ except json.JSONDecodeError as e:
229
+ logging.error(f"get_context_llm_contents: Failed to decode json, content=:{content}")
230
+ handle_error(e)
231
+ else:
232
+ cxt_llm_contents.append(content)
233
+
234
+ return cxt_llm_contents
235
+
236
+ def _create_response_processer(self) -> TaskResponseProcessor:
237
+ response_context = self._create_response_context()
238
+ is_stream = response_context.get("is_stream", False)
239
+ if is_stream:
240
+ from xgae.engine.responser.xga_stream_responser import StreamTaskResponser
241
+ return StreamTaskResponser(response_context)
242
+ else:
243
+ from xgae.engine.responser.xga_non_stream_responser import NonStreamTaskResponser
244
+ return NonStreamTaskResponser(response_context)
245
+
246
+ def _create_response_context(self) -> TaskResponseContext:
247
+ response_context: TaskResponseContext = {
248
+ "is_stream": self.is_stream,
249
+ "task_id": self.task_id,
250
+ "task_run_id": self.task_run_id,
251
+ "trace_id": self.trace_id,
252
+ "model_name": self.model_name,
253
+ "max_xml_tool_calls": 0,
254
+ "add_context_msg": self.add_context_msg,
255
+ "tool_box": self.tool_box,
256
+ "tool_execution_strategy": "parallel",
257
+ "xml_adding_strategy": "user_message",
258
+ }
259
+ return response_context
87
260
 
88
261
  if __name__ == "__main__":
89
262
  import asyncio
90
-
263
+ from xgae.utils.utils import read_file
91
264
  async def main():
92
265
  tool_box = XGAMcpToolBox(custom_mcp_server_file="mcpservers/custom_servers.json")
93
- engine = await XGATaskEngine.create(tool_box=tool_box, custom_tools=["bomc_fault.*"])
94
- # engine = await XGATaskEngine.create()
266
+ system_prompt = read_file("templates/scp_test_prompt.txt")
267
+ engine = await XGATaskEngine.create(tool_box=tool_box,
268
+ general_tools=[],
269
+ custom_tools=["bomc_fault.*"],
270
+ llm_config=LLMConfig(stream=False),
271
+ system_prompt=system_prompt)
272
+ #engine = await XGATaskEngine.create()
95
273
 
96
- async for chunk in engine.run_task(task_messages=[{}]):
274
+ async for chunk in engine.run_task(task_message={"role": "user", "content": "定位10.0.0.1的故障"},
275
+ max_auto_run=8):
97
276
  print(chunk)
98
277
 
99
278
  asyncio.run(main())
@@ -75,6 +75,12 @@ class XGAMcpToolBox(XGAToolBox):
75
75
  await self.call_tool(task_id, "end_task", {"task_id": task_id})
76
76
  self.task_tool_schemas.pop(task_id, None)
77
77
 
78
+ @override
79
+ def get_task_tool_names(self, task_id: str) -> List[str]:
80
+ task_tool_schema = self.task_tool_schemas.get(task_id, {})
81
+ task_tool_names = list(task_tool_schema.keys())
82
+ return task_tool_names
83
+
78
84
  @override
79
85
  def get_task_tool_schemas(self, task_id: str, type: Literal["general_tool", "custom_tool"]) -> List[XGAToolSchema]:
80
86
  task_tool_schemas = []
@@ -0,0 +1,174 @@
1
+ """
2
+ JSON helper utilities for handling both legacy (string) and new (dict/list) formats.
3
+
4
+ These utilities help with the transition from storing JSON as strings to storing
5
+ them as proper JSONB objects in the database.
6
+ """
7
+
8
+ import json
9
+ from typing import Any, Union, Dict, List
10
+
11
+
12
+ def ensure_dict(value: Union[str, Dict[str, Any], None], default: Dict[str, Any] = None) -> Dict[str, Any]:
13
+ """
14
+ Ensure a value is a dictionary.
15
+
16
+ Handles:
17
+ - None -> returns default or {}
18
+ - Dict -> returns as-is
19
+ - JSON string -> parses and returns dict
20
+ - Other -> returns default or {}
21
+
22
+ Args:
23
+ value: The value to ensure is a dict
24
+ default: Default value if conversion fails
25
+
26
+ Returns:
27
+ A dictionary
28
+ """
29
+ if default is None:
30
+ default = {}
31
+
32
+ if value is None:
33
+ return default
34
+
35
+ if isinstance(value, dict):
36
+ return value
37
+
38
+ if isinstance(value, str):
39
+ try:
40
+ parsed = json.loads(value)
41
+ if isinstance(parsed, dict):
42
+ return parsed
43
+ return default
44
+ except (json.JSONDecodeError, TypeError):
45
+ return default
46
+
47
+ return default
48
+
49
+
50
+ def ensure_list(value: Union[str, List[Any], None], default: List[Any] = None) -> List[Any]:
51
+ """
52
+ Ensure a value is a list.
53
+
54
+ Handles:
55
+ - None -> returns default or []
56
+ - List -> returns as-is
57
+ - JSON string -> parses and returns list
58
+ - Other -> returns default or []
59
+
60
+ Args:
61
+ value: The value to ensure is a list
62
+ default: Default value if conversion fails
63
+
64
+ Returns:
65
+ A list
66
+ """
67
+ if default is None:
68
+ default = []
69
+
70
+ if value is None:
71
+ return default
72
+
73
+ if isinstance(value, list):
74
+ return value
75
+
76
+ if isinstance(value, str):
77
+ try:
78
+ parsed = json.loads(value)
79
+ if isinstance(parsed, list):
80
+ return parsed
81
+ return default
82
+ except (json.JSONDecodeError, TypeError):
83
+ return default
84
+
85
+ return default
86
+
87
+
88
+ def safe_json_parse(value: Union[str, Dict, List, Any], default: Any = None) -> Any:
89
+ """
90
+ Safely parse a value that might be JSON string or already parsed.
91
+
92
+ This handles the transition period where some data might be stored as
93
+ JSON strings (old format) and some as proper objects (new format).
94
+
95
+ Args:
96
+ value: The value to parse
97
+ default: Default value if parsing fails
98
+
99
+ Returns:
100
+ Parsed value or default
101
+ """
102
+ if value is None:
103
+ return default
104
+
105
+ # If it's already a dict or list, return as-is
106
+ if isinstance(value, (dict, list)):
107
+ return value
108
+
109
+ # If it's a string, try to parse it
110
+ if isinstance(value, str):
111
+ try:
112
+ return json.loads(value)
113
+ except (json.JSONDecodeError, TypeError):
114
+ # If it's not valid JSON, return the string itself
115
+ return value
116
+
117
+ # For any other type, return as-is
118
+ return value
119
+
120
+
121
+ def to_json_string(value: Any) -> str:
122
+ """
123
+ Convert a value to a JSON string if needed.
124
+
125
+ This is used for backwards compatibility when yielding data that
126
+ expects JSON strings.
127
+
128
+ Args:
129
+ value: The value to convert
130
+
131
+ Returns:
132
+ JSON string representation
133
+ """
134
+ if isinstance(value, str):
135
+ # If it's already a string, check if it's valid JSON
136
+ try:
137
+ json.loads(value)
138
+ return value # It's already a JSON string
139
+ except (json.JSONDecodeError, TypeError):
140
+ # It's a plain string, encode it as JSON
141
+ return json.dumps(value)
142
+
143
+ # For all other types, convert to JSON
144
+ return json.dumps(value)
145
+
146
+
147
+ def format_for_yield(message_object: Dict[str, Any]) -> Dict[str, Any]:
148
+ """
149
+ Format a message object for yielding, ensuring content and metadata are JSON strings.
150
+
151
+ This maintains backward compatibility with clients expecting JSON strings
152
+ while the database now stores proper objects.
153
+
154
+ Args:
155
+ message_object: The message object from the database
156
+
157
+ Returns:
158
+ Message object with content and metadata as JSON strings
159
+ """
160
+ if not message_object:
161
+ return message_object
162
+
163
+ # Create a copy to avoid modifying the original
164
+ formatted = message_object.copy()
165
+
166
+ # Ensure content is a JSON string
167
+ if 'content' in formatted and not isinstance(formatted['content'], str):
168
+ formatted['content'] = json.dumps(formatted['content'])
169
+
170
+ # Ensure metadata is a JSON string
171
+ if 'metadata' in formatted and not isinstance(formatted['metadata'], str):
172
+ formatted['metadata'] = json.dumps(formatted['metadata'])
173
+
174
+ return formatted
xgae/utils/llm_client.py CHANGED
@@ -4,11 +4,24 @@ import logging
4
4
  import os
5
5
  import litellm
6
6
 
7
- from typing import Union, Dict, Any, Optional, List
7
+ from typing import Union, Dict, Any, Optional, List, TypedDict
8
8
 
9
9
  from litellm.utils import ModelResponse, CustomStreamWrapper
10
10
  from openai import OpenAIError
11
11
 
12
+ class LLMConfig(TypedDict, total=False):
13
+ model: str
14
+ model_name: str
15
+ model_id: str
16
+ api_key: str
17
+ api_base: str
18
+ temperature: float
19
+ max_tokens: int
20
+ stream: bool
21
+ enable_thinking: bool
22
+ reasoning_effort: str
23
+ response_format: str
24
+ top_p: int
12
25
 
13
26
  class LLMError(Exception):
14
27
  """Base exception for LLM-related errors."""
@@ -18,7 +31,7 @@ class LLMClient:
18
31
  RATE_LIMIT_DELAY = 30
19
32
  RETRY_DELAY = 0.1
20
33
 
21
- def __init__(self, llm_config: Optional[Dict[str, Any]]=None) -> None:
34
+ def __init__(self, llm_config: LLMConfig=None) -> None:
22
35
  """
23
36
  Arg: llm_config (Optional[Dict[str, Any]], optional)
24
37
  model: Override default model to use, default set by .env LLM_MODEL
@@ -34,7 +47,7 @@ class LLMClient:
34
47
  reasoning_effort: Optional level of reasoning effort, default is ‘low’
35
48
  top_p: Optional Top-p sampling parameter, default is None
36
49
  """
37
- llm_config = llm_config or {}
50
+ llm_config = llm_config or LLMConfig()
38
51
  litellm.modify_params = True
39
52
  litellm.drop_params = True
40
53
 
@@ -214,9 +227,7 @@ class LLMClient:
214
227
 
215
228
  if __name__ == "__main__":
216
229
  async def llm_completion():
217
- llm_client = LLMClient({
218
- "stream": False #default is True
219
- })
230
+ llm_client = LLMClient(LLMConfig(stream=False))
220
231
  messages = [{"role": "user", "content": "今天是2025年8月15日,北京本周每天温度"}]
221
232
  response = await llm_client.create_completion(messages)
222
233
  if llm_client.is_stream:
xgae/utils/setup_env.py CHANGED
@@ -89,8 +89,5 @@ langfuse: Langfuse = Langfuse if _langfuse_initialized else setup_langfuse()
89
89
 
90
90
 
91
91
  if __name__ == "__main__":
92
- try:
93
92
  trace_id = langfuse.create_trace_id()
94
- print(f"trace_id={trace_id}")
95
- except Exception as e:
96
- handle_error(e)
93
+ logging.warning(f"trace_id={trace_id}")