lollms-client 0.25.6__py3-none-any.whl → 0.26.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lollms-client might be problematic. Click here for more details.

@@ -0,0 +1,298 @@
1
+ import os
2
+ from typing import Optional, Callable, List, Union, Dict
3
+
4
+ from lollms_client.lollms_discussion import LollmsDiscussion, LollmsMessage
5
+ from lollms_client.lollms_llm_binding import LollmsLLMBinding
6
+ from lollms_client.lollms_types import MSG_TYPE
7
+ from ascii_colors import ASCIIColors, trace_exception
8
+
9
+ import pipmaster as pm
10
+
11
+ # Ensure the required packages are installed
12
+ pm.ensure_packages(["mistralai", "pillow", "tiktoken"])
13
+
14
+ from mistralai.client import MistralClient
15
+ from mistralai.models.chat_completion import ChatMessage
16
+ from PIL import Image, ImageDraw
17
+ import tiktoken
18
+
19
+ BindingName = "MistralBinding"
20
+
21
+ class MistralBinding(LollmsLLMBinding):
22
+ """
23
+ Mistral AI API binding implementation.
24
+
25
+ This binding allows communication with Mistral's API for both their
26
+ open-weight and proprietary models.
27
+ """
28
+
29
+ def __init__(self,
30
+ model_name: str = "mistral-large-latest",
31
+ mistral_api_key: str = None,
32
+ **kwargs
33
+ ):
34
+ """
35
+ Initialize the MistralBinding.
36
+
37
+ Args:
38
+ model_name (str): The name of the Mistral model to use.
39
+ mistral_api_key (str): The API key for the Mistral service.
40
+ """
41
+ super().__init__(binding_name=BindingName)
42
+ self.model_name = model_name
43
+ self.mistral_api_key = mistral_api_key or os.getenv("MISTRAL_API_KEY")
44
+
45
+ if not self.mistral_api_key:
46
+ raise ValueError("Mistral API key is required. Set it via 'mistral_api_key' or MISTRAL_API_KEY env var.")
47
+
48
+ try:
49
+ self.client = MistralClient(api_key=self.mistral_api_key)
50
+ except Exception as e:
51
+ ASCIIColors.error(f"Failed to configure Mistral client: {e}")
52
+ self.client = None
53
+ raise ConnectionError(f"Could not configure Mistral client: {e}") from e
54
+
55
+ def _construct_parameters(self,
56
+ temperature: float,
57
+ top_p: float,
58
+ n_predict: int,
59
+ seed: Optional[int]) -> Dict[str, any]:
60
+ """Builds a parameters dictionary for the Mistral API."""
61
+ params = {}
62
+ if temperature is not None: params['temperature'] = float(temperature)
63
+ if top_p is not None: params['top_p'] = top_p
64
+ if n_predict is not None: params['max_tokens'] = n_predict
65
+ if seed is not None: params['random_seed'] = seed # Mistral uses 'random_seed'
66
+ return params
67
+
68
+ def _prepare_messages(self, discussion: LollmsDiscussion, branch_tip_id: Optional[str] = None) -> List[ChatMessage]:
69
+ """Prepares the message list for the Mistral API from a LollmsDiscussion."""
70
+ history = []
71
+ if discussion.system_prompt:
72
+ # Mistral prefers the system prompt as the first message with a user/assistant turn.
73
+ # A lone system message is not ideal. We will prepend it to the first user message.
74
+ # However, for API consistency, we will treat it as a separate message if it exists.
75
+ # The official client will likely handle this.
76
+ history.append(ChatMessage(role="system", content=discussion.system_prompt))
77
+
78
+ for msg in discussion.get_messages(branch_tip_id):
79
+ role = 'user' if msg.sender_type == "user" else 'assistant'
80
+ # Note: Mistral API currently does not support image inputs via the chat endpoint.
81
+ if msg.content:
82
+ history.append(ChatMessage(role=role, content=msg.content))
83
+ return history
84
+
85
+ def generate_text(self, prompt: str, **kwargs) -> Union[str, dict]:
86
+ """
87
+ Generate text using Mistral. This is a wrapper around the chat method.
88
+ """
89
+ temp_discussion = LollmsDiscussion.from_messages([
90
+ LollmsMessage.new_message(sender_type="user", content=prompt)
91
+ ])
92
+ if kwargs.get("system_prompt"):
93
+ temp_discussion.system_prompt = kwargs.get("system_prompt")
94
+
95
+ return self.chat(temp_discussion, **kwargs)
96
+
97
+ def chat(self,
98
+ discussion: LollmsDiscussion,
99
+ branch_tip_id: Optional[str] = None,
100
+ n_predict: Optional[int] = 2048,
101
+ stream: Optional[bool] = False,
102
+ temperature: float = 0.7,
103
+ top_p: float = 0.9,
104
+ seed: Optional[int] = None,
105
+ streaming_callback: Optional[Callable[[str, MSG_TYPE], None]] = None,
106
+ **kwargs
107
+ ) -> Union[str, dict]:
108
+ """
109
+ Conduct a chat session with a Mistral model.
110
+ """
111
+ if not self.client:
112
+ return {"status": "error", "message": "Mistral client not initialized."}
113
+
114
+ messages = self._prepare_messages(discussion, branch_tip_id)
115
+ api_params = self._construct_parameters(temperature, top_p, n_predict, seed)
116
+ full_response_text = ""
117
+
118
+ try:
119
+ if stream:
120
+ response = self.client.chat_stream(
121
+ model=self.model_name,
122
+ messages=messages,
123
+ **api_params
124
+ )
125
+ for chunk in response:
126
+ delta = chunk.choices[0].delta.content
127
+ if delta:
128
+ full_response_text += delta
129
+ if streaming_callback:
130
+ if not streaming_callback(delta, MSG_TYPE.MSG_TYPE_CHUNK):
131
+ break
132
+ return full_response_text
133
+ else:
134
+ response = self.client.chat(
135
+ model=self.model_name,
136
+ messages=messages,
137
+ **api_params
138
+ )
139
+ return response.choices[0].message.content
140
+
141
+ except Exception as ex:
142
+ error_message = f"An unexpected error occurred with Mistral API: {str(ex)}"
143
+ trace_exception(ex)
144
+ return {"status": "error", "message": error_message}
145
+
146
+ def tokenize(self, text: str) -> list:
147
+ """Tokenize text using tiktoken as a fallback."""
148
+ try:
149
+ encoding = tiktoken.get_encoding("cl100k_base")
150
+ return encoding.encode(text)
151
+ except Exception:
152
+ return list(text.encode('utf-8'))
153
+
154
+ def detokenize(self, tokens: list) -> str:
155
+ """Detokenize tokens using tiktoken."""
156
+ try:
157
+ encoding = tiktoken.get_encoding("cl100k_base")
158
+ return encoding.decode(tokens)
159
+ except Exception:
160
+ return bytes(tokens).decode('utf-8', errors='ignore')
161
+
162
+ def count_tokens(self, text: str) -> int:
163
+ """Count tokens in a text using the fallback tokenizer."""
164
+ return len(self.tokenize(text))
165
+
166
+ def embed(self, text: str, **kwargs) -> List[float]:
167
+ """
168
+ Get embeddings for the input text using the Mistral embedding API.
169
+ """
170
+ if not self.client:
171
+ raise Exception("Mistral client not initialized.")
172
+
173
+ # Default to the recommended embedding model
174
+ model_to_use = kwargs.get("model", "mistral-embed")
175
+
176
+ try:
177
+ response = self.client.embeddings(
178
+ model=model_to_use,
179
+ input=[text] # API expects a list of strings
180
+ )
181
+ return response.data[0].embedding
182
+ except Exception as ex:
183
+ trace_exception(ex)
184
+ raise Exception(f"Mistral embedding failed: {str(ex)}") from ex
185
+
186
+ def get_model_info(self) -> dict:
187
+ """Return information about the current Mistral setup."""
188
+ return {
189
+ "name": self.binding_name,
190
+ "version": "unknown", # mistralai library doesn't expose a version attribute easily
191
+ "host_address": "https://api.mistral.ai",
192
+ "model_name": self.model_name,
193
+ "supports_structured_output": False,
194
+ "supports_vision": False, # Mistral API does not currently support vision
195
+ }
196
+
197
+ def listModels(self) -> List[Dict[str, str]]:
198
+ """Lists available models from the Mistral service."""
199
+ if not self.client:
200
+ ASCIIColors.error("Mistral client not initialized. Cannot list models.")
201
+ return []
202
+ try:
203
+ ASCIIColors.debug("Listing Mistral models...")
204
+ models = self.client.list_models()
205
+ model_info_list = []
206
+ for m in models.data:
207
+ model_info_list.append({
208
+ 'model_name': m.id,
209
+ 'display_name': m.id.replace('-', ' ').title(),
210
+ 'description': f"Owned by: {m.owned_by}",
211
+ 'owned_by': m.owned_by
212
+ })
213
+ return model_info_list
214
+ except Exception as ex:
215
+ trace_exception(ex)
216
+ return []
217
+
218
+ def load_model(self, model_name: str) -> bool:
219
+ """Sets the model name for subsequent operations."""
220
+ self.model_name = model_name
221
+ ASCIIColors.info(f"Mistral model set to: {model_name}. It will be used on the next API call.")
222
+ return True
223
+
224
+ if __name__ == '__main__':
225
+ # Environment variable to set for testing:
226
+ # MISTRAL_API_KEY: Your Mistral API key
227
+
228
+ if "MISTRAL_API_KEY" not in os.environ:
229
+ ASCIIColors.red("Error: MISTRAL_API_KEY environment variable not set.")
230
+ print("Please get your key from https://console.mistral.ai/api-keys/ and set it.")
231
+ exit(1)
232
+
233
+ ASCIIColors.yellow("--- Testing MistralBinding ---")
234
+
235
+ test_model_name = "mistral-small-latest" # Use a smaller, faster model for testing
236
+ test_embedding_model = "mistral-embed"
237
+
238
+ try:
239
+ # --- Initialization ---
240
+ ASCIIColors.cyan("\n--- Initializing Binding ---")
241
+ binding = MistralBinding(model_name=test_model_name)
242
+ ASCIIColors.green("Binding initialized successfully.")
243
+
244
+ # --- List Models ---
245
+ ASCIIColors.cyan("\n--- Listing Models ---")
246
+ models = binding.listModels()
247
+ if models:
248
+ ASCIIColors.green(f"Found {len(models)} models on Mistral. Available models:")
249
+ for m in models:
250
+ print(f"- {m['model_name']}")
251
+ else:
252
+ ASCIIColors.warning("No models found or failed to list models.")
253
+
254
+ # --- Text Generation (Non-Streaming) ---
255
+ ASCIIColors.cyan("\n--- Text Generation (Non-Streaming) ---")
256
+ prompt_text = "Who developed the transformer architecture and in what paper?"
257
+ generated_text = binding.generate_text(prompt_text, n_predict=100, stream=False)
258
+ if isinstance(generated_text, str):
259
+ ASCIIColors.green(f"Generated text:\n{generated_text}")
260
+ else:
261
+ ASCIIColors.error(f"Generation failed: {generated_text}")
262
+
263
+ # --- Text Generation (Streaming) ---
264
+ ASCIIColors.cyan("\n--- Text Generation (Streaming) ---")
265
+ full_streamed_text = ""
266
+ def stream_callback(chunk: str, msg_type: int):
267
+ nonlocal full_streamed_text
268
+ ASCIIColors.green(chunk, end="", flush=True)
269
+ full_streamed_text += chunk
270
+ return True
271
+
272
+ result = binding.generate_text(prompt_text, n_predict=150, stream=True, streaming_callback=stream_callback)
273
+ print("\n--- End of Stream ---")
274
+ ASCIIColors.green(f"Full streamed text (for verification): {result}")
275
+
276
+ # --- Embeddings Test ---
277
+ ASCIIColors.cyan("\n--- Embeddings ---")
278
+ try:
279
+ embedding_text = "Mistral AI is based in Paris."
280
+ embedding_vector = binding.embed(embedding_text, model=test_embedding_model)
281
+ ASCIIColors.green(f"Embedding for '{embedding_text}' (first 5 dims): {embedding_vector[:5]}...")
282
+ ASCIIColors.info(f"Embedding vector dimension: {len(embedding_vector)}")
283
+ except Exception as e:
284
+ ASCIIColors.error(f"Embedding test failed: {e}")
285
+
286
+ # --- Vision Test (should be unsupported) ---
287
+ ASCIIColors.cyan("\n--- Vision Test (Expecting No Support) ---")
288
+ model_info = binding.get_model_info()
289
+ if not model_info.get("supports_vision"):
290
+ ASCIIColors.green("Binding correctly reports no support for vision.")
291
+ else:
292
+ ASCIIColors.warning("Binding reports support for vision, which is unexpected for Mistral.")
293
+
294
+ except Exception as e:
295
+ ASCIIColors.error(f"An error occurred during testing: {e}")
296
+ trace_exception(e)
297
+
298
+ ASCIIColors.yellow("\nMistralBinding test finished.")
@@ -0,0 +1,304 @@
1
+ import os
2
+ from typing import Optional, Callable, List, Union, Dict
3
+
4
+ from lollms_client.lollms_discussion import LollmsDiscussion, LollmsMessage
5
+ from lollms_client.lollms_llm_binding import LollmsLLMBinding
6
+ from lollms_client.lollms_types import MSG_TYPE
7
+ from ascii_colors import ASCIIColors, trace_exception
8
+
9
+ import pipmaster as pm
10
+
11
+ # Ensure the required packages are installed
12
+ pm.ensure_packages(["openai", "pillow", "tiktoken"])
13
+
14
+ import openai
15
+ from PIL import Image, ImageDraw
16
+ import tiktoken
17
+
18
+ BindingName = "OpenRouterBinding"
19
+
20
+ class OpenRouterBinding(LollmsLLMBinding):
21
+ """
22
+ OpenRouter API binding implementation.
23
+
24
+ This binding allows communication with the OpenRouter service, which acts as a
25
+ aggregator for a vast number of AI models from different providers. It uses
26
+ an OpenAI-compatible API structure.
27
+ """
28
+ BASE_URL = "https://openrouter.ai/api/v1"
29
+
30
+ def __init__(self,
31
+ model_name: str = "google/gemini-flash-1.5", # A good, fast default
32
+ open_router_api_key: str = None,
33
+ **kwargs
34
+ ):
35
+ """
36
+ Initialize the OpenRouterBinding.
37
+
38
+ Args:
39
+ model_name (str): The name of the model to use from OpenRouter (e.g., 'anthropic/claude-3-haiku-20240307').
40
+ open_router_api_key (str): The API key for the OpenRouter service.
41
+ """
42
+ super().__init__(binding_name=BindingName)
43
+ self.model_name = model_name
44
+ self.api_key = open_router_api_key or os.getenv("OPENROUTER_API_KEY")
45
+
46
+ if not self.api_key:
47
+ raise ValueError("OpenRouter API key is required. Set it via 'open_router_api_key' or OPENROUTER_API_KEY env var.")
48
+
49
+ try:
50
+ self.client = openai.OpenAI(
51
+ base_url=self.BASE_URL,
52
+ api_key=self.api_key,
53
+ )
54
+ except Exception as e:
55
+ ASCIIColors.error(f"Failed to configure OpenRouter client: {e}")
56
+ self.client = None
57
+ raise ConnectionError(f"Could not configure OpenRouter client: {e}") from e
58
+
59
+ def _construct_parameters(self,
60
+ temperature: float,
61
+ top_p: float,
62
+ n_predict: int,
63
+ seed: Optional[int]) -> Dict[str, any]:
64
+ """Builds a parameters dictionary for the API."""
65
+ params = {}
66
+ if temperature is not None: params['temperature'] = float(temperature)
67
+ if top_p is not None: params['top_p'] = top_p
68
+ if n_predict is not None: params['max_tokens'] = n_predict
69
+ if seed is not None: params['seed'] = seed
70
+ return params
71
+
72
+ def _prepare_messages(self, discussion: LollmsDiscussion, branch_tip_id: Optional[str] = None) -> List[Dict[str, any]]:
73
+ """Prepares the message list for the API from a LollmsDiscussion."""
74
+ history = []
75
+ if discussion.system_prompt:
76
+ history.append({"role": "system", "content": discussion.system_prompt})
77
+
78
+ for msg in discussion.get_messages(branch_tip_id):
79
+ role = 'user' if msg.sender_type == "user" else 'assistant'
80
+ # Note: Vision support depends on the specific model being called via OpenRouter.
81
+ # We will not implement it in this generic binding to avoid complexity,
82
+ # as different models might expect different formats.
83
+ if msg.content:
84
+ history.append({'role': role, 'content': msg.content})
85
+ return history
86
+
87
+ def generate_text(self, prompt: str, **kwargs) -> Union[str, dict]:
88
+ """
89
+ Generate text using OpenRouter. This is a wrapper around the chat method.
90
+ """
91
+ temp_discussion = LollmsDiscussion.from_messages([
92
+ LollmsMessage.new_message(sender_type="user", content=prompt)
93
+ ])
94
+ if kwargs.get("system_prompt"):
95
+ temp_discussion.system_prompt = kwargs.get("system_prompt")
96
+
97
+ return self.chat(temp_discussion, **kwargs)
98
+
99
+ def chat(self,
100
+ discussion: LollmsDiscussion,
101
+ branch_tip_id: Optional[str] = None,
102
+ n_predict: Optional[int] = 2048,
103
+ stream: Optional[bool] = False,
104
+ temperature: float = 0.7,
105
+ top_p: float = 0.9,
106
+ seed: Optional[int] = None,
107
+ streaming_callback: Optional[Callable[[str, MSG_TYPE], None]] = None,
108
+ **kwargs
109
+ ) -> Union[str, dict]:
110
+ """
111
+ Conduct a chat session with a model via OpenRouter.
112
+ """
113
+ if not self.client:
114
+ return {"status": "error", "message": "OpenRouter client not initialized."}
115
+
116
+ messages = self._prepare_messages(discussion, branch_tip_id)
117
+ api_params = self._construct_parameters(temperature, top_p, n_predict, seed)
118
+ full_response_text = ""
119
+
120
+ try:
121
+ response = self.client.chat.completions.create(
122
+ model=self.model_name,
123
+ messages=messages,
124
+ stream=stream,
125
+ **api_params
126
+ )
127
+
128
+ if stream:
129
+ for chunk in response:
130
+ delta = chunk.choices[0].delta.content
131
+ if delta:
132
+ full_response_text += delta
133
+ if streaming_callback:
134
+ if not streaming_callback(delta, MSG_TYPE.MSG_TYPE_CHUNK):
135
+ break
136
+ return full_response_text
137
+ else:
138
+ return response.choices[0].message.content
139
+
140
+ except Exception as ex:
141
+ error_message = f"An unexpected error occurred with OpenRouter API: {str(ex)}"
142
+ trace_exception(ex)
143
+ return {"status": "error", "message": error_message}
144
+
145
+ def tokenize(self, text: str) -> list:
146
+ """Tokenize text using tiktoken as a general-purpose fallback."""
147
+ try:
148
+ encoding = tiktoken.get_encoding("cl100k_base")
149
+ return encoding.encode(text)
150
+ except Exception:
151
+ return list(text.encode('utf-8'))
152
+
153
+ def detokenize(self, tokens: list) -> str:
154
+ """Detokenize tokens using tiktoken."""
155
+ try:
156
+ encoding = tiktoken.get_encoding("cl100k_base")
157
+ return encoding.decode(tokens)
158
+ except Exception:
159
+ return bytes(tokens).decode('utf-8', errors='ignore')
160
+
161
+ def count_tokens(self, text: str) -> int:
162
+ """Count tokens in a text using the fallback tokenizer."""
163
+ return len(self.tokenize(text))
164
+
165
+ def embed(self, text: str, **kwargs) -> List[float]:
166
+ """
167
+ Get embeddings for the input text using an OpenRouter embedding model.
168
+ """
169
+ if not self.client:
170
+ raise Exception("OpenRouter client not initialized.")
171
+
172
+ # User must specify an embedding model, e.g., 'text-embedding-ada-002'
173
+ embedding_model = kwargs.get("model")
174
+ if not embedding_model:
175
+ raise ValueError("An embedding model name must be provided via the 'model' kwarg for the embed method.")
176
+
177
+ try:
178
+ # The client is already configured for OpenRouter's base URL
179
+ response = self.client.embeddings.create(
180
+ model=embedding_model,
181
+ input=text
182
+ )
183
+ return response.data[0].embedding
184
+ except Exception as ex:
185
+ trace_exception(ex)
186
+ raise Exception(f"OpenRouter embedding failed: {str(ex)}") from ex
187
+
188
+ def get_model_info(self) -> dict:
189
+ """Return information about the current OpenRouter setup."""
190
+ return {
191
+ "name": self.binding_name,
192
+ "version": openai.__version__,
193
+ "host_address": self.BASE_URL,
194
+ "model_name": self.model_name,
195
+ "supports_structured_output": False,
196
+ "supports_vision": "Depends on the specific model selected. This generic binding does not support vision.",
197
+ }
198
+
199
+ def listModels(self) -> List[Dict[str, str]]:
200
+ """Lists available models from the OpenRouter service."""
201
+ if not self.client:
202
+ ASCIIColors.error("OpenRouter client not initialized. Cannot list models.")
203
+ return []
204
+ try:
205
+ ASCIIColors.debug("Listing OpenRouter models...")
206
+ models = self.client.models.list()
207
+ model_info_list = []
208
+ for m in models.data:
209
+ model_info_list.append({
210
+ 'model_name': m.id,
211
+ 'display_name': m.name if hasattr(m, 'name') else m.id,
212
+ 'description': m.description if hasattr(m, 'description') else "No description available.",
213
+ 'owned_by': m.id.split('/')[0] # Heuristic to get the provider
214
+ })
215
+ return model_info_list
216
+ except Exception as ex:
217
+ trace_exception(ex)
218
+ return []
219
+
220
+ def load_model(self, model_name: str) -> bool:
221
+ """Sets the model name for subsequent operations."""
222
+ self.model_name = model_name
223
+ ASCIIColors.info(f"OpenRouter model set to: {model_name}. It will be used on the next API call.")
224
+ return True
225
+
226
+ if __name__ == '__main__':
227
+ # Environment variable to set for testing:
228
+ # OPENROUTER_API_KEY: Your OpenRouter API key (starts with sk-or-...)
229
+
230
+ if "OPENROUTER_API_KEY" not in os.environ:
231
+ ASCIIColors.red("Error: OPENROUTER_API_KEY environment variable not set.")
232
+ print("Please get your key from https://openrouter.ai/keys and set it.")
233
+ exit(1)
234
+
235
+ ASCIIColors.yellow("--- Testing OpenRouterBinding ---")
236
+
237
+ try:
238
+ # --- Initialization ---
239
+ ASCIIColors.cyan("\n--- Initializing Binding ---")
240
+ # Initialize with a fast, cheap, and well-known model
241
+ binding = OpenRouterBinding(model_name="mistralai/mistral-7b-instruct")
242
+ ASCIIColors.green("Binding initialized successfully.")
243
+
244
+ # --- List Models ---
245
+ ASCIIColors.cyan("\n--- Listing Models ---")
246
+ models = binding.listModels()
247
+ if models:
248
+ ASCIIColors.green(f"Successfully fetched {len(models)} models from OpenRouter.")
249
+ ASCIIColors.info("Sample of available models:")
250
+ # Print a few examples from different providers
251
+ providers_seen = set()
252
+ count = 0
253
+ for m in models:
254
+ provider = m['owned_by']
255
+ if provider not in providers_seen:
256
+ print(f"- {m['model_name']}")
257
+ providers_seen.add(provider)
258
+ count += 1
259
+ if count >= 5:
260
+ break
261
+ else:
262
+ ASCIIColors.warning("No models found or failed to list models.")
263
+
264
+ # --- Text Generation (Testing with a Claude model) ---
265
+ ASCIIColors.cyan("\n--- Text Generation (Claude via OpenRouter) ---")
266
+ binding.load_model("anthropic/claude-3-haiku-20240307")
267
+ prompt_text = "Why is Claude Haiku a good choice for fast-paced chat applications?"
268
+ generated_text = binding.generate_text(prompt_text, n_predict=100, stream=False)
269
+ if isinstance(generated_text, str):
270
+ ASCIIColors.green(f"Generated text:\n{generated_text}")
271
+ else:
272
+ ASCIIColors.error(f"Generation failed: {generated_text}")
273
+
274
+ # --- Text Generation (Streaming with a Groq model) ---
275
+ ASCIIColors.cyan("\n--- Text Generation (Llama3 on Groq via OpenRouter) ---")
276
+ binding.load_model("meta-llama/llama-3-8b-instruct:free") # Use the free tier on OpenRouter
277
+ full_streamed_text = ""
278
+ def stream_callback(chunk: str, msg_type: int):
279
+ nonlocal full_streamed_text
280
+ ASCIIColors.green(chunk, end="", flush=True)
281
+ full_streamed_text += chunk
282
+ return True
283
+
284
+ stream_prompt = "Write a very short, 3-line poem about the speed of Groq."
285
+ result = binding.generate_text(stream_prompt, n_predict=50, stream=True, streaming_callback=stream_callback)
286
+ print("\n--- End of Stream ---")
287
+ ASCIIColors.green(f"Full streamed text (for verification): {result}")
288
+
289
+ # --- Embeddings Test ---
290
+ ASCIIColors.cyan("\n--- Embeddings (OpenAI model via OpenRouter) ---")
291
+ try:
292
+ embedding_model = "openai/text-embedding-ada-002"
293
+ embedding_text = "OpenRouter simplifies everything."
294
+ embedding_vector = binding.embed(embedding_text, model=embedding_model)
295
+ ASCIIColors.green(f"Embedding for '{embedding_text}' (first 5 dims): {embedding_vector[:5]}...")
296
+ ASCIIColors.info(f"Embedding vector dimension: {len(embedding_vector)}")
297
+ except Exception as e:
298
+ ASCIIColors.error(f"Embedding test failed: {e}")
299
+
300
+ except Exception as e:
301
+ ASCIIColors.error(f"An error occurred during testing: {e}")
302
+ trace_exception(e)
303
+
304
+ ASCIIColors.yellow("\nOpenRouterBinding test finished.")
@@ -423,33 +423,29 @@ class LollmsDiscussion:
423
423
  else:
424
424
  return cls(lollmsClient=lollms_client, discussion_id=kwargs.get('id'), **init_args)
425
425
 
426
- def get_messages(self, branch_id: Optional[str] = None) -> Union[List[LollmsMessage], Optional[LollmsMessage]]:
426
+ def get_messages(self, branch_id: Optional[str] = None) -> Optional[List[LollmsMessage]]:
427
427
  """
428
- Returns messages from the discussion with branch-aware logic.
428
+ Returns a list of messages forming a branch, from root to a specific leaf.
429
429
 
430
- - If no branch_id is provided, it returns a list of all messages
431
- in the currently active branch, ordered from root to leaf.
432
- - If a branch_id is provided, it returns the single message object
433
- (the "leaf") corresponding to that ID.
430
+ - If no branch_id is provided, it returns the full message list of the
431
+ currently active branch.
432
+ - If a branch_id is provided, it returns the list of all messages from the
433
+ root up to (and including) the message with that ID.
434
434
 
435
435
  Args:
436
- branch_id: The ID of the leaf message. If provided, only this
437
- message is returned. If None, the full active branch is returned.
436
+ branch_id: The ID of the leaf message of the desired branch.
437
+ If None, the active branch's leaf is used.
438
438
 
439
439
  Returns:
440
- A list of LollmsMessage objects for the active branch, or a single
441
- LollmsMessage if a branch_id is specified, or None if the ID is not found.
440
+ A list of LollmsMessage objects for the specified branch, ordered
441
+ from root to leaf, or None if the branch_id does not exist.
442
442
  """
443
- if branch_id is None:
444
- # Case 1: No ID, return the current active branch as a list of messages
445
- leaf_id = self.active_branch_id
446
- return self.get_branch(leaf_id)
447
- else:
448
- # Case 2: ID provided, return just the single leaf message
449
- if branch_id in self._message_index:
450
- return LollmsMessage(self, self._message_index[branch_id])
451
- else:
452
- return None
443
+ # Determine which leaf message ID to use
444
+ leaf_id = branch_id if branch_id is not None else self.active_branch_id
445
+
446
+ # Return the full branch leading to that leaf
447
+ # We assume self.get_branch() correctly handles non-existent IDs by returning None or an empty list.
448
+ return self.get_branch(leaf_id)
453
449
 
454
450
 
455
451
  def __getattr__(self, name: str) -> Any: