lollms-client 0.31.0__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lollms-client might be problematic. Click here for more details.

@@ -9,6 +9,7 @@ from lollms_client.lollms_discussion import LollmsDiscussion
9
9
  from typing import Optional, Callable, List, Union
10
10
  from ascii_colors import ASCIIColors, trace_exception
11
11
  from typing import List, Dict
12
+ import math
12
13
 
13
14
  import pipmaster as pm
14
15
 
@@ -56,189 +57,182 @@ class OpenAIBinding(LollmsLLMBinding):
56
57
  self.client = openai.OpenAI(api_key=self.service_key, base_url=None if host_address is None else host_address if len(host_address)>0 else None)
57
58
  self.completion_format = ELF_COMPLETION_FORMAT.Chat
58
59
 
60
+ def _build_openai_params(self, messages: list, **kwargs) -> dict:
61
+ model = kwargs.get("model", self.model_name)
62
+ if "n_predict" in kwargs:
63
+ kwargs["max_tokens"] = kwargs.pop("n_predict")
64
+
65
+ restricted_families = [
66
+ "gpt-5",
67
+ "gpt-4o",
68
+ "o1",
69
+ "o3",
70
+ "o4"
71
+ ]
72
+
73
+ allowed_params = {
74
+ "model", "messages", "temperature", "top_p", "n",
75
+ "stop", "max_tokens", "presence_penalty", "frequency_penalty",
76
+ "logit_bias", "stream", "user", "max_completion_tokens"
77
+ }
78
+
79
+ params = {
80
+ "model": model,
81
+ "messages": messages,
82
+ }
83
+
84
+ for k, v in kwargs.items():
85
+ if k in allowed_params and v is not None:
86
+ params[k] = v
87
+ else:
88
+ if v is not None:
89
+ ASCIIColors.warning(f"Removed unsupported OpenAI param '{k}'")
90
+
91
+ model_lower = model.lower()
92
+ if any(fam in model_lower for fam in restricted_families):
93
+ if "temperature" in params and params["temperature"] != 1:
94
+ ASCIIColors.warning(f"{model} does not support temperature != 1. Overriding to 1.")
95
+ params["temperature"] = 1
96
+ if "top_p" in params:
97
+ ASCIIColors.warning(f"{model} does not support top_p. Removing it.")
98
+ params.pop("top_p")
99
+
100
+ return params
101
+
102
+
103
+
104
+
105
+
59
106
 
60
107
  def generate_text(self,
61
- prompt: str,
62
- images: Optional[List[str]] = None,
63
- system_prompt: str = "",
64
- n_predict: Optional[int] = None,
65
- stream: Optional[bool] = None,
66
- temperature: float = 0.7,
67
- top_k: int = 40,
68
- top_p: float = 0.9,
69
- repeat_penalty: float = 1.1,
70
- repeat_last_n: int = 64,
71
- seed: Optional[int] = None,
72
- n_threads: Optional[int] = None,
73
- ctx_size: int | None = None,
74
- streaming_callback: Optional[Callable[[str, MSG_TYPE], None]] = None,
75
- split:Optional[bool]=False, # put to true if the prompt is a discussion
76
- user_keyword:Optional[str]="!@>user:",
77
- ai_keyword:Optional[str]="!@>assistant:",
78
- ) -> Union[str, dict]:
79
- """
80
- Generate text using the active LLM binding, using instance defaults if parameters are not provided.
81
-
82
- Args:
83
- prompt (str): The input prompt for text generation.
84
- images (Optional[List[str]]): List of image file paths for multimodal generation.
85
- n_predict (Optional[int]): Maximum number of tokens to generate. Uses instance default if None.
86
- stream (Optional[bool]): Whether to stream the output. Uses instance default if None.
87
- temperature (Optional[float]): Sampling temperature. Uses instance default if None.
88
- top_k (Optional[int]): Top-k sampling parameter. Uses instance default if None.
89
- top_p (Optional[float]): Top-p sampling parameter. Uses instance default if None.
90
- repeat_penalty (Optional[float]): Penalty for repeated tokens. Uses instance default if None.
91
- repeat_last_n (Optional[int]): Number of previous tokens to consider for repeat penalty. Uses instance default if None.
92
- seed (Optional[int]): Random seed for generation. Uses instance default if None.
93
- n_threads (Optional[int]): Number of threads to use. Uses instance default if None.
94
- ctx_size (int | None): Context size override for this generation.
95
- streaming_callback (Optional[Callable[[str, str], None]]): Callback function for streaming output.
96
- - First parameter (str): The chunk of text received.
97
- - Second parameter (str): The message type (e.g., MSG_TYPE.MSG_TYPE_CHUNK).
98
- split:Optional[bool]: put to true if the prompt is a discussion
99
- user_keyword:Optional[str]: when splitting we use this to extract user prompt
100
- ai_keyword:Optional[str]": when splitting we use this to extract ai prompt
108
+ prompt: str,
109
+ images: Optional[List[str]] = None,
110
+ system_prompt: str = "",
111
+ n_predict: Optional[int] = None,
112
+ stream: Optional[bool] = None,
113
+ temperature: float = 0.7,
114
+ top_k: int = 40,
115
+ top_p: float = 0.9,
116
+ repeat_penalty: float = 1.1,
117
+ repeat_last_n: int = 64,
118
+ seed: Optional[int] = None,
119
+ n_threads: Optional[int] = None,
120
+ ctx_size: int | None = None,
121
+ streaming_callback: Optional[Callable[[str, MSG_TYPE], None]] = None,
122
+ split: Optional[bool] = False,
123
+ user_keyword: Optional[str] = "!@>user:",
124
+ ai_keyword: Optional[str] = "!@>assistant:"
125
+ ) -> Union[str, dict]:
101
126
 
102
- Returns:
103
- Union[str, dict]: Generated text or error dictionary if failed.
104
- """
105
127
  count = 0
106
128
  output = ""
107
- messages = [
108
- {
109
- "role": "system",
110
- "content": system_prompt or "You are a helpful assistant.",
111
- }
112
- ]
129
+ messages = [{"role": "system", "content": system_prompt or "You are a helpful assistant."}]
113
130
 
114
- # Prepare messages based on whether images are provided
115
131
  if images:
116
132
  if split:
117
- messages += self.split_discussion(prompt,user_keyword=user_keyword, ai_keyword=ai_keyword)
118
- if images:
119
- messages[-1]["content"] = [
120
- {
121
- "type": "text",
122
- "text": messages[-1]["content"]
123
- }
124
- ]+[
125
- {
126
- "type": "image_url",
127
- "image_url": {
128
- "url": f"data:image/jpeg;base64,{encode_image(image_path)}"
129
- }
130
- }
131
- for image_path in images
132
- ]
133
+ messages += self.split_discussion(prompt, user_keyword=user_keyword, ai_keyword=ai_keyword)
134
+ messages[-1]["content"] = [{"type": "text", "text": messages[-1]["content"]}] + [
135
+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image(path)}"}}
136
+ for path in images
137
+ ]
133
138
  else:
134
139
  messages.append({
135
- 'role': 'user',
136
- 'content': [
137
- {
138
- "type": "text",
139
- "text": prompt
140
- }
141
- ] + [
142
- {
143
- "type": "image_url",
144
- "image_url": {
145
- "url": f"data:image/jpeg;base64,{encode_image(image_path)}"
146
- }
147
- }
148
- for image_path in images
149
- ]
150
- }
151
- )
152
-
140
+ 'role': 'user',
141
+ 'content': [{"type": "text", "text": prompt}] + [
142
+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image(path)}"}}
143
+ for path in images
144
+ ]
145
+ })
153
146
  else:
154
-
155
147
  if split:
156
- messages += self.split_discussion(prompt,user_keyword=user_keyword, ai_keyword=ai_keyword)
157
- if images:
158
- messages[-1]["content"] = [
159
- {
160
- "type": "text",
161
- "text": messages[-1]["content"]
162
- }
163
- ]
148
+ messages += self.split_discussion(prompt, user_keyword=user_keyword, ai_keyword=ai_keyword)
164
149
  else:
165
- messages.append({
166
- 'role': 'user',
167
- 'content': [
168
- {
169
- "type": "text",
170
- "text": prompt
171
- }
172
- ]
173
- }
174
- )
175
-
176
- # Generate text using the OpenAI API
177
- if self.completion_format == ELF_COMPLETION_FORMAT.Chat:
178
- try:
179
- chat_completion = self.client.chat.completions.create(
180
- model=self.model_name, # Choose the engine according to your OpenAI plan
181
- messages=messages,
182
- max_tokens=n_predict, # Adjust the desired length of the generated response
183
- n=1, # Specify the number of responses you want
184
- temperature=temperature, # Adjust the temperature for more or less randomness in the output
185
- stream=stream
186
- )
187
- except Exception as ex:
188
- # exception for new openai models
189
- chat_completion = self.client.chat.completions.create(
190
- model=self.model_name, # Choose the engine according to your OpenAI plan
191
- messages=messages,
192
- max_completion_tokens=n_predict, # Adjust the desired length of the generated response
193
- n=1, # Specify the number of responses you want
194
- temperature=1, # Adjust the temperature for more or less randomness in the output
195
- stream=stream
196
- )
150
+ messages.append({'role': 'user', 'content': [{"type": "text", "text": prompt}]})
197
151
 
198
- if stream:
199
- for resp in chat_completion:
200
- if count >= n_predict:
201
- break
202
- try:
203
- word = resp.choices[0].delta.content
204
- except Exception as ex:
205
- word = ""
206
- if streaming_callback is not None:
207
- if not streaming_callback(word, MSG_TYPE.MSG_TYPE_CHUNK):
152
+ try:
153
+ if self.completion_format == ELF_COMPLETION_FORMAT.Chat:
154
+ params = self._build_openai_params(messages=messages,
155
+ n_predict=n_predict,
156
+ stream=stream,
157
+ temperature=temperature,
158
+ top_p=top_p,
159
+ repeat_penalty=repeat_penalty,
160
+ seed=seed)
161
+ try:
162
+ chat_completion = self.client.chat.completions.create(**params)
163
+ except Exception as ex:
164
+ # exception for new openai models
165
+ params["max_completion_tokens"]=params["max_tokens"]
166
+ params["temperature"]=1
167
+ try: del params["max_tokens"]
168
+ except Exception: pass
169
+ try: del params["top_p"]
170
+ except Exception: pass
171
+ try: del params["frequency_penalty"]
172
+ except Exception: pass
173
+
174
+ chat_completion = self.client.chat.completions.create(**params)
175
+
176
+ if stream:
177
+ for resp in chat_completion:
178
+ if count >= (n_predict or float('inf')):
208
179
  break
209
- if word:
210
- output += word
211
- count += 1
180
+ word = getattr(resp.choices[0].delta, "content", "") or ""
181
+ if streaming_callback and not streaming_callback(word, MSG_TYPE.MSG_TYPE_CHUNK):
182
+ break
183
+ if word:
184
+ output += word
185
+ count += 1
186
+ else:
187
+ output = chat_completion.choices[0].message.content
188
+
212
189
  else:
213
- output = chat_completion.choices[0].message.content
214
- else:
215
- completion = self.client.completions.create(
216
- model=self.model_name, # Choose the engine according to your OpenAI plan
217
- prompt=prompt,
218
- max_tokens=n_predict, # Adjust the desired length of the generated response
219
- n=1, # Specify the number of responses you want
220
- temperature=temperature, # Adjust the temperature for more or less randomness in the output
221
- stream=stream
222
- )
190
+ params = self._build_openai_params(prompt=prompt,
191
+ n_predict=n_predict,
192
+ stream=stream,
193
+ temperature=temperature,
194
+ top_p=top_p,
195
+ repeat_penalty=repeat_penalty,
196
+ seed=seed)
197
+ try:
198
+ completion = self.client.completions.create(**params)
199
+ except Exception as ex:
200
+ # exception for new openai models
201
+ params["max_completion_tokens"]=params["max_tokens"]
202
+ params["temperature"]=1
203
+ try: del params["max_tokens"]
204
+ except Exception: pass
205
+ try: del params["top_p"]
206
+ except Exception: pass
207
+ try: del params["frequency_penalty"]
208
+ except Exception: pass
223
209
 
224
- if stream:
225
- for resp in completion:
226
- if count >= n_predict:
227
- break
228
- try:
229
- word = resp.choices[0].text
230
- except Exception as ex:
231
- word = ""
232
- if streaming_callback is not None:
233
- if not streaming_callback(word, "MSG_TYPE_CHUNK"):
210
+
211
+ completion = self.client.completions.create(**params)
212
+
213
+ if stream:
214
+ for resp in completion:
215
+ if count >= (n_predict or float('inf')):
234
216
  break
235
- if word:
236
- output += word
237
- count += 1
238
- else:
239
- output = completion.choices[0].text
217
+ word = getattr(resp.choices[0], "text", "") or ""
218
+ if streaming_callback and not streaming_callback(word, MSG_TYPE.MSG_TYPE_CHUNK):
219
+ break
220
+ if word:
221
+ output += word
222
+ count += 1
223
+ else:
224
+ output = completion.choices[0].text
225
+
226
+ except Exception as e:
227
+ trace_exception(e)
228
+ err_msg = f"An error occurred with the OpenAI API: {e}"
229
+ if streaming_callback:
230
+ streaming_callback(err_msg, MSG_TYPE.MSG_TYPE_EXCEPTION)
231
+ return {"status": "error", "message": err_msg}
240
232
 
241
233
  return output
234
+
235
+
242
236
 
243
237
  def generate_from_messages(self,
244
238
  messages: List[Dict],
@@ -282,9 +276,13 @@ class OpenAIBinding(LollmsLLMBinding):
282
276
  # exception for new openai models
283
277
  params["max_completion_tokens"]=params["max_tokens"]
284
278
  params["temperature"]=1
285
- del params["max_tokens"]
286
- del params["top_p"]
287
- del params["frequency_penalty"]
279
+ try: del params["max_tokens"]
280
+ except Exception: pass
281
+ try: del params["top_p"]
282
+ except Exception: pass
283
+ try: del params["frequency_penalty"]
284
+ except Exception: pass
285
+
288
286
 
289
287
  completion = self.client.chat.completions.create(**params)
290
288
  if stream:
@@ -308,155 +306,163 @@ class OpenAIBinding(LollmsLLMBinding):
308
306
  return {"status": "error", "message": error_message}
309
307
 
310
308
  return output
311
-
309
+
312
310
  def chat(self,
313
- discussion: LollmsDiscussion,
314
- branch_tip_id: Optional[str] = None,
315
- n_predict: Optional[int] = None,
316
- stream: Optional[bool] = None,
317
- temperature: float = 0.7,
318
- top_k: int = 40,
319
- top_p: float = 0.9,
320
- repeat_penalty: float = 1.1,
321
- repeat_last_n: int = 64,
322
- seed: Optional[int] = None,
323
- n_threads: Optional[int] = None,
324
- ctx_size: Optional[int] = None,
325
- streaming_callback: Optional[Callable[[str, MSG_TYPE], None]] = None
326
- ) -> Union[str, dict]:
327
- """
328
- Conduct a chat session with the OpenAI model using a LollmsDiscussion object.
311
+ discussion: LollmsDiscussion,
312
+ branch_tip_id: Optional[str] = None,
313
+ n_predict: Optional[int] = None,
314
+ stream: Optional[bool] = None,
315
+ temperature: float = 0.7,
316
+ top_k: int = 40,
317
+ top_p: float = 0.9,
318
+ repeat_penalty: float = 1.1,
319
+ repeat_last_n: int = 64,
320
+ seed: Optional[int] = None,
321
+ n_threads: Optional[int] = None,
322
+ ctx_size: Optional[int] = None,
323
+ streaming_callback: Optional[Callable[[str, MSG_TYPE], None]] = None
324
+ ) -> Union[str, dict]:
329
325
 
330
- Args:
331
- discussion (LollmsDiscussion): The discussion object containing the conversation history.
332
- branch_tip_id (Optional[str]): The ID of the message to use as the tip of the conversation branch. Defaults to the active branch.
333
- n_predict (Optional[int]): Maximum number of tokens to generate.
334
- stream (Optional[bool]): Whether to stream the output.
335
- temperature (float): Sampling temperature.
336
- top_k (int): Top-k sampling parameter (Note: not all OpenAI models use this).
337
- top_p (float): Top-p sampling parameter.
338
- repeat_penalty (float): Frequency penalty for repeated tokens.
339
- seed (Optional[int]): Random seed for generation.
340
- streaming_callback (Optional[Callable[[str, MSG_TYPE], None]]): Callback for streaming output.
341
-
342
- Returns:
343
- Union[str, dict]: The generated text or an error dictionary.
344
- """
345
- # 1. Export the discussion to the OpenAI chat format
346
- # This handles system prompts, user/assistant roles, and multi-modal content automatically.
347
326
  messages = discussion.export("openai_chat", branch_tip_id)
327
+ params = self._build_openai_params(messages=messages,
328
+ n_predict=n_predict,
329
+ stream=stream,
330
+ temperature=temperature,
331
+ top_p=top_p,
332
+ repeat_penalty=repeat_penalty,
333
+ seed=seed)
348
334
 
349
- # Build the request parameters
350
- params = {
351
- "model": self.model_name,
352
- "messages": messages,
353
- "max_tokens": n_predict,
354
- "n": 1,
355
- "temperature": temperature,
356
- "top_p": top_p,
357
- "frequency_penalty": repeat_penalty,
358
- "stream": stream
359
- }
360
- # Add seed if available, as it's supported by newer OpenAI models
361
- if seed is not None:
362
- params["seed"] = seed
363
-
364
- # Remove None values, as the API expects them to be absent
365
- params = {k: v for k, v in params.items() if v is not None}
366
-
367
335
  output = ""
368
- # 2. Call the API
369
336
  try:
370
- # Check if we should use the chat completions or legacy completions endpoint
371
337
  if self.completion_format == ELF_COMPLETION_FORMAT.Chat:
372
- try:
373
- completion = self.client.chat.completions.create(**params)
374
- except Exception as ex:
375
- # exception for new openai models
376
- params["max_completion_tokens"]=params["max_tokens"]
377
- params["temperature"]=1
378
- del params["max_tokens"]
379
- del params["top_p"]
380
- del params["frequency_penalty"]
381
-
382
- completion = self.client.chat.completions.create(**params)
383
-
338
+ completion = self.client.chat.completions.create(**params)
384
339
  if stream:
385
340
  for chunk in completion:
386
- # The streaming response for chat has a different structure
387
341
  delta = chunk.choices[0].delta
388
342
  if delta.content:
389
343
  word = delta.content
390
- if streaming_callback is not None:
391
- if not streaming_callback(word, MSG_TYPE.MSG_TYPE_CHUNK):
392
- break
344
+ if streaming_callback and not streaming_callback(word, MSG_TYPE.MSG_TYPE_CHUNK):
345
+ break
393
346
  output += word
394
347
  else:
395
348
  output = completion.choices[0].message.content
396
-
397
- else: # Fallback to legacy completion format (not recommended for chat)
398
- # We need to format the messages list into a single string prompt
349
+ else:
399
350
  legacy_prompt = discussion.export("openai_completion", branch_tip_id)
400
- legacy_params = {
401
- "model": self.model_name,
402
- "prompt": legacy_prompt,
403
- "max_tokens": n_predict,
404
- "n": 1,
405
- "temperature": temperature,
406
- "top_p": top_p,
407
- "frequency_penalty": repeat_penalty,
408
- "stream": stream
409
- }
351
+ legacy_params = self._build_openai_params(prompt=legacy_prompt,
352
+ n_predict=n_predict,
353
+ stream=stream,
354
+ temperature=temperature,
355
+ top_p=top_p,
356
+ repeat_penalty=repeat_penalty,
357
+ seed=seed)
410
358
  completion = self.client.completions.create(**legacy_params)
411
-
412
359
  if stream:
413
360
  for chunk in completion:
414
361
  word = chunk.choices[0].text
415
- if streaming_callback is not None:
416
- if not streaming_callback(word, MSG_TYPE.MSG_TYPE_CHUNK):
417
- break
362
+ if streaming_callback and not streaming_callback(word, MSG_TYPE.MSG_TYPE_CHUNK):
363
+ break
418
364
  output += word
419
365
  else:
420
366
  output = completion.choices[0].text
421
-
367
+
422
368
  except Exception as e:
423
- # Handle API errors gracefully
424
- error_message = f"An error occurred with the OpenAI API: {e}"
369
+ err = f"An error occurred with the OpenAI API: {e}"
425
370
  if streaming_callback:
426
- streaming_callback(error_message, MSG_TYPE.MSG_TYPE_EXCEPTION)
427
- return {"status": "error", "message": error_message}
428
-
429
- return output
430
- def tokenize(self, text: str) -> list:
371
+ streaming_callback(err, MSG_TYPE.MSG_TYPE_EXCEPTION)
372
+ return {"status": "error", "message": err}
373
+
374
+ return output
375
+
376
+ def _get_encoding(self, model_name: str | None = None):
431
377
  """
432
- Tokenize the input text into a list of characters.
378
+ Get the tiktoken encoding for a given model.
379
+ Falls back to 'cl100k_base' if model is unknown.
380
+ """
381
+ if model_name is None:
382
+ model_name = self.model_name
383
+ try:
384
+ return tiktoken.encoding_for_model(model_name)
385
+ except KeyError:
386
+ return tiktoken.get_encoding("cl100k_base")
387
+
388
+ def tokenize(self, text: str) -> list[int]:
389
+ """
390
+ Tokenize text into a list of token IDs.
433
391
 
434
392
  Args:
435
393
  text (str): The text to tokenize.
436
394
 
437
395
  Returns:
438
- list: List of individual characters.
396
+ list[int]: List of token IDs.
439
397
  """
440
- try:
441
- return tiktoken.model.encoding_for_model(self.model_name).encode(text)
442
- except:
443
- return tiktoken.model.encoding_for_model("gpt-3.5-turbo").encode(text)
444
-
445
- def detokenize(self, tokens: list) -> str:
398
+ encoding = self._get_encoding()
399
+ return encoding.encode(text)
400
+
401
+ def detokenize(self, tokens: list[int]) -> str:
446
402
  """
447
- Convert a list of tokens back to text.
403
+ Convert a list of token IDs back to text.
448
404
 
449
405
  Args:
450
- tokens (list): List of tokens (characters) to detokenize.
406
+ tokens (list[int]): List of tokens.
451
407
 
452
408
  Returns:
453
- str: Detokenized text.
409
+ str: The decoded text.
454
410
  """
455
- try:
456
- return tiktoken.model.encoding_for_model(self.model_name).decode(tokens)
457
- except:
458
- return tiktoken.model.encoding_for_model("gpt-3.5-turbo").decode(tokens)
411
+ encoding = self._get_encoding()
412
+ return encoding.decode(tokens)
459
413
 
414
+ def get_input_tokens_price(self, model_name: str | None = None) -> float:
415
+ """
416
+ Get the price per input token for a given model (USD).
417
+
418
+ Args:
419
+ model_name (str | None): Model name. Defaults to self.model_name.
420
+
421
+ Returns:
422
+ float: Price per input token in USD.
423
+ """
424
+ if model_name is None:
425
+ model_name = self.model_name
426
+
427
+ price_map = {
428
+ "gpt-4o": 5e-6,
429
+ "gpt-4o-mini": 1.5e-6,
430
+ "gpt-3.5-turbo": 1.5e-6,
431
+ "o1": 15e-6,
432
+ "o3": 15e-6,
433
+ }
434
+
435
+ for key, price in price_map.items():
436
+ if model_name.lower().startswith(key):
437
+ return price
438
+ return 0.0 # Unknown → treat as free
439
+
440
+ def get_output_tokens_price(self, model_name: str | None = None) -> float:
441
+ """
442
+ Get the price per output token for a given model (USD).
443
+
444
+ Args:
445
+ model_name (str | None): Model name. Defaults to self.model_name.
446
+
447
+ Returns:
448
+ float: Price per output token in USD.
449
+ """
450
+ if model_name is None:
451
+ model_name = self.model_name
452
+
453
+ price_map = {
454
+ "gpt-4o": 15e-6,
455
+ "gpt-4o-mini": 6e-6,
456
+ "gpt-3.5-turbo": 2e-6,
457
+ "o1": 60e-6,
458
+ "o3": 60e-6,
459
+ }
460
+
461
+ for key, price in price_map.items():
462
+ if model_name.lower().startswith(key):
463
+ return price
464
+ return 0.0
465
+
460
466
  def count_tokens(self, text: str) -> int:
461
467
  """
462
468
  Count tokens from a text.
@@ -470,41 +476,113 @@ class OpenAIBinding(LollmsLLMBinding):
470
476
  return len(self.tokenize(text))
471
477
 
472
478
 
473
- def embed(self, text: str, **kwargs) -> list:
479
+
480
+ def embed(self, text: str | list[str], normalize: bool = False, **kwargs) -> list:
474
481
  """
475
- Get embeddings for the input text using OpenAI API.
482
+ Get embeddings for input text(s) using OpenAI API.
476
483
 
477
484
  Args:
478
- text (str): Input text to embed.
485
+ text (str | list[str]): Input text or list of texts to embed.
486
+ normalize (bool): Whether to normalize the resulting vector(s) to unit length.
479
487
  **kwargs: Additional arguments. The 'model' argument can be used
480
- to specify the embedding model (e.g., "text-embedding-3-small").
481
- Defaults to "text-embedding-ada-002".
488
+ to specify the embedding model (e.g., "text-embedding-3-small").
489
+ Defaults to "text-embedding-3-small".
482
490
 
483
491
  Returns:
484
- list: The embedding vector as a list of floats, or an empty list on failure.
492
+ list: A single embedding vector (list of floats) if input is str,
493
+ or a list of embedding vectors if input is list[str].
494
+ Returns empty list on failure.
485
495
  """
486
- # Determine the embedding model, prioritizing kwargs, with a default
496
+ # Determine the embedding model
487
497
  embedding_model = kwargs.get("model", self.model_name)
488
-
498
+ if not embedding_model.startswith("text-embedding"):
499
+ embedding_model = "text-embedding-3-small"
500
+
501
+ # Ensure input is a list of strings
502
+ is_single_input = isinstance(text, str)
503
+ input_texts = [text] if is_single_input else text
504
+
505
+ # Optional safety: truncate if too many tokens for embedding model
506
+ max_tokens_map = {
507
+ "text-embedding-3-small": 8191,
508
+ "text-embedding-3-large": 8191,
509
+ "text-embedding-ada-002": 8191
510
+ }
511
+ max_tokens = max_tokens_map.get(embedding_model, None)
512
+ if max_tokens is not None:
513
+ input_texts = [
514
+ self.detokenize(self.tokenize(t)[:max_tokens])
515
+ for t in input_texts
516
+ ]
517
+
489
518
  try:
490
- # The OpenAI API expects the input to be a list of strings
491
519
  response = self.client.embeddings.create(
492
520
  model=embedding_model,
493
- input=[text] # Wrap the single text string in a list
521
+ input=input_texts
494
522
  )
495
-
496
- # Extract the embedding from the response
497
- if response.data and len(response.data) > 0:
498
- return response.data[0].embedding
499
- else:
500
- ASCIIColors.warning("OpenAI API returned no data for the embedding request.")
523
+
524
+ if not response.data:
525
+ ASCIIColors.warning(f"OpenAI API returned no data for the embedding request (model: {embedding_model}).")
501
526
  return []
502
-
527
+
528
+ embeddings = [item.embedding for item in response.data]
529
+
530
+ # Normalize if requested
531
+ if normalize:
532
+ embeddings = [
533
+ [v / math.sqrt(sum(x*x for x in emb)) for v in emb]
534
+ for emb in embeddings
535
+ ]
536
+
537
+ return embeddings[0] if is_single_input else embeddings
538
+
503
539
  except Exception as e:
504
- ASCIIColors.error(f"Failed to generate embeddings using OpenAI API: {e}")
540
+ ASCIIColors.error(f"Failed to generate embeddings using model '{embedding_model}': {e}")
505
541
  trace_exception(e)
506
542
  return []
507
543
 
544
+
545
+ def get_ctx_size(self, model_name: str | None = None) -> int:
546
+ """
547
+ Get the context size for a given model.
548
+ If model_name is None, use the instance's model_name.
549
+
550
+ Args:
551
+ model_name (str | None): The model name to check.
552
+
553
+ Returns:
554
+ int: The context window size in tokens.
555
+ """
556
+ if model_name is None:
557
+ model_name = self.model_name
558
+
559
+ # Default context sizes (update as needed)
560
+ context_map = {
561
+ # GPT-4 family
562
+ "gpt-4": 8192,
563
+ "gpt-4-32k": 32768,
564
+ "gpt-4o": 128000,
565
+ "gpt-4o-mini": 128000,
566
+ # GPT-3.5 family
567
+ "gpt-3.5-turbo": 16385,
568
+ "gpt-3.5-turbo-16k": 16385,
569
+ # GPT-5 and o-series
570
+ "gpt-5": 200000,
571
+ "o1": 200000,
572
+ "o3": 200000,
573
+ "o4": 200000,
574
+ }
575
+
576
+ # Try to find the best match
577
+ model_name_lower = model_name.lower()
578
+ for key, size in context_map.items():
579
+ if model_name_lower.startswith(key):
580
+ return size
581
+
582
+ # Fallback: default safe value
583
+ return 8192
584
+
585
+
508
586
  def get_model_info(self) -> dict:
509
587
  """
510
588
  Return information about the current OpenAI model.