yaicli 0.5.8__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,187 @@
1
+ from typing import Any, Dict, Generator, List, Optional
2
+
3
+ import openai
4
+ from openai._streaming import Stream
5
+ from openai.types.chat.chat_completion import ChatCompletion
6
+ from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
7
+
8
+ from ...config import cfg
9
+ from ...console import get_console
10
+ from ...schemas import ChatMessage, LLMResponse, ToolCall
11
+ from ...tools import get_openai_schemas
12
+ from ..provider import Provider
13
+
14
+
15
+ class OpenAIProvider(Provider):
16
+ """OpenAI provider implementation based on openai library"""
17
+
18
+ DEFAULT_BASE_URL = "https://api.openai.com/v1"
19
+
20
+ def __init__(self, config: dict = cfg, verbose: bool = False, **kwargs):
21
+ self.config = config
22
+ self.enable_function = self.config["ENABLE_FUNCTIONS"]
23
+ self.verbose = verbose
24
+ # Initialize client params
25
+ self.client_params = {
26
+ "api_key": self.config["API_KEY"],
27
+ "base_url": self.config["BASE_URL"] or self.DEFAULT_BASE_URL,
28
+ }
29
+
30
+ # Add extra headers if set
31
+ if self.config["EXTRA_HEADERS"]:
32
+ self.client_params["default_headers"] = {
33
+ **self.config["EXTRA_HEADERS"],
34
+ "X-Title": self.APP_NAME,
35
+ "HTTP-Referer": self.APPA_REFERER,
36
+ }
37
+
38
+ # Initialize client
39
+ self.client = openai.OpenAI(**self.client_params)
40
+ self.console = get_console()
41
+
42
+ # Store completion params
43
+ self.completion_params = {
44
+ "model": self.config["MODEL"],
45
+ "temperature": self.config["TEMPERATURE"],
46
+ "top_p": self.config["TOP_P"],
47
+ "max_completion_tokens": self.config["MAX_TOKENS"],
48
+ "timeout": self.config["TIMEOUT"],
49
+ }
50
+
51
+ # Add extra body params if set
52
+ if self.config["EXTRA_BODY"]:
53
+ self.completion_params["extra_body"] = self.config["EXTRA_BODY"]
54
+
55
+ def _convert_messages(self, messages: List[ChatMessage]) -> List[Dict[str, Any]]:
56
+ """Convert a list of ChatMessage objects to a list of OpenAI message dicts."""
57
+ converted_messages = []
58
+ for msg in messages:
59
+ message = {"role": msg.role, "content": msg.content or ""}
60
+
61
+ if msg.name:
62
+ message["name"] = msg.name
63
+
64
+ if msg.role == "assistant" and msg.tool_calls:
65
+ message["tool_calls"] = [
66
+ {"id": tc.id, "type": "function", "function": {"name": tc.name, "arguments": tc.arguments}}
67
+ for tc in msg.tool_calls
68
+ ]
69
+
70
+ if msg.role == "tool" and msg.tool_call_id:
71
+ message["tool_call_id"] = msg.tool_call_id
72
+
73
+ converted_messages.append(message)
74
+
75
+ return converted_messages
76
+
77
+ def completion(
78
+ self,
79
+ messages: List[ChatMessage],
80
+ stream: bool = False,
81
+ ) -> Generator[LLMResponse, None, None]:
82
+ """Send completion request to OpenAI and return responses"""
83
+ openai_messages = self._convert_messages(messages)
84
+ if self.verbose:
85
+ self.console.print("Messages:")
86
+ self.console.print(openai_messages)
87
+
88
+ params = self.completion_params.copy()
89
+ params["messages"] = openai_messages
90
+ params["stream"] = stream
91
+
92
+ if self.enable_function:
93
+ tools = get_openai_schemas()
94
+ if tools:
95
+ params["tools"] = tools
96
+
97
+ if stream:
98
+ response = self.client.chat.completions.create(**params)
99
+ yield from self._handle_stream_response(response)
100
+ else:
101
+ response = self.client.chat.completions.create(**params)
102
+ yield from self._handle_normal_response(response)
103
+
104
+ def _handle_normal_response(self, response: ChatCompletion) -> Generator[LLMResponse, None, None]:
105
+ """Handle normal (non-streaming) response"""
106
+ choice = response.choices[0]
107
+ content = choice.message.content or "" # type: ignore
108
+ reasoning = choice.message.reasoning_content # type: ignore
109
+ finish_reason = choice.finish_reason
110
+ tool_call: Optional[ToolCall] = None
111
+
112
+ # Check if the response contains reasoning content in model_extra
113
+ if hasattr(choice.message, "model_extra") and choice.message.model_extra:
114
+ model_extra = choice.message.model_extra
115
+ reasoning = self._get_reasoning_content(model_extra)
116
+
117
+ if finish_reason == "tool_calls" and hasattr(choice.message, "tool_calls") and choice.message.tool_calls:
118
+ tool = choice.message.tool_calls[0]
119
+ tool_call = ToolCall(tool.id, tool.function.name or "", tool.function.arguments)
120
+
121
+ yield LLMResponse(reasoning=reasoning, content=content, finish_reason=finish_reason, tool_call=tool_call)
122
+
123
+ def _handle_stream_response(self, response: Stream[ChatCompletionChunk]) -> Generator[LLMResponse, None, None]:
124
+ """Handle streaming response from OpenAI API"""
125
+ # Initialize tool call object to accumulate tool call data across chunks
126
+ tool_call: Optional[ToolCall] = None
127
+
128
+ # Process each chunk in the response stream
129
+ for chunk in response:
130
+ if not chunk.choices:
131
+ continue
132
+
133
+ choice = chunk.choices[0]
134
+ delta = choice.delta
135
+ finish_reason = choice.finish_reason
136
+
137
+ # Extract content from current chunk
138
+ content = delta.content or ""
139
+
140
+ # Extract reasoning content if available
141
+ reasoning = self._get_reasoning_content(getattr(delta, "model_extra", None) or delta)
142
+
143
+ # Process tool call information that may be scattered across chunks
144
+ if hasattr(delta, "tool_calls") and delta.tool_calls:
145
+ tool_call = self._process_tool_call_chunk(delta.tool_calls, tool_call)
146
+
147
+ # Generate response object with tool_call only when finish_reason indicates completion
148
+ yield LLMResponse(
149
+ reasoning=reasoning,
150
+ content=content,
151
+ tool_call=tool_call if finish_reason == "tool_calls" else None,
152
+ finish_reason=finish_reason,
153
+ )
154
+
155
+ def _process_tool_call_chunk(self, tool_calls, existing_tool_call=None):
156
+ """Process tool call data from a response chunk"""
157
+ # Initialize tool call object if this is the first chunk with tool call data
158
+ if existing_tool_call is None and tool_calls:
159
+ existing_tool_call = ToolCall(tool_calls[0].id or "", tool_calls[0].function.name or "", "")
160
+
161
+ # Accumulate arguments from multiple chunks
162
+ if existing_tool_call:
163
+ for tool in tool_calls:
164
+ if not tool.function:
165
+ continue
166
+ existing_tool_call.arguments += tool.function.arguments or ""
167
+
168
+ return existing_tool_call
169
+
170
+ def _get_reasoning_content(self, delta: Any) -> Optional[str]:
171
+ """Extract reasoning content from delta if available based on specific keys."""
172
+ if not delta:
173
+ return None
174
+ if not isinstance(delta, dict):
175
+ delta = dict(delta)
176
+ # Reasoning content keys from API:
177
+ # reasoning_content: deepseek/infi-ai
178
+ # reasoning: openrouter
179
+ # <think> block implementation not in here
180
+ for key in ("reasoning_content", "reasoning"):
181
+ if key in delta:
182
+ return delta[key]
183
+ return None
184
+
185
+ def detect_tool_role(self) -> str:
186
+ """Return the role that should be used for tool responses"""
187
+ return "tool"
@@ -0,0 +1,11 @@
1
+ from .openai_provider import OpenAIProvider
2
+
3
+
4
+ class OpenRouterProvider(OpenAIProvider):
5
+ """OpenRouter provider implementation based on openai-compatible API"""
6
+
7
+ DEFAULT_BASE_URL = "https://openrouter.ai/api/v1"
8
+
9
+ def __init__(self, config: dict = ..., **kwargs):
10
+ super().__init__(config, **kwargs)
11
+ self.completion_params["max_tokens"] = self.completion_params.pop("max_completion_tokens")
@@ -0,0 +1,28 @@
1
+ from ...const import DEFAULT_TEMPERATURE
2
+ from .openai_provider import OpenAIProvider
3
+
4
+
5
+ class SambanovaProvider(OpenAIProvider):
6
+ """Sambanova provider implementation based on OpenAI API"""
7
+
8
+ DEFAULT_BASE_URL = "https://api.sambanova.ai/v1"
9
+ SUPPORT_FUNCTION_CALL_MOELS = (
10
+ "Meta-Llama-3.1-8B-Instruct",
11
+ "Meta-Llama-3.1-405B-Instruct",
12
+ "Meta-Llama-3.3-70B-Instruct",
13
+ "Llama-4-Scout-17B-16E-Instruct",
14
+ "DeepSeek-V3-0324",
15
+ )
16
+
17
+ def __init__(self, config: dict = ..., verbose: bool = False, **kwargs):
18
+ super().__init__(config, verbose, **kwargs)
19
+ self.completion_params.pop("presence_penalty", None)
20
+ self.completion_params.pop("frequency_penalty", None)
21
+ if self.completion_params.get("temperature") < 0 or self.completion_params.get("temperature") > 1:
22
+ self.console.print("Sambanova temperature must be between 0 and 1, setting to 0.4", style="yellow")
23
+ self.completion_params["temperature"] = DEFAULT_TEMPERATURE
24
+ if self.enable_function and self.config["MODEL"] not in self.SUPPORT_FUNCTION_CALL_MOELS:
25
+ self.console.print(
26
+ f"Sambanova supports function call models: {', '.join(self.SUPPORT_FUNCTION_CALL_MOELS)}",
27
+ style="yellow",
28
+ )
@@ -0,0 +1,11 @@
1
+ from .openai_provider import OpenAIProvider
2
+
3
+
4
+ class SiliconFlowProvider(OpenAIProvider):
5
+ """SiliconFlow provider implementation based on openai-compatible API"""
6
+
7
+ DEFAULT_BASE_URL = "https://api.siliconflow.cn/v1"
8
+
9
+ def __init__(self, config: dict = ..., **kwargs):
10
+ super().__init__(config, **kwargs)
11
+ self.completion_params["max_tokens"] = self.completion_params.pop("max_completion_tokens")
@@ -0,0 +1,7 @@
1
+ from .openai_provider import OpenAIProvider
2
+
3
+
4
+ class YiProvider(OpenAIProvider):
5
+ """Yi provider implementation based on openai-compatible API"""
6
+
7
+ DEFAULT_BASE_URL = "https://api.lingyiwanwu.com/v1"
yaicli/printer.py CHANGED
@@ -1,18 +1,14 @@
1
1
  import time
2
2
  from dataclasses import dataclass, field
3
- from typing import TYPE_CHECKING, Iterator, List, Tuple, Union
3
+ from typing import Iterator, List, Tuple, Union
4
4
 
5
5
  from rich.console import Group, RenderableType
6
6
  from rich.live import Live
7
7
 
8
- from .client import RefreshLive
9
8
  from .config import Config, get_config
10
9
  from .console import YaiConsole, get_console
11
10
  from .render import Markdown, plain_formatter
12
- from .schemas import ChatMessage
13
-
14
- if TYPE_CHECKING:
15
- from .schemas import LLMResponse
11
+ from .schemas import LLMResponse, RefreshLive
16
12
 
17
13
 
18
14
  @dataclass
@@ -147,9 +143,7 @@ class Printer:
147
143
  # Use Rich Group to combine multiple renderables
148
144
  return Group(*display_elements)
149
145
 
150
- def display_normal(
151
- self, content_iterator: Iterator[Union["LLMResponse", RefreshLive]], messages: list["ChatMessage"]
152
- ) -> tuple[str, str]:
146
+ def display_normal(self, content_iterator: Iterator[Union["LLMResponse", RefreshLive]]) -> tuple[str, str]:
153
147
  """Process and display non-stream LLMContent, including reasoning and content parts."""
154
148
  self._reset_state()
155
149
  full_content = full_reasoning = ""
@@ -174,13 +168,9 @@ class Printer:
174
168
  self.console.print()
175
169
  self.console.print(self.content_formatter(full_content))
176
170
 
177
- messages.append(ChatMessage(role="assistant", content=full_content))
178
-
179
171
  return full_content, full_reasoning
180
172
 
181
- def display_stream(
182
- self, stream_iterator: Iterator[Union["LLMResponse", RefreshLive]], messages: list["ChatMessage"]
183
- ) -> tuple[str, str]:
173
+ def display_stream(self, stream_iterator: Iterator[Union["LLMResponse", RefreshLive]]) -> tuple[str, str]:
184
174
  """Process and display LLMContent stream, including reasoning and content parts."""
185
175
  self._reset_state()
186
176
  full_content = full_reasoning = ""
@@ -191,7 +181,6 @@ class Printer:
191
181
  if isinstance(chunk, RefreshLive):
192
182
  # Refresh live display when in next completion
193
183
  live.stop()
194
- messages.append(ChatMessage(role="assistant", content=full_content))
195
184
  live = Live(console=self.console)
196
185
  live.start()
197
186
  # Initialize full_content and full_reasoning for the next completion
@@ -210,5 +199,4 @@ class Printer:
210
199
  time.sleep(self._UPDATE_INTERVAL)
211
200
 
212
201
  live.stop()
213
- messages.append(ChatMessage(role="assistant", content=full_content))
214
202
  return full_content, full_reasoning
yaicli/schemas.py CHANGED
@@ -1,5 +1,5 @@
1
- from dataclasses import dataclass
2
- from typing import Optional
1
+ from dataclasses import dataclass, field
2
+ from typing import List, Optional
3
3
 
4
4
 
5
5
  @dataclass
@@ -7,9 +7,10 @@ class ChatMessage:
7
7
  """Chat message class"""
8
8
 
9
9
  role: str
10
- content: str
10
+ content: Optional[str] = None
11
11
  name: Optional[str] = None
12
12
  tool_call_id: Optional[str] = None
13
+ tool_calls: List["ToolCall"] = field(default_factory=list)
13
14
 
14
15
 
15
16
  @dataclass
@@ -29,3 +30,11 @@ class LLMResponse:
29
30
  content: str = ""
30
31
  finish_reason: Optional[str] = None
31
32
  tool_call: Optional[ToolCall] = None
33
+
34
+
35
+ class RefreshLive:
36
+ """Refresh live display"""
37
+
38
+
39
+ class StopLive:
40
+ """Stop live display"""
yaicli/tools.py CHANGED
@@ -1,14 +1,20 @@
1
1
  import importlib.util
2
2
  import sys
3
- from typing import Any, Dict, List, NewType, Optional
3
+ from typing import Any, Dict, List, NewType, Optional, Tuple, cast
4
4
 
5
5
  from instructor import OpenAISchema
6
+ from json_repair import repair_json
7
+ from rich.panel import Panel
6
8
 
9
+ from .config import cfg
7
10
  from .console import get_console
8
11
  from .const import FUNCTIONS_DIR
12
+ from .schemas import ToolCall
9
13
 
10
14
  console = get_console()
11
15
 
16
+ FunctionName = NewType("FunctionName", str)
17
+
12
18
 
13
19
  class Function:
14
20
  """Function description class"""
@@ -20,8 +26,6 @@ class Function:
20
26
  self.execute = function.execute # type: ignore
21
27
 
22
28
 
23
- FunctionName = NewType("FunctionName", str)
24
-
25
29
  _func_name_map: Optional[dict[FunctionName, Function]] = None
26
30
 
27
31
 
@@ -101,3 +105,55 @@ def get_openai_schemas() -> List[Dict[str, Any]]:
101
105
  }
102
106
  transformed_schemas.append(schema)
103
107
  return transformed_schemas
108
+
109
+
110
+ def execute_tool_call(tool_call: ToolCall) -> Tuple[str, bool]:
111
+ """Execute a tool call and return the result
112
+
113
+ Args:
114
+ tool_call: The tool call to execute
115
+
116
+ Returns:
117
+ Tuple[str, bool]: (result text, success flag)
118
+ """
119
+ console.print(f"@Function call: {tool_call.name}({tool_call.arguments})", style="blue")
120
+
121
+ # 1. Get the function
122
+ try:
123
+ function = get_function(FunctionName(tool_call.name))
124
+ except ValueError as e:
125
+ error_msg = f"Function '{tool_call.name!r}' not exists: {e}"
126
+ console.print(error_msg, style="red")
127
+ return error_msg, False
128
+
129
+ # 2. Parse function arguments
130
+ try:
131
+ arguments = repair_json(tool_call.arguments, return_objects=True)
132
+ if not isinstance(arguments, dict):
133
+ error_msg = f"Invalid arguments type: {arguments!r}, should be JSON object"
134
+ console.print(error_msg, style="red")
135
+ return error_msg, False
136
+ arguments = cast(dict, arguments)
137
+ except Exception as e:
138
+ error_msg = f"Invalid arguments from llm: {e}\nRaw arguments: {tool_call.arguments!r}"
139
+ console.print(error_msg, style="red")
140
+ return error_msg, False
141
+
142
+ # 3. Execute the function
143
+ try:
144
+ function_result = function.execute(**arguments)
145
+ if cfg["SHOW_FUNCTION_OUTPUT"]:
146
+ panel = Panel(
147
+ function_result,
148
+ title="Function output",
149
+ title_align="left",
150
+ expand=False,
151
+ border_style="blue",
152
+ style="dim",
153
+ )
154
+ console.print(panel)
155
+ return function_result, True
156
+ except Exception as e:
157
+ error_msg = f"Call function error: {e}\nFunction name: {tool_call.name!r}\nArguments: {arguments!r}"
158
+ console.print(error_msg, style="red")
159
+ return error_msg, False