yaicli 0.6.3__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyproject.toml +14 -1
- yaicli/cli.py +2 -2
- yaicli/config.py +1 -1
- yaicli/const.py +19 -8
- yaicli/entry.py +24 -1
- yaicli/functions/__init__.py +13 -1
- yaicli/llms/client.py +71 -38
- yaicli/llms/provider.py +3 -0
- yaicli/llms/providers/ai21_provider.py +33 -1
- yaicli/llms/providers/chatglm_provider.py +38 -11
- yaicli/llms/providers/cohere_provider.py +6 -3
- yaicli/llms/providers/deepseek_provider.py +2 -1
- yaicli/llms/providers/doubao_provider.py +0 -14
- yaicli/llms/providers/gemini_provider.py +29 -28
- yaicli/llms/providers/huggingface_provider.py +40 -0
- yaicli/llms/providers/infiniai_provider.py +4 -2
- yaicli/llms/providers/modelscope_provider.py +2 -1
- yaicli/llms/providers/openai_provider.py +20 -11
- yaicli/llms/providers/siliconflow_provider.py +2 -1
- yaicli/tools/__init__.py +127 -0
- yaicli/tools/function.py +90 -0
- yaicli/tools/mcp.py +459 -0
- yaicli/utils.py +34 -0
- {yaicli-0.6.3.dist-info → yaicli-0.7.0.dist-info}/METADATA +231 -19
- yaicli-0.7.0.dist-info/RECORD +49 -0
- yaicli/tools.py +0 -159
- yaicli-0.6.3.dist-info/RECORD +0 -46
- {yaicli-0.6.3.dist-info → yaicli-0.7.0.dist-info}/WHEEL +0 -0
- {yaicli-0.6.3.dist-info → yaicli-0.7.0.dist-info}/entry_points.txt +0 -0
- {yaicli-0.6.3.dist-info → yaicli-0.7.0.dist-info}/licenses/LICENSE +0 -0
pyproject.toml
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
[project]
|
2
2
|
name = "yaicli"
|
3
|
-
version = "0.
|
3
|
+
version = "0.7.0"
|
4
4
|
description = "A simple CLI tool to interact with LLM"
|
5
5
|
authors = [{ name = "belingud", email = "im.victor@qq.com" }]
|
6
6
|
readme = "README.md"
|
@@ -42,6 +42,15 @@ keywords = [
|
|
42
42
|
"anthropic",
|
43
43
|
"groq",
|
44
44
|
"cohere",
|
45
|
+
"huggingface",
|
46
|
+
"chatglm",
|
47
|
+
"sambanova",
|
48
|
+
"siliconflow",
|
49
|
+
"xai",
|
50
|
+
"vertexai",
|
51
|
+
"deepseek",
|
52
|
+
"modelscope",
|
53
|
+
"ollama",
|
45
54
|
]
|
46
55
|
dependencies = [
|
47
56
|
"click>=8.1.8",
|
@@ -70,11 +79,15 @@ all = [
|
|
70
79
|
"ollama>=0.5.1",
|
71
80
|
"cohere>=5.15.0",
|
72
81
|
"google-genai>=1.20.0",
|
82
|
+
"huggingface-hub>=0.33.0",
|
73
83
|
]
|
74
84
|
doubao = ["volcengine-python-sdk>=3.0.15"]
|
75
85
|
ollama = ["ollama>=0.5.1"]
|
76
86
|
cohere = ["cohere>=5.15.0"]
|
77
87
|
gemini = ["google-genai>=1.20.0"]
|
88
|
+
huggingface = [
|
89
|
+
"huggingface-hub>=0.33.0",
|
90
|
+
]
|
78
91
|
|
79
92
|
[tool.pytest.ini_options]
|
80
93
|
testpaths = ["tests"]
|
yaicli/cli.py
CHANGED
@@ -267,7 +267,7 @@ class CLI:
|
|
267
267
|
assistant_msg = self.chat.history[i + 1] if (i + 1) < len(self.chat.history) else None
|
268
268
|
self.console.print(f"[dim]{i // 2 + 1}[/dim] [bold blue]User:[/bold blue] {user_msg.content}")
|
269
269
|
if assistant_msg:
|
270
|
-
md = Markdown(assistant_msg.content, code_theme=cfg["CODE_THEME"])
|
270
|
+
md = Markdown(assistant_msg.content or "", code_theme=cfg["CODE_THEME"])
|
271
271
|
padded_md = Padding(md, (0, 0, 0, 4))
|
272
272
|
self.console.print(" Assistant:", style="bold green")
|
273
273
|
self.console.print(padded_md)
|
@@ -384,7 +384,7 @@ class CLI:
|
|
384
384
|
self._check_history_len()
|
385
385
|
|
386
386
|
if self.current_mode == EXEC_MODE:
|
387
|
-
self._confirm_and_execute(content)
|
387
|
+
self._confirm_and_execute(content or "")
|
388
388
|
return True
|
389
389
|
|
390
390
|
def _confirm_and_execute(self, raw_content: str) -> None:
|
yaicli/config.py
CHANGED
@@ -142,7 +142,7 @@ class Config(dict):
|
|
142
142
|
if target_type is bool:
|
143
143
|
converted_value = str2bool(raw_value)
|
144
144
|
elif target_type in (int, float, str):
|
145
|
-
converted_value = target_type(raw_value)
|
145
|
+
converted_value = target_type(raw_value) if raw_value else raw_value
|
146
146
|
elif target_type is dict and raw_value:
|
147
147
|
converted_value = json.loads(raw_value)
|
148
148
|
except (ValueError, TypeError, json.JSONDecodeError) as e:
|
yaicli/const.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1
1
|
try:
|
2
|
-
from enum import StrEnum
|
2
|
+
from enum import StrEnum # type: ignore
|
3
3
|
except ImportError:
|
4
4
|
from enum import Enum
|
5
5
|
|
@@ -16,7 +16,7 @@ from rich.console import JustifyMethod
|
|
16
16
|
BOOL_STR = Literal["true", "false", "yes", "no", "y", "n", "1", "0", "on", "off"]
|
17
17
|
|
18
18
|
|
19
|
-
class JustifyEnum(StrEnum):
|
19
|
+
class JustifyEnum(StrEnum): # type: ignore
|
20
20
|
DEFAULT = "default"
|
21
21
|
LEFT = "left"
|
22
22
|
CENTER = "center"
|
@@ -43,6 +43,7 @@ HISTORY_FILE = Path("~/.yaicli_history").expanduser()
|
|
43
43
|
CONFIG_PATH = Path("~/.config/yaicli/config.ini").expanduser()
|
44
44
|
ROLES_DIR = CONFIG_PATH.parent / "roles"
|
45
45
|
FUNCTIONS_DIR = CONFIG_PATH.parent / "functions"
|
46
|
+
MCP_JSON_PATH = CONFIG_PATH.parent / "mcp.json"
|
46
47
|
|
47
48
|
# Default configuration values
|
48
49
|
DEFAULT_CODE_THEME = "monokai"
|
@@ -68,7 +69,9 @@ DEFAULT_JUSTIFY: JustifyMethod = "default"
|
|
68
69
|
DEFAULT_ROLE_MODIFY_WARNING: BOOL_STR = "true"
|
69
70
|
DEFAULT_ENABLE_FUNCTIONS: BOOL_STR = "true"
|
70
71
|
DEFAULT_SHOW_FUNCTION_OUTPUT: BOOL_STR = "true"
|
71
|
-
DEFAULT_REASONING_EFFORT: Optional[Literal["low", "high", "medium"]] =
|
72
|
+
DEFAULT_REASONING_EFFORT: Optional[Literal["low", "high", "medium"]] = None
|
73
|
+
DEFAULT_ENABLE_MCP: BOOL_STR = "false"
|
74
|
+
DEFAULT_SHOW_MCP_OUTPUT: BOOL_STR = "false"
|
72
75
|
|
73
76
|
|
74
77
|
SHELL_PROMPT = """You are YAICLI, a shell command generator.
|
@@ -93,16 +96,16 @@ CODER_PROMPT = (
|
|
93
96
|
)
|
94
97
|
|
95
98
|
|
96
|
-
class DefaultRoleNames(StrEnum):
|
99
|
+
class DefaultRoleNames(StrEnum): # type: ignore
|
97
100
|
SHELL = "Shell Command Generator"
|
98
101
|
DEFAULT = "DEFAULT"
|
99
102
|
CODER = "Code Assistant"
|
100
103
|
|
101
104
|
|
102
105
|
DEFAULT_ROLES: dict[str, dict[str, Any]] = {
|
103
|
-
DefaultRoleNames.SHELL.value: {"name": DefaultRoleNames.SHELL.value, "prompt": SHELL_PROMPT},
|
104
|
-
DefaultRoleNames.DEFAULT.value: {"name": DefaultRoleNames.DEFAULT.value, "prompt": DEFAULT_PROMPT},
|
105
|
-
DefaultRoleNames.CODER.value: {"name": DefaultRoleNames.CODER.value, "prompt": CODER_PROMPT},
|
106
|
+
DefaultRoleNames.SHELL.value: {"name": DefaultRoleNames.SHELL.value, "prompt": SHELL_PROMPT}, # type: ignore
|
107
|
+
DefaultRoleNames.DEFAULT.value: {"name": DefaultRoleNames.DEFAULT.value, "prompt": DEFAULT_PROMPT}, # type: ignore
|
108
|
+
DefaultRoleNames.CODER.value: {"name": DefaultRoleNames.CODER.value, "prompt": CODER_PROMPT}, # type: ignore
|
106
109
|
}
|
107
110
|
|
108
111
|
# DEFAULT_CONFIG_MAP is a dictionary of the configuration options.
|
@@ -151,6 +154,8 @@ DEFAULT_CONFIG_MAP = {
|
|
151
154
|
"env_key": "YAI_SHOW_FUNCTION_OUTPUT",
|
152
155
|
"type": bool,
|
153
156
|
},
|
157
|
+
"ENABLE_MCP": {"value": DEFAULT_ENABLE_MCP, "env_key": "YAI_ENABLE_MCP", "type": bool},
|
158
|
+
"SHOW_MCP_OUTPUT": {"value": DEFAULT_SHOW_MCP_OUTPUT, "env_key": "YAI_SHOW_MCP_OUTPUT", "type": bool},
|
154
159
|
}
|
155
160
|
|
156
161
|
DEFAULT_CONFIG_INI = f"""[core]
|
@@ -201,6 +206,12 @@ ROLE_MODIFY_WARNING={DEFAULT_CONFIG_MAP["ROLE_MODIFY_WARNING"]["value"]}
|
|
201
206
|
# Function settings
|
202
207
|
# Set to false to disable sending functions in API requests
|
203
208
|
ENABLE_FUNCTIONS={DEFAULT_CONFIG_MAP["ENABLE_FUNCTIONS"]["value"]}
|
204
|
-
# Set to false to disable showing function output
|
209
|
+
# Set to false to disable showing function output when calling functions
|
205
210
|
SHOW_FUNCTION_OUTPUT={DEFAULT_CONFIG_MAP["SHOW_FUNCTION_OUTPUT"]["value"]}
|
211
|
+
|
212
|
+
# MCP settings
|
213
|
+
# Set to false to disable MCP in API requests
|
214
|
+
ENABLE_MCP={DEFAULT_CONFIG_MAP["ENABLE_MCP"]["value"]}
|
215
|
+
# Set to false to disable showing MCP output when calling MCP tools
|
216
|
+
SHOW_MCP_OUTPUT={DEFAULT_CONFIG_MAP["SHOW_MCP_OUTPUT"]["value"]}
|
206
217
|
"""
|
yaicli/entry.py
CHANGED
@@ -6,7 +6,7 @@ import typer
|
|
6
6
|
from .chat import FileChatManager
|
7
7
|
from .config import cfg
|
8
8
|
from .const import DEFAULT_CONFIG_INI, DefaultRoleNames, JustifyEnum
|
9
|
-
from .functions import install_functions, print_functions
|
9
|
+
from .functions import install_functions, print_functions, print_mcp
|
10
10
|
from .role import RoleManager
|
11
11
|
|
12
12
|
app = typer.Typer(
|
@@ -209,6 +209,29 @@ def main(
|
|
209
209
|
show_default=False,
|
210
210
|
callback=override_config,
|
211
211
|
),
|
212
|
+
# ------------------- MCP Options -------------------
|
213
|
+
enable_mcp: bool = typer.Option( # noqa: F841
|
214
|
+
cfg["ENABLE_MCP"],
|
215
|
+
"--enable-mcp/--disable-mcp",
|
216
|
+
help=f"Enable/disable MCP in API requests [dim](default: {'enabled' if cfg['ENABLE_MCP'] else 'disabled'})[/dim]",
|
217
|
+
rich_help_panel="MCP Options",
|
218
|
+
callback=override_config,
|
219
|
+
),
|
220
|
+
show_mcp_output: bool = typer.Option( # noqa: F841
|
221
|
+
cfg["SHOW_MCP_OUTPUT"],
|
222
|
+
"--show-mcp-output/--hide-mcp-output",
|
223
|
+
help=f"Show the output of MCP [dim](default: {'show' if cfg['SHOW_MCP_OUTPUT'] else 'hide'})[/dim]",
|
224
|
+
rich_help_panel="MCP Options",
|
225
|
+
show_default=False,
|
226
|
+
callback=override_config,
|
227
|
+
),
|
228
|
+
list_mcp: bool = typer.Option( # noqa: F841
|
229
|
+
False,
|
230
|
+
"--list-mcp",
|
231
|
+
help="List all available mcp.",
|
232
|
+
rich_help_panel="MCP Options",
|
233
|
+
callback=print_mcp,
|
234
|
+
),
|
212
235
|
):
|
213
236
|
"""YAICLI: Your AI assistant in the command line.
|
214
237
|
|
yaicli/functions/__init__.py
CHANGED
@@ -1,9 +1,10 @@
|
|
1
|
+
import json
|
1
2
|
import shutil
|
2
3
|
from pathlib import Path
|
3
4
|
from typing import Any
|
4
5
|
|
5
6
|
from ..console import get_console
|
6
|
-
from ..const import FUNCTIONS_DIR
|
7
|
+
from ..const import FUNCTIONS_DIR, MCP_JSON_PATH
|
7
8
|
from ..utils import option_callback
|
8
9
|
|
9
10
|
console = get_console()
|
@@ -37,3 +38,14 @@ def print_functions(cls, _: Any) -> None:
|
|
37
38
|
if file.name.startswith("_"):
|
38
39
|
continue
|
39
40
|
console.print(file)
|
41
|
+
|
42
|
+
|
43
|
+
@option_callback
|
44
|
+
def print_mcp(cls, _: Any) -> None:
|
45
|
+
"""List all available mcp"""
|
46
|
+
if not MCP_JSON_PATH.exists():
|
47
|
+
console.print("No mcp config found, please add your mcp config in ~/.config/yaicli/mcp.json")
|
48
|
+
return
|
49
|
+
with open(MCP_JSON_PATH, "r") as f:
|
50
|
+
mcp_config = json.load(f)
|
51
|
+
console.print_json(data=mcp_config)
|
yaicli/llms/client.py
CHANGED
@@ -4,6 +4,7 @@ from ..config import cfg
|
|
4
4
|
from ..console import get_console
|
5
5
|
from ..schemas import ChatMessage, LLMResponse, RefreshLive, ToolCall
|
6
6
|
from ..tools import execute_tool_call
|
7
|
+
from ..tools.mcp import MCP_TOOL_NAME_PREFIX
|
7
8
|
from .provider import Provider, ProviderFactory
|
8
9
|
|
9
10
|
|
@@ -37,6 +38,8 @@ class LLMClient:
|
|
37
38
|
self.config = config
|
38
39
|
self.verbose = verbose
|
39
40
|
self.console = get_console()
|
41
|
+
self.enable_function = self.config["ENABLE_FUNCTIONS"]
|
42
|
+
self.enable_mcp = self.config["ENABLE_MCP"]
|
40
43
|
|
41
44
|
# Use provided provider or create one
|
42
45
|
if provider:
|
@@ -73,48 +76,78 @@ class LLMClient:
|
|
73
76
|
)
|
74
77
|
return
|
75
78
|
|
76
|
-
# Get completion from provider
|
77
|
-
llm_response_generator = self.provider.completion(messages, stream=stream)
|
78
|
-
|
79
|
-
# To hold the full response
|
79
|
+
# Get completion from provider and collect response data
|
80
80
|
assistant_response_content = ""
|
81
|
-
|
81
|
+
# Providers may return identical tool calls with the same ID in a single response during streaming
|
82
|
+
tool_calls: dict[str, ToolCall] = {}
|
82
83
|
|
83
|
-
#
|
84
|
-
for llm_response in
|
85
|
-
# Forward
|
86
|
-
yield llm_response
|
84
|
+
# Stream responses and collect data
|
85
|
+
for llm_response in self.provider.completion(messages, stream=stream):
|
86
|
+
yield llm_response # Forward response to caller
|
87
87
|
|
88
|
-
# Collect content and tool calls
|
88
|
+
# Collect content and tool calls for potential tool execution
|
89
89
|
if llm_response.content:
|
90
90
|
assistant_response_content += llm_response.content
|
91
|
-
if llm_response.tool_call and llm_response.tool_call not in tool_calls:
|
92
|
-
tool_calls.
|
93
|
-
|
94
|
-
#
|
95
|
-
if tool_calls
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
91
|
+
if llm_response.tool_call and llm_response.tool_call.id not in tool_calls:
|
92
|
+
tool_calls[llm_response.tool_call.id] = llm_response.tool_call
|
93
|
+
|
94
|
+
# Check if we need to execute tools
|
95
|
+
if not tool_calls or not (self.enable_function or self.enable_mcp):
|
96
|
+
return
|
97
|
+
|
98
|
+
# Filter valid tool calls based on enabled features
|
99
|
+
valid_tool_calls = self._get_valid_tool_calls(tool_calls)
|
100
|
+
if not valid_tool_calls:
|
101
|
+
return
|
102
|
+
|
103
|
+
# Execute tools and continue conversation
|
104
|
+
yield from self._execute_tools_and_continue(
|
105
|
+
messages, assistant_response_content, valid_tool_calls, stream, recursion_depth
|
106
|
+
)
|
107
|
+
|
108
|
+
def _get_valid_tool_calls(self, tool_calls: dict[str, ToolCall]) -> List[ToolCall]:
|
109
|
+
"""Filter tool calls based on enabled features"""
|
110
|
+
valid_tool_calls = []
|
111
|
+
|
112
|
+
for tool_call in tool_calls.values():
|
113
|
+
is_mcp = tool_call.name.startswith(MCP_TOOL_NAME_PREFIX)
|
114
|
+
|
115
|
+
if is_mcp and self.enable_mcp:
|
116
|
+
valid_tool_calls.append(tool_call)
|
117
|
+
elif not is_mcp and self.enable_function:
|
118
|
+
valid_tool_calls.append(tool_call)
|
119
|
+
|
120
|
+
return valid_tool_calls
|
121
|
+
|
122
|
+
def _execute_tools_and_continue(
|
123
|
+
self,
|
124
|
+
messages: List[ChatMessage],
|
125
|
+
assistant_response_content: str,
|
126
|
+
tool_calls: List[ToolCall],
|
127
|
+
stream: bool,
|
128
|
+
recursion_depth: int,
|
129
|
+
) -> Generator[Union[LLMResponse, RefreshLive], None, None]:
|
130
|
+
"""Execute tool calls and continue the conversation"""
|
131
|
+
# Signal that new content is coming
|
132
|
+
yield RefreshLive()
|
133
|
+
|
134
|
+
# Add assistant message with tool calls to history (only once)
|
135
|
+
messages.append(ChatMessage(role="assistant", content=assistant_response_content, tool_calls=tool_calls))
|
136
|
+
|
137
|
+
# Execute each tool call and add results to messages
|
138
|
+
tool_role = self.provider.detect_tool_role()
|
139
|
+
|
140
|
+
for tool_call in tool_calls:
|
141
|
+
function_result, _ = execute_tool_call(tool_call)
|
142
|
+
|
143
|
+
messages.append(
|
144
|
+
ChatMessage(
|
145
|
+
role=tool_role,
|
146
|
+
content=function_result,
|
147
|
+
name=tool_call.name,
|
148
|
+
tool_call_id=tool_call.id,
|
117
149
|
)
|
150
|
+
)
|
118
151
|
|
119
|
-
|
120
|
-
|
152
|
+
# Continue the conversation with updated history
|
153
|
+
yield from self.completion_with_tools(messages, stream=stream, recursion_depth=recursion_depth + 1)
|
yaicli/llms/provider.py
CHANGED
@@ -43,10 +43,13 @@ class ProviderFactory:
|
|
43
43
|
"chatglm": (".providers.chatglm_provider", "ChatglmProvider"),
|
44
44
|
"chutes": (".providers.chutes_provider", "ChutesProvider"),
|
45
45
|
"cohere": (".providers.cohere_provider", "CohereProvider"),
|
46
|
+
"cohere-bedrock": (".providers.cohere_provider", "CohereBadrockProvider"),
|
47
|
+
"cohere-sagemaker": (".providers.cohere_provider", "CohereSagemakerProvider"),
|
46
48
|
"deepseek": (".providers.deepseek_provider", "DeepSeekProvider"),
|
47
49
|
"doubao": (".providers.doubao_provider", "DoubaoProvider"),
|
48
50
|
"gemini": (".providers.gemini_provider", "GeminiProvider"),
|
49
51
|
"groq": (".providers.groq_provider", "GroqProvider"),
|
52
|
+
"huggingface": (".providers.huggingface_provider", "HuggingFaceProvider"),
|
50
53
|
"infini-ai": (".providers.infiniai_provider", "InfiniAIProvider"),
|
51
54
|
"minimax": (".providers.minimax_provider", "MinimaxProvider"),
|
52
55
|
"modelscope": (".providers.modelscope_provider", "ModelScopeProvider"),
|
@@ -63,7 +63,7 @@ class AI21Provider(OpenAIProvider):
|
|
63
63
|
if finish_reason == "tool_calls" and not content:
|
64
64
|
# tool call assistant message, content can't be empty
|
65
65
|
# Error code: 422 - {'detail': {'error': ['Value error, message content must not be an empty string']}}
|
66
|
-
content = tool_call.id
|
66
|
+
content = tool_call.id if tool_call else ""
|
67
67
|
|
68
68
|
# Generate response object
|
69
69
|
yield LLMResponse(
|
@@ -72,3 +72,35 @@ class AI21Provider(OpenAIProvider):
|
|
72
72
|
tool_call=tool_call if finish_reason == "tool_calls" else None,
|
73
73
|
finish_reason=finish_reason,
|
74
74
|
)
|
75
|
+
|
76
|
+
def _process_tool_call_chunk(self, tool_calls, existing_tool_call=None):
|
77
|
+
"""Process a tool call chunk from AI21 API response
|
78
|
+
|
79
|
+
Tool calls from AI21 are delivered across multiple chunks:
|
80
|
+
- First chunk contains function name
|
81
|
+
- Subsequent chunks contain arguments data
|
82
|
+
- Final chunk (with finish_reason='tool_calls') contains complete arguments
|
83
|
+
|
84
|
+
Args:
|
85
|
+
tool_calls: Tool call data from current chunk
|
86
|
+
existing_tool_call: Previously accumulated tool call object
|
87
|
+
|
88
|
+
Returns:
|
89
|
+
ToolCall: Updated tool call object with accumulated data
|
90
|
+
"""
|
91
|
+
# Get the first (and only) tool call from the chunk
|
92
|
+
call = tool_calls[0]
|
93
|
+
|
94
|
+
if existing_tool_call is None:
|
95
|
+
# First chunk - create new tool call with function name
|
96
|
+
return ToolCall(id=call.id, name=call.function.name, arguments="{}")
|
97
|
+
else:
|
98
|
+
# Update existing tool call with new arguments data
|
99
|
+
# Keep existing data and update with new information
|
100
|
+
existing_arguments = existing_tool_call.arguments
|
101
|
+
new_arguments = call.function.arguments if hasattr(call.function, "arguments") else "{}"
|
102
|
+
|
103
|
+
# Combine arguments (new arguments should override if available)
|
104
|
+
combined_arguments = new_arguments if new_arguments else existing_arguments
|
105
|
+
|
106
|
+
return ToolCall(id=existing_tool_call.id, name=existing_tool_call.name, arguments=combined_arguments)
|
@@ -1,9 +1,10 @@
|
|
1
1
|
import json
|
2
|
-
from typing import
|
2
|
+
from typing import Generator, Optional, Union, overload
|
3
3
|
|
4
4
|
from openai._streaming import Stream
|
5
5
|
from openai.types.chat.chat_completion import ChatCompletion, Choice
|
6
6
|
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
7
|
+
from openai.types.chat.chat_completion_chunk import Choice as ChoiceChunk
|
7
8
|
|
8
9
|
from ...schemas import LLMResponse, ToolCall
|
9
10
|
from .openai_provider import OpenAIProvider
|
@@ -14,10 +15,14 @@ class ChatglmProvider(OpenAIProvider):
|
|
14
15
|
|
15
16
|
DEFAULT_BASE_URL = "https://open.bigmodel.cn/api/paas/v4/"
|
16
17
|
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
18
|
+
COMPLETION_PARAMS_KEYS = {
|
19
|
+
"model": "MODEL",
|
20
|
+
"temperature": "TEMPERATURE",
|
21
|
+
"top_p": "TOP_P",
|
22
|
+
"max_tokens": "MAX_TOKENS",
|
23
|
+
"do_sample": "DO_SAMPLE",
|
24
|
+
"extra_body": "EXTRA_BODY",
|
25
|
+
}
|
21
26
|
|
22
27
|
def _handle_normal_response(self, response: ChatCompletion) -> Generator[LLMResponse, None, None]:
|
23
28
|
"""Handle normal (non-streaming) response
|
@@ -57,7 +62,7 @@ class ChatglmProvider(OpenAIProvider):
|
|
57
62
|
tmp_content = content[tool_index:]
|
58
63
|
# Tool call data may in content after the <think> block
|
59
64
|
try:
|
60
|
-
choice = self.parse_choice_from_content(tmp_content)
|
65
|
+
choice = self.parse_choice_from_content(tmp_content, Choice)
|
61
66
|
except ValueError:
|
62
67
|
pass
|
63
68
|
if hasattr(choice, "message") and hasattr(choice.message, "tool_calls") and choice.message.tool_calls: # type: ignore
|
@@ -97,8 +102,6 @@ class ChatglmProvider(OpenAIProvider):
|
|
97
102
|
reasoning = self._get_reasoning_content(delta)
|
98
103
|
full_reasoning += reasoning or ""
|
99
104
|
|
100
|
-
if finish_reason:
|
101
|
-
pass
|
102
105
|
if finish_reason == "tool_calls" or ('{"index":' in content or '"tool_calls":' in content):
|
103
106
|
# Tool call data may in content after the <think> block
|
104
107
|
# >/n{"index": 0, "tool_call_id": "call_1", "function": {"name": "name", "arguments": "{}"}, "output": null}
|
@@ -106,7 +109,7 @@ class ChatglmProvider(OpenAIProvider):
|
|
106
109
|
if tool_index != -1:
|
107
110
|
tmp_content = full_content[tool_index:]
|
108
111
|
try:
|
109
|
-
choice = self.parse_choice_from_content(tmp_content)
|
112
|
+
choice = self.parse_choice_from_content(tmp_content, ChoiceChunk)
|
110
113
|
except ValueError:
|
111
114
|
pass
|
112
115
|
if hasattr(choice.delta, "tool_calls") and choice.delta.tool_calls: # type: ignore
|
@@ -120,7 +123,31 @@ class ChatglmProvider(OpenAIProvider):
|
|
120
123
|
tool_call = ToolCall(tool_id, tool_call_name, arguments)
|
121
124
|
yield LLMResponse(reasoning=reasoning, content=content, tool_call=tool_call, finish_reason=finish_reason)
|
122
125
|
|
123
|
-
|
126
|
+
@overload
|
127
|
+
def parse_choice_from_content(self, content: str, choice_class: type[ChoiceChunk] = ChoiceChunk) -> "ChoiceChunk":
|
128
|
+
"""
|
129
|
+
Parse the choice from the content after <think>...</think> block.
|
130
|
+
Args:
|
131
|
+
content: The content from the LLM response
|
132
|
+
Returns:
|
133
|
+
The choice object
|
134
|
+
Raises ValueError if the content is not valid JSON
|
135
|
+
"""
|
136
|
+
|
137
|
+
@overload
|
138
|
+
def parse_choice_from_content(self, content: str, choice_class: type[Choice] = Choice) -> "Choice":
|
139
|
+
"""
|
140
|
+
Parse the choice from the content after <think>...</think> block.
|
141
|
+
Args:
|
142
|
+
content: The content from the LLM response
|
143
|
+
Returns:
|
144
|
+
The choice object
|
145
|
+
Raises ValueError if the content is not valid JSON
|
146
|
+
"""
|
147
|
+
|
148
|
+
def parse_choice_from_content(
|
149
|
+
self, content: str, choice_class: type[Union[Choice, ChoiceChunk]] = Choice
|
150
|
+
) -> Union[Choice, ChoiceChunk]:
|
124
151
|
"""
|
125
152
|
Parse the choice from the content after <think>...</think> block.
|
126
153
|
Args:
|
@@ -134,6 +161,6 @@ class ChatglmProvider(OpenAIProvider):
|
|
134
161
|
except json.JSONDecodeError:
|
135
162
|
raise ValueError(f"Invalid message from LLM: {content}")
|
136
163
|
try:
|
137
|
-
return
|
164
|
+
return choice_class.model_validate(content_dict)
|
138
165
|
except Exception as e:
|
139
166
|
raise ValueError(f"Invalid message from LLM: {content}") from e
|
@@ -10,7 +10,8 @@ This module implements Cohere provider classes for different deployment options:
|
|
10
10
|
from typing import Any, Dict, Generator, List, Optional
|
11
11
|
|
12
12
|
from cohere import BedrockClientV2, ClientV2, SagemakerClientV2
|
13
|
-
from cohere.types.tool_call_v2 import ToolCallV2
|
13
|
+
from cohere.types.tool_call_v2 import ToolCallV2
|
14
|
+
from cohere.types.tool_call_v2function import ToolCallV2Function
|
14
15
|
|
15
16
|
from ...config import cfg
|
16
17
|
from ...console import get_console
|
@@ -179,7 +180,9 @@ class CohereProvider(Provider):
|
|
179
180
|
continue
|
180
181
|
elif chunk.type == "tool-call-delta":
|
181
182
|
# Tool call arguments being generated: cohere.types.chat_tool_call_delta_event_delta_message.ChatToolCallDeltaEventDeltaMessage
|
182
|
-
|
183
|
+
if not tool_call:
|
184
|
+
continue
|
185
|
+
tool_call.arguments += chunk.delta.message.tool_calls.function.arguments or ""
|
183
186
|
# Waiting for tool-call-end event
|
184
187
|
continue
|
185
188
|
|
@@ -292,7 +295,7 @@ class CohereBadrockProvider(CohereProvider):
|
|
292
295
|
return self.CLIENT_CLS(**self.client_params)
|
293
296
|
|
294
297
|
|
295
|
-
class
|
298
|
+
class CohereSagemakerProvider(CohereBadrockProvider):
|
296
299
|
"""Cohere provider for AWS Sagemaker integration"""
|
297
300
|
|
298
301
|
CLIENT_CLS = SagemakerClientV2
|
@@ -10,5 +10,6 @@ class DeepSeekProvider(OpenAIProvider):
|
|
10
10
|
|
11
11
|
def get_completion_params(self) -> Dict[str, Any]:
|
12
12
|
params = super().get_completion_params()
|
13
|
-
|
13
|
+
if "max_completion_tokens" in params:
|
14
|
+
params["max_tokens"] = params.pop("max_completion_tokens")
|
14
15
|
return params
|
@@ -2,8 +2,6 @@ from typing import Any, Dict
|
|
2
2
|
|
3
3
|
from volcenginesdkarkruntime import Ark
|
4
4
|
|
5
|
-
from ...config import cfg
|
6
|
-
from ...console import get_console
|
7
5
|
from .openai_provider import OpenAIProvider
|
8
6
|
|
9
7
|
|
@@ -13,18 +11,6 @@ class DoubaoProvider(OpenAIProvider):
|
|
13
11
|
DEFAULT_BASE_URL = "https://ark.cn-beijing.volces.com/api/v3"
|
14
12
|
CLIENT_CLS = Ark
|
15
13
|
|
16
|
-
def __init__(self, config: dict = cfg, **kwargs):
|
17
|
-
self.config = config
|
18
|
-
self.enable_function = self.config["ENABLE_FUNCTIONS"]
|
19
|
-
self.client_params = self.get_client_params()
|
20
|
-
|
21
|
-
# Initialize client
|
22
|
-
self.client = self.CLIENT_CLS(**self.client_params)
|
23
|
-
self.console = get_console()
|
24
|
-
|
25
|
-
# Store completion params
|
26
|
-
self.completion_params = self.get_completion_params()
|
27
|
-
|
28
14
|
def get_client_params(self) -> Dict[str, Any]:
|
29
15
|
# Initialize client params
|
30
16
|
client_params = {"base_url": self.DEFAULT_BASE_URL}
|