lollms-client 0.25.5__py3-none-any.whl → 0.26.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lollms-client might be problematic. Click here for more details.
- lollms_client/__init__.py +1 -1
- lollms_client/llm_bindings/azure_openai/__init__.py +364 -0
- lollms_client/llm_bindings/claude/__init__.py +549 -0
- lollms_client/llm_bindings/groq/__init__.py +292 -0
- lollms_client/llm_bindings/hugging_face_inference_api/__init__.py +307 -0
- lollms_client/llm_bindings/lollms/__init__.py +1 -0
- lollms_client/llm_bindings/mistral/__init__.py +298 -0
- lollms_client/llm_bindings/open_router/__init__.py +304 -0
- lollms_client/lollms_core.py +2 -2
- lollms_client/lollms_discussion.py +16 -20
- {lollms_client-0.25.5.dist-info → lollms_client-0.26.0.dist-info}/METADATA +366 -1
- {lollms_client-0.25.5.dist-info → lollms_client-0.26.0.dist-info}/RECORD +15 -9
- {lollms_client-0.25.5.dist-info → lollms_client-0.26.0.dist-info}/WHEEL +0 -0
- {lollms_client-0.25.5.dist-info → lollms_client-0.26.0.dist-info}/licenses/LICENSE +0 -0
- {lollms_client-0.25.5.dist-info → lollms_client-0.26.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,292 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
import os
|
|
3
|
+
from io import BytesIO
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Optional, Callable, List, Union, Dict
|
|
6
|
+
|
|
7
|
+
from lollms_client.lollms_discussion import LollmsDiscussion, LollmsMessage
|
|
8
|
+
from lollms_client.lollms_llm_binding import LollmsLLMBinding
|
|
9
|
+
from lollms_client.lollms_types import MSG_TYPE
|
|
10
|
+
from ascii_colors import ASCIIColors, trace_exception
|
|
11
|
+
|
|
12
|
+
import pipmaster as pm
|
|
13
|
+
|
|
14
|
+
# Ensure the required packages are installed
|
|
15
|
+
pm.ensure_packages(["groq", "pillow", "tiktoken"])
|
|
16
|
+
|
|
17
|
+
import groq
|
|
18
|
+
from PIL import Image, ImageDraw
|
|
19
|
+
import tiktoken
|
|
20
|
+
|
|
21
|
+
BindingName = "GroqBinding"
|
|
22
|
+
|
|
23
|
+
class GroqBinding(LollmsLLMBinding):
|
|
24
|
+
"""
|
|
25
|
+
Groq API binding implementation.
|
|
26
|
+
|
|
27
|
+
This binding allows communication with Groq's LPU-powered inference service,
|
|
28
|
+
known for its high-speed generation. It uses an OpenAI-compatible API structure.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def __init__(self,
|
|
32
|
+
model_name: str = "llama3-8b-8192",
|
|
33
|
+
groq_api_key: str = None,
|
|
34
|
+
**kwargs
|
|
35
|
+
):
|
|
36
|
+
"""
|
|
37
|
+
Initialize the GroqBinding.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
model_name (str): The name of the Groq model to use.
|
|
41
|
+
groq_api_key (str): The API key for the Groq service.
|
|
42
|
+
"""
|
|
43
|
+
super().__init__(binding_name=BindingName)
|
|
44
|
+
self.model_name = model_name
|
|
45
|
+
self.groq_api_key = groq_api_key or os.getenv("GROQ_API_KEY")
|
|
46
|
+
|
|
47
|
+
if not self.groq_api_key:
|
|
48
|
+
raise ValueError("Groq API key is required. Set it via 'groq_api_key' or GROQ_API_KEY env var.")
|
|
49
|
+
|
|
50
|
+
try:
|
|
51
|
+
self.client = groq.Groq(api_key=self.groq_api_key)
|
|
52
|
+
except Exception as e:
|
|
53
|
+
ASCIIColors.error(f"Failed to configure Groq client: {e}")
|
|
54
|
+
self.client = None
|
|
55
|
+
raise ConnectionError(f"Could not configure Groq client: {e}") from e
|
|
56
|
+
|
|
57
|
+
def _construct_parameters(self,
|
|
58
|
+
temperature: float,
|
|
59
|
+
top_p: float,
|
|
60
|
+
n_predict: int,
|
|
61
|
+
seed: Optional[int]) -> Dict[str, any]:
|
|
62
|
+
"""Builds a parameters dictionary for the Groq API."""
|
|
63
|
+
params = {}
|
|
64
|
+
# Groq API mirrors OpenAI's parameters
|
|
65
|
+
if temperature is not None: params['temperature'] = float(temperature)
|
|
66
|
+
if top_p is not None: params['top_p'] = top_p
|
|
67
|
+
if n_predict is not None: params['max_tokens'] = n_predict
|
|
68
|
+
if seed is not None: params['seed'] = seed
|
|
69
|
+
return params
|
|
70
|
+
|
|
71
|
+
def _prepare_messages(self, discussion: LollmsDiscussion, branch_tip_id: Optional[str] = None) -> List[Dict[str, any]]:
|
|
72
|
+
"""Prepares the message list for the Groq API from a LollmsDiscussion."""
|
|
73
|
+
history = []
|
|
74
|
+
if discussion.system_prompt:
|
|
75
|
+
history.append({"role": "system", "content": discussion.system_prompt})
|
|
76
|
+
|
|
77
|
+
for msg in discussion.get_messages(branch_tip_id):
|
|
78
|
+
role = 'user' if msg.sender_type == "user" else 'assistant'
|
|
79
|
+
# Note: Groq models currently do not support image inputs.
|
|
80
|
+
# We only process the text content.
|
|
81
|
+
if msg.content:
|
|
82
|
+
history.append({'role': role, 'content': msg.content})
|
|
83
|
+
return history
|
|
84
|
+
|
|
85
|
+
def generate_text(self, prompt: str, **kwargs) -> Union[str, dict]:
|
|
86
|
+
"""
|
|
87
|
+
Generate text using Groq. This is a wrapper around the chat method.
|
|
88
|
+
"""
|
|
89
|
+
# Create a temporary discussion to leverage the `chat` method's logic
|
|
90
|
+
temp_discussion = LollmsDiscussion.from_messages([
|
|
91
|
+
LollmsMessage.new_message(sender_type="user", content=prompt)
|
|
92
|
+
])
|
|
93
|
+
if kwargs.get("system_prompt"):
|
|
94
|
+
temp_discussion.system_prompt = kwargs.get("system_prompt")
|
|
95
|
+
|
|
96
|
+
return self.chat(temp_discussion, **kwargs)
|
|
97
|
+
|
|
98
|
+
def chat(self,
|
|
99
|
+
discussion: LollmsDiscussion,
|
|
100
|
+
branch_tip_id: Optional[str] = None,
|
|
101
|
+
n_predict: Optional[int] = 2048,
|
|
102
|
+
stream: Optional[bool] = False,
|
|
103
|
+
temperature: float = 0.7,
|
|
104
|
+
top_p: float = 0.9,
|
|
105
|
+
seed: Optional[int] = None,
|
|
106
|
+
streaming_callback: Optional[Callable[[str, MSG_TYPE], None]] = None,
|
|
107
|
+
**kwargs
|
|
108
|
+
) -> Union[str, dict]:
|
|
109
|
+
"""
|
|
110
|
+
Conduct a chat session with a Groq model.
|
|
111
|
+
"""
|
|
112
|
+
if not self.client:
|
|
113
|
+
return {"status": "error", "message": "Groq client not initialized."}
|
|
114
|
+
|
|
115
|
+
messages = self._prepare_messages(discussion, branch_tip_id)
|
|
116
|
+
api_params = self._construct_parameters(temperature, top_p, n_predict, seed)
|
|
117
|
+
full_response_text = ""
|
|
118
|
+
|
|
119
|
+
try:
|
|
120
|
+
response = self.client.chat.completions.create(
|
|
121
|
+
model=self.model_name,
|
|
122
|
+
messages=messages,
|
|
123
|
+
stream=stream,
|
|
124
|
+
**api_params
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
if stream:
|
|
128
|
+
for chunk in response:
|
|
129
|
+
delta = chunk.choices[0].delta.content
|
|
130
|
+
if delta:
|
|
131
|
+
full_response_text += delta
|
|
132
|
+
if streaming_callback:
|
|
133
|
+
if not streaming_callback(delta, MSG_TYPE.MSG_TYPE_CHUNK):
|
|
134
|
+
break
|
|
135
|
+
return full_response_text
|
|
136
|
+
else:
|
|
137
|
+
return response.choices[0].message.content
|
|
138
|
+
|
|
139
|
+
except Exception as ex:
|
|
140
|
+
error_message = f"An unexpected error occurred with Groq API: {str(ex)}"
|
|
141
|
+
trace_exception(ex)
|
|
142
|
+
return {"status": "error", "message": error_message}
|
|
143
|
+
|
|
144
|
+
def tokenize(self, text: str) -> list:
|
|
145
|
+
"""Tokenize text using tiktoken for a rough estimate."""
|
|
146
|
+
try:
|
|
147
|
+
# Most models on Groq (like Llama) use tokenizers that are
|
|
148
|
+
# reasonably approximated by cl100k_base.
|
|
149
|
+
encoding = tiktoken.get_encoding("cl100k_base")
|
|
150
|
+
return encoding.encode(text)
|
|
151
|
+
except Exception:
|
|
152
|
+
return list(text.encode('utf-8'))
|
|
153
|
+
|
|
154
|
+
def detokenize(self, tokens: list) -> str:
|
|
155
|
+
"""Detokenize tokens using tiktoken."""
|
|
156
|
+
try:
|
|
157
|
+
encoding = tiktoken.get_encoding("cl100k_base")
|
|
158
|
+
return encoding.decode(tokens)
|
|
159
|
+
except Exception:
|
|
160
|
+
return bytes(tokens).decode('utf-8', errors='ignore')
|
|
161
|
+
|
|
162
|
+
def count_tokens(self, text: str) -> int:
|
|
163
|
+
"""Count tokens in a text using the fallback tokenizer."""
|
|
164
|
+
return len(self.tokenize(text))
|
|
165
|
+
|
|
166
|
+
def embed(self, text: str, **kwargs) -> List[float]:
|
|
167
|
+
"""
|
|
168
|
+
Groq does not provide an embedding API. This method is not implemented.
|
|
169
|
+
"""
|
|
170
|
+
ASCIIColors.warning("Groq does not offer a public embedding API. This method is not implemented.")
|
|
171
|
+
raise NotImplementedError("Groq binding does not support embeddings.")
|
|
172
|
+
|
|
173
|
+
def get_model_info(self) -> dict:
|
|
174
|
+
"""Return information about the current Groq setup."""
|
|
175
|
+
return {
|
|
176
|
+
"name": self.binding_name,
|
|
177
|
+
"version": groq.__version__,
|
|
178
|
+
"host_address": "https://api.groq.com/openai/v1",
|
|
179
|
+
"model_name": self.model_name,
|
|
180
|
+
"supports_structured_output": False,
|
|
181
|
+
"supports_vision": False, # Groq models do not currently support vision
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
def listModels(self) -> List[Dict[str, str]]:
|
|
185
|
+
"""Lists available models from the Groq service."""
|
|
186
|
+
if not self.client:
|
|
187
|
+
ASCIIColors.error("Groq client not initialized. Cannot list models.")
|
|
188
|
+
return []
|
|
189
|
+
try:
|
|
190
|
+
ASCIIColors.debug("Listing Groq models...")
|
|
191
|
+
models = self.client.models.list()
|
|
192
|
+
model_info_list = []
|
|
193
|
+
for m in models.data:
|
|
194
|
+
model_info_list.append({
|
|
195
|
+
'model_name': m.id,
|
|
196
|
+
'display_name': m.id.replace('-', ' ').title(),
|
|
197
|
+
'description': f"Context window: {m.context_window}, Active: {m.active}",
|
|
198
|
+
'owned_by': m.owned_by
|
|
199
|
+
})
|
|
200
|
+
return model_info_list
|
|
201
|
+
except Exception as ex:
|
|
202
|
+
trace_exception(ex)
|
|
203
|
+
return []
|
|
204
|
+
|
|
205
|
+
def load_model(self, model_name: str) -> bool:
|
|
206
|
+
"""Sets the model name for subsequent operations."""
|
|
207
|
+
self.model_name = model_name
|
|
208
|
+
ASCIIColors.info(f"Groq model set to: {model_name}. It will be used on the next API call.")
|
|
209
|
+
return True
|
|
210
|
+
|
|
211
|
+
if __name__ == '__main__':
|
|
212
|
+
# Environment variable to set for testing:
|
|
213
|
+
# GROQ_API_KEY: Your Groq API key
|
|
214
|
+
|
|
215
|
+
if "GROQ_API_KEY" not in os.environ:
|
|
216
|
+
ASCIIColors.red("Error: GROQ_API_KEY environment variable not set.")
|
|
217
|
+
print("Please get your key from https://console.groq.com/keys and set it.")
|
|
218
|
+
exit(1)
|
|
219
|
+
|
|
220
|
+
ASCIIColors.yellow("--- Testing GroqBinding ---")
|
|
221
|
+
|
|
222
|
+
# Use a fast and common model for testing
|
|
223
|
+
test_model_name = "llama3-8b-8192"
|
|
224
|
+
|
|
225
|
+
try:
|
|
226
|
+
# --- Initialization ---
|
|
227
|
+
ASCIIColors.cyan("\n--- Initializing Binding ---")
|
|
228
|
+
binding = GroqBinding(model_name=test_model_name)
|
|
229
|
+
ASCIIColors.green("Binding initialized successfully.")
|
|
230
|
+
ASCIIColors.info(f"Using groq library version: {groq.__version__}")
|
|
231
|
+
|
|
232
|
+
# --- List Models ---
|
|
233
|
+
ASCIIColors.cyan("\n--- Listing Models ---")
|
|
234
|
+
models = binding.listModels()
|
|
235
|
+
if models:
|
|
236
|
+
ASCIIColors.green(f"Found {len(models)} models on Groq. Available models:")
|
|
237
|
+
for m in models:
|
|
238
|
+
print(f"- {m['model_name']} (owned by {m['owned_by']})")
|
|
239
|
+
else:
|
|
240
|
+
ASCIIColors.warning("No models found or failed to list models.")
|
|
241
|
+
|
|
242
|
+
# --- Count Tokens ---
|
|
243
|
+
ASCIIColors.cyan("\n--- Counting Tokens ---")
|
|
244
|
+
sample_text = "The quick brown fox jumps over the lazy dog."
|
|
245
|
+
token_count = binding.count_tokens(sample_text)
|
|
246
|
+
ASCIIColors.green(f"Token count for '{sample_text}': {token_count}")
|
|
247
|
+
|
|
248
|
+
# --- Text Generation (Non-Streaming) ---
|
|
249
|
+
ASCIIColors.cyan("\n--- Text Generation (Non-Streaming) ---")
|
|
250
|
+
prompt_text = "What is the capital of France? Be concise."
|
|
251
|
+
generated_text = binding.generate_text(prompt_text, n_predict=20, stream=False)
|
|
252
|
+
if isinstance(generated_text, str):
|
|
253
|
+
ASCIIColors.green(f"Generated text:\n{generated_text}")
|
|
254
|
+
else:
|
|
255
|
+
ASCIIColors.error(f"Generation failed: {generated_text}")
|
|
256
|
+
|
|
257
|
+
# --- Text Generation (Streaming) ---
|
|
258
|
+
ASCIIColors.cyan("\n--- Text Generation (Streaming) ---")
|
|
259
|
+
full_streamed_text = ""
|
|
260
|
+
def stream_callback(chunk: str, msg_type: int):
|
|
261
|
+
nonlocal full_streamed_text
|
|
262
|
+
ASCIIColors.green(chunk, end="", flush=True)
|
|
263
|
+
full_streamed_text += chunk
|
|
264
|
+
return True
|
|
265
|
+
|
|
266
|
+
stream_prompt = "Write a very short, 3-line poem about speed."
|
|
267
|
+
result = binding.generate_text(stream_prompt, n_predict=50, stream=True, streaming_callback=stream_callback)
|
|
268
|
+
print("\n--- End of Stream ---")
|
|
269
|
+
ASCIIColors.green(f"Full streamed text (for verification): {result}")
|
|
270
|
+
|
|
271
|
+
# --- Embeddings Test ---
|
|
272
|
+
ASCIIColors.cyan("\n--- Embeddings ---")
|
|
273
|
+
try:
|
|
274
|
+
binding.embed("This should fail.")
|
|
275
|
+
except NotImplementedError as e:
|
|
276
|
+
ASCIIColors.green(f"Successfully caught expected error for embeddings: {e}")
|
|
277
|
+
except Exception as e:
|
|
278
|
+
ASCIIColors.error(f"Caught an unexpected error for embeddings: {e}")
|
|
279
|
+
|
|
280
|
+
# --- Vision Test (should be unsupported) ---
|
|
281
|
+
ASCIIColors.cyan("\n--- Vision Test (Expecting No Support) ---")
|
|
282
|
+
model_info = binding.get_model_info()
|
|
283
|
+
if not model_info.get("supports_vision"):
|
|
284
|
+
ASCIIColors.green("Binding correctly reports no support for vision.")
|
|
285
|
+
else:
|
|
286
|
+
ASCIIColors.warning("Binding reports support for vision, which is unexpected for Groq.")
|
|
287
|
+
|
|
288
|
+
except Exception as e:
|
|
289
|
+
ASCIIColors.error(f"An error occurred during testing: {e}")
|
|
290
|
+
trace_exception(e)
|
|
291
|
+
|
|
292
|
+
ASCIIColors.yellow("\nGroqBinding test finished.")
|
|
@@ -0,0 +1,307 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Optional, Callable, List, Union, Dict
|
|
3
|
+
|
|
4
|
+
from lollms_client.lollms_discussion import LollmsDiscussion, LollmsMessage
|
|
5
|
+
from lollms_client.lollms_llm_binding import LollmsLLMBinding
|
|
6
|
+
from lollms_client.lollms_types import MSG_TYPE
|
|
7
|
+
from ascii_colors import ASCIIColors, trace_exception
|
|
8
|
+
|
|
9
|
+
import pipmaster as pm
|
|
10
|
+
|
|
11
|
+
# Ensure the required packages are installed
|
|
12
|
+
pm.ensure_packages(["huggingface_hub", "tiktoken"])
|
|
13
|
+
|
|
14
|
+
from huggingface_hub import HfApi, InferenceClient
|
|
15
|
+
import tiktoken
|
|
16
|
+
|
|
17
|
+
BindingName = "HuggingFaceInferenceAPIBinding"
|
|
18
|
+
|
|
19
|
+
class HuggingFaceInferenceAPIBinding(LollmsLLMBinding):
|
|
20
|
+
"""
|
|
21
|
+
Hugging Face Inference API binding implementation.
|
|
22
|
+
|
|
23
|
+
This binding communicates with the Hugging Face serverless Inference API,
|
|
24
|
+
allowing access to thousands of models hosted on the Hub.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(self,
|
|
28
|
+
model_name: str = "mistralai/Mistral-7B-Instruct-v0.2",
|
|
29
|
+
hf_api_key: str = None,
|
|
30
|
+
**kwargs
|
|
31
|
+
):
|
|
32
|
+
"""
|
|
33
|
+
Initialize the HuggingFaceInferenceAPIBinding.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
model_name (str): The repository ID of the model on the Hugging Face Hub.
|
|
37
|
+
hf_api_key (str): The Hugging Face API key.
|
|
38
|
+
"""
|
|
39
|
+
super().__init__(binding_name=BindingName)
|
|
40
|
+
self.model_name = model_name
|
|
41
|
+
self.hf_api_key = hf_api_key or os.getenv("HUGGING_FACE_HUB_TOKEN")
|
|
42
|
+
|
|
43
|
+
if not self.hf_api_key:
|
|
44
|
+
raise ValueError("Hugging Face API key is required. Set it via 'hf_api_key' or HUGGING_FACE_HUB_TOKEN env var.")
|
|
45
|
+
|
|
46
|
+
try:
|
|
47
|
+
self.client = InferenceClient(model=self.model_name, token=self.hf_api_key)
|
|
48
|
+
self.hf_api = HfApi(token=self.hf_api_key)
|
|
49
|
+
except Exception as e:
|
|
50
|
+
ASCIIColors.error(f"Failed to configure Hugging Face client: {e}")
|
|
51
|
+
self.client = None
|
|
52
|
+
self.hf_api = None
|
|
53
|
+
raise ConnectionError(f"Could not configure Hugging Face client: {e}") from e
|
|
54
|
+
|
|
55
|
+
def _construct_parameters(self,
|
|
56
|
+
temperature: float,
|
|
57
|
+
top_p: float,
|
|
58
|
+
n_predict: int,
|
|
59
|
+
repeat_penalty: float,
|
|
60
|
+
seed: Optional[int]) -> Dict[str, any]:
|
|
61
|
+
"""Builds a parameters dictionary for the HF Inference API."""
|
|
62
|
+
params = {"details": False, "do_sample": True}
|
|
63
|
+
if temperature is not None and temperature > 0:
|
|
64
|
+
params['temperature'] = float(temperature)
|
|
65
|
+
else:
|
|
66
|
+
# A temperature of 0 can cause issues, a small epsilon is better
|
|
67
|
+
params['temperature'] = 0.001
|
|
68
|
+
params['do_sample'] = False
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
if top_p is not None: params['top_p'] = top_p
|
|
72
|
+
if n_predict is not None: params['max_new_tokens'] = n_predict
|
|
73
|
+
if repeat_penalty is not None: params['repetition_penalty'] = repeat_penalty
|
|
74
|
+
if seed is not None: params['seed'] = seed
|
|
75
|
+
return params
|
|
76
|
+
|
|
77
|
+
def _format_chat_prompt(self, discussion: LollmsDiscussion, branch_tip_id: Optional[str] = None) -> str:
|
|
78
|
+
"""
|
|
79
|
+
Formats a discussion into a single prompt string, attempting to use the model's chat template.
|
|
80
|
+
"""
|
|
81
|
+
messages = []
|
|
82
|
+
if discussion.system_prompt:
|
|
83
|
+
messages.append({"role": "system", "content": discussion.system_prompt})
|
|
84
|
+
|
|
85
|
+
for msg in discussion.get_messages(branch_tip_id):
|
|
86
|
+
role = 'user' if msg.sender_type == "user" else 'assistant'
|
|
87
|
+
messages.append({"role": role, "content": msg.content})
|
|
88
|
+
|
|
89
|
+
try:
|
|
90
|
+
# This is the preferred way, as it respects the model's specific formatting.
|
|
91
|
+
return self.client.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
|
92
|
+
except Exception:
|
|
93
|
+
# Fallback for models without a chat template or if the client fails to fetch it
|
|
94
|
+
ASCIIColors.warning("Could not apply chat template. Using generic formatting.")
|
|
95
|
+
full_prompt = ""
|
|
96
|
+
if discussion.system_prompt:
|
|
97
|
+
full_prompt += f"<|system|>\n{discussion.system_prompt}\n"
|
|
98
|
+
for msg in messages:
|
|
99
|
+
if msg['role'] == 'user':
|
|
100
|
+
full_prompt += f"<|user|>\n{msg['content']}\n"
|
|
101
|
+
else:
|
|
102
|
+
full_prompt += f"<|assistant|>\n{msg['content']}\n"
|
|
103
|
+
full_prompt += "<|assistant|>\n"
|
|
104
|
+
return full_prompt
|
|
105
|
+
|
|
106
|
+
def generate_text(self,
|
|
107
|
+
prompt: str,
|
|
108
|
+
n_predict: Optional[int] = 1024,
|
|
109
|
+
stream: Optional[bool] = False,
|
|
110
|
+
temperature: float = 0.7,
|
|
111
|
+
top_p: float = 0.9,
|
|
112
|
+
repeat_penalty: float = 1.1,
|
|
113
|
+
seed: Optional[int] = None,
|
|
114
|
+
streaming_callback: Optional[Callable[[str, MSG_TYPE], None]] = None,
|
|
115
|
+
**kwargs
|
|
116
|
+
) -> Union[str, dict]:
|
|
117
|
+
"""
|
|
118
|
+
Generate text using the Hugging Face Inference API.
|
|
119
|
+
"""
|
|
120
|
+
if not self.client:
|
|
121
|
+
return {"status": "error", "message": "HF Inference client not initialized."}
|
|
122
|
+
|
|
123
|
+
api_params = self._construct_parameters(temperature, top_p, n_predict, repeat_penalty, seed)
|
|
124
|
+
full_response_text = ""
|
|
125
|
+
|
|
126
|
+
try:
|
|
127
|
+
if stream:
|
|
128
|
+
for chunk in self.client.text_generation(prompt, stream=True, **api_params):
|
|
129
|
+
full_response_text += chunk
|
|
130
|
+
if streaming_callback:
|
|
131
|
+
if not streaming_callback(chunk, MSG_TYPE.MSG_TYPE_CHUNK):
|
|
132
|
+
break
|
|
133
|
+
return full_response_text
|
|
134
|
+
else:
|
|
135
|
+
return self.client.text_generation(prompt, **api_params)
|
|
136
|
+
|
|
137
|
+
except Exception as ex:
|
|
138
|
+
error_message = f"An unexpected error occurred with HF Inference API: {str(ex)}"
|
|
139
|
+
trace_exception(ex)
|
|
140
|
+
return {"status": "error", "message": error_message}
|
|
141
|
+
|
|
142
|
+
def chat(self, discussion: LollmsDiscussion, **kwargs) -> Union[str, dict]:
|
|
143
|
+
"""
|
|
144
|
+
Conduct a chat session using the Inference API by formatting the discussion into a single prompt.
|
|
145
|
+
"""
|
|
146
|
+
prompt = self._format_chat_prompt(discussion)
|
|
147
|
+
return self.generate_text(prompt, **kwargs)
|
|
148
|
+
|
|
149
|
+
def tokenize(self, text: str) -> list:
|
|
150
|
+
"""Tokenize text using tiktoken as a fallback."""
|
|
151
|
+
try:
|
|
152
|
+
encoding = tiktoken.get_encoding("cl100k_base")
|
|
153
|
+
return encoding.encode(text)
|
|
154
|
+
except Exception:
|
|
155
|
+
return list(text.encode('utf-8'))
|
|
156
|
+
|
|
157
|
+
def detokenize(self, tokens: list) -> str:
|
|
158
|
+
"""Detokenize tokens using tiktoken."""
|
|
159
|
+
try:
|
|
160
|
+
encoding = tiktoken.get_encoding("cl100k_base")
|
|
161
|
+
return encoding.decode(tokens)
|
|
162
|
+
except Exception:
|
|
163
|
+
return bytes(tokens).decode('utf-8', errors='ignore')
|
|
164
|
+
|
|
165
|
+
def count_tokens(self, text: str) -> int:
|
|
166
|
+
return len(self.tokenize(text))
|
|
167
|
+
|
|
168
|
+
def embed(self, text: str, **kwargs) -> List[float]:
|
|
169
|
+
"""
|
|
170
|
+
Get embeddings using a dedicated sentence-transformer model from the Inference API.
|
|
171
|
+
"""
|
|
172
|
+
if not self.client:
|
|
173
|
+
raise Exception("HF Inference client not initialized.")
|
|
174
|
+
|
|
175
|
+
# User should specify a sentence-transformer model
|
|
176
|
+
embedding_model = kwargs.get("model")
|
|
177
|
+
if not embedding_model:
|
|
178
|
+
raise ValueError("A sentence-transformer model ID must be provided via the 'model' kwarg.")
|
|
179
|
+
|
|
180
|
+
try:
|
|
181
|
+
# This is a different endpoint on the InferenceClient
|
|
182
|
+
response = self.client.feature_extraction(text, model=embedding_model)
|
|
183
|
+
# The output for many models is a nested list, we need the first element.
|
|
184
|
+
if isinstance(response, list) and isinstance(response[0], list):
|
|
185
|
+
return response[0]
|
|
186
|
+
return response
|
|
187
|
+
except Exception as ex:
|
|
188
|
+
trace_exception(ex)
|
|
189
|
+
raise Exception(f"HF Inference API embedding failed: {str(ex)}") from ex
|
|
190
|
+
|
|
191
|
+
def get_model_info(self) -> dict:
|
|
192
|
+
return {
|
|
193
|
+
"name": self.binding_name,
|
|
194
|
+
"version": "unknown",
|
|
195
|
+
"host_address": "https://api-inference.huggingface.co",
|
|
196
|
+
"model_name": self.model_name,
|
|
197
|
+
"supports_structured_output": False,
|
|
198
|
+
"supports_vision": False, # Vision models use a different API call
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
def listModels(self) -> List[Dict[str, str]]:
|
|
202
|
+
"""Lists text-generation models from the Hugging Face Hub."""
|
|
203
|
+
if not self.hf_api:
|
|
204
|
+
ASCIIColors.error("HF API client not initialized. Cannot list models.")
|
|
205
|
+
return []
|
|
206
|
+
try:
|
|
207
|
+
ASCIIColors.debug("Listing Hugging Face text-generation models...")
|
|
208
|
+
# We filter for the 'text-generation' pipeline tag
|
|
209
|
+
models = self.hf_api.list_models(filter="text-generation", sort="downloads", direction=-1, limit=100)
|
|
210
|
+
model_info_list = []
|
|
211
|
+
for m in models:
|
|
212
|
+
model_info_list.append({
|
|
213
|
+
'model_name': m.modelId,
|
|
214
|
+
'display_name': m.modelId,
|
|
215
|
+
'description': f"Downloads: {m.downloads}, Likes: {m.likes}",
|
|
216
|
+
'owned_by': m.author or "Hugging Face Community"
|
|
217
|
+
})
|
|
218
|
+
return model_info_list
|
|
219
|
+
except Exception as ex:
|
|
220
|
+
trace_exception(ex)
|
|
221
|
+
return []
|
|
222
|
+
|
|
223
|
+
def load_model(self, model_name: str) -> bool:
|
|
224
|
+
"""Sets the model for subsequent operations and re-initializes the client."""
|
|
225
|
+
self.model_name = model_name
|
|
226
|
+
try:
|
|
227
|
+
self.client = InferenceClient(model=self.model_name, token=self.hf_api_key)
|
|
228
|
+
ASCIIColors.info(f"Hugging Face model set to: {model_name}. It will be used on the next API call.")
|
|
229
|
+
return True
|
|
230
|
+
except Exception as e:
|
|
231
|
+
ASCIIColors.error(f"Failed to re-initialize client for model {model_name}: {e}")
|
|
232
|
+
self.client = None
|
|
233
|
+
return False
|
|
234
|
+
|
|
235
|
+
if __name__ == '__main__':
|
|
236
|
+
# Environment variable to set for testing:
|
|
237
|
+
# HUGGING_FACE_HUB_TOKEN: Your Hugging Face API key with read access.
|
|
238
|
+
|
|
239
|
+
if "HUGGING_FACE_HUB_TOKEN" not in os.environ:
|
|
240
|
+
ASCIIColors.red("Error: HUGGING_FACE_HUB_TOKEN environment variable not set.")
|
|
241
|
+
print("Please get your token from https://huggingface.co/settings/tokens and set it.")
|
|
242
|
+
exit(1)
|
|
243
|
+
|
|
244
|
+
ASCIIColors.yellow("--- Testing HuggingFaceInferenceAPIBinding ---")
|
|
245
|
+
|
|
246
|
+
test_model_name = "mistralai/Mistral-7B-Instruct-v0.2"
|
|
247
|
+
test_embedding_model = "sentence-transformers/all-MiniLM-L6-v2"
|
|
248
|
+
|
|
249
|
+
try:
|
|
250
|
+
# --- Initialization ---
|
|
251
|
+
ASCIIColors.cyan("\n--- Initializing Binding ---")
|
|
252
|
+
binding = HuggingFaceInferenceAPIBinding(model_name=test_model_name)
|
|
253
|
+
ASCIIColors.green("Binding initialized successfully.")
|
|
254
|
+
|
|
255
|
+
# --- List Models ---
|
|
256
|
+
ASCIIColors.cyan("\n--- Listing Models ---")
|
|
257
|
+
models = binding.listModels()
|
|
258
|
+
if models:
|
|
259
|
+
ASCIIColors.green(f"Successfully fetched {len(models)} text-generation models.")
|
|
260
|
+
ASCIIColors.info("Top 5 most downloaded models:")
|
|
261
|
+
for m in models[:5]:
|
|
262
|
+
print(f"- {m['model_name']} ({m['description']})")
|
|
263
|
+
else:
|
|
264
|
+
ASCIIColors.warning("No models found or failed to list models.")
|
|
265
|
+
|
|
266
|
+
# --- Text Generation (Non-Streaming) ---
|
|
267
|
+
ASCIIColors.cyan("\n--- Text Generation (Non-Streaming) ---")
|
|
268
|
+
prompt_text = "In a world where AI companions are common, a detective is assigned a new AI partner. Write the first paragraph of their first meeting."
|
|
269
|
+
ASCIIColors.info("Waiting for model to load (this might take a moment for cold starts)...")
|
|
270
|
+
generated_text = binding.generate_text(prompt_text, n_predict=150, stream=False)
|
|
271
|
+
if isinstance(generated_text, str):
|
|
272
|
+
ASCIIColors.green(f"Generated text:\n{generated_text}")
|
|
273
|
+
else:
|
|
274
|
+
ASCIIColors.error(f"Generation failed: {generated_text}")
|
|
275
|
+
|
|
276
|
+
# --- Chat (Streaming) ---
|
|
277
|
+
ASCIIColors.cyan("\n--- Chat (Streaming) ---")
|
|
278
|
+
chat_discussion = LollmsDiscussion.from_messages([
|
|
279
|
+
LollmsMessage.new_message(sender_type="system", content="You are a helpful and pirate-themed assistant named Captain Coder."),
|
|
280
|
+
LollmsMessage.new_message(sender_type="user", content="Ahoy there! Tell me, what be the best language for a scallywag to learn for data science?"),
|
|
281
|
+
])
|
|
282
|
+
full_streamed_text = ""
|
|
283
|
+
def stream_callback(chunk: str, msg_type: int):
|
|
284
|
+
nonlocal full_streamed_text
|
|
285
|
+
ASCIIColors.green(chunk, end="", flush=True)
|
|
286
|
+
full_streamed_text += chunk
|
|
287
|
+
return True
|
|
288
|
+
|
|
289
|
+
result = binding.chat(chat_discussion, n_predict=100, stream=True, streaming_callback=stream_callback)
|
|
290
|
+
print("\n--- End of Stream ---")
|
|
291
|
+
ASCIIColors.green(f"Full streamed text (for verification): {result}")
|
|
292
|
+
|
|
293
|
+
# --- Embeddings Test ---
|
|
294
|
+
ASCIIColors.cyan("\n--- Embeddings ---")
|
|
295
|
+
try:
|
|
296
|
+
embedding_text = "Hugging Face is the home of open-source AI."
|
|
297
|
+
embedding_vector = binding.embed(embedding_text, model=test_embedding_model)
|
|
298
|
+
ASCIIColors.green(f"Embedding for '{embedding_text}' (first 5 dims): {embedding_vector[:5]}...")
|
|
299
|
+
ASCIIColors.info(f"Embedding vector dimension: {len(embedding_vector)}")
|
|
300
|
+
except Exception as e:
|
|
301
|
+
ASCIIColors.error(f"Embedding test failed: {e}")
|
|
302
|
+
|
|
303
|
+
except Exception as e:
|
|
304
|
+
ASCIIColors.error(f"An error occurred during testing: {e}")
|
|
305
|
+
trace_exception(e)
|
|
306
|
+
|
|
307
|
+
ASCIIColors.yellow("\nHuggingFaceInferenceAPIBinding test finished.")
|