xgae 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xgae might be problematic. Click here for more details.

xgae/engine/xga_base.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Union, Optional, Dict, List, Any, Literal
1
+ from typing import Union, Optional, Dict, List, Any, Literal, TypedDict
2
2
  from dataclasses import dataclass
3
3
  from abc import ABC, abstractmethod
4
4
 
@@ -6,16 +6,23 @@ class XGAError(Exception):
6
6
  """Custom exception for errors in the XGA system."""
7
7
  pass
8
8
 
9
- @dataclass
10
- class XGAMessage:
11
- message_id: str
12
- type: Literal["status", "tool", "assistant", "assistant_response_end"]
13
- is_llm_message: bool
9
+
10
+ class XGAResponseMsg(TypedDict, total=False):
11
+ type: Literal["user", "status", "tool", "assistant", "assistant_response_end"]
14
12
  content: Union[Dict[str, Any], List[Any], str]
15
- metadata: Optional[Dict[str, Any]]
13
+ is_llm_message: bool
14
+ metadata: Dict[str, Any]
15
+ message_id: str
16
+ task_id: str
17
+ task_run_id: str
18
+ trace_id: str
16
19
  session_id: Optional[str]
17
20
  agent_id: Optional[str]
18
- task_id: Optional[str]
21
+
22
+ class XGATaskResult(TypedDict, total=False):
23
+ type: Literal["ask", "answer", "error"]
24
+ content: str
25
+ attachments: Optional[List[str]]
19
26
 
20
27
  @dataclass
21
28
  class XGAToolSchema:
@@ -31,6 +38,7 @@ class XGAToolResult:
31
38
  success: bool
32
39
  output: str
33
40
 
41
+
34
42
  class XGAToolBox(ABC):
35
43
  @abstractmethod
36
44
  async def creat_task_tool_box(self, task_id: str, general_tools: List[str], custom_tools: List[str]):
@@ -44,6 +52,10 @@ class XGAToolBox(ABC):
44
52
  def get_task_tool_schemas(self, task_id: str, type: Literal["general_tool", "custom_tool"]) -> List[XGAToolSchema]:
45
53
  pass
46
54
 
55
+ @abstractmethod
56
+ def get_task_tool_names(self, task_id: str) -> List[str]:
57
+ pass
58
+
47
59
  @abstractmethod
48
60
  async def call_tool(self, task_id: str, tool_name: str, args: Optional[Dict[str, Any]] = None) -> XGAToolResult:
49
61
  pass
xgae/engine/xga_engine.py CHANGED
@@ -1,30 +1,32 @@
1
1
 
2
- from typing import List, Any, Dict, Optional, AsyncGenerator
2
+ import logging
3
+ import json
4
+
5
+ from typing import List, Any, Dict, Optional, AsyncGenerator, cast, Union, Literal
3
6
  from uuid import uuid4
4
7
 
5
- from xgae.engine.xga_base import XGAMessage, XGAToolBox
6
- from xgae.utils.llm_client import LLMClient
8
+ from xgae.engine.responser.xga_responser_base import TaskResponseContext, TaskResponseProcessor, TaskRunContinuousState
9
+ from xgae.engine.xga_base import XGAResponseMsg, XGAToolBox, XGATaskResult
10
+ from xgae.utils.llm_client import LLMClient, LLMConfig
7
11
  from xgae.utils.setup_env import langfuse
12
+ from xgae.utils.utils import handle_error
13
+
8
14
  from xga_prompt_builder import XGAPromptBuilder
9
15
  from xga_mcp_tool_box import XGAMcpToolBox
10
16
 
11
-
12
- class XGATaskEngine():
17
+ class XGATaskEngine:
13
18
  def __init__(self,
14
19
  session_id: Optional[str] = None,
15
20
  task_id: Optional[str] = None,
16
21
  agent_id: Optional[str] = None,
17
- trace_id: Optional[str] = None,
18
22
  system_prompt: Optional[str] = None,
19
- llm_config: Optional[Dict[str, Any]] = None,
23
+ llm_config: Optional[LLMConfig] = None,
20
24
  prompt_builder: Optional[XGAPromptBuilder] = None,
21
25
  tool_box: Optional[XGAToolBox] = None):
22
- self.session_id = session_id if session_id else f"xga_sid_{uuid4()}"
23
26
  self.task_id = task_id if task_id else f"xga_task_{uuid4()}"
24
27
  self.agent_id = agent_id
25
- self.trace_id = trace_id if trace_id else langfuse.create_trace_id()
28
+ self.session_id = session_id
26
29
 
27
- self.messages: List[XGAMessage] = []
28
30
  self.llm_client = LLMClient(llm_config)
29
31
  self.model_name = self.llm_client.model_name
30
32
  self.is_stream = self.llm_client.is_stream
@@ -32,8 +34,12 @@ class XGATaskEngine():
32
34
  self.prompt_builder = prompt_builder or XGAPromptBuilder(system_prompt)
33
35
  self.tool_box = tool_box or XGAMcpToolBox()
34
36
 
37
+ self.task_response_msgs: List[XGAResponseMsg] = []
38
+ self.task_no = -1
39
+ self.task_run_id = f"{self.task_id}[{self.task_no}]"
40
+ self.trace_id = None
35
41
 
36
- async def __async_init__(self, general_tools:List[str], custom_tools: List[str]) -> None:
42
+ async def _post_init_(self, general_tools:List[str], custom_tools: List[str]) -> None:
37
43
  await self.tool_box.load_mcp_tools_schema()
38
44
  await self.tool_box.creat_task_tool_box(self.task_id, general_tools, custom_tools)
39
45
  general_tool_schemas = self.tool_box.get_task_tool_schemas(self.task_id, "general_tool")
@@ -46,54 +52,296 @@ class XGATaskEngine():
46
52
  session_id: Optional[str] = None,
47
53
  task_id: Optional[str] = None,
48
54
  agent_id: Optional[str] = None,
49
- trace_id: Optional[str] = None,
50
55
  system_prompt: Optional[str] = None,
51
56
  general_tools: Optional[List[str]] = None,
52
57
  custom_tools: Optional[List[str]] = None,
53
- llm_config: Optional[Dict[str, Any]] = None,
58
+ llm_config: Optional[LLMConfig] = None,
54
59
  prompt_builder: Optional[XGAPromptBuilder] = None,
55
60
  tool_box: Optional[XGAToolBox] = None) -> 'XGATaskEngine':
56
61
  engine: XGATaskEngine = cls(session_id=session_id,
57
62
  task_id=task_id,
58
63
  agent_id=agent_id,
59
- trace_id=trace_id,
60
64
  system_prompt=system_prompt,
61
65
  llm_config=llm_config,
62
66
  prompt_builder=prompt_builder,
63
67
  tool_box=tool_box)
64
- general_tools = general_tools or ["*"]
68
+
69
+ general_tools = general_tools or ["complete", "ask"]
65
70
  custom_tools = custom_tools or []
66
- await engine.__async_init__(general_tools, custom_tools)
71
+ await engine._post_init_(general_tools, custom_tools)
72
+
73
+ logging.info("*"*30 + f" XGATaskEngine Task'{engine.task_id}' Initialized " + "*"*30)
74
+ logging.info(f"model_name={engine.model_name}, is_stream={engine.is_stream}, trace_id={engine.trace_id}")
75
+ logging.info(f"general_tools={general_tools}, custom_tools={custom_tools}")
76
+
67
77
  return engine
68
78
 
79
+ async def run_task_with_final_answer(self,
80
+ task_message: Dict[str, Any],
81
+ max_auto_run: int = 25,
82
+ trace_id: Optional[str] = None) -> XGATaskResult:
83
+ chunks = []
84
+ async for chunk in self.run_task(task_message=task_message, max_auto_run=max_auto_run, trace_id=trace_id):
85
+ chunks.append(chunk)
69
86
 
70
- async def run_task(self, task_messages: List[Dict[str, Any]]) -> AsyncGenerator:
87
+ final_result = self._parse_final_result(chunks)
88
+ return final_result
89
+
90
+ async def run_task(self,
91
+ task_message: Dict[str, Any],
92
+ max_auto_run: int = 25,
93
+ trace_id: Optional[str] = None) -> AsyncGenerator[Dict[str, Any], None]:
71
94
  try:
72
- yield self.task_prompt
95
+ self.trace_id = trace_id or self.trace_id or langfuse.create_trace_id()
96
+
97
+ self.task_no += 1
98
+ self.task_run_id = f"{self.task_id}[{self.task_no}]"
73
99
 
100
+ self.add_response_msg(type="user", content=task_message, is_llm_message=True)
101
+
102
+ if max_auto_run <= 1:
103
+ continuous_state:TaskRunContinuousState = {
104
+ "accumulated_content": "",
105
+ "auto_continue_count": 0,
106
+ "auto_continue": False
107
+ }
108
+ async for chunk in self._run_task_once(continuous_state):
109
+ yield chunk
110
+ else:
111
+ async for chunk in self._run_task_auto(max_auto_run):
112
+ yield chunk
74
113
  finally:
75
114
  await self.tool_box.destroy_task_tool_box(self.task_id)
76
-
77
115
 
78
- def _run_task_once(self):
79
- pass
116
+ async def _run_task_once(self, continuous_state: TaskRunContinuousState) -> AsyncGenerator[Dict[str, Any], None]:
117
+ llm_messages = [{"role": "system", "content": self.task_prompt}]
118
+ cxt_llm_contents = self._get_response_llm_contents()
119
+ llm_messages.extend(cxt_llm_contents)
120
+
121
+ partial_content = continuous_state.get('accumulated_content', '')
122
+ if partial_content:
123
+ temp_assistant_message = {
124
+ "role": "assistant",
125
+ "content": partial_content
126
+ }
127
+ llm_messages.append(temp_assistant_message)
128
+
129
+ llm_response = await self.llm_client.create_completion(llm_messages)
130
+ response_processor = self._create_response_processer()
131
+
132
+ async for chunk in response_processor.process_response(llm_response, llm_messages, continuous_state):
133
+ self._reponse_chunk_log(chunk)
134
+ yield chunk
135
+
136
+ async def _run_task_auto(self, max_auto_run: int) -> AsyncGenerator[Dict[str, Any], None]:
137
+ continuous_state: TaskRunContinuousState = {
138
+ "accumulated_content": "",
139
+ "auto_continue_count": 0,
140
+ "auto_continue": True
141
+ }
142
+
143
+ def update_continuous_state(_auto_continue_count, _auto_continue):
144
+ continuous_state["auto_continue_count"] = _auto_continue_count
145
+ continuous_state["auto_continue"] = _auto_continue
146
+
147
+ auto_continue_count = 0
148
+ auto_continue = True
149
+ while auto_continue and auto_continue_count < max_auto_run:
150
+ auto_continue = False
80
151
 
152
+ try:
153
+ async for chunk in self._run_task_once(continuous_state):
154
+ yield chunk
155
+ try:
156
+ if chunk.get("type") == "status":
157
+ content = json.loads(chunk.get('content', '{}'))
158
+ status_type = content.get('status_type', None)
159
+ if status_type == "error":
160
+ logging.error(f"run_task_auto: task_response error: {chunk.get('message', 'Unknown error')}")
161
+ auto_continue = False
162
+ break
163
+ elif status_type == 'finish':
164
+ finish_reason = content.get('finish_reason', None)
165
+ if finish_reason == 'completed':
166
+ logging.warning(f"run_task_auto: Detected finish_reason='completed', Task Completed Success !")
167
+ auto_continue = False
168
+ break
169
+ elif finish_reason == 'xml_tool_limit_reached':
170
+ logging.warning(f"run_task_auto: Detected finish_reason='xml_tool_limit_reached', stopping auto-continue")
171
+ auto_continue = False
172
+ break
173
+ elif finish_reason == 'stop' or finish_reason == 'length': # 'length' never occur
174
+ auto_continue = True
175
+ auto_continue_count += 1
176
+ update_continuous_state(auto_continue_count, auto_continue)
177
+ logging.info(f"run_task_auto: Detected finish_reason='{finish_reason}', auto-continuing ({auto_continue_count}/{max_auto_run})")
178
+ except StopAsyncIteration:
179
+ pass
180
+ except Exception as parse_error:
181
+ logging.error(f"run_task_auto: Error in parse chunk: {str(parse_error)}")
182
+ content = {"role": "system", "status_type": "error", "message": "Parse response chunk Error"}
183
+ error_msg = self.add_response_msg(type="status", content=content, is_llm_message=False)
184
+ yield error_msg
185
+ except Exception as run_error:
186
+ logging.error(f"run_task_auto: Call task_run_once error: {str(run_error)}")
187
+ content = {"role": "system", "status_type": "error", "message": "Call task_run_once error"}
188
+ error_msg = self.add_response_msg(type="status", content=content, is_llm_message=False)
189
+ yield error_msg
190
+
191
+ def _parse_final_result(self, chunks: List[Dict[str, Any]]) -> XGATaskResult:
192
+ final_result: XGATaskResult = None
193
+ try:
194
+ finish_reason = ''
195
+ for chunk in reversed(chunks):
196
+ chunk_type = chunk.get("type")
197
+ if chunk_type == "status":
198
+ status_content = json.loads(chunk.get('content', '{}'))
199
+ status_type = status_content.get('status_type', None)
200
+ if status_type == "error":
201
+ error = status_content.get('message', 'Unknown error')
202
+ final_result = XGATaskResult(type="error", content=error)
203
+ break
204
+ elif status_type == "finish":
205
+ finish_reason = status_content.get('finish_reason', None)
206
+ if finish_reason == 'xml_tool_limit_reached':
207
+ error = "Completed due to over task max_auto_run limit !"
208
+ final_result = XGATaskResult(type="error", content=error)
209
+ break
210
+ continue
211
+ elif chunk_type == "tool" and finish_reason in ['completed', 'stop']:
212
+ tool_content = json.loads(chunk.get('content', '{}'))
213
+ tool_execution = tool_content.get('tool_execution')
214
+ tool_name = tool_execution.get('function_name')
215
+ if tool_name == "complete":
216
+ result_content = tool_execution["arguments"].get("text", "Task completed with no answer")
217
+ attachments = tool_execution["arguments"].get("attachments", None)
218
+ final_result = XGATaskResult(type="answer", content=result_content, attachments=attachments)
219
+ elif tool_name == "ask":
220
+ result_content = tool_execution["arguments"].get("text", "Task ask for more info")
221
+ attachments = tool_execution["arguments"].get("attachments", None)
222
+ final_result = XGATaskResult(type="ask", content=result_content, attachments=attachments)
223
+ else:
224
+ tool_result = tool_execution.get("result", None)
225
+ if tool_result is not None:
226
+ success = tool_result.get("success")
227
+ output = tool_result.get("output")
228
+ result_type = "answer" if success else "error"
229
+ result_content = f"Task execute '{tool_name}' {result_type}: {output}"
230
+ final_result = XGATaskResult(type=result_type, content=result_content)
231
+ elif chunk_type == "assistant" and finish_reason == 'stop':
232
+ assis_content = chunk.get('content', '{}')
233
+ result_content = assis_content.get("content", "LLM output is empty")
234
+ final_result = XGATaskResult(type="answer", content=result_content)
235
+ if final_result is not None:
236
+ break
237
+ except Exception as e:
238
+ logging.error(f"parse_final_result: Final result pass error: {str(e)}")
239
+ final_result = XGATaskResult(type="error", content="Parse final result failed!")
240
+ handle_error(e)
241
+
242
+ return final_result
243
+
244
+ def add_response_msg(self, type: Literal["user", "status", "tool", "assistant", "assistant_response_end"],
245
+ content: Union[Dict[str, Any], List[Any], str],
246
+ is_llm_message: bool,
247
+ metadata: Optional[Dict[str, Any]]=None)-> XGAResponseMsg:
248
+ message = XGAResponseMsg(
249
+ message_id = f"xga_msg_{uuid4()}",
250
+ type = type,
251
+ content = content,
252
+ is_llm_message=is_llm_message,
253
+ metadata = metadata,
254
+ session_id = self.session_id,
255
+ agent_id = self.agent_id,
256
+ task_id = self.task_id,
257
+ task_run_id = self.task_run_id,
258
+ trace_id = self.trace_id
259
+ )
260
+ self.task_response_msgs.append(message)
261
+
262
+ return message
263
+
264
+ def _get_response_llm_contents (self) -> List[Dict[str, Any]]:
265
+ llm_messages = []
266
+ for message in self.task_response_msgs:
267
+ if message["is_llm_message"]:
268
+ llm_messages.append(message)
269
+
270
+ cxt_llm_contents = []
271
+ for llm_message in llm_messages:
272
+ content = llm_message["content"]
273
+ # @todo content List type
274
+ if isinstance(content, str):
275
+ try:
276
+ _content = json.loads(content)
277
+ cxt_llm_contents.append(_content)
278
+ except json.JSONDecodeError as e:
279
+ logging.error(f"get_context_llm_contents: Failed to decode json, content=:{content}")
280
+ handle_error(e)
281
+ else:
282
+ cxt_llm_contents.append(content)
283
+
284
+ return cxt_llm_contents
285
+
286
+ def _create_response_processer(self) -> TaskResponseProcessor:
287
+ response_context = self._create_response_context()
288
+ is_stream = response_context.get("is_stream", False)
289
+ if is_stream:
290
+ from xgae.engine.responser.xga_stream_responser import StreamTaskResponser
291
+ return StreamTaskResponser(response_context)
292
+ else:
293
+ from xgae.engine.responser.xga_non_stream_responser import NonStreamTaskResponser
294
+ return NonStreamTaskResponser(response_context)
295
+
296
+ def _create_response_context(self) -> TaskResponseContext:
297
+ response_context: TaskResponseContext = {
298
+ "is_stream": self.is_stream,
299
+ "task_id": self.task_id,
300
+ "task_run_id": self.task_run_id,
301
+ "trace_id": self.trace_id,
302
+ "model_name": self.model_name,
303
+ "max_xml_tool_calls": 0,
304
+ "add_context_msg": self.add_response_msg,
305
+ "tool_box": self.tool_box,
306
+ "tool_execution_strategy": "parallel",
307
+ "xml_adding_strategy": "user_message",
308
+ }
309
+ return response_context
310
+
311
+ def _reponse_chunk_log(self, chunk):
312
+ chunk_type = chunk.get('type')
313
+ prefix = ""
314
+
315
+ if chunk_type == 'status':
316
+ content = json.loads(chunk.get('content', '{}'))
317
+ status_type = content.get('status_type', "empty")
318
+ prefix = "-" + status_type
319
+ elif chunk_type == 'tool':
320
+ tool_content = json.loads(chunk.get('content', '{}'))
321
+ tool_execution = tool_content.get('tool_execution')
322
+ tool_name = tool_execution.get('function_name')
323
+ prefix = "-" + tool_name
324
+
325
+ logging.info(f"TASK_RESP_CHUNK[{chunk_type}{prefix}]: {chunk}")
81
326
 
82
- def add_message(self, message: XGAMessage):
83
- message.message_id = f"xga_msg_{uuid4()}"
84
- message.session_id = self.session_id
85
- message.agent_id = self.agent_id
86
- self.messages.append(message)
87
327
 
88
328
  if __name__ == "__main__":
89
329
  import asyncio
90
-
330
+ from xgae.utils.utils import read_file
91
331
  async def main():
92
332
  tool_box = XGAMcpToolBox(custom_mcp_server_file="mcpservers/custom_servers.json")
93
- engine = await XGATaskEngine.create(tool_box=tool_box, custom_tools=["bomc_fault.*"])
94
- # engine = await XGATaskEngine.create()
95
-
96
- async for chunk in engine.run_task(task_messages=[{}]):
97
- print(chunk)
333
+ system_prompt = read_file("templates/scp_test_prompt.txt")
334
+ engine = await XGATaskEngine.create(tool_box=tool_box,
335
+ general_tools=[],
336
+ custom_tools=["bomc_fault.*"],
337
+ llm_config=LLMConfig(stream=False),
338
+ system_prompt=system_prompt)
339
+ # engine = await XGATaskEngine.create(llm_config=LLMConfig(stream=False))
340
+ #chunks = []
341
+ # async for chunk in engine.run_task(task_message={"role": "user", "content": "定位10.0.0.1的故障"},max_auto_run=8):
342
+ # print(chunk)
343
+ #final_result = await engine.run_task_with_final_answer(task_message={"role": "user", "content": "1+1"}, max_auto_run=2)
98
344
 
345
+ final_result = await engine.run_task_with_final_answer(task_message={"role": "user", "content": "定位10.0.1.1故障"},max_auto_run=8)
346
+ print("FINAL RESULT:", final_result)
99
347
  asyncio.run(main())
@@ -75,6 +75,12 @@ class XGAMcpToolBox(XGAToolBox):
75
75
  await self.call_tool(task_id, "end_task", {"task_id": task_id})
76
76
  self.task_tool_schemas.pop(task_id, None)
77
77
 
78
+ @override
79
+ def get_task_tool_names(self, task_id: str) -> List[str]:
80
+ task_tool_schema = self.task_tool_schemas.get(task_id, {})
81
+ task_tool_names = list(task_tool_schema.keys())
82
+ return task_tool_names
83
+
78
84
  @override
79
85
  def get_task_tool_schemas(self, task_id: str, type: Literal["general_tool", "custom_tool"]) -> List[XGAToolSchema]:
80
86
  task_tool_schemas = []
@@ -106,10 +112,7 @@ class XGAMcpToolBox(XGAToolBox):
106
112
  if mcp_tool:
107
113
  tool_args = args or {}
108
114
  if server_name == self.GENERAL_MCP_SERVER_NAME:
109
- pass
110
- #tool_args["task_id"] = task_id #xga general tool, first param must be task_id
111
- else:
112
- tool_args = args
115
+ tool_args = dict({"task_id": task_id}, **tool_args)
113
116
 
114
117
  try:
115
118
  tool_result = await mcp_tool.arun(tool_args)
@@ -0,0 +1,174 @@
1
+ """
2
+ JSON helper utilities for handling both legacy (string) and new (dict/list) formats.
3
+
4
+ These utilities help with the transition from storing JSON as strings to storing
5
+ them as proper JSONB objects in the database.
6
+ """
7
+
8
+ import json
9
+ from typing import Any, Union, Dict, List
10
+
11
+
12
+ def ensure_dict(value: Union[str, Dict[str, Any], None], default: Dict[str, Any] = None) -> Dict[str, Any]:
13
+ """
14
+ Ensure a value is a dictionary.
15
+
16
+ Handles:
17
+ - None -> returns default or {}
18
+ - Dict -> returns as-is
19
+ - JSON string -> parses and returns dict
20
+ - Other -> returns default or {}
21
+
22
+ Args:
23
+ value: The value to ensure is a dict
24
+ default: Default value if conversion fails
25
+
26
+ Returns:
27
+ A dictionary
28
+ """
29
+ if default is None:
30
+ default = {}
31
+
32
+ if value is None:
33
+ return default
34
+
35
+ if isinstance(value, dict):
36
+ return value
37
+
38
+ if isinstance(value, str):
39
+ try:
40
+ parsed = json.loads(value)
41
+ if isinstance(parsed, dict):
42
+ return parsed
43
+ return default
44
+ except (json.JSONDecodeError, TypeError):
45
+ return default
46
+
47
+ return default
48
+
49
+
50
+ def ensure_list(value: Union[str, List[Any], None], default: List[Any] = None) -> List[Any]:
51
+ """
52
+ Ensure a value is a list.
53
+
54
+ Handles:
55
+ - None -> returns default or []
56
+ - List -> returns as-is
57
+ - JSON string -> parses and returns list
58
+ - Other -> returns default or []
59
+
60
+ Args:
61
+ value: The value to ensure is a list
62
+ default: Default value if conversion fails
63
+
64
+ Returns:
65
+ A list
66
+ """
67
+ if default is None:
68
+ default = []
69
+
70
+ if value is None:
71
+ return default
72
+
73
+ if isinstance(value, list):
74
+ return value
75
+
76
+ if isinstance(value, str):
77
+ try:
78
+ parsed = json.loads(value)
79
+ if isinstance(parsed, list):
80
+ return parsed
81
+ return default
82
+ except (json.JSONDecodeError, TypeError):
83
+ return default
84
+
85
+ return default
86
+
87
+
88
+ def safe_json_parse(value: Union[str, Dict, List, Any], default: Any = None) -> Any:
89
+ """
90
+ Safely parse a value that might be JSON string or already parsed.
91
+
92
+ This handles the transition period where some data might be stored as
93
+ JSON strings (old format) and some as proper objects (new format).
94
+
95
+ Args:
96
+ value: The value to parse
97
+ default: Default value if parsing fails
98
+
99
+ Returns:
100
+ Parsed value or default
101
+ """
102
+ if value is None:
103
+ return default
104
+
105
+ # If it's already a dict or list, return as-is
106
+ if isinstance(value, (dict, list)):
107
+ return value
108
+
109
+ # If it's a string, try to parse it
110
+ if isinstance(value, str):
111
+ try:
112
+ return json.loads(value)
113
+ except (json.JSONDecodeError, TypeError):
114
+ # If it's not valid JSON, return the string itself
115
+ return value
116
+
117
+ # For any other type, return as-is
118
+ return value
119
+
120
+
121
+ def to_json_string(value: Any) -> str:
122
+ """
123
+ Convert a value to a JSON string if needed.
124
+
125
+ This is used for backwards compatibility when yielding data that
126
+ expects JSON strings.
127
+
128
+ Args:
129
+ value: The value to convert
130
+
131
+ Returns:
132
+ JSON string representation
133
+ """
134
+ if isinstance(value, str):
135
+ # If it's already a string, check if it's valid JSON
136
+ try:
137
+ json.loads(value)
138
+ return value # It's already a JSON string
139
+ except (json.JSONDecodeError, TypeError):
140
+ # It's a plain string, encode it as JSON
141
+ return json.dumps(value)
142
+
143
+ # For all other types, convert to JSON
144
+ return json.dumps(value)
145
+
146
+
147
+ def format_for_yield(message_object: Dict[str, Any]) -> Dict[str, Any]:
148
+ """
149
+ Format a message object for yielding, ensuring content and metadata are JSON strings.
150
+
151
+ This maintains backward compatibility with clients expecting JSON strings
152
+ while the database now stores proper objects.
153
+
154
+ Args:
155
+ message_object: The message object from the database
156
+
157
+ Returns:
158
+ Message object with content and metadata as JSON strings
159
+ """
160
+ if not message_object:
161
+ return message_object
162
+
163
+ # Create a copy to avoid modifying the original
164
+ formatted = message_object.copy()
165
+
166
+ # Ensure content is a JSON string
167
+ if 'content' in formatted and not isinstance(formatted['content'], str):
168
+ formatted['content'] = json.dumps(formatted['content'])
169
+
170
+ # Ensure metadata is a JSON string
171
+ if 'metadata' in formatted and not isinstance(formatted['metadata'], str):
172
+ formatted['metadata'] = json.dumps(formatted['metadata'])
173
+
174
+ return formatted
xgae/utils/llm_client.py CHANGED
@@ -4,11 +4,24 @@ import logging
4
4
  import os
5
5
  import litellm
6
6
 
7
- from typing import Union, Dict, Any, Optional, List
7
+ from typing import Union, Dict, Any, Optional, List, TypedDict
8
8
 
9
9
  from litellm.utils import ModelResponse, CustomStreamWrapper
10
10
  from openai import OpenAIError
11
11
 
12
+ class LLMConfig(TypedDict, total=False):
13
+ model: str
14
+ model_name: str
15
+ model_id: str
16
+ api_key: str
17
+ api_base: str
18
+ temperature: float
19
+ max_tokens: int
20
+ stream: bool
21
+ enable_thinking: bool
22
+ reasoning_effort: str
23
+ response_format: str
24
+ top_p: int
12
25
 
13
26
  class LLMError(Exception):
14
27
  """Base exception for LLM-related errors."""
@@ -18,7 +31,7 @@ class LLMClient:
18
31
  RATE_LIMIT_DELAY = 30
19
32
  RETRY_DELAY = 0.1
20
33
 
21
- def __init__(self, llm_config: Optional[Dict[str, Any]]=None) -> None:
34
+ def __init__(self, llm_config: LLMConfig=None) -> None:
22
35
  """
23
36
  Arg: llm_config (Optional[Dict[str, Any]], optional)
24
37
  model: Override default model to use, default set by .env LLM_MODEL
@@ -34,7 +47,7 @@ class LLMClient:
34
47
  reasoning_effort: Optional level of reasoning effort, default is ‘low’
35
48
  top_p: Optional Top-p sampling parameter, default is None
36
49
  """
37
- llm_config = llm_config or {}
50
+ llm_config = llm_config or LLMConfig()
38
51
  litellm.modify_params = True
39
52
  litellm.drop_params = True
40
53
 
@@ -214,9 +227,7 @@ class LLMClient:
214
227
 
215
228
  if __name__ == "__main__":
216
229
  async def llm_completion():
217
- llm_client = LLMClient({
218
- "stream": False #default is True
219
- })
230
+ llm_client = LLMClient(LLMConfig(stream=False))
220
231
  messages = [{"role": "user", "content": "今天是2025年8月15日,北京本周每天温度"}]
221
232
  response = await llm_client.create_completion(messages)
222
233
  if llm_client.is_stream: