yaicli 0.5.9__py3-none-any.whl → 0.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. pyproject.toml +35 -12
  2. yaicli/cli.py +31 -20
  3. yaicli/const.py +6 -5
  4. yaicli/entry.py +1 -1
  5. yaicli/llms/__init__.py +13 -0
  6. yaicli/llms/client.py +120 -0
  7. yaicli/llms/provider.py +78 -0
  8. yaicli/llms/providers/ai21_provider.py +66 -0
  9. yaicli/llms/providers/chatglm_provider.py +139 -0
  10. yaicli/llms/providers/chutes_provider.py +14 -0
  11. yaicli/llms/providers/cohere_provider.py +298 -0
  12. yaicli/llms/providers/deepseek_provider.py +14 -0
  13. yaicli/llms/providers/doubao_provider.py +53 -0
  14. yaicli/llms/providers/groq_provider.py +16 -0
  15. yaicli/llms/providers/infiniai_provider.py +20 -0
  16. yaicli/llms/providers/minimax_provider.py +13 -0
  17. yaicli/llms/providers/modelscope_provider.py +14 -0
  18. yaicli/llms/providers/ollama_provider.py +187 -0
  19. yaicli/llms/providers/openai_provider.py +211 -0
  20. yaicli/llms/providers/openrouter_provider.py +14 -0
  21. yaicli/llms/providers/sambanova_provider.py +30 -0
  22. yaicli/llms/providers/siliconflow_provider.py +14 -0
  23. yaicli/llms/providers/targon_provider.py +14 -0
  24. yaicli/llms/providers/yi_provider.py +14 -0
  25. yaicli/printer.py +4 -16
  26. yaicli/schemas.py +12 -3
  27. yaicli/tools.py +59 -3
  28. {yaicli-0.5.9.dist-info → yaicli-0.6.1.dist-info}/METADATA +238 -32
  29. yaicli-0.6.1.dist-info/RECORD +43 -0
  30. yaicli/client.py +0 -391
  31. yaicli-0.5.9.dist-info/RECORD +0 -24
  32. {yaicli-0.5.9.dist-info → yaicli-0.6.1.dist-info}/WHEEL +0 -0
  33. {yaicli-0.5.9.dist-info → yaicli-0.6.1.dist-info}/entry_points.txt +0 -0
  34. {yaicli-0.5.9.dist-info → yaicli-0.6.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,211 @@
1
+ import json
2
+ from typing import Any, Dict, Generator, List, Optional
3
+
4
+ import openai
5
+ from openai._streaming import Stream
6
+ from openai.types.chat.chat_completion import ChatCompletion
7
+ from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
8
+
9
+ from ...config import cfg
10
+ from ...console import get_console
11
+ from ...schemas import ChatMessage, LLMResponse, ToolCall
12
+ from ...tools import get_openai_schemas
13
+ from ..provider import Provider
14
+
15
+
16
+ class OpenAIProvider(Provider):
17
+ """OpenAI provider implementation based on openai library"""
18
+
19
+ DEFAULT_BASE_URL = "https://api.openai.com/v1"
20
+ CLIENT_CLS = openai.OpenAI
21
+
22
+ def __init__(self, config: dict = cfg, verbose: bool = False, **kwargs):
23
+ self.config = config
24
+ self.enable_function = self.config["ENABLE_FUNCTIONS"]
25
+ self.verbose = verbose
26
+
27
+ # Initialize client
28
+ self.client_params = self.get_client_params()
29
+ self.client = self.CLIENT_CLS(**self.client_params)
30
+ self.console = get_console()
31
+
32
+ # Store completion params
33
+ self.completion_params = self.get_completion_params()
34
+
35
+ def get_client_params(self) -> Dict[str, Any]:
36
+ """Get the client parameters"""
37
+ # Initialize client params
38
+ client_params = {
39
+ "api_key": self.config["API_KEY"],
40
+ "base_url": self.config["BASE_URL"] or self.DEFAULT_BASE_URL,
41
+ }
42
+
43
+ # Add extra headers if set
44
+ if self.config["EXTRA_HEADERS"]:
45
+ client_params["default_headers"] = {
46
+ **self.config["EXTRA_HEADERS"],
47
+ "X-Title": self.APP_NAME,
48
+ "HTTP-Referer": self.APPA_REFERER,
49
+ }
50
+ return client_params
51
+
52
+ def get_completion_params(self) -> Dict[str, Any]:
53
+ """Get the completion parameters"""
54
+ completion_params = {
55
+ "model": self.config["MODEL"],
56
+ "temperature": self.config["TEMPERATURE"],
57
+ "top_p": self.config["TOP_P"],
58
+ "max_completion_tokens": self.config["MAX_TOKENS"],
59
+ "timeout": self.config["TIMEOUT"],
60
+ }
61
+ # Add extra body params if set
62
+ if self.config["EXTRA_BODY"]:
63
+ completion_params["extra_body"] = self.config["EXTRA_BODY"]
64
+ return completion_params
65
+
66
+ def _convert_messages(self, messages: List[ChatMessage]) -> List[Dict[str, Any]]:
67
+ """Convert a list of ChatMessage objects to a list of OpenAI message dicts."""
68
+ converted_messages = []
69
+ for msg in messages:
70
+ message = {"role": msg.role, "content": msg.content or ""}
71
+
72
+ if msg.name:
73
+ message["name"] = msg.name
74
+
75
+ if msg.role == "assistant" and msg.tool_calls:
76
+ message["tool_calls"] = [
77
+ {"id": tc.id, "type": "function", "function": {"name": tc.name, "arguments": tc.arguments}}
78
+ for tc in msg.tool_calls
79
+ ]
80
+
81
+ if msg.role == "tool" and msg.tool_call_id:
82
+ message["tool_call_id"] = msg.tool_call_id
83
+
84
+ converted_messages.append(message)
85
+
86
+ return converted_messages
87
+
88
+ def completion(
89
+ self,
90
+ messages: List[ChatMessage],
91
+ stream: bool = False,
92
+ ) -> Generator[LLMResponse, None, None]:
93
+ """Send completion request to OpenAI and return responses"""
94
+ openai_messages = self._convert_messages(messages)
95
+ if self.verbose:
96
+ self.console.print("Messages:")
97
+ self.console.print(openai_messages)
98
+
99
+ params = self.completion_params.copy()
100
+ params["messages"] = openai_messages
101
+ params["stream"] = stream
102
+
103
+ if self.enable_function:
104
+ tools = get_openai_schemas()
105
+ if tools:
106
+ params["tools"] = tools
107
+
108
+ if stream:
109
+ response = self.client.chat.completions.create(**params)
110
+ yield from self._handle_stream_response(response)
111
+ else:
112
+ response = self.client.chat.completions.create(**params)
113
+ yield from self._handle_normal_response(response)
114
+
115
+ def _handle_normal_response(self, response: ChatCompletion) -> Generator[LLMResponse, None, None]:
116
+ """Handle normal (non-streaming) response"""
117
+ if not response.choices:
118
+ yield LLMResponse(
119
+ content=json.dumps(getattr(response, "base_resp", None) or response.to_dict()), finish_reason="stop"
120
+ )
121
+ return
122
+ choice = response.choices[0]
123
+ content = choice.message.content or "" # type: ignore
124
+ reasoning = choice.message.reasoning_content # type: ignore
125
+ finish_reason = choice.finish_reason
126
+ tool_call: Optional[ToolCall] = None
127
+
128
+ # Check if the response contains reasoning content in model_extra
129
+ if hasattr(choice.message, "model_extra") and choice.message.model_extra:
130
+ model_extra = choice.message.model_extra
131
+ reasoning = self._get_reasoning_content(model_extra)
132
+
133
+ if finish_reason == "tool_calls" and hasattr(choice.message, "tool_calls") and choice.message.tool_calls:
134
+ tool = choice.message.tool_calls[0]
135
+ tool_call = ToolCall(tool.id, tool.function.name or "", tool.function.arguments)
136
+
137
+ yield LLMResponse(reasoning=reasoning, content=content, finish_reason=finish_reason, tool_call=tool_call)
138
+
139
+ def _handle_stream_response(self, response: Stream[ChatCompletionChunk]) -> Generator[LLMResponse, None, None]:
140
+ """Handle streaming response from OpenAI API"""
141
+ # Initialize tool call object to accumulate tool call data across chunks
142
+ tool_call: Optional[ToolCall] = None
143
+ started = False
144
+ # Process each chunk in the response stream
145
+ for chunk in response:
146
+ if not chunk.choices and not started:
147
+ # Some api could return error message in the first chunk, no choices to handle, return raw response to show the message
148
+ yield LLMResponse(
149
+ content=json.dumps(getattr(chunk, "base_resp", None) or chunk.to_dict()), finish_reason="stop"
150
+ )
151
+ started = True
152
+ continue
153
+
154
+ if not chunk.choices:
155
+ continue
156
+ started = True
157
+ choice = chunk.choices[0]
158
+ delta = choice.delta
159
+ finish_reason = choice.finish_reason
160
+
161
+ # Extract content from current chunk
162
+ content = delta.content or ""
163
+
164
+ # Extract reasoning content if available
165
+ reasoning = self._get_reasoning_content(getattr(delta, "model_extra", None) or delta)
166
+
167
+ # Process tool call information that may be scattered across chunks
168
+ if hasattr(delta, "tool_calls") and delta.tool_calls:
169
+ tool_call = self._process_tool_call_chunk(delta.tool_calls, tool_call)
170
+
171
+ # Generate response object with tool_call only when finish_reason indicates completion
172
+ yield LLMResponse(
173
+ reasoning=reasoning,
174
+ content=content,
175
+ tool_call=tool_call if finish_reason == "tool_calls" else None,
176
+ finish_reason=finish_reason,
177
+ )
178
+
179
+ def _process_tool_call_chunk(self, tool_calls, existing_tool_call=None):
180
+ """Process tool call data from a response chunk"""
181
+ # Initialize tool call object if this is the first chunk with tool call data
182
+ if existing_tool_call is None and tool_calls:
183
+ existing_tool_call = ToolCall(tool_calls[0].id or "", tool_calls[0].function.name or "", "")
184
+
185
+ # Accumulate arguments from multiple chunks
186
+ if existing_tool_call:
187
+ for tool in tool_calls:
188
+ if not tool.function:
189
+ continue
190
+ existing_tool_call.arguments += tool.function.arguments or ""
191
+
192
+ return existing_tool_call
193
+
194
+ def _get_reasoning_content(self, delta: Any) -> Optional[str]:
195
+ """Extract reasoning content from delta if available based on specific keys."""
196
+ if not delta:
197
+ return None
198
+ if not isinstance(delta, dict):
199
+ delta = dict(delta)
200
+ # Reasoning content keys from API:
201
+ # reasoning_content: deepseek/infi-ai
202
+ # reasoning: openrouter
203
+ # <think> block implementation not in here
204
+ for key in ("reasoning_content", "reasoning"):
205
+ if key in delta:
206
+ return delta[key]
207
+ return None
208
+
209
+ def detect_tool_role(self) -> str:
210
+ """Return the role that should be used for tool responses"""
211
+ return "tool"
@@ -0,0 +1,14 @@
1
+ from typing import Any, Dict
2
+
3
+ from .openai_provider import OpenAIProvider
4
+
5
+
6
+ class OpenRouterProvider(OpenAIProvider):
7
+ """OpenRouter provider implementation based on openai-compatible API"""
8
+
9
+ DEFAULT_BASE_URL = "https://openrouter.ai/api/v1"
10
+
11
+ def get_completion_params(self) -> Dict[str, Any]:
12
+ params = super().get_completion_params()
13
+ params["max_tokens"] = params.pop("max_completion_tokens")
14
+ return params
@@ -0,0 +1,30 @@
1
+ from typing import Any, Dict
2
+
3
+ from ...const import DEFAULT_TEMPERATURE
4
+ from .openai_provider import OpenAIProvider
5
+
6
+
7
+ class SambanovaProvider(OpenAIProvider):
8
+ """Sambanova provider implementation based on OpenAI API"""
9
+
10
+ DEFAULT_BASE_URL = "https://api.sambanova.ai/v1"
11
+ SUPPORT_FUNCTION_CALL_MOELS = (
12
+ "Meta-Llama-3.1-8B-Instruct",
13
+ "Meta-Llama-3.1-405B-Instruct",
14
+ "Meta-Llama-3.3-70B-Instruct",
15
+ "Llama-4-Scout-17B-16E-Instruct",
16
+ "DeepSeek-V3-0324",
17
+ )
18
+
19
+ def get_completion_params(self) -> Dict[str, Any]:
20
+ params = super().get_completion_params()
21
+ params.pop("presence_penalty", None)
22
+ params.pop("frequency_penalty", None)
23
+ if params.get("temperature") < 0 or params.get("temperature") > 1:
24
+ self.console.print("Sambanova temperature must be between 0 and 1, setting to 0.4", style="yellow")
25
+ params["temperature"] = DEFAULT_TEMPERATURE
26
+ if self.enable_function and self.config["MODEL"] not in self.SUPPORT_FUNCTION_CALL_MOELS:
27
+ self.console.print(
28
+ f"Sambanova supports function call models: {', '.join(self.SUPPORT_FUNCTION_CALL_MOELS)}",
29
+ style="yellow",
30
+ )
@@ -0,0 +1,14 @@
1
+ from typing import Any, Dict
2
+
3
+ from .openai_provider import OpenAIProvider
4
+
5
+
6
+ class SiliconFlowProvider(OpenAIProvider):
7
+ """SiliconFlow provider implementation based on openai-compatible API"""
8
+
9
+ DEFAULT_BASE_URL = "https://api.siliconflow.cn/v1"
10
+
11
+ def get_completion_params(self) -> Dict[str, Any]:
12
+ params = super().get_completion_params()
13
+ params["max_tokens"] = params.pop("max_completion_tokens")
14
+ return params
@@ -0,0 +1,14 @@
1
+ from typing import Any, Dict
2
+
3
+ from .openai_provider import OpenAIProvider
4
+
5
+
6
+ class TargonProvider(OpenAIProvider):
7
+ """Targon provider implementation based on openai-compatible API"""
8
+
9
+ DEFAULT_BASE_URL = "https://api.targon.com/v1"
10
+
11
+ def get_completion_params(self) -> Dict[str, Any]:
12
+ params = super().get_completion_params()
13
+ params["max_tokens"] = params.pop("max_completion_tokens")
14
+ return params
@@ -0,0 +1,14 @@
1
+ from typing import Any, Dict
2
+
3
+ from .openai_provider import OpenAIProvider
4
+
5
+
6
+ class YiProvider(OpenAIProvider):
7
+ """Lingyiwanwu provider implementation based on openai-compatible API"""
8
+
9
+ DEFAULT_BASE_URL = "https://api.lingyiwanwu.com/v1"
10
+
11
+ def get_completion_params(self) -> Dict[str, Any]:
12
+ params = super().get_completion_params()
13
+ params["max_tokens"] = params.pop("max_completion_tokens")
14
+ return params
yaicli/printer.py CHANGED
@@ -1,18 +1,14 @@
1
1
  import time
2
2
  from dataclasses import dataclass, field
3
- from typing import TYPE_CHECKING, Iterator, List, Tuple, Union
3
+ from typing import Iterator, List, Tuple, Union
4
4
 
5
5
  from rich.console import Group, RenderableType
6
6
  from rich.live import Live
7
7
 
8
- from .client import RefreshLive
9
8
  from .config import Config, get_config
10
9
  from .console import YaiConsole, get_console
11
10
  from .render import Markdown, plain_formatter
12
- from .schemas import ChatMessage
13
-
14
- if TYPE_CHECKING:
15
- from .schemas import LLMResponse
11
+ from .schemas import LLMResponse, RefreshLive
16
12
 
17
13
 
18
14
  @dataclass
@@ -147,9 +143,7 @@ class Printer:
147
143
  # Use Rich Group to combine multiple renderables
148
144
  return Group(*display_elements)
149
145
 
150
- def display_normal(
151
- self, content_iterator: Iterator[Union["LLMResponse", RefreshLive]], messages: list["ChatMessage"]
152
- ) -> tuple[str, str]:
146
+ def display_normal(self, content_iterator: Iterator[Union["LLMResponse", RefreshLive]]) -> tuple[str, str]:
153
147
  """Process and display non-stream LLMContent, including reasoning and content parts."""
154
148
  self._reset_state()
155
149
  full_content = full_reasoning = ""
@@ -174,13 +168,9 @@ class Printer:
174
168
  self.console.print()
175
169
  self.console.print(self.content_formatter(full_content))
176
170
 
177
- messages.append(ChatMessage(role="assistant", content=full_content))
178
-
179
171
  return full_content, full_reasoning
180
172
 
181
- def display_stream(
182
- self, stream_iterator: Iterator[Union["LLMResponse", RefreshLive]], messages: list["ChatMessage"]
183
- ) -> tuple[str, str]:
173
+ def display_stream(self, stream_iterator: Iterator[Union["LLMResponse", RefreshLive]]) -> tuple[str, str]:
184
174
  """Process and display LLMContent stream, including reasoning and content parts."""
185
175
  self._reset_state()
186
176
  full_content = full_reasoning = ""
@@ -191,7 +181,6 @@ class Printer:
191
181
  if isinstance(chunk, RefreshLive):
192
182
  # Refresh live display when in next completion
193
183
  live.stop()
194
- messages.append(ChatMessage(role="assistant", content=full_content))
195
184
  live = Live(console=self.console)
196
185
  live.start()
197
186
  # Initialize full_content and full_reasoning for the next completion
@@ -210,5 +199,4 @@ class Printer:
210
199
  time.sleep(self._UPDATE_INTERVAL)
211
200
 
212
201
  live.stop()
213
- messages.append(ChatMessage(role="assistant", content=full_content))
214
202
  return full_content, full_reasoning
yaicli/schemas.py CHANGED
@@ -1,5 +1,5 @@
1
- from dataclasses import dataclass
2
- from typing import Optional
1
+ from dataclasses import dataclass, field
2
+ from typing import List, Optional
3
3
 
4
4
 
5
5
  @dataclass
@@ -7,9 +7,10 @@ class ChatMessage:
7
7
  """Chat message class"""
8
8
 
9
9
  role: str
10
- content: str
10
+ content: Optional[str] = None
11
11
  name: Optional[str] = None
12
12
  tool_call_id: Optional[str] = None
13
+ tool_calls: List["ToolCall"] = field(default_factory=list)
13
14
 
14
15
 
15
16
  @dataclass
@@ -29,3 +30,11 @@ class LLMResponse:
29
30
  content: str = ""
30
31
  finish_reason: Optional[str] = None
31
32
  tool_call: Optional[ToolCall] = None
33
+
34
+
35
+ class RefreshLive:
36
+ """Refresh live display"""
37
+
38
+
39
+ class StopLive:
40
+ """Stop live display"""
yaicli/tools.py CHANGED
@@ -1,14 +1,20 @@
1
1
  import importlib.util
2
2
  import sys
3
- from typing import Any, Dict, List, NewType, Optional
3
+ from typing import Any, Dict, List, NewType, Optional, Tuple, cast
4
4
 
5
5
  from instructor import OpenAISchema
6
+ from json_repair import repair_json
7
+ from rich.panel import Panel
6
8
 
9
+ from .config import cfg
7
10
  from .console import get_console
8
11
  from .const import FUNCTIONS_DIR
12
+ from .schemas import ToolCall
9
13
 
10
14
  console = get_console()
11
15
 
16
+ FunctionName = NewType("FunctionName", str)
17
+
12
18
 
13
19
  class Function:
14
20
  """Function description class"""
@@ -20,8 +26,6 @@ class Function:
20
26
  self.execute = function.execute # type: ignore
21
27
 
22
28
 
23
- FunctionName = NewType("FunctionName", str)
24
-
25
29
  _func_name_map: Optional[dict[FunctionName, Function]] = None
26
30
 
27
31
 
@@ -101,3 +105,55 @@ def get_openai_schemas() -> List[Dict[str, Any]]:
101
105
  }
102
106
  transformed_schemas.append(schema)
103
107
  return transformed_schemas
108
+
109
+
110
+ def execute_tool_call(tool_call: ToolCall) -> Tuple[str, bool]:
111
+ """Execute a tool call and return the result
112
+
113
+ Args:
114
+ tool_call: The tool call to execute
115
+
116
+ Returns:
117
+ Tuple[str, bool]: (result text, success flag)
118
+ """
119
+ console.print(f"@Function call: {tool_call.name}({tool_call.arguments})", style="blue")
120
+
121
+ # 1. Get the function
122
+ try:
123
+ function = get_function(FunctionName(tool_call.name))
124
+ except ValueError as e:
125
+ error_msg = f"Function '{tool_call.name!r}' not exists: {e}"
126
+ console.print(error_msg, style="red")
127
+ return error_msg, False
128
+
129
+ # 2. Parse function arguments
130
+ try:
131
+ arguments = repair_json(tool_call.arguments, return_objects=True)
132
+ if not isinstance(arguments, dict):
133
+ error_msg = f"Invalid arguments type: {arguments!r}, should be JSON object"
134
+ console.print(error_msg, style="red")
135
+ return error_msg, False
136
+ arguments = cast(dict, arguments)
137
+ except Exception as e:
138
+ error_msg = f"Invalid arguments from llm: {e}\nRaw arguments: {tool_call.arguments!r}"
139
+ console.print(error_msg, style="red")
140
+ return error_msg, False
141
+
142
+ # 3. Execute the function
143
+ try:
144
+ function_result = function.execute(**arguments)
145
+ if cfg["SHOW_FUNCTION_OUTPUT"]:
146
+ panel = Panel(
147
+ function_result,
148
+ title="Function output",
149
+ title_align="left",
150
+ expand=False,
151
+ border_style="blue",
152
+ style="dim",
153
+ )
154
+ console.print(panel)
155
+ return function_result, True
156
+ except Exception as e:
157
+ error_msg = f"Call function error: {e}\nFunction name: {tool_call.name!r}\nArguments: {arguments!r}"
158
+ console.print(error_msg, style="red")
159
+ return error_msg, False