yaicli 0.6.3__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,25 +1,18 @@
1
1
  import json
2
- from functools import wraps
3
2
  from typing import Any, Callable, Dict, Generator, List
4
3
 
5
4
  import google.genai as genai
6
5
  from google.genai import types
7
6
 
7
+ from yaicli.tools.mcp import get_mcp_manager
8
+
8
9
  from ...config import cfg
9
10
  from ...console import get_console
10
11
  from ...schemas import ChatMessage, LLMResponse
11
- from ...tools import get_func_name_map
12
+ from ...tools.function import get_functions_gemini_format
12
13
  from ..provider import Provider
13
14
 
14
15
 
15
- def wrap_function(func):
16
- @wraps(func)
17
- def wrapper(*args, **kwargs):
18
- return func(*args, **kwargs)
19
-
20
- return wrapper
21
-
22
-
23
16
  class GeminiProvider(Provider):
24
17
  """Gemini provider implementation based on google-genai library"""
25
18
 
@@ -28,6 +21,7 @@ class GeminiProvider(Provider):
28
21
  def __init__(self, config: dict = cfg, verbose: bool = False, **kwargs):
29
22
  self.config = config
30
23
  self.enable_function = self.config["ENABLE_FUNCTIONS"]
24
+ self.enable_mcp = self.config["ENABLE_MCP"]
31
25
  self.verbose = verbose
32
26
 
33
27
  # Initialize client
@@ -67,16 +61,17 @@ class GeminiProvider(Provider):
67
61
  config_map["frequency_penalty"] = self.config["FREQUENCY_PENALTY"]
68
62
  if self.config.get("SEED"):
69
63
  config_map["seed"] = self.config["SEED"]
70
- # Indicates whether to include thoughts in the response. If true, thoughts are returned only if the model supports thought and thoughts are available.
64
+ # Indicates whether to include thoughts in the response.
65
+ # If true, thoughts are returned only if the model supports thought and thoughts are available.
71
66
  thinking_config_map = {"include_thoughts": self.config.get("INCLUDE_THOUGHTS", True)}
72
67
  if self.config.get("THINKING_BUDGET"):
73
68
  thinking_config_map["thinking_budget"] = int(self.config["THINKING_BUDGET"])
74
69
  config_map["thinking_config"] = types.ThinkingConfig(**thinking_config_map)
75
- config = types.GenerateContentConfig(**config_map)
76
- if self.enable_function:
70
+ if self.enable_function or self.enable_mcp:
77
71
  # TODO: support disable automatic function calling
78
72
  # config.automatic_function_calling = types.AutomaticFunctionCallingConfig(disable=False)
79
- config.tools = self.gen_gemini_functions()
73
+ config_map["tools"] = self.gen_gemini_functions()
74
+ config = types.GenerateContentConfig(**config_map)
80
75
  return config
81
76
 
82
77
  def _convert_messages(self, messages: List[ChatMessage]) -> List[types.Content]:
@@ -88,7 +83,9 @@ class GeminiProvider(Provider):
88
83
  content = types.Content(role=self._map_role(msg.role), parts=[types.Part(text=msg.content)])
89
84
  if msg.role == "tool":
90
85
  content.role = "user"
91
- content.parts = [types.Part.from_function_response(name=msg.name, response={"result": msg.content})]
86
+ content.parts = [
87
+ types.Part.from_function_response(name=msg.name or "", response={"result": msg.content})
88
+ ]
92
89
  converted_messages.append(content)
93
90
  return converted_messages
94
91
 
@@ -101,15 +98,19 @@ class GeminiProvider(Provider):
101
98
 
102
99
  def gen_gemini_functions(self) -> List[Callable[..., Any]]:
103
100
  """Wrap Gemini functions from OpenAI functions for automatic function calling"""
104
- func_name_map = get_func_name_map()
105
- if not func_name_map:
106
- return []
107
101
  funcs = []
108
- for func_name, func in func_name_map.items():
109
- wrapped_func = wrap_function(func.execute)
110
- wrapped_func.__name__ = func_name
111
- wrapped_func.__doc__ = func.__doc__
112
- funcs.append(wrapped_func)
102
+
103
+ # Add regular functions
104
+ if self.enable_function:
105
+ funcs.extend(get_functions_gemini_format())
106
+
107
+ # Add MCP functions if enabled
108
+ if self.enable_mcp:
109
+ try:
110
+ mcp_tools = get_mcp_manager().to_gemini_tools()
111
+ funcs.extend(mcp_tools)
112
+ except (ImportError, Exception) as e:
113
+ self.console.print(f"Failed to load MCP tools for Gemini: {e}", style="red")
113
114
  return funcs
114
115
 
115
116
  def completion(
@@ -137,14 +138,14 @@ class GeminiProvider(Provider):
137
138
  self.console.print(gemini_messages)
138
139
  chat_config = self.get_chat_config()
139
140
  chat_config.system_instruction = messages[0].content
140
- chat = self.client.chats.create(model=self.config["MODEL"], history=gemini_messages, config=chat_config)
141
+ chat = self.client.chats.create(model=self.config["MODEL"], history=gemini_messages, config=chat_config) # type: ignore
141
142
  message = messages[-1].content
142
143
 
143
144
  if stream:
144
- response = chat.send_message_stream(message=message)
145
+ response = chat.send_message_stream(message=message) # type: ignore
145
146
  yield from self._handle_stream_response(response)
146
147
  else:
147
- response = chat.send_message(message=message)
148
+ response = chat.send_message(message=message) # type: ignore
148
149
  yield from self._handle_normal_response(response)
149
150
 
150
151
  def _handle_normal_response(self, response) -> Generator[LLMResponse, None, None]:
@@ -158,7 +159,7 @@ class GeminiProvider(Provider):
158
159
  return
159
160
  for part in response.candidates[0].content.parts:
160
161
  if part.thought:
161
- yield LLMResponse(reasoning=part.text, content=None, finish_reason="stop")
162
+ yield LLMResponse(reasoning=part.text, finish_reason="stop")
162
163
  else:
163
164
  yield LLMResponse(reasoning=None, content=part.text, finish_reason="stop")
164
165
 
@@ -181,7 +182,7 @@ class GeminiProvider(Provider):
181
182
  reasoning = None
182
183
  yield LLMResponse(
183
184
  reasoning=reasoning,
184
- content=content,
185
+ content=content or "",
185
186
  tool_call=tool_call if finish_reason == "tool_calls" else None,
186
187
  finish_reason=finish_reason or None,
187
188
  )
@@ -0,0 +1,40 @@
1
+ from typing import Any, Dict
2
+
3
+ from huggingface_hub import InferenceClient
4
+
5
+ from .chatglm_provider import ChatglmProvider
6
+
7
+
8
+ class HuggingFaceProvider(ChatglmProvider):
9
+ """
10
+ HuggingFaceProvider is a provider for the HuggingFace API.
11
+ """
12
+
13
+ CLIENT_CLS = InferenceClient
14
+ DEFAULT_PROVIDER = "auto"
15
+
16
+ COMPLETION_PARAMS_KEYS = {
17
+ "model": "MODEL",
18
+ "temperature": "TEMPERATURE",
19
+ "top_p": "TOP_P",
20
+ "max_tokens": "MAX_TOKENS",
21
+ "extra_body": "EXTRA_BODY",
22
+ }
23
+
24
+ def get_client_params(self) -> Dict[str, Any]:
25
+ client_params = {
26
+ "api_key": self.config["API_KEY"],
27
+ "timeout": self.config["TIMEOUT"],
28
+ "provider": self.config.get("HF_PROVIDER") or self.DEFAULT_PROVIDER,
29
+ }
30
+ if self.config["BASE_URL"]:
31
+ client_params["base_url"] = self.config["BASE_URL"]
32
+ if self.config["EXTRA_HEADERS"]:
33
+ client_params["headers"] = {
34
+ **self.config["EXTRA_HEADERS"],
35
+ "X-Title": self.APP_NAME,
36
+ "HTTP-Referer": self.APP_REFERER,
37
+ }
38
+ if self.config.get("BILL_TO"):
39
+ client_params["bill_to"] = self.config["BILL_TO"]
40
+ return client_params
@@ -1,5 +1,6 @@
1
1
  from typing import Any, Dict
2
2
 
3
+ from ...config import cfg
3
4
  from .openai_provider import OpenAIProvider
4
5
 
5
6
 
@@ -8,7 +9,7 @@ class InfiniAIProvider(OpenAIProvider):
8
9
 
9
10
  DEFAULT_BASE_URL = "https://cloud.infini-ai.com/maas/v1"
10
11
 
11
- def __init__(self, config: dict = ..., **kwargs):
12
+ def __init__(self, config: dict = cfg, **kwargs):
12
13
  super().__init__(config, **kwargs)
13
14
  if self.enable_function:
14
15
  self.console.print("InfiniAI does not support functions, disabled", style="yellow")
@@ -16,5 +17,6 @@ class InfiniAIProvider(OpenAIProvider):
16
17
 
17
18
  def get_completion_params(self) -> Dict[str, Any]:
18
19
  params = super().get_completion_params()
19
- params["max_tokens"] = params.pop("max_completion_tokens")
20
+ if "max_completion_tokens" in params:
21
+ params["max_tokens"] = params.pop("max_completion_tokens")
20
22
  return params
@@ -10,5 +10,6 @@ class ModelScopeProvider(OpenAIProvider):
10
10
 
11
11
  def get_completion_params(self) -> Dict[str, Any]:
12
12
  params = super().get_completion_params()
13
- params["max_tokens"] = params.pop("max_completion_tokens")
13
+ if "max_completion_tokens" in params:
14
+ params["max_tokens"] = params.pop("max_completion_tokens")
14
15
  return params
@@ -10,6 +10,7 @@ from ...config import cfg
10
10
  from ...console import get_console
11
11
  from ...schemas import ChatMessage, LLMResponse, ToolCall
12
12
  from ...tools import get_openai_schemas
13
+ from ...tools.mcp import get_mcp_manager
13
14
  from ..provider import Provider
14
15
 
15
16
 
@@ -19,7 +20,7 @@ class OpenAIProvider(Provider):
19
20
  DEFAULT_BASE_URL = "https://api.openai.com/v1"
20
21
  CLIENT_CLS = openai.OpenAI
21
22
  # Base mapping between config keys and API parameter names
22
- _BASE_COMPLETION_PARAMS_KEYS = {
23
+ COMPLETION_PARAMS_KEYS = {
23
24
  "model": "MODEL",
24
25
  "temperature": "TEMPERATURE",
25
26
  "top_p": "TOP_P",
@@ -34,6 +35,7 @@ class OpenAIProvider(Provider):
34
35
  if not self.config.get("API_KEY"):
35
36
  raise ValueError("API_KEY is required")
36
37
  self.enable_function = self.config["ENABLE_FUNCTIONS"]
38
+ self.enable_mcp = self.config["ENABLE_MCP"]
37
39
  self.verbose = verbose
38
40
 
39
41
  # Initialize client
@@ -50,15 +52,12 @@ class OpenAIProvider(Provider):
50
52
  client_params = {
51
53
  "api_key": self.config["API_KEY"],
52
54
  "base_url": self.config["BASE_URL"] or self.DEFAULT_BASE_URL,
55
+ "default_headers": {"X-Title": self.APP_NAME, "HTTP_Referer": self.APP_REFERER},
53
56
  }
54
57
 
55
58
  # Add extra headers if set
56
59
  if self.config["EXTRA_HEADERS"]:
57
- client_params["default_headers"] = {
58
- **self.config["EXTRA_HEADERS"],
59
- "X-Title": self.APP_NAME,
60
- "HTTP-Referer": self.APP_REFERER,
61
- }
60
+ client_params["default_headers"] = {**self.config["EXTRA_HEADERS"], **client_params["default_headers"]}
62
61
  return client_params
63
62
 
64
63
  def get_completion_params_keys(self) -> Dict[str, str]:
@@ -69,7 +68,7 @@ class OpenAIProvider(Provider):
69
68
  Returns:
70
69
  Dict[str, str]: Mapping from API parameter names to config keys
71
70
  """
72
- return self._BASE_COMPLETION_PARAMS_KEYS.copy()
71
+ return self.COMPLETION_PARAMS_KEYS.copy()
73
72
 
74
73
  def get_completion_params(self) -> Dict[str, Any]:
75
74
  """
@@ -89,7 +88,7 @@ class OpenAIProvider(Provider):
89
88
  """Convert a list of ChatMessage objects to a list of OpenAI message dicts."""
90
89
  converted_messages = []
91
90
  for msg in messages:
92
- message = {"role": msg.role, "content": msg.content or ""}
91
+ message: Dict[str, Any] = {"role": msg.role, "content": msg.content or ""}
93
92
 
94
93
  if msg.name:
95
94
  message["name"] = msg.name
@@ -134,11 +133,21 @@ class OpenAIProvider(Provider):
134
133
  params = self.completion_params.copy()
135
134
  params["messages"] = openai_messages
136
135
  params["stream"] = stream
136
+ tools = []
137
137
 
138
138
  if self.enable_function:
139
- tools = get_openai_schemas()
140
- if tools:
141
- params["tools"] = tools
139
+ tools.extend(get_openai_schemas())
140
+
141
+ # Add MCP tools if enabled
142
+ if self.enable_mcp:
143
+ try:
144
+ mcp_tools = get_mcp_manager().to_openai_tools()
145
+ except (ValueError, FileNotFoundError) as e:
146
+ self.console.print(f"Failed to load MCP tools: {e}", style="red")
147
+ mcp_tools = []
148
+ tools.extend(mcp_tools)
149
+ if tools:
150
+ params["tools"] = tools
142
151
 
143
152
  try:
144
153
  if stream:
@@ -10,5 +10,6 @@ class SiliconFlowProvider(OpenAIProvider):
10
10
 
11
11
  def get_completion_params(self) -> Dict[str, Any]:
12
12
  params = super().get_completion_params()
13
- params["max_tokens"] = params.pop("max_completion_tokens")
13
+ if "max_completion_tokens" in params:
14
+ params["max_tokens"] = params.pop("max_completion_tokens")
14
15
  return params
@@ -0,0 +1,127 @@
1
+ from typing import Any, Dict, List, Tuple, cast
2
+
3
+ from json_repair import repair_json
4
+ from mcp import types
5
+ from rich.panel import Panel
6
+
7
+ from ..config import cfg
8
+ from ..console import get_console
9
+ from ..schemas import ToolCall
10
+ from .function import get_function, list_functions
11
+ from .mcp import MCP_TOOL_NAME_PREFIX, get_mcp, get_mcp_manager, parse_mcp_tool_name
12
+
13
+ console = get_console()
14
+
15
+
16
+ def get_openai_schemas() -> List[Dict[str, Any]]:
17
+ """Get OpenAI-compatible function schemas
18
+
19
+ Returns:
20
+ List of function schemas in OpenAI format
21
+ """
22
+ transformed_schemas = []
23
+ for function in list_functions():
24
+ schema = {
25
+ "type": "function",
26
+ "function": {
27
+ "name": function.name,
28
+ "description": function.description,
29
+ "parameters": function.parameters,
30
+ },
31
+ }
32
+ transformed_schemas.append(schema)
33
+ return transformed_schemas
34
+
35
+
36
+ def get_openai_mcp_tools() -> list[dict[str, Any]]:
37
+ """Get OpenAI-compatible function schemas
38
+
39
+ Returns:
40
+ List of function schemas in OpenAI format
41
+ """
42
+ return get_mcp_manager().to_openai_tools()
43
+
44
+
45
+ def execute_mcp_tool(tool_name: str, tool_kwargs: dict) -> str:
46
+ """Execute an MCP tool
47
+
48
+ Args:
49
+ tool_name: The name of the tool to execute
50
+ tool_kwargs: The arguments to pass to the tool
51
+ """
52
+ manager = get_mcp_manager()
53
+ tool = manager.get_tool(tool_name)
54
+ try:
55
+ result = tool.execute(**tool_kwargs)
56
+ if isinstance(result, list) and len(result) > 0:
57
+ result = result[0]
58
+ if isinstance(result, types.TextContent):
59
+ return result.text
60
+ else:
61
+ return str(result)
62
+ except Exception as e:
63
+ error_msg = f"Call MCP tool error:\nTool name: {tool_name!r}\nArguments: {tool_kwargs!r}\nError: {e}"
64
+ console.print(error_msg, style="red")
65
+ return error_msg
66
+
67
+
68
+ def execute_tool_call(tool_call: ToolCall) -> Tuple[str, bool]:
69
+ """Execute a tool call and return the result
70
+
71
+ Args:
72
+ tool_call: The tool call to execute
73
+
74
+ Returns:
75
+ Tuple[str, bool]: (result text, success flag)
76
+ """
77
+ is_function_call = not tool_call.name.startswith(MCP_TOOL_NAME_PREFIX)
78
+ if is_function_call:
79
+ get_tool_func = get_function
80
+ show_output = cfg["SHOW_FUNCTION_OUTPUT"]
81
+ _type = "function"
82
+ else:
83
+ tool_call.name = parse_mcp_tool_name(tool_call.name)
84
+ get_tool_func = get_mcp
85
+ show_output = cfg["SHOW_MCP_OUTPUT"]
86
+ _type = "mcp"
87
+
88
+ console.print(f"@{_type.title()} call: {tool_call.name}({tool_call.arguments})", style="blue")
89
+ # 1. Get the tool
90
+ try:
91
+ tool = get_tool_func(tool_call.name)
92
+ except ValueError as e:
93
+ error_msg = f"{_type.title()} '{tool_call.name!r}' not exists: {e}"
94
+ console.print(error_msg, style="red")
95
+ return error_msg, False
96
+
97
+ # 2. Parse tool arguments
98
+ try:
99
+ arguments = repair_json(tool_call.arguments, return_objects=True)
100
+ if not isinstance(arguments, dict):
101
+ error_msg = f"Invalid arguments type: {arguments!r}, should be JSON object"
102
+ console.print(error_msg, style="red")
103
+ return error_msg, False
104
+ arguments = cast(dict, arguments)
105
+ except Exception as e:
106
+ error_msg = f"Invalid arguments from llm: {e}\nRaw arguments: {tool_call.arguments!r}"
107
+ console.print(error_msg, style="red")
108
+ return error_msg, False
109
+
110
+ # 3. Execute the tool
111
+ try:
112
+ result = tool.execute(**arguments)
113
+ if show_output:
114
+ panel = Panel(
115
+ result,
116
+ title=f"{_type.title()} output",
117
+ title_align="left",
118
+ expand=False,
119
+ border_style="blue",
120
+ style="dim",
121
+ )
122
+ console.print(panel)
123
+ return result, True
124
+ except Exception as e:
125
+ error_msg = f"Call {_type} error: {e}\n{_type} name: {tool_call.name!r}\nArguments: {arguments!r}"
126
+ console.print(error_msg, style="red")
127
+ return error_msg, False
@@ -0,0 +1,90 @@
1
+ import importlib.util
2
+ import sys
3
+ from typing import Callable, List, Optional
4
+
5
+ from instructor import OpenAISchema
6
+
7
+ from ..const import FUNCTIONS_DIR
8
+ from ..utils import wrap_function
9
+
10
+
11
+ class Function:
12
+ """Function description class"""
13
+
14
+ def __init__(self, function: type[OpenAISchema]):
15
+ self.name = function.openai_schema["name"]
16
+ self.description = function.openai_schema.get("description", "")
17
+ self.parameters = function.openai_schema.get("parameters", {})
18
+ self.execute = function.execute # type: ignore
19
+
20
+
21
+ _func_name_map: Optional[dict[str, Function]] = None
22
+
23
+
24
+ def get_func_name_map() -> dict[str, Function]:
25
+ """Get function name map"""
26
+ global _func_name_map
27
+ if _func_name_map:
28
+ return _func_name_map
29
+ if not FUNCTIONS_DIR.exists():
30
+ FUNCTIONS_DIR.mkdir(parents=True, exist_ok=True)
31
+ return {}
32
+ functions = []
33
+ for file in FUNCTIONS_DIR.glob("*.py"):
34
+ if file.name.startswith("_"):
35
+ continue
36
+ module_name = str(file).replace("/", ".").rstrip(".py")
37
+ spec = importlib.util.spec_from_file_location(module_name, str(file))
38
+ module = importlib.util.module_from_spec(spec) # type: ignore
39
+ sys.modules[module_name] = module
40
+ spec.loader.exec_module(module) # type: ignore
41
+
42
+ if not issubclass(module.Function, OpenAISchema):
43
+ raise TypeError(f"Function {module_name} must be a subclass of instructor.OpenAISchema")
44
+ if not hasattr(module.Function, "execute"):
45
+ raise TypeError(f"Function {module_name} must have an 'execute' classmethod")
46
+
47
+ # Add to function list
48
+ functions.append(Function(function=module.Function))
49
+
50
+ # Cache the function list
51
+ _func_name_map = {func.name: func for func in functions}
52
+ return _func_name_map
53
+
54
+
55
+ def list_functions() -> list[Function]:
56
+ """List all available buildin functions"""
57
+ global _func_name_map
58
+ if not _func_name_map:
59
+ _func_name_map = get_func_name_map()
60
+
61
+ return list(_func_name_map.values())
62
+
63
+
64
+ def get_function(name: str) -> Function:
65
+ """Get a function by name
66
+
67
+ Args:
68
+ name: Function name
69
+
70
+ Returns:
71
+ Function execute method
72
+
73
+ Raises:
74
+ ValueError: If function not found
75
+ """
76
+ func_map = get_func_name_map()
77
+ if name in func_map:
78
+ return func_map[name]
79
+ raise ValueError(f"Function {name!r} not found")
80
+
81
+
82
+ def get_functions_gemini_format() -> List[Callable]:
83
+ """Get functions in gemini format"""
84
+ gemini_functions = []
85
+ for func_name, func in get_func_name_map().items():
86
+ wrapped_func = wrap_function(func.execute)
87
+ wrapped_func.__name__ = func_name
88
+ wrapped_func.__doc__ = func.description
89
+ gemini_functions.append(wrapped_func)
90
+ return gemini_functions