lollms-client 0.14.0__py3-none-any.whl → 0.15.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lollms-client might be problematic. Click here for more details.
- examples/simple_text_gen_with_image_test.py +21 -9
- examples/text_gen.py +3 -1
- examples/text_gen_system_prompt.py +2 -1
- lollms_client/__init__.py +1 -1
- lollms_client/llm_bindings/llamacpp/__init__.py +1041 -0
- lollms_client/llm_bindings/ollama/__init__.py +3 -3
- lollms_client/llm_bindings/openllm/__init__.py +547 -0
- lollms_client/llm_bindings/pythonllamacpp/__init__.py +591 -0
- lollms_client/llm_bindings/transformers/__init__.py +660 -251
- lollms_client/lollms_core.py +21 -22
- lollms_client/lollms_llm_binding.py +1 -5
- {lollms_client-0.14.0.dist-info → lollms_client-0.15.0.dist-info}/METADATA +1 -1
- {lollms_client-0.14.0.dist-info → lollms_client-0.15.0.dist-info}/RECORD +16 -13
- {lollms_client-0.14.0.dist-info → lollms_client-0.15.0.dist-info}/WHEEL +1 -1
- {lollms_client-0.14.0.dist-info → lollms_client-0.15.0.dist-info}/licenses/LICENSE +0 -0
- {lollms_client-0.14.0.dist-info → lollms_client-0.15.0.dist-info}/top_level.txt +0 -0
|
@@ -1,287 +1,696 @@
|
|
|
1
|
-
# bindings/
|
|
2
|
-
import requests
|
|
1
|
+
# bindings/huggingface_hub/binding.py
|
|
3
2
|
import json
|
|
3
|
+
import os
|
|
4
|
+
import pprint
|
|
5
|
+
import re
|
|
6
|
+
import socket # Not used directly for server, but good to keep for consistency if needed elsewhere
|
|
7
|
+
import subprocess # Not used for server
|
|
8
|
+
import sys
|
|
9
|
+
import threading
|
|
10
|
+
import time
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import Optional, Callable, List, Union, Dict, Any, Set
|
|
13
|
+
import base64 # For potential image data handling, though PIL.Image is primary
|
|
14
|
+
import requests # Not used for server, but for consistency
|
|
15
|
+
|
|
4
16
|
from lollms_client.lollms_llm_binding import LollmsLLMBinding
|
|
5
|
-
from lollms_client.lollms_types import MSG_TYPE
|
|
6
|
-
from lollms_client.lollms_utilities import encode_image
|
|
7
|
-
from lollms_client.lollms_types import ELF_COMPLETION_FORMAT
|
|
8
|
-
from typing import Optional, Callable, List, Union
|
|
9
|
-
from ascii_colors import ASCIIColors
|
|
17
|
+
from lollms_client.lollms_types import MSG_TYPE, ELF_COMPLETION_FORMAT
|
|
10
18
|
|
|
19
|
+
from ascii_colors import ASCIIColors, trace_exception
|
|
11
20
|
import pipmaster as pm
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
21
|
+
|
|
22
|
+
# --- Pipmaster: Ensure dependencies ---
|
|
23
|
+
pm.ensure_packages([
|
|
24
|
+
"torch",
|
|
25
|
+
"transformers",
|
|
26
|
+
"accelerate", # For device_map="auto" and advanced model loading
|
|
27
|
+
"bitsandbytes", # For 4-bit/8-bit quantization (works best on CUDA)
|
|
28
|
+
"sentence_transformers", # For robust embedding generation
|
|
29
|
+
"pillow" # For image handling (vision models)
|
|
30
|
+
])
|
|
31
|
+
|
|
32
|
+
try:
|
|
20
33
|
import torch
|
|
34
|
+
from transformers import (
|
|
35
|
+
AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer,
|
|
36
|
+
BitsAndBytesConfig, AutoConfig, GenerationConfig,
|
|
37
|
+
AutoProcessor, LlavaForConditionalGeneration, LlavaNextForConditionalGeneration, # Common LLaVA models
|
|
38
|
+
StoppingCriteria, StoppingCriteriaList
|
|
39
|
+
)
|
|
40
|
+
from sentence_transformers import SentenceTransformer
|
|
41
|
+
from PIL import Image
|
|
42
|
+
except ImportError as e:
|
|
43
|
+
ASCIIColors.error(f"Failed to import core libraries: {e}")
|
|
44
|
+
ASCIIColors.error("Please ensure torch, transformers, accelerate, bitsandbytes, sentence_transformers, and pillow are installed.")
|
|
45
|
+
trace_exception(e)
|
|
46
|
+
# Set them to None so the binding can report failure cleanly if __init__ is still called.
|
|
47
|
+
torch = None
|
|
48
|
+
transformers = None
|
|
49
|
+
sentence_transformers = None
|
|
50
|
+
Image = None
|
|
21
51
|
|
|
22
|
-
if not pm.is_installed("transformers"):
|
|
23
|
-
pm.install_or_update("transformers")
|
|
24
52
|
|
|
25
|
-
|
|
53
|
+
# --- Custom Stopping Criteria for Hugging Face generate ---
|
|
54
|
+
class StopOnWords(StoppingCriteria):
|
|
55
|
+
def __init__(self, tokenizer, stop_words: List[str]):
|
|
56
|
+
super().__init__()
|
|
57
|
+
self.tokenizer = tokenizer
|
|
58
|
+
self.stop_sequences_token_ids = []
|
|
59
|
+
for word in stop_words:
|
|
60
|
+
# Encode stop words without adding special tokens to get their raw token IDs
|
|
61
|
+
token_ids = tokenizer.encode(word, add_special_tokens=False)
|
|
62
|
+
if token_ids:
|
|
63
|
+
self.stop_sequences_token_ids.append(torch.tensor(token_ids))
|
|
26
64
|
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
65
|
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
|
66
|
+
for stop_seq_ids in self.stop_sequences_token_ids:
|
|
67
|
+
if input_ids.shape[1] >= stop_seq_ids.shape[0]:
|
|
68
|
+
# Check if the end of input_ids matches the stop sequence
|
|
69
|
+
if torch.equal(input_ids[0, -stop_seq_ids.shape[0]:], stop_seq_ids.to(input_ids.device)):
|
|
70
|
+
return True
|
|
71
|
+
return False
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
BindingName = "HuggingFaceHubBinding"
|
|
75
|
+
|
|
76
|
+
class HuggingFaceHubBinding(LollmsLLMBinding):
|
|
77
|
+
DEFAULT_CONFIG_ARGS = {
|
|
78
|
+
"device": "auto", # "auto", "cuda", "mps", "cpu"
|
|
79
|
+
"quantize": False, # False, "8bit", "4bit" (8bit/4bit require CUDA and bitsandbytes)
|
|
80
|
+
"torch_dtype": "auto", # "auto", "float16", "bfloat16", "float32"
|
|
81
|
+
"max_new_tokens": 2048, # Default for generation
|
|
82
|
+
"temperature": 0.7,
|
|
83
|
+
"top_k": 50,
|
|
84
|
+
"top_p": 0.95,
|
|
85
|
+
"repetition_penalty": 1.1,
|
|
86
|
+
"trust_remote_code": False, # Set to True for models like Phi, some LLaVA, etc.
|
|
87
|
+
"use_flash_attention_2": False, # If supported by hardware/model & transformers version
|
|
88
|
+
"embedding_model_name": "sentence-transformers/all-MiniLM-L6-v2", # Default for embed()
|
|
89
|
+
"generation_timeout": 300, # Timeout for non-streaming generation
|
|
90
|
+
"stop_words": [], # List of strings to stop generation
|
|
91
|
+
}
|
|
30
92
|
|
|
31
|
-
class TransformersBinding(LollmsLLMBinding):
|
|
32
|
-
"""Transformers-specific binding implementation"""
|
|
33
|
-
|
|
34
93
|
def __init__(self,
|
|
35
|
-
|
|
36
|
-
|
|
94
|
+
model_name_or_id: str, # Can be HF Hub ID or local folder name relative to models_path
|
|
95
|
+
models_path: Union[str, Path],
|
|
96
|
+
config: Optional[Dict[str, Any]] = None,
|
|
37
97
|
default_completion_format: ELF_COMPLETION_FORMAT = ELF_COMPLETION_FORMAT.Chat,
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
Initialize the Transformers binding.
|
|
41
|
-
|
|
42
|
-
Args:
|
|
43
|
-
models_folder (str): The folder where we can find local models
|
|
44
|
-
model_name (str): Name of the model to use. Defaults to empty string.
|
|
45
|
-
service_key (str): Authentication key for the service. Defaults to None.
|
|
46
|
-
verify_ssl_certificate (bool): Whether to verify SSL certificates. Defaults to True.
|
|
47
|
-
default_completion_format (ELF_COMPLETION_FORMAT): Default format for completions.
|
|
48
|
-
prompt_template (Optional[str]): Custom prompt template. If None, inferred from model.
|
|
49
|
-
"""
|
|
50
|
-
super().__init__(
|
|
51
|
-
binding_name = "transformers"
|
|
52
|
-
)
|
|
53
|
-
self.models_folder= models_folder
|
|
54
|
-
self.model_name=model_name
|
|
55
|
-
self.default_completion_format=default_completion_format
|
|
56
|
-
|
|
57
|
-
# Configure 4-bit quantization
|
|
58
|
-
quantization_config = BitsAndBytesConfig(
|
|
59
|
-
load_in_4bit=True,
|
|
60
|
-
bnb_4bit_compute_dtype=torch.bfloat16,
|
|
61
|
-
bnb_4bit_quant_type="nf4",
|
|
62
|
-
bnb_4bit_use_double_quant=True
|
|
63
|
-
)
|
|
64
|
-
|
|
65
|
-
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
66
|
-
str(model_name),
|
|
67
|
-
trust_remote_code=False
|
|
68
|
-
)
|
|
69
|
-
|
|
70
|
-
self.model = AutoModelForCausalLM.from_pretrained(
|
|
71
|
-
str(model_name),
|
|
72
|
-
device_map="auto",
|
|
73
|
-
quantization_config=quantization_config,
|
|
74
|
-
torch_dtype=torch.bfloat16
|
|
75
|
-
)
|
|
76
|
-
|
|
77
|
-
self.generation_config = GenerationConfig.from_pretrained(str(model_name))
|
|
98
|
+
**kwargs # Overrides for config_args
|
|
99
|
+
):
|
|
78
100
|
|
|
79
|
-
|
|
80
|
-
|
|
101
|
+
super().__init__(binding_name=BindingName)
|
|
102
|
+
|
|
103
|
+
if torch is None or transformers is None: # Check if core imports failed
|
|
104
|
+
raise ImportError("Core libraries (torch, transformers) not available. Binding cannot function.")
|
|
105
|
+
|
|
106
|
+
self.models_path = Path(models_path)
|
|
107
|
+
self.config = {**self.DEFAULT_CONFIG_ARGS, **(config or {}), **kwargs}
|
|
108
|
+
self.default_completion_format = default_completion_format
|
|
81
109
|
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
110
|
+
self.model_identifier: Optional[str] = None
|
|
111
|
+
self.model_name: Optional[str] = None # User-friendly name (folder name or hub id)
|
|
112
|
+
self.model: Optional[Union[AutoModelForCausalLM, LlavaForConditionalGeneration, LlavaNextForConditionalGeneration]] = None
|
|
113
|
+
self.tokenizer: Optional[AutoTokenizer] = None
|
|
114
|
+
self.processor: Optional[AutoProcessor] = None # For vision models
|
|
115
|
+
self.embedding_model: Optional[SentenceTransformer] = None
|
|
116
|
+
self.device: Optional[str] = None
|
|
117
|
+
self.torch_dtype: Optional[torch.dtype] = None
|
|
118
|
+
self.supports_vision: bool = False
|
|
119
|
+
|
|
120
|
+
# Attempt to load the model during initialization
|
|
121
|
+
if not self.load_model(model_name_or_id):
|
|
122
|
+
# load_model will print errors. Here we can raise if critical.
|
|
123
|
+
ASCIIColors.error(f"Initial model load failed for {model_name_or_id}. Binding may not be functional.")
|
|
124
|
+
# Depending on Lollms behavior, this might be acceptable if user can select another model later.
|
|
125
|
+
|
|
126
|
+
def _resolve_model_path_or_id(self, model_name_or_id: str) -> str:
|
|
127
|
+
# 1. Check if it's an absolute path to a model directory
|
|
128
|
+
abs_path = Path(model_name_or_id)
|
|
129
|
+
if abs_path.is_absolute() and abs_path.is_dir() and (abs_path / "config.json").exists():
|
|
130
|
+
ASCIIColors.info(f"Using absolute model path: {abs_path}")
|
|
131
|
+
return str(abs_path)
|
|
132
|
+
|
|
133
|
+
# 2. Check if it's a name relative to self.models_path
|
|
134
|
+
local_model_path = self.models_path / model_name_or_id
|
|
135
|
+
if local_model_path.is_dir() and (local_model_path / "config.json").exists():
|
|
136
|
+
ASCIIColors.info(f"Found local model in models_path: {local_model_path}")
|
|
137
|
+
return str(local_model_path)
|
|
86
138
|
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
139
|
+
# 3. Assume it's a Hugging Face Hub ID
|
|
140
|
+
ASCIIColors.info(f"Assuming '{model_name_or_id}' is a Hugging Face Hub ID.")
|
|
141
|
+
return model_name_or_id
|
|
142
|
+
|
|
143
|
+
def load_model(self, model_name_or_id: str) -> bool:
|
|
144
|
+
if self.model is not None:
|
|
145
|
+
self.unload_model()
|
|
91
146
|
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
"""
|
|
102
|
-
model_name = model_name.lower()
|
|
103
|
-
if "llama-2" in model_name or "llama" in model_name:
|
|
104
|
-
return "[INST] <<SYS>> {system_prompt} <</SYS>> {user_prompt} [/INST]"
|
|
105
|
-
elif "gpt" in model_name:
|
|
106
|
-
return "{system_prompt}\n{user_prompt}" # Simple concatenation for GPT-style models
|
|
147
|
+
self.model_identifier = self._resolve_model_path_or_id(model_name_or_id)
|
|
148
|
+
self.model_name = Path(self.model_identifier).name # User-friendly name
|
|
149
|
+
|
|
150
|
+
# --- Device Selection ---
|
|
151
|
+
device_pref = self.config.get("device", "auto")
|
|
152
|
+
if device_pref == "auto":
|
|
153
|
+
if torch.cuda.is_available(): self.device = "cuda"
|
|
154
|
+
elif torch.backends.mps.is_available(): self.device = "mps" # For Apple Silicon
|
|
155
|
+
else: self.device = "cpu"
|
|
107
156
|
else:
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
return "[INST] {system_prompt}\n{user_prompt} [/INST]"
|
|
157
|
+
self.device = device_pref
|
|
158
|
+
ASCIIColors.info(f"Using device: {self.device}")
|
|
111
159
|
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
repeat_penalty: float = 0.8,
|
|
122
|
-
repeat_last_n: int = 40,
|
|
123
|
-
seed: Optional[int] = None,
|
|
124
|
-
n_threads: int = 8,
|
|
125
|
-
ctx_size: int | None = None,
|
|
126
|
-
streaming_callback: Optional[Callable[[str, str], None]] = None,
|
|
127
|
-
return_legacy_cache: bool = False) -> Union[str, dict]:
|
|
128
|
-
"""
|
|
129
|
-
Generate text using the Transformers model, with optional image support.
|
|
130
|
-
|
|
131
|
-
Args:
|
|
132
|
-
prompt (str): The input prompt for text generation (user prompt).
|
|
133
|
-
images (Optional[List[str]]): List of image file paths for multimodal generation.
|
|
134
|
-
n_predict (Optional[int]): Maximum number of tokens to generate.
|
|
135
|
-
stream (bool): Whether to stream the output. Defaults to False.
|
|
136
|
-
temperature (float): Sampling temperature. Defaults to 0.1.
|
|
137
|
-
top_k (int): Top-k sampling parameter. Defaults to 50.
|
|
138
|
-
top_p (float): Top-p sampling parameter. Defaults to 0.95.
|
|
139
|
-
repeat_penalty (float): Penalty for repeated tokens. Defaults to 0.8.
|
|
140
|
-
repeat_last_n (int): Number of previous tokens to consider for repeat penalty. Defaults to 40.
|
|
141
|
-
seed (Optional[int]): Random seed for generation.
|
|
142
|
-
n_threads (int): Number of threads to use. Defaults to 8.
|
|
143
|
-
streaming_callback (Optional[Callable[[str, str], None]]): Callback for streaming output.
|
|
144
|
-
return_legacy_cache (bool): Whether to use legacy cache format (pre-v4.47). Defaults to False.
|
|
145
|
-
system_prompt (str): System prompt to set model behavior. Defaults to "You are a helpful assistant."
|
|
146
|
-
|
|
147
|
-
Returns:
|
|
148
|
-
Union[str, dict]: Generated text if successful, or a dictionary with status and error if failed.
|
|
149
|
-
"""
|
|
150
|
-
try:
|
|
151
|
-
if not self.model or not self.tokenizer:
|
|
152
|
-
return {"status": "error", "error": "Model or tokenizer not loaded"}
|
|
160
|
+
# --- Dtype Selection ---
|
|
161
|
+
dtype_pref = self.config.get("torch_dtype", "auto")
|
|
162
|
+
if dtype_pref == "auto":
|
|
163
|
+
if self.device == "cuda": self.torch_dtype = torch.float16 # bfloat16 is better for Ampere+
|
|
164
|
+
else: self.torch_dtype = torch.float32 # MPS and CPU generally use float32
|
|
165
|
+
elif dtype_pref == "float16": self.torch_dtype = torch.float16
|
|
166
|
+
elif dtype_pref == "bfloat16": self.torch_dtype = torch.bfloat16
|
|
167
|
+
else: self.torch_dtype = torch.float32
|
|
168
|
+
ASCIIColors.info(f"Using DType: {self.torch_dtype}")
|
|
153
169
|
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
170
|
+
# --- Quantization ---
|
|
171
|
+
quantize_mode = self.config.get("quantize", False)
|
|
172
|
+
load_in_8bit = False
|
|
173
|
+
load_in_4bit = False
|
|
174
|
+
bnb_config = None
|
|
157
175
|
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
176
|
+
if self.device == "cuda": # bitsandbytes primarily for CUDA
|
|
177
|
+
if quantize_mode == "8bit":
|
|
178
|
+
load_in_8bit = True
|
|
179
|
+
ASCIIColors.info("Quantizing model to 8-bit.")
|
|
180
|
+
elif quantize_mode == "4bit":
|
|
181
|
+
load_in_4bit = True
|
|
182
|
+
bnb_config = BitsAndBytesConfig(
|
|
183
|
+
load_in_4bit=True,
|
|
184
|
+
bnb_4bit_quant_type="nf4",
|
|
185
|
+
bnb_4bit_use_double_quant=True,
|
|
186
|
+
bnb_4bit_compute_dtype=self.torch_dtype # e.g., torch.bfloat16 for computation
|
|
187
|
+
)
|
|
188
|
+
ASCIIColors.info("Quantizing model to 4-bit.")
|
|
189
|
+
elif quantize_mode and self.device != "cuda":
|
|
190
|
+
ASCIIColors.warning(f"Quantization ('{quantize_mode}') is selected but device is '{self.device}'. bitsandbytes works best on CUDA. Proceeding without quantization.")
|
|
191
|
+
quantize_mode = False
|
|
163
192
|
|
|
164
|
-
# Prepare generation config
|
|
165
|
-
self.generation_config.max_new_tokens = n_predict if n_predict else 2048
|
|
166
|
-
self.generation_config.temperature = temperature
|
|
167
|
-
self.generation_config.top_k = top_k
|
|
168
|
-
self.generation_config.top_p = top_p
|
|
169
|
-
self.generation_config.repetition_penalty = repeat_penalty
|
|
170
|
-
self.generation_config.pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
|
|
171
|
-
|
|
172
|
-
# Tokenize input with attention mask
|
|
173
|
-
inputs = self.tokenizer(formatted_prompt, return_tensors="pt", padding=True)
|
|
174
|
-
input_ids = inputs.input_ids.to(self.model.device)
|
|
175
|
-
attention_mask = inputs.attention_mask.to(self.model.device)
|
|
176
|
-
|
|
177
|
-
# Handle image input if provided (basic implementation)
|
|
178
|
-
if images and len(images) > 0:
|
|
179
|
-
ASCIIColors.yellow("Warning: Image processing not fully implemented in this binding")
|
|
180
|
-
formatted_prompt += "\n[Image content not processed]"
|
|
181
|
-
|
|
182
|
-
# Check transformers version for cache handling
|
|
183
|
-
use_legacy_cache = return_legacy_cache or version.parse(transformers.__version__) < version.parse("4.47.0")
|
|
184
|
-
|
|
185
|
-
if stream:
|
|
186
|
-
# Streaming case
|
|
187
|
-
if not streaming_callback:
|
|
188
|
-
return {"status": "error", "error": "Streaming callback required for stream mode"}
|
|
189
|
-
|
|
190
|
-
generated_text = ""
|
|
191
|
-
# Generate with streaming
|
|
192
|
-
for output in self.model.generate(
|
|
193
|
-
input_ids,
|
|
194
|
-
attention_mask=attention_mask,
|
|
195
|
-
generation_config=self.generation_config,
|
|
196
|
-
do_sample=True,
|
|
197
|
-
return_dict_in_generate=True,
|
|
198
|
-
output_scores=False,
|
|
199
|
-
return_legacy_cache=use_legacy_cache
|
|
200
|
-
):
|
|
201
|
-
# Handle different output formats based on version/cache setting
|
|
202
|
-
if use_legacy_cache:
|
|
203
|
-
sequences = output[0]
|
|
204
|
-
else:
|
|
205
|
-
sequences = output.sequences
|
|
206
|
-
|
|
207
|
-
# Decode the new tokens
|
|
208
|
-
new_tokens = sequences[:, -1:] # Get the last generated token
|
|
209
|
-
chunk = self.tokenizer.decode(new_tokens[0], skip_special_tokens=True)
|
|
210
|
-
generated_text += chunk
|
|
211
|
-
|
|
212
|
-
# Send chunk through callback
|
|
213
|
-
streaming_callback(chunk, MSG_TYPE.MSG_TYPE_CHUNK)
|
|
214
|
-
|
|
215
|
-
return generated_text
|
|
216
193
|
|
|
194
|
+
# --- Model Loading Arguments ---
|
|
195
|
+
model_load_args = {
|
|
196
|
+
"trust_remote_code": self.config.get("trust_remote_code", False),
|
|
197
|
+
# torch_dtype is handled by BitsAndBytesConfig if quantizing, otherwise set directly
|
|
198
|
+
"torch_dtype": self.torch_dtype if not (load_in_8bit or load_in_4bit) else None,
|
|
199
|
+
}
|
|
200
|
+
if self.config.get("use_flash_attention_2", False) and self.device == "cuda":
|
|
201
|
+
if hasattr(transformers, " আসছেAttention"): # Check for Flash Attention support in transformers version
|
|
202
|
+
model_load_args["attn_implementation"] = "flash_attention_2"
|
|
203
|
+
ASCIIColors.info("Attempting to use Flash Attention 2.")
|
|
217
204
|
else:
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
205
|
+
ASCIIColors.warning("Flash Attention 2 requested but not found in this transformers version. Using default.")
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
if load_in_8bit: model_load_args["load_in_8bit"] = True
|
|
209
|
+
if load_in_4bit: model_load_args["quantization_config"] = bnb_config
|
|
210
|
+
|
|
211
|
+
# device_map="auto" for multi-GPU or when quantizing on CUDA
|
|
212
|
+
if self.device == "cuda" and (load_in_8bit or load_in_4bit or torch.cuda.device_count() > 1):
|
|
213
|
+
model_load_args["device_map"] = "auto"
|
|
214
|
+
ASCIIColors.info("Using device_map='auto'.")
|
|
215
|
+
|
|
216
|
+
try:
|
|
217
|
+
ASCIIColors.info(f"Loading tokenizer for '{self.model_identifier}'...")
|
|
218
|
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
219
|
+
self.model_identifier,
|
|
220
|
+
trust_remote_code=model_load_args["trust_remote_code"]
|
|
221
|
+
)
|
|
222
|
+
if self.tokenizer.pad_token is None:
|
|
223
|
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
224
|
+
ASCIIColors.info("Tokenizer `pad_token` was None, set to `eos_token`.")
|
|
225
|
+
|
|
226
|
+
# --- Determine if it's a LLaVA-like vision model ---
|
|
227
|
+
model_config_hf = AutoConfig.from_pretrained(self.model_identifier, trust_remote_code=model_load_args["trust_remote_code"])
|
|
228
|
+
self.supports_vision = "llava" in model_config_hf.model_type.lower() or \
|
|
229
|
+
any("Llava" in arch for arch in getattr(model_config_hf, "architectures", [])) or \
|
|
230
|
+
"vision_tower" in model_config_hf.to_dict() # Common LLaVA config key
|
|
231
|
+
|
|
232
|
+
if self.supports_vision:
|
|
233
|
+
ASCIIColors.info(f"Detected LLaVA-like vision model: '{self.model_identifier}'.")
|
|
234
|
+
self.processor = AutoProcessor.from_pretrained(
|
|
235
|
+
self.model_identifier,
|
|
236
|
+
trust_remote_code=model_load_args["trust_remote_code"]
|
|
227
237
|
)
|
|
238
|
+
# Choose appropriate LLaVA model class
|
|
239
|
+
if "llava-next" in self.model_identifier.lower() or any("LlavaNext" in arch for arch in getattr(model_config_hf, "architectures", [])):
|
|
240
|
+
ModelClass = LlavaNextForConditionalGeneration
|
|
241
|
+
elif "llava" in self.model_identifier.lower() or any("LlavaForConditionalGeneration" in arch for arch in getattr(model_config_hf, "architectures", [])):
|
|
242
|
+
ModelClass = LlavaForConditionalGeneration
|
|
243
|
+
else: # Fallback if specific Llava class not matched by name
|
|
244
|
+
ASCIIColors.warning("Could not determine specific LLaVA class, using AutoModelForCausalLM. Vision capabilities might be limited.")
|
|
245
|
+
ModelClass = AutoModelForCausalLM # This might not fully work for all LLaVAs
|
|
228
246
|
|
|
229
|
-
|
|
230
|
-
|
|
247
|
+
self.model = ModelClass.from_pretrained(self.model_identifier, **model_load_args)
|
|
248
|
+
else:
|
|
249
|
+
ASCIIColors.info(f"Loading text model '{self.model_identifier}'...")
|
|
250
|
+
self.model = AutoModelForCausalLM.from_pretrained(self.model_identifier, **model_load_args)
|
|
251
|
+
|
|
252
|
+
# If not using device_map, move model to the selected device
|
|
253
|
+
if "device_map" not in model_load_args and self.device != "cpu":
|
|
254
|
+
self.model.to(self.device)
|
|
255
|
+
|
|
256
|
+
self.model.eval() # Set to evaluation mode
|
|
257
|
+
|
|
258
|
+
# --- Load Embedding Model ---
|
|
259
|
+
emb_model_name = self.config.get("embedding_model_name")
|
|
260
|
+
if emb_model_name:
|
|
261
|
+
try:
|
|
262
|
+
ASCIIColors.info(f"Loading embedding model: {emb_model_name} on device: {self.device}")
|
|
263
|
+
self.embedding_model = SentenceTransformer(emb_model_name, device=self.device)
|
|
264
|
+
except Exception as e_emb:
|
|
265
|
+
ASCIIColors.warning(f"Failed to load embedding model '{emb_model_name}': {e_emb}. Embeddings will not be available.")
|
|
266
|
+
self.embedding_model = None
|
|
267
|
+
else:
|
|
268
|
+
ASCIIColors.info("No embedding_model_name configured. Skipping embedding model load.")
|
|
269
|
+
self.embedding_model = None
|
|
270
|
+
|
|
271
|
+
ASCIIColors.green(f"Model '{self.model_identifier}' loaded successfully.")
|
|
272
|
+
return True
|
|
273
|
+
|
|
274
|
+
except Exception as e:
|
|
275
|
+
ASCIIColors.error(f"Failed to load model '{self.model_identifier}': {e}")
|
|
276
|
+
trace_exception(e)
|
|
277
|
+
self.unload_model() # Ensure partial loads are cleaned up
|
|
278
|
+
return False
|
|
279
|
+
|
|
280
|
+
def unload_model(self):
|
|
281
|
+
if self.model is not None:
|
|
282
|
+
del self.model
|
|
283
|
+
self.model = None
|
|
284
|
+
if self.tokenizer is not None:
|
|
285
|
+
del self.tokenizer
|
|
286
|
+
self.tokenizer = None
|
|
287
|
+
if self.processor is not None:
|
|
288
|
+
del self.processor
|
|
289
|
+
self.processor = None
|
|
290
|
+
if self.embedding_model is not None:
|
|
291
|
+
del self.embedding_model
|
|
292
|
+
self.embedding_model = None
|
|
293
|
+
|
|
294
|
+
if self.device == "cuda":
|
|
295
|
+
torch.cuda.empty_cache()
|
|
296
|
+
|
|
297
|
+
self.model_identifier = None
|
|
298
|
+
self.model_name = None
|
|
299
|
+
self.supports_vision = False
|
|
300
|
+
ASCIIColors.info("Hugging Face model unloaded.")
|
|
301
|
+
|
|
302
|
+
def generate_text(self,
|
|
303
|
+
prompt: str,
|
|
304
|
+
images: Optional[List[str]] = None,
|
|
305
|
+
system_prompt: str = "",
|
|
306
|
+
n_predict: Optional[int] = None,
|
|
307
|
+
stream: bool = False,
|
|
308
|
+
temperature: float = None,
|
|
309
|
+
top_k: int = None,
|
|
310
|
+
top_p: float = None,
|
|
311
|
+
repeat_penalty: float = None,
|
|
312
|
+
seed: Optional[int] = None,
|
|
313
|
+
stop_words: Optional[List[str]] = None, # Added custom stop_words
|
|
314
|
+
streaming_callback: Optional[Callable[[str, int], bool]] = None,
|
|
315
|
+
use_chat_format_override: Optional[bool] = None,
|
|
316
|
+
**generation_kwargs
|
|
317
|
+
) -> Union[str, Dict[str, Any]]:
|
|
318
|
+
|
|
319
|
+
if self.model is None or self.tokenizer is None:
|
|
320
|
+
return {"status": False, "error": "Model not loaded."}
|
|
321
|
+
|
|
322
|
+
if seed is not None:
|
|
323
|
+
torch.manual_seed(seed)
|
|
324
|
+
if self.device == "cuda": torch.cuda.manual_seed_all(seed)
|
|
325
|
+
|
|
326
|
+
_use_chat_format = use_chat_format_override if use_chat_format_override is not None \
|
|
327
|
+
else (self.default_completion_format == ELF_COMPLETION_FORMAT.Chat)
|
|
328
|
+
|
|
329
|
+
# --- Prepare Inputs ---
|
|
330
|
+
inputs_dict = {}
|
|
331
|
+
processed_images = []
|
|
332
|
+
if self.supports_vision and self.processor and images:
|
|
333
|
+
try:
|
|
334
|
+
for img_path in images:
|
|
335
|
+
processed_images.append(Image.open(img_path).convert("RGB"))
|
|
336
|
+
# LLaVA processor typically takes text and images, returns combined inputs
|
|
337
|
+
inputs_dict = self.processor(text=prompt, images=processed_images, return_tensors="pt").to(self.model.device)
|
|
338
|
+
ASCIIColors.debug("Processed inputs with LLaVA processor.")
|
|
339
|
+
except Exception as e_img:
|
|
340
|
+
ASCIIColors.error(f"Error processing images for LLaVA: {e_img}")
|
|
341
|
+
return {"status": False, "error": f"Image processing error: {e_img}"}
|
|
342
|
+
|
|
343
|
+
elif _use_chat_format and hasattr(self.tokenizer, 'apply_chat_template'):
|
|
344
|
+
messages = []
|
|
345
|
+
if system_prompt: messages.append({"role": "system", "content": system_prompt})
|
|
346
|
+
|
|
347
|
+
# Newer chat templates can handle images directly in content if tokenizer supports it
|
|
348
|
+
# Example: [{"type": "text", "text": "..."}, {"type": "image_url", "image_url": {"url": "path/to/image.jpg"}}]
|
|
349
|
+
# For now, this example keeps LLaVA processor separate.
|
|
350
|
+
messages.append({"role": "user", "content": prompt})
|
|
351
|
+
|
|
352
|
+
try:
|
|
353
|
+
input_text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
|
354
|
+
inputs_dict = self.tokenizer(input_text, return_tensors="pt").to(self.model.device)
|
|
355
|
+
ASCIIColors.debug("Applied chat template.")
|
|
356
|
+
except Exception as e_tmpl: # Some tokenizers might fail if template is complex or not well-defined
|
|
357
|
+
ASCIIColors.warning(f"Failed to apply chat template ({e_tmpl}). Falling back to raw prompt.")
|
|
358
|
+
_use_chat_format = False # Fallback
|
|
359
|
+
|
|
360
|
+
if not _use_chat_format or not inputs_dict: # Raw prompt or fallback
|
|
361
|
+
full_prompt_text = ""
|
|
362
|
+
if system_prompt: full_prompt_text += system_prompt + "\n\n"
|
|
363
|
+
full_prompt_text += prompt
|
|
364
|
+
inputs_dict = self.tokenizer(full_prompt_text, return_tensors="pt").to(self.model.device)
|
|
365
|
+
ASCIIColors.debug("Using raw prompt format.")
|
|
366
|
+
|
|
367
|
+
input_ids = inputs_dict.get("input_ids")
|
|
368
|
+
if input_ids is None: return {"status": False, "error": "Failed to tokenize prompt."}
|
|
369
|
+
|
|
370
|
+
current_input_length = input_ids.shape[1]
|
|
371
|
+
|
|
372
|
+
# --- Generation Parameters ---
|
|
373
|
+
gen_conf = GenerationConfig.from_model_config(self.model.config) # Start with model's default
|
|
374
|
+
|
|
375
|
+
gen_conf.max_new_tokens = n_predict if n_predict is not None else self.config.get("max_new_tokens")
|
|
376
|
+
gen_conf.temperature = temperature if temperature is not None else self.config.get("temperature")
|
|
377
|
+
gen_conf.top_k = top_k if top_k is not None else self.config.get("top_k")
|
|
378
|
+
gen_conf.top_p = top_p if top_p is not None else self.config.get("top_p")
|
|
379
|
+
gen_conf.repetition_penalty = repeat_penalty if repeat_penalty is not None else self.config.get("repetition_penalty")
|
|
380
|
+
gen_conf.pad_token_id = self.tokenizer.eos_token_id # Crucial for stopping
|
|
381
|
+
gen_conf.eos_token_id = self.tokenizer.eos_token_id
|
|
382
|
+
|
|
383
|
+
# Apply any other valid GenerationConfig parameters from generation_kwargs
|
|
384
|
+
for key, value in generation_kwargs.items():
|
|
385
|
+
if hasattr(gen_conf, key): setattr(gen_conf, key, value)
|
|
386
|
+
|
|
387
|
+
# --- Stopping Criteria ---
|
|
388
|
+
stopping_criteria_list = StoppingCriteriaList()
|
|
389
|
+
effective_stop_words = stop_words if stop_words is not None else self.config.get("stop_words", [])
|
|
390
|
+
if effective_stop_words:
|
|
391
|
+
stopping_criteria_list.append(StopOnWords(self.tokenizer, effective_stop_words))
|
|
392
|
+
|
|
393
|
+
# --- Generation ---
|
|
394
|
+
try:
|
|
395
|
+
if stream and streaming_callback:
|
|
396
|
+
streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True)
|
|
397
|
+
generation_thread_kwargs = {
|
|
398
|
+
**inputs_dict, # input_ids, attention_mask, pixel_values (if vision)
|
|
399
|
+
"generation_config": gen_conf,
|
|
400
|
+
"streamer": streamer,
|
|
401
|
+
"stopping_criteria": stopping_criteria_list if effective_stop_words else None
|
|
402
|
+
}
|
|
231
403
|
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
404
|
+
thread = threading.Thread(target=self.model.generate, kwargs=generation_thread_kwargs)
|
|
405
|
+
thread.start()
|
|
406
|
+
|
|
407
|
+
full_response_text = ""
|
|
408
|
+
for new_text_chunk in streamer:
|
|
409
|
+
if streaming_callback(new_text_chunk, MSG_TYPE.MSG_TYPE_CHUNK):
|
|
410
|
+
full_response_text += new_text_chunk
|
|
411
|
+
else: # Callback requested stop
|
|
412
|
+
ASCIIColors.info("Streaming callback requested stop.")
|
|
413
|
+
# Note: stopping the model.generate thread externally is complex.
|
|
414
|
+
# The thread will complete its current generation.
|
|
415
|
+
break
|
|
416
|
+
thread.join(timeout=self.config.get("generation_timeout", 300))
|
|
417
|
+
if thread.is_alive():
|
|
418
|
+
ASCIIColors.warning("Generation thread did not finish in time after streaming.")
|
|
419
|
+
return full_response_text
|
|
420
|
+
else: # Non-streaming
|
|
421
|
+
outputs = self.model.generate(
|
|
422
|
+
**inputs_dict,
|
|
423
|
+
generation_config=gen_conf,
|
|
424
|
+
stopping_criteria=stopping_criteria_list if effective_stop_words else None
|
|
236
425
|
)
|
|
237
|
-
|
|
238
|
-
|
|
426
|
+
# outputs contains the full sequence (prompt + new tokens)
|
|
427
|
+
generated_tokens = outputs[0][current_input_length:]
|
|
428
|
+
generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
|
429
|
+
return generated_text.strip()
|
|
239
430
|
|
|
240
431
|
except Exception as e:
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
return {"status":
|
|
432
|
+
ASCIIColors.error(f"Error during text generation: {e}")
|
|
433
|
+
trace_exception(e)
|
|
434
|
+
return {"status": False, "error": str(e)}
|
|
244
435
|
|
|
245
|
-
def tokenize(self, text: str) ->
|
|
246
|
-
|
|
247
|
-
return
|
|
248
|
-
|
|
249
|
-
def detokenize(self, tokens: list) -> str:
|
|
250
|
-
"""Convert a list of tokens back to text."""
|
|
251
|
-
return "".join(tokens)
|
|
252
|
-
|
|
253
|
-
def count_tokens(self, text: str) -> int:
|
|
254
|
-
"""
|
|
255
|
-
Count tokens from a text.
|
|
436
|
+
def tokenize(self, text: str) -> List[int]:
|
|
437
|
+
if self.tokenizer is None: raise RuntimeError("Tokenizer not loaded.")
|
|
438
|
+
return self.tokenizer.encode(text)
|
|
256
439
|
|
|
257
|
-
|
|
258
|
-
|
|
440
|
+
def detokenize(self, tokens: List[int]) -> str:
|
|
441
|
+
if self.tokenizer is None: raise RuntimeError("Tokenizer not loaded.")
|
|
442
|
+
return self.tokenizer.decode(tokens)
|
|
259
443
|
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
444
|
+
def count_tokens(self, text: str) -> int:
|
|
445
|
+
if self.tokenizer is None: raise RuntimeError("Tokenizer not loaded.")
|
|
446
|
+
return len(self.tokenizer.encode(text))
|
|
447
|
+
|
|
448
|
+
def embed(self, text: Union[str, List[str]], **kwargs) -> Union[List[float], List[List[float]]]:
|
|
449
|
+
if self.embedding_model is None:
|
|
450
|
+
raise RuntimeError("Embedding model not loaded. Configure 'embedding_model_name'.")
|
|
451
|
+
try:
|
|
452
|
+
# SentenceTransformer's encode can take a string or list of strings
|
|
453
|
+
embeddings_np = self.embedding_model.encode(text, **kwargs)
|
|
454
|
+
if isinstance(text, str): # Single text input
|
|
455
|
+
return embeddings_np.tolist()
|
|
456
|
+
else: # List of texts input
|
|
457
|
+
return [emb.tolist() for emb in embeddings_np]
|
|
458
|
+
except Exception as e:
|
|
459
|
+
ASCIIColors.error(f"Embedding generation failed: {e}")
|
|
460
|
+
trace_exception(e)
|
|
461
|
+
raise
|
|
264
462
|
|
|
265
|
-
|
|
266
|
-
def embed(self, text: str, **kwargs) -> list:
|
|
267
|
-
"""Get embeddings for the input text (placeholder)."""
|
|
268
|
-
pass
|
|
269
|
-
|
|
270
463
|
def get_model_info(self) -> dict:
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
"
|
|
274
|
-
"
|
|
275
|
-
"
|
|
276
|
-
"
|
|
464
|
+
info = {
|
|
465
|
+
"binding_name": self.binding_name,
|
|
466
|
+
"model_name": self.model_name,
|
|
467
|
+
"model_identifier": self.model_identifier,
|
|
468
|
+
"loaded": self.model is not None,
|
|
469
|
+
"config": self.config, # Binding's own config
|
|
470
|
+
"device": self.device,
|
|
471
|
+
"torch_dtype": str(self.torch_dtype),
|
|
472
|
+
"supports_vision": self.supports_vision,
|
|
473
|
+
"embedding_model_name": self.config.get("embedding_model_name") if self.embedding_model else None,
|
|
277
474
|
}
|
|
475
|
+
if self.model and hasattr(self.model, 'config'):
|
|
476
|
+
model_hf_config = self.model.config.to_dict()
|
|
477
|
+
info["model_hf_config"] = {k: str(v)[:200] for k,v in model_hf_config.items()} # Truncate long values
|
|
478
|
+
info["max_model_len"] = getattr(self.model.config, "max_position_embeddings", "N/A")
|
|
479
|
+
|
|
480
|
+
info["supports_structured_output"] = False # HF models don't inherently support grammar like llama.cpp server
|
|
481
|
+
# (unless using external libraries like outlines)
|
|
482
|
+
return info
|
|
483
|
+
|
|
484
|
+
def listModels(self) -> List[Dict[str, str]]:
|
|
485
|
+
models_found = []
|
|
486
|
+
unique_model_names = set()
|
|
487
|
+
|
|
488
|
+
if self.models_path.exists() and self.models_path.is_dir():
|
|
489
|
+
for item in self.models_path.iterdir():
|
|
490
|
+
if item.is_dir(): # HF models are directories
|
|
491
|
+
# Basic check for a config file to qualify as a model dir
|
|
492
|
+
if (item / "config.json").exists():
|
|
493
|
+
model_name = item.name
|
|
494
|
+
if model_name not in unique_model_names:
|
|
495
|
+
try:
|
|
496
|
+
# Calculating size can be slow for large model repos
|
|
497
|
+
# total_size = sum(f.stat().st_size for f in item.rglob('*') if f.is_file())
|
|
498
|
+
# size_gb_str = f"{total_size / (1024**3):.2f} GB"
|
|
499
|
+
size_gb_str = "N/A (size calculation disabled for speed)"
|
|
500
|
+
except Exception:
|
|
501
|
+
size_gb_str = "N/A"
|
|
502
|
+
|
|
503
|
+
models_found.append({
|
|
504
|
+
"model_name": model_name, # This is the folder name
|
|
505
|
+
"path_hint": str(item.relative_to(self.models_path.parent) if item.is_relative_to(self.models_path.parent) else item),
|
|
506
|
+
"size_gb": size_gb_str
|
|
507
|
+
})
|
|
508
|
+
unique_model_names.add(model_name)
|
|
509
|
+
|
|
510
|
+
ASCIIColors.info("Tip: You can also use any Hugging Face Hub model ID directly (e.g., 'mistralai/Mistral-7B-Instruct-v0.1').")
|
|
511
|
+
return models_found
|
|
278
512
|
|
|
279
|
-
def
|
|
280
|
-
|
|
281
|
-
|
|
513
|
+
def __del__(self):
|
|
514
|
+
self.unload_model()
|
|
515
|
+
|
|
516
|
+
|
|
517
|
+
if __name__ == '__main__':
|
|
518
|
+
global full_streamed_text
|
|
519
|
+
ASCIIColors.yellow("Testing HuggingFaceHubBinding...")
|
|
520
|
+
|
|
521
|
+
# --- Configuration ---
|
|
522
|
+
# For testing, you might need to download a model first or use a small Hub ID.
|
|
523
|
+
# Option 1: Use a small model from Hugging Face Hub
|
|
524
|
+
# test_model_name = "gpt2" # Small, good for quick tests
|
|
525
|
+
test_model_name = "microsoft/phi-2" # Small, good quality, requires trust_remote_code=True
|
|
526
|
+
# test_model_name = "HuggingFaceH4/zephyr-7b-beta" # Larger, powerful
|
|
527
|
+
|
|
528
|
+
# Option 2: Path to a local model folder (if you have one)
|
|
529
|
+
# Replace 'path/to/your/models' with the PARENT directory of your HF model folders.
|
|
530
|
+
# And 'your-local-model-folder' with the actual folder name.
|
|
531
|
+
# Example:
|
|
532
|
+
# test_models_parent_path = Path.home() / "lollms_models" # Example path
|
|
533
|
+
# test_model_name = "phi-2" # if "phi-2" folder is inside test_models_parent_path
|
|
534
|
+
|
|
535
|
+
# For local testing, models_path should be where your HF model *folders* are.
|
|
536
|
+
# If using a Hub ID like "gpt2", models_path is less critical unless you expect
|
|
537
|
+
# the binding to *only* look there (which it doesn't, it prioritizes Hub IDs).
|
|
538
|
+
# Let's use a dummy path for models_path for Hub ID testing.
|
|
282
539
|
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
540
|
+
# Adjust current_directory for local model testing if needed
|
|
541
|
+
# For this test, we'll assume a Hub ID. `models_path` is where `listModels` would scan.
|
|
542
|
+
test_models_parent_path = Path("./test_hf_models_dir") # Create a dummy for listModels scan
|
|
543
|
+
test_models_parent_path.mkdir(exist_ok=True)
|
|
544
|
+
|
|
545
|
+
binding_config = {
|
|
546
|
+
"device": "auto", # "cuda", "mps", "cpu"
|
|
547
|
+
"quantize": False, # False, "4bit", "8bit" (requires CUDA & bitsandbytes for 4/8 bit)
|
|
548
|
+
"torch_dtype": "auto", # "float16" or "bfloat16" on CUDA for speed
|
|
549
|
+
"max_new_tokens": 100, # Limit generation length for tests
|
|
550
|
+
"trust_remote_code": True, # Needed for models like Phi-2
|
|
551
|
+
"stop_words": ["\nHuman:", "\nUSER:"], # Example stop words
|
|
552
|
+
# "embedding_model_name": "sentence-transformers/paraphrase-MiniLM-L3-v2" # Smaller embedding model
|
|
553
|
+
}
|
|
554
|
+
|
|
555
|
+
active_binding = None
|
|
556
|
+
try:
|
|
557
|
+
ASCIIColors.cyan("\n--- Initializing HuggingFaceHubBinding ---")
|
|
558
|
+
active_binding = HuggingFaceHubBinding(
|
|
559
|
+
model_name_or_id=test_model_name,
|
|
560
|
+
models_path=test_models_parent_path,
|
|
561
|
+
config=binding_config
|
|
562
|
+
)
|
|
563
|
+
if not active_binding.model:
|
|
564
|
+
raise RuntimeError(f"Model '{test_model_name}' failed to load.")
|
|
565
|
+
|
|
566
|
+
ASCIIColors.green(f"Binding initialized. Model '{active_binding.model_name}' loaded on {active_binding.device}.")
|
|
567
|
+
ASCIIColors.info(f"Model Info: {json.dumps(active_binding.get_model_info(), indent=2, default=str)}")
|
|
568
|
+
|
|
569
|
+
# --- List Models (scans configured models_path) ---
|
|
570
|
+
ASCIIColors.cyan("\n--- Listing Models (from models_path) ---")
|
|
571
|
+
# To make this test useful, you could manually place a model folder in `test_hf_models_dir`
|
|
572
|
+
# e.g., download "gpt2" and put it in `test_hf_models_dir/gpt2`
|
|
573
|
+
# For now, it will likely be empty unless you do that.
|
|
574
|
+
listed_models = active_binding.listModels()
|
|
575
|
+
if listed_models:
|
|
576
|
+
ASCIIColors.green(f"Found {len(listed_models)} potential model folders. First 5:")
|
|
577
|
+
for m in listed_models[:5]: print(m)
|
|
578
|
+
else: ASCIIColors.warning(f"No model folders found in '{test_models_parent_path}'. This is normal if it's empty.")
|
|
579
|
+
|
|
580
|
+
# --- Tokenize/Detokenize ---
|
|
581
|
+
ASCIIColors.cyan("\n--- Tokenize/Detokenize ---")
|
|
582
|
+
sample_text = "Hello, Hugging Face world!"
|
|
583
|
+
tokens = active_binding.tokenize(sample_text)
|
|
584
|
+
ASCIIColors.green(f"Tokens for '{sample_text}': {tokens[:10]}...")
|
|
585
|
+
token_count = active_binding.count_tokens(sample_text)
|
|
586
|
+
ASCIIColors.green(f"Token count: {token_count}")
|
|
587
|
+
if tokens:
|
|
588
|
+
detokenized_text = active_binding.detokenize(tokens)
|
|
589
|
+
ASCIIColors.green(f"Detokenized text: {detokenized_text}")
|
|
590
|
+
else: ASCIIColors.warning("Tokenization returned empty list.")
|
|
591
|
+
|
|
592
|
+
# --- Text Generation (Non-Streaming, Chat Format if supported) ---
|
|
593
|
+
ASCIIColors.cyan("\n--- Text Generation (Non-Streaming) ---")
|
|
594
|
+
prompt_text = "What is the capital of France?"
|
|
595
|
+
# For Phi-2, system prompt might need specific formatting if not using apply_chat_template strictly
|
|
596
|
+
# For models like Zephyr, system_prompt is part of chat template
|
|
597
|
+
system_prompt_text = "You are a helpful AI assistant."
|
|
598
|
+
generated_text = active_binding.generate_text(
|
|
599
|
+
prompt_text, system_prompt=system_prompt_text, stream=False,
|
|
600
|
+
n_predict=30 # Override default max_new_tokens for this call
|
|
601
|
+
)
|
|
602
|
+
if isinstance(generated_text, str): ASCIIColors.green(f"Generated text: {generated_text}")
|
|
603
|
+
else: ASCIIColors.error(f"Generation failed: {generated_text}")
|
|
604
|
+
|
|
605
|
+
# --- Text Generation (Streaming) ---
|
|
606
|
+
ASCIIColors.cyan("\n--- Text Generation (Streaming) ---")
|
|
607
|
+
full_streamed_text = ""
|
|
608
|
+
def stream_callback(chunk: str, msg_type: int):
|
|
609
|
+
global full_streamed_text
|
|
610
|
+
ASCIIColors.green(f"{chunk}", end="", flush=True)
|
|
611
|
+
full_streamed_text += chunk
|
|
612
|
+
return True # Continue streaming
|
|
613
|
+
|
|
614
|
+
result = active_binding.generate_text(
|
|
615
|
+
"Tell me a short story about a brave robot.",
|
|
616
|
+
stream=True,
|
|
617
|
+
streaming_callback=stream_callback,
|
|
618
|
+
n_predict=70
|
|
619
|
+
)
|
|
620
|
+
print("\n--- End of Stream ---")
|
|
621
|
+
if isinstance(result, str): ASCIIColors.green(f"Full streamed text collected: {result}")
|
|
622
|
+
else: ASCIIColors.error(f"Streaming generation failed: {result}")
|
|
623
|
+
|
|
624
|
+
# --- Embeddings ---
|
|
625
|
+
if active_binding.embedding_model:
|
|
626
|
+
ASCIIColors.cyan("\n--- Embeddings ---")
|
|
627
|
+
embedding_text = "This is a test sentence for Hugging Face embeddings."
|
|
628
|
+
try:
|
|
629
|
+
embedding_vector = active_binding.embed(embedding_text)
|
|
630
|
+
ASCIIColors.green(f"Embedding for '{embedding_text}' (first 3 dims): {embedding_vector[:3]}...")
|
|
631
|
+
ASCIIColors.info(f"Embedding vector dimension: {len(embedding_vector)}")
|
|
632
|
+
|
|
633
|
+
# Test batch embedding
|
|
634
|
+
batch_texts = ["First sentence.", "Second sentence, quite different."]
|
|
635
|
+
batch_embeddings = active_binding.embed(batch_texts)
|
|
636
|
+
ASCIIColors.green(f"Batch embeddings generated for {len(batch_texts)} texts.")
|
|
637
|
+
ASCIIColors.info(f"First batch embedding (first 3 dims): {batch_embeddings[0][:3]}...")
|
|
638
|
+
|
|
639
|
+
except Exception as e_emb: ASCIIColors.warning(f"Could not get embedding: {e_emb}")
|
|
640
|
+
else: ASCIIColors.yellow("\n--- Embeddings Skipped (no embedding model loaded) ---")
|
|
641
|
+
|
|
642
|
+
# --- LLaVA Vision Test (Conceptual - requires a LLaVA model and an image) ---
|
|
643
|
+
# To test LLaVA properly:
|
|
644
|
+
# 1. Set `test_model_name` to a LLaVA model, e.g., "llava-hf/llava-1.5-7b-hf" (very large!)
|
|
645
|
+
# or a smaller one like "unum-cloud/uform-gen2-qwen-500m" (check its specific prompting style).
|
|
646
|
+
# 2. Ensure `trust_remote_code=True` might be needed.
|
|
647
|
+
# 3. Provide a real image path.
|
|
648
|
+
if active_binding.supports_vision:
|
|
649
|
+
ASCIIColors.cyan("\n--- LLaVA Vision Test ---")
|
|
650
|
+
dummy_image_path = Path("test_dummy_image.png")
|
|
651
|
+
try:
|
|
652
|
+
# Create a dummy image for testing
|
|
653
|
+
img = Image.new('RGB', (200, 100), color = ('skyblue'))
|
|
654
|
+
from PIL import ImageDraw
|
|
655
|
+
d = ImageDraw.Draw(img)
|
|
656
|
+
d.text((10,10), "Hello LLaVA from HF!", fill=('black'))
|
|
657
|
+
img.save(dummy_image_path)
|
|
658
|
+
ASCIIColors.info(f"Created dummy image: {dummy_image_path}")
|
|
659
|
+
|
|
660
|
+
llava_prompt = "Describe this image." # LLaVA models often use "<image>\nUSER: <prompt>\nASSISTANT:"
|
|
661
|
+
# or just the prompt if processor handles template.
|
|
662
|
+
# For AutoProcessor, often just the text part of the prompt.
|
|
663
|
+
llava_response = active_binding.generate_text(
|
|
664
|
+
prompt=llava_prompt,
|
|
665
|
+
images=[str(dummy_image_path)],
|
|
666
|
+
n_predict=50,
|
|
667
|
+
stream=False
|
|
668
|
+
)
|
|
669
|
+
if isinstance(llava_response, str): ASCIIColors.green(f"LLaVA response: {llava_response}")
|
|
670
|
+
else: ASCIIColors.error(f"LLaVA generation failed: {llava_response}")
|
|
671
|
+
|
|
672
|
+
except ImportError: ASCIIColors.warning("Pillow's ImageDraw not found for dummy image text.")
|
|
673
|
+
except Exception as e_llava: ASCIIColors.error(f"LLaVA test error: {e_llava}"); trace_exception(e_llava)
|
|
674
|
+
finally:
|
|
675
|
+
if dummy_image_path.exists(): dummy_image_path.unlink()
|
|
676
|
+
else:
|
|
677
|
+
ASCIIColors.yellow("\n--- LLaVA Vision Test Skipped (model does not support vision or not configured for it) ---")
|
|
678
|
+
|
|
679
|
+
except ImportError as e_imp:
|
|
680
|
+
ASCIIColors.error(f"Import error: {e_imp}. Check installations.")
|
|
681
|
+
except RuntimeError as e_rt:
|
|
682
|
+
ASCIIColors.error(f"Runtime error: {e_rt}")
|
|
683
|
+
except Exception as e_main:
|
|
684
|
+
ASCIIColors.error(f"An unexpected error occurred: {e_main}")
|
|
685
|
+
trace_exception(e_main)
|
|
686
|
+
finally:
|
|
687
|
+
if active_binding:
|
|
688
|
+
ASCIIColors.cyan("\n--- Unloading Model ---")
|
|
689
|
+
active_binding.unload_model()
|
|
690
|
+
ASCIIColors.green("Model unloaded.")
|
|
691
|
+
if test_models_parent_path.exists() and not any(test_models_parent_path.iterdir()): # cleanup dummy dir if empty
|
|
692
|
+
try: os.rmdir(test_models_parent_path)
|
|
693
|
+
except: pass
|
|
694
|
+
|
|
695
|
+
|
|
696
|
+
ASCIIColors.yellow("\nHuggingFaceHubBinding test finished.")
|