yaicli 0.5.8__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyproject.toml +37 -14
- yaicli/cli.py +31 -20
- yaicli/const.py +6 -5
- yaicli/entry.py +1 -1
- yaicli/llms/__init__.py +13 -0
- yaicli/llms/client.py +120 -0
- yaicli/llms/provider.py +76 -0
- yaicli/llms/providers/ai21_provider.py +65 -0
- yaicli/llms/providers/chatglm_provider.py +134 -0
- yaicli/llms/providers/chutes_provider.py +7 -0
- yaicli/llms/providers/cohere_provider.py +298 -0
- yaicli/llms/providers/deepseek_provider.py +11 -0
- yaicli/llms/providers/doubao_provider.py +51 -0
- yaicli/llms/providers/groq_provider.py +14 -0
- yaicli/llms/providers/infiniai_provider.py +14 -0
- yaicli/llms/providers/modelscope_provider.py +11 -0
- yaicli/llms/providers/ollama_provider.py +187 -0
- yaicli/llms/providers/openai_provider.py +187 -0
- yaicli/llms/providers/openrouter_provider.py +11 -0
- yaicli/llms/providers/sambanova_provider.py +28 -0
- yaicli/llms/providers/siliconflow_provider.py +11 -0
- yaicli/llms/providers/yi_provider.py +7 -0
- yaicli/printer.py +4 -16
- yaicli/schemas.py +12 -3
- yaicli/tools.py +59 -3
- {yaicli-0.5.8.dist-info → yaicli-0.6.0.dist-info}/METADATA +240 -34
- yaicli-0.6.0.dist-info/RECORD +41 -0
- yaicli/client.py +0 -391
- yaicli-0.5.8.dist-info/RECORD +0 -24
- {yaicli-0.5.8.dist-info → yaicli-0.6.0.dist-info}/WHEEL +0 -0
- {yaicli-0.5.8.dist-info → yaicli-0.6.0.dist-info}/entry_points.txt +0 -0
- {yaicli-0.5.8.dist-info → yaicli-0.6.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,187 @@
|
|
1
|
+
from typing import Any, Dict, Generator, List, Optional
|
2
|
+
|
3
|
+
import openai
|
4
|
+
from openai._streaming import Stream
|
5
|
+
from openai.types.chat.chat_completion import ChatCompletion
|
6
|
+
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
7
|
+
|
8
|
+
from ...config import cfg
|
9
|
+
from ...console import get_console
|
10
|
+
from ...schemas import ChatMessage, LLMResponse, ToolCall
|
11
|
+
from ...tools import get_openai_schemas
|
12
|
+
from ..provider import Provider
|
13
|
+
|
14
|
+
|
15
|
+
class OpenAIProvider(Provider):
|
16
|
+
"""OpenAI provider implementation based on openai library"""
|
17
|
+
|
18
|
+
DEFAULT_BASE_URL = "https://api.openai.com/v1"
|
19
|
+
|
20
|
+
def __init__(self, config: dict = cfg, verbose: bool = False, **kwargs):
|
21
|
+
self.config = config
|
22
|
+
self.enable_function = self.config["ENABLE_FUNCTIONS"]
|
23
|
+
self.verbose = verbose
|
24
|
+
# Initialize client params
|
25
|
+
self.client_params = {
|
26
|
+
"api_key": self.config["API_KEY"],
|
27
|
+
"base_url": self.config["BASE_URL"] or self.DEFAULT_BASE_URL,
|
28
|
+
}
|
29
|
+
|
30
|
+
# Add extra headers if set
|
31
|
+
if self.config["EXTRA_HEADERS"]:
|
32
|
+
self.client_params["default_headers"] = {
|
33
|
+
**self.config["EXTRA_HEADERS"],
|
34
|
+
"X-Title": self.APP_NAME,
|
35
|
+
"HTTP-Referer": self.APPA_REFERER,
|
36
|
+
}
|
37
|
+
|
38
|
+
# Initialize client
|
39
|
+
self.client = openai.OpenAI(**self.client_params)
|
40
|
+
self.console = get_console()
|
41
|
+
|
42
|
+
# Store completion params
|
43
|
+
self.completion_params = {
|
44
|
+
"model": self.config["MODEL"],
|
45
|
+
"temperature": self.config["TEMPERATURE"],
|
46
|
+
"top_p": self.config["TOP_P"],
|
47
|
+
"max_completion_tokens": self.config["MAX_TOKENS"],
|
48
|
+
"timeout": self.config["TIMEOUT"],
|
49
|
+
}
|
50
|
+
|
51
|
+
# Add extra body params if set
|
52
|
+
if self.config["EXTRA_BODY"]:
|
53
|
+
self.completion_params["extra_body"] = self.config["EXTRA_BODY"]
|
54
|
+
|
55
|
+
def _convert_messages(self, messages: List[ChatMessage]) -> List[Dict[str, Any]]:
|
56
|
+
"""Convert a list of ChatMessage objects to a list of OpenAI message dicts."""
|
57
|
+
converted_messages = []
|
58
|
+
for msg in messages:
|
59
|
+
message = {"role": msg.role, "content": msg.content or ""}
|
60
|
+
|
61
|
+
if msg.name:
|
62
|
+
message["name"] = msg.name
|
63
|
+
|
64
|
+
if msg.role == "assistant" and msg.tool_calls:
|
65
|
+
message["tool_calls"] = [
|
66
|
+
{"id": tc.id, "type": "function", "function": {"name": tc.name, "arguments": tc.arguments}}
|
67
|
+
for tc in msg.tool_calls
|
68
|
+
]
|
69
|
+
|
70
|
+
if msg.role == "tool" and msg.tool_call_id:
|
71
|
+
message["tool_call_id"] = msg.tool_call_id
|
72
|
+
|
73
|
+
converted_messages.append(message)
|
74
|
+
|
75
|
+
return converted_messages
|
76
|
+
|
77
|
+
def completion(
|
78
|
+
self,
|
79
|
+
messages: List[ChatMessage],
|
80
|
+
stream: bool = False,
|
81
|
+
) -> Generator[LLMResponse, None, None]:
|
82
|
+
"""Send completion request to OpenAI and return responses"""
|
83
|
+
openai_messages = self._convert_messages(messages)
|
84
|
+
if self.verbose:
|
85
|
+
self.console.print("Messages:")
|
86
|
+
self.console.print(openai_messages)
|
87
|
+
|
88
|
+
params = self.completion_params.copy()
|
89
|
+
params["messages"] = openai_messages
|
90
|
+
params["stream"] = stream
|
91
|
+
|
92
|
+
if self.enable_function:
|
93
|
+
tools = get_openai_schemas()
|
94
|
+
if tools:
|
95
|
+
params["tools"] = tools
|
96
|
+
|
97
|
+
if stream:
|
98
|
+
response = self.client.chat.completions.create(**params)
|
99
|
+
yield from self._handle_stream_response(response)
|
100
|
+
else:
|
101
|
+
response = self.client.chat.completions.create(**params)
|
102
|
+
yield from self._handle_normal_response(response)
|
103
|
+
|
104
|
+
def _handle_normal_response(self, response: ChatCompletion) -> Generator[LLMResponse, None, None]:
|
105
|
+
"""Handle normal (non-streaming) response"""
|
106
|
+
choice = response.choices[0]
|
107
|
+
content = choice.message.content or "" # type: ignore
|
108
|
+
reasoning = choice.message.reasoning_content # type: ignore
|
109
|
+
finish_reason = choice.finish_reason
|
110
|
+
tool_call: Optional[ToolCall] = None
|
111
|
+
|
112
|
+
# Check if the response contains reasoning content in model_extra
|
113
|
+
if hasattr(choice.message, "model_extra") and choice.message.model_extra:
|
114
|
+
model_extra = choice.message.model_extra
|
115
|
+
reasoning = self._get_reasoning_content(model_extra)
|
116
|
+
|
117
|
+
if finish_reason == "tool_calls" and hasattr(choice.message, "tool_calls") and choice.message.tool_calls:
|
118
|
+
tool = choice.message.tool_calls[0]
|
119
|
+
tool_call = ToolCall(tool.id, tool.function.name or "", tool.function.arguments)
|
120
|
+
|
121
|
+
yield LLMResponse(reasoning=reasoning, content=content, finish_reason=finish_reason, tool_call=tool_call)
|
122
|
+
|
123
|
+
def _handle_stream_response(self, response: Stream[ChatCompletionChunk]) -> Generator[LLMResponse, None, None]:
|
124
|
+
"""Handle streaming response from OpenAI API"""
|
125
|
+
# Initialize tool call object to accumulate tool call data across chunks
|
126
|
+
tool_call: Optional[ToolCall] = None
|
127
|
+
|
128
|
+
# Process each chunk in the response stream
|
129
|
+
for chunk in response:
|
130
|
+
if not chunk.choices:
|
131
|
+
continue
|
132
|
+
|
133
|
+
choice = chunk.choices[0]
|
134
|
+
delta = choice.delta
|
135
|
+
finish_reason = choice.finish_reason
|
136
|
+
|
137
|
+
# Extract content from current chunk
|
138
|
+
content = delta.content or ""
|
139
|
+
|
140
|
+
# Extract reasoning content if available
|
141
|
+
reasoning = self._get_reasoning_content(getattr(delta, "model_extra", None) or delta)
|
142
|
+
|
143
|
+
# Process tool call information that may be scattered across chunks
|
144
|
+
if hasattr(delta, "tool_calls") and delta.tool_calls:
|
145
|
+
tool_call = self._process_tool_call_chunk(delta.tool_calls, tool_call)
|
146
|
+
|
147
|
+
# Generate response object with tool_call only when finish_reason indicates completion
|
148
|
+
yield LLMResponse(
|
149
|
+
reasoning=reasoning,
|
150
|
+
content=content,
|
151
|
+
tool_call=tool_call if finish_reason == "tool_calls" else None,
|
152
|
+
finish_reason=finish_reason,
|
153
|
+
)
|
154
|
+
|
155
|
+
def _process_tool_call_chunk(self, tool_calls, existing_tool_call=None):
|
156
|
+
"""Process tool call data from a response chunk"""
|
157
|
+
# Initialize tool call object if this is the first chunk with tool call data
|
158
|
+
if existing_tool_call is None and tool_calls:
|
159
|
+
existing_tool_call = ToolCall(tool_calls[0].id or "", tool_calls[0].function.name or "", "")
|
160
|
+
|
161
|
+
# Accumulate arguments from multiple chunks
|
162
|
+
if existing_tool_call:
|
163
|
+
for tool in tool_calls:
|
164
|
+
if not tool.function:
|
165
|
+
continue
|
166
|
+
existing_tool_call.arguments += tool.function.arguments or ""
|
167
|
+
|
168
|
+
return existing_tool_call
|
169
|
+
|
170
|
+
def _get_reasoning_content(self, delta: Any) -> Optional[str]:
|
171
|
+
"""Extract reasoning content from delta if available based on specific keys."""
|
172
|
+
if not delta:
|
173
|
+
return None
|
174
|
+
if not isinstance(delta, dict):
|
175
|
+
delta = dict(delta)
|
176
|
+
# Reasoning content keys from API:
|
177
|
+
# reasoning_content: deepseek/infi-ai
|
178
|
+
# reasoning: openrouter
|
179
|
+
# <think> block implementation not in here
|
180
|
+
for key in ("reasoning_content", "reasoning"):
|
181
|
+
if key in delta:
|
182
|
+
return delta[key]
|
183
|
+
return None
|
184
|
+
|
185
|
+
def detect_tool_role(self) -> str:
|
186
|
+
"""Return the role that should be used for tool responses"""
|
187
|
+
return "tool"
|
@@ -0,0 +1,11 @@
|
|
1
|
+
from .openai_provider import OpenAIProvider
|
2
|
+
|
3
|
+
|
4
|
+
class OpenRouterProvider(OpenAIProvider):
|
5
|
+
"""OpenRouter provider implementation based on openai-compatible API"""
|
6
|
+
|
7
|
+
DEFAULT_BASE_URL = "https://openrouter.ai/api/v1"
|
8
|
+
|
9
|
+
def __init__(self, config: dict = ..., **kwargs):
|
10
|
+
super().__init__(config, **kwargs)
|
11
|
+
self.completion_params["max_tokens"] = self.completion_params.pop("max_completion_tokens")
|
@@ -0,0 +1,28 @@
|
|
1
|
+
from ...const import DEFAULT_TEMPERATURE
|
2
|
+
from .openai_provider import OpenAIProvider
|
3
|
+
|
4
|
+
|
5
|
+
class SambanovaProvider(OpenAIProvider):
|
6
|
+
"""Sambanova provider implementation based on OpenAI API"""
|
7
|
+
|
8
|
+
DEFAULT_BASE_URL = "https://api.sambanova.ai/v1"
|
9
|
+
SUPPORT_FUNCTION_CALL_MOELS = (
|
10
|
+
"Meta-Llama-3.1-8B-Instruct",
|
11
|
+
"Meta-Llama-3.1-405B-Instruct",
|
12
|
+
"Meta-Llama-3.3-70B-Instruct",
|
13
|
+
"Llama-4-Scout-17B-16E-Instruct",
|
14
|
+
"DeepSeek-V3-0324",
|
15
|
+
)
|
16
|
+
|
17
|
+
def __init__(self, config: dict = ..., verbose: bool = False, **kwargs):
|
18
|
+
super().__init__(config, verbose, **kwargs)
|
19
|
+
self.completion_params.pop("presence_penalty", None)
|
20
|
+
self.completion_params.pop("frequency_penalty", None)
|
21
|
+
if self.completion_params.get("temperature") < 0 or self.completion_params.get("temperature") > 1:
|
22
|
+
self.console.print("Sambanova temperature must be between 0 and 1, setting to 0.4", style="yellow")
|
23
|
+
self.completion_params["temperature"] = DEFAULT_TEMPERATURE
|
24
|
+
if self.enable_function and self.config["MODEL"] not in self.SUPPORT_FUNCTION_CALL_MOELS:
|
25
|
+
self.console.print(
|
26
|
+
f"Sambanova supports function call models: {', '.join(self.SUPPORT_FUNCTION_CALL_MOELS)}",
|
27
|
+
style="yellow",
|
28
|
+
)
|
@@ -0,0 +1,11 @@
|
|
1
|
+
from .openai_provider import OpenAIProvider
|
2
|
+
|
3
|
+
|
4
|
+
class SiliconFlowProvider(OpenAIProvider):
|
5
|
+
"""SiliconFlow provider implementation based on openai-compatible API"""
|
6
|
+
|
7
|
+
DEFAULT_BASE_URL = "https://api.siliconflow.cn/v1"
|
8
|
+
|
9
|
+
def __init__(self, config: dict = ..., **kwargs):
|
10
|
+
super().__init__(config, **kwargs)
|
11
|
+
self.completion_params["max_tokens"] = self.completion_params.pop("max_completion_tokens")
|
yaicli/printer.py
CHANGED
@@ -1,18 +1,14 @@
|
|
1
1
|
import time
|
2
2
|
from dataclasses import dataclass, field
|
3
|
-
from typing import
|
3
|
+
from typing import Iterator, List, Tuple, Union
|
4
4
|
|
5
5
|
from rich.console import Group, RenderableType
|
6
6
|
from rich.live import Live
|
7
7
|
|
8
|
-
from .client import RefreshLive
|
9
8
|
from .config import Config, get_config
|
10
9
|
from .console import YaiConsole, get_console
|
11
10
|
from .render import Markdown, plain_formatter
|
12
|
-
from .schemas import
|
13
|
-
|
14
|
-
if TYPE_CHECKING:
|
15
|
-
from .schemas import LLMResponse
|
11
|
+
from .schemas import LLMResponse, RefreshLive
|
16
12
|
|
17
13
|
|
18
14
|
@dataclass
|
@@ -147,9 +143,7 @@ class Printer:
|
|
147
143
|
# Use Rich Group to combine multiple renderables
|
148
144
|
return Group(*display_elements)
|
149
145
|
|
150
|
-
def display_normal(
|
151
|
-
self, content_iterator: Iterator[Union["LLMResponse", RefreshLive]], messages: list["ChatMessage"]
|
152
|
-
) -> tuple[str, str]:
|
146
|
+
def display_normal(self, content_iterator: Iterator[Union["LLMResponse", RefreshLive]]) -> tuple[str, str]:
|
153
147
|
"""Process and display non-stream LLMContent, including reasoning and content parts."""
|
154
148
|
self._reset_state()
|
155
149
|
full_content = full_reasoning = ""
|
@@ -174,13 +168,9 @@ class Printer:
|
|
174
168
|
self.console.print()
|
175
169
|
self.console.print(self.content_formatter(full_content))
|
176
170
|
|
177
|
-
messages.append(ChatMessage(role="assistant", content=full_content))
|
178
|
-
|
179
171
|
return full_content, full_reasoning
|
180
172
|
|
181
|
-
def display_stream(
|
182
|
-
self, stream_iterator: Iterator[Union["LLMResponse", RefreshLive]], messages: list["ChatMessage"]
|
183
|
-
) -> tuple[str, str]:
|
173
|
+
def display_stream(self, stream_iterator: Iterator[Union["LLMResponse", RefreshLive]]) -> tuple[str, str]:
|
184
174
|
"""Process and display LLMContent stream, including reasoning and content parts."""
|
185
175
|
self._reset_state()
|
186
176
|
full_content = full_reasoning = ""
|
@@ -191,7 +181,6 @@ class Printer:
|
|
191
181
|
if isinstance(chunk, RefreshLive):
|
192
182
|
# Refresh live display when in next completion
|
193
183
|
live.stop()
|
194
|
-
messages.append(ChatMessage(role="assistant", content=full_content))
|
195
184
|
live = Live(console=self.console)
|
196
185
|
live.start()
|
197
186
|
# Initialize full_content and full_reasoning for the next completion
|
@@ -210,5 +199,4 @@ class Printer:
|
|
210
199
|
time.sleep(self._UPDATE_INTERVAL)
|
211
200
|
|
212
201
|
live.stop()
|
213
|
-
messages.append(ChatMessage(role="assistant", content=full_content))
|
214
202
|
return full_content, full_reasoning
|
yaicli/schemas.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1
|
-
from dataclasses import dataclass
|
2
|
-
from typing import Optional
|
1
|
+
from dataclasses import dataclass, field
|
2
|
+
from typing import List, Optional
|
3
3
|
|
4
4
|
|
5
5
|
@dataclass
|
@@ -7,9 +7,10 @@ class ChatMessage:
|
|
7
7
|
"""Chat message class"""
|
8
8
|
|
9
9
|
role: str
|
10
|
-
content: str
|
10
|
+
content: Optional[str] = None
|
11
11
|
name: Optional[str] = None
|
12
12
|
tool_call_id: Optional[str] = None
|
13
|
+
tool_calls: List["ToolCall"] = field(default_factory=list)
|
13
14
|
|
14
15
|
|
15
16
|
@dataclass
|
@@ -29,3 +30,11 @@ class LLMResponse:
|
|
29
30
|
content: str = ""
|
30
31
|
finish_reason: Optional[str] = None
|
31
32
|
tool_call: Optional[ToolCall] = None
|
33
|
+
|
34
|
+
|
35
|
+
class RefreshLive:
|
36
|
+
"""Refresh live display"""
|
37
|
+
|
38
|
+
|
39
|
+
class StopLive:
|
40
|
+
"""Stop live display"""
|
yaicli/tools.py
CHANGED
@@ -1,14 +1,20 @@
|
|
1
1
|
import importlib.util
|
2
2
|
import sys
|
3
|
-
from typing import Any, Dict, List, NewType, Optional
|
3
|
+
from typing import Any, Dict, List, NewType, Optional, Tuple, cast
|
4
4
|
|
5
5
|
from instructor import OpenAISchema
|
6
|
+
from json_repair import repair_json
|
7
|
+
from rich.panel import Panel
|
6
8
|
|
9
|
+
from .config import cfg
|
7
10
|
from .console import get_console
|
8
11
|
from .const import FUNCTIONS_DIR
|
12
|
+
from .schemas import ToolCall
|
9
13
|
|
10
14
|
console = get_console()
|
11
15
|
|
16
|
+
FunctionName = NewType("FunctionName", str)
|
17
|
+
|
12
18
|
|
13
19
|
class Function:
|
14
20
|
"""Function description class"""
|
@@ -20,8 +26,6 @@ class Function:
|
|
20
26
|
self.execute = function.execute # type: ignore
|
21
27
|
|
22
28
|
|
23
|
-
FunctionName = NewType("FunctionName", str)
|
24
|
-
|
25
29
|
_func_name_map: Optional[dict[FunctionName, Function]] = None
|
26
30
|
|
27
31
|
|
@@ -101,3 +105,55 @@ def get_openai_schemas() -> List[Dict[str, Any]]:
|
|
101
105
|
}
|
102
106
|
transformed_schemas.append(schema)
|
103
107
|
return transformed_schemas
|
108
|
+
|
109
|
+
|
110
|
+
def execute_tool_call(tool_call: ToolCall) -> Tuple[str, bool]:
|
111
|
+
"""Execute a tool call and return the result
|
112
|
+
|
113
|
+
Args:
|
114
|
+
tool_call: The tool call to execute
|
115
|
+
|
116
|
+
Returns:
|
117
|
+
Tuple[str, bool]: (result text, success flag)
|
118
|
+
"""
|
119
|
+
console.print(f"@Function call: {tool_call.name}({tool_call.arguments})", style="blue")
|
120
|
+
|
121
|
+
# 1. Get the function
|
122
|
+
try:
|
123
|
+
function = get_function(FunctionName(tool_call.name))
|
124
|
+
except ValueError as e:
|
125
|
+
error_msg = f"Function '{tool_call.name!r}' not exists: {e}"
|
126
|
+
console.print(error_msg, style="red")
|
127
|
+
return error_msg, False
|
128
|
+
|
129
|
+
# 2. Parse function arguments
|
130
|
+
try:
|
131
|
+
arguments = repair_json(tool_call.arguments, return_objects=True)
|
132
|
+
if not isinstance(arguments, dict):
|
133
|
+
error_msg = f"Invalid arguments type: {arguments!r}, should be JSON object"
|
134
|
+
console.print(error_msg, style="red")
|
135
|
+
return error_msg, False
|
136
|
+
arguments = cast(dict, arguments)
|
137
|
+
except Exception as e:
|
138
|
+
error_msg = f"Invalid arguments from llm: {e}\nRaw arguments: {tool_call.arguments!r}"
|
139
|
+
console.print(error_msg, style="red")
|
140
|
+
return error_msg, False
|
141
|
+
|
142
|
+
# 3. Execute the function
|
143
|
+
try:
|
144
|
+
function_result = function.execute(**arguments)
|
145
|
+
if cfg["SHOW_FUNCTION_OUTPUT"]:
|
146
|
+
panel = Panel(
|
147
|
+
function_result,
|
148
|
+
title="Function output",
|
149
|
+
title_align="left",
|
150
|
+
expand=False,
|
151
|
+
border_style="blue",
|
152
|
+
style="dim",
|
153
|
+
)
|
154
|
+
console.print(panel)
|
155
|
+
return function_result, True
|
156
|
+
except Exception as e:
|
157
|
+
error_msg = f"Call function error: {e}\nFunction name: {tool_call.name!r}\nArguments: {arguments!r}"
|
158
|
+
console.print(error_msg, style="red")
|
159
|
+
return error_msg, False
|