xgae 0.1.7__py3-none-any.whl → 0.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xgae might be problematic. Click here for more details.
- xgae/engine/engine_base.py +2 -2
- xgae/engine/mcp_tool_box.py +7 -6
- xgae/engine/prompt_builder.py +1 -1
- xgae/engine/responser/non_stream_responser.py +8 -7
- xgae/engine/responser/responser_base.py +54 -56
- xgae/engine/responser/stream_responser.py +24 -25
- xgae/engine/task_engine.py +64 -61
- xgae/engine/task_langfuse.py +63 -0
- xgae/tools/without_general_tools_app.py +1 -1
- xgae/utils/__init__.py +7 -5
- xgae/utils/json_helpers.py +7 -13
- xgae/utils/llm_client.py +62 -39
- xgae/utils/misc.py +3 -1
- xgae/utils/setup_env.py +36 -33
- xgae/utils/xml_tool_parser.py +4 -80
- xgae-0.1.9.dist-info/METADATA +11 -0
- xgae-0.1.9.dist-info/RECORD +20 -0
- {xgae-0.1.7.dist-info → xgae-0.1.9.dist-info}/entry_points.txt +1 -0
- xgae-0.1.7.dist-info/METADATA +0 -11
- xgae-0.1.7.dist-info/RECORD +0 -19
- {xgae-0.1.7.dist-info → xgae-0.1.9.dist-info}/WHEEL +0 -0
xgae/engine/task_engine.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
|
|
2
1
|
import logging
|
|
3
2
|
import json
|
|
4
3
|
import os
|
|
@@ -6,15 +5,15 @@ import os
|
|
|
6
5
|
from typing import List, Any, Dict, Optional, AsyncGenerator, Union, Literal
|
|
7
6
|
from uuid import uuid4
|
|
8
7
|
|
|
9
|
-
from xgae.
|
|
10
|
-
from xgae.engine.engine_base import XGAResponseMsgType, XGAResponseMessage, XGAToolBox, XGATaskResult
|
|
11
|
-
|
|
12
|
-
from xgae.utils import langfuse, handle_error
|
|
8
|
+
from xgae.utils import handle_error
|
|
13
9
|
from xgae.utils.llm_client import LLMClient, LLMConfig
|
|
14
|
-
|
|
15
10
|
from xgae.utils.json_helpers import format_for_yield
|
|
11
|
+
|
|
12
|
+
from xgae.engine.engine_base import XGAResponseMsgType, XGAResponseMessage, XGAToolBox, XGATaskResult
|
|
13
|
+
from xgae.engine.task_langfuse import XGATaskLangFuse
|
|
16
14
|
from xgae.engine.prompt_builder import XGAPromptBuilder
|
|
17
15
|
from xgae.engine.mcp_tool_box import XGAMcpToolBox
|
|
16
|
+
from xgae.engine.responser.responser_base import TaskResponserContext, TaskResponseProcessor, TaskRunContinuousState
|
|
18
17
|
|
|
19
18
|
class XGATaskEngine:
|
|
20
19
|
def __init__(self,
|
|
@@ -39,6 +38,7 @@ class XGATaskEngine:
|
|
|
39
38
|
|
|
40
39
|
self.prompt_builder = prompt_builder or XGAPromptBuilder(system_prompt)
|
|
41
40
|
self.tool_box: XGAToolBox = tool_box or XGAMcpToolBox()
|
|
41
|
+
self.task_langfuse: XGATaskLangFuse = None
|
|
42
42
|
|
|
43
43
|
self.general_tools:List[str] = general_tools
|
|
44
44
|
self.custom_tools:List[str] = custom_tools
|
|
@@ -50,21 +50,17 @@ class XGATaskEngine:
|
|
|
50
50
|
|
|
51
51
|
self.task_no = -1
|
|
52
52
|
self.task_run_id :str = None
|
|
53
|
-
|
|
54
53
|
self.task_prompt :str = None
|
|
55
|
-
|
|
56
|
-
self.root_span_id :str = None
|
|
54
|
+
|
|
57
55
|
|
|
58
56
|
async def run_task_with_final_answer(self,
|
|
59
57
|
task_message: Dict[str, Any],
|
|
60
58
|
trace_id: Optional[str] = None) -> XGATaskResult:
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
) as root_span:
|
|
67
|
-
self.root_span_id = root_span.id
|
|
59
|
+
final_result:XGATaskResult = None
|
|
60
|
+
try:
|
|
61
|
+
await self._init_task()
|
|
62
|
+
|
|
63
|
+
self.task_langfuse.start_root_span("run_task_with_final_answer", task_message, trace_id)
|
|
68
64
|
|
|
69
65
|
chunks = []
|
|
70
66
|
async for chunk in self.run_task(task_message=task_message, trace_id=trace_id):
|
|
@@ -75,8 +71,9 @@ class XGATaskEngine:
|
|
|
75
71
|
else:
|
|
76
72
|
final_result = XGATaskResult(type="error", content="LLM Answer is Empty")
|
|
77
73
|
|
|
78
|
-
root_span.update(output=final_result)
|
|
79
74
|
return final_result
|
|
75
|
+
finally:
|
|
76
|
+
self.task_langfuse.end_root_span("run_task_with_final_answer", final_result)
|
|
80
77
|
|
|
81
78
|
|
|
82
79
|
async def run_task(self,
|
|
@@ -84,60 +81,58 @@ class XGATaskEngine:
|
|
|
84
81
|
trace_id: Optional[str] = None) -> AsyncGenerator[Dict[str, Any], None]:
|
|
85
82
|
try:
|
|
86
83
|
await self._init_task()
|
|
87
|
-
if self.root_span_id is None:
|
|
88
|
-
self.trace_id = trace_id or langfuse.create_trace_id()
|
|
89
|
-
with langfuse.start_as_current_span(trace_context={"trace_id": self.trace_id},
|
|
90
|
-
name="run_task",
|
|
91
|
-
input=task_message
|
|
92
|
-
) as root_span:
|
|
93
|
-
self.root_span_id = root_span.id
|
|
94
84
|
|
|
85
|
+
self.task_langfuse.start_root_span("run_task", task_message, trace_id)
|
|
95
86
|
|
|
96
87
|
self.add_response_message(type="user", content=task_message, is_llm_message=True)
|
|
97
88
|
|
|
98
|
-
|
|
99
|
-
"accumulated_content": "",
|
|
100
|
-
"auto_continue_count": 0,
|
|
101
|
-
"auto_continue": False if self.max_auto_run <= 1 else True
|
|
102
|
-
}
|
|
103
|
-
async for chunk in self._run_task_auto(continuous_state):
|
|
89
|
+
async for chunk in self._run_task_auto():
|
|
104
90
|
yield chunk
|
|
105
91
|
finally:
|
|
106
92
|
await self.tool_box.destroy_task_tool_box(self.task_id)
|
|
107
|
-
self.
|
|
108
|
-
|
|
93
|
+
self.task_langfuse.end_root_span("run_task")
|
|
94
|
+
self.task_run_id = None
|
|
109
95
|
|
|
110
96
|
async def _init_task(self) -> None:
|
|
111
|
-
|
|
112
|
-
|
|
97
|
+
if self.task_run_id is None:
|
|
98
|
+
self.task_no = self.task_no + 1
|
|
99
|
+
self.task_run_id = f"{self.task_id}[{self.task_no}]"
|
|
100
|
+
|
|
101
|
+
self.task_langfuse =self._create_task_langfuse()
|
|
113
102
|
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
103
|
+
general_tools = self.general_tools or ["complete", "ask"]
|
|
104
|
+
if "*" not in general_tools:
|
|
105
|
+
if "complete" not in general_tools:
|
|
106
|
+
general_tools.append("complete")
|
|
107
|
+
elif "ask" not in general_tools:
|
|
108
|
+
general_tools.append("ask")
|
|
120
109
|
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
110
|
+
custom_tools = self.custom_tools or []
|
|
111
|
+
if isinstance(self.tool_box, XGAMcpToolBox):
|
|
112
|
+
await self.tool_box.load_mcp_tools_schema()
|
|
124
113
|
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
114
|
+
await self.tool_box.creat_task_tool_box(self.task_id, general_tools, custom_tools)
|
|
115
|
+
general_tool_schemas = self.tool_box.get_task_tool_schemas(self.task_id, "general_tool")
|
|
116
|
+
custom_tool_schemas = self.tool_box.get_task_tool_schemas(self.task_id, "custom_tool")
|
|
128
117
|
|
|
129
|
-
|
|
118
|
+
self.task_prompt = self.prompt_builder.build_task_prompt(self.model_name, general_tool_schemas, custom_tool_schemas)
|
|
130
119
|
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
120
|
+
logging.info("*" * 30 + f" XGATaskEngine Task'{self.task_id}' Initialized " + "*" * 30)
|
|
121
|
+
logging.info(f"model_name={self.model_name}, is_stream={self.is_stream}")
|
|
122
|
+
logging.info(f"general_tools={general_tools}, custom_tools={custom_tools}")
|
|
134
123
|
|
|
135
124
|
|
|
136
|
-
async def _run_task_auto(self
|
|
125
|
+
async def _run_task_auto(self) -> AsyncGenerator[Dict[str, Any], None]:
|
|
137
126
|
def update_continuous_state(_auto_continue_count, _auto_continue):
|
|
138
127
|
continuous_state["auto_continue_count"] = _auto_continue_count
|
|
139
128
|
continuous_state["auto_continue"] = _auto_continue
|
|
140
129
|
|
|
130
|
+
continuous_state: TaskRunContinuousState = {
|
|
131
|
+
"accumulated_content": "",
|
|
132
|
+
"auto_continue_count": 0,
|
|
133
|
+
"auto_continue": False if self.max_auto_run <= 1 else True
|
|
134
|
+
}
|
|
135
|
+
|
|
141
136
|
auto_continue_count = 0
|
|
142
137
|
auto_continue = True
|
|
143
138
|
while auto_continue and auto_continue_count < self.max_auto_run:
|
|
@@ -196,14 +191,16 @@ class XGATaskEngine:
|
|
|
196
191
|
}
|
|
197
192
|
llm_messages.append(temp_assistant_message)
|
|
198
193
|
|
|
199
|
-
|
|
194
|
+
llm_count = continuous_state.get("auto_continue_count")
|
|
195
|
+
langfuse_metadata = self.task_langfuse.create_llm_langfuse_meta(llm_count)
|
|
196
|
+
|
|
197
|
+
llm_response = await self.llm_client.create_completion(llm_messages, langfuse_metadata)
|
|
200
198
|
response_processor = self._create_response_processer()
|
|
201
199
|
|
|
202
200
|
async for chunk in response_processor.process_response(llm_response, llm_messages, continuous_state):
|
|
203
201
|
self._logging_reponse_chunk(chunk)
|
|
204
202
|
yield chunk
|
|
205
203
|
|
|
206
|
-
|
|
207
204
|
def _parse_final_result(self, chunks: List[Dict[str, Any]]) -> XGATaskResult:
|
|
208
205
|
final_result: XGATaskResult = None
|
|
209
206
|
try:
|
|
@@ -241,7 +238,7 @@ class XGATaskEngine:
|
|
|
241
238
|
result_type = "answer" if success else "error"
|
|
242
239
|
result_content = f"Task execute '{tool_name}' {result_type}: {output}"
|
|
243
240
|
final_result = XGATaskResult(type=result_type, content=result_content)
|
|
244
|
-
elif chunk_type == "
|
|
241
|
+
elif chunk_type == "assistant_complete" and finish_reason == 'stop':
|
|
245
242
|
assis_content = chunk.get('content', {})
|
|
246
243
|
result_content = assis_content.get("content", "LLM output is empty")
|
|
247
244
|
final_result = XGATaskResult(type="answer", content=result_content)
|
|
@@ -263,7 +260,7 @@ class XGATaskEngine:
|
|
|
263
260
|
metadata = metadata or {}
|
|
264
261
|
metadata["task_id"] = self.task_id
|
|
265
262
|
metadata["task_run_id"] = self.task_run_id
|
|
266
|
-
metadata["trace_id"] = self.trace_id
|
|
263
|
+
metadata["trace_id"] = self.task_langfuse.trace_id
|
|
267
264
|
metadata["session_id"] = self.session_id
|
|
268
265
|
metadata["agent_id"] = self.agent_id
|
|
269
266
|
|
|
@@ -300,7 +297,6 @@ class XGATaskEngine:
|
|
|
300
297
|
|
|
301
298
|
return response_llm_contents
|
|
302
299
|
|
|
303
|
-
|
|
304
300
|
def _create_response_processer(self) -> TaskResponseProcessor:
|
|
305
301
|
response_context = self._create_response_context()
|
|
306
302
|
is_stream = response_context.get("is_stream", False)
|
|
@@ -316,18 +312,22 @@ class XGATaskEngine:
|
|
|
316
312
|
"is_stream": self.is_stream,
|
|
317
313
|
"task_id": self.task_id,
|
|
318
314
|
"task_run_id": self.task_run_id,
|
|
319
|
-
"
|
|
320
|
-
"root_span_id": self.root_span_id,
|
|
315
|
+
"task_no": self.task_no,
|
|
321
316
|
"model_name": self.model_name,
|
|
322
317
|
"max_xml_tool_calls": 0,
|
|
318
|
+
"tool_execution_strategy": "parallel" if self.tool_exec_parallel else "sequential", # ,
|
|
319
|
+
"xml_adding_strategy": "user_message",
|
|
323
320
|
"add_response_msg_func": self.add_response_message,
|
|
324
321
|
"tool_box": self.tool_box,
|
|
325
|
-
"
|
|
326
|
-
"xml_adding_strategy": "user_message",
|
|
322
|
+
"task_langfuse": self.task_langfuse,
|
|
327
323
|
}
|
|
328
324
|
return response_context
|
|
329
325
|
|
|
330
326
|
|
|
327
|
+
def _create_task_langfuse(self)-> XGATaskLangFuse:
|
|
328
|
+
return XGATaskLangFuse(self.session_id, self.task_id, self.task_run_id, self.task_no, self.agent_id)
|
|
329
|
+
|
|
330
|
+
|
|
331
331
|
def _logging_reponse_chunk(self, chunk):
|
|
332
332
|
chunk_type = chunk.get('type')
|
|
333
333
|
prefix = ""
|
|
@@ -350,6 +350,7 @@ if __name__ == "__main__":
|
|
|
350
350
|
from xgae.utils.misc import read_file
|
|
351
351
|
|
|
352
352
|
async def main():
|
|
353
|
+
# Before Run Exec: uv run custom_fault_tools
|
|
353
354
|
tool_box = XGAMcpToolBox(custom_mcp_server_file="mcpservers/custom_servers.json")
|
|
354
355
|
system_prompt = read_file("templates/example_user_prompt.txt")
|
|
355
356
|
engine = XGATaskEngine(tool_box=tool_box,
|
|
@@ -357,7 +358,9 @@ if __name__ == "__main__":
|
|
|
357
358
|
custom_tools=["*"],
|
|
358
359
|
llm_config=LLMConfig(stream=False),
|
|
359
360
|
system_prompt=system_prompt,
|
|
360
|
-
max_auto_run=8
|
|
361
|
+
max_auto_run=8,
|
|
362
|
+
session_id="session_1",
|
|
363
|
+
agent_id="agent_1",)
|
|
361
364
|
|
|
362
365
|
final_result = await engine.run_task_with_final_answer(task_message={"role": "user",
|
|
363
366
|
"content": "locate 10.0.0.1 fault and solution"})
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
|
|
2
|
+
from typing import Any, Dict, Optional
|
|
3
|
+
from langfuse import Langfuse
|
|
4
|
+
|
|
5
|
+
from xgae.utils.setup_env import setup_langfuse, setup_env_logging
|
|
6
|
+
from xgae.utils.llm_client import LangfuseMetadata
|
|
7
|
+
from xgae.engine.engine_base import XGATaskResult
|
|
8
|
+
|
|
9
|
+
setup_env_logging()
|
|
10
|
+
langfuse:Langfuse = setup_langfuse()
|
|
11
|
+
|
|
12
|
+
class XGATaskLangFuse:
|
|
13
|
+
def __init__(self,
|
|
14
|
+
session_id: str,
|
|
15
|
+
task_id:str,
|
|
16
|
+
task_run_id: str,
|
|
17
|
+
task_no: int,
|
|
18
|
+
agent_id: str) -> None:
|
|
19
|
+
self.session_id = session_id
|
|
20
|
+
self.task_id = task_id
|
|
21
|
+
self.task_run_id = task_run_id
|
|
22
|
+
self.task_no = task_no
|
|
23
|
+
self.agent_id = agent_id
|
|
24
|
+
|
|
25
|
+
self.trace_id = None
|
|
26
|
+
self.root_span = None
|
|
27
|
+
self.root_span_name = None
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def start_root_span(self,
|
|
31
|
+
root_span_name: str,
|
|
32
|
+
task_message: Dict[str, Any],
|
|
33
|
+
trace_id: Optional[str] = None):
|
|
34
|
+
if self.root_span is None:
|
|
35
|
+
trace = None
|
|
36
|
+
if trace_id:
|
|
37
|
+
self.trace_id = trace_id
|
|
38
|
+
trace = langfuse.trace(id=trace_id)
|
|
39
|
+
else:
|
|
40
|
+
trace = langfuse.trace(name="xga_task_engine")
|
|
41
|
+
self.trace_id = trace.id
|
|
42
|
+
|
|
43
|
+
metadata = {"task_id": self.task_id, "session_id": self.session_id, "agent_id": self.agent_id}
|
|
44
|
+
self.root_span = trace.span(id=self.task_run_id, name=root_span_name, input=task_message,metadata=metadata)
|
|
45
|
+
self.root_span_name = root_span_name
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def end_root_span(self, root_span_name:str, output: Optional[XGATaskResult]=None):
|
|
49
|
+
if self.root_span and self.root_span_name == root_span_name:
|
|
50
|
+
self.root_span.end(output=output)
|
|
51
|
+
self.root_span = None
|
|
52
|
+
self.root_span_name = None
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def create_llm_langfuse_meta(self, llm_count:int)-> LangfuseMetadata:
|
|
56
|
+
generation_name = f"xga_task_engine_llm_completion[{self.task_no}]({llm_count})"
|
|
57
|
+
generation_id = f"{self.task_run_id}({llm_count})"
|
|
58
|
+
return LangfuseMetadata(
|
|
59
|
+
generation_name=generation_name,
|
|
60
|
+
generation_id=generation_id,
|
|
61
|
+
existing_trace_id=self.trace_id,
|
|
62
|
+
session_id=self.session_id,
|
|
63
|
+
)
|
|
@@ -41,7 +41,7 @@ async def end_task(task_id: str) :
|
|
|
41
41
|
|
|
42
42
|
|
|
43
43
|
def main():
|
|
44
|
-
|
|
44
|
+
print("="*20 + " XGAE Message Tools Sever Started in Stdio mode " + "="*20)
|
|
45
45
|
mcp.run(transport="stdio")
|
|
46
46
|
|
|
47
47
|
if __name__ == "__main__":
|
xgae/utils/__init__.py
CHANGED
|
@@ -1,13 +1,15 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
|
|
3
|
-
from .setup_env import setup_langfuse, setup_logging
|
|
4
|
-
|
|
5
|
-
setup_logging()
|
|
6
|
-
langfuse = setup_langfuse()
|
|
7
|
-
|
|
8
3
|
def handle_error(e: Exception) -> None:
|
|
9
4
|
import traceback
|
|
10
5
|
|
|
11
6
|
logging.error("An error occurred: %s", str(e))
|
|
12
7
|
logging.error("Traceback details:\n%s", traceback.format_exc())
|
|
13
8
|
raise (e) from e
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def to_bool(value: str) -> bool:
|
|
12
|
+
if value is None:
|
|
13
|
+
return False
|
|
14
|
+
|
|
15
|
+
return True if value.lower() == "true" else False
|
xgae/utils/json_helpers.py
CHANGED
|
@@ -26,8 +26,7 @@ def ensure_dict(value: Union[str, Dict[str, Any], None], default: Dict[str, Any]
|
|
|
26
26
|
Returns:
|
|
27
27
|
A dictionary
|
|
28
28
|
"""
|
|
29
|
-
|
|
30
|
-
default = {}
|
|
29
|
+
default = default or {}
|
|
31
30
|
|
|
32
31
|
if value is None:
|
|
33
32
|
return default
|
|
@@ -64,8 +63,7 @@ def ensure_list(value: Union[str, List[Any], None], default: List[Any] = None) -
|
|
|
64
63
|
Returns:
|
|
65
64
|
A list
|
|
66
65
|
"""
|
|
67
|
-
|
|
68
|
-
default = []
|
|
66
|
+
default = default or []
|
|
69
67
|
|
|
70
68
|
if value is None:
|
|
71
69
|
return default
|
|
@@ -84,7 +82,7 @@ def ensure_list(value: Union[str, List[Any], None], default: List[Any] = None) -
|
|
|
84
82
|
|
|
85
83
|
return default
|
|
86
84
|
|
|
87
|
-
|
|
85
|
+
# @todo if all call value is str, delete useless code
|
|
88
86
|
def safe_json_parse(value: Union[str, Dict, List, Any], default: Any = None) -> Any:
|
|
89
87
|
"""
|
|
90
88
|
Safely parse a value that might be JSON string or already parsed.
|
|
@@ -105,16 +103,13 @@ def safe_json_parse(value: Union[str, Dict, List, Any], default: Any = None) ->
|
|
|
105
103
|
# If it's already a dict or list, return as-is
|
|
106
104
|
if isinstance(value, (dict, list)):
|
|
107
105
|
return value
|
|
108
|
-
|
|
109
|
-
# If it's a string, try to parse it
|
|
106
|
+
|
|
110
107
|
if isinstance(value, str):
|
|
111
108
|
try:
|
|
112
109
|
return json.loads(value)
|
|
113
110
|
except (json.JSONDecodeError, TypeError):
|
|
114
|
-
# If it's not valid JSON, return the string itself
|
|
115
111
|
return value
|
|
116
|
-
|
|
117
|
-
# For any other type, return as-is
|
|
112
|
+
|
|
118
113
|
return value
|
|
119
114
|
|
|
120
115
|
|
|
@@ -137,9 +132,8 @@ def to_json_string(value: Any) -> str:
|
|
|
137
132
|
json.loads(value)
|
|
138
133
|
return value # It's already a JSON string
|
|
139
134
|
except (json.JSONDecodeError, TypeError):
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
135
|
+
pass
|
|
136
|
+
|
|
143
137
|
# For all other types, convert to JSON
|
|
144
138
|
return json.dumps(value)
|
|
145
139
|
|
xgae/utils/llm_client.py
CHANGED
|
@@ -1,27 +1,36 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import json
|
|
3
|
-
import logging
|
|
4
3
|
import os
|
|
4
|
+
import logging
|
|
5
5
|
import litellm
|
|
6
6
|
|
|
7
7
|
from typing import Union, Dict, Any, Optional, List, TypedDict
|
|
8
|
-
|
|
9
|
-
from litellm.utils import ModelResponse, CustomStreamWrapper
|
|
10
8
|
from openai import OpenAIError
|
|
9
|
+
from litellm.utils import ModelResponse, CustomStreamWrapper
|
|
10
|
+
|
|
11
|
+
from xgae.utils.setup_env import setup_langfuse
|
|
11
12
|
|
|
12
13
|
class LLMConfig(TypedDict, total=False):
|
|
13
|
-
model: str
|
|
14
|
-
model_name: str
|
|
15
|
-
model_id: str
|
|
16
|
-
api_key: str
|
|
17
|
-
api_base: str
|
|
18
|
-
temperature: float
|
|
19
|
-
max_tokens: int
|
|
20
|
-
stream: bool
|
|
21
|
-
enable_thinking: bool
|
|
22
|
-
reasoning_effort: str
|
|
23
|
-
response_format: str
|
|
24
|
-
top_p: int
|
|
14
|
+
model: str # Optional Name of the model to use , Override .env LLM_MODEL
|
|
15
|
+
model_name: str # Optional Name of the model to use , use model if empty
|
|
16
|
+
model_id: str # Optional ARN for Bedrock inference profiles, default is None
|
|
17
|
+
api_key: str # Optional API key, Override .env LLM_API_KEY or OS env variable
|
|
18
|
+
api_base: str # Optional API base URL, Override .env LLM_API_BASE
|
|
19
|
+
temperature: float # temperature: Optional Sampling temperature (0-1), Override .env LLM_TEMPERATURE
|
|
20
|
+
max_tokens: int # max_tokens: Optional Maximum tokens in the response, Override .env LLM_MAX_TOKENS
|
|
21
|
+
stream: bool # stream: Optional whether to stream the response, default is True
|
|
22
|
+
enable_thinking: bool # Optional whether to enable thinking, default is False
|
|
23
|
+
reasoning_effort: str # Optional level of reasoning effort, default is ‘low’
|
|
24
|
+
response_format: str # response_format: Optional desired format for the response, default is None
|
|
25
|
+
top_p: int # Optional Top-p sampling parameter, default is None
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class LangfuseMetadata(TypedDict, total=False):
|
|
29
|
+
generation_name: str
|
|
30
|
+
generation_id: str
|
|
31
|
+
existing_trace_id: str
|
|
32
|
+
session_id: str
|
|
33
|
+
|
|
25
34
|
|
|
26
35
|
class LLMError(Exception):
|
|
27
36
|
"""Base exception for LLM-related errors."""
|
|
@@ -31,27 +40,15 @@ class LLMClient:
|
|
|
31
40
|
RATE_LIMIT_DELAY = 30
|
|
32
41
|
RETRY_DELAY = 0.1
|
|
33
42
|
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
Arg: llm_config (Optional[Dict[str, Any]], optional)
|
|
37
|
-
model: Override default model to use, default set by .env LLM_MODEL
|
|
38
|
-
model_name: Optional Name of the model to use , use model if empty
|
|
39
|
-
model_id: Optional ARN for Bedrock inference profiles, default is None
|
|
40
|
-
api_key: Optional API key, Override .env LLM_API_KEY or OS environment variable
|
|
41
|
-
api_base: Optional API base URL, Override .env LLM_API_BASE
|
|
42
|
-
temperature: Optional Sampling temperature (0-1), Override .env LLM_TEMPERATURE
|
|
43
|
-
max_tokens: Optional Maximum tokens in the response, Override .env LLM_MAX_TOKENS
|
|
44
|
-
stream: Optional whether to stream the response, default is True
|
|
45
|
-
response_format: Optional desired format for the response, default is None
|
|
46
|
-
enable_thinking: Optional whether to enable thinking, default is False
|
|
47
|
-
reasoning_effort: Optional level of reasoning effort, default is ‘low’
|
|
48
|
-
top_p: Optional Top-p sampling parameter, default is None
|
|
49
|
-
"""
|
|
43
|
+
langfuse_inited = False
|
|
44
|
+
langfuse_enabled = False
|
|
50
45
|
|
|
51
|
-
|
|
46
|
+
def __init__(self, llm_config: LLMConfig=None):
|
|
52
47
|
litellm.modify_params = True
|
|
53
48
|
litellm.drop_params = True
|
|
49
|
+
self._init_langfuse()
|
|
54
50
|
|
|
51
|
+
llm_config = llm_config or LLMConfig()
|
|
55
52
|
self.max_retries = int(os.getenv("LLM_MAX_RETRIES", 1))
|
|
56
53
|
|
|
57
54
|
env_llm_model = os.getenv("LLM_MODEL", "openai/qwen3-235b-a22b")
|
|
@@ -83,6 +80,20 @@ class LLMClient:
|
|
|
83
80
|
self.lite_llm_params = self._prepare_llm_params(llm_config_params)
|
|
84
81
|
logging.info(f"📡 LLMClient initialed : model={self.model_name}, is_stream={self.is_stream}, enable thinking={self.lite_llm_params['enable_thinking']}")
|
|
85
82
|
|
|
83
|
+
@staticmethod
|
|
84
|
+
def _init_langfuse():
|
|
85
|
+
if not LLMClient.langfuse_inited:
|
|
86
|
+
LLMClient.langfuse_inited =True
|
|
87
|
+
env_langfuse = setup_langfuse()
|
|
88
|
+
if env_langfuse and env_langfuse.enabled:
|
|
89
|
+
litellm.success_callback = ["langfuse"]
|
|
90
|
+
litellm.failure_callback = ["langfuse"]
|
|
91
|
+
LLMClient.langfuse_enabled = True
|
|
92
|
+
logging.info("=== LiteLLM Langfuse is enable !")
|
|
93
|
+
else:
|
|
94
|
+
LLMClient.langfuse_enabled = False
|
|
95
|
+
logging.warning("*** LiteLLM Langfuse is disable !")
|
|
96
|
+
|
|
86
97
|
|
|
87
98
|
def _prepare_llm_params(self, llm_config_params: Dict[str, Any]) -> Dict[str, Any]:
|
|
88
99
|
prepared_llm_params = llm_config_params.copy()
|
|
@@ -206,10 +217,10 @@ class LLMClient:
|
|
|
206
217
|
logging.debug(f"LLMClient: Waiting {delay} seconds before retry llm completion...")
|
|
207
218
|
await asyncio.sleep(delay)
|
|
208
219
|
|
|
209
|
-
async def create_completion(self, messages: List[Dict[str, Any]],
|
|
220
|
+
async def create_completion(self, messages: List[Dict[str, Any]], langfuse_metadata: Optional[LangfuseMetadata]=None) -> Union[ModelResponse, CustomStreamWrapper]:
|
|
210
221
|
complete_params = self._prepare_complete_params(messages)
|
|
211
|
-
if
|
|
212
|
-
complete_params["
|
|
222
|
+
if LLMClient.langfuse_enabled and langfuse_metadata:
|
|
223
|
+
complete_params["metadata"] = langfuse_metadata
|
|
213
224
|
|
|
214
225
|
last_error = None
|
|
215
226
|
for attempt in range(self.max_retries):
|
|
@@ -228,13 +239,24 @@ class LLMClient:
|
|
|
228
239
|
raise LLMError(f"LLM completion failed after {self.max_retries} attempts !")
|
|
229
240
|
|
|
230
241
|
if __name__ == "__main__":
|
|
231
|
-
from xgae.utils import
|
|
242
|
+
from xgae.utils.setup_env import setup_logging
|
|
243
|
+
|
|
244
|
+
setup_logging()
|
|
245
|
+
langfuse = setup_langfuse()
|
|
232
246
|
|
|
233
247
|
async def llm_completion():
|
|
234
248
|
llm_client = LLMClient(LLMConfig(stream=False))
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
249
|
+
|
|
250
|
+
messages = [{"role": "user", "content": "1+1="}]
|
|
251
|
+
trace_id = langfuse.trace(name = "xgae_litellm_test").trace_id
|
|
252
|
+
meta = LangfuseMetadata(
|
|
253
|
+
generation_name="llm_completion_test",
|
|
254
|
+
generation_id="generation_id",
|
|
255
|
+
existing_trace_id=trace_id,
|
|
256
|
+
session_id="session_0",
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
response = await llm_client.create_completion(messages, meta)
|
|
238
260
|
if llm_client.is_stream:
|
|
239
261
|
async for chunk in response:
|
|
240
262
|
choices = chunk.get("choices", [{}])
|
|
@@ -247,6 +269,7 @@ if __name__ == "__main__":
|
|
|
247
269
|
else:
|
|
248
270
|
print(response.choices[0].message.content)
|
|
249
271
|
|
|
272
|
+
|
|
250
273
|
asyncio.run(llm_completion())
|
|
251
274
|
|
|
252
275
|
|
xgae/utils/misc.py
CHANGED
|
@@ -4,6 +4,8 @@ import sys
|
|
|
4
4
|
|
|
5
5
|
from typing import Any, Dict
|
|
6
6
|
|
|
7
|
+
from xgae.utils import handle_error
|
|
8
|
+
|
|
7
9
|
def read_file(file_path: str) -> str:
|
|
8
10
|
if not os.path.exists(file_path):
|
|
9
11
|
logging.error(f"File '{file_path}' not found")
|
|
@@ -31,4 +33,4 @@ def format_file_with_args(file_content:str, args: Dict[str, Any])-> str:
|
|
|
31
33
|
finally:
|
|
32
34
|
sys.stdout = original_stdout
|
|
33
35
|
|
|
34
|
-
return formated
|
|
36
|
+
return formated
|