xgae 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xgae might be problematic. Click here for more details.

xgae/__init__.py ADDED
File without changes
File without changes
File without changes
File without changes
@@ -0,0 +1,46 @@
1
+ from typing import Union, Optional, Dict, List, Any, Literal
2
+ from dataclasses import dataclass
3
+ from abc import ABC, abstractmethod
4
+
5
+
6
+
7
+ @dataclass
8
+ class XGAMessage:
9
+ message_id: str
10
+ type: Literal["status", "tool", "assistant", "assistant_response_end"]
11
+ is_llm_message: bool
12
+ content: Union[Dict[str, Any], List[Any], str]
13
+ metadata: Optional[Dict[str, Any]]
14
+ session_id: Optional[str]
15
+ agent_id: Optional[str]
16
+ task_id: Optional[str]
17
+
18
+ @dataclass
19
+ class XGAToolSchema:
20
+ tool_name: str
21
+ server_name: str
22
+ description: str
23
+ input_schema: Optional[str]
24
+
25
+
26
+ @dataclass
27
+ class XGAToolResult:
28
+ success: bool
29
+ output: str
30
+
31
+ class XGAToolBox(ABC):
32
+ @abstractmethod
33
+ async def creat_task_tool_box(self, task_id: str, general_tools: List[str], custom_tools: List[str]):
34
+ pass
35
+
36
+ @abstractmethod
37
+ async def destroy_task_tool_box(self, task_id: str):
38
+ pass
39
+
40
+ @abstractmethod
41
+ def get_task_tool_schemas(self, task_id: str, type: Literal["general_tool", "custom_tool"]) -> List[XGAToolSchema]:
42
+ pass
43
+
44
+ @abstractmethod
45
+ async def call_tool(self, task_id: str, tool_name: str, args: Optional[Dict[str, Any]] = None) -> XGAToolResult:
46
+ pass
@@ -0,0 +1,69 @@
1
+
2
+ from typing import List, Any, Dict, Optional, AsyncGenerator
3
+ from uuid import uuid4
4
+
5
+ from xgae.engine.xga_base import XGAMessage, XGAToolSchema, XGAToolBox
6
+ from xgae.utils.llm_client import LLMClient
7
+ from xgae.utils.setup_env import langfuse
8
+ from xga_prompt_builder import XGAPromptBuilder
9
+ from xga_mcp_tool_box import XGAMcpToolBox
10
+
11
+ class XGAEngine():
12
+ def __init__(self,
13
+ session_id: Optional[str] = None,
14
+ trace_id: Optional[str] = None,
15
+ agent_id: Optional[str] = None,
16
+ llm_config: Optional[Dict[str, Any]] = None,
17
+ prompt_builder: Optional[XGAPromptBuilder] = None,
18
+ tool_box: Optional[XGAToolBox] = None):
19
+ self.session_id = session_id if session_id else f"xga_sid_{uuid4()}"
20
+ self.agent_id = agent_id
21
+
22
+ self.messages: List[XGAMessage] = []
23
+ self.llm_client = LLMClient(llm_config)
24
+ self.model_name = self.llm_client.model_name
25
+ self.is_stream = self.llm_client.is_stream
26
+
27
+ self.prompt_builder = prompt_builder or XGAPromptBuilder()
28
+ self.tool_box = tool_box or XGAMcpToolBox()
29
+
30
+ self.task_id = None
31
+ self.trace_id = trace_id if trace_id else langfuse.create_trace_id()
32
+
33
+
34
+ async def run_task(self,
35
+ task_messages: List[Dict[str, Any]],
36
+ task_id: Optional[str],
37
+ prompt_template: Optional[str] = None,
38
+ general_tools: Optional[List[str]] = ["*"],
39
+ custom_tools: Optional[List[str]] = []) -> AsyncGenerator:
40
+ try:
41
+ self.task_id = task_id if task_id else f"xga_task_{uuid4()}"
42
+ await self.tool_box.creat_task_tool_box(self.task_id, general_tools, custom_tools)
43
+ system_prompt = await self._build_system_prompt(prompt_template, general_tools, custom_tools)
44
+ yield system_prompt
45
+
46
+ finally:
47
+ await self.tool_box.destroy_task_tool_box(self.task_id)
48
+
49
+
50
+ def _run_task_once(self):
51
+ pass
52
+
53
+ async def _build_system_prompt(self, prompt_template: str, general_tools: List[str], custom_tools: List[str]) -> str:
54
+ self.task_tool_schemas: Dict[str, XGAToolSchema] = {}
55
+ system_prompt = self.prompt_builder.build_system_prompt(self.model_name, prompt_template)
56
+
57
+ tool_schemas = await self.tool_box.get_task_tool_schemas(self.task_id, "general_tool")
58
+ system_prompt = self.prompt_builder.build_general_tool_prompt(self.model_name, system_prompt, tool_schemas)
59
+
60
+ tool_schemas = await self.tool_box.get_task_tool_schemas(self.task_id, "general_tool")
61
+ system_prompt = self.prompt_builder.build_custom_tool_prompt(self.model_name, system_prompt, tool_schemas)
62
+
63
+ return system_prompt
64
+
65
+ def add_message(self, message: XGAMessage):
66
+ message.message_id = f"xga_msg_{uuid4()}"
67
+ message.session_id = self.session_id
68
+ message.agent_id = self.agent_id
69
+ self.messages.append(message)
@@ -0,0 +1,192 @@
1
+ import json
2
+ import logging
3
+ import os
4
+ from typing import List, Any, Dict, Optional, Literal, override
5
+
6
+ from langchain_mcp_adapters.client import MultiServerMCPClient
7
+ from langchain_mcp_adapters.tools import load_mcp_tools
8
+
9
+ from xgae.engine.xga_base import XGAToolSchema, XGAToolBox, XGAToolResult
10
+ from xgae.utils.setup_env import XGAError
11
+
12
+
13
+ class XGAMcpToolBox(XGAToolBox):
14
+ GENERAL_MCP_SERVER_NAME = "xga_general"
15
+
16
+ def __init__(self,
17
+ custom_mcp_server_config: Optional[Dict[str, Any]] = None,
18
+ custom_mcp_server_file: Optional[str] = None):
19
+ general_mcp_server_config = self._load_mcp_servers_config("mcpservers/xga_server.json")
20
+ tool_box_mcp_server_config = general_mcp_server_config.get("mcpServers", {})
21
+
22
+ if custom_mcp_server_config:
23
+ tool_box_mcp_server_config.update(custom_mcp_server_config)
24
+ elif custom_mcp_server_file:
25
+ custom_mcp_server_config = self._load_mcp_servers_config(custom_mcp_server_file)
26
+ custom_mcp_server_config = custom_mcp_server_config.get("mcpServers", {})
27
+ tool_box_mcp_server_config.update(custom_mcp_server_config)
28
+
29
+ self._mcp_client = MultiServerMCPClient(tool_box_mcp_server_config)
30
+ self.mcp_server_names: List[str] = [server_name for server_name in tool_box_mcp_server_config]
31
+ self.mcp_tool_schemas: Dict[str, List[XGAToolSchema]] = {}
32
+ self.task_tool_schemas: Dict[str, Dict[str,XGAToolSchema]] = {}
33
+
34
+ @override
35
+ async def creat_task_tool_box(self, task_id: str, general_tools: List[str], custom_tools: List[str]):
36
+ task_tool_schemas = {}
37
+ general_tool_schemas = self.mcp_tool_schemas.get(XGAMcpToolBox.GENERAL_MCP_SERVER_NAME, {})
38
+ if len(general_tools) > 0 and general_tools[0] == "*":
39
+ task_tool_schemas = {tool_schema.tool_name: tool_schema for tool_schema in general_tool_schemas}
40
+ else:
41
+ for tool_schema in general_tool_schemas:
42
+ if tool_schema.tool_name in general_tools:
43
+ task_tool_schemas[tool_schema.tool_name] = tool_schema
44
+ task_tool_schemas.pop("end_task")
45
+
46
+ for server_tool_name in custom_tools:
47
+ parts = server_tool_name.split(".")
48
+ if len(parts) != 2:
49
+ continue
50
+ custom_server_name, custom_tool_name = parts
51
+ if (not custom_server_name ) or (not custom_tool_name):
52
+ continue
53
+
54
+ custom_tool_schemas = self.mcp_tool_schemas.get(custom_server_name, None)
55
+ if custom_tool_schemas is None:
56
+ continue
57
+ if custom_tool_name == "*":
58
+ custom_tool_schema_d = {tool_schema.tool_name: tool_schema for tool_schema in custom_tool_schemas}
59
+ task_tool_schemas.update(custom_tool_schema_d)
60
+ else:
61
+ for tool_schema in custom_tool_schemas:
62
+ if custom_tool_name == tool_schema.tool_name:
63
+ task_tool_schemas[custom_tool_name] = tool_schema
64
+
65
+
66
+ self.task_tool_schemas[task_id] = task_tool_schemas
67
+
68
+ @override
69
+ async def destroy_task_tool_box(self, task_id: str):
70
+ tool_schemas = self.get_task_tool_schemas(task_id, type="general_tool")
71
+ if len(tool_schemas) > 0:
72
+ await self.call_tool(task_id, "end_task", {"task_id": task_id})
73
+ self.task_tool_schemas.pop(task_id, None)
74
+
75
+ @override
76
+ def get_task_tool_schemas(self, task_id: str, type: Literal["general_tool", "custom_tool"]) -> List[XGAToolSchema]:
77
+ task_tool_schemas = []
78
+
79
+ all_task_tool_schemas = self.task_tool_schemas.get(task_id, {})
80
+ for tool_schema in all_task_tool_schemas.values():
81
+ if type == "general_tool" and tool_schema.server_name == self.GENERAL_MCP_SERVER_NAME:
82
+ task_tool_schemas.append(tool_schema)
83
+ elif type == "custom_tool" and tool_schema.server_name != self.GENERAL_MCP_SERVER_NAME:
84
+ task_tool_schemas.append(tool_schema)
85
+
86
+ return task_tool_schemas
87
+
88
+ @override
89
+ async def call_tool(self, task_id: str, tool_name: str, args: Optional[Dict[str, Any]] = None) -> XGAToolResult:
90
+ if tool_name == "end_task":
91
+ server_name = self.GENERAL_MCP_SERVER_NAME
92
+ else:
93
+ task_tool_schemas = self.task_tool_schemas.get(task_id, {})
94
+ tool_schema = task_tool_schemas.get(tool_name, None)
95
+ if tool_schema is None:
96
+ raise XGAError(f"MCP tool not found: '{tool_name}'")
97
+ server_name = tool_schema.server_name
98
+
99
+ async with self._mcp_client.session(server_name) as session:
100
+ tools = await load_mcp_tools(session)
101
+ mcp_tool = next((t for t in tools if t.name == tool_name), None)
102
+
103
+ if mcp_tool:
104
+ tool_args = args or {}
105
+ if server_name == self.GENERAL_MCP_SERVER_NAME:
106
+ pass
107
+ #tool_args["task_id"] = task_id #xga general tool, first param must be task_id
108
+ else:
109
+ tool_args = args
110
+
111
+ try:
112
+ tool_result = await mcp_tool.arun(tool_args)
113
+ result = XGAToolResult(success=True, output=str(tool_result))
114
+ except Exception as e:
115
+ error = f"Call mcp tool '{tool_name}' error: {str(e)}"
116
+ logging.error(f"XGAMcpToolBox.call_tool: {error}")
117
+ result = XGAToolResult(success=False, output=error)
118
+ else:
119
+ error = f"No MCP tool found with name: {tool_name}"
120
+ logging.info(f"XGAMcpToolBox.call_tool: error={error}")
121
+ result = XGAToolResult(success=False, output=error)
122
+
123
+ return result
124
+
125
+ async def load_mcp_tools_schema(self)-> None:
126
+ for server_name in self.mcp_server_names:
127
+ self.mcp_tool_schemas[server_name] = []
128
+ mcp_tools = await self._mcp_client.get_tools(server_name=server_name)
129
+ for tool in mcp_tools:
130
+ input_schema = tool.args_schema
131
+ if server_name == self.GENERAL_MCP_SERVER_NAME:
132
+ input_schema = str(tool.args_schema) # @todo remove task_id param
133
+ input_schema_str = str(input_schema) # @todo convert input tool.args_schema
134
+ tool_schema = XGAToolSchema(tool_name=tool.name,
135
+ server_name=server_name,
136
+ description=tool.description,
137
+ input_schema=input_schema_str)
138
+ self.mcp_tool_schemas[server_name].append(tool_schema)
139
+
140
+ @staticmethod
141
+ def _load_mcp_servers_config(mcp_config_path: str) -> Dict[str, Any]:
142
+ try:
143
+ if os.path.exists(mcp_config_path):
144
+ with open(mcp_config_path, 'r', encoding='utf-8') as f:
145
+ server_config = json.load(f)
146
+
147
+ for server_name, server_info in server_config["mcpServers"].items():
148
+ if "transport" not in server_info:
149
+ if "url" in server_info:
150
+ server_info["transport"] = "streamable_http" if "mcp" in server_info["url"] else "sse"
151
+ else:
152
+ server_info["transport"] = "stdio"
153
+
154
+ return server_config
155
+ else:
156
+ logging.warning("MCP servers config file not found at: %s", mcp_config_path)
157
+ return {"mcpServers": {}}
158
+
159
+ except Exception as e:
160
+ logging.error("Failed to load MCP servers config: %s", str(e))
161
+ return {"mcpServers": {}}
162
+
163
+
164
+ if __name__ == "__main__":
165
+ import asyncio
166
+ from dataclasses import asdict
167
+ async def main():
168
+ task_id = "task1"
169
+ mcp_tool_box = XGAMcpToolBox(custom_mcp_server_file="mcpservers/custom_servers.json")
170
+ #mcp_tool_box = XGAMcpToolBox()
171
+ await mcp_tool_box.load_mcp_tools_schema()
172
+ await mcp_tool_box.creat_task_tool_box(task_id=task_id, general_tools=["*"], custom_tools=["bomc_fault.*"])
173
+ tool_schemas = mcp_tool_box.get_task_tool_schemas(task_id, "general_tool")
174
+ print("general_tools_schemas" + "*"*50)
175
+ for tool_schema in tool_schemas:
176
+ print(asdict(tool_schema))
177
+ print()
178
+
179
+ tool_schemas = mcp_tool_box.get_task_tool_schemas(task_id, "custom_tool")
180
+ print("custom_tools_schemas" + "*" * 50)
181
+ for tool_schema in tool_schemas:
182
+ print(asdict(tool_schema))
183
+ print()
184
+
185
+ result = await mcp_tool_box.call_tool(task_id=task_id, tool_name="web_search", args={"task_id": task_id, "query": "查询天津天气"})
186
+ print(f"call web_search result: {result}")
187
+
188
+ result = await mcp_tool_box.call_tool(task_id=task_id, tool_name="complete", args={"task_id": task_id})
189
+ print(f"call complete result: {result}")
190
+
191
+ await mcp_tool_box.destroy_task_tool_box(task_id)
192
+ asyncio.run(main())
@@ -0,0 +1,38 @@
1
+ import datetime
2
+
3
+ from typing import Optional, List
4
+
5
+ from xga_base import XGAToolSchema
6
+ from xgae.utils.setup_env import read_file, XGAError
7
+
8
+
9
+ class XGAPromptBuilder():
10
+ def __init__(self,
11
+ prompt_template: Optional[str] = None,
12
+ prompt_template_file: Optional[str] = None):
13
+ self.system_prompt_template = None
14
+ if prompt_template:
15
+ self.system_prompt_template = prompt_template
16
+ elif prompt_template_file:
17
+ self.system_prompt_template = read_file(prompt_template_file)
18
+ else:
19
+ _system_prompt_template = read_file("templates/system_prompt_template.txt")
20
+ self.system_prompt_template = _system_prompt_template.format(
21
+ current_date=datetime.datetime.now(datetime.timezone.utc).strftime('%Y-%m-%d'),
22
+ current_time=datetime.datetime.now(datetime.timezone.utc).strftime('%H:%M:%S'),
23
+ current_year=datetime.datetime.now(datetime.timezone.utc).strftime('%Y')
24
+ )
25
+
26
+
27
+ def build_system_prompt(self, model_name:str, prompt_template: Optional[str]=None)-> str:
28
+ system_prompt = prompt_template if prompt_template else self.system_prompt_template
29
+
30
+ return system_prompt
31
+
32
+
33
+ def build_general_tool_prompt(self, model_name:str, prompt_template: str, tool_schemas:List[XGAToolSchema])-> str:
34
+ pass
35
+
36
+
37
+ def build_custom_tool_prompt(self, model_name:str, prompt_template: str, tool_schemas:List[XGAToolSchema])-> str:
38
+ pass
@@ -0,0 +1,239 @@
1
+ import asyncio
2
+ import json
3
+ import logging
4
+ import os
5
+ import litellm
6
+
7
+ from typing import Union, Dict, Any, Optional, List
8
+
9
+ from litellm.utils import ModelResponse, CustomStreamWrapper
10
+ from openai import OpenAIError
11
+
12
+ from setup_env import setup_xga_env
13
+
14
+
15
+ class LLMError(Exception):
16
+ """Base exception for LLM-related errors."""
17
+ pass
18
+
19
+ class LLMClient:
20
+ RATE_LIMIT_DELAY = 30
21
+ RETRY_DELAY = 0.1
22
+
23
+ def __init__(self, llm_config: Optional[Dict[str, Any]]={}) -> None:
24
+ """
25
+ Arg: llm_config (Optional[Dict[str, Any]], optional)
26
+ model: Override default model to use, default set by .env LLM_MODEL
27
+ model_name: Optional Name of the model to use , use model if empty
28
+ model_id: Optional ARN for Bedrock inference profiles, default is None
29
+ api_key: Optional API key, Override .env LLM_API_KEY or OS environment variable
30
+ api_base: Optional API base URL, Override .env LLM_API_BASE
31
+ temperature: Optional Sampling temperature (0-1), Override .env LLM_TEMPERATURE
32
+ max_tokens: Optional Maximum tokens in the response, Override .env LLM_MAX_TOKENS
33
+ stream: Optional whether to stream the response, default is True
34
+ response_format: Optional desired format for the response, default is None
35
+ enable_thinking: Optional whether to enable thinking, default is False
36
+ reasoning_effort: Optional level of reasoning effort, default is ‘low’
37
+ top_p: Optional Top-p sampling parameter, default is None
38
+ """
39
+ litellm.modify_params = True
40
+ litellm.drop_params = True
41
+
42
+ self.max_retries = int(os.getenv("LLM_MAX_RETRIES", 1))
43
+
44
+ env_llm_model = os.getenv("LLM_MODEL", "openai/qwen3-235b-a22b")
45
+ env_llm_api_key = os.getenv("LLM_API_KEY")
46
+ env_llm_api_base = os.getenv("LLM_API_BASE", "https://dashscope.aliyuncs.com/compatible-mode/v1")
47
+ env_llm_max_tokens = int(os.getenv("LLM_MAX_TOKENS", 16384))
48
+ env_llm_temperature = float(os.getenv("LLM_TEMPERATURE", 0.7))
49
+
50
+ llm_config_params = {
51
+ "model": llm_config.get("model", env_llm_model),
52
+ "model_name": llm_config.get("model_name", env_llm_model),
53
+ "model_id": llm_config.get("model_id", None),
54
+ "api_key": llm_config.get("api_key", env_llm_api_key),
55
+ "api_base": llm_config.get("api_base", env_llm_api_base),
56
+ "temperature": llm_config.get("temperature", env_llm_temperature),
57
+ "max_tokens": llm_config.get("max_tokens", env_llm_max_tokens),
58
+ "stream": llm_config.get("stream", True),
59
+ "enable_thinking": llm_config.get("enable_thinking", False),
60
+ "reasoning_effort": llm_config.get("reasoning_effort", 'low'),
61
+ "response_format": llm_config.get("response_format", None),
62
+ "top_p": llm_config.get("top_p", None),
63
+ "tools": None,
64
+ "tool_choice": "none",
65
+ }
66
+
67
+ self.model_name = llm_config_params["model_name"]
68
+ self.is_stream = llm_config_params['stream']
69
+
70
+ self.lite_llm_params = self._prepare_llm_params(llm_config_params)
71
+ logging.info(f"📡 LLMClient initialed : model={self.model_name}, is_stream={self.is_stream}, enable thinking={self.lite_llm_params['enable_thinking']}")
72
+
73
+
74
+ def _prepare_llm_params(self, llm_config_params: Dict[str, Any]) -> Dict[str, Any]:
75
+ prepared_llm_params = llm_config_params.copy()
76
+
77
+ model_name = llm_config_params.get("model_name")
78
+ max_tokens = llm_config_params.get("max_tokens")
79
+ model_id = llm_config_params.get("model_id")
80
+
81
+ # Handle token limits
82
+ if max_tokens is not None:
83
+ # For Claude 3.7 in Bedrock, do not set max_tokens or max_tokens_to_sample
84
+ # as it causes errors with inference profiles
85
+ if model_name.startswith("bedrock/") and "claude-3-7" in model_name:
86
+ prepared_llm_params.pop("max_tokens")
87
+ logging.debug(f"prepare_llm_params: Remove 'max_tokens' param for model: {model_name}")
88
+ else:
89
+ is_openai_o_series = 'o1' in model_name
90
+ is_openai_gpt5 = 'gpt-5' in model_name
91
+ param_name = "max_completion_tokens" if (is_openai_o_series or is_openai_gpt5) else "max_tokens"
92
+ if param_name == "max_completion_tokens":
93
+ prepared_llm_params[param_name] = max_tokens
94
+ logging.debug(f"prepare_llm_params: Add 'max_completion_tokens' param for model: {model_name}")
95
+
96
+ # # Add Claude-specific headers
97
+ if "claude" in model_name.lower() or "anthropic" in model_name.lower():
98
+ prepared_llm_params["extra_headers"] = {
99
+ "anthropic-beta": "output-128k-2025-02-19"
100
+ }
101
+ logging.debug(f"prepare_llm_params: Add 'extra_headers' param for model: {model_name}")
102
+
103
+ # Add Bedrock-specific parameters
104
+ if model_name.startswith("bedrock/"):
105
+ if not model_id and "anthropic.claude-3-7-sonnet" in model_name:
106
+ prepared_llm_params["model_id"] = "arn:aws:bedrock:us-west-2:935064898258:inference-profile/us.anthropic.claude-3-7-sonnet-20250219-v1:0"
107
+ logging.debug(f"prepare_llm_params: Must Set 'model_id' param for model: {model_name}")
108
+
109
+ # Apply Anthropic prompt caching (minimal implementation)
110
+ effective_model_name = llm_config_params.get("model", model_name)
111
+
112
+ # OpenAI GPT-5: drop unsupported temperature param (only default 1 allowed)
113
+ if "gpt-5" in effective_model_name and "temperature" in llm_config_params and llm_config_params["temperature"] != 1:
114
+ prepared_llm_params.pop("temperature", None)
115
+ logging.debug(f"prepare_llm_params: Remove 'temperature' param for model: {model_name}")
116
+
117
+ # OpenAI GPT-5: request priority service tier when calling OpenAI directly
118
+ # Pass via both top-level and extra_body for LiteLLM compatibility
119
+ if "gpt-5" in effective_model_name and not effective_model_name.startswith("openrouter/"):
120
+ prepared_llm_params["service_tier"] = "priority"
121
+ prepared_llm_params["extra_body"] = {"service_tier": "priority"}
122
+ logging.debug(f"prepare_llm_params: Add 'service_tier' and 'extra_body' param for model: {model_name}")
123
+
124
+ # Add reasoning_effort for Anthropic models if enabled
125
+ enable_thinking = llm_config_params.get("enable_thinking")
126
+ use_thinking = enable_thinking if enable_thinking is not None else False
127
+
128
+ is_anthropic = "anthropic" in effective_model_name.lower() or "claude" in effective_model_name.lower()
129
+ is_kimi_k2 = "kimi-k2" in effective_model_name.lower() or model_name.startswith("moonshotai/kimi-k2")
130
+
131
+ if is_kimi_k2:
132
+ prepared_llm_params["provider"] = {
133
+ "order": ["together/fp8", "novita/fp8", "baseten/fp8", "moonshotai", "groq"]
134
+ }
135
+ logging.debug(f"prepare_llm_params: Add 'provider' param for model: {model_name}")
136
+
137
+ reasoning_effort = llm_config_params.get("reasoning_effort")
138
+ if is_anthropic and use_thinking:
139
+ effort_level = reasoning_effort if reasoning_effort else 'low'
140
+ prepared_llm_params["reasoning_effort"] = effort_level
141
+ prepared_llm_params["temperature"] = 1.0 # Required by Anthropic when reasoning_effort is used
142
+ logging.debug(f"prepare_llm_params: Set 'temperature'=1.0 param for model: {model_name}")
143
+
144
+ return prepared_llm_params
145
+
146
+
147
+ def _prepare_complete_params(self, messages: List[Dict[str, Any]]) -> Dict[str, Any]:
148
+ """Prepare parameters for the API call."""
149
+ complete_params = self.lite_llm_params.copy()
150
+ complete_params["messages"] = messages
151
+
152
+ model_name = self.lite_llm_params["model_name"]
153
+ effective_model_name = complete_params.get("model", model_name)
154
+
155
+ # Apply cache control to the first 4 text blocks across all messages , for anthropic and claude model
156
+ if "claude" in effective_model_name.lower() or "anthropic" in effective_model_name.lower():
157
+ messages = complete_params["messages"]
158
+
159
+ if not isinstance(messages, list):
160
+ return complete_params
161
+
162
+ cache_control_count = 0
163
+ max_cache_control_blocks = 3
164
+
165
+ for message in messages:
166
+ if cache_control_count >= max_cache_control_blocks:
167
+ break
168
+
169
+ content = message.get("content")
170
+
171
+ if isinstance(content, str):
172
+ message["content"] = [
173
+ {"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}
174
+ ]
175
+ cache_control_count += 1
176
+ logging.debug(f"prepare_complete_params: Add 'cache_control' in message content, for model: {model_name}")
177
+ elif isinstance(content, list):
178
+ for item in content:
179
+ if cache_control_count >= max_cache_control_blocks:
180
+ break
181
+ if isinstance(item, dict) and item.get("type") == "text" and "cache_control" not in item:
182
+ item["cache_control"] = {"type": "ephemeral"}
183
+ cache_control_count += 1
184
+ logging.debug(f"prepare_complete_params: Add 'cache_control' in message content list, for model: {model_name}")
185
+
186
+ return complete_params
187
+
188
+ async def _handle_llm_error(self, error: Exception, attempt: int) -> None:
189
+ """Handle API errors with appropriate delays and logging."""
190
+ if (attempt + 1) < self.max_retries:
191
+ delay = LLMClient.RATE_LIMIT_DELAY if isinstance(error, litellm.exceptions.RateLimitError) else LLMClient.RETRY_DELAY
192
+ logging.warning(f"LLMClient: Error on llm completion, retry={attempt + 1}/{self.max_retries}: {error}")
193
+ logging.debug(f"LLMClient: Waiting {delay} seconds before retry llm completion...")
194
+ await asyncio.sleep(delay)
195
+
196
+
197
+ async def create_completion(self, messages: List[Dict[str, Any]]) -> Union[ModelResponse, CustomStreamWrapper]:
198
+ complete_params = self._prepare_complete_params(messages)
199
+
200
+ last_error = None
201
+ for attempt in range(self.max_retries):
202
+ try:
203
+ logging.info(f"*** create_completion ***: LLM '{self.model_name}' completion attempt {attempt + 1}/{self.max_retries}")
204
+ response = await litellm.acompletion(**complete_params)
205
+ return response
206
+ except (litellm.exceptions.RateLimitError, OpenAIError, json.JSONDecodeError) as e:
207
+ last_error = e
208
+ await self._handle_llm_error(e, attempt)
209
+ except Exception as e:
210
+ logging.error(f"create_completion: Unexpected error during LLM completion: {str(e)}", exc_info=True)
211
+ raise LLMError(f"LLM completion failed: {e}")
212
+
213
+ logging.error(f"create_completion: LLM completion failed after {self.max_retries} attempts: {last_error}", exc_info=True)
214
+ raise LLMError(f"LLM completion failed after {self.max_retries} attempts !")
215
+
216
+ if __name__ == "__main__":
217
+ setup_xga_env()
218
+
219
+ async def llm_completion():
220
+ llm_client = LLMClient({
221
+ "stream": False #default is True
222
+ })
223
+ messages = [{"role": "user", "content": "今天是2025年8月15日,北京本周每天温度"}]
224
+ response = await llm_client.create_completion(messages)
225
+ if llm_client.is_stream:
226
+ async for chunk in response:
227
+ choices = chunk.get("choices", [{}])
228
+ if not choices:
229
+ continue
230
+ delta = choices[0].get("delta", {})
231
+ content = delta.get("content", "")
232
+ if content:
233
+ print(content, end="", flush=True)
234
+ else:
235
+ print(response.choices[0].message.content)
236
+
237
+ asyncio.run(llm_completion())
238
+
239
+
@@ -0,0 +1,108 @@
1
+ import logging
2
+ import os
3
+ import traceback
4
+
5
+ from langfuse import Langfuse
6
+
7
+
8
+ class XGAError(Exception):
9
+ """Custom exception for errors in the XGA system."""
10
+ pass
11
+
12
+ langfuse: Langfuse = None
13
+
14
+ def setup_langfuse() -> None:
15
+ env_public_key = os.getenv("LANGFUSE_PUBLIC_KEY")
16
+ env_secret_key = os.getenv("LANGFUSE_SECRET_KEY")
17
+ env_host = os.getenv("LANGFUSE_HOST", "https://cloud.langfuse.com")
18
+ global langfuse
19
+ if env_public_key and env_secret_key:
20
+ langfuse = Langfuse(tracing_enabled=True,
21
+ public_key=env_public_key,
22
+ secret_key=env_secret_key,
23
+ host=env_host)
24
+ logging.info("utils.setup_langfuse: Langfuse initialized!")
25
+ else:
26
+ langfuse = Langfuse(tracing_enabled=False)
27
+ logging.warning("utils.setup_langfuse: Langfuse is disabled!")
28
+
29
+ def setup_logging() -> None:
30
+ env_log_level = os.getenv("LOG_LEVEL", "INFO")
31
+ env_log_file = os.getenv("LOG_FILE", "log/xga.log")
32
+ log_level = getattr(logging, env_log_level.upper(), logging.INFO)
33
+
34
+ log_dir = os.path.dirname(env_log_file)
35
+ if log_dir and not os.path.exists(log_dir):
36
+ os.makedirs(log_dir, exist_ok=True)
37
+ else:
38
+ os.remove(env_log_file)
39
+
40
+ logger = logging.getLogger()
41
+ for handler in logger.handlers[:]:
42
+ logger.removeHandler(handler)
43
+
44
+ import colorlog
45
+
46
+ log_colors = {
47
+ 'DEBUG': 'cyan',
48
+ 'INFO': 'green',
49
+ 'WARNING': 'yellow',
50
+ 'ERROR': 'red',
51
+ 'CRITICAL': 'red,bg_white'
52
+ }
53
+
54
+ console_formatter = colorlog.ColoredFormatter('%(log_color)s%(asctime)s - %(levelname)-8s%(reset)s %(white)s%(message)s',
55
+ log_colors=log_colors,
56
+ datefmt='%Y-%m-%d %H:%M:%S'
57
+ )
58
+
59
+ file_formatter = logging.Formatter(
60
+ '%(asctime)s -%(levelname)-8s %(message)s',
61
+ datefmt='%Y-%m-%d %H:%M:%S'
62
+ )
63
+
64
+ console_handler = logging.StreamHandler()
65
+ console_handler.setFormatter(console_formatter)
66
+
67
+ file_handler = logging.FileHandler(env_log_file, encoding='utf-8')
68
+ file_handler.setFormatter(file_formatter)
69
+
70
+ logger.addHandler(console_handler)
71
+ logger.addHandler(file_handler)
72
+
73
+ logger.setLevel(log_level)
74
+
75
+ logging.info(f"Logger is initialized, log_level={env_log_level}, log_file={env_log_file}")
76
+
77
+
78
+ def handle_error(e: Exception) -> None:
79
+ logging.error("An error occurred: %s", str(e))
80
+ logging.error("Traceback details:\n%s", traceback.format_exc())
81
+ raise (e) from e
82
+
83
+ def read_file(file_path: str) -> str:
84
+ if not os.path.exists(file_path):
85
+ logging.error(f"File '{file_path}' not found")
86
+ raise XGAError(f"File '{file_path}' not found")
87
+
88
+ try:
89
+ with open(file_path, "r", encoding="utf-8") as template_file:
90
+ content = template_file.read()
91
+ return content
92
+ except Exception as e:
93
+ logging.error(f"Read file '{file_path}' failed")
94
+ handle_error(e)
95
+
96
+ def setup_xga_env() -> None:
97
+ from dotenv import load_dotenv
98
+ load_dotenv()
99
+ setup_logging()
100
+ setup_langfuse()
101
+
102
+ if __name__ == "__main__":
103
+ try:
104
+ setup_xga_env()
105
+ trace_id = langfuse.create_trace_id()
106
+ print(f"trace_id={trace_id}")
107
+ except Exception as e:
108
+ handle_error(e)
@@ -0,0 +1,11 @@
1
+ Metadata-Version: 2.4
2
+ Name: xgae
3
+ Version: 0.1.1
4
+ Summary: Extreme General Agent Engine
5
+ Requires-Python: >=3.13
6
+ Requires-Dist: colorlog>=6.9.0
7
+ Requires-Dist: langchain-mcp-adapters>=0.1.4
8
+ Requires-Dist: langfuse>=2.60.5
9
+ Requires-Dist: langgraph>=0.3.21
10
+ Requires-Dist: litellm>=1.74.8
11
+ Requires-Dist: mcp>=1.12.1
@@ -0,0 +1,13 @@
1
+ xgae/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ xgae/engine/xga_base.py,sha256=ZuOVD5mLnfi8rTnDspY8qfEgwefVTTYgf6eY3B_001s,1213
3
+ xgae/engine/xga_engine.py,sha256=3U0F3_ISu5K1CGYhn0RZsIaZkm4OQIrMTIf3N22M7bE,3013
4
+ xgae/engine/xga_mcp_tool_box.py,sha256=3L48c-Wl8QghbudUsZdYemOzEHoOGF0QC-sAh0H1PVI,9047
5
+ xgae/engine/xga_prompt_builder.py,sha256=wPCB6g0QNpKDuvWs5Ix_Z8-OoCFjxqApodb-g-qZrbM,1503
6
+ xgae/engine/responser/xga_non_stream_responser.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
+ xgae/engine/responser/xga_responser_utils.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
+ xgae/engine/responser/xga_stream_reponser.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
+ xgae/utils/llm_client.py,sha256=bsdWLTXRDDpiuEn72a5BuT4FYomE1LjY-thQHbTQ_Fg,12167
10
+ xgae/utils/setup_env.py,sha256=nJHllWuM9kLs_mJg_-j3Lhj8FHtyFjK4CIoDY_SHOxk,3253
11
+ xgae-0.1.1.dist-info/METADATA,sha256=_j2ZiXzJZa-sw_fk7P0sHkKiyYus_ncxoL9YoZEv4iQ,309
12
+ xgae-0.1.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
13
+ xgae-0.1.1.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: hatchling 1.27.0
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any