yaicli 0.6.4__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyproject.toml +1 -1
- yaicli/const.py +17 -6
- yaicli/entry.py +24 -1
- yaicli/functions/__init__.py +13 -1
- yaicli/llms/client.py +71 -38
- yaicli/llms/providers/ai21_provider.py +33 -1
- yaicli/llms/providers/chatglm_provider.py +30 -7
- yaicli/llms/providers/doubao_provider.py +0 -14
- yaicli/llms/providers/gemini_provider.py +21 -22
- yaicli/llms/providers/huggingface_provider.py +1 -1
- yaicli/llms/providers/openai_provider.py +17 -8
- yaicli/tools/__init__.py +127 -0
- yaicli/tools/function.py +90 -0
- yaicli/tools/mcp.py +459 -0
- yaicli/utils.py +34 -0
- {yaicli-0.6.4.dist-info → yaicli-0.7.0.dist-info}/METADATA +91 -4
- {yaicli-0.6.4.dist-info → yaicli-0.7.0.dist-info}/RECORD +20 -18
- yaicli/tools.py +0 -159
- {yaicli-0.6.4.dist-info → yaicli-0.7.0.dist-info}/WHEEL +0 -0
- {yaicli-0.6.4.dist-info → yaicli-0.7.0.dist-info}/entry_points.txt +0 -0
- {yaicli-0.6.4.dist-info → yaicli-0.7.0.dist-info}/licenses/LICENSE +0 -0
pyproject.toml
CHANGED
yaicli/const.py
CHANGED
@@ -16,7 +16,7 @@ from rich.console import JustifyMethod
|
|
16
16
|
BOOL_STR = Literal["true", "false", "yes", "no", "y", "n", "1", "0", "on", "off"]
|
17
17
|
|
18
18
|
|
19
|
-
class JustifyEnum(StrEnum):
|
19
|
+
class JustifyEnum(StrEnum): # type: ignore
|
20
20
|
DEFAULT = "default"
|
21
21
|
LEFT = "left"
|
22
22
|
CENTER = "center"
|
@@ -43,6 +43,7 @@ HISTORY_FILE = Path("~/.yaicli_history").expanduser()
|
|
43
43
|
CONFIG_PATH = Path("~/.config/yaicli/config.ini").expanduser()
|
44
44
|
ROLES_DIR = CONFIG_PATH.parent / "roles"
|
45
45
|
FUNCTIONS_DIR = CONFIG_PATH.parent / "functions"
|
46
|
+
MCP_JSON_PATH = CONFIG_PATH.parent / "mcp.json"
|
46
47
|
|
47
48
|
# Default configuration values
|
48
49
|
DEFAULT_CODE_THEME = "monokai"
|
@@ -69,6 +70,8 @@ DEFAULT_ROLE_MODIFY_WARNING: BOOL_STR = "true"
|
|
69
70
|
DEFAULT_ENABLE_FUNCTIONS: BOOL_STR = "true"
|
70
71
|
DEFAULT_SHOW_FUNCTION_OUTPUT: BOOL_STR = "true"
|
71
72
|
DEFAULT_REASONING_EFFORT: Optional[Literal["low", "high", "medium"]] = None
|
73
|
+
DEFAULT_ENABLE_MCP: BOOL_STR = "false"
|
74
|
+
DEFAULT_SHOW_MCP_OUTPUT: BOOL_STR = "false"
|
72
75
|
|
73
76
|
|
74
77
|
SHELL_PROMPT = """You are YAICLI, a shell command generator.
|
@@ -93,16 +96,16 @@ CODER_PROMPT = (
|
|
93
96
|
)
|
94
97
|
|
95
98
|
|
96
|
-
class DefaultRoleNames(StrEnum):
|
99
|
+
class DefaultRoleNames(StrEnum): # type: ignore
|
97
100
|
SHELL = "Shell Command Generator"
|
98
101
|
DEFAULT = "DEFAULT"
|
99
102
|
CODER = "Code Assistant"
|
100
103
|
|
101
104
|
|
102
105
|
DEFAULT_ROLES: dict[str, dict[str, Any]] = {
|
103
|
-
DefaultRoleNames.SHELL.value: {"name": DefaultRoleNames.SHELL.value, "prompt": SHELL_PROMPT},
|
104
|
-
DefaultRoleNames.DEFAULT.value: {"name": DefaultRoleNames.DEFAULT.value, "prompt": DEFAULT_PROMPT},
|
105
|
-
DefaultRoleNames.CODER.value: {"name": DefaultRoleNames.CODER.value, "prompt": CODER_PROMPT},
|
106
|
+
DefaultRoleNames.SHELL.value: {"name": DefaultRoleNames.SHELL.value, "prompt": SHELL_PROMPT}, # type: ignore
|
107
|
+
DefaultRoleNames.DEFAULT.value: {"name": DefaultRoleNames.DEFAULT.value, "prompt": DEFAULT_PROMPT}, # type: ignore
|
108
|
+
DefaultRoleNames.CODER.value: {"name": DefaultRoleNames.CODER.value, "prompt": CODER_PROMPT}, # type: ignore
|
106
109
|
}
|
107
110
|
|
108
111
|
# DEFAULT_CONFIG_MAP is a dictionary of the configuration options.
|
@@ -151,6 +154,8 @@ DEFAULT_CONFIG_MAP = {
|
|
151
154
|
"env_key": "YAI_SHOW_FUNCTION_OUTPUT",
|
152
155
|
"type": bool,
|
153
156
|
},
|
157
|
+
"ENABLE_MCP": {"value": DEFAULT_ENABLE_MCP, "env_key": "YAI_ENABLE_MCP", "type": bool},
|
158
|
+
"SHOW_MCP_OUTPUT": {"value": DEFAULT_SHOW_MCP_OUTPUT, "env_key": "YAI_SHOW_MCP_OUTPUT", "type": bool},
|
154
159
|
}
|
155
160
|
|
156
161
|
DEFAULT_CONFIG_INI = f"""[core]
|
@@ -201,6 +206,12 @@ ROLE_MODIFY_WARNING={DEFAULT_CONFIG_MAP["ROLE_MODIFY_WARNING"]["value"]}
|
|
201
206
|
# Function settings
|
202
207
|
# Set to false to disable sending functions in API requests
|
203
208
|
ENABLE_FUNCTIONS={DEFAULT_CONFIG_MAP["ENABLE_FUNCTIONS"]["value"]}
|
204
|
-
# Set to false to disable showing function output
|
209
|
+
# Set to false to disable showing function output when calling functions
|
205
210
|
SHOW_FUNCTION_OUTPUT={DEFAULT_CONFIG_MAP["SHOW_FUNCTION_OUTPUT"]["value"]}
|
211
|
+
|
212
|
+
# MCP settings
|
213
|
+
# Set to false to disable MCP in API requests
|
214
|
+
ENABLE_MCP={DEFAULT_CONFIG_MAP["ENABLE_MCP"]["value"]}
|
215
|
+
# Set to false to disable showing MCP output when calling MCP tools
|
216
|
+
SHOW_MCP_OUTPUT={DEFAULT_CONFIG_MAP["SHOW_MCP_OUTPUT"]["value"]}
|
206
217
|
"""
|
yaicli/entry.py
CHANGED
@@ -6,7 +6,7 @@ import typer
|
|
6
6
|
from .chat import FileChatManager
|
7
7
|
from .config import cfg
|
8
8
|
from .const import DEFAULT_CONFIG_INI, DefaultRoleNames, JustifyEnum
|
9
|
-
from .functions import install_functions, print_functions
|
9
|
+
from .functions import install_functions, print_functions, print_mcp
|
10
10
|
from .role import RoleManager
|
11
11
|
|
12
12
|
app = typer.Typer(
|
@@ -209,6 +209,29 @@ def main(
|
|
209
209
|
show_default=False,
|
210
210
|
callback=override_config,
|
211
211
|
),
|
212
|
+
# ------------------- MCP Options -------------------
|
213
|
+
enable_mcp: bool = typer.Option( # noqa: F841
|
214
|
+
cfg["ENABLE_MCP"],
|
215
|
+
"--enable-mcp/--disable-mcp",
|
216
|
+
help=f"Enable/disable MCP in API requests [dim](default: {'enabled' if cfg['ENABLE_MCP'] else 'disabled'})[/dim]",
|
217
|
+
rich_help_panel="MCP Options",
|
218
|
+
callback=override_config,
|
219
|
+
),
|
220
|
+
show_mcp_output: bool = typer.Option( # noqa: F841
|
221
|
+
cfg["SHOW_MCP_OUTPUT"],
|
222
|
+
"--show-mcp-output/--hide-mcp-output",
|
223
|
+
help=f"Show the output of MCP [dim](default: {'show' if cfg['SHOW_MCP_OUTPUT'] else 'hide'})[/dim]",
|
224
|
+
rich_help_panel="MCP Options",
|
225
|
+
show_default=False,
|
226
|
+
callback=override_config,
|
227
|
+
),
|
228
|
+
list_mcp: bool = typer.Option( # noqa: F841
|
229
|
+
False,
|
230
|
+
"--list-mcp",
|
231
|
+
help="List all available mcp.",
|
232
|
+
rich_help_panel="MCP Options",
|
233
|
+
callback=print_mcp,
|
234
|
+
),
|
212
235
|
):
|
213
236
|
"""YAICLI: Your AI assistant in the command line.
|
214
237
|
|
yaicli/functions/__init__.py
CHANGED
@@ -1,9 +1,10 @@
|
|
1
|
+
import json
|
1
2
|
import shutil
|
2
3
|
from pathlib import Path
|
3
4
|
from typing import Any
|
4
5
|
|
5
6
|
from ..console import get_console
|
6
|
-
from ..const import FUNCTIONS_DIR
|
7
|
+
from ..const import FUNCTIONS_DIR, MCP_JSON_PATH
|
7
8
|
from ..utils import option_callback
|
8
9
|
|
9
10
|
console = get_console()
|
@@ -37,3 +38,14 @@ def print_functions(cls, _: Any) -> None:
|
|
37
38
|
if file.name.startswith("_"):
|
38
39
|
continue
|
39
40
|
console.print(file)
|
41
|
+
|
42
|
+
|
43
|
+
@option_callback
|
44
|
+
def print_mcp(cls, _: Any) -> None:
|
45
|
+
"""List all available mcp"""
|
46
|
+
if not MCP_JSON_PATH.exists():
|
47
|
+
console.print("No mcp config found, please add your mcp config in ~/.config/yaicli/mcp.json")
|
48
|
+
return
|
49
|
+
with open(MCP_JSON_PATH, "r") as f:
|
50
|
+
mcp_config = json.load(f)
|
51
|
+
console.print_json(data=mcp_config)
|
yaicli/llms/client.py
CHANGED
@@ -4,6 +4,7 @@ from ..config import cfg
|
|
4
4
|
from ..console import get_console
|
5
5
|
from ..schemas import ChatMessage, LLMResponse, RefreshLive, ToolCall
|
6
6
|
from ..tools import execute_tool_call
|
7
|
+
from ..tools.mcp import MCP_TOOL_NAME_PREFIX
|
7
8
|
from .provider import Provider, ProviderFactory
|
8
9
|
|
9
10
|
|
@@ -37,6 +38,8 @@ class LLMClient:
|
|
37
38
|
self.config = config
|
38
39
|
self.verbose = verbose
|
39
40
|
self.console = get_console()
|
41
|
+
self.enable_function = self.config["ENABLE_FUNCTIONS"]
|
42
|
+
self.enable_mcp = self.config["ENABLE_MCP"]
|
40
43
|
|
41
44
|
# Use provided provider or create one
|
42
45
|
if provider:
|
@@ -73,48 +76,78 @@ class LLMClient:
|
|
73
76
|
)
|
74
77
|
return
|
75
78
|
|
76
|
-
# Get completion from provider
|
77
|
-
llm_response_generator = self.provider.completion(messages, stream=stream)
|
78
|
-
|
79
|
-
# To hold the full response
|
79
|
+
# Get completion from provider and collect response data
|
80
80
|
assistant_response_content = ""
|
81
|
-
|
81
|
+
# Providers may return identical tool calls with the same ID in a single response during streaming
|
82
|
+
tool_calls: dict[str, ToolCall] = {}
|
82
83
|
|
83
|
-
#
|
84
|
-
for llm_response in
|
85
|
-
# Forward
|
86
|
-
yield llm_response
|
84
|
+
# Stream responses and collect data
|
85
|
+
for llm_response in self.provider.completion(messages, stream=stream):
|
86
|
+
yield llm_response # Forward response to caller
|
87
87
|
|
88
|
-
# Collect content and tool calls
|
88
|
+
# Collect content and tool calls for potential tool execution
|
89
89
|
if llm_response.content:
|
90
90
|
assistant_response_content += llm_response.content
|
91
|
-
if llm_response.tool_call and llm_response.tool_call not in tool_calls:
|
92
|
-
tool_calls.
|
93
|
-
|
94
|
-
#
|
95
|
-
if tool_calls
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
91
|
+
if llm_response.tool_call and llm_response.tool_call.id not in tool_calls:
|
92
|
+
tool_calls[llm_response.tool_call.id] = llm_response.tool_call
|
93
|
+
|
94
|
+
# Check if we need to execute tools
|
95
|
+
if not tool_calls or not (self.enable_function or self.enable_mcp):
|
96
|
+
return
|
97
|
+
|
98
|
+
# Filter valid tool calls based on enabled features
|
99
|
+
valid_tool_calls = self._get_valid_tool_calls(tool_calls)
|
100
|
+
if not valid_tool_calls:
|
101
|
+
return
|
102
|
+
|
103
|
+
# Execute tools and continue conversation
|
104
|
+
yield from self._execute_tools_and_continue(
|
105
|
+
messages, assistant_response_content, valid_tool_calls, stream, recursion_depth
|
106
|
+
)
|
107
|
+
|
108
|
+
def _get_valid_tool_calls(self, tool_calls: dict[str, ToolCall]) -> List[ToolCall]:
|
109
|
+
"""Filter tool calls based on enabled features"""
|
110
|
+
valid_tool_calls = []
|
111
|
+
|
112
|
+
for tool_call in tool_calls.values():
|
113
|
+
is_mcp = tool_call.name.startswith(MCP_TOOL_NAME_PREFIX)
|
114
|
+
|
115
|
+
if is_mcp and self.enable_mcp:
|
116
|
+
valid_tool_calls.append(tool_call)
|
117
|
+
elif not is_mcp and self.enable_function:
|
118
|
+
valid_tool_calls.append(tool_call)
|
119
|
+
|
120
|
+
return valid_tool_calls
|
121
|
+
|
122
|
+
def _execute_tools_and_continue(
|
123
|
+
self,
|
124
|
+
messages: List[ChatMessage],
|
125
|
+
assistant_response_content: str,
|
126
|
+
tool_calls: List[ToolCall],
|
127
|
+
stream: bool,
|
128
|
+
recursion_depth: int,
|
129
|
+
) -> Generator[Union[LLMResponse, RefreshLive], None, None]:
|
130
|
+
"""Execute tool calls and continue the conversation"""
|
131
|
+
# Signal that new content is coming
|
132
|
+
yield RefreshLive()
|
133
|
+
|
134
|
+
# Add assistant message with tool calls to history (only once)
|
135
|
+
messages.append(ChatMessage(role="assistant", content=assistant_response_content, tool_calls=tool_calls))
|
136
|
+
|
137
|
+
# Execute each tool call and add results to messages
|
138
|
+
tool_role = self.provider.detect_tool_role()
|
139
|
+
|
140
|
+
for tool_call in tool_calls:
|
141
|
+
function_result, _ = execute_tool_call(tool_call)
|
142
|
+
|
143
|
+
messages.append(
|
144
|
+
ChatMessage(
|
145
|
+
role=tool_role,
|
146
|
+
content=function_result,
|
147
|
+
name=tool_call.name,
|
148
|
+
tool_call_id=tool_call.id,
|
117
149
|
)
|
150
|
+
)
|
118
151
|
|
119
|
-
|
120
|
-
|
152
|
+
# Continue the conversation with updated history
|
153
|
+
yield from self.completion_with_tools(messages, stream=stream, recursion_depth=recursion_depth + 1)
|
@@ -63,7 +63,7 @@ class AI21Provider(OpenAIProvider):
|
|
63
63
|
if finish_reason == "tool_calls" and not content:
|
64
64
|
# tool call assistant message, content can't be empty
|
65
65
|
# Error code: 422 - {'detail': {'error': ['Value error, message content must not be an empty string']}}
|
66
|
-
content = tool_call.id
|
66
|
+
content = tool_call.id if tool_call else ""
|
67
67
|
|
68
68
|
# Generate response object
|
69
69
|
yield LLMResponse(
|
@@ -72,3 +72,35 @@ class AI21Provider(OpenAIProvider):
|
|
72
72
|
tool_call=tool_call if finish_reason == "tool_calls" else None,
|
73
73
|
finish_reason=finish_reason,
|
74
74
|
)
|
75
|
+
|
76
|
+
def _process_tool_call_chunk(self, tool_calls, existing_tool_call=None):
|
77
|
+
"""Process a tool call chunk from AI21 API response
|
78
|
+
|
79
|
+
Tool calls from AI21 are delivered across multiple chunks:
|
80
|
+
- First chunk contains function name
|
81
|
+
- Subsequent chunks contain arguments data
|
82
|
+
- Final chunk (with finish_reason='tool_calls') contains complete arguments
|
83
|
+
|
84
|
+
Args:
|
85
|
+
tool_calls: Tool call data from current chunk
|
86
|
+
existing_tool_call: Previously accumulated tool call object
|
87
|
+
|
88
|
+
Returns:
|
89
|
+
ToolCall: Updated tool call object with accumulated data
|
90
|
+
"""
|
91
|
+
# Get the first (and only) tool call from the chunk
|
92
|
+
call = tool_calls[0]
|
93
|
+
|
94
|
+
if existing_tool_call is None:
|
95
|
+
# First chunk - create new tool call with function name
|
96
|
+
return ToolCall(id=call.id, name=call.function.name, arguments="{}")
|
97
|
+
else:
|
98
|
+
# Update existing tool call with new arguments data
|
99
|
+
# Keep existing data and update with new information
|
100
|
+
existing_arguments = existing_tool_call.arguments
|
101
|
+
new_arguments = call.function.arguments if hasattr(call.function, "arguments") else "{}"
|
102
|
+
|
103
|
+
# Combine arguments (new arguments should override if available)
|
104
|
+
combined_arguments = new_arguments if new_arguments else existing_arguments
|
105
|
+
|
106
|
+
return ToolCall(id=existing_tool_call.id, name=existing_tool_call.name, arguments=combined_arguments)
|
@@ -1,9 +1,10 @@
|
|
1
1
|
import json
|
2
|
-
from typing import Generator, Optional
|
2
|
+
from typing import Generator, Optional, Union, overload
|
3
3
|
|
4
4
|
from openai._streaming import Stream
|
5
5
|
from openai.types.chat.chat_completion import ChatCompletion, Choice
|
6
6
|
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
7
|
+
from openai.types.chat.chat_completion_chunk import Choice as ChoiceChunk
|
7
8
|
|
8
9
|
from ...schemas import LLMResponse, ToolCall
|
9
10
|
from .openai_provider import OpenAIProvider
|
@@ -61,7 +62,7 @@ class ChatglmProvider(OpenAIProvider):
|
|
61
62
|
tmp_content = content[tool_index:]
|
62
63
|
# Tool call data may in content after the <think> block
|
63
64
|
try:
|
64
|
-
choice = self.parse_choice_from_content(tmp_content)
|
65
|
+
choice = self.parse_choice_from_content(tmp_content, Choice)
|
65
66
|
except ValueError:
|
66
67
|
pass
|
67
68
|
if hasattr(choice, "message") and hasattr(choice.message, "tool_calls") and choice.message.tool_calls: # type: ignore
|
@@ -101,8 +102,6 @@ class ChatglmProvider(OpenAIProvider):
|
|
101
102
|
reasoning = self._get_reasoning_content(delta)
|
102
103
|
full_reasoning += reasoning or ""
|
103
104
|
|
104
|
-
if finish_reason:
|
105
|
-
pass
|
106
105
|
if finish_reason == "tool_calls" or ('{"index":' in content or '"tool_calls":' in content):
|
107
106
|
# Tool call data may in content after the <think> block
|
108
107
|
# >/n{"index": 0, "tool_call_id": "call_1", "function": {"name": "name", "arguments": "{}"}, "output": null}
|
@@ -110,7 +109,7 @@ class ChatglmProvider(OpenAIProvider):
|
|
110
109
|
if tool_index != -1:
|
111
110
|
tmp_content = full_content[tool_index:]
|
112
111
|
try:
|
113
|
-
choice = self.parse_choice_from_content(tmp_content)
|
112
|
+
choice = self.parse_choice_from_content(tmp_content, ChoiceChunk)
|
114
113
|
except ValueError:
|
115
114
|
pass
|
116
115
|
if hasattr(choice.delta, "tool_calls") and choice.delta.tool_calls: # type: ignore
|
@@ -124,7 +123,31 @@ class ChatglmProvider(OpenAIProvider):
|
|
124
123
|
tool_call = ToolCall(tool_id, tool_call_name, arguments)
|
125
124
|
yield LLMResponse(reasoning=reasoning, content=content, tool_call=tool_call, finish_reason=finish_reason)
|
126
125
|
|
127
|
-
|
126
|
+
@overload
|
127
|
+
def parse_choice_from_content(self, content: str, choice_class: type[ChoiceChunk] = ChoiceChunk) -> "ChoiceChunk":
|
128
|
+
"""
|
129
|
+
Parse the choice from the content after <think>...</think> block.
|
130
|
+
Args:
|
131
|
+
content: The content from the LLM response
|
132
|
+
Returns:
|
133
|
+
The choice object
|
134
|
+
Raises ValueError if the content is not valid JSON
|
135
|
+
"""
|
136
|
+
|
137
|
+
@overload
|
138
|
+
def parse_choice_from_content(self, content: str, choice_class: type[Choice] = Choice) -> "Choice":
|
139
|
+
"""
|
140
|
+
Parse the choice from the content after <think>...</think> block.
|
141
|
+
Args:
|
142
|
+
content: The content from the LLM response
|
143
|
+
Returns:
|
144
|
+
The choice object
|
145
|
+
Raises ValueError if the content is not valid JSON
|
146
|
+
"""
|
147
|
+
|
148
|
+
def parse_choice_from_content(
|
149
|
+
self, content: str, choice_class: type[Union[Choice, ChoiceChunk]] = Choice
|
150
|
+
) -> Union[Choice, ChoiceChunk]:
|
128
151
|
"""
|
129
152
|
Parse the choice from the content after <think>...</think> block.
|
130
153
|
Args:
|
@@ -138,6 +161,6 @@ class ChatglmProvider(OpenAIProvider):
|
|
138
161
|
except json.JSONDecodeError:
|
139
162
|
raise ValueError(f"Invalid message from LLM: {content}")
|
140
163
|
try:
|
141
|
-
return
|
164
|
+
return choice_class.model_validate(content_dict)
|
142
165
|
except Exception as e:
|
143
166
|
raise ValueError(f"Invalid message from LLM: {content}") from e
|
@@ -2,8 +2,6 @@ from typing import Any, Dict
|
|
2
2
|
|
3
3
|
from volcenginesdkarkruntime import Ark
|
4
4
|
|
5
|
-
from ...config import cfg
|
6
|
-
from ...console import get_console
|
7
5
|
from .openai_provider import OpenAIProvider
|
8
6
|
|
9
7
|
|
@@ -13,18 +11,6 @@ class DoubaoProvider(OpenAIProvider):
|
|
13
11
|
DEFAULT_BASE_URL = "https://ark.cn-beijing.volces.com/api/v3"
|
14
12
|
CLIENT_CLS = Ark
|
15
13
|
|
16
|
-
def __init__(self, config: dict = cfg, **kwargs):
|
17
|
-
self.config = config
|
18
|
-
self.enable_function = self.config["ENABLE_FUNCTIONS"]
|
19
|
-
self.client_params = self.get_client_params()
|
20
|
-
|
21
|
-
# Initialize client
|
22
|
-
self.client = self.CLIENT_CLS(**self.client_params)
|
23
|
-
self.console = get_console()
|
24
|
-
|
25
|
-
# Store completion params
|
26
|
-
self.completion_params = self.get_completion_params()
|
27
|
-
|
28
14
|
def get_client_params(self) -> Dict[str, Any]:
|
29
15
|
# Initialize client params
|
30
16
|
client_params = {"base_url": self.DEFAULT_BASE_URL}
|
@@ -1,25 +1,18 @@
|
|
1
1
|
import json
|
2
|
-
from functools import wraps
|
3
2
|
from typing import Any, Callable, Dict, Generator, List
|
4
3
|
|
5
4
|
import google.genai as genai
|
6
5
|
from google.genai import types
|
7
6
|
|
7
|
+
from yaicli.tools.mcp import get_mcp_manager
|
8
|
+
|
8
9
|
from ...config import cfg
|
9
10
|
from ...console import get_console
|
10
11
|
from ...schemas import ChatMessage, LLMResponse
|
11
|
-
from ...tools import
|
12
|
+
from ...tools.function import get_functions_gemini_format
|
12
13
|
from ..provider import Provider
|
13
14
|
|
14
15
|
|
15
|
-
def wrap_function(func):
|
16
|
-
@wraps(func)
|
17
|
-
def wrapper(*args, **kwargs):
|
18
|
-
return func(*args, **kwargs)
|
19
|
-
|
20
|
-
return wrapper
|
21
|
-
|
22
|
-
|
23
16
|
class GeminiProvider(Provider):
|
24
17
|
"""Gemini provider implementation based on google-genai library"""
|
25
18
|
|
@@ -28,6 +21,7 @@ class GeminiProvider(Provider):
|
|
28
21
|
def __init__(self, config: dict = cfg, verbose: bool = False, **kwargs):
|
29
22
|
self.config = config
|
30
23
|
self.enable_function = self.config["ENABLE_FUNCTIONS"]
|
24
|
+
self.enable_mcp = self.config["ENABLE_MCP"]
|
31
25
|
self.verbose = verbose
|
32
26
|
|
33
27
|
# Initialize client
|
@@ -67,16 +61,17 @@ class GeminiProvider(Provider):
|
|
67
61
|
config_map["frequency_penalty"] = self.config["FREQUENCY_PENALTY"]
|
68
62
|
if self.config.get("SEED"):
|
69
63
|
config_map["seed"] = self.config["SEED"]
|
70
|
-
# Indicates whether to include thoughts in the response.
|
64
|
+
# Indicates whether to include thoughts in the response.
|
65
|
+
# If true, thoughts are returned only if the model supports thought and thoughts are available.
|
71
66
|
thinking_config_map = {"include_thoughts": self.config.get("INCLUDE_THOUGHTS", True)}
|
72
67
|
if self.config.get("THINKING_BUDGET"):
|
73
68
|
thinking_config_map["thinking_budget"] = int(self.config["THINKING_BUDGET"])
|
74
69
|
config_map["thinking_config"] = types.ThinkingConfig(**thinking_config_map)
|
75
|
-
|
76
|
-
if self.enable_function:
|
70
|
+
if self.enable_function or self.enable_mcp:
|
77
71
|
# TODO: support disable automatic function calling
|
78
72
|
# config.automatic_function_calling = types.AutomaticFunctionCallingConfig(disable=False)
|
79
|
-
|
73
|
+
config_map["tools"] = self.gen_gemini_functions()
|
74
|
+
config = types.GenerateContentConfig(**config_map)
|
80
75
|
return config
|
81
76
|
|
82
77
|
def _convert_messages(self, messages: List[ChatMessage]) -> List[types.Content]:
|
@@ -103,15 +98,19 @@ class GeminiProvider(Provider):
|
|
103
98
|
|
104
99
|
def gen_gemini_functions(self) -> List[Callable[..., Any]]:
|
105
100
|
"""Wrap Gemini functions from OpenAI functions for automatic function calling"""
|
106
|
-
func_name_map = get_func_name_map()
|
107
|
-
if not func_name_map:
|
108
|
-
return []
|
109
101
|
funcs = []
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
102
|
+
|
103
|
+
# Add regular functions
|
104
|
+
if self.enable_function:
|
105
|
+
funcs.extend(get_functions_gemini_format())
|
106
|
+
|
107
|
+
# Add MCP functions if enabled
|
108
|
+
if self.enable_mcp:
|
109
|
+
try:
|
110
|
+
mcp_tools = get_mcp_manager().to_gemini_tools()
|
111
|
+
funcs.extend(mcp_tools)
|
112
|
+
except (ImportError, Exception) as e:
|
113
|
+
self.console.print(f"Failed to load MCP tools for Gemini: {e}", style="red")
|
115
114
|
return funcs
|
116
115
|
|
117
116
|
def completion(
|
@@ -10,6 +10,7 @@ from ...config import cfg
|
|
10
10
|
from ...console import get_console
|
11
11
|
from ...schemas import ChatMessage, LLMResponse, ToolCall
|
12
12
|
from ...tools import get_openai_schemas
|
13
|
+
from ...tools.mcp import get_mcp_manager
|
13
14
|
from ..provider import Provider
|
14
15
|
|
15
16
|
|
@@ -34,6 +35,7 @@ class OpenAIProvider(Provider):
|
|
34
35
|
if not self.config.get("API_KEY"):
|
35
36
|
raise ValueError("API_KEY is required")
|
36
37
|
self.enable_function = self.config["ENABLE_FUNCTIONS"]
|
38
|
+
self.enable_mcp = self.config["ENABLE_MCP"]
|
37
39
|
self.verbose = verbose
|
38
40
|
|
39
41
|
# Initialize client
|
@@ -50,15 +52,12 @@ class OpenAIProvider(Provider):
|
|
50
52
|
client_params = {
|
51
53
|
"api_key": self.config["API_KEY"],
|
52
54
|
"base_url": self.config["BASE_URL"] or self.DEFAULT_BASE_URL,
|
55
|
+
"default_headers": {"X-Title": self.APP_NAME, "HTTP_Referer": self.APP_REFERER},
|
53
56
|
}
|
54
57
|
|
55
58
|
# Add extra headers if set
|
56
59
|
if self.config["EXTRA_HEADERS"]:
|
57
|
-
client_params["default_headers"] = {
|
58
|
-
**self.config["EXTRA_HEADERS"],
|
59
|
-
"X-Title": self.APP_NAME,
|
60
|
-
"HTTP-Referer": self.APP_REFERER,
|
61
|
-
}
|
60
|
+
client_params["default_headers"] = {**self.config["EXTRA_HEADERS"], **client_params["default_headers"]}
|
62
61
|
return client_params
|
63
62
|
|
64
63
|
def get_completion_params_keys(self) -> Dict[str, str]:
|
@@ -134,11 +133,21 @@ class OpenAIProvider(Provider):
|
|
134
133
|
params = self.completion_params.copy()
|
135
134
|
params["messages"] = openai_messages
|
136
135
|
params["stream"] = stream
|
136
|
+
tools = []
|
137
137
|
|
138
138
|
if self.enable_function:
|
139
|
-
tools
|
140
|
-
|
141
|
-
|
139
|
+
tools.extend(get_openai_schemas())
|
140
|
+
|
141
|
+
# Add MCP tools if enabled
|
142
|
+
if self.enable_mcp:
|
143
|
+
try:
|
144
|
+
mcp_tools = get_mcp_manager().to_openai_tools()
|
145
|
+
except (ValueError, FileNotFoundError) as e:
|
146
|
+
self.console.print(f"Failed to load MCP tools: {e}", style="red")
|
147
|
+
mcp_tools = []
|
148
|
+
tools.extend(mcp_tools)
|
149
|
+
if tools:
|
150
|
+
params["tools"] = tools
|
142
151
|
|
143
152
|
try:
|
144
153
|
if stream:
|