yaicli 0.5.9__py3-none-any.whl → 0.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. pyproject.toml +35 -12
  2. yaicli/cli.py +31 -20
  3. yaicli/const.py +6 -5
  4. yaicli/entry.py +1 -1
  5. yaicli/llms/__init__.py +13 -0
  6. yaicli/llms/client.py +120 -0
  7. yaicli/llms/provider.py +78 -0
  8. yaicli/llms/providers/ai21_provider.py +66 -0
  9. yaicli/llms/providers/chatglm_provider.py +139 -0
  10. yaicli/llms/providers/chutes_provider.py +14 -0
  11. yaicli/llms/providers/cohere_provider.py +298 -0
  12. yaicli/llms/providers/deepseek_provider.py +14 -0
  13. yaicli/llms/providers/doubao_provider.py +53 -0
  14. yaicli/llms/providers/groq_provider.py +16 -0
  15. yaicli/llms/providers/infiniai_provider.py +20 -0
  16. yaicli/llms/providers/minimax_provider.py +13 -0
  17. yaicli/llms/providers/modelscope_provider.py +14 -0
  18. yaicli/llms/providers/ollama_provider.py +187 -0
  19. yaicli/llms/providers/openai_provider.py +211 -0
  20. yaicli/llms/providers/openrouter_provider.py +14 -0
  21. yaicli/llms/providers/sambanova_provider.py +30 -0
  22. yaicli/llms/providers/siliconflow_provider.py +14 -0
  23. yaicli/llms/providers/targon_provider.py +14 -0
  24. yaicli/llms/providers/yi_provider.py +14 -0
  25. yaicli/printer.py +4 -16
  26. yaicli/schemas.py +12 -3
  27. yaicli/tools.py +59 -3
  28. {yaicli-0.5.9.dist-info → yaicli-0.6.1.dist-info}/METADATA +238 -32
  29. yaicli-0.6.1.dist-info/RECORD +43 -0
  30. yaicli/client.py +0 -391
  31. yaicli-0.5.9.dist-info/RECORD +0 -24
  32. {yaicli-0.5.9.dist-info → yaicli-0.6.1.dist-info}/WHEEL +0 -0
  33. {yaicli-0.5.9.dist-info → yaicli-0.6.1.dist-info}/entry_points.txt +0 -0
  34. {yaicli-0.5.9.dist-info → yaicli-0.6.1.dist-info}/licenses/LICENSE +0 -0
pyproject.toml CHANGED
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "yaicli"
3
- version = "0.5.9"
3
+ version = "0.6.1"
4
4
  description = "A simple CLI tool to interact with LLM"
5
5
  authors = [{ name = "belingud", email = "im.victor@qq.com" }]
6
6
  readme = "README.md"
@@ -8,7 +8,7 @@ requires-python = ">=3.9"
8
8
  license = { file = "LICENSE" }
9
9
  classifiers = [
10
10
  "Programming Language :: Python :: 3",
11
- "License :: OSI Approved :: MIT License",
11
+ "License :: OSI Approved :: Apache Software License",
12
12
  "Operating System :: OS Independent",
13
13
  ]
14
14
  keywords = [
@@ -16,18 +16,32 @@ keywords = [
16
16
  "llm",
17
17
  "ai",
18
18
  "chatgpt",
19
- "openai",
20
19
  "gpt",
21
20
  "llms",
22
- "openai",
23
21
  "terminal",
24
22
  "interactive",
25
- "interact",
26
- "interact with llm",
27
- "interact with chatgpt",
28
- "interact with openai",
29
- "interact with gpt",
30
- "interact with llms",
23
+ "command-line",
24
+ "ai-assistant",
25
+ "language-model",
26
+ "text-generation",
27
+ "conversation",
28
+ "prompt",
29
+ "completion",
30
+ "console-application",
31
+ "shell-integration",
32
+ "nlp",
33
+ "inference",
34
+ "ai-chat",
35
+ "python-tool",
36
+ "terminal-interface",
37
+ "ai-interaction",
38
+ "openai",
39
+ "claude",
40
+ "gemini",
41
+ "mistral",
42
+ "anthropic",
43
+ "groq",
44
+ "cohere",
31
45
  ]
32
46
  dependencies = [
33
47
  "click>=8.1.8",
@@ -35,7 +49,6 @@ dependencies = [
35
49
  "httpx>=0.28.1",
36
50
  "instructor>=1.7.9",
37
51
  "json-repair>=0.44.1",
38
- "litellm>=1.67.5",
39
52
  "openai>=1.76.0",
40
53
  "prompt-toolkit>=3.0.50",
41
54
  "rich>=13.9.4",
@@ -51,6 +64,12 @@ Documentation = "https://github.com/belingud/yaicli"
51
64
  ai = "yaicli.entry:app"
52
65
  yaicli = "yaicli.entry:app"
53
66
 
67
+ [project.optional-dependencies]
68
+ doubao = ["volcengine-python-sdk>=3.0.15"]
69
+ ollama = ["ollama>=0.5.1"]
70
+ cohere = ["cohere>=5.15.0"]
71
+ all = ["volcengine-python-sdk>=3.0.15", "ollama>=0.5.1", "cohere>=5.15.0"]
72
+
54
73
  [tool.pytest.ini_options]
55
74
  testpaths = ["tests"]
56
75
  python_files = ["test_*.py"]
@@ -60,7 +79,7 @@ filterwarnings = [
60
79
  "ignore::PendingDeprecationWarning",
61
80
  "ignore::UserWarning",
62
81
  "ignore::pydantic.PydanticDeprecatedSince20",
63
- "ignore:.*There is no current event loop.*:DeprecationWarning"
82
+ "ignore:.*There is no current event loop.*:DeprecationWarning",
64
83
  ]
65
84
 
66
85
  [tool.uv]
@@ -81,6 +100,10 @@ profile = "black"
81
100
  line-length = 120
82
101
  fix = true
83
102
 
103
+ [tool.ruff.lint]
104
+ select = ["F"]
105
+ fixable = ["F401"]
106
+
84
107
  [build-system]
85
108
  requires = ["hatchling>=1.18.0"]
86
109
  build-backend = "hatchling.build"
yaicli/cli.py CHANGED
@@ -17,7 +17,6 @@ from rich.panel import Panel
17
17
  from rich.prompt import Prompt
18
18
 
19
19
  from .chat import Chat, FileChatManager, chat_mgr
20
- from .client import ChatMessage, LitellmClient
21
20
  from .config import cfg
22
21
  from .console import get_console
23
22
  from .const import (
@@ -41,8 +40,10 @@ from .const import (
41
40
  )
42
41
  from .exceptions import ChatSaveError
43
42
  from .history import LimitedFileHistory
43
+ from .llms import LLMClient
44
44
  from .printer import Printer
45
45
  from .role import Role, RoleManager, role_mgr
46
+ from .schemas import ChatMessage
46
47
  from .utils import detect_os, detect_shell, filter_command
47
48
 
48
49
 
@@ -66,7 +67,7 @@ class CLI:
66
67
  self.role_manager = role_manager or role_mgr
67
68
  self.role: Role = self.role_manager.get_role(self.role_name)
68
69
  self.printer = Printer()
69
- self.client = client or LitellmClient(verbose=self.verbose)
70
+ self.client = client or self._create_client()
70
71
 
71
72
  self.bindings = KeyBindings()
72
73
 
@@ -338,7 +339,7 @@ class CLI:
338
339
  messages.append(ChatMessage(role="user", content=user_input))
339
340
  return messages
340
341
 
341
- def _handle_llm_response(self, user_input: str) -> Optional[str]:
342
+ def _handle_llm_response(self, user_input: str) -> tuple[Optional[str], list[ChatMessage]]:
342
343
  """Get response from API (streaming or normal) and print it.
343
344
  Returns the full content string or None if an error occurred.
344
345
 
@@ -347,44 +348,50 @@ class CLI:
347
348
 
348
349
  Returns:
349
350
  Optional[str]: The assistant's response content or None if an error occurred.
351
+ list[ChatMessage]: The updated message history.
350
352
  """
351
353
  messages = self._build_messages(user_input)
352
- if self.verbose:
353
- self.console.print(messages)
354
- if self.role != DefaultRoleNames.CODER:
354
+ if self.role.name != DefaultRoleNames.CODER:
355
355
  self.console.print("Assistant:", style="bold green")
356
356
  try:
357
- response = self.client.completion(messages, stream=cfg["STREAM"])
358
- if cfg["STREAM"]:
359
- content, _ = self.printer.display_stream(response, messages)
360
- else:
361
- content, _ = self.printer.display_normal(response, messages)
357
+ response_iterator = self.client.completion_with_tools(messages, stream=cfg["STREAM"])
362
358
 
363
- # Just return the content, message addition is handled in _process_user_input
364
- return content if content is not None else None
359
+ content, _ = self.printer.display_stream(response_iterator)
360
+
361
+ # The 'messages' list is modified by the client in-place
362
+ return content, messages
365
363
  except Exception as e:
366
364
  self.console.print(f"Error processing LLM response: {e}", style="red")
367
365
  if self.verbose:
368
366
  traceback.print_exc()
369
- return None
367
+ return None, messages
370
368
 
371
369
  def _process_user_input(self, user_input: str) -> bool:
372
370
  """Process user input: get response, print, update history, maybe execute.
373
371
  Returns True to continue REPL, False to exit on critical error.
374
372
  """
375
- content = self._handle_llm_response(user_input)
373
+ content, updated_messages = self._handle_llm_response(user_input)
376
374
 
377
- if content is None:
375
+ if content is None and not any(msg.tool_calls for msg in updated_messages):
378
376
  return True
379
377
 
380
- # Update chat history using Chat's add_message method
381
- self.chat.add_message("user", user_input)
382
- self.chat.add_message("assistant", content)
378
+ # The client modifies the message list in place, so the updated_messages
379
+ # contains the full history of the turn (system, history, user, assistant, tools).
380
+ # We replace the old history with the new one, removing the system prompt.
381
+ if updated_messages:
382
+ self.chat.history = updated_messages[1:]
383
383
 
384
384
  self._check_history_len()
385
385
 
386
386
  if self.current_mode == EXEC_MODE:
387
- self._confirm_and_execute(content)
387
+ # We need to extract the executable command from the last assistant message
388
+ # in case of tool use.
389
+ final_content = ""
390
+ if self.chat.history:
391
+ last_message = self.chat.history[-1]
392
+ if last_message.role == "assistant":
393
+ final_content = last_message.content or ""
394
+ self._confirm_and_execute(final_content)
388
395
  return True
389
396
 
390
397
  def _confirm_and_execute(self, raw_content: str) -> None:
@@ -555,3 +562,7 @@ class CLI:
555
562
  else:
556
563
  # Run in single-use mode
557
564
  self._run_once(user_input or "", shell=shell, code=code)
565
+
566
+ def _create_client(self):
567
+ """Create an LLM client instance based on configuration"""
568
+ return LLMClient(provider_name=cfg["PROVIDER"].lower(), verbose=self.verbose, config=cfg)
yaicli/const.py CHANGED
@@ -6,6 +6,7 @@ except ImportError:
6
6
  class StrEnum(str, Enum):
7
7
  """Compatible with python below 3.11"""
8
8
 
9
+
9
10
  from pathlib import Path
10
11
  from tempfile import gettempdir
11
12
  from typing import Any, Literal, Optional
@@ -51,7 +52,7 @@ DEFAULT_MODEL = "gpt-4o"
51
52
  DEFAULT_SHELL_NAME = "auto"
52
53
  DEFAULT_OS_NAME = "auto"
53
54
  DEFAULT_STREAM: BOOL_STR = "true"
54
- DEFAULT_TEMPERATURE: float = 0.5
55
+ DEFAULT_TEMPERATURE: float = 0.3
55
56
  DEFAULT_TOP_P: float = 1.0
56
57
  DEFAULT_MAX_TOKENS: int = 1024
57
58
  DEFAULT_MAX_HISTORY: int = 500
@@ -99,9 +100,9 @@ class DefaultRoleNames(StrEnum):
99
100
 
100
101
 
101
102
  DEFAULT_ROLES: dict[str, dict[str, Any]] = {
102
- DefaultRoleNames.SHELL: {"name": DefaultRoleNames.SHELL, "prompt": SHELL_PROMPT},
103
- DefaultRoleNames.DEFAULT: {"name": DefaultRoleNames.DEFAULT, "prompt": DEFAULT_PROMPT},
104
- DefaultRoleNames.CODER: {"name": DefaultRoleNames.CODER, "prompt": CODER_PROMPT},
103
+ DefaultRoleNames.SHELL.value: {"name": DefaultRoleNames.SHELL.value, "prompt": SHELL_PROMPT},
104
+ DefaultRoleNames.DEFAULT.value: {"name": DefaultRoleNames.DEFAULT.value, "prompt": DEFAULT_PROMPT},
105
+ DefaultRoleNames.CODER.value: {"name": DefaultRoleNames.CODER.value, "prompt": CODER_PROMPT},
105
106
  }
106
107
 
107
108
  # DEFAULT_CONFIG_MAP is a dictionary of the configuration options.
@@ -112,7 +113,7 @@ DEFAULT_ROLES: dict[str, dict[str, Any]] = {
112
113
  # - type: the type of the configuration option
113
114
  DEFAULT_CONFIG_MAP = {
114
115
  # Core API settings
115
- "BASE_URL": {"value": DEFAULT_BASE_URL, "env_key": "YAI_BASE_URL", "type": str},
116
+ "BASE_URL": {"value": "", "env_key": "YAI_BASE_URL", "type": str},
116
117
  "API_KEY": {"value": "", "env_key": "YAI_API_KEY", "type": str},
117
118
  "MODEL": {"value": DEFAULT_MODEL, "env_key": "YAI_MODEL", "type": str},
118
119
  # System detection hints
yaicli/entry.py CHANGED
@@ -82,7 +82,7 @@ def main(
82
82
  ),
83
83
  # ------------------- Role Options -------------------
84
84
  role: str = typer.Option(
85
- DefaultRoleNames.DEFAULT,
85
+ DefaultRoleNames.DEFAULT.value,
86
86
  "--role",
87
87
  "-r",
88
88
  help="Specify the assistant role to use.",
@@ -0,0 +1,13 @@
1
+ from ..config import cfg
2
+ from .client import LLMClient
3
+ from .provider import Provider, ProviderFactory
4
+
5
+ __all__ = ["LLMClient", "Provider", "ProviderFactory"]
6
+
7
+
8
+ class BaseProvider:
9
+ def __init__(self) -> None:
10
+ self.api_key = cfg["API_KEY"]
11
+ self.model = cfg["MODEL"]
12
+ self.base_url = cfg["BASE_URL"]
13
+ self.timeout = cfg["TIMEOUT"]
yaicli/llms/client.py ADDED
@@ -0,0 +1,120 @@
1
+ from typing import Generator, List, Optional, Union
2
+
3
+ from ..config import cfg
4
+ from ..console import get_console
5
+ from ..schemas import ChatMessage, LLMResponse, RefreshLive, ToolCall
6
+ from ..tools import execute_tool_call
7
+ from .provider import Provider, ProviderFactory
8
+
9
+
10
+ class LLMClient:
11
+ """
12
+ LLM Client that coordinates provider interactions and tool calling
13
+
14
+ This class handles the higher level logic of:
15
+ 1. Getting responses from LLM providers
16
+ 2. Managing tool calls and their execution
17
+ 3. Handling conversation flow with tools
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ provider: Optional[Provider] = None,
23
+ provider_name: str = "",
24
+ config: dict = cfg,
25
+ verbose: bool = False,
26
+ **kwargs,
27
+ ):
28
+ """
29
+ Initialize LLM client
30
+
31
+ Args:
32
+ provider: Optional pre-initialized Provider instance
33
+ provider_name: Name of the provider to use if provider not provided
34
+ config: Configuration dictionary
35
+ verbose: Whether to enable verbose logging
36
+ """
37
+ self.config = config
38
+ self.verbose = verbose
39
+ self.console = get_console()
40
+
41
+ # Use provided provider or create one
42
+ if provider:
43
+ self.provider = provider
44
+ elif provider_name:
45
+ self.provider = ProviderFactory.create_provider(provider_name, config=config, verbose=verbose, **kwargs)
46
+ else:
47
+ provider_name = config.get("PROVIDER", "openai").lower()
48
+ self.provider = ProviderFactory.create_provider(provider_name, config=config, verbose=verbose, **kwargs)
49
+
50
+ self.max_recursion_depth = config.get("MAX_RECURSION_DEPTH", 5)
51
+
52
+ def completion_with_tools(
53
+ self,
54
+ messages: List[ChatMessage],
55
+ stream: bool = False,
56
+ recursion_depth: int = 0,
57
+ ) -> Generator[Union[LLMResponse, RefreshLive], None, None]:
58
+ """
59
+ Get completion from provider with tool calling support
60
+
61
+ Args:
62
+ messages: List of messages for the conversation
63
+ stream: Whether to stream the response
64
+ recursion_depth: Current recursion depth for tool calls
65
+
66
+ Yields:
67
+ LLMResponse objects and control signals
68
+ """
69
+ if recursion_depth >= self.max_recursion_depth:
70
+ self.console.print(
71
+ f"Maximum recursion depth ({self.max_recursion_depth}) reached, stopping further tool calls",
72
+ style="yellow",
73
+ )
74
+ return
75
+
76
+ # Get completion from provider
77
+ llm_response_generator = self.provider.completion(messages, stream=stream)
78
+
79
+ # To hold the full response
80
+ assistant_response_content = ""
81
+ tool_calls: List[ToolCall] = []
82
+
83
+ # Process all responses from the provider
84
+ for llm_response in llm_response_generator:
85
+ # Forward the response to the caller
86
+ yield llm_response
87
+
88
+ # Collect content and tool calls
89
+ if llm_response.content:
90
+ assistant_response_content += llm_response.content
91
+ if llm_response.tool_call and llm_response.tool_call not in tool_calls:
92
+ tool_calls.append(llm_response.tool_call)
93
+
94
+ # If we have tool calls, execute them and make recursive call
95
+ if tool_calls and self.config["ENABLE_FUNCTIONS"]:
96
+ # Yield a refresh signal to indicate new content is coming
97
+ yield RefreshLive()
98
+
99
+ # Append the assistant message with tool calls to history
100
+ messages.append(ChatMessage(role="assistant", content=assistant_response_content, tool_calls=tool_calls))
101
+
102
+ # Execute each tool call and append the results
103
+ for tool_call in tool_calls:
104
+ function_result, _ = execute_tool_call(tool_call)
105
+
106
+ # Use provider's tool role detection
107
+ tool_role = self.provider.detect_tool_role()
108
+
109
+ # Append the tool result to history
110
+ messages.append(
111
+ ChatMessage(
112
+ role=tool_role,
113
+ content=function_result,
114
+ name=tool_call.name,
115
+ tool_call_id=tool_call.id,
116
+ )
117
+ )
118
+
119
+ # Make a recursive call with the updated history
120
+ yield from self.completion_with_tools(messages, stream=stream, recursion_depth=recursion_depth + 1)
@@ -0,0 +1,78 @@
1
+ import importlib
2
+ from abc import ABC, abstractmethod
3
+ from typing import Generator, List
4
+
5
+ from ..schemas import ChatMessage, LLMResponse
6
+
7
+
8
+ class Provider(ABC):
9
+ """Base abstract class for LLM providers"""
10
+
11
+ APP_NAME = "yaicli"
12
+ APPA_REFERER = "https://github.com/halfrost/yaicli"
13
+
14
+ @abstractmethod
15
+ def completion(
16
+ self,
17
+ messages: List[ChatMessage],
18
+ stream: bool = False,
19
+ ) -> Generator[LLMResponse, None, None]:
20
+ """
21
+ Send a completion request to the LLM provider
22
+
23
+ Args:
24
+ messages: List of message objects representing the conversation
25
+ stream: Whether to stream the response
26
+
27
+ Returns:
28
+ Generator yielding LLMResponse objects
29
+ """
30
+ pass
31
+
32
+ @abstractmethod
33
+ def detect_tool_role(self) -> str:
34
+ """Return the role that should be used for tool responses"""
35
+ pass
36
+
37
+
38
+ class ProviderFactory:
39
+ """Factory to create LLM provider instances"""
40
+
41
+ providers_map = {
42
+ "openai": (".providers.openai_provider", "OpenAIProvider"),
43
+ "modelscope": (".providers.modelscope_provider", "ModelScopeProvider"),
44
+ "chatglm": (".providers.chatglm_provider", "ChatglmProvider"),
45
+ "openrouter": (".providers.openrouter_provider", "OpenRouterProvider"),
46
+ "siliconflow": (".providers.siliconflow_provider", "SiliconFlowProvider"),
47
+ "chutes": (".providers.chutes_provider", "ChutesProvider"),
48
+ "infini-ai": (".providers.infiniai_provider", "InfiniAIProvider"),
49
+ "yi": (".providers.yi_provider", "YiProvider"),
50
+ "deepseek": (".providers.deepseek_provider", "DeepSeekProvider"),
51
+ "doubao": (".providers.doubao_provider", "DoubaoProvider"),
52
+ "groq": (".providers.groq_provider", "GroqProvider"),
53
+ "ai21": (".providers.ai21_provider", "AI21Provider"),
54
+ "ollama": (".providers.ollama_provider", "OllamaProvider"),
55
+ "cohere": (".providers.cohere_provider", "CohereProvider"),
56
+ "sambanova": (".providers.sambanova_provider", "SambanovaProvider"),
57
+ "minimax": (".providers.minimax_provider", "MinimaxProvider"),
58
+ "targon": (".providers.targon_provider", "TargonProvider"),
59
+ }
60
+
61
+ @classmethod
62
+ def create_provider(cls, provider_type: str, verbose: bool = False, **kwargs) -> Provider:
63
+ """Create a provider instance based on provider type
64
+
65
+ Args:
66
+ provider_type: The type of provider to create
67
+ **kwargs: Additional parameters to pass to the provider
68
+
69
+ Returns:
70
+ A Provider instance
71
+ """
72
+ provider_type = provider_type.lower()
73
+ if provider_type not in cls.providers_map:
74
+ raise ValueError(f"Unknown provider: {provider_type}")
75
+
76
+ module_path, class_name = cls.providers_map[provider_type]
77
+ module = importlib.import_module(module_path, package="yaicli.llms")
78
+ return getattr(module, class_name)(verbose=verbose, **kwargs)
@@ -0,0 +1,66 @@
1
+ from typing import Any, Dict, Generator, Optional
2
+
3
+ from openai._streaming import Stream
4
+ from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
5
+
6
+ from ...schemas import LLMResponse, ToolCall
7
+ from .openai_provider import OpenAIProvider
8
+
9
+
10
+ class AI21Provider(OpenAIProvider):
11
+ """AI21 provider implementation based on openai-compatible API"""
12
+
13
+ DEFAULT_BASE_URL = "https://api.ai21.com/studio/v1"
14
+
15
+ def get_completion_params(self) -> Dict[str, Any]:
16
+ params = super().get_completion_params()
17
+ params["max_tokens"] = params.pop("max_completion_tokens")
18
+ return params
19
+
20
+ def _handle_stream_response(self, response: Stream[ChatCompletionChunk]) -> Generator[LLMResponse, None, None]:
21
+ """Handle streaming response from AI21 models
22
+
23
+ Processes chunks from streaming API, extracting content, reasoning and tool calls.
24
+ The tool call response is scattered across multiple chunks.
25
+
26
+ Args:
27
+ response: Stream of chat completion chunks from AI21 API
28
+
29
+ Yields:
30
+ Generator yielding LLMResponse objects containing:
31
+ - reasoning: The thinking/reasoning content (if any)
32
+ - content: The normal response content
33
+ - tool_call: Tool call information when applicable
34
+ """
35
+ # Initialize tool call object to accumulate tool call data across chunks
36
+ tool_call: Optional[ToolCall] = None
37
+
38
+ # Process each chunk in the response stream
39
+ for chunk in response:
40
+ choice = chunk.choices[0]
41
+ delta = choice.delta
42
+ finish_reason = choice.finish_reason
43
+
44
+ # Extract content from current chunk
45
+ content = delta.content or ""
46
+
47
+ # Extract reasoning content if available
48
+ reasoning = self._get_reasoning_content(getattr(delta, "model_extra", None) or delta)
49
+
50
+ # Process tool call information that may be scattered across chunks
51
+ if hasattr(delta, "tool_calls") and delta.tool_calls:
52
+ tool_call = self._process_tool_call_chunk(delta.tool_calls, tool_call)
53
+
54
+ # AI21 specific handling: content cannot be empty for tool calls
55
+ if finish_reason == "tool_calls" and not content:
56
+ # tool call assistant message, content can't be empty
57
+ # Error code: 422 - {'detail': {'error': ['Value error, message content must not be an empty string']}}
58
+ content = tool_call.id
59
+
60
+ # Generate response object
61
+ yield LLMResponse(
62
+ reasoning=reasoning,
63
+ content=content,
64
+ tool_call=tool_call if finish_reason == "tool_calls" else None,
65
+ finish_reason=finish_reason,
66
+ )
@@ -0,0 +1,139 @@
1
+ import json
2
+ from typing import Any, Dict, Generator, Optional
3
+
4
+ from openai._streaming import Stream
5
+ from openai.types.chat.chat_completion import ChatCompletion, Choice
6
+ from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
7
+
8
+ from ...schemas import LLMResponse, ToolCall
9
+ from .openai_provider import OpenAIProvider
10
+
11
+
12
+ class ChatglmProvider(OpenAIProvider):
13
+ """Chatglm provider support"""
14
+
15
+ DEFAULT_BASE_URL = "https://open.bigmodel.cn/api/paas/v4/"
16
+
17
+ def get_completion_params(self) -> Dict[str, Any]:
18
+ params = super().get_completion_params()
19
+ params["max_tokens"] = params.pop("max_completion_tokens")
20
+ return params
21
+
22
+ def _handle_normal_response(self, response: ChatCompletion) -> Generator[LLMResponse, None, None]:
23
+ """Handle normal (non-streaming) response
24
+ Support both openai capabilities and chatglm
25
+
26
+ Returns:
27
+ LLMContent object with:
28
+ - reasoning: The thinking/reasoning content (if any)
29
+ - content: The normal response content
30
+ """
31
+ choice = response.choices[0]
32
+ content = choice.message.content or "" # type: ignore
33
+ reasoning = choice.message.reasoning_content # type: ignore
34
+ finish_reason = choice.finish_reason
35
+ tool_call: Optional[ToolCall] = None
36
+
37
+ # Check if the response contains reasoning content
38
+ if "<think>" in content and "</think>" in content:
39
+ # Extract reasoning content
40
+ content = content.lstrip()
41
+ if content.startswith("<think>"):
42
+ think_end = content.find("</think>")
43
+ if think_end != -1:
44
+ reasoning = content[7:think_end].strip() # Start after <think>
45
+ # Remove the <think> block from the main content
46
+ content = content[think_end + 8 :].strip() # Start after </think>
47
+ # Check if the response contains reasoning content in model_extra
48
+ elif hasattr(choice.message, "model_extra") and choice.message.model_extra: # type: ignore
49
+ model_extra = choice.message.model_extra # type: ignore
50
+ reasoning = self._get_reasoning_content(model_extra)
51
+ if finish_reason == "tool_calls":
52
+ if '{"index":' in content or '"tool_calls":' in content:
53
+ # Tool call data may in content after the <think> block
54
+ # >/n{"index": 0, "tool_call_id": "call_1", "function": {"name": "name", "arguments": "{}"}, "output": null}
55
+ tool_index = content.find('{"index":')
56
+ if tool_index != -1:
57
+ tmp_content = content[tool_index:]
58
+ # Tool call data may in content after the <think> block
59
+ try:
60
+ choice = self.parse_choice_from_content(tmp_content)
61
+ except ValueError:
62
+ pass
63
+ if hasattr(choice, "message") and hasattr(choice.message, "tool_calls") and choice.message.tool_calls: # type: ignore
64
+ tool = choice.message.tool_calls[0] # type: ignore
65
+ tool_call = ToolCall(tool.id, tool.function.name or "", tool.function.arguments)
66
+
67
+ yield LLMResponse(reasoning=reasoning, content=content, finish_reason=finish_reason, tool_call=tool_call)
68
+
69
+ def _handle_stream_response(self, response: Stream[ChatCompletionChunk]) -> Generator[LLMResponse, None, None]:
70
+ """Handle streaming response
71
+ Support both openai capabilities and chatglm
72
+
73
+ Returns:
74
+ Generator yielding LLMContent objects with:
75
+ - reasoning: The thinking/reasoning content (if any)
76
+ - content: The normal response content
77
+ """
78
+ full_reasoning = ""
79
+ full_content = ""
80
+ content = ""
81
+ reasoning = ""
82
+ tool_id = ""
83
+ tool_call_name = ""
84
+ arguments = ""
85
+ tool_call: Optional[ToolCall] = None
86
+ for chunk in response:
87
+ # Check if the response contains reasoning content
88
+ choice = chunk.choices[0] # type: ignore
89
+ delta = choice.delta
90
+ finish_reason = choice.finish_reason
91
+
92
+ # Concat content
93
+ content = delta.content or ""
94
+ full_content += content
95
+
96
+ # Concat reasoning
97
+ reasoning = self._get_reasoning_content(delta)
98
+ full_reasoning += reasoning or ""
99
+
100
+ if finish_reason:
101
+ pass
102
+ if finish_reason == "tool_calls" or ('{"index":' in content or '"tool_calls":' in content):
103
+ # Tool call data may in content after the <think> block
104
+ # >/n{"index": 0, "tool_call_id": "call_1", "function": {"name": "name", "arguments": "{}"}, "output": null}
105
+ tool_index = full_content.find('{"index":')
106
+ if tool_index != -1:
107
+ tmp_content = full_content[tool_index:]
108
+ try:
109
+ choice = self.parse_choice_from_content(tmp_content)
110
+ except ValueError:
111
+ pass
112
+ if hasattr(choice.delta, "tool_calls") and choice.delta.tool_calls: # type: ignore
113
+ # Handle tool calls
114
+ tool_id = choice.delta.tool_calls[0].id or "" # type: ignore
115
+ for tool in choice.delta.tool_calls: # type: ignore
116
+ if not tool.function:
117
+ continue
118
+ tool_call_name = tool.function.name or ""
119
+ arguments += tool.function.arguments or ""
120
+ tool_call = ToolCall(tool_id, tool_call_name, arguments)
121
+ yield LLMResponse(reasoning=reasoning, content=content, tool_call=tool_call, finish_reason=finish_reason)
122
+
123
+ def parse_choice_from_content(self, content: str) -> "Choice":
124
+ """
125
+ Parse the choice from the content after <think>...</think> block.
126
+ Args:
127
+ content: The content from the LLM response
128
+ Returns:
129
+ The choice object
130
+ Raises ValueError if the content is not valid JSON
131
+ """
132
+ try:
133
+ content_dict = json.loads(content)
134
+ except json.JSONDecodeError:
135
+ raise ValueError(f"Invalid message from LLM: {content}")
136
+ try:
137
+ return Choice.model_validate(content_dict)
138
+ except Exception as e:
139
+ raise ValueError(f"Invalid message from LLM: {content}") from e