lollms-client 0.25.5__py3-none-any.whl → 0.26.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lollms-client might be problematic. Click here for more details.
- lollms_client/__init__.py +1 -1
- lollms_client/llm_bindings/azure_openai/__init__.py +364 -0
- lollms_client/llm_bindings/claude/__init__.py +549 -0
- lollms_client/llm_bindings/groq/__init__.py +292 -0
- lollms_client/llm_bindings/hugging_face_inference_api/__init__.py +307 -0
- lollms_client/llm_bindings/lollms/__init__.py +1 -0
- lollms_client/llm_bindings/mistral/__init__.py +298 -0
- lollms_client/llm_bindings/open_router/__init__.py +304 -0
- lollms_client/lollms_core.py +2 -2
- lollms_client/lollms_discussion.py +16 -20
- {lollms_client-0.25.5.dist-info → lollms_client-0.26.0.dist-info}/METADATA +366 -1
- {lollms_client-0.25.5.dist-info → lollms_client-0.26.0.dist-info}/RECORD +15 -9
- {lollms_client-0.25.5.dist-info → lollms_client-0.26.0.dist-info}/WHEEL +0 -0
- {lollms_client-0.25.5.dist-info → lollms_client-0.26.0.dist-info}/licenses/LICENSE +0 -0
- {lollms_client-0.25.5.dist-info → lollms_client-0.26.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,298 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Optional, Callable, List, Union, Dict
|
|
3
|
+
|
|
4
|
+
from lollms_client.lollms_discussion import LollmsDiscussion, LollmsMessage
|
|
5
|
+
from lollms_client.lollms_llm_binding import LollmsLLMBinding
|
|
6
|
+
from lollms_client.lollms_types import MSG_TYPE
|
|
7
|
+
from ascii_colors import ASCIIColors, trace_exception
|
|
8
|
+
|
|
9
|
+
import pipmaster as pm
|
|
10
|
+
|
|
11
|
+
# Ensure the required packages are installed
|
|
12
|
+
pm.ensure_packages(["mistralai", "pillow", "tiktoken"])
|
|
13
|
+
|
|
14
|
+
from mistralai.client import MistralClient
|
|
15
|
+
from mistralai.models.chat_completion import ChatMessage
|
|
16
|
+
from PIL import Image, ImageDraw
|
|
17
|
+
import tiktoken
|
|
18
|
+
|
|
19
|
+
BindingName = "MistralBinding"
|
|
20
|
+
|
|
21
|
+
class MistralBinding(LollmsLLMBinding):
|
|
22
|
+
"""
|
|
23
|
+
Mistral AI API binding implementation.
|
|
24
|
+
|
|
25
|
+
This binding allows communication with Mistral's API for both their
|
|
26
|
+
open-weight and proprietary models.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(self,
|
|
30
|
+
model_name: str = "mistral-large-latest",
|
|
31
|
+
mistral_api_key: str = None,
|
|
32
|
+
**kwargs
|
|
33
|
+
):
|
|
34
|
+
"""
|
|
35
|
+
Initialize the MistralBinding.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
model_name (str): The name of the Mistral model to use.
|
|
39
|
+
mistral_api_key (str): The API key for the Mistral service.
|
|
40
|
+
"""
|
|
41
|
+
super().__init__(binding_name=BindingName)
|
|
42
|
+
self.model_name = model_name
|
|
43
|
+
self.mistral_api_key = mistral_api_key or os.getenv("MISTRAL_API_KEY")
|
|
44
|
+
|
|
45
|
+
if not self.mistral_api_key:
|
|
46
|
+
raise ValueError("Mistral API key is required. Set it via 'mistral_api_key' or MISTRAL_API_KEY env var.")
|
|
47
|
+
|
|
48
|
+
try:
|
|
49
|
+
self.client = MistralClient(api_key=self.mistral_api_key)
|
|
50
|
+
except Exception as e:
|
|
51
|
+
ASCIIColors.error(f"Failed to configure Mistral client: {e}")
|
|
52
|
+
self.client = None
|
|
53
|
+
raise ConnectionError(f"Could not configure Mistral client: {e}") from e
|
|
54
|
+
|
|
55
|
+
def _construct_parameters(self,
|
|
56
|
+
temperature: float,
|
|
57
|
+
top_p: float,
|
|
58
|
+
n_predict: int,
|
|
59
|
+
seed: Optional[int]) -> Dict[str, any]:
|
|
60
|
+
"""Builds a parameters dictionary for the Mistral API."""
|
|
61
|
+
params = {}
|
|
62
|
+
if temperature is not None: params['temperature'] = float(temperature)
|
|
63
|
+
if top_p is not None: params['top_p'] = top_p
|
|
64
|
+
if n_predict is not None: params['max_tokens'] = n_predict
|
|
65
|
+
if seed is not None: params['random_seed'] = seed # Mistral uses 'random_seed'
|
|
66
|
+
return params
|
|
67
|
+
|
|
68
|
+
def _prepare_messages(self, discussion: LollmsDiscussion, branch_tip_id: Optional[str] = None) -> List[ChatMessage]:
|
|
69
|
+
"""Prepares the message list for the Mistral API from a LollmsDiscussion."""
|
|
70
|
+
history = []
|
|
71
|
+
if discussion.system_prompt:
|
|
72
|
+
# Mistral prefers the system prompt as the first message with a user/assistant turn.
|
|
73
|
+
# A lone system message is not ideal. We will prepend it to the first user message.
|
|
74
|
+
# However, for API consistency, we will treat it as a separate message if it exists.
|
|
75
|
+
# The official client will likely handle this.
|
|
76
|
+
history.append(ChatMessage(role="system", content=discussion.system_prompt))
|
|
77
|
+
|
|
78
|
+
for msg in discussion.get_messages(branch_tip_id):
|
|
79
|
+
role = 'user' if msg.sender_type == "user" else 'assistant'
|
|
80
|
+
# Note: Mistral API currently does not support image inputs via the chat endpoint.
|
|
81
|
+
if msg.content:
|
|
82
|
+
history.append(ChatMessage(role=role, content=msg.content))
|
|
83
|
+
return history
|
|
84
|
+
|
|
85
|
+
def generate_text(self, prompt: str, **kwargs) -> Union[str, dict]:
|
|
86
|
+
"""
|
|
87
|
+
Generate text using Mistral. This is a wrapper around the chat method.
|
|
88
|
+
"""
|
|
89
|
+
temp_discussion = LollmsDiscussion.from_messages([
|
|
90
|
+
LollmsMessage.new_message(sender_type="user", content=prompt)
|
|
91
|
+
])
|
|
92
|
+
if kwargs.get("system_prompt"):
|
|
93
|
+
temp_discussion.system_prompt = kwargs.get("system_prompt")
|
|
94
|
+
|
|
95
|
+
return self.chat(temp_discussion, **kwargs)
|
|
96
|
+
|
|
97
|
+
def chat(self,
|
|
98
|
+
discussion: LollmsDiscussion,
|
|
99
|
+
branch_tip_id: Optional[str] = None,
|
|
100
|
+
n_predict: Optional[int] = 2048,
|
|
101
|
+
stream: Optional[bool] = False,
|
|
102
|
+
temperature: float = 0.7,
|
|
103
|
+
top_p: float = 0.9,
|
|
104
|
+
seed: Optional[int] = None,
|
|
105
|
+
streaming_callback: Optional[Callable[[str, MSG_TYPE], None]] = None,
|
|
106
|
+
**kwargs
|
|
107
|
+
) -> Union[str, dict]:
|
|
108
|
+
"""
|
|
109
|
+
Conduct a chat session with a Mistral model.
|
|
110
|
+
"""
|
|
111
|
+
if not self.client:
|
|
112
|
+
return {"status": "error", "message": "Mistral client not initialized."}
|
|
113
|
+
|
|
114
|
+
messages = self._prepare_messages(discussion, branch_tip_id)
|
|
115
|
+
api_params = self._construct_parameters(temperature, top_p, n_predict, seed)
|
|
116
|
+
full_response_text = ""
|
|
117
|
+
|
|
118
|
+
try:
|
|
119
|
+
if stream:
|
|
120
|
+
response = self.client.chat_stream(
|
|
121
|
+
model=self.model_name,
|
|
122
|
+
messages=messages,
|
|
123
|
+
**api_params
|
|
124
|
+
)
|
|
125
|
+
for chunk in response:
|
|
126
|
+
delta = chunk.choices[0].delta.content
|
|
127
|
+
if delta:
|
|
128
|
+
full_response_text += delta
|
|
129
|
+
if streaming_callback:
|
|
130
|
+
if not streaming_callback(delta, MSG_TYPE.MSG_TYPE_CHUNK):
|
|
131
|
+
break
|
|
132
|
+
return full_response_text
|
|
133
|
+
else:
|
|
134
|
+
response = self.client.chat(
|
|
135
|
+
model=self.model_name,
|
|
136
|
+
messages=messages,
|
|
137
|
+
**api_params
|
|
138
|
+
)
|
|
139
|
+
return response.choices[0].message.content
|
|
140
|
+
|
|
141
|
+
except Exception as ex:
|
|
142
|
+
error_message = f"An unexpected error occurred with Mistral API: {str(ex)}"
|
|
143
|
+
trace_exception(ex)
|
|
144
|
+
return {"status": "error", "message": error_message}
|
|
145
|
+
|
|
146
|
+
def tokenize(self, text: str) -> list:
|
|
147
|
+
"""Tokenize text using tiktoken as a fallback."""
|
|
148
|
+
try:
|
|
149
|
+
encoding = tiktoken.get_encoding("cl100k_base")
|
|
150
|
+
return encoding.encode(text)
|
|
151
|
+
except Exception:
|
|
152
|
+
return list(text.encode('utf-8'))
|
|
153
|
+
|
|
154
|
+
def detokenize(self, tokens: list) -> str:
|
|
155
|
+
"""Detokenize tokens using tiktoken."""
|
|
156
|
+
try:
|
|
157
|
+
encoding = tiktoken.get_encoding("cl100k_base")
|
|
158
|
+
return encoding.decode(tokens)
|
|
159
|
+
except Exception:
|
|
160
|
+
return bytes(tokens).decode('utf-8', errors='ignore')
|
|
161
|
+
|
|
162
|
+
def count_tokens(self, text: str) -> int:
|
|
163
|
+
"""Count tokens in a text using the fallback tokenizer."""
|
|
164
|
+
return len(self.tokenize(text))
|
|
165
|
+
|
|
166
|
+
def embed(self, text: str, **kwargs) -> List[float]:
|
|
167
|
+
"""
|
|
168
|
+
Get embeddings for the input text using the Mistral embedding API.
|
|
169
|
+
"""
|
|
170
|
+
if not self.client:
|
|
171
|
+
raise Exception("Mistral client not initialized.")
|
|
172
|
+
|
|
173
|
+
# Default to the recommended embedding model
|
|
174
|
+
model_to_use = kwargs.get("model", "mistral-embed")
|
|
175
|
+
|
|
176
|
+
try:
|
|
177
|
+
response = self.client.embeddings(
|
|
178
|
+
model=model_to_use,
|
|
179
|
+
input=[text] # API expects a list of strings
|
|
180
|
+
)
|
|
181
|
+
return response.data[0].embedding
|
|
182
|
+
except Exception as ex:
|
|
183
|
+
trace_exception(ex)
|
|
184
|
+
raise Exception(f"Mistral embedding failed: {str(ex)}") from ex
|
|
185
|
+
|
|
186
|
+
def get_model_info(self) -> dict:
|
|
187
|
+
"""Return information about the current Mistral setup."""
|
|
188
|
+
return {
|
|
189
|
+
"name": self.binding_name,
|
|
190
|
+
"version": "unknown", # mistralai library doesn't expose a version attribute easily
|
|
191
|
+
"host_address": "https://api.mistral.ai",
|
|
192
|
+
"model_name": self.model_name,
|
|
193
|
+
"supports_structured_output": False,
|
|
194
|
+
"supports_vision": False, # Mistral API does not currently support vision
|
|
195
|
+
}
|
|
196
|
+
|
|
197
|
+
def listModels(self) -> List[Dict[str, str]]:
|
|
198
|
+
"""Lists available models from the Mistral service."""
|
|
199
|
+
if not self.client:
|
|
200
|
+
ASCIIColors.error("Mistral client not initialized. Cannot list models.")
|
|
201
|
+
return []
|
|
202
|
+
try:
|
|
203
|
+
ASCIIColors.debug("Listing Mistral models...")
|
|
204
|
+
models = self.client.list_models()
|
|
205
|
+
model_info_list = []
|
|
206
|
+
for m in models.data:
|
|
207
|
+
model_info_list.append({
|
|
208
|
+
'model_name': m.id,
|
|
209
|
+
'display_name': m.id.replace('-', ' ').title(),
|
|
210
|
+
'description': f"Owned by: {m.owned_by}",
|
|
211
|
+
'owned_by': m.owned_by
|
|
212
|
+
})
|
|
213
|
+
return model_info_list
|
|
214
|
+
except Exception as ex:
|
|
215
|
+
trace_exception(ex)
|
|
216
|
+
return []
|
|
217
|
+
|
|
218
|
+
def load_model(self, model_name: str) -> bool:
|
|
219
|
+
"""Sets the model name for subsequent operations."""
|
|
220
|
+
self.model_name = model_name
|
|
221
|
+
ASCIIColors.info(f"Mistral model set to: {model_name}. It will be used on the next API call.")
|
|
222
|
+
return True
|
|
223
|
+
|
|
224
|
+
if __name__ == '__main__':
|
|
225
|
+
# Environment variable to set for testing:
|
|
226
|
+
# MISTRAL_API_KEY: Your Mistral API key
|
|
227
|
+
|
|
228
|
+
if "MISTRAL_API_KEY" not in os.environ:
|
|
229
|
+
ASCIIColors.red("Error: MISTRAL_API_KEY environment variable not set.")
|
|
230
|
+
print("Please get your key from https://console.mistral.ai/api-keys/ and set it.")
|
|
231
|
+
exit(1)
|
|
232
|
+
|
|
233
|
+
ASCIIColors.yellow("--- Testing MistralBinding ---")
|
|
234
|
+
|
|
235
|
+
test_model_name = "mistral-small-latest" # Use a smaller, faster model for testing
|
|
236
|
+
test_embedding_model = "mistral-embed"
|
|
237
|
+
|
|
238
|
+
try:
|
|
239
|
+
# --- Initialization ---
|
|
240
|
+
ASCIIColors.cyan("\n--- Initializing Binding ---")
|
|
241
|
+
binding = MistralBinding(model_name=test_model_name)
|
|
242
|
+
ASCIIColors.green("Binding initialized successfully.")
|
|
243
|
+
|
|
244
|
+
# --- List Models ---
|
|
245
|
+
ASCIIColors.cyan("\n--- Listing Models ---")
|
|
246
|
+
models = binding.listModels()
|
|
247
|
+
if models:
|
|
248
|
+
ASCIIColors.green(f"Found {len(models)} models on Mistral. Available models:")
|
|
249
|
+
for m in models:
|
|
250
|
+
print(f"- {m['model_name']}")
|
|
251
|
+
else:
|
|
252
|
+
ASCIIColors.warning("No models found or failed to list models.")
|
|
253
|
+
|
|
254
|
+
# --- Text Generation (Non-Streaming) ---
|
|
255
|
+
ASCIIColors.cyan("\n--- Text Generation (Non-Streaming) ---")
|
|
256
|
+
prompt_text = "Who developed the transformer architecture and in what paper?"
|
|
257
|
+
generated_text = binding.generate_text(prompt_text, n_predict=100, stream=False)
|
|
258
|
+
if isinstance(generated_text, str):
|
|
259
|
+
ASCIIColors.green(f"Generated text:\n{generated_text}")
|
|
260
|
+
else:
|
|
261
|
+
ASCIIColors.error(f"Generation failed: {generated_text}")
|
|
262
|
+
|
|
263
|
+
# --- Text Generation (Streaming) ---
|
|
264
|
+
ASCIIColors.cyan("\n--- Text Generation (Streaming) ---")
|
|
265
|
+
full_streamed_text = ""
|
|
266
|
+
def stream_callback(chunk: str, msg_type: int):
|
|
267
|
+
nonlocal full_streamed_text
|
|
268
|
+
ASCIIColors.green(chunk, end="", flush=True)
|
|
269
|
+
full_streamed_text += chunk
|
|
270
|
+
return True
|
|
271
|
+
|
|
272
|
+
result = binding.generate_text(prompt_text, n_predict=150, stream=True, streaming_callback=stream_callback)
|
|
273
|
+
print("\n--- End of Stream ---")
|
|
274
|
+
ASCIIColors.green(f"Full streamed text (for verification): {result}")
|
|
275
|
+
|
|
276
|
+
# --- Embeddings Test ---
|
|
277
|
+
ASCIIColors.cyan("\n--- Embeddings ---")
|
|
278
|
+
try:
|
|
279
|
+
embedding_text = "Mistral AI is based in Paris."
|
|
280
|
+
embedding_vector = binding.embed(embedding_text, model=test_embedding_model)
|
|
281
|
+
ASCIIColors.green(f"Embedding for '{embedding_text}' (first 5 dims): {embedding_vector[:5]}...")
|
|
282
|
+
ASCIIColors.info(f"Embedding vector dimension: {len(embedding_vector)}")
|
|
283
|
+
except Exception as e:
|
|
284
|
+
ASCIIColors.error(f"Embedding test failed: {e}")
|
|
285
|
+
|
|
286
|
+
# --- Vision Test (should be unsupported) ---
|
|
287
|
+
ASCIIColors.cyan("\n--- Vision Test (Expecting No Support) ---")
|
|
288
|
+
model_info = binding.get_model_info()
|
|
289
|
+
if not model_info.get("supports_vision"):
|
|
290
|
+
ASCIIColors.green("Binding correctly reports no support for vision.")
|
|
291
|
+
else:
|
|
292
|
+
ASCIIColors.warning("Binding reports support for vision, which is unexpected for Mistral.")
|
|
293
|
+
|
|
294
|
+
except Exception as e:
|
|
295
|
+
ASCIIColors.error(f"An error occurred during testing: {e}")
|
|
296
|
+
trace_exception(e)
|
|
297
|
+
|
|
298
|
+
ASCIIColors.yellow("\nMistralBinding test finished.")
|
|
@@ -0,0 +1,304 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Optional, Callable, List, Union, Dict
|
|
3
|
+
|
|
4
|
+
from lollms_client.lollms_discussion import LollmsDiscussion, LollmsMessage
|
|
5
|
+
from lollms_client.lollms_llm_binding import LollmsLLMBinding
|
|
6
|
+
from lollms_client.lollms_types import MSG_TYPE
|
|
7
|
+
from ascii_colors import ASCIIColors, trace_exception
|
|
8
|
+
|
|
9
|
+
import pipmaster as pm
|
|
10
|
+
|
|
11
|
+
# Ensure the required packages are installed
|
|
12
|
+
pm.ensure_packages(["openai", "pillow", "tiktoken"])
|
|
13
|
+
|
|
14
|
+
import openai
|
|
15
|
+
from PIL import Image, ImageDraw
|
|
16
|
+
import tiktoken
|
|
17
|
+
|
|
18
|
+
BindingName = "OpenRouterBinding"
|
|
19
|
+
|
|
20
|
+
class OpenRouterBinding(LollmsLLMBinding):
|
|
21
|
+
"""
|
|
22
|
+
OpenRouter API binding implementation.
|
|
23
|
+
|
|
24
|
+
This binding allows communication with the OpenRouter service, which acts as a
|
|
25
|
+
aggregator for a vast number of AI models from different providers. It uses
|
|
26
|
+
an OpenAI-compatible API structure.
|
|
27
|
+
"""
|
|
28
|
+
BASE_URL = "https://openrouter.ai/api/v1"
|
|
29
|
+
|
|
30
|
+
def __init__(self,
|
|
31
|
+
model_name: str = "google/gemini-flash-1.5", # A good, fast default
|
|
32
|
+
open_router_api_key: str = None,
|
|
33
|
+
**kwargs
|
|
34
|
+
):
|
|
35
|
+
"""
|
|
36
|
+
Initialize the OpenRouterBinding.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
model_name (str): The name of the model to use from OpenRouter (e.g., 'anthropic/claude-3-haiku-20240307').
|
|
40
|
+
open_router_api_key (str): The API key for the OpenRouter service.
|
|
41
|
+
"""
|
|
42
|
+
super().__init__(binding_name=BindingName)
|
|
43
|
+
self.model_name = model_name
|
|
44
|
+
self.api_key = open_router_api_key or os.getenv("OPENROUTER_API_KEY")
|
|
45
|
+
|
|
46
|
+
if not self.api_key:
|
|
47
|
+
raise ValueError("OpenRouter API key is required. Set it via 'open_router_api_key' or OPENROUTER_API_KEY env var.")
|
|
48
|
+
|
|
49
|
+
try:
|
|
50
|
+
self.client = openai.OpenAI(
|
|
51
|
+
base_url=self.BASE_URL,
|
|
52
|
+
api_key=self.api_key,
|
|
53
|
+
)
|
|
54
|
+
except Exception as e:
|
|
55
|
+
ASCIIColors.error(f"Failed to configure OpenRouter client: {e}")
|
|
56
|
+
self.client = None
|
|
57
|
+
raise ConnectionError(f"Could not configure OpenRouter client: {e}") from e
|
|
58
|
+
|
|
59
|
+
def _construct_parameters(self,
|
|
60
|
+
temperature: float,
|
|
61
|
+
top_p: float,
|
|
62
|
+
n_predict: int,
|
|
63
|
+
seed: Optional[int]) -> Dict[str, any]:
|
|
64
|
+
"""Builds a parameters dictionary for the API."""
|
|
65
|
+
params = {}
|
|
66
|
+
if temperature is not None: params['temperature'] = float(temperature)
|
|
67
|
+
if top_p is not None: params['top_p'] = top_p
|
|
68
|
+
if n_predict is not None: params['max_tokens'] = n_predict
|
|
69
|
+
if seed is not None: params['seed'] = seed
|
|
70
|
+
return params
|
|
71
|
+
|
|
72
|
+
def _prepare_messages(self, discussion: LollmsDiscussion, branch_tip_id: Optional[str] = None) -> List[Dict[str, any]]:
|
|
73
|
+
"""Prepares the message list for the API from a LollmsDiscussion."""
|
|
74
|
+
history = []
|
|
75
|
+
if discussion.system_prompt:
|
|
76
|
+
history.append({"role": "system", "content": discussion.system_prompt})
|
|
77
|
+
|
|
78
|
+
for msg in discussion.get_messages(branch_tip_id):
|
|
79
|
+
role = 'user' if msg.sender_type == "user" else 'assistant'
|
|
80
|
+
# Note: Vision support depends on the specific model being called via OpenRouter.
|
|
81
|
+
# We will not implement it in this generic binding to avoid complexity,
|
|
82
|
+
# as different models might expect different formats.
|
|
83
|
+
if msg.content:
|
|
84
|
+
history.append({'role': role, 'content': msg.content})
|
|
85
|
+
return history
|
|
86
|
+
|
|
87
|
+
def generate_text(self, prompt: str, **kwargs) -> Union[str, dict]:
|
|
88
|
+
"""
|
|
89
|
+
Generate text using OpenRouter. This is a wrapper around the chat method.
|
|
90
|
+
"""
|
|
91
|
+
temp_discussion = LollmsDiscussion.from_messages([
|
|
92
|
+
LollmsMessage.new_message(sender_type="user", content=prompt)
|
|
93
|
+
])
|
|
94
|
+
if kwargs.get("system_prompt"):
|
|
95
|
+
temp_discussion.system_prompt = kwargs.get("system_prompt")
|
|
96
|
+
|
|
97
|
+
return self.chat(temp_discussion, **kwargs)
|
|
98
|
+
|
|
99
|
+
def chat(self,
|
|
100
|
+
discussion: LollmsDiscussion,
|
|
101
|
+
branch_tip_id: Optional[str] = None,
|
|
102
|
+
n_predict: Optional[int] = 2048,
|
|
103
|
+
stream: Optional[bool] = False,
|
|
104
|
+
temperature: float = 0.7,
|
|
105
|
+
top_p: float = 0.9,
|
|
106
|
+
seed: Optional[int] = None,
|
|
107
|
+
streaming_callback: Optional[Callable[[str, MSG_TYPE], None]] = None,
|
|
108
|
+
**kwargs
|
|
109
|
+
) -> Union[str, dict]:
|
|
110
|
+
"""
|
|
111
|
+
Conduct a chat session with a model via OpenRouter.
|
|
112
|
+
"""
|
|
113
|
+
if not self.client:
|
|
114
|
+
return {"status": "error", "message": "OpenRouter client not initialized."}
|
|
115
|
+
|
|
116
|
+
messages = self._prepare_messages(discussion, branch_tip_id)
|
|
117
|
+
api_params = self._construct_parameters(temperature, top_p, n_predict, seed)
|
|
118
|
+
full_response_text = ""
|
|
119
|
+
|
|
120
|
+
try:
|
|
121
|
+
response = self.client.chat.completions.create(
|
|
122
|
+
model=self.model_name,
|
|
123
|
+
messages=messages,
|
|
124
|
+
stream=stream,
|
|
125
|
+
**api_params
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
if stream:
|
|
129
|
+
for chunk in response:
|
|
130
|
+
delta = chunk.choices[0].delta.content
|
|
131
|
+
if delta:
|
|
132
|
+
full_response_text += delta
|
|
133
|
+
if streaming_callback:
|
|
134
|
+
if not streaming_callback(delta, MSG_TYPE.MSG_TYPE_CHUNK):
|
|
135
|
+
break
|
|
136
|
+
return full_response_text
|
|
137
|
+
else:
|
|
138
|
+
return response.choices[0].message.content
|
|
139
|
+
|
|
140
|
+
except Exception as ex:
|
|
141
|
+
error_message = f"An unexpected error occurred with OpenRouter API: {str(ex)}"
|
|
142
|
+
trace_exception(ex)
|
|
143
|
+
return {"status": "error", "message": error_message}
|
|
144
|
+
|
|
145
|
+
def tokenize(self, text: str) -> list:
|
|
146
|
+
"""Tokenize text using tiktoken as a general-purpose fallback."""
|
|
147
|
+
try:
|
|
148
|
+
encoding = tiktoken.get_encoding("cl100k_base")
|
|
149
|
+
return encoding.encode(text)
|
|
150
|
+
except Exception:
|
|
151
|
+
return list(text.encode('utf-8'))
|
|
152
|
+
|
|
153
|
+
def detokenize(self, tokens: list) -> str:
|
|
154
|
+
"""Detokenize tokens using tiktoken."""
|
|
155
|
+
try:
|
|
156
|
+
encoding = tiktoken.get_encoding("cl100k_base")
|
|
157
|
+
return encoding.decode(tokens)
|
|
158
|
+
except Exception:
|
|
159
|
+
return bytes(tokens).decode('utf-8', errors='ignore')
|
|
160
|
+
|
|
161
|
+
def count_tokens(self, text: str) -> int:
|
|
162
|
+
"""Count tokens in a text using the fallback tokenizer."""
|
|
163
|
+
return len(self.tokenize(text))
|
|
164
|
+
|
|
165
|
+
def embed(self, text: str, **kwargs) -> List[float]:
|
|
166
|
+
"""
|
|
167
|
+
Get embeddings for the input text using an OpenRouter embedding model.
|
|
168
|
+
"""
|
|
169
|
+
if not self.client:
|
|
170
|
+
raise Exception("OpenRouter client not initialized.")
|
|
171
|
+
|
|
172
|
+
# User must specify an embedding model, e.g., 'text-embedding-ada-002'
|
|
173
|
+
embedding_model = kwargs.get("model")
|
|
174
|
+
if not embedding_model:
|
|
175
|
+
raise ValueError("An embedding model name must be provided via the 'model' kwarg for the embed method.")
|
|
176
|
+
|
|
177
|
+
try:
|
|
178
|
+
# The client is already configured for OpenRouter's base URL
|
|
179
|
+
response = self.client.embeddings.create(
|
|
180
|
+
model=embedding_model,
|
|
181
|
+
input=text
|
|
182
|
+
)
|
|
183
|
+
return response.data[0].embedding
|
|
184
|
+
except Exception as ex:
|
|
185
|
+
trace_exception(ex)
|
|
186
|
+
raise Exception(f"OpenRouter embedding failed: {str(ex)}") from ex
|
|
187
|
+
|
|
188
|
+
def get_model_info(self) -> dict:
|
|
189
|
+
"""Return information about the current OpenRouter setup."""
|
|
190
|
+
return {
|
|
191
|
+
"name": self.binding_name,
|
|
192
|
+
"version": openai.__version__,
|
|
193
|
+
"host_address": self.BASE_URL,
|
|
194
|
+
"model_name": self.model_name,
|
|
195
|
+
"supports_structured_output": False,
|
|
196
|
+
"supports_vision": "Depends on the specific model selected. This generic binding does not support vision.",
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
def listModels(self) -> List[Dict[str, str]]:
|
|
200
|
+
"""Lists available models from the OpenRouter service."""
|
|
201
|
+
if not self.client:
|
|
202
|
+
ASCIIColors.error("OpenRouter client not initialized. Cannot list models.")
|
|
203
|
+
return []
|
|
204
|
+
try:
|
|
205
|
+
ASCIIColors.debug("Listing OpenRouter models...")
|
|
206
|
+
models = self.client.models.list()
|
|
207
|
+
model_info_list = []
|
|
208
|
+
for m in models.data:
|
|
209
|
+
model_info_list.append({
|
|
210
|
+
'model_name': m.id,
|
|
211
|
+
'display_name': m.name if hasattr(m, 'name') else m.id,
|
|
212
|
+
'description': m.description if hasattr(m, 'description') else "No description available.",
|
|
213
|
+
'owned_by': m.id.split('/')[0] # Heuristic to get the provider
|
|
214
|
+
})
|
|
215
|
+
return model_info_list
|
|
216
|
+
except Exception as ex:
|
|
217
|
+
trace_exception(ex)
|
|
218
|
+
return []
|
|
219
|
+
|
|
220
|
+
def load_model(self, model_name: str) -> bool:
|
|
221
|
+
"""Sets the model name for subsequent operations."""
|
|
222
|
+
self.model_name = model_name
|
|
223
|
+
ASCIIColors.info(f"OpenRouter model set to: {model_name}. It will be used on the next API call.")
|
|
224
|
+
return True
|
|
225
|
+
|
|
226
|
+
if __name__ == '__main__':
|
|
227
|
+
# Environment variable to set for testing:
|
|
228
|
+
# OPENROUTER_API_KEY: Your OpenRouter API key (starts with sk-or-...)
|
|
229
|
+
|
|
230
|
+
if "OPENROUTER_API_KEY" not in os.environ:
|
|
231
|
+
ASCIIColors.red("Error: OPENROUTER_API_KEY environment variable not set.")
|
|
232
|
+
print("Please get your key from https://openrouter.ai/keys and set it.")
|
|
233
|
+
exit(1)
|
|
234
|
+
|
|
235
|
+
ASCIIColors.yellow("--- Testing OpenRouterBinding ---")
|
|
236
|
+
|
|
237
|
+
try:
|
|
238
|
+
# --- Initialization ---
|
|
239
|
+
ASCIIColors.cyan("\n--- Initializing Binding ---")
|
|
240
|
+
# Initialize with a fast, cheap, and well-known model
|
|
241
|
+
binding = OpenRouterBinding(model_name="mistralai/mistral-7b-instruct")
|
|
242
|
+
ASCIIColors.green("Binding initialized successfully.")
|
|
243
|
+
|
|
244
|
+
# --- List Models ---
|
|
245
|
+
ASCIIColors.cyan("\n--- Listing Models ---")
|
|
246
|
+
models = binding.listModels()
|
|
247
|
+
if models:
|
|
248
|
+
ASCIIColors.green(f"Successfully fetched {len(models)} models from OpenRouter.")
|
|
249
|
+
ASCIIColors.info("Sample of available models:")
|
|
250
|
+
# Print a few examples from different providers
|
|
251
|
+
providers_seen = set()
|
|
252
|
+
count = 0
|
|
253
|
+
for m in models:
|
|
254
|
+
provider = m['owned_by']
|
|
255
|
+
if provider not in providers_seen:
|
|
256
|
+
print(f"- {m['model_name']}")
|
|
257
|
+
providers_seen.add(provider)
|
|
258
|
+
count += 1
|
|
259
|
+
if count >= 5:
|
|
260
|
+
break
|
|
261
|
+
else:
|
|
262
|
+
ASCIIColors.warning("No models found or failed to list models.")
|
|
263
|
+
|
|
264
|
+
# --- Text Generation (Testing with a Claude model) ---
|
|
265
|
+
ASCIIColors.cyan("\n--- Text Generation (Claude via OpenRouter) ---")
|
|
266
|
+
binding.load_model("anthropic/claude-3-haiku-20240307")
|
|
267
|
+
prompt_text = "Why is Claude Haiku a good choice for fast-paced chat applications?"
|
|
268
|
+
generated_text = binding.generate_text(prompt_text, n_predict=100, stream=False)
|
|
269
|
+
if isinstance(generated_text, str):
|
|
270
|
+
ASCIIColors.green(f"Generated text:\n{generated_text}")
|
|
271
|
+
else:
|
|
272
|
+
ASCIIColors.error(f"Generation failed: {generated_text}")
|
|
273
|
+
|
|
274
|
+
# --- Text Generation (Streaming with a Groq model) ---
|
|
275
|
+
ASCIIColors.cyan("\n--- Text Generation (Llama3 on Groq via OpenRouter) ---")
|
|
276
|
+
binding.load_model("meta-llama/llama-3-8b-instruct:free") # Use the free tier on OpenRouter
|
|
277
|
+
full_streamed_text = ""
|
|
278
|
+
def stream_callback(chunk: str, msg_type: int):
|
|
279
|
+
nonlocal full_streamed_text
|
|
280
|
+
ASCIIColors.green(chunk, end="", flush=True)
|
|
281
|
+
full_streamed_text += chunk
|
|
282
|
+
return True
|
|
283
|
+
|
|
284
|
+
stream_prompt = "Write a very short, 3-line poem about the speed of Groq."
|
|
285
|
+
result = binding.generate_text(stream_prompt, n_predict=50, stream=True, streaming_callback=stream_callback)
|
|
286
|
+
print("\n--- End of Stream ---")
|
|
287
|
+
ASCIIColors.green(f"Full streamed text (for verification): {result}")
|
|
288
|
+
|
|
289
|
+
# --- Embeddings Test ---
|
|
290
|
+
ASCIIColors.cyan("\n--- Embeddings (OpenAI model via OpenRouter) ---")
|
|
291
|
+
try:
|
|
292
|
+
embedding_model = "openai/text-embedding-ada-002"
|
|
293
|
+
embedding_text = "OpenRouter simplifies everything."
|
|
294
|
+
embedding_vector = binding.embed(embedding_text, model=embedding_model)
|
|
295
|
+
ASCIIColors.green(f"Embedding for '{embedding_text}' (first 5 dims): {embedding_vector[:5]}...")
|
|
296
|
+
ASCIIColors.info(f"Embedding vector dimension: {len(embedding_vector)}")
|
|
297
|
+
except Exception as e:
|
|
298
|
+
ASCIIColors.error(f"Embedding test failed: {e}")
|
|
299
|
+
|
|
300
|
+
except Exception as e:
|
|
301
|
+
ASCIIColors.error(f"An error occurred during testing: {e}")
|
|
302
|
+
trace_exception(e)
|
|
303
|
+
|
|
304
|
+
ASCIIColors.yellow("\nOpenRouterBinding test finished.")
|
lollms_client/lollms_core.py
CHANGED
|
@@ -1595,7 +1595,7 @@ Provide your response as a single JSON object inside a JSON markdown tag. Use th
|
|
|
1595
1595
|
formatted_tools_list += "\n**request_clarification**:\nUse if the user's request is ambiguous and you can not infer a clear idea of his intent. this tool has no parameters."
|
|
1596
1596
|
formatted_tools_list += "\n**final_answer**:\nUse when you are ready to respond to the user. this tool has no parameters."
|
|
1597
1597
|
|
|
1598
|
-
if discovery_step_id: log_event("Discovering tools",MSG_TYPE.MSG_TYPE_STEP_END, event_id=discovery_step_id)
|
|
1598
|
+
if discovery_step_id: log_event("**Discovering tools**",MSG_TYPE.MSG_TYPE_STEP_END, event_id=discovery_step_id)
|
|
1599
1599
|
|
|
1600
1600
|
# --- 2. Dynamic Reasoning Loop ---
|
|
1601
1601
|
for i in range(max_reasoning_steps):
|
|
@@ -1755,7 +1755,7 @@ Provide your response as a single JSON object inside a JSON markdown tag. Use th
|
|
|
1755
1755
|
|
|
1756
1756
|
tool_calls_this_turn.append({"name": tool_name, "params": tool_params, "result": tool_result})
|
|
1757
1757
|
current_scratchpad += f"\n\n### Step {i+1}: Observation\n- **Action:** Called `{tool_name}`\n- **Result:**\n{observation_text}"
|
|
1758
|
-
log_event(f"Observation
|
|
1758
|
+
log_event(f"**Observation**: Result from `{tool_name}`:\n{dict_to_markdown(sanitized_result)}", MSG_TYPE.MSG_TYPE_OBSERVATION)
|
|
1759
1759
|
|
|
1760
1760
|
if reasoning_step_id: log_event(f"**Reasoning Step {i+1}/{max_reasoning_steps}**", MSG_TYPE.MSG_TYPE_STEP_END, event_id = reasoning_step_id)
|
|
1761
1761
|
except Exception as ex:
|
|
@@ -423,33 +423,29 @@ class LollmsDiscussion:
|
|
|
423
423
|
else:
|
|
424
424
|
return cls(lollmsClient=lollms_client, discussion_id=kwargs.get('id'), **init_args)
|
|
425
425
|
|
|
426
|
-
def get_messages(self, branch_id: Optional[str] = None) ->
|
|
426
|
+
def get_messages(self, branch_id: Optional[str] = None) -> Optional[List[LollmsMessage]]:
|
|
427
427
|
"""
|
|
428
|
-
Returns messages from
|
|
428
|
+
Returns a list of messages forming a branch, from root to a specific leaf.
|
|
429
429
|
|
|
430
|
-
- If no branch_id is provided, it returns
|
|
431
|
-
|
|
432
|
-
- If a branch_id is provided, it returns the
|
|
433
|
-
(
|
|
430
|
+
- If no branch_id is provided, it returns the full message list of the
|
|
431
|
+
currently active branch.
|
|
432
|
+
- If a branch_id is provided, it returns the list of all messages from the
|
|
433
|
+
root up to (and including) the message with that ID.
|
|
434
434
|
|
|
435
435
|
Args:
|
|
436
|
-
branch_id: The ID of the leaf message
|
|
437
|
-
|
|
436
|
+
branch_id: The ID of the leaf message of the desired branch.
|
|
437
|
+
If None, the active branch's leaf is used.
|
|
438
438
|
|
|
439
439
|
Returns:
|
|
440
|
-
A list of LollmsMessage objects for the
|
|
441
|
-
|
|
440
|
+
A list of LollmsMessage objects for the specified branch, ordered
|
|
441
|
+
from root to leaf, or None if the branch_id does not exist.
|
|
442
442
|
"""
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
if branch_id in self._message_index:
|
|
450
|
-
return LollmsMessage(self, self._message_index[branch_id])
|
|
451
|
-
else:
|
|
452
|
-
return None
|
|
443
|
+
# Determine which leaf message ID to use
|
|
444
|
+
leaf_id = branch_id if branch_id is not None else self.active_branch_id
|
|
445
|
+
|
|
446
|
+
# Return the full branch leading to that leaf
|
|
447
|
+
# We assume self.get_branch() correctly handles non-existent IDs by returning None or an empty list.
|
|
448
|
+
return self.get_branch(leaf_id)
|
|
453
449
|
|
|
454
450
|
|
|
455
451
|
def __getattr__(self, name: str) -> Any:
|