lollms-client 0.24.2__py3-none-any.whl → 0.27.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lollms-client might be problematic. Click here for more details.
- lollms_client/__init__.py +3 -2
- lollms_client/llm_bindings/azure_openai/__init__.py +364 -0
- lollms_client/llm_bindings/claude/__init__.py +549 -0
- lollms_client/llm_bindings/gemini/__init__.py +501 -0
- lollms_client/llm_bindings/grok/__init__.py +536 -0
- lollms_client/llm_bindings/groq/__init__.py +292 -0
- lollms_client/llm_bindings/hugging_face_inference_api/__init__.py +307 -0
- lollms_client/llm_bindings/litellm/__init__.py +201 -0
- lollms_client/llm_bindings/lollms/__init__.py +2 -0
- lollms_client/llm_bindings/mistral/__init__.py +298 -0
- lollms_client/llm_bindings/open_router/__init__.py +304 -0
- lollms_client/llm_bindings/openai/__init__.py +30 -9
- lollms_client/lollms_core.py +338 -162
- lollms_client/lollms_discussion.py +135 -37
- lollms_client/lollms_llm_binding.py +4 -0
- lollms_client/lollms_types.py +9 -1
- lollms_client/lollms_utilities.py +68 -0
- lollms_client/mcp_bindings/remote_mcp/__init__.py +82 -4
- lollms_client-0.27.0.dist-info/METADATA +604 -0
- {lollms_client-0.24.2.dist-info → lollms_client-0.27.0.dist-info}/RECORD +23 -14
- lollms_client-0.24.2.dist-info/METADATA +0 -239
- {lollms_client-0.24.2.dist-info → lollms_client-0.27.0.dist-info}/WHEEL +0 -0
- {lollms_client-0.24.2.dist-info → lollms_client-0.27.0.dist-info}/licenses/LICENSE +0 -0
- {lollms_client-0.24.2.dist-info → lollms_client-0.27.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
# bindings/LiteLLM/binding.py
|
|
2
|
+
import requests
|
|
3
|
+
import json
|
|
4
|
+
from lollms_client.lollms_llm_binding import LollmsLLMBinding
|
|
5
|
+
from lollms_client.lollms_types import MSG_TYPE
|
|
6
|
+
from lollms_client.lollms_discussion import LollmsDiscussion
|
|
7
|
+
from lollms_client.lollms_utilities import encode_image
|
|
8
|
+
from typing import Optional, Callable, List, Union, Dict
|
|
9
|
+
from ascii_colors import ASCIIColors, trace_exception
|
|
10
|
+
|
|
11
|
+
# Use pipmaster to ensure required packages are installed
|
|
12
|
+
try:
|
|
13
|
+
import pipmaster as pm
|
|
14
|
+
except ImportError:
|
|
15
|
+
print("Pipmaster not found. Please install it using 'pip install pipmaster'")
|
|
16
|
+
raise
|
|
17
|
+
|
|
18
|
+
# Ensure requests and tiktoken are installed
|
|
19
|
+
pm.ensure_packages(["requests", "tiktoken"])
|
|
20
|
+
|
|
21
|
+
import tiktoken
|
|
22
|
+
|
|
23
|
+
BindingName = "LiteLLMBinding"
|
|
24
|
+
|
|
25
|
+
def get_icon_path(model_name: str) -> str:
|
|
26
|
+
model_name = model_name.lower()
|
|
27
|
+
if 'gpt' in model_name: return '/bindings/openai/logo.png'
|
|
28
|
+
if 'mistral' in model_name or 'mixtral' in model_name: return '/bindings/mistral/logo.png'
|
|
29
|
+
if 'claude' in model_name: return '/bindings/anthropic/logo.png'
|
|
30
|
+
return '/bindings/litellm/logo.png'
|
|
31
|
+
|
|
32
|
+
class LiteLLMBinding(LollmsLLMBinding):
|
|
33
|
+
"""
|
|
34
|
+
A binding for the LiteLLM proxy using direct HTTP requests.
|
|
35
|
+
This version includes detailed logging, a fallback for listing models,
|
|
36
|
+
and correct payload formatting for both streaming and non-streaming modes.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
def __init__(self, host_address: str, model_name: str, service_key: str = "anything", verify_ssl_certificate: bool = True, **kwargs):
|
|
40
|
+
super().__init__(binding_name="litellm")
|
|
41
|
+
self.host_address = host_address.rstrip('/')
|
|
42
|
+
self.model_name = model_name
|
|
43
|
+
self.service_key = service_key
|
|
44
|
+
self.verify_ssl_certificate = verify_ssl_certificate
|
|
45
|
+
|
|
46
|
+
def _perform_generation(self, messages: List[Dict], n_predict: Optional[int], stream: bool, temperature: float, top_p: float, repeat_penalty: float, seed: Optional[int], streaming_callback: Optional[Callable[[str, MSG_TYPE], None]]) -> Union[str, dict]:
|
|
47
|
+
url = f'{self.host_address}/v1/chat/completions'
|
|
48
|
+
headers = {'Content-Type': 'application/json', 'Authorization': f'Bearer {self.service_key}'}
|
|
49
|
+
payload = {
|
|
50
|
+
"model": self.model_name, "messages": messages, "max_tokens": n_predict,
|
|
51
|
+
"temperature": temperature, "top_p": top_p, "frequency_penalty": repeat_penalty,
|
|
52
|
+
"stream": stream
|
|
53
|
+
}
|
|
54
|
+
if seed is not None: payload["seed"] = seed
|
|
55
|
+
|
|
56
|
+
payload = {k: v for k, v in payload.items() if v is not None}
|
|
57
|
+
output = ""
|
|
58
|
+
try:
|
|
59
|
+
response = requests.post(url, headers=headers, data=json.dumps(payload), stream=stream, verify=self.verify_ssl_certificate)
|
|
60
|
+
response.raise_for_status()
|
|
61
|
+
|
|
62
|
+
if stream:
|
|
63
|
+
for line in response.iter_lines():
|
|
64
|
+
if line:
|
|
65
|
+
decoded_line = line.decode('utf-8')
|
|
66
|
+
if decoded_line.startswith('data: '):
|
|
67
|
+
if '[DONE]' in decoded_line: break
|
|
68
|
+
json_data_string = decoded_line[6:]
|
|
69
|
+
try:
|
|
70
|
+
chunk_data = json.loads(json_data_string)
|
|
71
|
+
delta = chunk_data.get('choices', [{}])[0].get('delta', {})
|
|
72
|
+
if 'content' in delta and delta['content'] is not None:
|
|
73
|
+
word = delta['content']
|
|
74
|
+
if streaming_callback and not streaming_callback(word, MSG_TYPE.MSG_TYPE_CHUNK):
|
|
75
|
+
return output
|
|
76
|
+
output += word
|
|
77
|
+
except json.JSONDecodeError: continue
|
|
78
|
+
else:
|
|
79
|
+
full_response = response.json()
|
|
80
|
+
output = full_response['choices'][0]['message']['content']
|
|
81
|
+
if streaming_callback:
|
|
82
|
+
streaming_callback(output, MSG_TYPE.MSG_TYPE_CHUNK)
|
|
83
|
+
except Exception as e:
|
|
84
|
+
error_message = f"An error occurred: {e}\nResponse: {response.text if 'response' in locals() else 'No response'}"
|
|
85
|
+
trace_exception(e)
|
|
86
|
+
if streaming_callback: streaming_callback(error_message, MSG_TYPE.MSG_TYPE_EXCEPTION)
|
|
87
|
+
return {"status": "error", "message": error_message}
|
|
88
|
+
return output
|
|
89
|
+
|
|
90
|
+
def generate_text(self, prompt: str, images: Optional[List[str]] = None, system_prompt: str = "", n_predict: Optional[int] = None, stream: Optional[bool] = None, temperature: float = 0.7, top_p: float = 0.9, repeat_penalty: float = 1.1, seed: Optional[int] = None, streaming_callback: Optional[Callable[[str, MSG_TYPE], None]] = None, **kwargs) -> Union[str, dict]:
|
|
91
|
+
"""Generates text from a prompt, correctly formatting for text-only and multi-modal cases."""
|
|
92
|
+
is_streaming = stream if stream is not None else (streaming_callback is not None)
|
|
93
|
+
|
|
94
|
+
messages = []
|
|
95
|
+
if system_prompt:
|
|
96
|
+
messages.append({"role": "system", "content": system_prompt})
|
|
97
|
+
|
|
98
|
+
# --- THIS IS THE CRITICAL FIX ---
|
|
99
|
+
if images:
|
|
100
|
+
# If images are present, use the multi-modal list format for content
|
|
101
|
+
user_content = [{"type": "text", "text": prompt}]
|
|
102
|
+
for image_path in images:
|
|
103
|
+
base64_image = encode_image(image_path)
|
|
104
|
+
user_content.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}})
|
|
105
|
+
messages.append({"role": "user", "content": user_content})
|
|
106
|
+
else:
|
|
107
|
+
# If no images, use a simple string for content to avoid the API error
|
|
108
|
+
messages.append({"role": "user", "content": prompt})
|
|
109
|
+
# --- END OF FIX ---
|
|
110
|
+
|
|
111
|
+
return self._perform_generation(messages, n_predict, is_streaming, temperature, top_p, repeat_penalty, seed, streaming_callback)
|
|
112
|
+
|
|
113
|
+
def chat(self, discussion: LollmsDiscussion, branch_tip_id: Optional[str] = None, n_predict: Optional[int] = None, stream: Optional[bool] = None, temperature: float = 0.7, top_p: float = 0.9, repeat_penalty: float = 1.1, seed: Optional[int] = None, streaming_callback: Optional[Callable[[str, MSG_TYPE], None]] = None, **kwargs) -> Union[str, dict]:
|
|
114
|
+
is_streaming = stream if stream is not None else (streaming_callback is not None)
|
|
115
|
+
messages = discussion.export("openai_chat", branch_tip_id)
|
|
116
|
+
return self._perform_generation(messages, n_predict, is_streaming, temperature, top_p, repeat_penalty, seed, streaming_callback)
|
|
117
|
+
|
|
118
|
+
def embed(self, text: str, **kwargs) -> List[float]:
|
|
119
|
+
url = f'{self.host_address}/v1/embeddings'
|
|
120
|
+
headers = {'Content-Type': 'application/json', 'Authorization': f'Bearer {self.service_key}'}
|
|
121
|
+
payload = {"model": self.model_name, "input": text}
|
|
122
|
+
try:
|
|
123
|
+
response = requests.post(url, headers=headers, data=json.dumps(payload), verify=self.verify_ssl_certificate)
|
|
124
|
+
response.raise_for_status()
|
|
125
|
+
return response.json()['data'][0]['embedding']
|
|
126
|
+
except Exception as e:
|
|
127
|
+
trace_exception(e)
|
|
128
|
+
return []
|
|
129
|
+
|
|
130
|
+
def tokenize(self, text: str) -> list:
|
|
131
|
+
return tiktoken.model.encoding_for_model("gpt-3.5-turbo").encode(text)
|
|
132
|
+
|
|
133
|
+
def detokenize(self, tokens: list) -> str:
|
|
134
|
+
return tiktoken.model.encoding_for_model("gpt-3.5-turbo").decode(tokens)
|
|
135
|
+
|
|
136
|
+
def count_tokens(self, text: str) -> int:
|
|
137
|
+
return len(self.tokenize(text))
|
|
138
|
+
|
|
139
|
+
def _list_models_openai_fallback(self) -> List[Dict]:
|
|
140
|
+
ASCIIColors.warning("--- [LiteLLM Binding] Falling back to /v1/models endpoint. Rich metadata will be unavailable.")
|
|
141
|
+
url = f'{self.host_address}/v1/models'
|
|
142
|
+
headers = {'Authorization': f'Bearer {self.service_key}'}
|
|
143
|
+
entries = []
|
|
144
|
+
try:
|
|
145
|
+
response = requests.get(url, headers=headers, verify=self.verify_ssl_certificate)
|
|
146
|
+
response.raise_for_status()
|
|
147
|
+
models_data = response.json().get('data', [])
|
|
148
|
+
for model in models_data:
|
|
149
|
+
model_name = model.get('id')
|
|
150
|
+
entries.append({
|
|
151
|
+
"category": "api", "datasets": "unknown", "icon": get_icon_path(model_name),
|
|
152
|
+
"license": "unknown", "model_creator": model.get('owned_by', 'unknown'),
|
|
153
|
+
"name": model_name, "provider": "litellm", "rank": "1.0", "type": "api",
|
|
154
|
+
"variants": [{"name": model_name, "size": -1}]
|
|
155
|
+
})
|
|
156
|
+
except Exception as e:
|
|
157
|
+
ASCIIColors.error(f"--- [LiteLLM Binding] Fallback method failed: {e}")
|
|
158
|
+
return entries
|
|
159
|
+
|
|
160
|
+
def listModels(self) -> List[Dict]:
|
|
161
|
+
url = f'{self.host_address}/model/info'
|
|
162
|
+
headers = {'Authorization': f'Bearer {self.service_key}'}
|
|
163
|
+
entries = []
|
|
164
|
+
ASCIIColors.yellow(f"--- [LiteLLM Binding] Attempting to list models from: {url}")
|
|
165
|
+
try:
|
|
166
|
+
response = requests.get(url, headers=headers, verify=self.verify_ssl_certificate)
|
|
167
|
+
if response.status_code == 404:
|
|
168
|
+
ASCIIColors.warning("--- [LiteLLM Binding] /model/info endpoint not found (404).")
|
|
169
|
+
return self._list_models_openai_fallback()
|
|
170
|
+
response.raise_for_status()
|
|
171
|
+
models_data = response.json().get('data', [])
|
|
172
|
+
ASCIIColors.info(f"--- [LiteLLM Binding] Successfully parsed {len(models_data)} models from primary endpoint.")
|
|
173
|
+
for model in models_data:
|
|
174
|
+
model_name = model.get('model_name')
|
|
175
|
+
if not model_name: continue
|
|
176
|
+
model_info = model.get('model_info', {})
|
|
177
|
+
context_size = model_info.get('max_tokens', model_info.get('max_input_tokens', 4096))
|
|
178
|
+
entries.append({
|
|
179
|
+
"category": "api", "datasets": "unknown", "icon": get_icon_path(model_name),
|
|
180
|
+
"license": "unknown", "model_creator": model_info.get('owned_by', 'unknown'),
|
|
181
|
+
"name": model_name, "provider": "litellm", "rank": "1.0", "type": "api",
|
|
182
|
+
"variants": [{
|
|
183
|
+
"name": model_name, "size": context_size,
|
|
184
|
+
"input_cost_per_token": model_info.get('input_cost_per_token', 0),
|
|
185
|
+
"output_cost_per_token": model_info.get('output_cost_per_token', 0),
|
|
186
|
+
"max_output_tokens": model_info.get('max_output_tokens', 0),
|
|
187
|
+
}]
|
|
188
|
+
})
|
|
189
|
+
except requests.exceptions.RequestException as e:
|
|
190
|
+
ASCIIColors.error(f"--- [LiteLLM Binding] Network error when trying to list models: {e}")
|
|
191
|
+
if "404" in str(e): return self._list_models_openai_fallback()
|
|
192
|
+
except Exception as e:
|
|
193
|
+
ASCIIColors.error(f"--- [LiteLLM Binding] An unexpected error occurred while listing models: {e}")
|
|
194
|
+
return entries
|
|
195
|
+
|
|
196
|
+
def get_model_info(self) -> dict:
|
|
197
|
+
return {"name": "LiteLLM", "host_address": self.host_address, "model_name": self.model_name}
|
|
198
|
+
|
|
199
|
+
def load_model(self, model_name: str) -> bool:
|
|
200
|
+
self.model_name = model_name
|
|
201
|
+
return True
|
|
@@ -4,6 +4,7 @@ from lollms_client.lollms_llm_binding import LollmsLLMBinding
|
|
|
4
4
|
from lollms_client.lollms_types import MSG_TYPE
|
|
5
5
|
from lollms_client.lollms_utilities import encode_image
|
|
6
6
|
from lollms_client.lollms_types import ELF_COMPLETION_FORMAT
|
|
7
|
+
from lollms_client.lollms_discussion import LollmsDiscussion
|
|
7
8
|
from ascii_colors import ASCIIColors, trace_exception
|
|
8
9
|
from typing import Optional, Callable, List, Union
|
|
9
10
|
import json
|
|
@@ -280,6 +281,7 @@ class LollmsLLMBinding(LollmsLLMBinding):
|
|
|
280
281
|
Returns:
|
|
281
282
|
list: List of tokens.
|
|
282
283
|
"""
|
|
284
|
+
response=None
|
|
283
285
|
try:
|
|
284
286
|
# Prepare the request payload
|
|
285
287
|
payload = {
|
|
@@ -0,0 +1,298 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Optional, Callable, List, Union, Dict
|
|
3
|
+
|
|
4
|
+
from lollms_client.lollms_discussion import LollmsDiscussion, LollmsMessage
|
|
5
|
+
from lollms_client.lollms_llm_binding import LollmsLLMBinding
|
|
6
|
+
from lollms_client.lollms_types import MSG_TYPE
|
|
7
|
+
from ascii_colors import ASCIIColors, trace_exception
|
|
8
|
+
|
|
9
|
+
import pipmaster as pm
|
|
10
|
+
|
|
11
|
+
# Ensure the required packages are installed
|
|
12
|
+
pm.ensure_packages(["mistralai", "pillow", "tiktoken"])
|
|
13
|
+
|
|
14
|
+
from mistralai.client import MistralClient
|
|
15
|
+
from mistralai.models.chat_completion import ChatMessage
|
|
16
|
+
from PIL import Image, ImageDraw
|
|
17
|
+
import tiktoken
|
|
18
|
+
|
|
19
|
+
BindingName = "MistralBinding"
|
|
20
|
+
|
|
21
|
+
class MistralBinding(LollmsLLMBinding):
|
|
22
|
+
"""
|
|
23
|
+
Mistral AI API binding implementation.
|
|
24
|
+
|
|
25
|
+
This binding allows communication with Mistral's API for both their
|
|
26
|
+
open-weight and proprietary models.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(self,
|
|
30
|
+
model_name: str = "mistral-large-latest",
|
|
31
|
+
mistral_api_key: str = None,
|
|
32
|
+
**kwargs
|
|
33
|
+
):
|
|
34
|
+
"""
|
|
35
|
+
Initialize the MistralBinding.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
model_name (str): The name of the Mistral model to use.
|
|
39
|
+
mistral_api_key (str): The API key for the Mistral service.
|
|
40
|
+
"""
|
|
41
|
+
super().__init__(binding_name=BindingName)
|
|
42
|
+
self.model_name = model_name
|
|
43
|
+
self.mistral_api_key = mistral_api_key or os.getenv("MISTRAL_API_KEY")
|
|
44
|
+
|
|
45
|
+
if not self.mistral_api_key:
|
|
46
|
+
raise ValueError("Mistral API key is required. Set it via 'mistral_api_key' or MISTRAL_API_KEY env var.")
|
|
47
|
+
|
|
48
|
+
try:
|
|
49
|
+
self.client = MistralClient(api_key=self.mistral_api_key)
|
|
50
|
+
except Exception as e:
|
|
51
|
+
ASCIIColors.error(f"Failed to configure Mistral client: {e}")
|
|
52
|
+
self.client = None
|
|
53
|
+
raise ConnectionError(f"Could not configure Mistral client: {e}") from e
|
|
54
|
+
|
|
55
|
+
def _construct_parameters(self,
|
|
56
|
+
temperature: float,
|
|
57
|
+
top_p: float,
|
|
58
|
+
n_predict: int,
|
|
59
|
+
seed: Optional[int]) -> Dict[str, any]:
|
|
60
|
+
"""Builds a parameters dictionary for the Mistral API."""
|
|
61
|
+
params = {}
|
|
62
|
+
if temperature is not None: params['temperature'] = float(temperature)
|
|
63
|
+
if top_p is not None: params['top_p'] = top_p
|
|
64
|
+
if n_predict is not None: params['max_tokens'] = n_predict
|
|
65
|
+
if seed is not None: params['random_seed'] = seed # Mistral uses 'random_seed'
|
|
66
|
+
return params
|
|
67
|
+
|
|
68
|
+
def _prepare_messages(self, discussion: LollmsDiscussion, branch_tip_id: Optional[str] = None) -> List[ChatMessage]:
|
|
69
|
+
"""Prepares the message list for the Mistral API from a LollmsDiscussion."""
|
|
70
|
+
history = []
|
|
71
|
+
if discussion.system_prompt:
|
|
72
|
+
# Mistral prefers the system prompt as the first message with a user/assistant turn.
|
|
73
|
+
# A lone system message is not ideal. We will prepend it to the first user message.
|
|
74
|
+
# However, for API consistency, we will treat it as a separate message if it exists.
|
|
75
|
+
# The official client will likely handle this.
|
|
76
|
+
history.append(ChatMessage(role="system", content=discussion.system_prompt))
|
|
77
|
+
|
|
78
|
+
for msg in discussion.get_messages(branch_tip_id):
|
|
79
|
+
role = 'user' if msg.sender_type == "user" else 'assistant'
|
|
80
|
+
# Note: Mistral API currently does not support image inputs via the chat endpoint.
|
|
81
|
+
if msg.content:
|
|
82
|
+
history.append(ChatMessage(role=role, content=msg.content))
|
|
83
|
+
return history
|
|
84
|
+
|
|
85
|
+
def generate_text(self, prompt: str, **kwargs) -> Union[str, dict]:
|
|
86
|
+
"""
|
|
87
|
+
Generate text using Mistral. This is a wrapper around the chat method.
|
|
88
|
+
"""
|
|
89
|
+
temp_discussion = LollmsDiscussion.from_messages([
|
|
90
|
+
LollmsMessage.new_message(sender_type="user", content=prompt)
|
|
91
|
+
])
|
|
92
|
+
if kwargs.get("system_prompt"):
|
|
93
|
+
temp_discussion.system_prompt = kwargs.get("system_prompt")
|
|
94
|
+
|
|
95
|
+
return self.chat(temp_discussion, **kwargs)
|
|
96
|
+
|
|
97
|
+
def chat(self,
|
|
98
|
+
discussion: LollmsDiscussion,
|
|
99
|
+
branch_tip_id: Optional[str] = None,
|
|
100
|
+
n_predict: Optional[int] = 2048,
|
|
101
|
+
stream: Optional[bool] = False,
|
|
102
|
+
temperature: float = 0.7,
|
|
103
|
+
top_p: float = 0.9,
|
|
104
|
+
seed: Optional[int] = None,
|
|
105
|
+
streaming_callback: Optional[Callable[[str, MSG_TYPE], None]] = None,
|
|
106
|
+
**kwargs
|
|
107
|
+
) -> Union[str, dict]:
|
|
108
|
+
"""
|
|
109
|
+
Conduct a chat session with a Mistral model.
|
|
110
|
+
"""
|
|
111
|
+
if not self.client:
|
|
112
|
+
return {"status": "error", "message": "Mistral client not initialized."}
|
|
113
|
+
|
|
114
|
+
messages = self._prepare_messages(discussion, branch_tip_id)
|
|
115
|
+
api_params = self._construct_parameters(temperature, top_p, n_predict, seed)
|
|
116
|
+
full_response_text = ""
|
|
117
|
+
|
|
118
|
+
try:
|
|
119
|
+
if stream:
|
|
120
|
+
response = self.client.chat_stream(
|
|
121
|
+
model=self.model_name,
|
|
122
|
+
messages=messages,
|
|
123
|
+
**api_params
|
|
124
|
+
)
|
|
125
|
+
for chunk in response:
|
|
126
|
+
delta = chunk.choices[0].delta.content
|
|
127
|
+
if delta:
|
|
128
|
+
full_response_text += delta
|
|
129
|
+
if streaming_callback:
|
|
130
|
+
if not streaming_callback(delta, MSG_TYPE.MSG_TYPE_CHUNK):
|
|
131
|
+
break
|
|
132
|
+
return full_response_text
|
|
133
|
+
else:
|
|
134
|
+
response = self.client.chat(
|
|
135
|
+
model=self.model_name,
|
|
136
|
+
messages=messages,
|
|
137
|
+
**api_params
|
|
138
|
+
)
|
|
139
|
+
return response.choices[0].message.content
|
|
140
|
+
|
|
141
|
+
except Exception as ex:
|
|
142
|
+
error_message = f"An unexpected error occurred with Mistral API: {str(ex)}"
|
|
143
|
+
trace_exception(ex)
|
|
144
|
+
return {"status": "error", "message": error_message}
|
|
145
|
+
|
|
146
|
+
def tokenize(self, text: str) -> list:
|
|
147
|
+
"""Tokenize text using tiktoken as a fallback."""
|
|
148
|
+
try:
|
|
149
|
+
encoding = tiktoken.get_encoding("cl100k_base")
|
|
150
|
+
return encoding.encode(text)
|
|
151
|
+
except Exception:
|
|
152
|
+
return list(text.encode('utf-8'))
|
|
153
|
+
|
|
154
|
+
def detokenize(self, tokens: list) -> str:
|
|
155
|
+
"""Detokenize tokens using tiktoken."""
|
|
156
|
+
try:
|
|
157
|
+
encoding = tiktoken.get_encoding("cl100k_base")
|
|
158
|
+
return encoding.decode(tokens)
|
|
159
|
+
except Exception:
|
|
160
|
+
return bytes(tokens).decode('utf-8', errors='ignore')
|
|
161
|
+
|
|
162
|
+
def count_tokens(self, text: str) -> int:
|
|
163
|
+
"""Count tokens in a text using the fallback tokenizer."""
|
|
164
|
+
return len(self.tokenize(text))
|
|
165
|
+
|
|
166
|
+
def embed(self, text: str, **kwargs) -> List[float]:
|
|
167
|
+
"""
|
|
168
|
+
Get embeddings for the input text using the Mistral embedding API.
|
|
169
|
+
"""
|
|
170
|
+
if not self.client:
|
|
171
|
+
raise Exception("Mistral client not initialized.")
|
|
172
|
+
|
|
173
|
+
# Default to the recommended embedding model
|
|
174
|
+
model_to_use = kwargs.get("model", "mistral-embed")
|
|
175
|
+
|
|
176
|
+
try:
|
|
177
|
+
response = self.client.embeddings(
|
|
178
|
+
model=model_to_use,
|
|
179
|
+
input=[text] # API expects a list of strings
|
|
180
|
+
)
|
|
181
|
+
return response.data[0].embedding
|
|
182
|
+
except Exception as ex:
|
|
183
|
+
trace_exception(ex)
|
|
184
|
+
raise Exception(f"Mistral embedding failed: {str(ex)}") from ex
|
|
185
|
+
|
|
186
|
+
def get_model_info(self) -> dict:
|
|
187
|
+
"""Return information about the current Mistral setup."""
|
|
188
|
+
return {
|
|
189
|
+
"name": self.binding_name,
|
|
190
|
+
"version": "unknown", # mistralai library doesn't expose a version attribute easily
|
|
191
|
+
"host_address": "https://api.mistral.ai",
|
|
192
|
+
"model_name": self.model_name,
|
|
193
|
+
"supports_structured_output": False,
|
|
194
|
+
"supports_vision": False, # Mistral API does not currently support vision
|
|
195
|
+
}
|
|
196
|
+
|
|
197
|
+
def listModels(self) -> List[Dict[str, str]]:
|
|
198
|
+
"""Lists available models from the Mistral service."""
|
|
199
|
+
if not self.client:
|
|
200
|
+
ASCIIColors.error("Mistral client not initialized. Cannot list models.")
|
|
201
|
+
return []
|
|
202
|
+
try:
|
|
203
|
+
ASCIIColors.debug("Listing Mistral models...")
|
|
204
|
+
models = self.client.list_models()
|
|
205
|
+
model_info_list = []
|
|
206
|
+
for m in models.data:
|
|
207
|
+
model_info_list.append({
|
|
208
|
+
'model_name': m.id,
|
|
209
|
+
'display_name': m.id.replace('-', ' ').title(),
|
|
210
|
+
'description': f"Owned by: {m.owned_by}",
|
|
211
|
+
'owned_by': m.owned_by
|
|
212
|
+
})
|
|
213
|
+
return model_info_list
|
|
214
|
+
except Exception as ex:
|
|
215
|
+
trace_exception(ex)
|
|
216
|
+
return []
|
|
217
|
+
|
|
218
|
+
def load_model(self, model_name: str) -> bool:
|
|
219
|
+
"""Sets the model name for subsequent operations."""
|
|
220
|
+
self.model_name = model_name
|
|
221
|
+
ASCIIColors.info(f"Mistral model set to: {model_name}. It will be used on the next API call.")
|
|
222
|
+
return True
|
|
223
|
+
|
|
224
|
+
if __name__ == '__main__':
|
|
225
|
+
# Environment variable to set for testing:
|
|
226
|
+
# MISTRAL_API_KEY: Your Mistral API key
|
|
227
|
+
|
|
228
|
+
if "MISTRAL_API_KEY" not in os.environ:
|
|
229
|
+
ASCIIColors.red("Error: MISTRAL_API_KEY environment variable not set.")
|
|
230
|
+
print("Please get your key from https://console.mistral.ai/api-keys/ and set it.")
|
|
231
|
+
exit(1)
|
|
232
|
+
|
|
233
|
+
ASCIIColors.yellow("--- Testing MistralBinding ---")
|
|
234
|
+
|
|
235
|
+
test_model_name = "mistral-small-latest" # Use a smaller, faster model for testing
|
|
236
|
+
test_embedding_model = "mistral-embed"
|
|
237
|
+
|
|
238
|
+
try:
|
|
239
|
+
# --- Initialization ---
|
|
240
|
+
ASCIIColors.cyan("\n--- Initializing Binding ---")
|
|
241
|
+
binding = MistralBinding(model_name=test_model_name)
|
|
242
|
+
ASCIIColors.green("Binding initialized successfully.")
|
|
243
|
+
|
|
244
|
+
# --- List Models ---
|
|
245
|
+
ASCIIColors.cyan("\n--- Listing Models ---")
|
|
246
|
+
models = binding.listModels()
|
|
247
|
+
if models:
|
|
248
|
+
ASCIIColors.green(f"Found {len(models)} models on Mistral. Available models:")
|
|
249
|
+
for m in models:
|
|
250
|
+
print(f"- {m['model_name']}")
|
|
251
|
+
else:
|
|
252
|
+
ASCIIColors.warning("No models found or failed to list models.")
|
|
253
|
+
|
|
254
|
+
# --- Text Generation (Non-Streaming) ---
|
|
255
|
+
ASCIIColors.cyan("\n--- Text Generation (Non-Streaming) ---")
|
|
256
|
+
prompt_text = "Who developed the transformer architecture and in what paper?"
|
|
257
|
+
generated_text = binding.generate_text(prompt_text, n_predict=100, stream=False)
|
|
258
|
+
if isinstance(generated_text, str):
|
|
259
|
+
ASCIIColors.green(f"Generated text:\n{generated_text}")
|
|
260
|
+
else:
|
|
261
|
+
ASCIIColors.error(f"Generation failed: {generated_text}")
|
|
262
|
+
|
|
263
|
+
# --- Text Generation (Streaming) ---
|
|
264
|
+
ASCIIColors.cyan("\n--- Text Generation (Streaming) ---")
|
|
265
|
+
full_streamed_text = ""
|
|
266
|
+
def stream_callback(chunk: str, msg_type: int):
|
|
267
|
+
nonlocal full_streamed_text
|
|
268
|
+
ASCIIColors.green(chunk, end="", flush=True)
|
|
269
|
+
full_streamed_text += chunk
|
|
270
|
+
return True
|
|
271
|
+
|
|
272
|
+
result = binding.generate_text(prompt_text, n_predict=150, stream=True, streaming_callback=stream_callback)
|
|
273
|
+
print("\n--- End of Stream ---")
|
|
274
|
+
ASCIIColors.green(f"Full streamed text (for verification): {result}")
|
|
275
|
+
|
|
276
|
+
# --- Embeddings Test ---
|
|
277
|
+
ASCIIColors.cyan("\n--- Embeddings ---")
|
|
278
|
+
try:
|
|
279
|
+
embedding_text = "Mistral AI is based in Paris."
|
|
280
|
+
embedding_vector = binding.embed(embedding_text, model=test_embedding_model)
|
|
281
|
+
ASCIIColors.green(f"Embedding for '{embedding_text}' (first 5 dims): {embedding_vector[:5]}...")
|
|
282
|
+
ASCIIColors.info(f"Embedding vector dimension: {len(embedding_vector)}")
|
|
283
|
+
except Exception as e:
|
|
284
|
+
ASCIIColors.error(f"Embedding test failed: {e}")
|
|
285
|
+
|
|
286
|
+
# --- Vision Test (should be unsupported) ---
|
|
287
|
+
ASCIIColors.cyan("\n--- Vision Test (Expecting No Support) ---")
|
|
288
|
+
model_info = binding.get_model_info()
|
|
289
|
+
if not model_info.get("supports_vision"):
|
|
290
|
+
ASCIIColors.green("Binding correctly reports no support for vision.")
|
|
291
|
+
else:
|
|
292
|
+
ASCIIColors.warning("Binding reports support for vision, which is unexpected for Mistral.")
|
|
293
|
+
|
|
294
|
+
except Exception as e:
|
|
295
|
+
ASCIIColors.error(f"An error occurred during testing: {e}")
|
|
296
|
+
trace_exception(e)
|
|
297
|
+
|
|
298
|
+
ASCIIColors.yellow("\nMistralBinding test finished.")
|