xgae 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xgae might be problematic. Click here for more details.
- xgae/__init__.py +0 -0
- xgae/engine/responser/xga_non_stream_responser.py +0 -0
- xgae/engine/responser/xga_responser_utils.py +0 -0
- xgae/engine/responser/xga_stream_reponser.py +0 -0
- xgae/engine/xga_base.py +46 -0
- xgae/engine/xga_engine.py +69 -0
- xgae/engine/xga_mcp_tool_box.py +192 -0
- xgae/engine/xga_prompt_builder.py +38 -0
- xgae/utils/llm_client.py +239 -0
- xgae/utils/setup_env.py +108 -0
- xgae-0.1.1.dist-info/METADATA +11 -0
- xgae-0.1.1.dist-info/RECORD +13 -0
- xgae-0.1.1.dist-info/WHEEL +4 -0
xgae/__init__.py
ADDED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
xgae/engine/xga_base.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
from typing import Union, Optional, Dict, List, Any, Literal
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclass
|
|
8
|
+
class XGAMessage:
|
|
9
|
+
message_id: str
|
|
10
|
+
type: Literal["status", "tool", "assistant", "assistant_response_end"]
|
|
11
|
+
is_llm_message: bool
|
|
12
|
+
content: Union[Dict[str, Any], List[Any], str]
|
|
13
|
+
metadata: Optional[Dict[str, Any]]
|
|
14
|
+
session_id: Optional[str]
|
|
15
|
+
agent_id: Optional[str]
|
|
16
|
+
task_id: Optional[str]
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class XGAToolSchema:
|
|
20
|
+
tool_name: str
|
|
21
|
+
server_name: str
|
|
22
|
+
description: str
|
|
23
|
+
input_schema: Optional[str]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass
|
|
27
|
+
class XGAToolResult:
|
|
28
|
+
success: bool
|
|
29
|
+
output: str
|
|
30
|
+
|
|
31
|
+
class XGAToolBox(ABC):
|
|
32
|
+
@abstractmethod
|
|
33
|
+
async def creat_task_tool_box(self, task_id: str, general_tools: List[str], custom_tools: List[str]):
|
|
34
|
+
pass
|
|
35
|
+
|
|
36
|
+
@abstractmethod
|
|
37
|
+
async def destroy_task_tool_box(self, task_id: str):
|
|
38
|
+
pass
|
|
39
|
+
|
|
40
|
+
@abstractmethod
|
|
41
|
+
def get_task_tool_schemas(self, task_id: str, type: Literal["general_tool", "custom_tool"]) -> List[XGAToolSchema]:
|
|
42
|
+
pass
|
|
43
|
+
|
|
44
|
+
@abstractmethod
|
|
45
|
+
async def call_tool(self, task_id: str, tool_name: str, args: Optional[Dict[str, Any]] = None) -> XGAToolResult:
|
|
46
|
+
pass
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
|
|
2
|
+
from typing import List, Any, Dict, Optional, AsyncGenerator
|
|
3
|
+
from uuid import uuid4
|
|
4
|
+
|
|
5
|
+
from xgae.engine.xga_base import XGAMessage, XGAToolSchema, XGAToolBox
|
|
6
|
+
from xgae.utils.llm_client import LLMClient
|
|
7
|
+
from xgae.utils.setup_env import langfuse
|
|
8
|
+
from xga_prompt_builder import XGAPromptBuilder
|
|
9
|
+
from xga_mcp_tool_box import XGAMcpToolBox
|
|
10
|
+
|
|
11
|
+
class XGAEngine():
|
|
12
|
+
def __init__(self,
|
|
13
|
+
session_id: Optional[str] = None,
|
|
14
|
+
trace_id: Optional[str] = None,
|
|
15
|
+
agent_id: Optional[str] = None,
|
|
16
|
+
llm_config: Optional[Dict[str, Any]] = None,
|
|
17
|
+
prompt_builder: Optional[XGAPromptBuilder] = None,
|
|
18
|
+
tool_box: Optional[XGAToolBox] = None):
|
|
19
|
+
self.session_id = session_id if session_id else f"xga_sid_{uuid4()}"
|
|
20
|
+
self.agent_id = agent_id
|
|
21
|
+
|
|
22
|
+
self.messages: List[XGAMessage] = []
|
|
23
|
+
self.llm_client = LLMClient(llm_config)
|
|
24
|
+
self.model_name = self.llm_client.model_name
|
|
25
|
+
self.is_stream = self.llm_client.is_stream
|
|
26
|
+
|
|
27
|
+
self.prompt_builder = prompt_builder or XGAPromptBuilder()
|
|
28
|
+
self.tool_box = tool_box or XGAMcpToolBox()
|
|
29
|
+
|
|
30
|
+
self.task_id = None
|
|
31
|
+
self.trace_id = trace_id if trace_id else langfuse.create_trace_id()
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
async def run_task(self,
|
|
35
|
+
task_messages: List[Dict[str, Any]],
|
|
36
|
+
task_id: Optional[str],
|
|
37
|
+
prompt_template: Optional[str] = None,
|
|
38
|
+
general_tools: Optional[List[str]] = ["*"],
|
|
39
|
+
custom_tools: Optional[List[str]] = []) -> AsyncGenerator:
|
|
40
|
+
try:
|
|
41
|
+
self.task_id = task_id if task_id else f"xga_task_{uuid4()}"
|
|
42
|
+
await self.tool_box.creat_task_tool_box(self.task_id, general_tools, custom_tools)
|
|
43
|
+
system_prompt = await self._build_system_prompt(prompt_template, general_tools, custom_tools)
|
|
44
|
+
yield system_prompt
|
|
45
|
+
|
|
46
|
+
finally:
|
|
47
|
+
await self.tool_box.destroy_task_tool_box(self.task_id)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _run_task_once(self):
|
|
51
|
+
pass
|
|
52
|
+
|
|
53
|
+
async def _build_system_prompt(self, prompt_template: str, general_tools: List[str], custom_tools: List[str]) -> str:
|
|
54
|
+
self.task_tool_schemas: Dict[str, XGAToolSchema] = {}
|
|
55
|
+
system_prompt = self.prompt_builder.build_system_prompt(self.model_name, prompt_template)
|
|
56
|
+
|
|
57
|
+
tool_schemas = await self.tool_box.get_task_tool_schemas(self.task_id, "general_tool")
|
|
58
|
+
system_prompt = self.prompt_builder.build_general_tool_prompt(self.model_name, system_prompt, tool_schemas)
|
|
59
|
+
|
|
60
|
+
tool_schemas = await self.tool_box.get_task_tool_schemas(self.task_id, "general_tool")
|
|
61
|
+
system_prompt = self.prompt_builder.build_custom_tool_prompt(self.model_name, system_prompt, tool_schemas)
|
|
62
|
+
|
|
63
|
+
return system_prompt
|
|
64
|
+
|
|
65
|
+
def add_message(self, message: XGAMessage):
|
|
66
|
+
message.message_id = f"xga_msg_{uuid4()}"
|
|
67
|
+
message.session_id = self.session_id
|
|
68
|
+
message.agent_id = self.agent_id
|
|
69
|
+
self.messages.append(message)
|
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
from typing import List, Any, Dict, Optional, Literal, override
|
|
5
|
+
|
|
6
|
+
from langchain_mcp_adapters.client import MultiServerMCPClient
|
|
7
|
+
from langchain_mcp_adapters.tools import load_mcp_tools
|
|
8
|
+
|
|
9
|
+
from xgae.engine.xga_base import XGAToolSchema, XGAToolBox, XGAToolResult
|
|
10
|
+
from xgae.utils.setup_env import XGAError
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class XGAMcpToolBox(XGAToolBox):
|
|
14
|
+
GENERAL_MCP_SERVER_NAME = "xga_general"
|
|
15
|
+
|
|
16
|
+
def __init__(self,
|
|
17
|
+
custom_mcp_server_config: Optional[Dict[str, Any]] = None,
|
|
18
|
+
custom_mcp_server_file: Optional[str] = None):
|
|
19
|
+
general_mcp_server_config = self._load_mcp_servers_config("mcpservers/xga_server.json")
|
|
20
|
+
tool_box_mcp_server_config = general_mcp_server_config.get("mcpServers", {})
|
|
21
|
+
|
|
22
|
+
if custom_mcp_server_config:
|
|
23
|
+
tool_box_mcp_server_config.update(custom_mcp_server_config)
|
|
24
|
+
elif custom_mcp_server_file:
|
|
25
|
+
custom_mcp_server_config = self._load_mcp_servers_config(custom_mcp_server_file)
|
|
26
|
+
custom_mcp_server_config = custom_mcp_server_config.get("mcpServers", {})
|
|
27
|
+
tool_box_mcp_server_config.update(custom_mcp_server_config)
|
|
28
|
+
|
|
29
|
+
self._mcp_client = MultiServerMCPClient(tool_box_mcp_server_config)
|
|
30
|
+
self.mcp_server_names: List[str] = [server_name for server_name in tool_box_mcp_server_config]
|
|
31
|
+
self.mcp_tool_schemas: Dict[str, List[XGAToolSchema]] = {}
|
|
32
|
+
self.task_tool_schemas: Dict[str, Dict[str,XGAToolSchema]] = {}
|
|
33
|
+
|
|
34
|
+
@override
|
|
35
|
+
async def creat_task_tool_box(self, task_id: str, general_tools: List[str], custom_tools: List[str]):
|
|
36
|
+
task_tool_schemas = {}
|
|
37
|
+
general_tool_schemas = self.mcp_tool_schemas.get(XGAMcpToolBox.GENERAL_MCP_SERVER_NAME, {})
|
|
38
|
+
if len(general_tools) > 0 and general_tools[0] == "*":
|
|
39
|
+
task_tool_schemas = {tool_schema.tool_name: tool_schema for tool_schema in general_tool_schemas}
|
|
40
|
+
else:
|
|
41
|
+
for tool_schema in general_tool_schemas:
|
|
42
|
+
if tool_schema.tool_name in general_tools:
|
|
43
|
+
task_tool_schemas[tool_schema.tool_name] = tool_schema
|
|
44
|
+
task_tool_schemas.pop("end_task")
|
|
45
|
+
|
|
46
|
+
for server_tool_name in custom_tools:
|
|
47
|
+
parts = server_tool_name.split(".")
|
|
48
|
+
if len(parts) != 2:
|
|
49
|
+
continue
|
|
50
|
+
custom_server_name, custom_tool_name = parts
|
|
51
|
+
if (not custom_server_name ) or (not custom_tool_name):
|
|
52
|
+
continue
|
|
53
|
+
|
|
54
|
+
custom_tool_schemas = self.mcp_tool_schemas.get(custom_server_name, None)
|
|
55
|
+
if custom_tool_schemas is None:
|
|
56
|
+
continue
|
|
57
|
+
if custom_tool_name == "*":
|
|
58
|
+
custom_tool_schema_d = {tool_schema.tool_name: tool_schema for tool_schema in custom_tool_schemas}
|
|
59
|
+
task_tool_schemas.update(custom_tool_schema_d)
|
|
60
|
+
else:
|
|
61
|
+
for tool_schema in custom_tool_schemas:
|
|
62
|
+
if custom_tool_name == tool_schema.tool_name:
|
|
63
|
+
task_tool_schemas[custom_tool_name] = tool_schema
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
self.task_tool_schemas[task_id] = task_tool_schemas
|
|
67
|
+
|
|
68
|
+
@override
|
|
69
|
+
async def destroy_task_tool_box(self, task_id: str):
|
|
70
|
+
tool_schemas = self.get_task_tool_schemas(task_id, type="general_tool")
|
|
71
|
+
if len(tool_schemas) > 0:
|
|
72
|
+
await self.call_tool(task_id, "end_task", {"task_id": task_id})
|
|
73
|
+
self.task_tool_schemas.pop(task_id, None)
|
|
74
|
+
|
|
75
|
+
@override
|
|
76
|
+
def get_task_tool_schemas(self, task_id: str, type: Literal["general_tool", "custom_tool"]) -> List[XGAToolSchema]:
|
|
77
|
+
task_tool_schemas = []
|
|
78
|
+
|
|
79
|
+
all_task_tool_schemas = self.task_tool_schemas.get(task_id, {})
|
|
80
|
+
for tool_schema in all_task_tool_schemas.values():
|
|
81
|
+
if type == "general_tool" and tool_schema.server_name == self.GENERAL_MCP_SERVER_NAME:
|
|
82
|
+
task_tool_schemas.append(tool_schema)
|
|
83
|
+
elif type == "custom_tool" and tool_schema.server_name != self.GENERAL_MCP_SERVER_NAME:
|
|
84
|
+
task_tool_schemas.append(tool_schema)
|
|
85
|
+
|
|
86
|
+
return task_tool_schemas
|
|
87
|
+
|
|
88
|
+
@override
|
|
89
|
+
async def call_tool(self, task_id: str, tool_name: str, args: Optional[Dict[str, Any]] = None) -> XGAToolResult:
|
|
90
|
+
if tool_name == "end_task":
|
|
91
|
+
server_name = self.GENERAL_MCP_SERVER_NAME
|
|
92
|
+
else:
|
|
93
|
+
task_tool_schemas = self.task_tool_schemas.get(task_id, {})
|
|
94
|
+
tool_schema = task_tool_schemas.get(tool_name, None)
|
|
95
|
+
if tool_schema is None:
|
|
96
|
+
raise XGAError(f"MCP tool not found: '{tool_name}'")
|
|
97
|
+
server_name = tool_schema.server_name
|
|
98
|
+
|
|
99
|
+
async with self._mcp_client.session(server_name) as session:
|
|
100
|
+
tools = await load_mcp_tools(session)
|
|
101
|
+
mcp_tool = next((t for t in tools if t.name == tool_name), None)
|
|
102
|
+
|
|
103
|
+
if mcp_tool:
|
|
104
|
+
tool_args = args or {}
|
|
105
|
+
if server_name == self.GENERAL_MCP_SERVER_NAME:
|
|
106
|
+
pass
|
|
107
|
+
#tool_args["task_id"] = task_id #xga general tool, first param must be task_id
|
|
108
|
+
else:
|
|
109
|
+
tool_args = args
|
|
110
|
+
|
|
111
|
+
try:
|
|
112
|
+
tool_result = await mcp_tool.arun(tool_args)
|
|
113
|
+
result = XGAToolResult(success=True, output=str(tool_result))
|
|
114
|
+
except Exception as e:
|
|
115
|
+
error = f"Call mcp tool '{tool_name}' error: {str(e)}"
|
|
116
|
+
logging.error(f"XGAMcpToolBox.call_tool: {error}")
|
|
117
|
+
result = XGAToolResult(success=False, output=error)
|
|
118
|
+
else:
|
|
119
|
+
error = f"No MCP tool found with name: {tool_name}"
|
|
120
|
+
logging.info(f"XGAMcpToolBox.call_tool: error={error}")
|
|
121
|
+
result = XGAToolResult(success=False, output=error)
|
|
122
|
+
|
|
123
|
+
return result
|
|
124
|
+
|
|
125
|
+
async def load_mcp_tools_schema(self)-> None:
|
|
126
|
+
for server_name in self.mcp_server_names:
|
|
127
|
+
self.mcp_tool_schemas[server_name] = []
|
|
128
|
+
mcp_tools = await self._mcp_client.get_tools(server_name=server_name)
|
|
129
|
+
for tool in mcp_tools:
|
|
130
|
+
input_schema = tool.args_schema
|
|
131
|
+
if server_name == self.GENERAL_MCP_SERVER_NAME:
|
|
132
|
+
input_schema = str(tool.args_schema) # @todo remove task_id param
|
|
133
|
+
input_schema_str = str(input_schema) # @todo convert input tool.args_schema
|
|
134
|
+
tool_schema = XGAToolSchema(tool_name=tool.name,
|
|
135
|
+
server_name=server_name,
|
|
136
|
+
description=tool.description,
|
|
137
|
+
input_schema=input_schema_str)
|
|
138
|
+
self.mcp_tool_schemas[server_name].append(tool_schema)
|
|
139
|
+
|
|
140
|
+
@staticmethod
|
|
141
|
+
def _load_mcp_servers_config(mcp_config_path: str) -> Dict[str, Any]:
|
|
142
|
+
try:
|
|
143
|
+
if os.path.exists(mcp_config_path):
|
|
144
|
+
with open(mcp_config_path, 'r', encoding='utf-8') as f:
|
|
145
|
+
server_config = json.load(f)
|
|
146
|
+
|
|
147
|
+
for server_name, server_info in server_config["mcpServers"].items():
|
|
148
|
+
if "transport" not in server_info:
|
|
149
|
+
if "url" in server_info:
|
|
150
|
+
server_info["transport"] = "streamable_http" if "mcp" in server_info["url"] else "sse"
|
|
151
|
+
else:
|
|
152
|
+
server_info["transport"] = "stdio"
|
|
153
|
+
|
|
154
|
+
return server_config
|
|
155
|
+
else:
|
|
156
|
+
logging.warning("MCP servers config file not found at: %s", mcp_config_path)
|
|
157
|
+
return {"mcpServers": {}}
|
|
158
|
+
|
|
159
|
+
except Exception as e:
|
|
160
|
+
logging.error("Failed to load MCP servers config: %s", str(e))
|
|
161
|
+
return {"mcpServers": {}}
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
if __name__ == "__main__":
|
|
165
|
+
import asyncio
|
|
166
|
+
from dataclasses import asdict
|
|
167
|
+
async def main():
|
|
168
|
+
task_id = "task1"
|
|
169
|
+
mcp_tool_box = XGAMcpToolBox(custom_mcp_server_file="mcpservers/custom_servers.json")
|
|
170
|
+
#mcp_tool_box = XGAMcpToolBox()
|
|
171
|
+
await mcp_tool_box.load_mcp_tools_schema()
|
|
172
|
+
await mcp_tool_box.creat_task_tool_box(task_id=task_id, general_tools=["*"], custom_tools=["bomc_fault.*"])
|
|
173
|
+
tool_schemas = mcp_tool_box.get_task_tool_schemas(task_id, "general_tool")
|
|
174
|
+
print("general_tools_schemas" + "*"*50)
|
|
175
|
+
for tool_schema in tool_schemas:
|
|
176
|
+
print(asdict(tool_schema))
|
|
177
|
+
print()
|
|
178
|
+
|
|
179
|
+
tool_schemas = mcp_tool_box.get_task_tool_schemas(task_id, "custom_tool")
|
|
180
|
+
print("custom_tools_schemas" + "*" * 50)
|
|
181
|
+
for tool_schema in tool_schemas:
|
|
182
|
+
print(asdict(tool_schema))
|
|
183
|
+
print()
|
|
184
|
+
|
|
185
|
+
result = await mcp_tool_box.call_tool(task_id=task_id, tool_name="web_search", args={"task_id": task_id, "query": "查询天津天气"})
|
|
186
|
+
print(f"call web_search result: {result}")
|
|
187
|
+
|
|
188
|
+
result = await mcp_tool_box.call_tool(task_id=task_id, tool_name="complete", args={"task_id": task_id})
|
|
189
|
+
print(f"call complete result: {result}")
|
|
190
|
+
|
|
191
|
+
await mcp_tool_box.destroy_task_tool_box(task_id)
|
|
192
|
+
asyncio.run(main())
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
import datetime
|
|
2
|
+
|
|
3
|
+
from typing import Optional, List
|
|
4
|
+
|
|
5
|
+
from xga_base import XGAToolSchema
|
|
6
|
+
from xgae.utils.setup_env import read_file, XGAError
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class XGAPromptBuilder():
|
|
10
|
+
def __init__(self,
|
|
11
|
+
prompt_template: Optional[str] = None,
|
|
12
|
+
prompt_template_file: Optional[str] = None):
|
|
13
|
+
self.system_prompt_template = None
|
|
14
|
+
if prompt_template:
|
|
15
|
+
self.system_prompt_template = prompt_template
|
|
16
|
+
elif prompt_template_file:
|
|
17
|
+
self.system_prompt_template = read_file(prompt_template_file)
|
|
18
|
+
else:
|
|
19
|
+
_system_prompt_template = read_file("templates/system_prompt_template.txt")
|
|
20
|
+
self.system_prompt_template = _system_prompt_template.format(
|
|
21
|
+
current_date=datetime.datetime.now(datetime.timezone.utc).strftime('%Y-%m-%d'),
|
|
22
|
+
current_time=datetime.datetime.now(datetime.timezone.utc).strftime('%H:%M:%S'),
|
|
23
|
+
current_year=datetime.datetime.now(datetime.timezone.utc).strftime('%Y')
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def build_system_prompt(self, model_name:str, prompt_template: Optional[str]=None)-> str:
|
|
28
|
+
system_prompt = prompt_template if prompt_template else self.system_prompt_template
|
|
29
|
+
|
|
30
|
+
return system_prompt
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def build_general_tool_prompt(self, model_name:str, prompt_template: str, tool_schemas:List[XGAToolSchema])-> str:
|
|
34
|
+
pass
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def build_custom_tool_prompt(self, model_name:str, prompt_template: str, tool_schemas:List[XGAToolSchema])-> str:
|
|
38
|
+
pass
|
xgae/utils/llm_client.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
import litellm
|
|
6
|
+
|
|
7
|
+
from typing import Union, Dict, Any, Optional, List
|
|
8
|
+
|
|
9
|
+
from litellm.utils import ModelResponse, CustomStreamWrapper
|
|
10
|
+
from openai import OpenAIError
|
|
11
|
+
|
|
12
|
+
from setup_env import setup_xga_env
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class LLMError(Exception):
|
|
16
|
+
"""Base exception for LLM-related errors."""
|
|
17
|
+
pass
|
|
18
|
+
|
|
19
|
+
class LLMClient:
|
|
20
|
+
RATE_LIMIT_DELAY = 30
|
|
21
|
+
RETRY_DELAY = 0.1
|
|
22
|
+
|
|
23
|
+
def __init__(self, llm_config: Optional[Dict[str, Any]]={}) -> None:
|
|
24
|
+
"""
|
|
25
|
+
Arg: llm_config (Optional[Dict[str, Any]], optional)
|
|
26
|
+
model: Override default model to use, default set by .env LLM_MODEL
|
|
27
|
+
model_name: Optional Name of the model to use , use model if empty
|
|
28
|
+
model_id: Optional ARN for Bedrock inference profiles, default is None
|
|
29
|
+
api_key: Optional API key, Override .env LLM_API_KEY or OS environment variable
|
|
30
|
+
api_base: Optional API base URL, Override .env LLM_API_BASE
|
|
31
|
+
temperature: Optional Sampling temperature (0-1), Override .env LLM_TEMPERATURE
|
|
32
|
+
max_tokens: Optional Maximum tokens in the response, Override .env LLM_MAX_TOKENS
|
|
33
|
+
stream: Optional whether to stream the response, default is True
|
|
34
|
+
response_format: Optional desired format for the response, default is None
|
|
35
|
+
enable_thinking: Optional whether to enable thinking, default is False
|
|
36
|
+
reasoning_effort: Optional level of reasoning effort, default is ‘low’
|
|
37
|
+
top_p: Optional Top-p sampling parameter, default is None
|
|
38
|
+
"""
|
|
39
|
+
litellm.modify_params = True
|
|
40
|
+
litellm.drop_params = True
|
|
41
|
+
|
|
42
|
+
self.max_retries = int(os.getenv("LLM_MAX_RETRIES", 1))
|
|
43
|
+
|
|
44
|
+
env_llm_model = os.getenv("LLM_MODEL", "openai/qwen3-235b-a22b")
|
|
45
|
+
env_llm_api_key = os.getenv("LLM_API_KEY")
|
|
46
|
+
env_llm_api_base = os.getenv("LLM_API_BASE", "https://dashscope.aliyuncs.com/compatible-mode/v1")
|
|
47
|
+
env_llm_max_tokens = int(os.getenv("LLM_MAX_TOKENS", 16384))
|
|
48
|
+
env_llm_temperature = float(os.getenv("LLM_TEMPERATURE", 0.7))
|
|
49
|
+
|
|
50
|
+
llm_config_params = {
|
|
51
|
+
"model": llm_config.get("model", env_llm_model),
|
|
52
|
+
"model_name": llm_config.get("model_name", env_llm_model),
|
|
53
|
+
"model_id": llm_config.get("model_id", None),
|
|
54
|
+
"api_key": llm_config.get("api_key", env_llm_api_key),
|
|
55
|
+
"api_base": llm_config.get("api_base", env_llm_api_base),
|
|
56
|
+
"temperature": llm_config.get("temperature", env_llm_temperature),
|
|
57
|
+
"max_tokens": llm_config.get("max_tokens", env_llm_max_tokens),
|
|
58
|
+
"stream": llm_config.get("stream", True),
|
|
59
|
+
"enable_thinking": llm_config.get("enable_thinking", False),
|
|
60
|
+
"reasoning_effort": llm_config.get("reasoning_effort", 'low'),
|
|
61
|
+
"response_format": llm_config.get("response_format", None),
|
|
62
|
+
"top_p": llm_config.get("top_p", None),
|
|
63
|
+
"tools": None,
|
|
64
|
+
"tool_choice": "none",
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
self.model_name = llm_config_params["model_name"]
|
|
68
|
+
self.is_stream = llm_config_params['stream']
|
|
69
|
+
|
|
70
|
+
self.lite_llm_params = self._prepare_llm_params(llm_config_params)
|
|
71
|
+
logging.info(f"📡 LLMClient initialed : model={self.model_name}, is_stream={self.is_stream}, enable thinking={self.lite_llm_params['enable_thinking']}")
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def _prepare_llm_params(self, llm_config_params: Dict[str, Any]) -> Dict[str, Any]:
|
|
75
|
+
prepared_llm_params = llm_config_params.copy()
|
|
76
|
+
|
|
77
|
+
model_name = llm_config_params.get("model_name")
|
|
78
|
+
max_tokens = llm_config_params.get("max_tokens")
|
|
79
|
+
model_id = llm_config_params.get("model_id")
|
|
80
|
+
|
|
81
|
+
# Handle token limits
|
|
82
|
+
if max_tokens is not None:
|
|
83
|
+
# For Claude 3.7 in Bedrock, do not set max_tokens or max_tokens_to_sample
|
|
84
|
+
# as it causes errors with inference profiles
|
|
85
|
+
if model_name.startswith("bedrock/") and "claude-3-7" in model_name:
|
|
86
|
+
prepared_llm_params.pop("max_tokens")
|
|
87
|
+
logging.debug(f"prepare_llm_params: Remove 'max_tokens' param for model: {model_name}")
|
|
88
|
+
else:
|
|
89
|
+
is_openai_o_series = 'o1' in model_name
|
|
90
|
+
is_openai_gpt5 = 'gpt-5' in model_name
|
|
91
|
+
param_name = "max_completion_tokens" if (is_openai_o_series or is_openai_gpt5) else "max_tokens"
|
|
92
|
+
if param_name == "max_completion_tokens":
|
|
93
|
+
prepared_llm_params[param_name] = max_tokens
|
|
94
|
+
logging.debug(f"prepare_llm_params: Add 'max_completion_tokens' param for model: {model_name}")
|
|
95
|
+
|
|
96
|
+
# # Add Claude-specific headers
|
|
97
|
+
if "claude" in model_name.lower() or "anthropic" in model_name.lower():
|
|
98
|
+
prepared_llm_params["extra_headers"] = {
|
|
99
|
+
"anthropic-beta": "output-128k-2025-02-19"
|
|
100
|
+
}
|
|
101
|
+
logging.debug(f"prepare_llm_params: Add 'extra_headers' param for model: {model_name}")
|
|
102
|
+
|
|
103
|
+
# Add Bedrock-specific parameters
|
|
104
|
+
if model_name.startswith("bedrock/"):
|
|
105
|
+
if not model_id and "anthropic.claude-3-7-sonnet" in model_name:
|
|
106
|
+
prepared_llm_params["model_id"] = "arn:aws:bedrock:us-west-2:935064898258:inference-profile/us.anthropic.claude-3-7-sonnet-20250219-v1:0"
|
|
107
|
+
logging.debug(f"prepare_llm_params: Must Set 'model_id' param for model: {model_name}")
|
|
108
|
+
|
|
109
|
+
# Apply Anthropic prompt caching (minimal implementation)
|
|
110
|
+
effective_model_name = llm_config_params.get("model", model_name)
|
|
111
|
+
|
|
112
|
+
# OpenAI GPT-5: drop unsupported temperature param (only default 1 allowed)
|
|
113
|
+
if "gpt-5" in effective_model_name and "temperature" in llm_config_params and llm_config_params["temperature"] != 1:
|
|
114
|
+
prepared_llm_params.pop("temperature", None)
|
|
115
|
+
logging.debug(f"prepare_llm_params: Remove 'temperature' param for model: {model_name}")
|
|
116
|
+
|
|
117
|
+
# OpenAI GPT-5: request priority service tier when calling OpenAI directly
|
|
118
|
+
# Pass via both top-level and extra_body for LiteLLM compatibility
|
|
119
|
+
if "gpt-5" in effective_model_name and not effective_model_name.startswith("openrouter/"):
|
|
120
|
+
prepared_llm_params["service_tier"] = "priority"
|
|
121
|
+
prepared_llm_params["extra_body"] = {"service_tier": "priority"}
|
|
122
|
+
logging.debug(f"prepare_llm_params: Add 'service_tier' and 'extra_body' param for model: {model_name}")
|
|
123
|
+
|
|
124
|
+
# Add reasoning_effort for Anthropic models if enabled
|
|
125
|
+
enable_thinking = llm_config_params.get("enable_thinking")
|
|
126
|
+
use_thinking = enable_thinking if enable_thinking is not None else False
|
|
127
|
+
|
|
128
|
+
is_anthropic = "anthropic" in effective_model_name.lower() or "claude" in effective_model_name.lower()
|
|
129
|
+
is_kimi_k2 = "kimi-k2" in effective_model_name.lower() or model_name.startswith("moonshotai/kimi-k2")
|
|
130
|
+
|
|
131
|
+
if is_kimi_k2:
|
|
132
|
+
prepared_llm_params["provider"] = {
|
|
133
|
+
"order": ["together/fp8", "novita/fp8", "baseten/fp8", "moonshotai", "groq"]
|
|
134
|
+
}
|
|
135
|
+
logging.debug(f"prepare_llm_params: Add 'provider' param for model: {model_name}")
|
|
136
|
+
|
|
137
|
+
reasoning_effort = llm_config_params.get("reasoning_effort")
|
|
138
|
+
if is_anthropic and use_thinking:
|
|
139
|
+
effort_level = reasoning_effort if reasoning_effort else 'low'
|
|
140
|
+
prepared_llm_params["reasoning_effort"] = effort_level
|
|
141
|
+
prepared_llm_params["temperature"] = 1.0 # Required by Anthropic when reasoning_effort is used
|
|
142
|
+
logging.debug(f"prepare_llm_params: Set 'temperature'=1.0 param for model: {model_name}")
|
|
143
|
+
|
|
144
|
+
return prepared_llm_params
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def _prepare_complete_params(self, messages: List[Dict[str, Any]]) -> Dict[str, Any]:
|
|
148
|
+
"""Prepare parameters for the API call."""
|
|
149
|
+
complete_params = self.lite_llm_params.copy()
|
|
150
|
+
complete_params["messages"] = messages
|
|
151
|
+
|
|
152
|
+
model_name = self.lite_llm_params["model_name"]
|
|
153
|
+
effective_model_name = complete_params.get("model", model_name)
|
|
154
|
+
|
|
155
|
+
# Apply cache control to the first 4 text blocks across all messages , for anthropic and claude model
|
|
156
|
+
if "claude" in effective_model_name.lower() or "anthropic" in effective_model_name.lower():
|
|
157
|
+
messages = complete_params["messages"]
|
|
158
|
+
|
|
159
|
+
if not isinstance(messages, list):
|
|
160
|
+
return complete_params
|
|
161
|
+
|
|
162
|
+
cache_control_count = 0
|
|
163
|
+
max_cache_control_blocks = 3
|
|
164
|
+
|
|
165
|
+
for message in messages:
|
|
166
|
+
if cache_control_count >= max_cache_control_blocks:
|
|
167
|
+
break
|
|
168
|
+
|
|
169
|
+
content = message.get("content")
|
|
170
|
+
|
|
171
|
+
if isinstance(content, str):
|
|
172
|
+
message["content"] = [
|
|
173
|
+
{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}
|
|
174
|
+
]
|
|
175
|
+
cache_control_count += 1
|
|
176
|
+
logging.debug(f"prepare_complete_params: Add 'cache_control' in message content, for model: {model_name}")
|
|
177
|
+
elif isinstance(content, list):
|
|
178
|
+
for item in content:
|
|
179
|
+
if cache_control_count >= max_cache_control_blocks:
|
|
180
|
+
break
|
|
181
|
+
if isinstance(item, dict) and item.get("type") == "text" and "cache_control" not in item:
|
|
182
|
+
item["cache_control"] = {"type": "ephemeral"}
|
|
183
|
+
cache_control_count += 1
|
|
184
|
+
logging.debug(f"prepare_complete_params: Add 'cache_control' in message content list, for model: {model_name}")
|
|
185
|
+
|
|
186
|
+
return complete_params
|
|
187
|
+
|
|
188
|
+
async def _handle_llm_error(self, error: Exception, attempt: int) -> None:
|
|
189
|
+
"""Handle API errors with appropriate delays and logging."""
|
|
190
|
+
if (attempt + 1) < self.max_retries:
|
|
191
|
+
delay = LLMClient.RATE_LIMIT_DELAY if isinstance(error, litellm.exceptions.RateLimitError) else LLMClient.RETRY_DELAY
|
|
192
|
+
logging.warning(f"LLMClient: Error on llm completion, retry={attempt + 1}/{self.max_retries}: {error}")
|
|
193
|
+
logging.debug(f"LLMClient: Waiting {delay} seconds before retry llm completion...")
|
|
194
|
+
await asyncio.sleep(delay)
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
async def create_completion(self, messages: List[Dict[str, Any]]) -> Union[ModelResponse, CustomStreamWrapper]:
|
|
198
|
+
complete_params = self._prepare_complete_params(messages)
|
|
199
|
+
|
|
200
|
+
last_error = None
|
|
201
|
+
for attempt in range(self.max_retries):
|
|
202
|
+
try:
|
|
203
|
+
logging.info(f"*** create_completion ***: LLM '{self.model_name}' completion attempt {attempt + 1}/{self.max_retries}")
|
|
204
|
+
response = await litellm.acompletion(**complete_params)
|
|
205
|
+
return response
|
|
206
|
+
except (litellm.exceptions.RateLimitError, OpenAIError, json.JSONDecodeError) as e:
|
|
207
|
+
last_error = e
|
|
208
|
+
await self._handle_llm_error(e, attempt)
|
|
209
|
+
except Exception as e:
|
|
210
|
+
logging.error(f"create_completion: Unexpected error during LLM completion: {str(e)}", exc_info=True)
|
|
211
|
+
raise LLMError(f"LLM completion failed: {e}")
|
|
212
|
+
|
|
213
|
+
logging.error(f"create_completion: LLM completion failed after {self.max_retries} attempts: {last_error}", exc_info=True)
|
|
214
|
+
raise LLMError(f"LLM completion failed after {self.max_retries} attempts !")
|
|
215
|
+
|
|
216
|
+
if __name__ == "__main__":
|
|
217
|
+
setup_xga_env()
|
|
218
|
+
|
|
219
|
+
async def llm_completion():
|
|
220
|
+
llm_client = LLMClient({
|
|
221
|
+
"stream": False #default is True
|
|
222
|
+
})
|
|
223
|
+
messages = [{"role": "user", "content": "今天是2025年8月15日,北京本周每天温度"}]
|
|
224
|
+
response = await llm_client.create_completion(messages)
|
|
225
|
+
if llm_client.is_stream:
|
|
226
|
+
async for chunk in response:
|
|
227
|
+
choices = chunk.get("choices", [{}])
|
|
228
|
+
if not choices:
|
|
229
|
+
continue
|
|
230
|
+
delta = choices[0].get("delta", {})
|
|
231
|
+
content = delta.get("content", "")
|
|
232
|
+
if content:
|
|
233
|
+
print(content, end="", flush=True)
|
|
234
|
+
else:
|
|
235
|
+
print(response.choices[0].message.content)
|
|
236
|
+
|
|
237
|
+
asyncio.run(llm_completion())
|
|
238
|
+
|
|
239
|
+
|
xgae/utils/setup_env.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
import traceback
|
|
4
|
+
|
|
5
|
+
from langfuse import Langfuse
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class XGAError(Exception):
|
|
9
|
+
"""Custom exception for errors in the XGA system."""
|
|
10
|
+
pass
|
|
11
|
+
|
|
12
|
+
langfuse: Langfuse = None
|
|
13
|
+
|
|
14
|
+
def setup_langfuse() -> None:
|
|
15
|
+
env_public_key = os.getenv("LANGFUSE_PUBLIC_KEY")
|
|
16
|
+
env_secret_key = os.getenv("LANGFUSE_SECRET_KEY")
|
|
17
|
+
env_host = os.getenv("LANGFUSE_HOST", "https://cloud.langfuse.com")
|
|
18
|
+
global langfuse
|
|
19
|
+
if env_public_key and env_secret_key:
|
|
20
|
+
langfuse = Langfuse(tracing_enabled=True,
|
|
21
|
+
public_key=env_public_key,
|
|
22
|
+
secret_key=env_secret_key,
|
|
23
|
+
host=env_host)
|
|
24
|
+
logging.info("utils.setup_langfuse: Langfuse initialized!")
|
|
25
|
+
else:
|
|
26
|
+
langfuse = Langfuse(tracing_enabled=False)
|
|
27
|
+
logging.warning("utils.setup_langfuse: Langfuse is disabled!")
|
|
28
|
+
|
|
29
|
+
def setup_logging() -> None:
|
|
30
|
+
env_log_level = os.getenv("LOG_LEVEL", "INFO")
|
|
31
|
+
env_log_file = os.getenv("LOG_FILE", "log/xga.log")
|
|
32
|
+
log_level = getattr(logging, env_log_level.upper(), logging.INFO)
|
|
33
|
+
|
|
34
|
+
log_dir = os.path.dirname(env_log_file)
|
|
35
|
+
if log_dir and not os.path.exists(log_dir):
|
|
36
|
+
os.makedirs(log_dir, exist_ok=True)
|
|
37
|
+
else:
|
|
38
|
+
os.remove(env_log_file)
|
|
39
|
+
|
|
40
|
+
logger = logging.getLogger()
|
|
41
|
+
for handler in logger.handlers[:]:
|
|
42
|
+
logger.removeHandler(handler)
|
|
43
|
+
|
|
44
|
+
import colorlog
|
|
45
|
+
|
|
46
|
+
log_colors = {
|
|
47
|
+
'DEBUG': 'cyan',
|
|
48
|
+
'INFO': 'green',
|
|
49
|
+
'WARNING': 'yellow',
|
|
50
|
+
'ERROR': 'red',
|
|
51
|
+
'CRITICAL': 'red,bg_white'
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
console_formatter = colorlog.ColoredFormatter('%(log_color)s%(asctime)s - %(levelname)-8s%(reset)s %(white)s%(message)s',
|
|
55
|
+
log_colors=log_colors,
|
|
56
|
+
datefmt='%Y-%m-%d %H:%M:%S'
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
file_formatter = logging.Formatter(
|
|
60
|
+
'%(asctime)s -%(levelname)-8s %(message)s',
|
|
61
|
+
datefmt='%Y-%m-%d %H:%M:%S'
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
console_handler = logging.StreamHandler()
|
|
65
|
+
console_handler.setFormatter(console_formatter)
|
|
66
|
+
|
|
67
|
+
file_handler = logging.FileHandler(env_log_file, encoding='utf-8')
|
|
68
|
+
file_handler.setFormatter(file_formatter)
|
|
69
|
+
|
|
70
|
+
logger.addHandler(console_handler)
|
|
71
|
+
logger.addHandler(file_handler)
|
|
72
|
+
|
|
73
|
+
logger.setLevel(log_level)
|
|
74
|
+
|
|
75
|
+
logging.info(f"Logger is initialized, log_level={env_log_level}, log_file={env_log_file}")
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def handle_error(e: Exception) -> None:
|
|
79
|
+
logging.error("An error occurred: %s", str(e))
|
|
80
|
+
logging.error("Traceback details:\n%s", traceback.format_exc())
|
|
81
|
+
raise (e) from e
|
|
82
|
+
|
|
83
|
+
def read_file(file_path: str) -> str:
|
|
84
|
+
if not os.path.exists(file_path):
|
|
85
|
+
logging.error(f"File '{file_path}' not found")
|
|
86
|
+
raise XGAError(f"File '{file_path}' not found")
|
|
87
|
+
|
|
88
|
+
try:
|
|
89
|
+
with open(file_path, "r", encoding="utf-8") as template_file:
|
|
90
|
+
content = template_file.read()
|
|
91
|
+
return content
|
|
92
|
+
except Exception as e:
|
|
93
|
+
logging.error(f"Read file '{file_path}' failed")
|
|
94
|
+
handle_error(e)
|
|
95
|
+
|
|
96
|
+
def setup_xga_env() -> None:
|
|
97
|
+
from dotenv import load_dotenv
|
|
98
|
+
load_dotenv()
|
|
99
|
+
setup_logging()
|
|
100
|
+
setup_langfuse()
|
|
101
|
+
|
|
102
|
+
if __name__ == "__main__":
|
|
103
|
+
try:
|
|
104
|
+
setup_xga_env()
|
|
105
|
+
trace_id = langfuse.create_trace_id()
|
|
106
|
+
print(f"trace_id={trace_id}")
|
|
107
|
+
except Exception as e:
|
|
108
|
+
handle_error(e)
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: xgae
|
|
3
|
+
Version: 0.1.1
|
|
4
|
+
Summary: Extreme General Agent Engine
|
|
5
|
+
Requires-Python: >=3.13
|
|
6
|
+
Requires-Dist: colorlog>=6.9.0
|
|
7
|
+
Requires-Dist: langchain-mcp-adapters>=0.1.4
|
|
8
|
+
Requires-Dist: langfuse>=2.60.5
|
|
9
|
+
Requires-Dist: langgraph>=0.3.21
|
|
10
|
+
Requires-Dist: litellm>=1.74.8
|
|
11
|
+
Requires-Dist: mcp>=1.12.1
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
xgae/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
+
xgae/engine/xga_base.py,sha256=ZuOVD5mLnfi8rTnDspY8qfEgwefVTTYgf6eY3B_001s,1213
|
|
3
|
+
xgae/engine/xga_engine.py,sha256=3U0F3_ISu5K1CGYhn0RZsIaZkm4OQIrMTIf3N22M7bE,3013
|
|
4
|
+
xgae/engine/xga_mcp_tool_box.py,sha256=3L48c-Wl8QghbudUsZdYemOzEHoOGF0QC-sAh0H1PVI,9047
|
|
5
|
+
xgae/engine/xga_prompt_builder.py,sha256=wPCB6g0QNpKDuvWs5Ix_Z8-OoCFjxqApodb-g-qZrbM,1503
|
|
6
|
+
xgae/engine/responser/xga_non_stream_responser.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
7
|
+
xgae/engine/responser/xga_responser_utils.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
8
|
+
xgae/engine/responser/xga_stream_reponser.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
9
|
+
xgae/utils/llm_client.py,sha256=bsdWLTXRDDpiuEn72a5BuT4FYomE1LjY-thQHbTQ_Fg,12167
|
|
10
|
+
xgae/utils/setup_env.py,sha256=nJHllWuM9kLs_mJg_-j3Lhj8FHtyFjK4CIoDY_SHOxk,3253
|
|
11
|
+
xgae-0.1.1.dist-info/METADATA,sha256=_j2ZiXzJZa-sw_fk7P0sHkKiyYus_ncxoL9YoZEv4iQ,309
|
|
12
|
+
xgae-0.1.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
13
|
+
xgae-0.1.1.dist-info/RECORD,,
|