xgae 0.1.3__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xgae might be problematic. Click here for more details.
- xgae/engine/responser/xga_non_stream_responser.py +213 -0
- xgae/engine/responser/xga_responser_base.py +751 -0
- xgae/engine/responser/xga_stream_responser.py +787 -0
- xgae/engine/xga_base.py +21 -9
- xgae/engine/xga_engine.py +206 -27
- xgae/engine/xga_mcp_tool_box.py +6 -0
- xgae/utils/json_helpers.py +174 -0
- xgae/utils/llm_client.py +17 -6
- xgae/utils/setup_env.py +1 -4
- xgae/utils/xml_tool_parser.py +236 -0
- {xgae-0.1.3.dist-info → xgae-0.1.4.dist-info}/METADATA +1 -1
- xgae-0.1.4.dist-info/RECORD +16 -0
- xgae/engine/responser/xga_responser_utils.py +0 -0
- xgae/engine/responser/xga_stream_reponser.py +0 -0
- xgae-0.1.3.dist-info/RECORD +0 -14
- {xgae-0.1.3.dist-info → xgae-0.1.4.dist-info}/WHEEL +0 -0
xgae/engine/xga_base.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Union, Optional, Dict, List, Any, Literal
|
|
1
|
+
from typing import Union, Optional, Dict, List, Any, Literal, TypedDict
|
|
2
2
|
from dataclasses import dataclass
|
|
3
3
|
from abc import ABC, abstractmethod
|
|
4
4
|
|
|
@@ -6,16 +6,24 @@ class XGAError(Exception):
|
|
|
6
6
|
"""Custom exception for errors in the XGA system."""
|
|
7
7
|
pass
|
|
8
8
|
|
|
9
|
-
|
|
10
|
-
class
|
|
11
|
-
|
|
12
|
-
|
|
9
|
+
|
|
10
|
+
class XGAContextMsg(TypedDict, total=False):
|
|
11
|
+
type: Literal["user", "status", "tool", "assistant", "assistant_response_end"]
|
|
12
|
+
content: Union[Dict[str, Any], List[Any], str]
|
|
13
13
|
is_llm_message: bool
|
|
14
|
+
metadata: Dict[str, Any]
|
|
15
|
+
message_id: str
|
|
16
|
+
session_id: str
|
|
17
|
+
agent_id: str
|
|
18
|
+
task_id: str
|
|
19
|
+
task_run_id: str
|
|
20
|
+
trace_id: str
|
|
21
|
+
|
|
22
|
+
class XGAResponseMsg(TypedDict, total=False):
|
|
23
|
+
type: Literal["content", "status"]
|
|
14
24
|
content: Union[Dict[str, Any], List[Any], str]
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
agent_id: Optional[str]
|
|
18
|
-
task_id: Optional[str]
|
|
25
|
+
status: Literal["error", "status"]
|
|
26
|
+
message: str
|
|
19
27
|
|
|
20
28
|
@dataclass
|
|
21
29
|
class XGAToolSchema:
|
|
@@ -44,6 +52,10 @@ class XGAToolBox(ABC):
|
|
|
44
52
|
def get_task_tool_schemas(self, task_id: str, type: Literal["general_tool", "custom_tool"]) -> List[XGAToolSchema]:
|
|
45
53
|
pass
|
|
46
54
|
|
|
55
|
+
@abstractmethod
|
|
56
|
+
def get_task_tool_names(self, task_id: str) -> List[str]:
|
|
57
|
+
pass
|
|
58
|
+
|
|
47
59
|
@abstractmethod
|
|
48
60
|
async def call_tool(self, task_id: str, tool_name: str, args: Optional[Dict[str, Any]] = None) -> XGAToolResult:
|
|
49
61
|
pass
|
xgae/engine/xga_engine.py
CHANGED
|
@@ -1,30 +1,32 @@
|
|
|
1
1
|
|
|
2
|
-
|
|
2
|
+
import logging
|
|
3
|
+
import json
|
|
4
|
+
|
|
5
|
+
from typing import List, Any, Dict, Optional, AsyncGenerator, cast, Union, Literal
|
|
3
6
|
from uuid import uuid4
|
|
4
7
|
|
|
5
|
-
from xgae.engine.
|
|
6
|
-
from xgae.
|
|
8
|
+
from xgae.engine.responser.xga_responser_base import TaskResponseContext, TaskResponseProcessor, TaskRunContinuousState
|
|
9
|
+
from xgae.engine.xga_base import XGAContextMsg, XGAToolBox, XGAResponseMsg
|
|
10
|
+
from xgae.utils.llm_client import LLMClient, LLMConfig
|
|
7
11
|
from xgae.utils.setup_env import langfuse
|
|
12
|
+
from xgae.utils.utils import handle_error
|
|
13
|
+
|
|
8
14
|
from xga_prompt_builder import XGAPromptBuilder
|
|
9
15
|
from xga_mcp_tool_box import XGAMcpToolBox
|
|
10
16
|
|
|
11
|
-
|
|
12
|
-
class XGATaskEngine():
|
|
17
|
+
class XGATaskEngine:
|
|
13
18
|
def __init__(self,
|
|
14
19
|
session_id: Optional[str] = None,
|
|
15
20
|
task_id: Optional[str] = None,
|
|
16
21
|
agent_id: Optional[str] = None,
|
|
17
|
-
trace_id: Optional[str] = None,
|
|
18
22
|
system_prompt: Optional[str] = None,
|
|
19
|
-
llm_config: Optional[
|
|
23
|
+
llm_config: Optional[LLMConfig] = None,
|
|
20
24
|
prompt_builder: Optional[XGAPromptBuilder] = None,
|
|
21
25
|
tool_box: Optional[XGAToolBox] = None):
|
|
22
26
|
self.session_id = session_id if session_id else f"xga_sid_{uuid4()}"
|
|
23
27
|
self.task_id = task_id if task_id else f"xga_task_{uuid4()}"
|
|
24
28
|
self.agent_id = agent_id
|
|
25
|
-
self.trace_id = trace_id if trace_id else langfuse.create_trace_id()
|
|
26
29
|
|
|
27
|
-
self.messages: List[XGAMessage] = []
|
|
28
30
|
self.llm_client = LLMClient(llm_config)
|
|
29
31
|
self.model_name = self.llm_client.model_name
|
|
30
32
|
self.is_stream = self.llm_client.is_stream
|
|
@@ -32,6 +34,10 @@ class XGATaskEngine():
|
|
|
32
34
|
self.prompt_builder = prompt_builder or XGAPromptBuilder(system_prompt)
|
|
33
35
|
self.tool_box = tool_box or XGAMcpToolBox()
|
|
34
36
|
|
|
37
|
+
self.task_context_msgs: List[XGAContextMsg] = []
|
|
38
|
+
self.task_no = -1
|
|
39
|
+
self.task_run_id = f"{self.task_id}[{self.task_no}]"
|
|
40
|
+
self.trace_id = None
|
|
35
41
|
|
|
36
42
|
async def __async_init__(self, general_tools:List[str], custom_tools: List[str]) -> None:
|
|
37
43
|
await self.tool_box.load_mcp_tools_schema()
|
|
@@ -46,54 +52,227 @@ class XGATaskEngine():
|
|
|
46
52
|
session_id: Optional[str] = None,
|
|
47
53
|
task_id: Optional[str] = None,
|
|
48
54
|
agent_id: Optional[str] = None,
|
|
49
|
-
trace_id: Optional[str] = None,
|
|
50
55
|
system_prompt: Optional[str] = None,
|
|
51
56
|
general_tools: Optional[List[str]] = None,
|
|
52
57
|
custom_tools: Optional[List[str]] = None,
|
|
53
|
-
llm_config: Optional[
|
|
58
|
+
llm_config: Optional[LLMConfig] = None,
|
|
54
59
|
prompt_builder: Optional[XGAPromptBuilder] = None,
|
|
55
60
|
tool_box: Optional[XGAToolBox] = None) -> 'XGATaskEngine':
|
|
56
61
|
engine: XGATaskEngine = cls(session_id=session_id,
|
|
57
62
|
task_id=task_id,
|
|
58
63
|
agent_id=agent_id,
|
|
59
|
-
trace_id=trace_id,
|
|
60
64
|
system_prompt=system_prompt,
|
|
61
65
|
llm_config=llm_config,
|
|
62
66
|
prompt_builder=prompt_builder,
|
|
63
67
|
tool_box=tool_box)
|
|
64
|
-
|
|
68
|
+
|
|
69
|
+
general_tools = general_tools or ["complete"]
|
|
65
70
|
custom_tools = custom_tools or []
|
|
66
71
|
await engine.__async_init__(general_tools, custom_tools)
|
|
72
|
+
|
|
73
|
+
logging.info("*"*30 + f" XGATaskEngine Task'{engine.task_id}' Initialized " + "*"*30)
|
|
74
|
+
logging.info(f"model_name={engine.model_name}, is_stream={engine.is_stream}, trace_id={engine.trace_id}")
|
|
75
|
+
logging.info(f"general_tools={general_tools}, custom_tools={custom_tools}")
|
|
76
|
+
|
|
67
77
|
return engine
|
|
68
78
|
|
|
69
79
|
|
|
70
|
-
async def run_task(self,
|
|
80
|
+
async def run_task(self,
|
|
81
|
+
task_message: Dict[str, Any],
|
|
82
|
+
max_auto_run: int = 25,
|
|
83
|
+
trace_id: Optional[str] = None) -> AsyncGenerator[Dict[str, Any], None]:
|
|
71
84
|
try:
|
|
72
|
-
|
|
85
|
+
self.trace_id = trace_id or self.trace_id or langfuse.create_trace_id()
|
|
73
86
|
|
|
87
|
+
self.task_no += 1
|
|
88
|
+
self.task_run_id = f"{self.task_id}[{self.task_no}]"
|
|
89
|
+
|
|
90
|
+
self.add_context_msg(type="user", content=task_message, is_llm_message=True)
|
|
91
|
+
|
|
92
|
+
if max_auto_run <= 1:
|
|
93
|
+
continuous_state:TaskRunContinuousState = {
|
|
94
|
+
"accumulated_content": "",
|
|
95
|
+
"auto_continue_count": 0,
|
|
96
|
+
"auto_continue": False
|
|
97
|
+
}
|
|
98
|
+
async for chunk in self._run_task_once(continuous_state):
|
|
99
|
+
yield chunk
|
|
100
|
+
else:
|
|
101
|
+
async for chunk in self._run_task_auto(max_auto_run):
|
|
102
|
+
yield chunk
|
|
74
103
|
finally:
|
|
75
104
|
await self.tool_box.destroy_task_tool_box(self.task_id)
|
|
76
|
-
|
|
77
105
|
|
|
78
|
-
def _run_task_once(self):
|
|
79
|
-
|
|
106
|
+
async def _run_task_once(self, continuous_state: TaskRunContinuousState) -> AsyncGenerator[Dict[str, Any], None]:
|
|
107
|
+
llm_messages = [{"role": "system", "content": self.task_prompt}]
|
|
108
|
+
cxt_llm_contents = self._get_context_llm_contents()
|
|
109
|
+
llm_messages.extend(cxt_llm_contents)
|
|
110
|
+
|
|
111
|
+
partial_content = continuous_state.get('accumulated_content', '')
|
|
112
|
+
if partial_content:
|
|
113
|
+
temp_assistant_message = {
|
|
114
|
+
"role": "assistant",
|
|
115
|
+
"content": partial_content
|
|
116
|
+
}
|
|
117
|
+
llm_messages.append(temp_assistant_message)
|
|
118
|
+
|
|
119
|
+
llm_response = await self.llm_client.create_completion(llm_messages)
|
|
120
|
+
response_processor = self._create_response_processer()
|
|
121
|
+
|
|
122
|
+
async for chunk in response_processor.process_response(llm_response, llm_messages, continuous_state):
|
|
123
|
+
yield chunk
|
|
124
|
+
|
|
80
125
|
|
|
126
|
+
async def _run_task_auto(self, max_auto_run: int) -> AsyncGenerator:
|
|
127
|
+
continuous_state: TaskRunContinuousState = {
|
|
128
|
+
"accumulated_content": "",
|
|
129
|
+
"auto_continue_count": 0,
|
|
130
|
+
"auto_continue": True
|
|
131
|
+
}
|
|
81
132
|
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
133
|
+
def update_continuous_state(_auto_continue_count, _auto_continue):
|
|
134
|
+
continuous_state["auto_continue_count"] = _auto_continue_count
|
|
135
|
+
continuous_state["auto_continue"] = _auto_continue
|
|
136
|
+
|
|
137
|
+
auto_continue_count = 0
|
|
138
|
+
auto_continue = True
|
|
139
|
+
while auto_continue and auto_continue_count < max_auto_run:
|
|
140
|
+
auto_continue = False
|
|
141
|
+
|
|
142
|
+
try:
|
|
143
|
+
async for chunk in self._run_task_once(continuous_state):
|
|
144
|
+
try:
|
|
145
|
+
if chunk.get("type") == "status":
|
|
146
|
+
content = json.loads(chunk.get('content', '{}'))
|
|
147
|
+
status_type = content.get('status_type', None)
|
|
148
|
+
if status_type == "error":
|
|
149
|
+
logging.error(f"run_task_auto: task_response error: {chunk.get('message', 'Unknown error')}")
|
|
150
|
+
yield chunk
|
|
151
|
+
return
|
|
152
|
+
elif status_type == 'finish':
|
|
153
|
+
finish_reason = content.get('finish_reason', None)
|
|
154
|
+
if finish_reason == 'stop' :
|
|
155
|
+
auto_continue = True
|
|
156
|
+
auto_continue_count += 1
|
|
157
|
+
update_continuous_state(auto_continue_count, auto_continue)
|
|
158
|
+
logging.info(f"run_task_auto: Detected finish_reason='stop', auto-continuing ({auto_continue_count}/{max_auto_run})")
|
|
159
|
+
continue
|
|
160
|
+
elif finish_reason == 'xml_tool_limit_reached':
|
|
161
|
+
logging.info(f"run_task_auto: Detected finish_reason='xml_tool_limit_reached', stopping auto-continue")
|
|
162
|
+
auto_continue = False
|
|
163
|
+
update_continuous_state(auto_continue_count, auto_continue)
|
|
164
|
+
elif finish_reason == 'length':
|
|
165
|
+
auto_continue = True
|
|
166
|
+
auto_continue_count += 1
|
|
167
|
+
update_continuous_state(auto_continue_count, auto_continue)
|
|
168
|
+
logging.info(f"run_task_auto: Detected finish_reason='length', auto-continuing ({auto_continue_count}/{max_auto_run})")
|
|
169
|
+
continue
|
|
170
|
+
except Exception as parse_error:
|
|
171
|
+
logging.error(f"run_task_auto: Error in parse chunk: {str(parse_error)}")
|
|
172
|
+
yield {
|
|
173
|
+
"type": "status",
|
|
174
|
+
"status": "error",
|
|
175
|
+
"message": f"Error in parse chunk: {str(parse_error)}"
|
|
176
|
+
}
|
|
177
|
+
return
|
|
178
|
+
|
|
179
|
+
# Otherwise just yield the chunk normally
|
|
180
|
+
yield chunk
|
|
181
|
+
|
|
182
|
+
# If not auto-continuing, we're done
|
|
183
|
+
if not auto_continue:
|
|
184
|
+
break
|
|
185
|
+
except Exception as run_error:
|
|
186
|
+
logging.error(f"run_task_auto: Call task_run_once error: {str(run_error)}")
|
|
187
|
+
yield {
|
|
188
|
+
"type": "status",
|
|
189
|
+
"status": "error",
|
|
190
|
+
"message": f"Call task_run_once error: {str(run_error)}"
|
|
191
|
+
}
|
|
192
|
+
return
|
|
193
|
+
|
|
194
|
+
def add_context_msg(self, type: Literal["user", "status", "tool", "assistant", "assistant_response_end"],
|
|
195
|
+
content: Union[Dict[str, Any], List[Any], str],
|
|
196
|
+
is_llm_message: bool,
|
|
197
|
+
metadata: Optional[Dict[str, Any]]=None)-> XGAContextMsg:
|
|
198
|
+
message = XGAContextMsg(
|
|
199
|
+
message_id = f"xga_msg_{uuid4()}",
|
|
200
|
+
type = type,
|
|
201
|
+
content = content,
|
|
202
|
+
is_llm_message=is_llm_message,
|
|
203
|
+
metadata = metadata,
|
|
204
|
+
session_id = self.session_id,
|
|
205
|
+
agent_id = self.agent_id,
|
|
206
|
+
task_id = self.task_id,
|
|
207
|
+
task_run_id = self.task_run_id,
|
|
208
|
+
trace_id = self.trace_id
|
|
209
|
+
)
|
|
210
|
+
self.task_context_msgs.append(message)
|
|
211
|
+
|
|
212
|
+
return message
|
|
213
|
+
|
|
214
|
+
def _get_context_llm_contents (self) -> List[Dict[str, Any]]:
|
|
215
|
+
llm_messages = []
|
|
216
|
+
for message in self.task_context_msgs:
|
|
217
|
+
if message["is_llm_message"]:
|
|
218
|
+
llm_messages.append(message)
|
|
219
|
+
|
|
220
|
+
cxt_llm_contents = []
|
|
221
|
+
for llm_message in llm_messages:
|
|
222
|
+
content = llm_message["content"]
|
|
223
|
+
# @todo content List type
|
|
224
|
+
if isinstance(content, str):
|
|
225
|
+
try:
|
|
226
|
+
_content = json.loads(content)
|
|
227
|
+
cxt_llm_contents.append(_content)
|
|
228
|
+
except json.JSONDecodeError as e:
|
|
229
|
+
logging.error(f"get_context_llm_contents: Failed to decode json, content=:{content}")
|
|
230
|
+
handle_error(e)
|
|
231
|
+
else:
|
|
232
|
+
cxt_llm_contents.append(content)
|
|
233
|
+
|
|
234
|
+
return cxt_llm_contents
|
|
235
|
+
|
|
236
|
+
def _create_response_processer(self) -> TaskResponseProcessor:
|
|
237
|
+
response_context = self._create_response_context()
|
|
238
|
+
is_stream = response_context.get("is_stream", False)
|
|
239
|
+
if is_stream:
|
|
240
|
+
from xgae.engine.responser.xga_stream_responser import StreamTaskResponser
|
|
241
|
+
return StreamTaskResponser(response_context)
|
|
242
|
+
else:
|
|
243
|
+
from xgae.engine.responser.xga_non_stream_responser import NonStreamTaskResponser
|
|
244
|
+
return NonStreamTaskResponser(response_context)
|
|
245
|
+
|
|
246
|
+
def _create_response_context(self) -> TaskResponseContext:
|
|
247
|
+
response_context: TaskResponseContext = {
|
|
248
|
+
"is_stream": self.is_stream,
|
|
249
|
+
"task_id": self.task_id,
|
|
250
|
+
"task_run_id": self.task_run_id,
|
|
251
|
+
"trace_id": self.trace_id,
|
|
252
|
+
"model_name": self.model_name,
|
|
253
|
+
"max_xml_tool_calls": 0,
|
|
254
|
+
"add_context_msg": self.add_context_msg,
|
|
255
|
+
"tool_box": self.tool_box,
|
|
256
|
+
"tool_execution_strategy": "parallel",
|
|
257
|
+
"xml_adding_strategy": "user_message",
|
|
258
|
+
}
|
|
259
|
+
return response_context
|
|
87
260
|
|
|
88
261
|
if __name__ == "__main__":
|
|
89
262
|
import asyncio
|
|
90
|
-
|
|
263
|
+
from xgae.utils.utils import read_file
|
|
91
264
|
async def main():
|
|
92
265
|
tool_box = XGAMcpToolBox(custom_mcp_server_file="mcpservers/custom_servers.json")
|
|
93
|
-
|
|
94
|
-
|
|
266
|
+
system_prompt = read_file("templates/scp_test_prompt.txt")
|
|
267
|
+
engine = await XGATaskEngine.create(tool_box=tool_box,
|
|
268
|
+
general_tools=[],
|
|
269
|
+
custom_tools=["bomc_fault.*"],
|
|
270
|
+
llm_config=LLMConfig(stream=False),
|
|
271
|
+
system_prompt=system_prompt)
|
|
272
|
+
#engine = await XGATaskEngine.create()
|
|
95
273
|
|
|
96
|
-
async for chunk in engine.run_task(
|
|
274
|
+
async for chunk in engine.run_task(task_message={"role": "user", "content": "定位10.0.0.1的故障"},
|
|
275
|
+
max_auto_run=8):
|
|
97
276
|
print(chunk)
|
|
98
277
|
|
|
99
278
|
asyncio.run(main())
|
xgae/engine/xga_mcp_tool_box.py
CHANGED
|
@@ -75,6 +75,12 @@ class XGAMcpToolBox(XGAToolBox):
|
|
|
75
75
|
await self.call_tool(task_id, "end_task", {"task_id": task_id})
|
|
76
76
|
self.task_tool_schemas.pop(task_id, None)
|
|
77
77
|
|
|
78
|
+
@override
|
|
79
|
+
def get_task_tool_names(self, task_id: str) -> List[str]:
|
|
80
|
+
task_tool_schema = self.task_tool_schemas.get(task_id, {})
|
|
81
|
+
task_tool_names = list(task_tool_schema.keys())
|
|
82
|
+
return task_tool_names
|
|
83
|
+
|
|
78
84
|
@override
|
|
79
85
|
def get_task_tool_schemas(self, task_id: str, type: Literal["general_tool", "custom_tool"]) -> List[XGAToolSchema]:
|
|
80
86
|
task_tool_schemas = []
|
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
"""
|
|
2
|
+
JSON helper utilities for handling both legacy (string) and new (dict/list) formats.
|
|
3
|
+
|
|
4
|
+
These utilities help with the transition from storing JSON as strings to storing
|
|
5
|
+
them as proper JSONB objects in the database.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
from typing import Any, Union, Dict, List
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def ensure_dict(value: Union[str, Dict[str, Any], None], default: Dict[str, Any] = None) -> Dict[str, Any]:
|
|
13
|
+
"""
|
|
14
|
+
Ensure a value is a dictionary.
|
|
15
|
+
|
|
16
|
+
Handles:
|
|
17
|
+
- None -> returns default or {}
|
|
18
|
+
- Dict -> returns as-is
|
|
19
|
+
- JSON string -> parses and returns dict
|
|
20
|
+
- Other -> returns default or {}
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
value: The value to ensure is a dict
|
|
24
|
+
default: Default value if conversion fails
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
A dictionary
|
|
28
|
+
"""
|
|
29
|
+
if default is None:
|
|
30
|
+
default = {}
|
|
31
|
+
|
|
32
|
+
if value is None:
|
|
33
|
+
return default
|
|
34
|
+
|
|
35
|
+
if isinstance(value, dict):
|
|
36
|
+
return value
|
|
37
|
+
|
|
38
|
+
if isinstance(value, str):
|
|
39
|
+
try:
|
|
40
|
+
parsed = json.loads(value)
|
|
41
|
+
if isinstance(parsed, dict):
|
|
42
|
+
return parsed
|
|
43
|
+
return default
|
|
44
|
+
except (json.JSONDecodeError, TypeError):
|
|
45
|
+
return default
|
|
46
|
+
|
|
47
|
+
return default
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def ensure_list(value: Union[str, List[Any], None], default: List[Any] = None) -> List[Any]:
|
|
51
|
+
"""
|
|
52
|
+
Ensure a value is a list.
|
|
53
|
+
|
|
54
|
+
Handles:
|
|
55
|
+
- None -> returns default or []
|
|
56
|
+
- List -> returns as-is
|
|
57
|
+
- JSON string -> parses and returns list
|
|
58
|
+
- Other -> returns default or []
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
value: The value to ensure is a list
|
|
62
|
+
default: Default value if conversion fails
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
A list
|
|
66
|
+
"""
|
|
67
|
+
if default is None:
|
|
68
|
+
default = []
|
|
69
|
+
|
|
70
|
+
if value is None:
|
|
71
|
+
return default
|
|
72
|
+
|
|
73
|
+
if isinstance(value, list):
|
|
74
|
+
return value
|
|
75
|
+
|
|
76
|
+
if isinstance(value, str):
|
|
77
|
+
try:
|
|
78
|
+
parsed = json.loads(value)
|
|
79
|
+
if isinstance(parsed, list):
|
|
80
|
+
return parsed
|
|
81
|
+
return default
|
|
82
|
+
except (json.JSONDecodeError, TypeError):
|
|
83
|
+
return default
|
|
84
|
+
|
|
85
|
+
return default
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def safe_json_parse(value: Union[str, Dict, List, Any], default: Any = None) -> Any:
|
|
89
|
+
"""
|
|
90
|
+
Safely parse a value that might be JSON string or already parsed.
|
|
91
|
+
|
|
92
|
+
This handles the transition period where some data might be stored as
|
|
93
|
+
JSON strings (old format) and some as proper objects (new format).
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
value: The value to parse
|
|
97
|
+
default: Default value if parsing fails
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
Parsed value or default
|
|
101
|
+
"""
|
|
102
|
+
if value is None:
|
|
103
|
+
return default
|
|
104
|
+
|
|
105
|
+
# If it's already a dict or list, return as-is
|
|
106
|
+
if isinstance(value, (dict, list)):
|
|
107
|
+
return value
|
|
108
|
+
|
|
109
|
+
# If it's a string, try to parse it
|
|
110
|
+
if isinstance(value, str):
|
|
111
|
+
try:
|
|
112
|
+
return json.loads(value)
|
|
113
|
+
except (json.JSONDecodeError, TypeError):
|
|
114
|
+
# If it's not valid JSON, return the string itself
|
|
115
|
+
return value
|
|
116
|
+
|
|
117
|
+
# For any other type, return as-is
|
|
118
|
+
return value
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def to_json_string(value: Any) -> str:
|
|
122
|
+
"""
|
|
123
|
+
Convert a value to a JSON string if needed.
|
|
124
|
+
|
|
125
|
+
This is used for backwards compatibility when yielding data that
|
|
126
|
+
expects JSON strings.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
value: The value to convert
|
|
130
|
+
|
|
131
|
+
Returns:
|
|
132
|
+
JSON string representation
|
|
133
|
+
"""
|
|
134
|
+
if isinstance(value, str):
|
|
135
|
+
# If it's already a string, check if it's valid JSON
|
|
136
|
+
try:
|
|
137
|
+
json.loads(value)
|
|
138
|
+
return value # It's already a JSON string
|
|
139
|
+
except (json.JSONDecodeError, TypeError):
|
|
140
|
+
# It's a plain string, encode it as JSON
|
|
141
|
+
return json.dumps(value)
|
|
142
|
+
|
|
143
|
+
# For all other types, convert to JSON
|
|
144
|
+
return json.dumps(value)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def format_for_yield(message_object: Dict[str, Any]) -> Dict[str, Any]:
|
|
148
|
+
"""
|
|
149
|
+
Format a message object for yielding, ensuring content and metadata are JSON strings.
|
|
150
|
+
|
|
151
|
+
This maintains backward compatibility with clients expecting JSON strings
|
|
152
|
+
while the database now stores proper objects.
|
|
153
|
+
|
|
154
|
+
Args:
|
|
155
|
+
message_object: The message object from the database
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
Message object with content and metadata as JSON strings
|
|
159
|
+
"""
|
|
160
|
+
if not message_object:
|
|
161
|
+
return message_object
|
|
162
|
+
|
|
163
|
+
# Create a copy to avoid modifying the original
|
|
164
|
+
formatted = message_object.copy()
|
|
165
|
+
|
|
166
|
+
# Ensure content is a JSON string
|
|
167
|
+
if 'content' in formatted and not isinstance(formatted['content'], str):
|
|
168
|
+
formatted['content'] = json.dumps(formatted['content'])
|
|
169
|
+
|
|
170
|
+
# Ensure metadata is a JSON string
|
|
171
|
+
if 'metadata' in formatted and not isinstance(formatted['metadata'], str):
|
|
172
|
+
formatted['metadata'] = json.dumps(formatted['metadata'])
|
|
173
|
+
|
|
174
|
+
return formatted
|
xgae/utils/llm_client.py
CHANGED
|
@@ -4,11 +4,24 @@ import logging
|
|
|
4
4
|
import os
|
|
5
5
|
import litellm
|
|
6
6
|
|
|
7
|
-
from typing import Union, Dict, Any, Optional, List
|
|
7
|
+
from typing import Union, Dict, Any, Optional, List, TypedDict
|
|
8
8
|
|
|
9
9
|
from litellm.utils import ModelResponse, CustomStreamWrapper
|
|
10
10
|
from openai import OpenAIError
|
|
11
11
|
|
|
12
|
+
class LLMConfig(TypedDict, total=False):
|
|
13
|
+
model: str
|
|
14
|
+
model_name: str
|
|
15
|
+
model_id: str
|
|
16
|
+
api_key: str
|
|
17
|
+
api_base: str
|
|
18
|
+
temperature: float
|
|
19
|
+
max_tokens: int
|
|
20
|
+
stream: bool
|
|
21
|
+
enable_thinking: bool
|
|
22
|
+
reasoning_effort: str
|
|
23
|
+
response_format: str
|
|
24
|
+
top_p: int
|
|
12
25
|
|
|
13
26
|
class LLMError(Exception):
|
|
14
27
|
"""Base exception for LLM-related errors."""
|
|
@@ -18,7 +31,7 @@ class LLMClient:
|
|
|
18
31
|
RATE_LIMIT_DELAY = 30
|
|
19
32
|
RETRY_DELAY = 0.1
|
|
20
33
|
|
|
21
|
-
def __init__(self, llm_config:
|
|
34
|
+
def __init__(self, llm_config: LLMConfig=None) -> None:
|
|
22
35
|
"""
|
|
23
36
|
Arg: llm_config (Optional[Dict[str, Any]], optional)
|
|
24
37
|
model: Override default model to use, default set by .env LLM_MODEL
|
|
@@ -34,7 +47,7 @@ class LLMClient:
|
|
|
34
47
|
reasoning_effort: Optional level of reasoning effort, default is ‘low’
|
|
35
48
|
top_p: Optional Top-p sampling parameter, default is None
|
|
36
49
|
"""
|
|
37
|
-
llm_config = llm_config or
|
|
50
|
+
llm_config = llm_config or LLMConfig()
|
|
38
51
|
litellm.modify_params = True
|
|
39
52
|
litellm.drop_params = True
|
|
40
53
|
|
|
@@ -214,9 +227,7 @@ class LLMClient:
|
|
|
214
227
|
|
|
215
228
|
if __name__ == "__main__":
|
|
216
229
|
async def llm_completion():
|
|
217
|
-
llm_client = LLMClient(
|
|
218
|
-
"stream": False #default is True
|
|
219
|
-
})
|
|
230
|
+
llm_client = LLMClient(LLMConfig(stream=False))
|
|
220
231
|
messages = [{"role": "user", "content": "今天是2025年8月15日,北京本周每天温度"}]
|
|
221
232
|
response = await llm_client.create_completion(messages)
|
|
222
233
|
if llm_client.is_stream:
|
xgae/utils/setup_env.py
CHANGED
|
@@ -89,8 +89,5 @@ langfuse: Langfuse = Langfuse if _langfuse_initialized else setup_langfuse()
|
|
|
89
89
|
|
|
90
90
|
|
|
91
91
|
if __name__ == "__main__":
|
|
92
|
-
try:
|
|
93
92
|
trace_id = langfuse.create_trace_id()
|
|
94
|
-
|
|
95
|
-
except Exception as e:
|
|
96
|
-
handle_error(e)
|
|
93
|
+
logging.warning(f"trace_id={trace_id}")
|