yaicli 0.5.8__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,7 @@
1
+ from .openai_provider import OpenAIProvider
2
+
3
+
4
+ class ChutesProvider(OpenAIProvider):
5
+ """Chutes provider implementation based on openai-compatible API"""
6
+
7
+ DEFAULT_BASE_URL = "https://llm.chutes.ai/v1"
@@ -0,0 +1,298 @@
1
+ """
2
+ Cohere API provider implementation
3
+
4
+ This module implements Cohere provider classes for different deployment options:
5
+ - CohereProvider: Standard Cohere API
6
+ - CohereBadrockProvider: AWS Bedrock integration
7
+ - CohereSagemaker: AWS Sagemaker integration
8
+ """
9
+
10
+ from typing import Any, Dict, Generator, List, Optional
11
+
12
+ from cohere import BedrockClientV2, ClientV2, SagemakerClientV2
13
+ from cohere.types.tool_call_v2 import ToolCallV2, ToolCallV2Function
14
+
15
+ from ...config import cfg
16
+ from ...console import get_console
17
+ from ...schemas import ChatMessage, LLMResponse, ToolCall
18
+ from ...tools import get_openai_schemas
19
+ from ..provider import Provider
20
+
21
+
22
+ class CohereProvider(Provider):
23
+ """Cohere provider implementation based on cohere library"""
24
+
25
+ DEFAULT_BASE_URL = "https://api.cohere.com/v2"
26
+ CLIENT_CLS = ClientV2
27
+ DEFAULT_MODEL = "command-a-03-2025"
28
+
29
+ def __init__(self, config: dict = cfg, verbose: bool = False, **kwargs):
30
+ """
31
+ Initialize the Cohere provider
32
+
33
+ Args:
34
+ config: Configuration dictionary
35
+ verbose: Whether to enable verbose logging
36
+ **kwargs: Additional parameters passed to the client
37
+ """
38
+ self.config = config
39
+ self.verbose = verbose
40
+ self.client_params = {
41
+ "api_key": self.config["API_KEY"],
42
+ "timeout": self.config["TIMEOUT"],
43
+ }
44
+ if self.config["BASE_URL"]:
45
+ self.client_params["base_url"] = self.config["BASE_URL"]
46
+ self.client = self.create_client()
47
+ self.console = get_console()
48
+
49
+ def create_client(self):
50
+ """Create and return Cohere client instance"""
51
+ if self.config.get("ENVIRONMENT"):
52
+ self.client_params["environment"] = self.config["ENVIRONMENT"]
53
+ return self.CLIENT_CLS(**self.client_params)
54
+
55
+ def detect_tool_role(self) -> str:
56
+ """Return the role name for tool response messages"""
57
+ return "tool"
58
+
59
+ def _convert_messages(self, messages: List[ChatMessage]) -> List[Dict[str, Any]]:
60
+ """
61
+ Convert a list of ChatMessage objects to a list of Cohere message dicts
62
+
63
+ {
64
+ "role": "tool",
65
+ "tool_call_id": tc.id,
66
+ "content": {
67
+ "type": "document",
68
+ "document": {"data": string},
69
+ },
70
+ }
71
+
72
+ Args:
73
+ messages: List of ChatMessage objects
74
+
75
+ Returns:
76
+ List of message dicts formatted for Cohere API
77
+ """
78
+ converted_messages = []
79
+ for msg in messages:
80
+ # Create base message
81
+ message = {}
82
+
83
+ # Set role always
84
+ message["role"] = msg.role
85
+
86
+ # Add tool calls for assistant messages
87
+ if msg.role == "assistant" and msg.tool_calls:
88
+ # {
89
+ # "role": "assistant",
90
+ # "tool_calls": response.message.tool_calls,
91
+ # "tool_plan": response.message.tool_plan,
92
+ # }
93
+ message["tool_calls"] = [
94
+ ToolCallV2(
95
+ id=tc.id,
96
+ type="function",
97
+ function=ToolCallV2Function(name=tc.name, arguments=tc.arguments),
98
+ )
99
+ for tc in msg.tool_calls
100
+ ]
101
+ else:
102
+ # Add content for non-tool-call messages
103
+ message["content"] = msg.content or ""
104
+
105
+ # Add tool call ID for tool messages
106
+ if msg.role == "tool" and msg.tool_call_id:
107
+ message["tool_call_id"] = msg.tool_call_id
108
+
109
+ # For tool messages, convert content to the expected document format
110
+ if msg.content:
111
+ message["content"] = [{"type": "document", "document": {"data": msg.content}}]
112
+
113
+ converted_messages.append(message)
114
+
115
+ return converted_messages
116
+
117
+ def _prepare_tools(self) -> Optional[List[Dict[str, Any]]]:
118
+ """
119
+ Prepare tools for Cohere API if enabled
120
+
121
+ Returns:
122
+ List of tool definitions or None if disabled
123
+ """
124
+ if not self.config.get("ENABLE_FUNCTIONS", False):
125
+ return None
126
+
127
+ tools = get_openai_schemas()
128
+ if not tools and self.verbose:
129
+ self.console.print("No tools available", style="yellow")
130
+ return tools
131
+
132
+ def _handle_streaming_response(self, response_stream) -> Generator[LLMResponse, None, None]:
133
+ """
134
+ Process streaming response from Cohere API
135
+
136
+ doc: https://docs.cohere.com/v2/docs/streaming
137
+
138
+ According to Cohere docs, there are multiple event types:
139
+ - message-start: First event with metadata
140
+ - content-start: Start of content block
141
+ - content-delta: Chunk of generated text
142
+ - content-end: End of content block
143
+ - message-end: End of message
144
+ - tool-plan-delta: Part of tool planning
145
+ - tool-call-start: Start of tool call
146
+ - tool-call-delta: Part of tool call
147
+ - tool-call-end: End of tool call
148
+ - citation-start/end: For citations in RAG
149
+
150
+ Args:
151
+ response_stream: Stream from Cohere client
152
+
153
+ Yields:
154
+ LLMResponse objects with content or tool calls
155
+ """
156
+ tool_call: Optional[ToolCall] = None
157
+ for chunk in response_stream:
158
+ if not chunk:
159
+ continue
160
+
161
+ # Handle different event types
162
+ if chunk.type == "content-delta":
163
+ # Text generation chunks
164
+ content = chunk.delta.message.content.text or ""
165
+ yield LLMResponse(content=content)
166
+
167
+ elif chunk.type == "tool-plan-delta":
168
+ # Tool planning - when model is deciding which tool to use: cohere.types.chat_tool_plan_delta_event_delta_message.ChatToolPlanDeltaEventDeltaMessage
169
+ content = chunk.delta.message.tool_plan or ""
170
+ yield LLMResponse(content=content)
171
+
172
+ elif chunk.type == "tool-call-start":
173
+ # Start of tool call
174
+ tool_call_msg = chunk.delta.message.tool_calls
175
+ tool_call = ToolCall(
176
+ id=tool_call_msg.id, name=tool_call_msg.function.name, arguments=tool_call_msg.function.arguments
177
+ )
178
+ # Tool call started, waiting for tool-calls-delta events
179
+ continue
180
+ elif chunk.type == "tool-call-delta":
181
+ # Tool call arguments being generated: cohere.types.chat_tool_call_delta_event_delta_message.ChatToolCallDeltaEventDeltaMessage
182
+ tool_call.arguments += chunk.delta.message.tool_calls.function.arguments
183
+ # Waiting for tool-call-end event
184
+ continue
185
+
186
+ elif chunk.type == "tool-call-end":
187
+ # End of a tool call, empty chunk
188
+ yield LLMResponse(tool_call=tool_call)
189
+
190
+ def _handle_normal_response(self, response) -> Generator[LLMResponse, None, None]:
191
+ """
192
+ Process non-streaming response from Cohere API
193
+
194
+ Args:
195
+ response: Response from Cohere client
196
+
197
+ Yields:
198
+ LLMResponse objects with content or tool calls
199
+ """
200
+ # Handle content
201
+ if response.message.content:
202
+ for content_item in response.message.content:
203
+ if hasattr(content_item, "text") and content_item.text:
204
+ yield LLMResponse(content=content_item.text)
205
+
206
+ # Handle tool calls
207
+ if response.message.tool_calls:
208
+ yield LLMResponse(content=response.message.tool_plan)
209
+ for tool_call in response.message.tool_calls:
210
+ yield LLMResponse(
211
+ tool_call=ToolCall(
212
+ id=tool_call.id,
213
+ name=tool_call.function.name,
214
+ arguments=tool_call.function.arguments,
215
+ )
216
+ )
217
+
218
+ def completion(
219
+ self, messages: List[ChatMessage], stream: bool = False, **kwargs
220
+ ) -> Generator[LLMResponse, None, None]:
221
+ """
222
+ Get completion from Cohere models
223
+
224
+ Args:
225
+ messages: List of messages for the conversation
226
+ stream: Whether to stream the response
227
+ **kwargs: Additional parameters to pass to the Cohere client
228
+
229
+ Yields:
230
+ LLMResponse objects with content or tool calls
231
+ """
232
+ # Get configuration values
233
+ model = self.config.get("MODEL", self.DEFAULT_MODEL)
234
+ temperature = float(self.config.get("TEMPERATURE", 0.7))
235
+
236
+ # Prepare messages and tools
237
+ cohere_messages = self._convert_messages(messages)
238
+ if self.verbose:
239
+ self.console.print("Messages:")
240
+ self.console.print(cohere_messages)
241
+ tools = self._prepare_tools()
242
+
243
+ # Common request parameters
244
+ request_params = {"model": model, "messages": cohere_messages, "temperature": temperature, **kwargs}
245
+
246
+ # Add tools if available
247
+ if tools:
248
+ request_params["tools"] = tools
249
+
250
+ # Call Cohere API
251
+ try:
252
+ if stream:
253
+ # Streaming mode
254
+ response_stream = self.client.chat_stream(**request_params)
255
+ yield from self._handle_streaming_response(response_stream)
256
+ else:
257
+ # Non-streaming mode
258
+ response = self.client.chat(**request_params)
259
+ yield from self._handle_normal_response(response)
260
+
261
+ except Exception as e:
262
+ error_msg = f"Error in Cohere API call: {e}"
263
+ if self.verbose:
264
+ import traceback
265
+
266
+ self.console.print("Error in Cohere completion:")
267
+ traceback.print_exc()
268
+ yield LLMResponse(content=error_msg)
269
+
270
+
271
+ class CohereBadrockProvider(CohereProvider):
272
+ """Cohere provider for AWS Bedrock integration"""
273
+
274
+ CLIENT_CLS = BedrockClientV2
275
+ DOC_URL = "https://docs.cohere.com/v2/docs/text-gen-quickstart"
276
+ CLIENT_KEYS = (
277
+ ("AWS_REGION", "aws_region"),
278
+ ("AWS_ACCESS_KEY_ID", "aws_access_key"),
279
+ ("AWS_SECRET_ACCESS_KEY", "aws_secret_key"),
280
+ ("AWS_SESSION_TOKEN", "aws_session_token"),
281
+ )
282
+
283
+ def create_client(self):
284
+ """Create Bedrock client with AWS credentials"""
285
+ for k, p in self.CLIENT_KEYS:
286
+ v = self.config.get(k, None)
287
+ if v is None:
288
+ raise ValueError(
289
+ f"You have to set key `{k}` to use {self.__class__.__name__}, see cohere doc `{self.DOC_URL}`"
290
+ )
291
+ self.client_params[p] = v
292
+ return self.CLIENT_CLS(**self.client_params)
293
+
294
+
295
+ class CohereSagemaker(CohereBadrockProvider):
296
+ """Cohere provider for AWS Sagemaker integration"""
297
+
298
+ CLIENT_CLS = SagemakerClientV2
@@ -0,0 +1,11 @@
1
+ from .openai_provider import OpenAIProvider
2
+
3
+
4
+ class DeepSeekProvider(OpenAIProvider):
5
+ """DeepSeek provider implementation based on openai-compatible API"""
6
+
7
+ DEFAULT_BASE_URL = "https://api.deepseek.com/v1"
8
+
9
+ def __init__(self, config: dict = ..., **kwargs):
10
+ super().__init__(config, **kwargs)
11
+ self.completion_params["max_tokens"] = self.completion_params.pop("max_completion_tokens")
@@ -0,0 +1,51 @@
1
+ from volcenginesdkarkruntime import Ark
2
+
3
+ from ...config import cfg
4
+ from ...console import get_console
5
+ from .openai_provider import OpenAIProvider
6
+
7
+
8
+ class DoubaoProvider(OpenAIProvider):
9
+ """Doubao provider implementation based on openai-compatible API"""
10
+
11
+ DEFAULT_BASE_URL = "https://ark.cn-beijing.volces.com/api/v3"
12
+
13
+ def __init__(self, config: dict = cfg, **kwargs):
14
+ self.config = config
15
+ self.enable_function = self.config["ENABLE_FUNCTIONS"]
16
+ # Initialize client params
17
+ self.client_params = {"base_url": self.DEFAULT_BASE_URL}
18
+ if self.config.get("API_KEY", None):
19
+ self.client_params["api_key"] = self.config["API_KEY"]
20
+ if self.config.get("BASE_URL", None):
21
+ self.client_params["base_url"] = self.config["BASE_URL"]
22
+ if self.config.get("AK", None):
23
+ self.client_params["ak"] = self.config["AK"]
24
+ if self.config.get("SK", None):
25
+ self.client_params["sk"] = self.config["SK"]
26
+ if self.config.get("REGION", None):
27
+ self.client_params["region"] = self.config["REGION"]
28
+
29
+ # Initialize client
30
+ self.client = Ark(**self.client_params)
31
+ self.console = get_console()
32
+
33
+ # Store completion params
34
+ self.completion_params = {
35
+ "model": self.config["MODEL"],
36
+ "temperature": self.config["TEMPERATURE"],
37
+ "top_p": self.config["TOP_P"],
38
+ "max_tokens": self.config["MAX_TOKENS"],
39
+ "timeout": self.config["TIMEOUT"],
40
+ }
41
+ # Add extra headers if set
42
+ if self.config.get("EXTRA_HEADERS", None):
43
+ self.completion_params["extra_headers"] = {
44
+ **self.config["EXTRA_HEADERS"],
45
+ "X-Title": self.APP_NAME,
46
+ "HTTP-Referer": self.APPA_REFERER,
47
+ }
48
+
49
+ # Add extra body params if set
50
+ if self.config.get("EXTRA_BODY", None):
51
+ self.completion_params["extra_body"] = self.config["EXTRA_BODY"]
@@ -0,0 +1,14 @@
1
+ from .openai_provider import OpenAIProvider
2
+
3
+
4
+ class GroqProvider(OpenAIProvider):
5
+ """Groq provider implementation based on openai-compatible API"""
6
+
7
+ DEFAULT_BASE_URL = "https://api.groq.com/openai/v1"
8
+
9
+ def __init__(self, config: dict = ..., **kwargs):
10
+ super().__init__(config, **kwargs)
11
+ if self.config.get("EXTRA_BODY") and "N" in self.config["EXTRA_BODY"] and self.config["EXTRA_BODY"]["N"] != 1:
12
+ self.console.print("Groq does not support N parameter, setting N to 1 as Groq default", style="yellow")
13
+ if "extra_body" in self.completion_params:
14
+ self.completion_params["extra_body"]["N"] = 1
@@ -0,0 +1,14 @@
1
+ from .openai_provider import OpenAIProvider
2
+
3
+
4
+ class InfiniAIProvider(OpenAIProvider):
5
+ """InfiniAI provider implementation based on openai-compatible API"""
6
+
7
+ DEFAULT_BASE_URL = "https://cloud.infini-ai.com/maas/v1"
8
+
9
+ def __init__(self, config: dict = ..., **kwargs):
10
+ super().__init__(config, **kwargs)
11
+ if self.enable_function:
12
+ self.console.print("InfiniAI does not support functions, disabled", style="yellow")
13
+ self.enable_function = False
14
+ self.completion_params["max_tokens"] = self.completion_params.pop("max_completion_tokens")
@@ -0,0 +1,11 @@
1
+ from .openai_provider import OpenAIProvider
2
+
3
+
4
+ class ModelScopeProvider(OpenAIProvider):
5
+ """ModelScope provider implementation based on openai-compatible API"""
6
+
7
+ DEFAULT_BASE_URL = "https://api-inference.modelscope.cn/v1/"
8
+
9
+ def __init__(self, config: dict = ..., **kwargs):
10
+ super().__init__(config, **kwargs)
11
+ self.completion_params["max_tokens"] = self.completion_params.pop("max_completion_tokens")
@@ -0,0 +1,187 @@
1
+ import json
2
+ import time
3
+ from typing import Any, Dict, Generator, List
4
+
5
+ import ollama
6
+
7
+ from ...config import cfg
8
+ from ...console import get_console
9
+ from ...schemas import ChatMessage, LLMResponse, ToolCall
10
+ from ...tools import get_openai_schemas
11
+ from ...utils import str2bool
12
+ from ..provider import Provider
13
+
14
+
15
+ class OllamaProvider(Provider):
16
+ """Ollama provider implementation based on ollama Python library"""
17
+
18
+ DEFAULT_BASE_URL = "http://localhost:11434"
19
+ OPTION_KEYS = (
20
+ ("SEED", "seed"),
21
+ ("NUM_PREDICT", "num_predict"),
22
+ ("NUM_CTX", "num_ctx"),
23
+ ("NUM_BATCH", "num_batch"),
24
+ ("NUM_GPU", "num_gpu"),
25
+ ("MAIN_GPU", "main_gpu"),
26
+ ("LOW_VRAM", "low_vram"),
27
+ ("F16_KV", "f16_kv"),
28
+ ("LOGITS_ALL", "logits_all"),
29
+ ("VOCAB_ONLY", "vocab_only"),
30
+ ("USE_MMAP", "use_mmap"),
31
+ ("USE_MLOCK", "use_mlock"),
32
+ ("NUM_THREAD", "num_thread"),
33
+ )
34
+
35
+ def __init__(self, config: dict = cfg, verbose: bool = False, **kwargs):
36
+ self.config = config
37
+ self.enable_function = self.config.get("ENABLE_FUNCTIONS", False)
38
+ self.verbose = verbose
39
+ self.think = str2bool(self.config.get("THINK", False))
40
+
41
+ # Initialize client params - Ollama host support
42
+ self.host = self.config.get("BASE_URL") or self.DEFAULT_BASE_URL
43
+
44
+ # Initialize console
45
+ self.console = get_console()
46
+
47
+ self.client = ollama.Client(host=self.host, timeout=self.config["TIMEOUT"])
48
+
49
+ def _convert_messages(self, messages: List[ChatMessage]) -> List[Dict[str, Any]]:
50
+ """Convert a list of ChatMessage objects to a list of Ollama message dicts."""
51
+ converted_messages = []
52
+ for msg in messages:
53
+ message = {"role": msg.role, "content": msg.content or ""}
54
+
55
+ if msg.name:
56
+ message["name"] = msg.name
57
+
58
+ # Handle tool calls - Ollama now supports the OpenAI format directly
59
+ if msg.role == "assistant" and msg.tool_calls:
60
+ message["tool_calls"] = [
61
+ {
62
+ "id": tc.id,
63
+ "type": "function",
64
+ "function": {"name": tc.name, "arguments": json.loads(tc.arguments)},
65
+ }
66
+ for tc in msg.tool_calls
67
+ ]
68
+
69
+ # Handle tool responses - Ollama supports tool_call_id directly
70
+ if msg.role == "tool" and msg.tool_call_id:
71
+ message["tool_call_id"] = msg.tool_call_id
72
+
73
+ converted_messages.append(message)
74
+
75
+ return converted_messages
76
+
77
+ def completion(
78
+ self,
79
+ messages: List[ChatMessage],
80
+ stream: bool = False,
81
+ ) -> Generator[LLMResponse, None, None]:
82
+ """Send messages to Ollama and get response"""
83
+ # Convert message format
84
+ ollama_messages = self._convert_messages(messages)
85
+ if self.verbose:
86
+ self.console.print("Messages:")
87
+ self.console.print(ollama_messages)
88
+ options = {"temperature": self.config["TEMPERATURE"], "top_p": self.config["TOP_P"]}
89
+ for k, v in self.OPTION_KEYS:
90
+ if self.config.get(k, None) is not None:
91
+ options[v] = self.config[k]
92
+
93
+ # Prepare parameters
94
+ params = {
95
+ "model": self.config.get("MODEL", "llama3"),
96
+ "messages": ollama_messages,
97
+ "stream": stream,
98
+ "think": self.think,
99
+ "options": options,
100
+ }
101
+
102
+ # Add tools if enabled
103
+ if self.enable_function:
104
+ params["tools"] = get_openai_schemas()
105
+
106
+ if self.verbose:
107
+ self.console.print("Ollama API params:")
108
+ self.console.print(params)
109
+ try:
110
+ if stream:
111
+ response_generator = self.client.chat(**params)
112
+ yield from self._handle_stream_response(response_generator)
113
+ else:
114
+ response = self.client.chat(**params)
115
+ yield from self._handle_normal_response(response)
116
+
117
+ except Exception as e:
118
+ self.console.print(f"Ollama API error: {e}", style="red")
119
+ yield LLMResponse(content=f"Error calling Ollama API: {str(e)}")
120
+
121
+ def _handle_normal_response(self, response: Dict[str, Any]) -> Generator[LLMResponse, None, None]:
122
+ """Handle normal (non-streaming) response"""
123
+ content = response.message.content or ""
124
+ reasoning = response.message.thinking or ""
125
+
126
+ # Check for tool calls in the response
127
+ tool_call = None
128
+ tool_calls = response.message.tool_calls or []
129
+
130
+ if tool_calls and self.enable_function:
131
+ # Get the first tool call
132
+ tc = tool_calls[0]
133
+ function_data = tc.get("function", {})
134
+
135
+ # Create tool call with appropriate data type handling
136
+ arguments = function_data.get("arguments", "")
137
+ if isinstance(arguments, dict):
138
+ arguments = json.dumps(arguments)
139
+
140
+ tool_call = ToolCall(
141
+ id=tc.get("id", f"tc_{hash(function_data.get('name', ''))}_{int(time.time())}"),
142
+ name=function_data.get("name", ""),
143
+ arguments=arguments,
144
+ )
145
+
146
+ yield LLMResponse(content=content, reasoning=reasoning, tool_call=tool_call)
147
+
148
+ def _handle_stream_response(self, response_generator) -> Generator[LLMResponse, None, None]:
149
+ """Handle streaming response"""
150
+ accumulated_content = ""
151
+ tool_call = None
152
+
153
+ for chunk in response_generator:
154
+ # Extract content from the current chunk
155
+ message = chunk.message
156
+ content = message.content or ""
157
+ reasoning = message.thinking or ""
158
+
159
+ if content or reasoning:
160
+ accumulated_content += content
161
+ yield LLMResponse(content=content, reasoning=reasoning)
162
+
163
+ # Check for tool calls in the chunk
164
+ tool_calls = message.tool_calls or []
165
+ if tool_calls and self.enable_function:
166
+ # Only handle the first tool call for now
167
+ tc = tool_calls[0]
168
+ function_data = tc.get("function", {})
169
+
170
+ # Create tool call with appropriate data type handling
171
+ arguments = function_data.get("arguments", "")
172
+ if isinstance(arguments, dict):
173
+ arguments = json.dumps(arguments)
174
+
175
+ tool_call = ToolCall(
176
+ id=tc.get("id", None) or f"tc_{hash(function_data.get('name', ''))}_{int(time.time())}",
177
+ name=function_data.get("name", ""),
178
+ arguments=arguments,
179
+ )
180
+
181
+ # After streaming is complete, if we found a tool call, yield it
182
+ if tool_call:
183
+ yield LLMResponse(tool_call=tool_call)
184
+
185
+ def detect_tool_role(self) -> str:
186
+ """Return the role to be used for tool responses"""
187
+ return "tool"