lollms-client 0.9.2__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lollms-client might be problematic. Click here for more details.
- lollms_client/__init__.py +1 -1
- lollms_client/llm_bindings/__init__.py +1 -0
- lollms_client/llm_bindings/lollms/__init__.py +302 -0
- lollms_client/llm_bindings/ollama/__init__.py +297 -0
- lollms_client/llm_bindings/openai/__init__.py +261 -0
- lollms_client/llm_bindings/transformers/__init__.py +277 -0
- lollms_client/lollms_core.py +451 -1456
- lollms_client/lollms_llm_binding.py +210 -0
- lollms_client/lollms_tasks.py +42 -109
- lollms_client/lollms_tts.py +7 -3
- lollms_client/lollms_types.py +19 -1
- lollms_client/stt_bindings/__init__.py +0 -0
- lollms_client/stt_bindings/lollms/__init__.py +0 -0
- lollms_client/tti_bindings/__init__.py +0 -0
- lollms_client/tti_bindings/lollms/__init__.py +0 -0
- lollms_client/tts_bindings/__init__.py +0 -0
- lollms_client/tts_bindings/lollms/__init__.py +0 -0
- lollms_client/ttv_bindings/__init__.py +0 -0
- lollms_client/ttv_bindings/lollms/__init__.py +0 -0
- {lollms_client-0.9.2.dist-info → lollms_client-0.11.0.dist-info}/METADATA +26 -13
- lollms_client-0.11.0.dist-info/RECORD +34 -0
- {lollms_client-0.9.2.dist-info → lollms_client-0.11.0.dist-info}/WHEEL +1 -1
- lollms_client-0.9.2.dist-info/RECORD +0 -20
- {lollms_client-0.9.2.dist-info → lollms_client-0.11.0.dist-info/licenses}/LICENSE +0 -0
- {lollms_client-0.9.2.dist-info → lollms_client-0.11.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,261 @@
|
|
|
1
|
+
# bindings/OpenAI/binding.py
|
|
2
|
+
import requests
|
|
3
|
+
import json
|
|
4
|
+
from lollms_client.lollms_llm_binding import LollmsLLMBinding
|
|
5
|
+
from lollms_client.lollms_types import MSG_TYPE
|
|
6
|
+
from lollms_client.lollms_utilities import encode_image
|
|
7
|
+
from lollms_client.lollms_types import ELF_COMPLETION_FORMAT
|
|
8
|
+
from typing import Optional, Callable, List, Union
|
|
9
|
+
from ascii_colors import ASCIIColors, trace_exception
|
|
10
|
+
import pipmaster as pm
|
|
11
|
+
if not pm.is_installed("openai"):
|
|
12
|
+
pm.install("openai")
|
|
13
|
+
if not pm.is_installed("tiktoken"):
|
|
14
|
+
pm.install("tiktoken")
|
|
15
|
+
import openai
|
|
16
|
+
import tiktoken
|
|
17
|
+
import os
|
|
18
|
+
|
|
19
|
+
BindingName = "OpenAIBinding"
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class OpenAIBinding(LollmsLLMBinding):
|
|
23
|
+
"""OpenAI-specific binding implementation"""
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def __init__(self,
|
|
27
|
+
host_address: str = None,
|
|
28
|
+
model_name: str = "",
|
|
29
|
+
service_key: str = None,
|
|
30
|
+
verify_ssl_certificate: bool = True,
|
|
31
|
+
default_completion_format: ELF_COMPLETION_FORMAT = ELF_COMPLETION_FORMAT.Chat):
|
|
32
|
+
"""
|
|
33
|
+
Initialize the OpenAI binding.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
host_address (str): Host address for the OpenAI service. Defaults to DEFAULT_HOST_ADDRESS.
|
|
37
|
+
model_name (str): Name of the model to use. Defaults to empty string.
|
|
38
|
+
service_key (str): Authentication key for the service. Defaults to None.
|
|
39
|
+
verify_ssl_certificate (bool): Whether to verify SSL certificates. Defaults to True.
|
|
40
|
+
personality (Optional[int]): Ignored parameter for compatibility with LollmsLLMBinding.
|
|
41
|
+
"""
|
|
42
|
+
super().__init__(
|
|
43
|
+
host_address=host_address if host_address is not None else self.DEFAULT_HOST_ADDRESS,
|
|
44
|
+
model_name=model_name,
|
|
45
|
+
service_key=service_key,
|
|
46
|
+
verify_ssl_certificate=verify_ssl_certificate,
|
|
47
|
+
default_completion_format=default_completion_format
|
|
48
|
+
)
|
|
49
|
+
self.service_key = os.getenv("OPENAI_API_KEY","")
|
|
50
|
+
self.client = openai.OpenAI(base_url=host_address)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def generate_text(self,
|
|
54
|
+
prompt: str,
|
|
55
|
+
images: Optional[List[str]] = None,
|
|
56
|
+
n_predict: Optional[int] = None,
|
|
57
|
+
stream: bool = False,
|
|
58
|
+
temperature: float = 0.1,
|
|
59
|
+
top_k: int = 50,
|
|
60
|
+
top_p: float = 0.95,
|
|
61
|
+
repeat_penalty: float = 0.8,
|
|
62
|
+
repeat_last_n: int = 40,
|
|
63
|
+
seed: Optional[int] = None,
|
|
64
|
+
n_threads: int = 8,
|
|
65
|
+
ctx_size: int | None = None,
|
|
66
|
+
streaming_callback: Optional[Callable[[str, str], None]] = None) -> str:
|
|
67
|
+
"""
|
|
68
|
+
Generate text based on the provided prompt and parameters.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
prompt (str): The input prompt for text generation.
|
|
72
|
+
images (Optional[List[str]]): List of image file paths for multimodal generation.
|
|
73
|
+
n_predict (Optional[int]): Maximum number of tokens to generate.
|
|
74
|
+
stream (bool): Whether to stream the output. Defaults to False.
|
|
75
|
+
temperature (float): Sampling temperature. Defaults to 0.1.
|
|
76
|
+
top_k (int): Top-k sampling parameter. Defaults to 50.
|
|
77
|
+
top_p (float): Top-p sampling parameter. Defaults to 0.95.
|
|
78
|
+
repeat_penalty (float): Penalty for repeated tokens. Defaults to 0.8.
|
|
79
|
+
repeat_last_n (int): Number of previous tokens to consider for repeat penalty. Defaults to 40.
|
|
80
|
+
seed (Optional[int]): Random seed for generation.
|
|
81
|
+
n_threads (int): Number of threads to use. Defaults to 8.
|
|
82
|
+
streaming_callback (Optional[Callable[[str, str], None]]): Callback function for streaming output.
|
|
83
|
+
- First parameter (str): The chunk of text received.
|
|
84
|
+
- Second parameter (str): The message type (e.g., MSG_TYPE.MSG_TYPE_CHUNK).
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
str: Generated text or error dictionary if failed.
|
|
88
|
+
"""
|
|
89
|
+
count = 0
|
|
90
|
+
output = ""
|
|
91
|
+
|
|
92
|
+
# Prepare messages based on whether images are provided
|
|
93
|
+
if images:
|
|
94
|
+
messages = [
|
|
95
|
+
{
|
|
96
|
+
"role": "user",
|
|
97
|
+
"content": [
|
|
98
|
+
{
|
|
99
|
+
"type": "text",
|
|
100
|
+
"text": prompt
|
|
101
|
+
}
|
|
102
|
+
] + [
|
|
103
|
+
{
|
|
104
|
+
"type": "image_url",
|
|
105
|
+
"image_url": {
|
|
106
|
+
"url": f"data:image/jpeg;base64,{encode_image(image_path)}"
|
|
107
|
+
}
|
|
108
|
+
}
|
|
109
|
+
for image_path in images
|
|
110
|
+
]
|
|
111
|
+
}
|
|
112
|
+
]
|
|
113
|
+
else:
|
|
114
|
+
messages = [{"role": "user", "content": prompt}]
|
|
115
|
+
|
|
116
|
+
# Generate text using the OpenAI API
|
|
117
|
+
if completion_format == ELF_COMPLETION_FORMAT.Chat:
|
|
118
|
+
chat_completion = self.client.chat.completions.create(
|
|
119
|
+
model=self.model_name, # Choose the engine according to your OpenAI plan
|
|
120
|
+
messages=messages,
|
|
121
|
+
max_tokens=n_predict, # Adjust the desired length of the generated response
|
|
122
|
+
n=1, # Specify the number of responses you want
|
|
123
|
+
temperature=temperature, # Adjust the temperature for more or less randomness in the output
|
|
124
|
+
stream=stream
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
if stream:
|
|
128
|
+
for resp in chat_completion:
|
|
129
|
+
if count >= n_predict:
|
|
130
|
+
break
|
|
131
|
+
try:
|
|
132
|
+
word = resp.choices[0].delta.content
|
|
133
|
+
except Exception as ex:
|
|
134
|
+
word = ""
|
|
135
|
+
if streaming_callback is not None:
|
|
136
|
+
if not streaming_callback(word, "MSG_TYPE_CHUNK"):
|
|
137
|
+
break
|
|
138
|
+
if word:
|
|
139
|
+
output += word
|
|
140
|
+
count += 1
|
|
141
|
+
else:
|
|
142
|
+
output = chat_completion.choices[0].message.content
|
|
143
|
+
else:
|
|
144
|
+
completion = self.client.completions.create(
|
|
145
|
+
model=self.model_name, # Choose the engine according to your OpenAI plan
|
|
146
|
+
prompt=prompt,
|
|
147
|
+
max_tokens=n_predict, # Adjust the desired length of the generated response
|
|
148
|
+
n=1, # Specify the number of responses you want
|
|
149
|
+
temperature=temperature, # Adjust the temperature for more or less randomness in the output
|
|
150
|
+
stream=stream
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
if stream:
|
|
154
|
+
for resp in completion:
|
|
155
|
+
if count >= n_predict:
|
|
156
|
+
break
|
|
157
|
+
try:
|
|
158
|
+
word = resp.choices[0].text
|
|
159
|
+
except Exception as ex:
|
|
160
|
+
word = ""
|
|
161
|
+
if streaming_callback is not None:
|
|
162
|
+
if not streaming_callback(word, "MSG_TYPE_CHUNK"):
|
|
163
|
+
break
|
|
164
|
+
if word:
|
|
165
|
+
output += word
|
|
166
|
+
count += 1
|
|
167
|
+
else:
|
|
168
|
+
output = completion.choices[0].text
|
|
169
|
+
|
|
170
|
+
return output
|
|
171
|
+
|
|
172
|
+
def tokenize(self, text: str) -> list:
|
|
173
|
+
"""
|
|
174
|
+
Tokenize the input text into a list of characters.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
text (str): The text to tokenize.
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
list: List of individual characters.
|
|
181
|
+
"""
|
|
182
|
+
try:
|
|
183
|
+
return tiktoken.model.encoding_for_model(self.model_name).encode(text)
|
|
184
|
+
except:
|
|
185
|
+
return tiktoken.model.encoding_for_model("gpt-3.5-turbo").encode(text)
|
|
186
|
+
|
|
187
|
+
def detokenize(self, tokens: list) -> str:
|
|
188
|
+
"""
|
|
189
|
+
Convert a list of tokens back to text.
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
tokens (list): List of tokens (characters) to detokenize.
|
|
193
|
+
|
|
194
|
+
Returns:
|
|
195
|
+
str: Detokenized text.
|
|
196
|
+
"""
|
|
197
|
+
try:
|
|
198
|
+
return tiktoken.model.encoding_for_model(self.model_name).decode(tokens)
|
|
199
|
+
except:
|
|
200
|
+
return tiktoken.model.encoding_for_model("gpt-3.5-turbo").decode(tokens)
|
|
201
|
+
|
|
202
|
+
def embed(self, text: str, **kwargs) -> list:
|
|
203
|
+
"""
|
|
204
|
+
Get embeddings for the input text using OpenAI API
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
text (str or List[str]): Input text to embed
|
|
208
|
+
**kwargs: Additional arguments like model, truncate, options, keep_alive
|
|
209
|
+
|
|
210
|
+
Returns:
|
|
211
|
+
dict: Response containing embeddings
|
|
212
|
+
"""
|
|
213
|
+
pass
|
|
214
|
+
def get_model_info(self) -> dict:
|
|
215
|
+
"""
|
|
216
|
+
Return information about the current OpenAI model.
|
|
217
|
+
|
|
218
|
+
Returns:
|
|
219
|
+
dict: Dictionary containing model name, version, and host address.
|
|
220
|
+
"""
|
|
221
|
+
return {
|
|
222
|
+
"name": "OpenAI",
|
|
223
|
+
"version": "2.0",
|
|
224
|
+
"host_address": self.host_address,
|
|
225
|
+
"model_name": self.model_name
|
|
226
|
+
}
|
|
227
|
+
def listModels(self):
|
|
228
|
+
""" Lists available models """
|
|
229
|
+
url = f'{self.host_address}/v1/models'
|
|
230
|
+
headers = {
|
|
231
|
+
'accept': 'application/json',
|
|
232
|
+
'Authorization': f'Bearer {self.service_key}'
|
|
233
|
+
}
|
|
234
|
+
response = requests.get(url, headers=headers, verify= self.verify_ssl_certificate)
|
|
235
|
+
try:
|
|
236
|
+
data = response.json()
|
|
237
|
+
model_info = []
|
|
238
|
+
|
|
239
|
+
for model in data["data"]:
|
|
240
|
+
model_name = model['id']
|
|
241
|
+
owned_by = model['owned_by']
|
|
242
|
+
created_datetime = model["created"]
|
|
243
|
+
model_info.append({'model_name': model_name, 'owned_by': owned_by, 'created_datetime': created_datetime})
|
|
244
|
+
|
|
245
|
+
return model_info
|
|
246
|
+
except Exception as ex:
|
|
247
|
+
trace_exception(ex)
|
|
248
|
+
return []
|
|
249
|
+
def load_model(self, model_name: str) -> bool:
|
|
250
|
+
"""
|
|
251
|
+
Load a specific model into the OpenAI binding.
|
|
252
|
+
|
|
253
|
+
Args:
|
|
254
|
+
model_name (str): Name of the model to load.
|
|
255
|
+
|
|
256
|
+
Returns:
|
|
257
|
+
bool: True if model loaded successfully.
|
|
258
|
+
"""
|
|
259
|
+
self.model = model_name
|
|
260
|
+
self.model_name = model_name
|
|
261
|
+
return True
|
|
@@ -0,0 +1,277 @@
|
|
|
1
|
+
# bindings/ollama/binding.py
|
|
2
|
+
import requests
|
|
3
|
+
import json
|
|
4
|
+
from lollms_client.lollms_llm_binding import LollmsLLMBinding
|
|
5
|
+
from lollms_client.lollms_types import MSG_TYPE
|
|
6
|
+
from lollms_client.lollms_utilities import encode_image
|
|
7
|
+
from lollms_client.lollms_types import ELF_COMPLETION_FORMAT
|
|
8
|
+
from typing import Optional, Callable, List, Union
|
|
9
|
+
from ascii_colors import ASCIIColors
|
|
10
|
+
|
|
11
|
+
import pipmaster as pm
|
|
12
|
+
if not pm.is_installed("torch"):
|
|
13
|
+
ASCIIColors.yellow("Diffusers: Torch not found. Installing it")
|
|
14
|
+
pm.install_multiple(["torch", "torchvision", "torchaudio"], "https://download.pytorch.org/whl/cu121", force_reinstall=True)
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
if not torch.cuda.is_available():
|
|
18
|
+
ASCIIColors.yellow("Diffusers: Torch not using cuda. Reinstalling it")
|
|
19
|
+
pm.install_multiple(["torch", "torchvision", "torchaudio"], "https://download.pytorch.org/whl/cu121", force_reinstall=True)
|
|
20
|
+
import torch
|
|
21
|
+
|
|
22
|
+
if not pm.is_installed("transformers"):
|
|
23
|
+
pm.install_or_update("transformers")
|
|
24
|
+
|
|
25
|
+
BindingName = "TransformersBinding"
|
|
26
|
+
|
|
27
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, BitsAndBytesConfig
|
|
28
|
+
from packaging import version
|
|
29
|
+
import transformers
|
|
30
|
+
|
|
31
|
+
class TransformersBinding(LollmsLLMBinding):
|
|
32
|
+
"""Transformers-specific binding implementation"""
|
|
33
|
+
|
|
34
|
+
def __init__(self,
|
|
35
|
+
host_address: str = None,
|
|
36
|
+
model_name: str = "",
|
|
37
|
+
service_key: str = None,
|
|
38
|
+
verify_ssl_certificate: bool = True,
|
|
39
|
+
default_completion_format: ELF_COMPLETION_FORMAT = ELF_COMPLETION_FORMAT.Chat,
|
|
40
|
+
prompt_template: Optional[str] = None):
|
|
41
|
+
"""
|
|
42
|
+
Initialize the Transformers binding.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
host_address (str): Host address for the service. Defaults to None.
|
|
46
|
+
model_name (str): Name of the model to use. Defaults to empty string.
|
|
47
|
+
service_key (str): Authentication key for the service. Defaults to None.
|
|
48
|
+
verify_ssl_certificate (bool): Whether to verify SSL certificates. Defaults to True.
|
|
49
|
+
default_completion_format (ELF_COMPLETION_FORMAT): Default format for completions.
|
|
50
|
+
prompt_template (Optional[str]): Custom prompt template. If None, inferred from model.
|
|
51
|
+
"""
|
|
52
|
+
super().__init__(
|
|
53
|
+
host_address=host_address,
|
|
54
|
+
model_name=model_name,
|
|
55
|
+
service_key=service_key,
|
|
56
|
+
verify_ssl_certificate=verify_ssl_certificate,
|
|
57
|
+
default_completion_format=default_completion_format
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
# Configure 4-bit quantization
|
|
61
|
+
quantization_config = BitsAndBytesConfig(
|
|
62
|
+
load_in_4bit=True,
|
|
63
|
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
|
64
|
+
bnb_4bit_quant_type="nf4",
|
|
65
|
+
bnb_4bit_use_double_quant=True
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
69
|
+
str(model_name),
|
|
70
|
+
trust_remote_code=False
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
self.model = AutoModelForCausalLM.from_pretrained(
|
|
74
|
+
str(model_name),
|
|
75
|
+
device_map="auto",
|
|
76
|
+
quantization_config=quantization_config,
|
|
77
|
+
torch_dtype=torch.bfloat16
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
self.generation_config = GenerationConfig.from_pretrained(str(model_name))
|
|
81
|
+
|
|
82
|
+
# Infer or set prompt template
|
|
83
|
+
self.prompt_template = prompt_template if prompt_template else self._infer_prompt_template(model_name)
|
|
84
|
+
|
|
85
|
+
# Display device information
|
|
86
|
+
device = next(self.model.parameters()).device
|
|
87
|
+
device_type = "CPU" if device.type == "cpu" else "GPU"
|
|
88
|
+
device_str = f"Running on {device}"
|
|
89
|
+
|
|
90
|
+
ASCIIColors.multicolor(
|
|
91
|
+
["Model loaded - ", device_str],
|
|
92
|
+
[ASCIIColors.color_green, ASCIIColors.color_blue if device_type == "GPU" else ASCIIColors.color_red]
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
def _infer_prompt_template(self, model_name: str) -> str:
|
|
96
|
+
"""
|
|
97
|
+
Infer the prompt template based on the model name.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
model_name (str): Name of the model.
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
str: The inferred prompt template format string.
|
|
104
|
+
"""
|
|
105
|
+
model_name = model_name.lower()
|
|
106
|
+
if "llama-2" in model_name or "llama" in model_name:
|
|
107
|
+
return "[INST] <<SYS>> {system_prompt} <</SYS>> {user_prompt} [/INST]"
|
|
108
|
+
elif "gpt" in model_name:
|
|
109
|
+
return "{system_prompt}\n{user_prompt}" # Simple concatenation for GPT-style models
|
|
110
|
+
else:
|
|
111
|
+
# Default to a basic chat format
|
|
112
|
+
ASCIIColors.yellow(f"Warning: No specific template found for {model_name}. Using default chat format.")
|
|
113
|
+
return "[INST] {system_prompt}\n{user_prompt} [/INST]"
|
|
114
|
+
|
|
115
|
+
def generate_text(self,
|
|
116
|
+
prompt: str,
|
|
117
|
+
images: Optional[List[str]] = None,
|
|
118
|
+
n_predict: Optional[int] = None,
|
|
119
|
+
stream: bool = False,
|
|
120
|
+
temperature: float = 0.1,
|
|
121
|
+
top_k: int = 50,
|
|
122
|
+
top_p: float = 0.95,
|
|
123
|
+
repeat_penalty: float = 0.8,
|
|
124
|
+
repeat_last_n: int = 40,
|
|
125
|
+
seed: Optional[int] = None,
|
|
126
|
+
n_threads: int = 8,
|
|
127
|
+
ctx_size: int | None = None,
|
|
128
|
+
streaming_callback: Optional[Callable[[str, str], None]] = None,
|
|
129
|
+
return_legacy_cache: bool = False,
|
|
130
|
+
system_prompt: str = "You are a helpful assistant.") -> Union[str, dict]:
|
|
131
|
+
"""
|
|
132
|
+
Generate text using the Transformers model, with optional image support.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
prompt (str): The input prompt for text generation (user prompt).
|
|
136
|
+
images (Optional[List[str]]): List of image file paths for multimodal generation.
|
|
137
|
+
n_predict (Optional[int]): Maximum number of tokens to generate.
|
|
138
|
+
stream (bool): Whether to stream the output. Defaults to False.
|
|
139
|
+
temperature (float): Sampling temperature. Defaults to 0.1.
|
|
140
|
+
top_k (int): Top-k sampling parameter. Defaults to 50.
|
|
141
|
+
top_p (float): Top-p sampling parameter. Defaults to 0.95.
|
|
142
|
+
repeat_penalty (float): Penalty for repeated tokens. Defaults to 0.8.
|
|
143
|
+
repeat_last_n (int): Number of previous tokens to consider for repeat penalty. Defaults to 40.
|
|
144
|
+
seed (Optional[int]): Random seed for generation.
|
|
145
|
+
n_threads (int): Number of threads to use. Defaults to 8.
|
|
146
|
+
streaming_callback (Optional[Callable[[str, str], None]]): Callback for streaming output.
|
|
147
|
+
return_legacy_cache (bool): Whether to use legacy cache format (pre-v4.47). Defaults to False.
|
|
148
|
+
system_prompt (str): System prompt to set model behavior. Defaults to "You are a helpful assistant."
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
Union[str, dict]: Generated text if successful, or a dictionary with status and error if failed.
|
|
152
|
+
"""
|
|
153
|
+
try:
|
|
154
|
+
if not self.model or not self.tokenizer:
|
|
155
|
+
return {"status": "error", "error": "Model or tokenizer not loaded"}
|
|
156
|
+
|
|
157
|
+
# Set seed if provided
|
|
158
|
+
if seed is not None:
|
|
159
|
+
torch.manual_seed(seed)
|
|
160
|
+
|
|
161
|
+
# Apply the prompt template
|
|
162
|
+
formatted_prompt = self.prompt_template.format(
|
|
163
|
+
system_prompt=system_prompt,
|
|
164
|
+
user_prompt=prompt
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
# Prepare generation config
|
|
168
|
+
self.generation_config.max_new_tokens = n_predict if n_predict else 2048
|
|
169
|
+
self.generation_config.temperature = temperature
|
|
170
|
+
self.generation_config.top_k = top_k
|
|
171
|
+
self.generation_config.top_p = top_p
|
|
172
|
+
self.generation_config.repetition_penalty = repeat_penalty
|
|
173
|
+
self.generation_config.pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
|
|
174
|
+
|
|
175
|
+
# Tokenize input with attention mask
|
|
176
|
+
inputs = self.tokenizer(formatted_prompt, return_tensors="pt", padding=True)
|
|
177
|
+
input_ids = inputs.input_ids.to(self.model.device)
|
|
178
|
+
attention_mask = inputs.attention_mask.to(self.model.device)
|
|
179
|
+
|
|
180
|
+
# Handle image input if provided (basic implementation)
|
|
181
|
+
if images and len(images) > 0:
|
|
182
|
+
ASCIIColors.yellow("Warning: Image processing not fully implemented in this binding")
|
|
183
|
+
formatted_prompt += "\n[Image content not processed]"
|
|
184
|
+
|
|
185
|
+
# Check transformers version for cache handling
|
|
186
|
+
use_legacy_cache = return_legacy_cache or version.parse(transformers.__version__) < version.parse("4.47.0")
|
|
187
|
+
|
|
188
|
+
if stream:
|
|
189
|
+
# Streaming case
|
|
190
|
+
if not streaming_callback:
|
|
191
|
+
return {"status": "error", "error": "Streaming callback required for stream mode"}
|
|
192
|
+
|
|
193
|
+
generated_text = ""
|
|
194
|
+
# Generate with streaming
|
|
195
|
+
for output in self.model.generate(
|
|
196
|
+
input_ids,
|
|
197
|
+
attention_mask=attention_mask,
|
|
198
|
+
generation_config=self.generation_config,
|
|
199
|
+
do_sample=True,
|
|
200
|
+
return_dict_in_generate=True,
|
|
201
|
+
output_scores=False,
|
|
202
|
+
return_legacy_cache=use_legacy_cache
|
|
203
|
+
):
|
|
204
|
+
# Handle different output formats based on version/cache setting
|
|
205
|
+
if use_legacy_cache:
|
|
206
|
+
sequences = output[0]
|
|
207
|
+
else:
|
|
208
|
+
sequences = output.sequences
|
|
209
|
+
|
|
210
|
+
# Decode the new tokens
|
|
211
|
+
new_tokens = sequences[:, -1:] # Get the last generated token
|
|
212
|
+
chunk = self.tokenizer.decode(new_tokens[0], skip_special_tokens=True)
|
|
213
|
+
generated_text += chunk
|
|
214
|
+
|
|
215
|
+
# Send chunk through callback
|
|
216
|
+
streaming_callback(chunk, MSG_TYPE.MSG_TYPE_CHUNK)
|
|
217
|
+
|
|
218
|
+
return generated_text
|
|
219
|
+
|
|
220
|
+
else:
|
|
221
|
+
# Non-streaming case
|
|
222
|
+
outputs = self.model.generate(
|
|
223
|
+
input_ids,
|
|
224
|
+
attention_mask=attention_mask,
|
|
225
|
+
generation_config=self.generation_config,
|
|
226
|
+
do_sample=True,
|
|
227
|
+
return_dict_in_generate=True,
|
|
228
|
+
output_scores=False,
|
|
229
|
+
return_legacy_cache=use_legacy_cache
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
# Handle different output formats
|
|
233
|
+
sequences = outputs[0] if use_legacy_cache else outputs.sequences
|
|
234
|
+
|
|
235
|
+
# Decode the full sequence, removing the input prompt
|
|
236
|
+
generated_text = self.tokenizer.decode(
|
|
237
|
+
sequences[0][input_ids.shape[-1]:],
|
|
238
|
+
skip_special_tokens=True
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
return generated_text
|
|
242
|
+
|
|
243
|
+
except Exception as e:
|
|
244
|
+
error_msg = f"Error generating text: {str(e)}"
|
|
245
|
+
ASCIIColors.red(error_msg)
|
|
246
|
+
return {"status": "error", "error": error_msg}
|
|
247
|
+
|
|
248
|
+
def tokenize(self, text: str) -> list:
|
|
249
|
+
"""Tokenize the input text into a list of characters."""
|
|
250
|
+
return list(text)
|
|
251
|
+
|
|
252
|
+
def detokenize(self, tokens: list) -> str:
|
|
253
|
+
"""Convert a list of tokens back to text."""
|
|
254
|
+
return "".join(tokens)
|
|
255
|
+
|
|
256
|
+
def embed(self, text: str, **kwargs) -> list:
|
|
257
|
+
"""Get embeddings for the input text (placeholder)."""
|
|
258
|
+
pass
|
|
259
|
+
|
|
260
|
+
def get_model_info(self) -> dict:
|
|
261
|
+
"""Return information about the current model."""
|
|
262
|
+
return {
|
|
263
|
+
"name": "transformers",
|
|
264
|
+
"version": transformers.__version__,
|
|
265
|
+
"host_address": self.host_address,
|
|
266
|
+
"model_name": self.model_name
|
|
267
|
+
}
|
|
268
|
+
|
|
269
|
+
def listModels(self):
|
|
270
|
+
"""Lists available models (placeholder)."""
|
|
271
|
+
pass
|
|
272
|
+
|
|
273
|
+
def load_model(self, model_name: str) -> bool:
|
|
274
|
+
"""Load a specific model into the binding."""
|
|
275
|
+
self.model = model_name
|
|
276
|
+
self.model_name = model_name
|
|
277
|
+
return True
|