yaicli 0.5.9__py3-none-any.whl → 0.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyproject.toml +35 -12
- yaicli/cli.py +31 -20
- yaicli/const.py +6 -5
- yaicli/entry.py +1 -1
- yaicli/llms/__init__.py +13 -0
- yaicli/llms/client.py +120 -0
- yaicli/llms/provider.py +78 -0
- yaicli/llms/providers/ai21_provider.py +66 -0
- yaicli/llms/providers/chatglm_provider.py +139 -0
- yaicli/llms/providers/chutes_provider.py +14 -0
- yaicli/llms/providers/cohere_provider.py +298 -0
- yaicli/llms/providers/deepseek_provider.py +14 -0
- yaicli/llms/providers/doubao_provider.py +53 -0
- yaicli/llms/providers/groq_provider.py +16 -0
- yaicli/llms/providers/infiniai_provider.py +20 -0
- yaicli/llms/providers/minimax_provider.py +13 -0
- yaicli/llms/providers/modelscope_provider.py +14 -0
- yaicli/llms/providers/ollama_provider.py +187 -0
- yaicli/llms/providers/openai_provider.py +211 -0
- yaicli/llms/providers/openrouter_provider.py +14 -0
- yaicli/llms/providers/sambanova_provider.py +30 -0
- yaicli/llms/providers/siliconflow_provider.py +14 -0
- yaicli/llms/providers/targon_provider.py +14 -0
- yaicli/llms/providers/yi_provider.py +14 -0
- yaicli/printer.py +4 -16
- yaicli/schemas.py +12 -3
- yaicli/tools.py +59 -3
- {yaicli-0.5.9.dist-info → yaicli-0.6.1.dist-info}/METADATA +238 -32
- yaicli-0.6.1.dist-info/RECORD +43 -0
- yaicli/client.py +0 -391
- yaicli-0.5.9.dist-info/RECORD +0 -24
- {yaicli-0.5.9.dist-info → yaicli-0.6.1.dist-info}/WHEEL +0 -0
- {yaicli-0.5.9.dist-info → yaicli-0.6.1.dist-info}/entry_points.txt +0 -0
- {yaicli-0.5.9.dist-info → yaicli-0.6.1.dist-info}/licenses/LICENSE +0 -0
pyproject.toml
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
[project]
|
2
2
|
name = "yaicli"
|
3
|
-
version = "0.
|
3
|
+
version = "0.6.1"
|
4
4
|
description = "A simple CLI tool to interact with LLM"
|
5
5
|
authors = [{ name = "belingud", email = "im.victor@qq.com" }]
|
6
6
|
readme = "README.md"
|
@@ -8,7 +8,7 @@ requires-python = ">=3.9"
|
|
8
8
|
license = { file = "LICENSE" }
|
9
9
|
classifiers = [
|
10
10
|
"Programming Language :: Python :: 3",
|
11
|
-
"License :: OSI Approved ::
|
11
|
+
"License :: OSI Approved :: Apache Software License",
|
12
12
|
"Operating System :: OS Independent",
|
13
13
|
]
|
14
14
|
keywords = [
|
@@ -16,18 +16,32 @@ keywords = [
|
|
16
16
|
"llm",
|
17
17
|
"ai",
|
18
18
|
"chatgpt",
|
19
|
-
"openai",
|
20
19
|
"gpt",
|
21
20
|
"llms",
|
22
|
-
"openai",
|
23
21
|
"terminal",
|
24
22
|
"interactive",
|
25
|
-
"
|
26
|
-
"
|
27
|
-
"
|
28
|
-
"
|
29
|
-
"
|
30
|
-
"
|
23
|
+
"command-line",
|
24
|
+
"ai-assistant",
|
25
|
+
"language-model",
|
26
|
+
"text-generation",
|
27
|
+
"conversation",
|
28
|
+
"prompt",
|
29
|
+
"completion",
|
30
|
+
"console-application",
|
31
|
+
"shell-integration",
|
32
|
+
"nlp",
|
33
|
+
"inference",
|
34
|
+
"ai-chat",
|
35
|
+
"python-tool",
|
36
|
+
"terminal-interface",
|
37
|
+
"ai-interaction",
|
38
|
+
"openai",
|
39
|
+
"claude",
|
40
|
+
"gemini",
|
41
|
+
"mistral",
|
42
|
+
"anthropic",
|
43
|
+
"groq",
|
44
|
+
"cohere",
|
31
45
|
]
|
32
46
|
dependencies = [
|
33
47
|
"click>=8.1.8",
|
@@ -35,7 +49,6 @@ dependencies = [
|
|
35
49
|
"httpx>=0.28.1",
|
36
50
|
"instructor>=1.7.9",
|
37
51
|
"json-repair>=0.44.1",
|
38
|
-
"litellm>=1.67.5",
|
39
52
|
"openai>=1.76.0",
|
40
53
|
"prompt-toolkit>=3.0.50",
|
41
54
|
"rich>=13.9.4",
|
@@ -51,6 +64,12 @@ Documentation = "https://github.com/belingud/yaicli"
|
|
51
64
|
ai = "yaicli.entry:app"
|
52
65
|
yaicli = "yaicli.entry:app"
|
53
66
|
|
67
|
+
[project.optional-dependencies]
|
68
|
+
doubao = ["volcengine-python-sdk>=3.0.15"]
|
69
|
+
ollama = ["ollama>=0.5.1"]
|
70
|
+
cohere = ["cohere>=5.15.0"]
|
71
|
+
all = ["volcengine-python-sdk>=3.0.15", "ollama>=0.5.1", "cohere>=5.15.0"]
|
72
|
+
|
54
73
|
[tool.pytest.ini_options]
|
55
74
|
testpaths = ["tests"]
|
56
75
|
python_files = ["test_*.py"]
|
@@ -60,7 +79,7 @@ filterwarnings = [
|
|
60
79
|
"ignore::PendingDeprecationWarning",
|
61
80
|
"ignore::UserWarning",
|
62
81
|
"ignore::pydantic.PydanticDeprecatedSince20",
|
63
|
-
"ignore:.*There is no current event loop.*:DeprecationWarning"
|
82
|
+
"ignore:.*There is no current event loop.*:DeprecationWarning",
|
64
83
|
]
|
65
84
|
|
66
85
|
[tool.uv]
|
@@ -81,6 +100,10 @@ profile = "black"
|
|
81
100
|
line-length = 120
|
82
101
|
fix = true
|
83
102
|
|
103
|
+
[tool.ruff.lint]
|
104
|
+
select = ["F"]
|
105
|
+
fixable = ["F401"]
|
106
|
+
|
84
107
|
[build-system]
|
85
108
|
requires = ["hatchling>=1.18.0"]
|
86
109
|
build-backend = "hatchling.build"
|
yaicli/cli.py
CHANGED
@@ -17,7 +17,6 @@ from rich.panel import Panel
|
|
17
17
|
from rich.prompt import Prompt
|
18
18
|
|
19
19
|
from .chat import Chat, FileChatManager, chat_mgr
|
20
|
-
from .client import ChatMessage, LitellmClient
|
21
20
|
from .config import cfg
|
22
21
|
from .console import get_console
|
23
22
|
from .const import (
|
@@ -41,8 +40,10 @@ from .const import (
|
|
41
40
|
)
|
42
41
|
from .exceptions import ChatSaveError
|
43
42
|
from .history import LimitedFileHistory
|
43
|
+
from .llms import LLMClient
|
44
44
|
from .printer import Printer
|
45
45
|
from .role import Role, RoleManager, role_mgr
|
46
|
+
from .schemas import ChatMessage
|
46
47
|
from .utils import detect_os, detect_shell, filter_command
|
47
48
|
|
48
49
|
|
@@ -66,7 +67,7 @@ class CLI:
|
|
66
67
|
self.role_manager = role_manager or role_mgr
|
67
68
|
self.role: Role = self.role_manager.get_role(self.role_name)
|
68
69
|
self.printer = Printer()
|
69
|
-
self.client = client or
|
70
|
+
self.client = client or self._create_client()
|
70
71
|
|
71
72
|
self.bindings = KeyBindings()
|
72
73
|
|
@@ -338,7 +339,7 @@ class CLI:
|
|
338
339
|
messages.append(ChatMessage(role="user", content=user_input))
|
339
340
|
return messages
|
340
341
|
|
341
|
-
def _handle_llm_response(self, user_input: str) -> Optional[str]:
|
342
|
+
def _handle_llm_response(self, user_input: str) -> tuple[Optional[str], list[ChatMessage]]:
|
342
343
|
"""Get response from API (streaming or normal) and print it.
|
343
344
|
Returns the full content string or None if an error occurred.
|
344
345
|
|
@@ -347,44 +348,50 @@ class CLI:
|
|
347
348
|
|
348
349
|
Returns:
|
349
350
|
Optional[str]: The assistant's response content or None if an error occurred.
|
351
|
+
list[ChatMessage]: The updated message history.
|
350
352
|
"""
|
351
353
|
messages = self._build_messages(user_input)
|
352
|
-
if self.
|
353
|
-
self.console.print(messages)
|
354
|
-
if self.role != DefaultRoleNames.CODER:
|
354
|
+
if self.role.name != DefaultRoleNames.CODER:
|
355
355
|
self.console.print("Assistant:", style="bold green")
|
356
356
|
try:
|
357
|
-
|
358
|
-
if cfg["STREAM"]:
|
359
|
-
content, _ = self.printer.display_stream(response, messages)
|
360
|
-
else:
|
361
|
-
content, _ = self.printer.display_normal(response, messages)
|
357
|
+
response_iterator = self.client.completion_with_tools(messages, stream=cfg["STREAM"])
|
362
358
|
|
363
|
-
|
364
|
-
|
359
|
+
content, _ = self.printer.display_stream(response_iterator)
|
360
|
+
|
361
|
+
# The 'messages' list is modified by the client in-place
|
362
|
+
return content, messages
|
365
363
|
except Exception as e:
|
366
364
|
self.console.print(f"Error processing LLM response: {e}", style="red")
|
367
365
|
if self.verbose:
|
368
366
|
traceback.print_exc()
|
369
|
-
return None
|
367
|
+
return None, messages
|
370
368
|
|
371
369
|
def _process_user_input(self, user_input: str) -> bool:
|
372
370
|
"""Process user input: get response, print, update history, maybe execute.
|
373
371
|
Returns True to continue REPL, False to exit on critical error.
|
374
372
|
"""
|
375
|
-
content = self._handle_llm_response(user_input)
|
373
|
+
content, updated_messages = self._handle_llm_response(user_input)
|
376
374
|
|
377
|
-
if content is None:
|
375
|
+
if content is None and not any(msg.tool_calls for msg in updated_messages):
|
378
376
|
return True
|
379
377
|
|
380
|
-
#
|
381
|
-
|
382
|
-
|
378
|
+
# The client modifies the message list in place, so the updated_messages
|
379
|
+
# contains the full history of the turn (system, history, user, assistant, tools).
|
380
|
+
# We replace the old history with the new one, removing the system prompt.
|
381
|
+
if updated_messages:
|
382
|
+
self.chat.history = updated_messages[1:]
|
383
383
|
|
384
384
|
self._check_history_len()
|
385
385
|
|
386
386
|
if self.current_mode == EXEC_MODE:
|
387
|
-
|
387
|
+
# We need to extract the executable command from the last assistant message
|
388
|
+
# in case of tool use.
|
389
|
+
final_content = ""
|
390
|
+
if self.chat.history:
|
391
|
+
last_message = self.chat.history[-1]
|
392
|
+
if last_message.role == "assistant":
|
393
|
+
final_content = last_message.content or ""
|
394
|
+
self._confirm_and_execute(final_content)
|
388
395
|
return True
|
389
396
|
|
390
397
|
def _confirm_and_execute(self, raw_content: str) -> None:
|
@@ -555,3 +562,7 @@ class CLI:
|
|
555
562
|
else:
|
556
563
|
# Run in single-use mode
|
557
564
|
self._run_once(user_input or "", shell=shell, code=code)
|
565
|
+
|
566
|
+
def _create_client(self):
|
567
|
+
"""Create an LLM client instance based on configuration"""
|
568
|
+
return LLMClient(provider_name=cfg["PROVIDER"].lower(), verbose=self.verbose, config=cfg)
|
yaicli/const.py
CHANGED
@@ -6,6 +6,7 @@ except ImportError:
|
|
6
6
|
class StrEnum(str, Enum):
|
7
7
|
"""Compatible with python below 3.11"""
|
8
8
|
|
9
|
+
|
9
10
|
from pathlib import Path
|
10
11
|
from tempfile import gettempdir
|
11
12
|
from typing import Any, Literal, Optional
|
@@ -51,7 +52,7 @@ DEFAULT_MODEL = "gpt-4o"
|
|
51
52
|
DEFAULT_SHELL_NAME = "auto"
|
52
53
|
DEFAULT_OS_NAME = "auto"
|
53
54
|
DEFAULT_STREAM: BOOL_STR = "true"
|
54
|
-
DEFAULT_TEMPERATURE: float = 0.
|
55
|
+
DEFAULT_TEMPERATURE: float = 0.3
|
55
56
|
DEFAULT_TOP_P: float = 1.0
|
56
57
|
DEFAULT_MAX_TOKENS: int = 1024
|
57
58
|
DEFAULT_MAX_HISTORY: int = 500
|
@@ -99,9 +100,9 @@ class DefaultRoleNames(StrEnum):
|
|
99
100
|
|
100
101
|
|
101
102
|
DEFAULT_ROLES: dict[str, dict[str, Any]] = {
|
102
|
-
DefaultRoleNames.SHELL: {"name": DefaultRoleNames.SHELL, "prompt": SHELL_PROMPT},
|
103
|
-
DefaultRoleNames.DEFAULT: {"name": DefaultRoleNames.DEFAULT, "prompt": DEFAULT_PROMPT},
|
104
|
-
DefaultRoleNames.CODER: {"name": DefaultRoleNames.CODER, "prompt": CODER_PROMPT},
|
103
|
+
DefaultRoleNames.SHELL.value: {"name": DefaultRoleNames.SHELL.value, "prompt": SHELL_PROMPT},
|
104
|
+
DefaultRoleNames.DEFAULT.value: {"name": DefaultRoleNames.DEFAULT.value, "prompt": DEFAULT_PROMPT},
|
105
|
+
DefaultRoleNames.CODER.value: {"name": DefaultRoleNames.CODER.value, "prompt": CODER_PROMPT},
|
105
106
|
}
|
106
107
|
|
107
108
|
# DEFAULT_CONFIG_MAP is a dictionary of the configuration options.
|
@@ -112,7 +113,7 @@ DEFAULT_ROLES: dict[str, dict[str, Any]] = {
|
|
112
113
|
# - type: the type of the configuration option
|
113
114
|
DEFAULT_CONFIG_MAP = {
|
114
115
|
# Core API settings
|
115
|
-
"BASE_URL": {"value":
|
116
|
+
"BASE_URL": {"value": "", "env_key": "YAI_BASE_URL", "type": str},
|
116
117
|
"API_KEY": {"value": "", "env_key": "YAI_API_KEY", "type": str},
|
117
118
|
"MODEL": {"value": DEFAULT_MODEL, "env_key": "YAI_MODEL", "type": str},
|
118
119
|
# System detection hints
|
yaicli/entry.py
CHANGED
yaicli/llms/__init__.py
ADDED
@@ -0,0 +1,13 @@
|
|
1
|
+
from ..config import cfg
|
2
|
+
from .client import LLMClient
|
3
|
+
from .provider import Provider, ProviderFactory
|
4
|
+
|
5
|
+
__all__ = ["LLMClient", "Provider", "ProviderFactory"]
|
6
|
+
|
7
|
+
|
8
|
+
class BaseProvider:
|
9
|
+
def __init__(self) -> None:
|
10
|
+
self.api_key = cfg["API_KEY"]
|
11
|
+
self.model = cfg["MODEL"]
|
12
|
+
self.base_url = cfg["BASE_URL"]
|
13
|
+
self.timeout = cfg["TIMEOUT"]
|
yaicli/llms/client.py
ADDED
@@ -0,0 +1,120 @@
|
|
1
|
+
from typing import Generator, List, Optional, Union
|
2
|
+
|
3
|
+
from ..config import cfg
|
4
|
+
from ..console import get_console
|
5
|
+
from ..schemas import ChatMessage, LLMResponse, RefreshLive, ToolCall
|
6
|
+
from ..tools import execute_tool_call
|
7
|
+
from .provider import Provider, ProviderFactory
|
8
|
+
|
9
|
+
|
10
|
+
class LLMClient:
|
11
|
+
"""
|
12
|
+
LLM Client that coordinates provider interactions and tool calling
|
13
|
+
|
14
|
+
This class handles the higher level logic of:
|
15
|
+
1. Getting responses from LLM providers
|
16
|
+
2. Managing tool calls and their execution
|
17
|
+
3. Handling conversation flow with tools
|
18
|
+
"""
|
19
|
+
|
20
|
+
def __init__(
|
21
|
+
self,
|
22
|
+
provider: Optional[Provider] = None,
|
23
|
+
provider_name: str = "",
|
24
|
+
config: dict = cfg,
|
25
|
+
verbose: bool = False,
|
26
|
+
**kwargs,
|
27
|
+
):
|
28
|
+
"""
|
29
|
+
Initialize LLM client
|
30
|
+
|
31
|
+
Args:
|
32
|
+
provider: Optional pre-initialized Provider instance
|
33
|
+
provider_name: Name of the provider to use if provider not provided
|
34
|
+
config: Configuration dictionary
|
35
|
+
verbose: Whether to enable verbose logging
|
36
|
+
"""
|
37
|
+
self.config = config
|
38
|
+
self.verbose = verbose
|
39
|
+
self.console = get_console()
|
40
|
+
|
41
|
+
# Use provided provider or create one
|
42
|
+
if provider:
|
43
|
+
self.provider = provider
|
44
|
+
elif provider_name:
|
45
|
+
self.provider = ProviderFactory.create_provider(provider_name, config=config, verbose=verbose, **kwargs)
|
46
|
+
else:
|
47
|
+
provider_name = config.get("PROVIDER", "openai").lower()
|
48
|
+
self.provider = ProviderFactory.create_provider(provider_name, config=config, verbose=verbose, **kwargs)
|
49
|
+
|
50
|
+
self.max_recursion_depth = config.get("MAX_RECURSION_DEPTH", 5)
|
51
|
+
|
52
|
+
def completion_with_tools(
|
53
|
+
self,
|
54
|
+
messages: List[ChatMessage],
|
55
|
+
stream: bool = False,
|
56
|
+
recursion_depth: int = 0,
|
57
|
+
) -> Generator[Union[LLMResponse, RefreshLive], None, None]:
|
58
|
+
"""
|
59
|
+
Get completion from provider with tool calling support
|
60
|
+
|
61
|
+
Args:
|
62
|
+
messages: List of messages for the conversation
|
63
|
+
stream: Whether to stream the response
|
64
|
+
recursion_depth: Current recursion depth for tool calls
|
65
|
+
|
66
|
+
Yields:
|
67
|
+
LLMResponse objects and control signals
|
68
|
+
"""
|
69
|
+
if recursion_depth >= self.max_recursion_depth:
|
70
|
+
self.console.print(
|
71
|
+
f"Maximum recursion depth ({self.max_recursion_depth}) reached, stopping further tool calls",
|
72
|
+
style="yellow",
|
73
|
+
)
|
74
|
+
return
|
75
|
+
|
76
|
+
# Get completion from provider
|
77
|
+
llm_response_generator = self.provider.completion(messages, stream=stream)
|
78
|
+
|
79
|
+
# To hold the full response
|
80
|
+
assistant_response_content = ""
|
81
|
+
tool_calls: List[ToolCall] = []
|
82
|
+
|
83
|
+
# Process all responses from the provider
|
84
|
+
for llm_response in llm_response_generator:
|
85
|
+
# Forward the response to the caller
|
86
|
+
yield llm_response
|
87
|
+
|
88
|
+
# Collect content and tool calls
|
89
|
+
if llm_response.content:
|
90
|
+
assistant_response_content += llm_response.content
|
91
|
+
if llm_response.tool_call and llm_response.tool_call not in tool_calls:
|
92
|
+
tool_calls.append(llm_response.tool_call)
|
93
|
+
|
94
|
+
# If we have tool calls, execute them and make recursive call
|
95
|
+
if tool_calls and self.config["ENABLE_FUNCTIONS"]:
|
96
|
+
# Yield a refresh signal to indicate new content is coming
|
97
|
+
yield RefreshLive()
|
98
|
+
|
99
|
+
# Append the assistant message with tool calls to history
|
100
|
+
messages.append(ChatMessage(role="assistant", content=assistant_response_content, tool_calls=tool_calls))
|
101
|
+
|
102
|
+
# Execute each tool call and append the results
|
103
|
+
for tool_call in tool_calls:
|
104
|
+
function_result, _ = execute_tool_call(tool_call)
|
105
|
+
|
106
|
+
# Use provider's tool role detection
|
107
|
+
tool_role = self.provider.detect_tool_role()
|
108
|
+
|
109
|
+
# Append the tool result to history
|
110
|
+
messages.append(
|
111
|
+
ChatMessage(
|
112
|
+
role=tool_role,
|
113
|
+
content=function_result,
|
114
|
+
name=tool_call.name,
|
115
|
+
tool_call_id=tool_call.id,
|
116
|
+
)
|
117
|
+
)
|
118
|
+
|
119
|
+
# Make a recursive call with the updated history
|
120
|
+
yield from self.completion_with_tools(messages, stream=stream, recursion_depth=recursion_depth + 1)
|
yaicli/llms/provider.py
ADDED
@@ -0,0 +1,78 @@
|
|
1
|
+
import importlib
|
2
|
+
from abc import ABC, abstractmethod
|
3
|
+
from typing import Generator, List
|
4
|
+
|
5
|
+
from ..schemas import ChatMessage, LLMResponse
|
6
|
+
|
7
|
+
|
8
|
+
class Provider(ABC):
|
9
|
+
"""Base abstract class for LLM providers"""
|
10
|
+
|
11
|
+
APP_NAME = "yaicli"
|
12
|
+
APPA_REFERER = "https://github.com/halfrost/yaicli"
|
13
|
+
|
14
|
+
@abstractmethod
|
15
|
+
def completion(
|
16
|
+
self,
|
17
|
+
messages: List[ChatMessage],
|
18
|
+
stream: bool = False,
|
19
|
+
) -> Generator[LLMResponse, None, None]:
|
20
|
+
"""
|
21
|
+
Send a completion request to the LLM provider
|
22
|
+
|
23
|
+
Args:
|
24
|
+
messages: List of message objects representing the conversation
|
25
|
+
stream: Whether to stream the response
|
26
|
+
|
27
|
+
Returns:
|
28
|
+
Generator yielding LLMResponse objects
|
29
|
+
"""
|
30
|
+
pass
|
31
|
+
|
32
|
+
@abstractmethod
|
33
|
+
def detect_tool_role(self) -> str:
|
34
|
+
"""Return the role that should be used for tool responses"""
|
35
|
+
pass
|
36
|
+
|
37
|
+
|
38
|
+
class ProviderFactory:
|
39
|
+
"""Factory to create LLM provider instances"""
|
40
|
+
|
41
|
+
providers_map = {
|
42
|
+
"openai": (".providers.openai_provider", "OpenAIProvider"),
|
43
|
+
"modelscope": (".providers.modelscope_provider", "ModelScopeProvider"),
|
44
|
+
"chatglm": (".providers.chatglm_provider", "ChatglmProvider"),
|
45
|
+
"openrouter": (".providers.openrouter_provider", "OpenRouterProvider"),
|
46
|
+
"siliconflow": (".providers.siliconflow_provider", "SiliconFlowProvider"),
|
47
|
+
"chutes": (".providers.chutes_provider", "ChutesProvider"),
|
48
|
+
"infini-ai": (".providers.infiniai_provider", "InfiniAIProvider"),
|
49
|
+
"yi": (".providers.yi_provider", "YiProvider"),
|
50
|
+
"deepseek": (".providers.deepseek_provider", "DeepSeekProvider"),
|
51
|
+
"doubao": (".providers.doubao_provider", "DoubaoProvider"),
|
52
|
+
"groq": (".providers.groq_provider", "GroqProvider"),
|
53
|
+
"ai21": (".providers.ai21_provider", "AI21Provider"),
|
54
|
+
"ollama": (".providers.ollama_provider", "OllamaProvider"),
|
55
|
+
"cohere": (".providers.cohere_provider", "CohereProvider"),
|
56
|
+
"sambanova": (".providers.sambanova_provider", "SambanovaProvider"),
|
57
|
+
"minimax": (".providers.minimax_provider", "MinimaxProvider"),
|
58
|
+
"targon": (".providers.targon_provider", "TargonProvider"),
|
59
|
+
}
|
60
|
+
|
61
|
+
@classmethod
|
62
|
+
def create_provider(cls, provider_type: str, verbose: bool = False, **kwargs) -> Provider:
|
63
|
+
"""Create a provider instance based on provider type
|
64
|
+
|
65
|
+
Args:
|
66
|
+
provider_type: The type of provider to create
|
67
|
+
**kwargs: Additional parameters to pass to the provider
|
68
|
+
|
69
|
+
Returns:
|
70
|
+
A Provider instance
|
71
|
+
"""
|
72
|
+
provider_type = provider_type.lower()
|
73
|
+
if provider_type not in cls.providers_map:
|
74
|
+
raise ValueError(f"Unknown provider: {provider_type}")
|
75
|
+
|
76
|
+
module_path, class_name = cls.providers_map[provider_type]
|
77
|
+
module = importlib.import_module(module_path, package="yaicli.llms")
|
78
|
+
return getattr(module, class_name)(verbose=verbose, **kwargs)
|
@@ -0,0 +1,66 @@
|
|
1
|
+
from typing import Any, Dict, Generator, Optional
|
2
|
+
|
3
|
+
from openai._streaming import Stream
|
4
|
+
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
5
|
+
|
6
|
+
from ...schemas import LLMResponse, ToolCall
|
7
|
+
from .openai_provider import OpenAIProvider
|
8
|
+
|
9
|
+
|
10
|
+
class AI21Provider(OpenAIProvider):
|
11
|
+
"""AI21 provider implementation based on openai-compatible API"""
|
12
|
+
|
13
|
+
DEFAULT_BASE_URL = "https://api.ai21.com/studio/v1"
|
14
|
+
|
15
|
+
def get_completion_params(self) -> Dict[str, Any]:
|
16
|
+
params = super().get_completion_params()
|
17
|
+
params["max_tokens"] = params.pop("max_completion_tokens")
|
18
|
+
return params
|
19
|
+
|
20
|
+
def _handle_stream_response(self, response: Stream[ChatCompletionChunk]) -> Generator[LLMResponse, None, None]:
|
21
|
+
"""Handle streaming response from AI21 models
|
22
|
+
|
23
|
+
Processes chunks from streaming API, extracting content, reasoning and tool calls.
|
24
|
+
The tool call response is scattered across multiple chunks.
|
25
|
+
|
26
|
+
Args:
|
27
|
+
response: Stream of chat completion chunks from AI21 API
|
28
|
+
|
29
|
+
Yields:
|
30
|
+
Generator yielding LLMResponse objects containing:
|
31
|
+
- reasoning: The thinking/reasoning content (if any)
|
32
|
+
- content: The normal response content
|
33
|
+
- tool_call: Tool call information when applicable
|
34
|
+
"""
|
35
|
+
# Initialize tool call object to accumulate tool call data across chunks
|
36
|
+
tool_call: Optional[ToolCall] = None
|
37
|
+
|
38
|
+
# Process each chunk in the response stream
|
39
|
+
for chunk in response:
|
40
|
+
choice = chunk.choices[0]
|
41
|
+
delta = choice.delta
|
42
|
+
finish_reason = choice.finish_reason
|
43
|
+
|
44
|
+
# Extract content from current chunk
|
45
|
+
content = delta.content or ""
|
46
|
+
|
47
|
+
# Extract reasoning content if available
|
48
|
+
reasoning = self._get_reasoning_content(getattr(delta, "model_extra", None) or delta)
|
49
|
+
|
50
|
+
# Process tool call information that may be scattered across chunks
|
51
|
+
if hasattr(delta, "tool_calls") and delta.tool_calls:
|
52
|
+
tool_call = self._process_tool_call_chunk(delta.tool_calls, tool_call)
|
53
|
+
|
54
|
+
# AI21 specific handling: content cannot be empty for tool calls
|
55
|
+
if finish_reason == "tool_calls" and not content:
|
56
|
+
# tool call assistant message, content can't be empty
|
57
|
+
# Error code: 422 - {'detail': {'error': ['Value error, message content must not be an empty string']}}
|
58
|
+
content = tool_call.id
|
59
|
+
|
60
|
+
# Generate response object
|
61
|
+
yield LLMResponse(
|
62
|
+
reasoning=reasoning,
|
63
|
+
content=content,
|
64
|
+
tool_call=tool_call if finish_reason == "tool_calls" else None,
|
65
|
+
finish_reason=finish_reason,
|
66
|
+
)
|
@@ -0,0 +1,139 @@
|
|
1
|
+
import json
|
2
|
+
from typing import Any, Dict, Generator, Optional
|
3
|
+
|
4
|
+
from openai._streaming import Stream
|
5
|
+
from openai.types.chat.chat_completion import ChatCompletion, Choice
|
6
|
+
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
7
|
+
|
8
|
+
from ...schemas import LLMResponse, ToolCall
|
9
|
+
from .openai_provider import OpenAIProvider
|
10
|
+
|
11
|
+
|
12
|
+
class ChatglmProvider(OpenAIProvider):
|
13
|
+
"""Chatglm provider support"""
|
14
|
+
|
15
|
+
DEFAULT_BASE_URL = "https://open.bigmodel.cn/api/paas/v4/"
|
16
|
+
|
17
|
+
def get_completion_params(self) -> Dict[str, Any]:
|
18
|
+
params = super().get_completion_params()
|
19
|
+
params["max_tokens"] = params.pop("max_completion_tokens")
|
20
|
+
return params
|
21
|
+
|
22
|
+
def _handle_normal_response(self, response: ChatCompletion) -> Generator[LLMResponse, None, None]:
|
23
|
+
"""Handle normal (non-streaming) response
|
24
|
+
Support both openai capabilities and chatglm
|
25
|
+
|
26
|
+
Returns:
|
27
|
+
LLMContent object with:
|
28
|
+
- reasoning: The thinking/reasoning content (if any)
|
29
|
+
- content: The normal response content
|
30
|
+
"""
|
31
|
+
choice = response.choices[0]
|
32
|
+
content = choice.message.content or "" # type: ignore
|
33
|
+
reasoning = choice.message.reasoning_content # type: ignore
|
34
|
+
finish_reason = choice.finish_reason
|
35
|
+
tool_call: Optional[ToolCall] = None
|
36
|
+
|
37
|
+
# Check if the response contains reasoning content
|
38
|
+
if "<think>" in content and "</think>" in content:
|
39
|
+
# Extract reasoning content
|
40
|
+
content = content.lstrip()
|
41
|
+
if content.startswith("<think>"):
|
42
|
+
think_end = content.find("</think>")
|
43
|
+
if think_end != -1:
|
44
|
+
reasoning = content[7:think_end].strip() # Start after <think>
|
45
|
+
# Remove the <think> block from the main content
|
46
|
+
content = content[think_end + 8 :].strip() # Start after </think>
|
47
|
+
# Check if the response contains reasoning content in model_extra
|
48
|
+
elif hasattr(choice.message, "model_extra") and choice.message.model_extra: # type: ignore
|
49
|
+
model_extra = choice.message.model_extra # type: ignore
|
50
|
+
reasoning = self._get_reasoning_content(model_extra)
|
51
|
+
if finish_reason == "tool_calls":
|
52
|
+
if '{"index":' in content or '"tool_calls":' in content:
|
53
|
+
# Tool call data may in content after the <think> block
|
54
|
+
# >/n{"index": 0, "tool_call_id": "call_1", "function": {"name": "name", "arguments": "{}"}, "output": null}
|
55
|
+
tool_index = content.find('{"index":')
|
56
|
+
if tool_index != -1:
|
57
|
+
tmp_content = content[tool_index:]
|
58
|
+
# Tool call data may in content after the <think> block
|
59
|
+
try:
|
60
|
+
choice = self.parse_choice_from_content(tmp_content)
|
61
|
+
except ValueError:
|
62
|
+
pass
|
63
|
+
if hasattr(choice, "message") and hasattr(choice.message, "tool_calls") and choice.message.tool_calls: # type: ignore
|
64
|
+
tool = choice.message.tool_calls[0] # type: ignore
|
65
|
+
tool_call = ToolCall(tool.id, tool.function.name or "", tool.function.arguments)
|
66
|
+
|
67
|
+
yield LLMResponse(reasoning=reasoning, content=content, finish_reason=finish_reason, tool_call=tool_call)
|
68
|
+
|
69
|
+
def _handle_stream_response(self, response: Stream[ChatCompletionChunk]) -> Generator[LLMResponse, None, None]:
|
70
|
+
"""Handle streaming response
|
71
|
+
Support both openai capabilities and chatglm
|
72
|
+
|
73
|
+
Returns:
|
74
|
+
Generator yielding LLMContent objects with:
|
75
|
+
- reasoning: The thinking/reasoning content (if any)
|
76
|
+
- content: The normal response content
|
77
|
+
"""
|
78
|
+
full_reasoning = ""
|
79
|
+
full_content = ""
|
80
|
+
content = ""
|
81
|
+
reasoning = ""
|
82
|
+
tool_id = ""
|
83
|
+
tool_call_name = ""
|
84
|
+
arguments = ""
|
85
|
+
tool_call: Optional[ToolCall] = None
|
86
|
+
for chunk in response:
|
87
|
+
# Check if the response contains reasoning content
|
88
|
+
choice = chunk.choices[0] # type: ignore
|
89
|
+
delta = choice.delta
|
90
|
+
finish_reason = choice.finish_reason
|
91
|
+
|
92
|
+
# Concat content
|
93
|
+
content = delta.content or ""
|
94
|
+
full_content += content
|
95
|
+
|
96
|
+
# Concat reasoning
|
97
|
+
reasoning = self._get_reasoning_content(delta)
|
98
|
+
full_reasoning += reasoning or ""
|
99
|
+
|
100
|
+
if finish_reason:
|
101
|
+
pass
|
102
|
+
if finish_reason == "tool_calls" or ('{"index":' in content or '"tool_calls":' in content):
|
103
|
+
# Tool call data may in content after the <think> block
|
104
|
+
# >/n{"index": 0, "tool_call_id": "call_1", "function": {"name": "name", "arguments": "{}"}, "output": null}
|
105
|
+
tool_index = full_content.find('{"index":')
|
106
|
+
if tool_index != -1:
|
107
|
+
tmp_content = full_content[tool_index:]
|
108
|
+
try:
|
109
|
+
choice = self.parse_choice_from_content(tmp_content)
|
110
|
+
except ValueError:
|
111
|
+
pass
|
112
|
+
if hasattr(choice.delta, "tool_calls") and choice.delta.tool_calls: # type: ignore
|
113
|
+
# Handle tool calls
|
114
|
+
tool_id = choice.delta.tool_calls[0].id or "" # type: ignore
|
115
|
+
for tool in choice.delta.tool_calls: # type: ignore
|
116
|
+
if not tool.function:
|
117
|
+
continue
|
118
|
+
tool_call_name = tool.function.name or ""
|
119
|
+
arguments += tool.function.arguments or ""
|
120
|
+
tool_call = ToolCall(tool_id, tool_call_name, arguments)
|
121
|
+
yield LLMResponse(reasoning=reasoning, content=content, tool_call=tool_call, finish_reason=finish_reason)
|
122
|
+
|
123
|
+
def parse_choice_from_content(self, content: str) -> "Choice":
|
124
|
+
"""
|
125
|
+
Parse the choice from the content after <think>...</think> block.
|
126
|
+
Args:
|
127
|
+
content: The content from the LLM response
|
128
|
+
Returns:
|
129
|
+
The choice object
|
130
|
+
Raises ValueError if the content is not valid JSON
|
131
|
+
"""
|
132
|
+
try:
|
133
|
+
content_dict = json.loads(content)
|
134
|
+
except json.JSONDecodeError:
|
135
|
+
raise ValueError(f"Invalid message from LLM: {content}")
|
136
|
+
try:
|
137
|
+
return Choice.model_validate(content_dict)
|
138
|
+
except Exception as e:
|
139
|
+
raise ValueError(f"Invalid message from LLM: {content}") from e
|