yaicli 0.5.8__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyproject.toml +37 -14
- yaicli/cli.py +31 -20
- yaicli/const.py +6 -5
- yaicli/entry.py +1 -1
- yaicli/llms/__init__.py +13 -0
- yaicli/llms/client.py +120 -0
- yaicli/llms/provider.py +76 -0
- yaicli/llms/providers/ai21_provider.py +65 -0
- yaicli/llms/providers/chatglm_provider.py +134 -0
- yaicli/llms/providers/chutes_provider.py +7 -0
- yaicli/llms/providers/cohere_provider.py +298 -0
- yaicli/llms/providers/deepseek_provider.py +11 -0
- yaicli/llms/providers/doubao_provider.py +51 -0
- yaicli/llms/providers/groq_provider.py +14 -0
- yaicli/llms/providers/infiniai_provider.py +14 -0
- yaicli/llms/providers/modelscope_provider.py +11 -0
- yaicli/llms/providers/ollama_provider.py +187 -0
- yaicli/llms/providers/openai_provider.py +187 -0
- yaicli/llms/providers/openrouter_provider.py +11 -0
- yaicli/llms/providers/sambanova_provider.py +28 -0
- yaicli/llms/providers/siliconflow_provider.py +11 -0
- yaicli/llms/providers/yi_provider.py +7 -0
- yaicli/printer.py +4 -16
- yaicli/schemas.py +12 -3
- yaicli/tools.py +59 -3
- {yaicli-0.5.8.dist-info → yaicli-0.6.0.dist-info}/METADATA +240 -34
- yaicli-0.6.0.dist-info/RECORD +41 -0
- yaicli/client.py +0 -391
- yaicli-0.5.8.dist-info/RECORD +0 -24
- {yaicli-0.5.8.dist-info → yaicli-0.6.0.dist-info}/WHEEL +0 -0
- {yaicli-0.5.8.dist-info → yaicli-0.6.0.dist-info}/entry_points.txt +0 -0
- {yaicli-0.5.8.dist-info → yaicli-0.6.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,298 @@
|
|
1
|
+
"""
|
2
|
+
Cohere API provider implementation
|
3
|
+
|
4
|
+
This module implements Cohere provider classes for different deployment options:
|
5
|
+
- CohereProvider: Standard Cohere API
|
6
|
+
- CohereBadrockProvider: AWS Bedrock integration
|
7
|
+
- CohereSagemaker: AWS Sagemaker integration
|
8
|
+
"""
|
9
|
+
|
10
|
+
from typing import Any, Dict, Generator, List, Optional
|
11
|
+
|
12
|
+
from cohere import BedrockClientV2, ClientV2, SagemakerClientV2
|
13
|
+
from cohere.types.tool_call_v2 import ToolCallV2, ToolCallV2Function
|
14
|
+
|
15
|
+
from ...config import cfg
|
16
|
+
from ...console import get_console
|
17
|
+
from ...schemas import ChatMessage, LLMResponse, ToolCall
|
18
|
+
from ...tools import get_openai_schemas
|
19
|
+
from ..provider import Provider
|
20
|
+
|
21
|
+
|
22
|
+
class CohereProvider(Provider):
|
23
|
+
"""Cohere provider implementation based on cohere library"""
|
24
|
+
|
25
|
+
DEFAULT_BASE_URL = "https://api.cohere.com/v2"
|
26
|
+
CLIENT_CLS = ClientV2
|
27
|
+
DEFAULT_MODEL = "command-a-03-2025"
|
28
|
+
|
29
|
+
def __init__(self, config: dict = cfg, verbose: bool = False, **kwargs):
|
30
|
+
"""
|
31
|
+
Initialize the Cohere provider
|
32
|
+
|
33
|
+
Args:
|
34
|
+
config: Configuration dictionary
|
35
|
+
verbose: Whether to enable verbose logging
|
36
|
+
**kwargs: Additional parameters passed to the client
|
37
|
+
"""
|
38
|
+
self.config = config
|
39
|
+
self.verbose = verbose
|
40
|
+
self.client_params = {
|
41
|
+
"api_key": self.config["API_KEY"],
|
42
|
+
"timeout": self.config["TIMEOUT"],
|
43
|
+
}
|
44
|
+
if self.config["BASE_URL"]:
|
45
|
+
self.client_params["base_url"] = self.config["BASE_URL"]
|
46
|
+
self.client = self.create_client()
|
47
|
+
self.console = get_console()
|
48
|
+
|
49
|
+
def create_client(self):
|
50
|
+
"""Create and return Cohere client instance"""
|
51
|
+
if self.config.get("ENVIRONMENT"):
|
52
|
+
self.client_params["environment"] = self.config["ENVIRONMENT"]
|
53
|
+
return self.CLIENT_CLS(**self.client_params)
|
54
|
+
|
55
|
+
def detect_tool_role(self) -> str:
|
56
|
+
"""Return the role name for tool response messages"""
|
57
|
+
return "tool"
|
58
|
+
|
59
|
+
def _convert_messages(self, messages: List[ChatMessage]) -> List[Dict[str, Any]]:
|
60
|
+
"""
|
61
|
+
Convert a list of ChatMessage objects to a list of Cohere message dicts
|
62
|
+
|
63
|
+
{
|
64
|
+
"role": "tool",
|
65
|
+
"tool_call_id": tc.id,
|
66
|
+
"content": {
|
67
|
+
"type": "document",
|
68
|
+
"document": {"data": string},
|
69
|
+
},
|
70
|
+
}
|
71
|
+
|
72
|
+
Args:
|
73
|
+
messages: List of ChatMessage objects
|
74
|
+
|
75
|
+
Returns:
|
76
|
+
List of message dicts formatted for Cohere API
|
77
|
+
"""
|
78
|
+
converted_messages = []
|
79
|
+
for msg in messages:
|
80
|
+
# Create base message
|
81
|
+
message = {}
|
82
|
+
|
83
|
+
# Set role always
|
84
|
+
message["role"] = msg.role
|
85
|
+
|
86
|
+
# Add tool calls for assistant messages
|
87
|
+
if msg.role == "assistant" and msg.tool_calls:
|
88
|
+
# {
|
89
|
+
# "role": "assistant",
|
90
|
+
# "tool_calls": response.message.tool_calls,
|
91
|
+
# "tool_plan": response.message.tool_plan,
|
92
|
+
# }
|
93
|
+
message["tool_calls"] = [
|
94
|
+
ToolCallV2(
|
95
|
+
id=tc.id,
|
96
|
+
type="function",
|
97
|
+
function=ToolCallV2Function(name=tc.name, arguments=tc.arguments),
|
98
|
+
)
|
99
|
+
for tc in msg.tool_calls
|
100
|
+
]
|
101
|
+
else:
|
102
|
+
# Add content for non-tool-call messages
|
103
|
+
message["content"] = msg.content or ""
|
104
|
+
|
105
|
+
# Add tool call ID for tool messages
|
106
|
+
if msg.role == "tool" and msg.tool_call_id:
|
107
|
+
message["tool_call_id"] = msg.tool_call_id
|
108
|
+
|
109
|
+
# For tool messages, convert content to the expected document format
|
110
|
+
if msg.content:
|
111
|
+
message["content"] = [{"type": "document", "document": {"data": msg.content}}]
|
112
|
+
|
113
|
+
converted_messages.append(message)
|
114
|
+
|
115
|
+
return converted_messages
|
116
|
+
|
117
|
+
def _prepare_tools(self) -> Optional[List[Dict[str, Any]]]:
|
118
|
+
"""
|
119
|
+
Prepare tools for Cohere API if enabled
|
120
|
+
|
121
|
+
Returns:
|
122
|
+
List of tool definitions or None if disabled
|
123
|
+
"""
|
124
|
+
if not self.config.get("ENABLE_FUNCTIONS", False):
|
125
|
+
return None
|
126
|
+
|
127
|
+
tools = get_openai_schemas()
|
128
|
+
if not tools and self.verbose:
|
129
|
+
self.console.print("No tools available", style="yellow")
|
130
|
+
return tools
|
131
|
+
|
132
|
+
def _handle_streaming_response(self, response_stream) -> Generator[LLMResponse, None, None]:
|
133
|
+
"""
|
134
|
+
Process streaming response from Cohere API
|
135
|
+
|
136
|
+
doc: https://docs.cohere.com/v2/docs/streaming
|
137
|
+
|
138
|
+
According to Cohere docs, there are multiple event types:
|
139
|
+
- message-start: First event with metadata
|
140
|
+
- content-start: Start of content block
|
141
|
+
- content-delta: Chunk of generated text
|
142
|
+
- content-end: End of content block
|
143
|
+
- message-end: End of message
|
144
|
+
- tool-plan-delta: Part of tool planning
|
145
|
+
- tool-call-start: Start of tool call
|
146
|
+
- tool-call-delta: Part of tool call
|
147
|
+
- tool-call-end: End of tool call
|
148
|
+
- citation-start/end: For citations in RAG
|
149
|
+
|
150
|
+
Args:
|
151
|
+
response_stream: Stream from Cohere client
|
152
|
+
|
153
|
+
Yields:
|
154
|
+
LLMResponse objects with content or tool calls
|
155
|
+
"""
|
156
|
+
tool_call: Optional[ToolCall] = None
|
157
|
+
for chunk in response_stream:
|
158
|
+
if not chunk:
|
159
|
+
continue
|
160
|
+
|
161
|
+
# Handle different event types
|
162
|
+
if chunk.type == "content-delta":
|
163
|
+
# Text generation chunks
|
164
|
+
content = chunk.delta.message.content.text or ""
|
165
|
+
yield LLMResponse(content=content)
|
166
|
+
|
167
|
+
elif chunk.type == "tool-plan-delta":
|
168
|
+
# Tool planning - when model is deciding which tool to use: cohere.types.chat_tool_plan_delta_event_delta_message.ChatToolPlanDeltaEventDeltaMessage
|
169
|
+
content = chunk.delta.message.tool_plan or ""
|
170
|
+
yield LLMResponse(content=content)
|
171
|
+
|
172
|
+
elif chunk.type == "tool-call-start":
|
173
|
+
# Start of tool call
|
174
|
+
tool_call_msg = chunk.delta.message.tool_calls
|
175
|
+
tool_call = ToolCall(
|
176
|
+
id=tool_call_msg.id, name=tool_call_msg.function.name, arguments=tool_call_msg.function.arguments
|
177
|
+
)
|
178
|
+
# Tool call started, waiting for tool-calls-delta events
|
179
|
+
continue
|
180
|
+
elif chunk.type == "tool-call-delta":
|
181
|
+
# Tool call arguments being generated: cohere.types.chat_tool_call_delta_event_delta_message.ChatToolCallDeltaEventDeltaMessage
|
182
|
+
tool_call.arguments += chunk.delta.message.tool_calls.function.arguments
|
183
|
+
# Waiting for tool-call-end event
|
184
|
+
continue
|
185
|
+
|
186
|
+
elif chunk.type == "tool-call-end":
|
187
|
+
# End of a tool call, empty chunk
|
188
|
+
yield LLMResponse(tool_call=tool_call)
|
189
|
+
|
190
|
+
def _handle_normal_response(self, response) -> Generator[LLMResponse, None, None]:
|
191
|
+
"""
|
192
|
+
Process non-streaming response from Cohere API
|
193
|
+
|
194
|
+
Args:
|
195
|
+
response: Response from Cohere client
|
196
|
+
|
197
|
+
Yields:
|
198
|
+
LLMResponse objects with content or tool calls
|
199
|
+
"""
|
200
|
+
# Handle content
|
201
|
+
if response.message.content:
|
202
|
+
for content_item in response.message.content:
|
203
|
+
if hasattr(content_item, "text") and content_item.text:
|
204
|
+
yield LLMResponse(content=content_item.text)
|
205
|
+
|
206
|
+
# Handle tool calls
|
207
|
+
if response.message.tool_calls:
|
208
|
+
yield LLMResponse(content=response.message.tool_plan)
|
209
|
+
for tool_call in response.message.tool_calls:
|
210
|
+
yield LLMResponse(
|
211
|
+
tool_call=ToolCall(
|
212
|
+
id=tool_call.id,
|
213
|
+
name=tool_call.function.name,
|
214
|
+
arguments=tool_call.function.arguments,
|
215
|
+
)
|
216
|
+
)
|
217
|
+
|
218
|
+
def completion(
|
219
|
+
self, messages: List[ChatMessage], stream: bool = False, **kwargs
|
220
|
+
) -> Generator[LLMResponse, None, None]:
|
221
|
+
"""
|
222
|
+
Get completion from Cohere models
|
223
|
+
|
224
|
+
Args:
|
225
|
+
messages: List of messages for the conversation
|
226
|
+
stream: Whether to stream the response
|
227
|
+
**kwargs: Additional parameters to pass to the Cohere client
|
228
|
+
|
229
|
+
Yields:
|
230
|
+
LLMResponse objects with content or tool calls
|
231
|
+
"""
|
232
|
+
# Get configuration values
|
233
|
+
model = self.config.get("MODEL", self.DEFAULT_MODEL)
|
234
|
+
temperature = float(self.config.get("TEMPERATURE", 0.7))
|
235
|
+
|
236
|
+
# Prepare messages and tools
|
237
|
+
cohere_messages = self._convert_messages(messages)
|
238
|
+
if self.verbose:
|
239
|
+
self.console.print("Messages:")
|
240
|
+
self.console.print(cohere_messages)
|
241
|
+
tools = self._prepare_tools()
|
242
|
+
|
243
|
+
# Common request parameters
|
244
|
+
request_params = {"model": model, "messages": cohere_messages, "temperature": temperature, **kwargs}
|
245
|
+
|
246
|
+
# Add tools if available
|
247
|
+
if tools:
|
248
|
+
request_params["tools"] = tools
|
249
|
+
|
250
|
+
# Call Cohere API
|
251
|
+
try:
|
252
|
+
if stream:
|
253
|
+
# Streaming mode
|
254
|
+
response_stream = self.client.chat_stream(**request_params)
|
255
|
+
yield from self._handle_streaming_response(response_stream)
|
256
|
+
else:
|
257
|
+
# Non-streaming mode
|
258
|
+
response = self.client.chat(**request_params)
|
259
|
+
yield from self._handle_normal_response(response)
|
260
|
+
|
261
|
+
except Exception as e:
|
262
|
+
error_msg = f"Error in Cohere API call: {e}"
|
263
|
+
if self.verbose:
|
264
|
+
import traceback
|
265
|
+
|
266
|
+
self.console.print("Error in Cohere completion:")
|
267
|
+
traceback.print_exc()
|
268
|
+
yield LLMResponse(content=error_msg)
|
269
|
+
|
270
|
+
|
271
|
+
class CohereBadrockProvider(CohereProvider):
|
272
|
+
"""Cohere provider for AWS Bedrock integration"""
|
273
|
+
|
274
|
+
CLIENT_CLS = BedrockClientV2
|
275
|
+
DOC_URL = "https://docs.cohere.com/v2/docs/text-gen-quickstart"
|
276
|
+
CLIENT_KEYS = (
|
277
|
+
("AWS_REGION", "aws_region"),
|
278
|
+
("AWS_ACCESS_KEY_ID", "aws_access_key"),
|
279
|
+
("AWS_SECRET_ACCESS_KEY", "aws_secret_key"),
|
280
|
+
("AWS_SESSION_TOKEN", "aws_session_token"),
|
281
|
+
)
|
282
|
+
|
283
|
+
def create_client(self):
|
284
|
+
"""Create Bedrock client with AWS credentials"""
|
285
|
+
for k, p in self.CLIENT_KEYS:
|
286
|
+
v = self.config.get(k, None)
|
287
|
+
if v is None:
|
288
|
+
raise ValueError(
|
289
|
+
f"You have to set key `{k}` to use {self.__class__.__name__}, see cohere doc `{self.DOC_URL}`"
|
290
|
+
)
|
291
|
+
self.client_params[p] = v
|
292
|
+
return self.CLIENT_CLS(**self.client_params)
|
293
|
+
|
294
|
+
|
295
|
+
class CohereSagemaker(CohereBadrockProvider):
|
296
|
+
"""Cohere provider for AWS Sagemaker integration"""
|
297
|
+
|
298
|
+
CLIENT_CLS = SagemakerClientV2
|
@@ -0,0 +1,11 @@
|
|
1
|
+
from .openai_provider import OpenAIProvider
|
2
|
+
|
3
|
+
|
4
|
+
class DeepSeekProvider(OpenAIProvider):
|
5
|
+
"""DeepSeek provider implementation based on openai-compatible API"""
|
6
|
+
|
7
|
+
DEFAULT_BASE_URL = "https://api.deepseek.com/v1"
|
8
|
+
|
9
|
+
def __init__(self, config: dict = ..., **kwargs):
|
10
|
+
super().__init__(config, **kwargs)
|
11
|
+
self.completion_params["max_tokens"] = self.completion_params.pop("max_completion_tokens")
|
@@ -0,0 +1,51 @@
|
|
1
|
+
from volcenginesdkarkruntime import Ark
|
2
|
+
|
3
|
+
from ...config import cfg
|
4
|
+
from ...console import get_console
|
5
|
+
from .openai_provider import OpenAIProvider
|
6
|
+
|
7
|
+
|
8
|
+
class DoubaoProvider(OpenAIProvider):
|
9
|
+
"""Doubao provider implementation based on openai-compatible API"""
|
10
|
+
|
11
|
+
DEFAULT_BASE_URL = "https://ark.cn-beijing.volces.com/api/v3"
|
12
|
+
|
13
|
+
def __init__(self, config: dict = cfg, **kwargs):
|
14
|
+
self.config = config
|
15
|
+
self.enable_function = self.config["ENABLE_FUNCTIONS"]
|
16
|
+
# Initialize client params
|
17
|
+
self.client_params = {"base_url": self.DEFAULT_BASE_URL}
|
18
|
+
if self.config.get("API_KEY", None):
|
19
|
+
self.client_params["api_key"] = self.config["API_KEY"]
|
20
|
+
if self.config.get("BASE_URL", None):
|
21
|
+
self.client_params["base_url"] = self.config["BASE_URL"]
|
22
|
+
if self.config.get("AK", None):
|
23
|
+
self.client_params["ak"] = self.config["AK"]
|
24
|
+
if self.config.get("SK", None):
|
25
|
+
self.client_params["sk"] = self.config["SK"]
|
26
|
+
if self.config.get("REGION", None):
|
27
|
+
self.client_params["region"] = self.config["REGION"]
|
28
|
+
|
29
|
+
# Initialize client
|
30
|
+
self.client = Ark(**self.client_params)
|
31
|
+
self.console = get_console()
|
32
|
+
|
33
|
+
# Store completion params
|
34
|
+
self.completion_params = {
|
35
|
+
"model": self.config["MODEL"],
|
36
|
+
"temperature": self.config["TEMPERATURE"],
|
37
|
+
"top_p": self.config["TOP_P"],
|
38
|
+
"max_tokens": self.config["MAX_TOKENS"],
|
39
|
+
"timeout": self.config["TIMEOUT"],
|
40
|
+
}
|
41
|
+
# Add extra headers if set
|
42
|
+
if self.config.get("EXTRA_HEADERS", None):
|
43
|
+
self.completion_params["extra_headers"] = {
|
44
|
+
**self.config["EXTRA_HEADERS"],
|
45
|
+
"X-Title": self.APP_NAME,
|
46
|
+
"HTTP-Referer": self.APPA_REFERER,
|
47
|
+
}
|
48
|
+
|
49
|
+
# Add extra body params if set
|
50
|
+
if self.config.get("EXTRA_BODY", None):
|
51
|
+
self.completion_params["extra_body"] = self.config["EXTRA_BODY"]
|
@@ -0,0 +1,14 @@
|
|
1
|
+
from .openai_provider import OpenAIProvider
|
2
|
+
|
3
|
+
|
4
|
+
class GroqProvider(OpenAIProvider):
|
5
|
+
"""Groq provider implementation based on openai-compatible API"""
|
6
|
+
|
7
|
+
DEFAULT_BASE_URL = "https://api.groq.com/openai/v1"
|
8
|
+
|
9
|
+
def __init__(self, config: dict = ..., **kwargs):
|
10
|
+
super().__init__(config, **kwargs)
|
11
|
+
if self.config.get("EXTRA_BODY") and "N" in self.config["EXTRA_BODY"] and self.config["EXTRA_BODY"]["N"] != 1:
|
12
|
+
self.console.print("Groq does not support N parameter, setting N to 1 as Groq default", style="yellow")
|
13
|
+
if "extra_body" in self.completion_params:
|
14
|
+
self.completion_params["extra_body"]["N"] = 1
|
@@ -0,0 +1,14 @@
|
|
1
|
+
from .openai_provider import OpenAIProvider
|
2
|
+
|
3
|
+
|
4
|
+
class InfiniAIProvider(OpenAIProvider):
|
5
|
+
"""InfiniAI provider implementation based on openai-compatible API"""
|
6
|
+
|
7
|
+
DEFAULT_BASE_URL = "https://cloud.infini-ai.com/maas/v1"
|
8
|
+
|
9
|
+
def __init__(self, config: dict = ..., **kwargs):
|
10
|
+
super().__init__(config, **kwargs)
|
11
|
+
if self.enable_function:
|
12
|
+
self.console.print("InfiniAI does not support functions, disabled", style="yellow")
|
13
|
+
self.enable_function = False
|
14
|
+
self.completion_params["max_tokens"] = self.completion_params.pop("max_completion_tokens")
|
@@ -0,0 +1,11 @@
|
|
1
|
+
from .openai_provider import OpenAIProvider
|
2
|
+
|
3
|
+
|
4
|
+
class ModelScopeProvider(OpenAIProvider):
|
5
|
+
"""ModelScope provider implementation based on openai-compatible API"""
|
6
|
+
|
7
|
+
DEFAULT_BASE_URL = "https://api-inference.modelscope.cn/v1/"
|
8
|
+
|
9
|
+
def __init__(self, config: dict = ..., **kwargs):
|
10
|
+
super().__init__(config, **kwargs)
|
11
|
+
self.completion_params["max_tokens"] = self.completion_params.pop("max_completion_tokens")
|
@@ -0,0 +1,187 @@
|
|
1
|
+
import json
|
2
|
+
import time
|
3
|
+
from typing import Any, Dict, Generator, List
|
4
|
+
|
5
|
+
import ollama
|
6
|
+
|
7
|
+
from ...config import cfg
|
8
|
+
from ...console import get_console
|
9
|
+
from ...schemas import ChatMessage, LLMResponse, ToolCall
|
10
|
+
from ...tools import get_openai_schemas
|
11
|
+
from ...utils import str2bool
|
12
|
+
from ..provider import Provider
|
13
|
+
|
14
|
+
|
15
|
+
class OllamaProvider(Provider):
|
16
|
+
"""Ollama provider implementation based on ollama Python library"""
|
17
|
+
|
18
|
+
DEFAULT_BASE_URL = "http://localhost:11434"
|
19
|
+
OPTION_KEYS = (
|
20
|
+
("SEED", "seed"),
|
21
|
+
("NUM_PREDICT", "num_predict"),
|
22
|
+
("NUM_CTX", "num_ctx"),
|
23
|
+
("NUM_BATCH", "num_batch"),
|
24
|
+
("NUM_GPU", "num_gpu"),
|
25
|
+
("MAIN_GPU", "main_gpu"),
|
26
|
+
("LOW_VRAM", "low_vram"),
|
27
|
+
("F16_KV", "f16_kv"),
|
28
|
+
("LOGITS_ALL", "logits_all"),
|
29
|
+
("VOCAB_ONLY", "vocab_only"),
|
30
|
+
("USE_MMAP", "use_mmap"),
|
31
|
+
("USE_MLOCK", "use_mlock"),
|
32
|
+
("NUM_THREAD", "num_thread"),
|
33
|
+
)
|
34
|
+
|
35
|
+
def __init__(self, config: dict = cfg, verbose: bool = False, **kwargs):
|
36
|
+
self.config = config
|
37
|
+
self.enable_function = self.config.get("ENABLE_FUNCTIONS", False)
|
38
|
+
self.verbose = verbose
|
39
|
+
self.think = str2bool(self.config.get("THINK", False))
|
40
|
+
|
41
|
+
# Initialize client params - Ollama host support
|
42
|
+
self.host = self.config.get("BASE_URL") or self.DEFAULT_BASE_URL
|
43
|
+
|
44
|
+
# Initialize console
|
45
|
+
self.console = get_console()
|
46
|
+
|
47
|
+
self.client = ollama.Client(host=self.host, timeout=self.config["TIMEOUT"])
|
48
|
+
|
49
|
+
def _convert_messages(self, messages: List[ChatMessage]) -> List[Dict[str, Any]]:
|
50
|
+
"""Convert a list of ChatMessage objects to a list of Ollama message dicts."""
|
51
|
+
converted_messages = []
|
52
|
+
for msg in messages:
|
53
|
+
message = {"role": msg.role, "content": msg.content or ""}
|
54
|
+
|
55
|
+
if msg.name:
|
56
|
+
message["name"] = msg.name
|
57
|
+
|
58
|
+
# Handle tool calls - Ollama now supports the OpenAI format directly
|
59
|
+
if msg.role == "assistant" and msg.tool_calls:
|
60
|
+
message["tool_calls"] = [
|
61
|
+
{
|
62
|
+
"id": tc.id,
|
63
|
+
"type": "function",
|
64
|
+
"function": {"name": tc.name, "arguments": json.loads(tc.arguments)},
|
65
|
+
}
|
66
|
+
for tc in msg.tool_calls
|
67
|
+
]
|
68
|
+
|
69
|
+
# Handle tool responses - Ollama supports tool_call_id directly
|
70
|
+
if msg.role == "tool" and msg.tool_call_id:
|
71
|
+
message["tool_call_id"] = msg.tool_call_id
|
72
|
+
|
73
|
+
converted_messages.append(message)
|
74
|
+
|
75
|
+
return converted_messages
|
76
|
+
|
77
|
+
def completion(
|
78
|
+
self,
|
79
|
+
messages: List[ChatMessage],
|
80
|
+
stream: bool = False,
|
81
|
+
) -> Generator[LLMResponse, None, None]:
|
82
|
+
"""Send messages to Ollama and get response"""
|
83
|
+
# Convert message format
|
84
|
+
ollama_messages = self._convert_messages(messages)
|
85
|
+
if self.verbose:
|
86
|
+
self.console.print("Messages:")
|
87
|
+
self.console.print(ollama_messages)
|
88
|
+
options = {"temperature": self.config["TEMPERATURE"], "top_p": self.config["TOP_P"]}
|
89
|
+
for k, v in self.OPTION_KEYS:
|
90
|
+
if self.config.get(k, None) is not None:
|
91
|
+
options[v] = self.config[k]
|
92
|
+
|
93
|
+
# Prepare parameters
|
94
|
+
params = {
|
95
|
+
"model": self.config.get("MODEL", "llama3"),
|
96
|
+
"messages": ollama_messages,
|
97
|
+
"stream": stream,
|
98
|
+
"think": self.think,
|
99
|
+
"options": options,
|
100
|
+
}
|
101
|
+
|
102
|
+
# Add tools if enabled
|
103
|
+
if self.enable_function:
|
104
|
+
params["tools"] = get_openai_schemas()
|
105
|
+
|
106
|
+
if self.verbose:
|
107
|
+
self.console.print("Ollama API params:")
|
108
|
+
self.console.print(params)
|
109
|
+
try:
|
110
|
+
if stream:
|
111
|
+
response_generator = self.client.chat(**params)
|
112
|
+
yield from self._handle_stream_response(response_generator)
|
113
|
+
else:
|
114
|
+
response = self.client.chat(**params)
|
115
|
+
yield from self._handle_normal_response(response)
|
116
|
+
|
117
|
+
except Exception as e:
|
118
|
+
self.console.print(f"Ollama API error: {e}", style="red")
|
119
|
+
yield LLMResponse(content=f"Error calling Ollama API: {str(e)}")
|
120
|
+
|
121
|
+
def _handle_normal_response(self, response: Dict[str, Any]) -> Generator[LLMResponse, None, None]:
|
122
|
+
"""Handle normal (non-streaming) response"""
|
123
|
+
content = response.message.content or ""
|
124
|
+
reasoning = response.message.thinking or ""
|
125
|
+
|
126
|
+
# Check for tool calls in the response
|
127
|
+
tool_call = None
|
128
|
+
tool_calls = response.message.tool_calls or []
|
129
|
+
|
130
|
+
if tool_calls and self.enable_function:
|
131
|
+
# Get the first tool call
|
132
|
+
tc = tool_calls[0]
|
133
|
+
function_data = tc.get("function", {})
|
134
|
+
|
135
|
+
# Create tool call with appropriate data type handling
|
136
|
+
arguments = function_data.get("arguments", "")
|
137
|
+
if isinstance(arguments, dict):
|
138
|
+
arguments = json.dumps(arguments)
|
139
|
+
|
140
|
+
tool_call = ToolCall(
|
141
|
+
id=tc.get("id", f"tc_{hash(function_data.get('name', ''))}_{int(time.time())}"),
|
142
|
+
name=function_data.get("name", ""),
|
143
|
+
arguments=arguments,
|
144
|
+
)
|
145
|
+
|
146
|
+
yield LLMResponse(content=content, reasoning=reasoning, tool_call=tool_call)
|
147
|
+
|
148
|
+
def _handle_stream_response(self, response_generator) -> Generator[LLMResponse, None, None]:
|
149
|
+
"""Handle streaming response"""
|
150
|
+
accumulated_content = ""
|
151
|
+
tool_call = None
|
152
|
+
|
153
|
+
for chunk in response_generator:
|
154
|
+
# Extract content from the current chunk
|
155
|
+
message = chunk.message
|
156
|
+
content = message.content or ""
|
157
|
+
reasoning = message.thinking or ""
|
158
|
+
|
159
|
+
if content or reasoning:
|
160
|
+
accumulated_content += content
|
161
|
+
yield LLMResponse(content=content, reasoning=reasoning)
|
162
|
+
|
163
|
+
# Check for tool calls in the chunk
|
164
|
+
tool_calls = message.tool_calls or []
|
165
|
+
if tool_calls and self.enable_function:
|
166
|
+
# Only handle the first tool call for now
|
167
|
+
tc = tool_calls[0]
|
168
|
+
function_data = tc.get("function", {})
|
169
|
+
|
170
|
+
# Create tool call with appropriate data type handling
|
171
|
+
arguments = function_data.get("arguments", "")
|
172
|
+
if isinstance(arguments, dict):
|
173
|
+
arguments = json.dumps(arguments)
|
174
|
+
|
175
|
+
tool_call = ToolCall(
|
176
|
+
id=tc.get("id", None) or f"tc_{hash(function_data.get('name', ''))}_{int(time.time())}",
|
177
|
+
name=function_data.get("name", ""),
|
178
|
+
arguments=arguments,
|
179
|
+
)
|
180
|
+
|
181
|
+
# After streaming is complete, if we found a tool call, yield it
|
182
|
+
if tool_call:
|
183
|
+
yield LLMResponse(tool_call=tool_call)
|
184
|
+
|
185
|
+
def detect_tool_role(self) -> str:
|
186
|
+
"""Return the role to be used for tool responses"""
|
187
|
+
return "tool"
|