yaicli 0.6.3__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pyproject.toml CHANGED
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "yaicli"
3
- version = "0.6.3"
3
+ version = "0.7.0"
4
4
  description = "A simple CLI tool to interact with LLM"
5
5
  authors = [{ name = "belingud", email = "im.victor@qq.com" }]
6
6
  readme = "README.md"
@@ -42,6 +42,15 @@ keywords = [
42
42
  "anthropic",
43
43
  "groq",
44
44
  "cohere",
45
+ "huggingface",
46
+ "chatglm",
47
+ "sambanova",
48
+ "siliconflow",
49
+ "xai",
50
+ "vertexai",
51
+ "deepseek",
52
+ "modelscope",
53
+ "ollama",
45
54
  ]
46
55
  dependencies = [
47
56
  "click>=8.1.8",
@@ -70,11 +79,15 @@ all = [
70
79
  "ollama>=0.5.1",
71
80
  "cohere>=5.15.0",
72
81
  "google-genai>=1.20.0",
82
+ "huggingface-hub>=0.33.0",
73
83
  ]
74
84
  doubao = ["volcengine-python-sdk>=3.0.15"]
75
85
  ollama = ["ollama>=0.5.1"]
76
86
  cohere = ["cohere>=5.15.0"]
77
87
  gemini = ["google-genai>=1.20.0"]
88
+ huggingface = [
89
+ "huggingface-hub>=0.33.0",
90
+ ]
78
91
 
79
92
  [tool.pytest.ini_options]
80
93
  testpaths = ["tests"]
yaicli/cli.py CHANGED
@@ -267,7 +267,7 @@ class CLI:
267
267
  assistant_msg = self.chat.history[i + 1] if (i + 1) < len(self.chat.history) else None
268
268
  self.console.print(f"[dim]{i // 2 + 1}[/dim] [bold blue]User:[/bold blue] {user_msg.content}")
269
269
  if assistant_msg:
270
- md = Markdown(assistant_msg.content, code_theme=cfg["CODE_THEME"])
270
+ md = Markdown(assistant_msg.content or "", code_theme=cfg["CODE_THEME"])
271
271
  padded_md = Padding(md, (0, 0, 0, 4))
272
272
  self.console.print(" Assistant:", style="bold green")
273
273
  self.console.print(padded_md)
@@ -384,7 +384,7 @@ class CLI:
384
384
  self._check_history_len()
385
385
 
386
386
  if self.current_mode == EXEC_MODE:
387
- self._confirm_and_execute(content)
387
+ self._confirm_and_execute(content or "")
388
388
  return True
389
389
 
390
390
  def _confirm_and_execute(self, raw_content: str) -> None:
yaicli/config.py CHANGED
@@ -142,7 +142,7 @@ class Config(dict):
142
142
  if target_type is bool:
143
143
  converted_value = str2bool(raw_value)
144
144
  elif target_type in (int, float, str):
145
- converted_value = target_type(raw_value)
145
+ converted_value = target_type(raw_value) if raw_value else raw_value
146
146
  elif target_type is dict and raw_value:
147
147
  converted_value = json.loads(raw_value)
148
148
  except (ValueError, TypeError, json.JSONDecodeError) as e:
yaicli/const.py CHANGED
@@ -1,5 +1,5 @@
1
1
  try:
2
- from enum import StrEnum
2
+ from enum import StrEnum # type: ignore
3
3
  except ImportError:
4
4
  from enum import Enum
5
5
 
@@ -16,7 +16,7 @@ from rich.console import JustifyMethod
16
16
  BOOL_STR = Literal["true", "false", "yes", "no", "y", "n", "1", "0", "on", "off"]
17
17
 
18
18
 
19
- class JustifyEnum(StrEnum):
19
+ class JustifyEnum(StrEnum): # type: ignore
20
20
  DEFAULT = "default"
21
21
  LEFT = "left"
22
22
  CENTER = "center"
@@ -43,6 +43,7 @@ HISTORY_FILE = Path("~/.yaicli_history").expanduser()
43
43
  CONFIG_PATH = Path("~/.config/yaicli/config.ini").expanduser()
44
44
  ROLES_DIR = CONFIG_PATH.parent / "roles"
45
45
  FUNCTIONS_DIR = CONFIG_PATH.parent / "functions"
46
+ MCP_JSON_PATH = CONFIG_PATH.parent / "mcp.json"
46
47
 
47
48
  # Default configuration values
48
49
  DEFAULT_CODE_THEME = "monokai"
@@ -68,7 +69,9 @@ DEFAULT_JUSTIFY: JustifyMethod = "default"
68
69
  DEFAULT_ROLE_MODIFY_WARNING: BOOL_STR = "true"
69
70
  DEFAULT_ENABLE_FUNCTIONS: BOOL_STR = "true"
70
71
  DEFAULT_SHOW_FUNCTION_OUTPUT: BOOL_STR = "true"
71
- DEFAULT_REASONING_EFFORT: Optional[Literal["low", "high", "medium"]] = ""
72
+ DEFAULT_REASONING_EFFORT: Optional[Literal["low", "high", "medium"]] = None
73
+ DEFAULT_ENABLE_MCP: BOOL_STR = "false"
74
+ DEFAULT_SHOW_MCP_OUTPUT: BOOL_STR = "false"
72
75
 
73
76
 
74
77
  SHELL_PROMPT = """You are YAICLI, a shell command generator.
@@ -93,16 +96,16 @@ CODER_PROMPT = (
93
96
  )
94
97
 
95
98
 
96
- class DefaultRoleNames(StrEnum):
99
+ class DefaultRoleNames(StrEnum): # type: ignore
97
100
  SHELL = "Shell Command Generator"
98
101
  DEFAULT = "DEFAULT"
99
102
  CODER = "Code Assistant"
100
103
 
101
104
 
102
105
  DEFAULT_ROLES: dict[str, dict[str, Any]] = {
103
- DefaultRoleNames.SHELL.value: {"name": DefaultRoleNames.SHELL.value, "prompt": SHELL_PROMPT},
104
- DefaultRoleNames.DEFAULT.value: {"name": DefaultRoleNames.DEFAULT.value, "prompt": DEFAULT_PROMPT},
105
- DefaultRoleNames.CODER.value: {"name": DefaultRoleNames.CODER.value, "prompt": CODER_PROMPT},
106
+ DefaultRoleNames.SHELL.value: {"name": DefaultRoleNames.SHELL.value, "prompt": SHELL_PROMPT}, # type: ignore
107
+ DefaultRoleNames.DEFAULT.value: {"name": DefaultRoleNames.DEFAULT.value, "prompt": DEFAULT_PROMPT}, # type: ignore
108
+ DefaultRoleNames.CODER.value: {"name": DefaultRoleNames.CODER.value, "prompt": CODER_PROMPT}, # type: ignore
106
109
  }
107
110
 
108
111
  # DEFAULT_CONFIG_MAP is a dictionary of the configuration options.
@@ -151,6 +154,8 @@ DEFAULT_CONFIG_MAP = {
151
154
  "env_key": "YAI_SHOW_FUNCTION_OUTPUT",
152
155
  "type": bool,
153
156
  },
157
+ "ENABLE_MCP": {"value": DEFAULT_ENABLE_MCP, "env_key": "YAI_ENABLE_MCP", "type": bool},
158
+ "SHOW_MCP_OUTPUT": {"value": DEFAULT_SHOW_MCP_OUTPUT, "env_key": "YAI_SHOW_MCP_OUTPUT", "type": bool},
154
159
  }
155
160
 
156
161
  DEFAULT_CONFIG_INI = f"""[core]
@@ -201,6 +206,12 @@ ROLE_MODIFY_WARNING={DEFAULT_CONFIG_MAP["ROLE_MODIFY_WARNING"]["value"]}
201
206
  # Function settings
202
207
  # Set to false to disable sending functions in API requests
203
208
  ENABLE_FUNCTIONS={DEFAULT_CONFIG_MAP["ENABLE_FUNCTIONS"]["value"]}
204
- # Set to false to disable showing function output in the response
209
+ # Set to false to disable showing function output when calling functions
205
210
  SHOW_FUNCTION_OUTPUT={DEFAULT_CONFIG_MAP["SHOW_FUNCTION_OUTPUT"]["value"]}
211
+
212
+ # MCP settings
213
+ # Set to false to disable MCP in API requests
214
+ ENABLE_MCP={DEFAULT_CONFIG_MAP["ENABLE_MCP"]["value"]}
215
+ # Set to false to disable showing MCP output when calling MCP tools
216
+ SHOW_MCP_OUTPUT={DEFAULT_CONFIG_MAP["SHOW_MCP_OUTPUT"]["value"]}
206
217
  """
yaicli/entry.py CHANGED
@@ -6,7 +6,7 @@ import typer
6
6
  from .chat import FileChatManager
7
7
  from .config import cfg
8
8
  from .const import DEFAULT_CONFIG_INI, DefaultRoleNames, JustifyEnum
9
- from .functions import install_functions, print_functions
9
+ from .functions import install_functions, print_functions, print_mcp
10
10
  from .role import RoleManager
11
11
 
12
12
  app = typer.Typer(
@@ -209,6 +209,29 @@ def main(
209
209
  show_default=False,
210
210
  callback=override_config,
211
211
  ),
212
+ # ------------------- MCP Options -------------------
213
+ enable_mcp: bool = typer.Option( # noqa: F841
214
+ cfg["ENABLE_MCP"],
215
+ "--enable-mcp/--disable-mcp",
216
+ help=f"Enable/disable MCP in API requests [dim](default: {'enabled' if cfg['ENABLE_MCP'] else 'disabled'})[/dim]",
217
+ rich_help_panel="MCP Options",
218
+ callback=override_config,
219
+ ),
220
+ show_mcp_output: bool = typer.Option( # noqa: F841
221
+ cfg["SHOW_MCP_OUTPUT"],
222
+ "--show-mcp-output/--hide-mcp-output",
223
+ help=f"Show the output of MCP [dim](default: {'show' if cfg['SHOW_MCP_OUTPUT'] else 'hide'})[/dim]",
224
+ rich_help_panel="MCP Options",
225
+ show_default=False,
226
+ callback=override_config,
227
+ ),
228
+ list_mcp: bool = typer.Option( # noqa: F841
229
+ False,
230
+ "--list-mcp",
231
+ help="List all available mcp.",
232
+ rich_help_panel="MCP Options",
233
+ callback=print_mcp,
234
+ ),
212
235
  ):
213
236
  """YAICLI: Your AI assistant in the command line.
214
237
 
@@ -1,9 +1,10 @@
1
+ import json
1
2
  import shutil
2
3
  from pathlib import Path
3
4
  from typing import Any
4
5
 
5
6
  from ..console import get_console
6
- from ..const import FUNCTIONS_DIR
7
+ from ..const import FUNCTIONS_DIR, MCP_JSON_PATH
7
8
  from ..utils import option_callback
8
9
 
9
10
  console = get_console()
@@ -37,3 +38,14 @@ def print_functions(cls, _: Any) -> None:
37
38
  if file.name.startswith("_"):
38
39
  continue
39
40
  console.print(file)
41
+
42
+
43
+ @option_callback
44
+ def print_mcp(cls, _: Any) -> None:
45
+ """List all available mcp"""
46
+ if not MCP_JSON_PATH.exists():
47
+ console.print("No mcp config found, please add your mcp config in ~/.config/yaicli/mcp.json")
48
+ return
49
+ with open(MCP_JSON_PATH, "r") as f:
50
+ mcp_config = json.load(f)
51
+ console.print_json(data=mcp_config)
yaicli/llms/client.py CHANGED
@@ -4,6 +4,7 @@ from ..config import cfg
4
4
  from ..console import get_console
5
5
  from ..schemas import ChatMessage, LLMResponse, RefreshLive, ToolCall
6
6
  from ..tools import execute_tool_call
7
+ from ..tools.mcp import MCP_TOOL_NAME_PREFIX
7
8
  from .provider import Provider, ProviderFactory
8
9
 
9
10
 
@@ -37,6 +38,8 @@ class LLMClient:
37
38
  self.config = config
38
39
  self.verbose = verbose
39
40
  self.console = get_console()
41
+ self.enable_function = self.config["ENABLE_FUNCTIONS"]
42
+ self.enable_mcp = self.config["ENABLE_MCP"]
40
43
 
41
44
  # Use provided provider or create one
42
45
  if provider:
@@ -73,48 +76,78 @@ class LLMClient:
73
76
  )
74
77
  return
75
78
 
76
- # Get completion from provider
77
- llm_response_generator = self.provider.completion(messages, stream=stream)
78
-
79
- # To hold the full response
79
+ # Get completion from provider and collect response data
80
80
  assistant_response_content = ""
81
- tool_calls: List[ToolCall] = []
81
+ # Providers may return identical tool calls with the same ID in a single response during streaming
82
+ tool_calls: dict[str, ToolCall] = {}
82
83
 
83
- # Process all responses from the provider
84
- for llm_response in llm_response_generator:
85
- # Forward the response to the caller
86
- yield llm_response
84
+ # Stream responses and collect data
85
+ for llm_response in self.provider.completion(messages, stream=stream):
86
+ yield llm_response # Forward response to caller
87
87
 
88
- # Collect content and tool calls
88
+ # Collect content and tool calls for potential tool execution
89
89
  if llm_response.content:
90
90
  assistant_response_content += llm_response.content
91
- if llm_response.tool_call and llm_response.tool_call not in tool_calls:
92
- tool_calls.append(llm_response.tool_call)
93
-
94
- # If we have tool calls, execute them and make recursive call
95
- if tool_calls and self.config["ENABLE_FUNCTIONS"]:
96
- # Yield a refresh signal to indicate new content is coming
97
- yield RefreshLive()
98
-
99
- # Append the assistant message with tool calls to history
100
- messages.append(ChatMessage(role="assistant", content=assistant_response_content, tool_calls=tool_calls))
101
-
102
- # Execute each tool call and append the results
103
- for tool_call in tool_calls:
104
- function_result, _ = execute_tool_call(tool_call)
105
-
106
- # Use provider's tool role detection
107
- tool_role = self.provider.detect_tool_role()
108
-
109
- # Append the tool result to history
110
- messages.append(
111
- ChatMessage(
112
- role=tool_role,
113
- content=function_result,
114
- name=tool_call.name,
115
- tool_call_id=tool_call.id,
116
- )
91
+ if llm_response.tool_call and llm_response.tool_call.id not in tool_calls:
92
+ tool_calls[llm_response.tool_call.id] = llm_response.tool_call
93
+
94
+ # Check if we need to execute tools
95
+ if not tool_calls or not (self.enable_function or self.enable_mcp):
96
+ return
97
+
98
+ # Filter valid tool calls based on enabled features
99
+ valid_tool_calls = self._get_valid_tool_calls(tool_calls)
100
+ if not valid_tool_calls:
101
+ return
102
+
103
+ # Execute tools and continue conversation
104
+ yield from self._execute_tools_and_continue(
105
+ messages, assistant_response_content, valid_tool_calls, stream, recursion_depth
106
+ )
107
+
108
+ def _get_valid_tool_calls(self, tool_calls: dict[str, ToolCall]) -> List[ToolCall]:
109
+ """Filter tool calls based on enabled features"""
110
+ valid_tool_calls = []
111
+
112
+ for tool_call in tool_calls.values():
113
+ is_mcp = tool_call.name.startswith(MCP_TOOL_NAME_PREFIX)
114
+
115
+ if is_mcp and self.enable_mcp:
116
+ valid_tool_calls.append(tool_call)
117
+ elif not is_mcp and self.enable_function:
118
+ valid_tool_calls.append(tool_call)
119
+
120
+ return valid_tool_calls
121
+
122
+ def _execute_tools_and_continue(
123
+ self,
124
+ messages: List[ChatMessage],
125
+ assistant_response_content: str,
126
+ tool_calls: List[ToolCall],
127
+ stream: bool,
128
+ recursion_depth: int,
129
+ ) -> Generator[Union[LLMResponse, RefreshLive], None, None]:
130
+ """Execute tool calls and continue the conversation"""
131
+ # Signal that new content is coming
132
+ yield RefreshLive()
133
+
134
+ # Add assistant message with tool calls to history (only once)
135
+ messages.append(ChatMessage(role="assistant", content=assistant_response_content, tool_calls=tool_calls))
136
+
137
+ # Execute each tool call and add results to messages
138
+ tool_role = self.provider.detect_tool_role()
139
+
140
+ for tool_call in tool_calls:
141
+ function_result, _ = execute_tool_call(tool_call)
142
+
143
+ messages.append(
144
+ ChatMessage(
145
+ role=tool_role,
146
+ content=function_result,
147
+ name=tool_call.name,
148
+ tool_call_id=tool_call.id,
117
149
  )
150
+ )
118
151
 
119
- # Make a recursive call with the updated history
120
- yield from self.completion_with_tools(messages, stream=stream, recursion_depth=recursion_depth + 1)
152
+ # Continue the conversation with updated history
153
+ yield from self.completion_with_tools(messages, stream=stream, recursion_depth=recursion_depth + 1)
yaicli/llms/provider.py CHANGED
@@ -43,10 +43,13 @@ class ProviderFactory:
43
43
  "chatglm": (".providers.chatglm_provider", "ChatglmProvider"),
44
44
  "chutes": (".providers.chutes_provider", "ChutesProvider"),
45
45
  "cohere": (".providers.cohere_provider", "CohereProvider"),
46
+ "cohere-bedrock": (".providers.cohere_provider", "CohereBadrockProvider"),
47
+ "cohere-sagemaker": (".providers.cohere_provider", "CohereSagemakerProvider"),
46
48
  "deepseek": (".providers.deepseek_provider", "DeepSeekProvider"),
47
49
  "doubao": (".providers.doubao_provider", "DoubaoProvider"),
48
50
  "gemini": (".providers.gemini_provider", "GeminiProvider"),
49
51
  "groq": (".providers.groq_provider", "GroqProvider"),
52
+ "huggingface": (".providers.huggingface_provider", "HuggingFaceProvider"),
50
53
  "infini-ai": (".providers.infiniai_provider", "InfiniAIProvider"),
51
54
  "minimax": (".providers.minimax_provider", "MinimaxProvider"),
52
55
  "modelscope": (".providers.modelscope_provider", "ModelScopeProvider"),
@@ -63,7 +63,7 @@ class AI21Provider(OpenAIProvider):
63
63
  if finish_reason == "tool_calls" and not content:
64
64
  # tool call assistant message, content can't be empty
65
65
  # Error code: 422 - {'detail': {'error': ['Value error, message content must not be an empty string']}}
66
- content = tool_call.id
66
+ content = tool_call.id if tool_call else ""
67
67
 
68
68
  # Generate response object
69
69
  yield LLMResponse(
@@ -72,3 +72,35 @@ class AI21Provider(OpenAIProvider):
72
72
  tool_call=tool_call if finish_reason == "tool_calls" else None,
73
73
  finish_reason=finish_reason,
74
74
  )
75
+
76
+ def _process_tool_call_chunk(self, tool_calls, existing_tool_call=None):
77
+ """Process a tool call chunk from AI21 API response
78
+
79
+ Tool calls from AI21 are delivered across multiple chunks:
80
+ - First chunk contains function name
81
+ - Subsequent chunks contain arguments data
82
+ - Final chunk (with finish_reason='tool_calls') contains complete arguments
83
+
84
+ Args:
85
+ tool_calls: Tool call data from current chunk
86
+ existing_tool_call: Previously accumulated tool call object
87
+
88
+ Returns:
89
+ ToolCall: Updated tool call object with accumulated data
90
+ """
91
+ # Get the first (and only) tool call from the chunk
92
+ call = tool_calls[0]
93
+
94
+ if existing_tool_call is None:
95
+ # First chunk - create new tool call with function name
96
+ return ToolCall(id=call.id, name=call.function.name, arguments="{}")
97
+ else:
98
+ # Update existing tool call with new arguments data
99
+ # Keep existing data and update with new information
100
+ existing_arguments = existing_tool_call.arguments
101
+ new_arguments = call.function.arguments if hasattr(call.function, "arguments") else "{}"
102
+
103
+ # Combine arguments (new arguments should override if available)
104
+ combined_arguments = new_arguments if new_arguments else existing_arguments
105
+
106
+ return ToolCall(id=existing_tool_call.id, name=existing_tool_call.name, arguments=combined_arguments)
@@ -1,9 +1,10 @@
1
1
  import json
2
- from typing import Any, Dict, Generator, Optional
2
+ from typing import Generator, Optional, Union, overload
3
3
 
4
4
  from openai._streaming import Stream
5
5
  from openai.types.chat.chat_completion import ChatCompletion, Choice
6
6
  from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
7
+ from openai.types.chat.chat_completion_chunk import Choice as ChoiceChunk
7
8
 
8
9
  from ...schemas import LLMResponse, ToolCall
9
10
  from .openai_provider import OpenAIProvider
@@ -14,10 +15,14 @@ class ChatglmProvider(OpenAIProvider):
14
15
 
15
16
  DEFAULT_BASE_URL = "https://open.bigmodel.cn/api/paas/v4/"
16
17
 
17
- def get_completion_params(self) -> Dict[str, Any]:
18
- params = super().get_completion_params()
19
- params["max_tokens"] = params.pop("max_completion_tokens")
20
- return params
18
+ COMPLETION_PARAMS_KEYS = {
19
+ "model": "MODEL",
20
+ "temperature": "TEMPERATURE",
21
+ "top_p": "TOP_P",
22
+ "max_tokens": "MAX_TOKENS",
23
+ "do_sample": "DO_SAMPLE",
24
+ "extra_body": "EXTRA_BODY",
25
+ }
21
26
 
22
27
  def _handle_normal_response(self, response: ChatCompletion) -> Generator[LLMResponse, None, None]:
23
28
  """Handle normal (non-streaming) response
@@ -57,7 +62,7 @@ class ChatglmProvider(OpenAIProvider):
57
62
  tmp_content = content[tool_index:]
58
63
  # Tool call data may in content after the <think> block
59
64
  try:
60
- choice = self.parse_choice_from_content(tmp_content)
65
+ choice = self.parse_choice_from_content(tmp_content, Choice)
61
66
  except ValueError:
62
67
  pass
63
68
  if hasattr(choice, "message") and hasattr(choice.message, "tool_calls") and choice.message.tool_calls: # type: ignore
@@ -97,8 +102,6 @@ class ChatglmProvider(OpenAIProvider):
97
102
  reasoning = self._get_reasoning_content(delta)
98
103
  full_reasoning += reasoning or ""
99
104
 
100
- if finish_reason:
101
- pass
102
105
  if finish_reason == "tool_calls" or ('{"index":' in content or '"tool_calls":' in content):
103
106
  # Tool call data may in content after the <think> block
104
107
  # >/n{"index": 0, "tool_call_id": "call_1", "function": {"name": "name", "arguments": "{}"}, "output": null}
@@ -106,7 +109,7 @@ class ChatglmProvider(OpenAIProvider):
106
109
  if tool_index != -1:
107
110
  tmp_content = full_content[tool_index:]
108
111
  try:
109
- choice = self.parse_choice_from_content(tmp_content)
112
+ choice = self.parse_choice_from_content(tmp_content, ChoiceChunk)
110
113
  except ValueError:
111
114
  pass
112
115
  if hasattr(choice.delta, "tool_calls") and choice.delta.tool_calls: # type: ignore
@@ -120,7 +123,31 @@ class ChatglmProvider(OpenAIProvider):
120
123
  tool_call = ToolCall(tool_id, tool_call_name, arguments)
121
124
  yield LLMResponse(reasoning=reasoning, content=content, tool_call=tool_call, finish_reason=finish_reason)
122
125
 
123
- def parse_choice_from_content(self, content: str) -> "Choice":
126
+ @overload
127
+ def parse_choice_from_content(self, content: str, choice_class: type[ChoiceChunk] = ChoiceChunk) -> "ChoiceChunk":
128
+ """
129
+ Parse the choice from the content after <think>...</think> block.
130
+ Args:
131
+ content: The content from the LLM response
132
+ Returns:
133
+ The choice object
134
+ Raises ValueError if the content is not valid JSON
135
+ """
136
+
137
+ @overload
138
+ def parse_choice_from_content(self, content: str, choice_class: type[Choice] = Choice) -> "Choice":
139
+ """
140
+ Parse the choice from the content after <think>...</think> block.
141
+ Args:
142
+ content: The content from the LLM response
143
+ Returns:
144
+ The choice object
145
+ Raises ValueError if the content is not valid JSON
146
+ """
147
+
148
+ def parse_choice_from_content(
149
+ self, content: str, choice_class: type[Union[Choice, ChoiceChunk]] = Choice
150
+ ) -> Union[Choice, ChoiceChunk]:
124
151
  """
125
152
  Parse the choice from the content after <think>...</think> block.
126
153
  Args:
@@ -134,6 +161,6 @@ class ChatglmProvider(OpenAIProvider):
134
161
  except json.JSONDecodeError:
135
162
  raise ValueError(f"Invalid message from LLM: {content}")
136
163
  try:
137
- return Choice.model_validate(content_dict)
164
+ return choice_class.model_validate(content_dict)
138
165
  except Exception as e:
139
166
  raise ValueError(f"Invalid message from LLM: {content}") from e
@@ -10,7 +10,8 @@ This module implements Cohere provider classes for different deployment options:
10
10
  from typing import Any, Dict, Generator, List, Optional
11
11
 
12
12
  from cohere import BedrockClientV2, ClientV2, SagemakerClientV2
13
- from cohere.types.tool_call_v2 import ToolCallV2, ToolCallV2Function
13
+ from cohere.types.tool_call_v2 import ToolCallV2
14
+ from cohere.types.tool_call_v2function import ToolCallV2Function
14
15
 
15
16
  from ...config import cfg
16
17
  from ...console import get_console
@@ -179,7 +180,9 @@ class CohereProvider(Provider):
179
180
  continue
180
181
  elif chunk.type == "tool-call-delta":
181
182
  # Tool call arguments being generated: cohere.types.chat_tool_call_delta_event_delta_message.ChatToolCallDeltaEventDeltaMessage
182
- tool_call.arguments += chunk.delta.message.tool_calls.function.arguments
183
+ if not tool_call:
184
+ continue
185
+ tool_call.arguments += chunk.delta.message.tool_calls.function.arguments or ""
183
186
  # Waiting for tool-call-end event
184
187
  continue
185
188
 
@@ -292,7 +295,7 @@ class CohereBadrockProvider(CohereProvider):
292
295
  return self.CLIENT_CLS(**self.client_params)
293
296
 
294
297
 
295
- class CohereSagemaker(CohereBadrockProvider):
298
+ class CohereSagemakerProvider(CohereBadrockProvider):
296
299
  """Cohere provider for AWS Sagemaker integration"""
297
300
 
298
301
  CLIENT_CLS = SagemakerClientV2
@@ -10,5 +10,6 @@ class DeepSeekProvider(OpenAIProvider):
10
10
 
11
11
  def get_completion_params(self) -> Dict[str, Any]:
12
12
  params = super().get_completion_params()
13
- params["max_tokens"] = params.pop("max_completion_tokens")
13
+ if "max_completion_tokens" in params:
14
+ params["max_tokens"] = params.pop("max_completion_tokens")
14
15
  return params
@@ -2,8 +2,6 @@ from typing import Any, Dict
2
2
 
3
3
  from volcenginesdkarkruntime import Ark
4
4
 
5
- from ...config import cfg
6
- from ...console import get_console
7
5
  from .openai_provider import OpenAIProvider
8
6
 
9
7
 
@@ -13,18 +11,6 @@ class DoubaoProvider(OpenAIProvider):
13
11
  DEFAULT_BASE_URL = "https://ark.cn-beijing.volces.com/api/v3"
14
12
  CLIENT_CLS = Ark
15
13
 
16
- def __init__(self, config: dict = cfg, **kwargs):
17
- self.config = config
18
- self.enable_function = self.config["ENABLE_FUNCTIONS"]
19
- self.client_params = self.get_client_params()
20
-
21
- # Initialize client
22
- self.client = self.CLIENT_CLS(**self.client_params)
23
- self.console = get_console()
24
-
25
- # Store completion params
26
- self.completion_params = self.get_completion_params()
27
-
28
14
  def get_client_params(self) -> Dict[str, Any]:
29
15
  # Initialize client params
30
16
  client_params = {"base_url": self.DEFAULT_BASE_URL}