lollms-client 0.24.2__py3-none-any.whl → 0.27.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lollms-client might be problematic. Click here for more details.

@@ -0,0 +1,536 @@
1
+ import base64
2
+ import os
3
+ import json
4
+ import requests
5
+ from io import BytesIO
6
+ from pathlib import Path
7
+ from typing import Optional, Callable, List, Union, Dict
8
+
9
+ from lollms_client.lollms_discussion import LollmsDiscussion, LollmsMessage
10
+ from lollms_client.lollms_llm_binding import LollmsLLMBinding
11
+ from lollms_client.lollms_types import MSG_TYPE
12
+ from ascii_colors import ASCIIColors, trace_exception
13
+
14
+ import pipmaster as pm
15
+
16
+ # Ensure the required packages are installed
17
+ pm.ensure_packages(["requests", "pillow", "tiktoken"])
18
+
19
+ from PIL import Image, ImageDraw
20
+ import tiktoken
21
+
22
+ BindingName = "GrokBinding"
23
+
24
+ # API Endpoint
25
+ GROK_API_BASE_URL = "https://api.x.ai/v1"
26
+
27
+ # A hardcoded list to be used as a fallback if the API call fails
28
+ _FALLBACK_MODELS = [
29
+ {'model_name': 'grok-1', 'display_name': 'Grok 1', 'description': 'The flagship conversational model from xAI.', 'owned_by': 'xAI'},
30
+ {'model_name': 'grok-1.5', 'display_name': 'Grok 1.5', 'description': 'The latest multimodal model from xAI.', 'owned_by': 'xAI'},
31
+ {'model_name': 'grok-1.5-vision-preview', 'display_name': 'Grok 1.5 Vision (Preview)', 'description': 'Multimodal model with vision capabilities (preview).', 'owned_by': 'xAI'},
32
+ ]
33
+
34
+ # Helper to check if a string is a valid path to an image
35
+ def is_image_path(path_str: str) -> bool:
36
+ try:
37
+ p = Path(path_str)
38
+ return p.is_file() and p.suffix.lower() in ['.png', '.jpg', '.jpeg', '.gif', '.bmp', '.webp']
39
+ except Exception:
40
+ return False
41
+
42
+ # Helper to get image media type for base64 URI
43
+ def get_media_type_for_uri(image_path: Union[str, Path]) -> str:
44
+ path = Path(image_path)
45
+ ext = path.suffix.lower()
46
+ if ext == ".jpg" or ext == ".jpeg":
47
+ return "image/jpeg"
48
+ elif ext == ".png":
49
+ return "image/png"
50
+ elif ext == ".gif":
51
+ return "image/gif"
52
+ elif ext == ".webp":
53
+ return "image/webp"
54
+ else:
55
+ # Default to PNG as it's lossless and widely supported
56
+ return "image/png"
57
+
58
+
59
+ class GrokBinding(LollmsLLMBinding):
60
+ """xAI Grok-specific binding implementation."""
61
+
62
+ def __init__(self,
63
+ host_address: str = None, # Ignored, for compatibility
64
+ model_name: str = "grok-1.5-vision-preview",
65
+ service_key: str = None,
66
+ verify_ssl_certificate: bool = True, # Ignored, for compatibility
67
+ **kwargs
68
+ ):
69
+ """
70
+ Initialize the Grok binding.
71
+
72
+ Args:
73
+ model_name (str): Name of the Grok model to use.
74
+ service_key (str): xAI API key.
75
+ """
76
+ super().__init__(binding_name=BindingName)
77
+ self.model_name = model_name
78
+ self.service_key = service_key
79
+ self.base_url = kwargs.get("base_url", GROK_API_BASE_URL)
80
+ self._cached_models: Optional[List[Dict[str, str]]] = None
81
+
82
+ if not self.service_key:
83
+ self.service_key = os.getenv("XAI_API_KEY")
84
+
85
+ if not self.service_key:
86
+ raise ValueError("xAI API key is required. Please set it via the 'service_key' parameter or the XAI_API_KEY environment variable.")
87
+
88
+ self.headers = {
89
+ "Authorization": f"Bearer {self.service_key}",
90
+ "Content-Type": "application/json"
91
+ }
92
+
93
+ def _construct_parameters(self,
94
+ temperature: float,
95
+ top_p: float,
96
+ n_predict: int) -> Dict[str, any]:
97
+ """Builds a parameters dictionary for the Grok API."""
98
+ params = {"stream": True} # Always stream from the API
99
+ if temperature is not None: params['temperature'] = float(temperature)
100
+ if top_p is not None: params['top_p'] = top_p
101
+ # Grok has a model-specific max_tokens, but we can request less.
102
+ if n_predict is not None: params['max_tokens'] = n_predict
103
+ return params
104
+
105
+ def _process_and_handle_stream(self,
106
+ response: requests.Response,
107
+ stream: bool,
108
+ streaming_callback: Optional[Callable[[str, MSG_TYPE], None]]
109
+ ) -> Union[str, dict]:
110
+ """Helper to process streaming responses from the API."""
111
+ full_response_text = ""
112
+
113
+ try:
114
+ for line in response.iter_lines():
115
+ if line:
116
+ decoded_line = line.decode('utf-8')
117
+ if decoded_line.startswith('data: '):
118
+ json_str = decoded_line[len('data: '):]
119
+ if json_str.strip() == "[DONE]":
120
+ break
121
+ try:
122
+ chunk = json.loads(json_str)
123
+ if chunk['choices']:
124
+ delta = chunk['choices'][0].get('delta', {})
125
+ content = delta.get('content', '')
126
+ if content:
127
+ full_response_text += content
128
+ if stream and streaming_callback:
129
+ if not streaming_callback(content, MSG_TYPE.MSG_TYPE_CHUNK):
130
+ # Stop streaming if the callback returns False
131
+ return full_response_text
132
+ except json.JSONDecodeError:
133
+ ASCIIColors.warning(f"Could not decode JSON chunk: {json_str}")
134
+ continue
135
+
136
+ # This handles both cases:
137
+ # - If stream=True, we have already sent chunks. We return the full string.
138
+ # - If stream=False, we have buffered the whole response and now return it.
139
+ return full_response_text
140
+
141
+ except Exception as ex:
142
+ error_message = f"An unexpected error occurred while processing the Grok stream: {str(ex)}"
143
+ trace_exception(ex)
144
+ return {"status": False, "error": error_message}
145
+
146
+
147
+ def generate_text(self,
148
+ prompt: str,
149
+ images: Optional[List[str]] = None,
150
+ system_prompt: str = "",
151
+ n_predict: Optional[int] = 2048,
152
+ stream: Optional[bool] = False,
153
+ temperature: float = 0.7,
154
+ top_p: float = 0.9,
155
+ repeat_penalty: float = 1.1, # Not supported
156
+ repeat_last_n: int = 64, # Not supported
157
+ seed: Optional[int] = None, # Not supported
158
+ n_threads: Optional[int] = None, # Not applicable
159
+ ctx_size: int | None = None, # Determined by model
160
+ streaming_callback: Optional[Callable[[str, MSG_TYPE], None]] = None,
161
+ **kwargs
162
+ ) -> Union[str, dict]:
163
+ """
164
+ Generate text using the Grok model.
165
+ """
166
+ if not self.service_key:
167
+ return {"status": False, "error": "xAI API key not configured."}
168
+
169
+ api_params = self._construct_parameters(temperature, top_p, n_predict)
170
+
171
+ messages = []
172
+ if system_prompt and system_prompt.strip():
173
+ messages.append({"role": "system", "content": system_prompt})
174
+
175
+ user_content = []
176
+ if prompt and prompt.strip():
177
+ user_content.append({"type": "text", "text": prompt})
178
+
179
+ if images:
180
+ for image_data in images:
181
+ try:
182
+ if is_image_path(image_data):
183
+ media_type = get_media_type_for_uri(image_data)
184
+ with open(image_data, "rb") as image_file:
185
+ b64_data = base64.b64encode(image_file.read()).decode('utf-8')
186
+ else: # Assume it's a base64 string
187
+ b64_data = image_data
188
+ media_type = "image/png" # Assume PNG if raw base64
189
+
190
+ user_content.append({
191
+ "type": "image_url",
192
+ "image_url": {"url": f"data:{media_type};base64,{b64_data}"}
193
+ })
194
+ except Exception as e:
195
+ error_msg = f"Failed to process image: {e}"
196
+ ASCIIColors.error(error_msg)
197
+ return {"status": False, "error": error_msg}
198
+
199
+ if not user_content:
200
+ if stream and streaming_callback:
201
+ streaming_callback("", MSG_TYPE.MSG_TYPE_FINISHED_MESSAGE)
202
+ return ""
203
+
204
+ messages.append({"role": "user", "content": user_content})
205
+
206
+ payload = {
207
+ "model": self.model_name,
208
+ "messages": messages,
209
+ **api_params
210
+ }
211
+
212
+ try:
213
+ response = requests.post(
214
+ f"{self.base_url}/chat/completions",
215
+ headers=self.headers,
216
+ json=payload,
217
+ stream=True # We always use the streaming endpoint
218
+ )
219
+ response.raise_for_status()
220
+
221
+ return self._process_and_handle_stream(response, stream, streaming_callback)
222
+
223
+ except requests.exceptions.RequestException as ex:
224
+ error_message = f"Grok API request failed: {str(ex)}"
225
+ try: # Try to get more info from the response body
226
+ error_message += f"\nResponse: {ex.response.text}"
227
+ except:
228
+ pass
229
+ trace_exception(ex)
230
+ return {"status": False, "error": error_message}
231
+ except Exception as ex:
232
+ error_message = f"An unexpected error occurred with Grok API: {str(ex)}"
233
+ trace_exception(ex)
234
+ return {"status": False, "error": error_message}
235
+
236
+
237
+ def chat(self,
238
+ discussion: LollmsDiscussion,
239
+ branch_tip_id: Optional[str] = None,
240
+ n_predict: Optional[int] = 2048,
241
+ stream: Optional[bool] = False,
242
+ temperature: float = 0.7,
243
+ top_p: float = 0.9,
244
+ streaming_callback: Optional[Callable[[str, MSG_TYPE], None]] = None,
245
+ **kwargs
246
+ ) -> Union[str, dict]:
247
+ """
248
+ Conduct a chat session with the Grok model using a LollmsDiscussion object.
249
+ """
250
+ if not self.service_key:
251
+ return {"status": "error", "message": "xAI API key not configured."}
252
+
253
+ system_prompt = discussion.system_prompt
254
+ discussion_messages = discussion.get_messages(branch_tip_id)
255
+
256
+ messages = []
257
+ if system_prompt and system_prompt.strip():
258
+ messages.append({"role": "system", "content": system_prompt})
259
+
260
+ for msg in discussion_messages:
261
+ role = 'assistant' if msg.sender_type == "assistant" else 'user'
262
+
263
+ content_parts = []
264
+ if msg.content and msg.content.strip():
265
+ content_parts.append({"type": "text", "text": msg.content})
266
+
267
+ if msg.images:
268
+ for file_path in msg.images:
269
+ if is_image_path(file_path):
270
+ try:
271
+ media_type = get_media_type_for_uri(file_path)
272
+ with open(file_path, "rb") as image_file:
273
+ b64_data = base64.b64encode(image_file.read()).decode('utf-8')
274
+ content_parts.append({
275
+ "type": "image_url",
276
+ "image_url": {"url": f"data:{media_type};base64,{b64_data}"}
277
+ })
278
+ except Exception as e:
279
+ ASCIIColors.warning(f"Could not load image {file_path}: {e}")
280
+
281
+ # Grok API expects content to be a string for assistant, or list for user.
282
+ if role == 'user':
283
+ messages.append({'role': role, 'content': content_parts})
284
+ else: # assistant
285
+ # Assistants can't send images, so we just extract the text.
286
+ text_content = next((part['text'] for part in content_parts if part['type'] == 'text'), "")
287
+ if text_content:
288
+ messages.append({'role': role, 'content': text_content})
289
+
290
+ if not messages or messages[-1]['role'] != 'user':
291
+ return {"status": "error", "message": "Cannot start chat without a user message."}
292
+
293
+ api_params = self._construct_parameters(temperature, top_p, n_predict)
294
+
295
+ payload = {
296
+ "model": self.model_name,
297
+ "messages": messages,
298
+ **api_params
299
+ }
300
+
301
+ try:
302
+ response = requests.post(
303
+ f"{self.base_url}/chat/completions",
304
+ headers=self.headers,
305
+ json=payload,
306
+ stream=True
307
+ )
308
+ response.raise_for_status()
309
+
310
+ return self._process_and_handle_stream(response, stream, streaming_callback)
311
+
312
+ except requests.exceptions.RequestException as ex:
313
+ error_message = f"Grok API request failed: {str(ex)}"
314
+ try:
315
+ error_message += f"\nResponse: {ex.response.text}"
316
+ except:
317
+ pass
318
+ trace_exception(ex)
319
+ return {"status": "error", "message": error_message}
320
+ except Exception as ex:
321
+ error_message = f"An unexpected error occurred with Grok API: {str(ex)}"
322
+ trace_exception(ex)
323
+ return {"status": "error", "message": error_message}
324
+
325
+ def tokenize(self, text: str) -> list:
326
+ """
327
+ Tokenize the input text.
328
+ Note: Grok doesn't expose a public tokenizer API.
329
+ Using tiktoken's cl100k_base for a reasonable estimate.
330
+ """
331
+ try:
332
+ encoding = tiktoken.get_encoding("cl100k_base")
333
+ return encoding.encode(text)
334
+ except:
335
+ return list(text.encode('utf-8'))
336
+
337
+ def detokenize(self, tokens: list) -> str:
338
+ """
339
+ Detokenize a list of tokens.
340
+ Note: Based on the placeholder tokenizer.
341
+ """
342
+ try:
343
+ encoding = tiktoken.get_encoding("cl100k_base")
344
+ return encoding.decode(tokens)
345
+ except:
346
+ return bytes(tokens).decode('utf-8', errors='ignore')
347
+
348
+ def count_tokens(self, text: str) -> int:
349
+ """
350
+ Count tokens from a text using the fallback tokenizer.
351
+ """
352
+ return len(self.tokenize(text))
353
+
354
+ def embed(self, text: str, **kwargs) -> List[float]:
355
+ """
356
+ Get embeddings for the input text.
357
+ Note: xAI does not provide a dedicated embedding model API.
358
+ """
359
+ ASCIIColors.warning("xAI does not offer a public embedding API. This method is not implemented.")
360
+ raise NotImplementedError("Grok binding does not support embeddings.")
361
+
362
+ def get_model_info(self) -> dict:
363
+ """Return information about the current Grok model setup."""
364
+ return {
365
+ "name": self.binding_name,
366
+ "host_address": self.base_url,
367
+ "model_name": self.model_name,
368
+ "supports_structured_output": False,
369
+ "supports_vision": "vision" in self.model_name or "grok-1.5" == self.model_name,
370
+ }
371
+
372
+ def listModels(self) -> List[Dict[str, str]]:
373
+ """
374
+ Lists available models from the xAI API.
375
+ Caches the result to avoid repeated API calls.
376
+ Falls back to a static list if the API call fails.
377
+ """
378
+ if self._cached_models is not None:
379
+ return self._cached_models
380
+
381
+ if not self.service_key:
382
+ ASCIIColors.warning("Cannot fetch models without an API key. Using fallback list.")
383
+ self._cached_models = _FALLBACK_MODELS
384
+ return self._cached_models
385
+
386
+ try:
387
+ ASCIIColors.info("Fetching available models from xAI API...")
388
+ response = requests.get(f"{self.base_url}/models", headers=self.headers, timeout=15)
389
+ response.raise_for_status()
390
+
391
+ data = response.json()
392
+
393
+ if "data" in data and isinstance(data["data"], list):
394
+ models_data = data["data"]
395
+ formatted_models = []
396
+ for model in models_data:
397
+ model_id = model.get("id")
398
+ if not model_id: continue
399
+
400
+ display_name = model_id.replace("-", " ").title()
401
+ description = f"Context: {model.get('context_window', 'N/A')} tokens."
402
+
403
+ formatted_models.append({
404
+ 'model_name': model_id,
405
+ 'display_name': display_name,
406
+ 'description': description,
407
+ 'owned_by': model.get('owned_by', 'xAI')
408
+ })
409
+
410
+ self._cached_models = formatted_models
411
+ ASCIIColors.green(f"Successfully fetched {len(self._cached_models)} models.")
412
+ return self._cached_models
413
+ else:
414
+ raise ValueError("API response is malformed.")
415
+
416
+ except Exception as e:
417
+ ASCIIColors.error(f"Failed to fetch models from xAI API: {e}")
418
+ ASCIIColors.warning("Using hardcoded fallback list of models.")
419
+ trace_exception(e)
420
+ self._cached_models = _FALLBACK_MODELS
421
+ return self._cached_models
422
+
423
+ def load_model(self, model_name: str) -> bool:
424
+ """Set the model name for subsequent operations."""
425
+ self.model_name = model_name
426
+ ASCIIColors.info(f"Grok model set to: {model_name}. It will be used on the next API call.")
427
+ return True
428
+
429
+
430
+ if __name__ == '__main__':
431
+ # Example Usage (requires XAI_API_KEY environment variable)
432
+ if 'XAI_API_KEY' not in os.environ:
433
+ ASCIIColors.red("Error: XAI_API_KEY environment variable not set.")
434
+ print("Please get your key from xAI and set it as an environment variable.")
435
+ exit(1)
436
+
437
+ ASCIIColors.yellow("--- Testing GrokBinding ---")
438
+
439
+ # --- Configuration ---
440
+ test_model_name = "grok-1"
441
+ test_vision_model_name = "grok-1.5-vision-preview"
442
+
443
+ try:
444
+ # --- Initialization ---
445
+ ASCIIColors.cyan("\n--- Initializing Binding ---")
446
+ binding = GrokBinding(model_name=test_model_name)
447
+ ASCIIColors.green("Binding initialized successfully.")
448
+
449
+ # --- List Models ---
450
+ ASCIIColors.cyan("\n--- Listing Models (dynamic) ---")
451
+ models = binding.listModels()
452
+ if models:
453
+ ASCIIColors.green(f"Found {len(models)} models.")
454
+ for m in models:
455
+ print(f"- {m['model_name']} ({m['display_name']})")
456
+ else:
457
+ ASCIIColors.error("Failed to list models.")
458
+
459
+ # --- Count Tokens ---
460
+ ASCIIColors.cyan("\n--- Counting Tokens ---")
461
+ sample_text = "Hello, world! This is a test from the Grok binding."
462
+ token_count = binding.count_tokens(sample_text)
463
+ ASCIIColors.green(f"Token count for '{sample_text}': {token_count} (using tiktoken)")
464
+
465
+ # --- Text Generation (Non-Streaming) ---
466
+ ASCIIColors.cyan("\n--- Text Generation (Non-Streaming) ---")
467
+ prompt_text = "Explain who Elon Musk is in one sentence."
468
+ ASCIIColors.info(f"Prompt: {prompt_text}")
469
+ generated_text = binding.generate_text(prompt_text, n_predict=100, stream=False, system_prompt="Be very concise.")
470
+ if isinstance(generated_text, str):
471
+ ASCIIColors.green(f"Generated text:\n{generated_text}")
472
+ else:
473
+ ASCIIColors.error(f"Generation failed: {generated_text}")
474
+
475
+ # --- Text Generation (Streaming) ---
476
+ ASCIIColors.cyan("\n--- Text Generation (Streaming) ---")
477
+
478
+ full_streamed_text = ""
479
+ def stream_callback(chunk: str, msg_type: int):
480
+ ASCIIColors.green(chunk, end="", flush=True)
481
+ full_streamed_text += chunk
482
+ return True
483
+
484
+ ASCIIColors.info(f"Prompt: {prompt_text}")
485
+ result = binding.generate_text(prompt_text, n_predict=150, stream=True, streaming_callback=stream_callback)
486
+ print("\n--- End of Stream ---")
487
+ ASCIIColors.green(f"Full streamed text (for verification): {result}")
488
+ assert result == full_streamed_text
489
+
490
+ # --- Embeddings ---
491
+ ASCIIColors.cyan("\n--- Embeddings ---")
492
+ try:
493
+ binding.embed("This should fail.")
494
+ except NotImplementedError as e:
495
+ ASCIIColors.green(f"Successfully caught expected error for embeddings: {e}")
496
+
497
+ # --- Vision Model Test ---
498
+ dummy_image_path = "grok_dummy_test_image.png"
499
+ try:
500
+ available_model_names = [m['model_name'] for m in models]
501
+ if test_vision_model_name not in available_model_names:
502
+ ASCIIColors.warning(f"Vision test model '{test_vision_model_name}' not available. Skipping vision test.")
503
+ else:
504
+ img = Image.new('RGB', (250, 60), color=('red'))
505
+ d = ImageDraw.Draw(img)
506
+ d.text((10, 10), "This is a test image for Grok", fill=('white'))
507
+ img.save(dummy_image_path)
508
+ ASCIIColors.info(f"Created dummy image: {dummy_image_path}")
509
+
510
+ ASCIIColors.cyan(f"\n--- Vision Generation (using {test_vision_model_name}) ---")
511
+ binding.load_model(test_vision_model_name)
512
+ vision_prompt = "Describe this image. What does the text say?"
513
+ ASCIIColors.info(f"Vision Prompt: {vision_prompt} with image {dummy_image_path}")
514
+
515
+ vision_response = binding.generate_text(
516
+ prompt=vision_prompt,
517
+ images=[dummy_image_path],
518
+ n_predict=100,
519
+ stream=False
520
+ )
521
+ if isinstance(vision_response, str):
522
+ ASCIIColors.green(f"Vision model response: {vision_response}")
523
+ else:
524
+ ASCIIColors.error(f"Vision generation failed: {vision_response}")
525
+ except Exception as e:
526
+ ASCIIColors.error(f"Error during vision test: {e}")
527
+ trace_exception(e)
528
+ finally:
529
+ if os.path.exists(dummy_image_path):
530
+ os.remove(dummy_image_path)
531
+
532
+ except Exception as e:
533
+ ASCIIColors.error(f"An error occurred during testing: {e}")
534
+ trace_exception(e)
535
+
536
+ ASCIIColors.yellow("\nGrokBinding test finished.")