nexaai 1.0.5__cp310-cp310-macosx_13_0_x86_64.whl → 1.0.6__cp310-cp310-macosx_13_0_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nexaai might be problematic. Click here for more details.
- nexaai/_stub.cpython-310-darwin.so +0 -0
- nexaai/_version.py +1 -1
- nexaai/binds/libnexa_bridge.dylib +0 -0
- nexaai/binds/nexa_llama_cpp/libggml-base.dylib +0 -0
- nexaai/binds/nexa_llama_cpp/libggml-cpu.so +0 -0
- nexaai/binds/nexa_llama_cpp/libggml-metal.so +0 -0
- nexaai/binds/nexa_llama_cpp/libllama.dylib +0 -0
- nexaai/binds/nexa_llama_cpp/libmtmd.dylib +0 -0
- nexaai/binds/nexa_llama_cpp/libnexa_plugin.dylib +0 -0
- nexaai/embedder_impl/mlx_embedder_impl.py +6 -5
- nexaai/mlx_backend/embedding/generate.py +219 -16
- nexaai/mlx_backend/embedding/interface.py +346 -41
- nexaai/mlx_backend/embedding/main.py +126 -35
- nexaai/utils/model_manager.py +0 -8
- nexaai/utils/progress_tracker.py +10 -6
- {nexaai-1.0.5.dist-info → nexaai-1.0.6.dist-info}/METADATA +2 -1
- {nexaai-1.0.5.dist-info → nexaai-1.0.6.dist-info}/RECORD +19 -19
- {nexaai-1.0.5.dist-info → nexaai-1.0.6.dist-info}/WHEEL +0 -0
- {nexaai-1.0.5.dist-info → nexaai-1.0.6.dist-info}/top_level.txt +0 -0
|
Binary file
|
nexaai/_version.py
CHANGED
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
@@ -3,7 +3,7 @@ import numpy as np
|
|
|
3
3
|
|
|
4
4
|
from nexaai.common import PluginID
|
|
5
5
|
from nexaai.embedder import Embedder, EmbeddingConfig
|
|
6
|
-
from nexaai.mlx_backend.embedding.interface import
|
|
6
|
+
from nexaai.mlx_backend.embedding.interface import create_embedder
|
|
7
7
|
from nexaai.mlx_backend.ml import ModelConfig as MLXModelConfig, SamplerConfig as MLXSamplerConfig, GenerationConfig as MLXGenerationConfig, EmbeddingConfig
|
|
8
8
|
|
|
9
9
|
|
|
@@ -27,11 +27,12 @@ class MLXEmbedderImpl(Embedder):
|
|
|
27
27
|
MLXEmbedderImpl instance
|
|
28
28
|
"""
|
|
29
29
|
try:
|
|
30
|
-
#
|
|
31
|
-
|
|
32
|
-
# Create instance and load MLX embedder
|
|
30
|
+
# Create instance
|
|
33
31
|
instance = cls()
|
|
34
|
-
|
|
32
|
+
|
|
33
|
+
# Use the factory function to create the appropriate embedder based on model type
|
|
34
|
+
# This will automatically detect if it's JinaV2 or generic model and route correctly
|
|
35
|
+
instance._mlx_embedder = create_embedder(
|
|
35
36
|
model_path=model_path,
|
|
36
37
|
tokenizer_path=tokenizer_file
|
|
37
38
|
)
|
|
@@ -23,11 +23,46 @@ from .modeling.nexa_jina_v2 import Model, ModelArgs
|
|
|
23
23
|
from tokenizers import Tokenizer
|
|
24
24
|
from huggingface_hub import snapshot_download
|
|
25
25
|
|
|
26
|
-
|
|
26
|
+
# Try to import mlx_embeddings for general embedding support
|
|
27
|
+
try:
|
|
28
|
+
import mlx_embeddings
|
|
29
|
+
MLX_EMBEDDINGS_AVAILABLE = True
|
|
30
|
+
except ImportError:
|
|
31
|
+
MLX_EMBEDDINGS_AVAILABLE = False
|
|
32
|
+
# Suppress warning during import to avoid interfering with C++ tests
|
|
33
|
+
# The warning will be shown when actually trying to use mlx_embeddings functionality
|
|
34
|
+
pass
|
|
35
|
+
|
|
36
|
+
def detect_model_type(model_path):
|
|
37
|
+
"""Detect if the model is Jina V2 or generic mlx_embeddings model."""
|
|
38
|
+
config_path = os.path.join(model_path, "config.json") if os.path.isdir(model_path) else f"{model_path}/config.json"
|
|
39
|
+
|
|
40
|
+
if not os.path.exists(config_path):
|
|
41
|
+
# Try default modelfiles directory
|
|
42
|
+
config_path = f"{curr_dir}/modelfiles/config.json"
|
|
43
|
+
if not os.path.exists(config_path):
|
|
44
|
+
return "generic"
|
|
45
|
+
|
|
46
|
+
try:
|
|
47
|
+
with open(config_path, "r") as f:
|
|
48
|
+
config = json.load(f)
|
|
49
|
+
|
|
50
|
+
# Check if it's a Jina V2 model
|
|
51
|
+
architectures = config.get("architectures", [])
|
|
52
|
+
if "JinaBertModel" in architectures:
|
|
53
|
+
return "jina_v2"
|
|
54
|
+
|
|
55
|
+
return "generic"
|
|
56
|
+
except Exception:
|
|
57
|
+
return "generic"
|
|
58
|
+
|
|
59
|
+
# ========== Jina V2 Direct Implementation ==========
|
|
60
|
+
|
|
61
|
+
def load_jina_model(model_id):
|
|
27
62
|
"""Initialize and load the Jina V2 model with FP16 weights."""
|
|
28
63
|
# Load configuration from config.json
|
|
29
64
|
if not os.path.exists(f"{curr_dir}/modelfiles/config.json"):
|
|
30
|
-
print(f"📥 Downloading model {model_id}...")
|
|
65
|
+
print(f"📥 Downloading Jina V2 model {model_id}...")
|
|
31
66
|
|
|
32
67
|
# Ensure modelfiles directory exists
|
|
33
68
|
os.makedirs(f"{curr_dir}/modelfiles", exist_ok=True)
|
|
@@ -82,15 +117,15 @@ def load_model(model_id):
|
|
|
82
117
|
|
|
83
118
|
return model
|
|
84
119
|
|
|
85
|
-
def
|
|
86
|
-
"""Load and configure the tokenizer."""
|
|
120
|
+
def load_jina_tokenizer():
|
|
121
|
+
"""Load and configure the tokenizer for Jina V2."""
|
|
87
122
|
tokenizer = Tokenizer.from_file(f"{curr_dir}/modelfiles/tokenizer.json")
|
|
88
123
|
tokenizer.enable_padding(pad_id=0, pad_token="[PAD]")
|
|
89
124
|
tokenizer.enable_truncation(max_length=512)
|
|
90
125
|
return tokenizer
|
|
91
126
|
|
|
92
|
-
def
|
|
93
|
-
"""Encode a single text and return its embedding."""
|
|
127
|
+
def encode_jina_text(model, tokenizer, text):
|
|
128
|
+
"""Encode a single text using Jina V2 and return its embedding."""
|
|
94
129
|
# Tokenize the text
|
|
95
130
|
encoding = tokenizer.encode(text)
|
|
96
131
|
|
|
@@ -113,18 +148,186 @@ def encode_text(model, tokenizer, text):
|
|
|
113
148
|
|
|
114
149
|
return embeddings
|
|
115
150
|
|
|
151
|
+
# ========== MLX Embeddings Direct Implementation ==========
|
|
152
|
+
|
|
153
|
+
def load_mlx_embeddings_model(model_id):
|
|
154
|
+
"""Load model using mlx_embeddings package."""
|
|
155
|
+
if not MLX_EMBEDDINGS_AVAILABLE:
|
|
156
|
+
print("Warning: mlx_embeddings not available. Please install it to use general embedding models.")
|
|
157
|
+
raise ImportError("mlx_embeddings package is not available. Please install it first.")
|
|
158
|
+
|
|
159
|
+
# Download model if needed
|
|
160
|
+
model_path = f"{curr_dir}/modelfiles"
|
|
161
|
+
|
|
162
|
+
if not os.path.exists(f"{model_path}/config.json"):
|
|
163
|
+
print(f"📥 Downloading model {model_id}...")
|
|
164
|
+
os.makedirs(model_path, exist_ok=True)
|
|
165
|
+
|
|
166
|
+
try:
|
|
167
|
+
snapshot_download(
|
|
168
|
+
repo_id=model_id,
|
|
169
|
+
local_dir=model_path,
|
|
170
|
+
resume_download=True,
|
|
171
|
+
local_dir_use_symlinks=False
|
|
172
|
+
)
|
|
173
|
+
print("✅ Model download completed!")
|
|
174
|
+
except Exception as e:
|
|
175
|
+
print(f"❌ Failed to download model: {e}")
|
|
176
|
+
raise
|
|
177
|
+
|
|
178
|
+
# Load model and tokenizer using mlx_embeddings
|
|
179
|
+
model, tokenizer = mlx_embeddings.load(model_path)
|
|
180
|
+
return model, tokenizer
|
|
181
|
+
|
|
182
|
+
def encode_mlx_embeddings_text(model, tokenizer, texts, model_path=None):
|
|
183
|
+
"""Generate embeddings using mlx_embeddings."""
|
|
184
|
+
if isinstance(texts, str):
|
|
185
|
+
texts = [texts]
|
|
186
|
+
|
|
187
|
+
# Check if this is a Gemma3TextModel by checking config
|
|
188
|
+
# WORKAROUND: Gemma3TextModel has a bug where it expects 'inputs' as positional arg
|
|
189
|
+
# but mlx_embeddings.generate passes 'input_ids' as keyword arg
|
|
190
|
+
# See: https://github.com/ml-explore/mlx-examples/issues/... (bug report pending)
|
|
191
|
+
is_gemma = False
|
|
192
|
+
if model_path:
|
|
193
|
+
config_path = os.path.join(model_path, "config.json") if os.path.isdir(model_path) else f"{model_path}/config.json"
|
|
194
|
+
else:
|
|
195
|
+
config_path = f"{curr_dir}/modelfiles/config.json"
|
|
196
|
+
|
|
197
|
+
if os.path.exists(config_path):
|
|
198
|
+
try:
|
|
199
|
+
with open(config_path, "r") as f:
|
|
200
|
+
config = json.load(f)
|
|
201
|
+
architectures = config.get("architectures", [])
|
|
202
|
+
is_gemma = "Gemma3TextModel" in architectures
|
|
203
|
+
except Exception:
|
|
204
|
+
pass
|
|
205
|
+
|
|
206
|
+
if is_gemma:
|
|
207
|
+
# HARDCODED WORKAROUND for Gemma3TextModel bug
|
|
208
|
+
# Use direct tokenization and model call instead of mlx_embeddings.generate
|
|
209
|
+
# This avoids the bug where generate passes 'input_ids' as keyword arg
|
|
210
|
+
# but Gemma3TextModel.__call__ expects 'inputs' as positional arg
|
|
211
|
+
|
|
212
|
+
# Tokenize using batch_encode_plus for Gemma models
|
|
213
|
+
encoded_input = tokenizer.batch_encode_plus(
|
|
214
|
+
texts,
|
|
215
|
+
padding=True,
|
|
216
|
+
truncation=True,
|
|
217
|
+
return_tensors='mlx',
|
|
218
|
+
max_length=512
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
# Get input tensors
|
|
222
|
+
input_ids = encoded_input['input_ids']
|
|
223
|
+
attention_mask = encoded_input.get('attention_mask', None)
|
|
224
|
+
|
|
225
|
+
# Call model with positional input_ids and keyword attention_mask
|
|
226
|
+
# This matches Gemma3TextModel's expected signature:
|
|
227
|
+
# def __call__(self, inputs: mx.array, attention_mask: Optional[mx.array] = None)
|
|
228
|
+
output = model(input_ids, attention_mask=attention_mask)
|
|
229
|
+
|
|
230
|
+
# Get the normalized embeddings
|
|
231
|
+
return output.text_embeds
|
|
232
|
+
else:
|
|
233
|
+
# Normal path for non-Gemma models
|
|
234
|
+
# Use standard mlx_embeddings.generate approach
|
|
235
|
+
output = mlx_embeddings.generate(
|
|
236
|
+
model,
|
|
237
|
+
tokenizer,
|
|
238
|
+
texts=texts,
|
|
239
|
+
max_length=512,
|
|
240
|
+
padding=True,
|
|
241
|
+
truncation=True
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
return output.text_embeds
|
|
245
|
+
|
|
116
246
|
def main(model_id):
|
|
117
247
|
"""Main function to handle user input and generate embeddings."""
|
|
118
248
|
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
249
|
+
print(f"🔍 Loading model: {model_id}")
|
|
250
|
+
|
|
251
|
+
# Detect model type
|
|
252
|
+
model_type = detect_model_type(f"{curr_dir}/modelfiles")
|
|
253
|
+
|
|
254
|
+
# First try to download/check if model exists
|
|
255
|
+
if not os.path.exists(f"{curr_dir}/modelfiles/config.json"):
|
|
256
|
+
# Download the model first to detect its type
|
|
257
|
+
print(f"Model not found locally. Downloading...")
|
|
258
|
+
os.makedirs(f"{curr_dir}/modelfiles", exist_ok=True)
|
|
259
|
+
try:
|
|
260
|
+
snapshot_download(
|
|
261
|
+
repo_id=model_id,
|
|
262
|
+
local_dir=f"{curr_dir}/modelfiles",
|
|
263
|
+
resume_download=True,
|
|
264
|
+
local_dir_use_symlinks=False
|
|
265
|
+
)
|
|
266
|
+
print("✅ Model download completed!")
|
|
267
|
+
# Re-detect model type after download
|
|
268
|
+
model_type = detect_model_type(f"{curr_dir}/modelfiles")
|
|
269
|
+
except Exception as e:
|
|
270
|
+
print(f"❌ Failed to download model: {e}")
|
|
271
|
+
raise
|
|
272
|
+
|
|
273
|
+
print(f"📦 Detected model type: {model_type}")
|
|
274
|
+
|
|
275
|
+
# Test texts
|
|
276
|
+
test_texts = [
|
|
277
|
+
"Hello, how are you?",
|
|
278
|
+
"What is machine learning?",
|
|
279
|
+
"The weather is nice today."
|
|
280
|
+
]
|
|
281
|
+
|
|
282
|
+
if model_type == "jina_v2":
|
|
283
|
+
print("Using Jina V2 direct implementation")
|
|
284
|
+
|
|
285
|
+
# Load Jina V2 model
|
|
286
|
+
model = load_jina_model(model_id)
|
|
287
|
+
tokenizer = load_jina_tokenizer()
|
|
288
|
+
|
|
289
|
+
print("\nGenerating embeddings for test texts:")
|
|
290
|
+
for text in test_texts:
|
|
291
|
+
embedding = encode_jina_text(model, tokenizer, text)
|
|
292
|
+
print(f"\nText: '{text}'")
|
|
293
|
+
print(f" Embedding shape: {embedding.shape}")
|
|
294
|
+
print(f" Sample values (first 5): {embedding.flatten()[:5].tolist()}")
|
|
295
|
+
print(f" Stats - Min: {embedding.min():.4f}, Max: {embedding.max():.4f}, Mean: {embedding.mean():.4f}")
|
|
296
|
+
|
|
297
|
+
else:
|
|
298
|
+
print("Using mlx_embeddings direct implementation")
|
|
299
|
+
|
|
300
|
+
if not MLX_EMBEDDINGS_AVAILABLE:
|
|
301
|
+
print("❌ mlx_embeddings is not installed. Please install it to use generic models.")
|
|
302
|
+
return
|
|
303
|
+
|
|
304
|
+
# Load generic model using mlx_embeddings
|
|
305
|
+
model, tokenizer = load_mlx_embeddings_model(model_id)
|
|
306
|
+
|
|
307
|
+
print("\nGenerating embeddings for test texts:")
|
|
308
|
+
# Pass model_path to handle Gemma workaround if needed
|
|
309
|
+
embeddings = encode_mlx_embeddings_text(model, tokenizer, test_texts, model_path=f"{curr_dir}/modelfiles")
|
|
310
|
+
|
|
311
|
+
for i, text in enumerate(test_texts):
|
|
312
|
+
embedding = embeddings[i]
|
|
313
|
+
print(f"\nText: '{text}'")
|
|
314
|
+
print(f" Embedding shape: {embedding.shape}")
|
|
315
|
+
print(f" Sample values (first 5): {embedding[:5].tolist()}")
|
|
316
|
+
|
|
317
|
+
# Calculate stats
|
|
318
|
+
emb_array = mx.array(embedding) if not isinstance(embedding, mx.array) else embedding
|
|
319
|
+
print(f" Stats - Min: {emb_array.min():.4f}, Max: {emb_array.max():.4f}, Mean: {emb_array.mean():.4f}")
|
|
320
|
+
|
|
321
|
+
print("\n✅ Direct embedding generation completed!")
|
|
127
322
|
|
|
128
323
|
if __name__ == "__main__":
|
|
129
|
-
|
|
130
|
-
|
|
324
|
+
import argparse
|
|
325
|
+
parser = argparse.ArgumentParser(description="Generate embeddings using direct implementation")
|
|
326
|
+
parser.add_argument(
|
|
327
|
+
"--model_id",
|
|
328
|
+
type=str,
|
|
329
|
+
default="nexaml/jina-v2-fp16-mlx",
|
|
330
|
+
help="Model ID from Hugging Face Hub (e.g., 'nexaml/jina-v2-fp16-mlx' or 'mlx-community/embeddinggemma-300m-bf16')"
|
|
331
|
+
)
|
|
332
|
+
args = parser.parse_args()
|
|
333
|
+
main(args.model_id)
|
|
@@ -20,11 +20,16 @@ import mlx.core as mx
|
|
|
20
20
|
import numpy as np
|
|
21
21
|
from pathlib import Path
|
|
22
22
|
from typing import Any, List, Optional, Sequence
|
|
23
|
+
from abc import ABC, abstractmethod
|
|
23
24
|
|
|
24
25
|
# Import necessary modules
|
|
25
26
|
from tokenizers import Tokenizer
|
|
26
27
|
|
|
27
28
|
# Import from ml.py for API alignment
|
|
29
|
+
import sys
|
|
30
|
+
from pathlib import Path as PathLib
|
|
31
|
+
sys.path.insert(0, str(PathLib(__file__).parent.parent))
|
|
32
|
+
|
|
28
33
|
from ml import (
|
|
29
34
|
Embedder as BaseEmbedder,
|
|
30
35
|
EmbeddingConfig,
|
|
@@ -34,13 +39,24 @@ from ml import (
|
|
|
34
39
|
# Import profiling module
|
|
35
40
|
from profiling import ProfilingMixin, StopReason
|
|
36
41
|
|
|
37
|
-
# Import the model implementation
|
|
38
|
-
|
|
42
|
+
# Import the model implementation for Jina
|
|
43
|
+
try:
|
|
44
|
+
from .modeling.nexa_jina_v2 import Model, ModelArgs
|
|
45
|
+
except ImportError:
|
|
46
|
+
# Fallback for when module is run directly
|
|
47
|
+
from modeling.nexa_jina_v2 import Model, ModelArgs
|
|
48
|
+
|
|
49
|
+
# Import mlx_embeddings for general embedding support
|
|
50
|
+
try:
|
|
51
|
+
import mlx_embeddings
|
|
52
|
+
MLX_EMBEDDINGS_AVAILABLE = True
|
|
53
|
+
except ImportError:
|
|
54
|
+
MLX_EMBEDDINGS_AVAILABLE = False
|
|
39
55
|
|
|
40
56
|
|
|
41
|
-
class
|
|
57
|
+
class BaseMLXEmbedder(BaseEmbedder, ProfilingMixin, ABC):
|
|
42
58
|
"""
|
|
43
|
-
|
|
59
|
+
Abstract base embedder interface for MLX embedding models.
|
|
44
60
|
API aligned with ml.py Embedder abstract base class.
|
|
45
61
|
"""
|
|
46
62
|
|
|
@@ -64,7 +80,7 @@ class Embedder(BaseEmbedder, ProfilingMixin):
|
|
|
64
80
|
|
|
65
81
|
self.model_path = model_path
|
|
66
82
|
self.tokenizer_path = tokenizer_path
|
|
67
|
-
self.device = device if device is not None else "cpu"
|
|
83
|
+
self.device = device if device is not None else "cpu"
|
|
68
84
|
|
|
69
85
|
# Initialize model and tokenizer as None
|
|
70
86
|
self.model = None
|
|
@@ -78,6 +94,69 @@ class Embedder(BaseEmbedder, ProfilingMixin):
|
|
|
78
94
|
self.config = None
|
|
79
95
|
self.reset_profiling()
|
|
80
96
|
|
|
97
|
+
@abstractmethod
|
|
98
|
+
def load_model(self, model_path: PathType) -> bool:
|
|
99
|
+
"""Load model from path."""
|
|
100
|
+
pass
|
|
101
|
+
|
|
102
|
+
def close(self) -> None:
|
|
103
|
+
"""Close the model."""
|
|
104
|
+
self.destroy()
|
|
105
|
+
|
|
106
|
+
@abstractmethod
|
|
107
|
+
def embed(
|
|
108
|
+
self,
|
|
109
|
+
texts: Sequence[str],
|
|
110
|
+
config: Optional[EmbeddingConfig] = None,
|
|
111
|
+
clear_cache: bool = True,
|
|
112
|
+
) -> List[List[float]]:
|
|
113
|
+
"""Generate embeddings for texts."""
|
|
114
|
+
pass
|
|
115
|
+
|
|
116
|
+
@abstractmethod
|
|
117
|
+
def embedding_dim(self) -> int:
|
|
118
|
+
"""Get embedding dimension."""
|
|
119
|
+
pass
|
|
120
|
+
|
|
121
|
+
def set_lora(self, lora_id: int) -> None:
|
|
122
|
+
"""Set active LoRA adapter. (Disabled for embedding models)"""
|
|
123
|
+
raise NotImplementedError("LoRA is not supported for embedding models")
|
|
124
|
+
|
|
125
|
+
def add_lora(self, lora_path: PathType) -> int:
|
|
126
|
+
"""Add LoRA adapter and return its ID. (Disabled for embedding models)"""
|
|
127
|
+
raise NotImplementedError("LoRA is not supported for embedding models")
|
|
128
|
+
|
|
129
|
+
def remove_lora(self, lora_id: int) -> None:
|
|
130
|
+
"""Remove LoRA adapter. (Disabled for embedding models)"""
|
|
131
|
+
raise NotImplementedError("LoRA is not supported for embedding models")
|
|
132
|
+
|
|
133
|
+
def list_loras(self) -> List[int]:
|
|
134
|
+
"""List available LoRA adapters. (Disabled for embedding models)"""
|
|
135
|
+
raise NotImplementedError("LoRA is not supported for embedding models")
|
|
136
|
+
|
|
137
|
+
def _normalize_embedding(self, embedding: List[float], method: str) -> List[float]:
|
|
138
|
+
"""Normalize embedding using specified method."""
|
|
139
|
+
if method == "none":
|
|
140
|
+
return embedding
|
|
141
|
+
|
|
142
|
+
embedding_array = np.array(embedding)
|
|
143
|
+
|
|
144
|
+
if method == "l2":
|
|
145
|
+
norm = np.linalg.norm(embedding_array)
|
|
146
|
+
if norm > 0:
|
|
147
|
+
embedding_array = embedding_array / norm
|
|
148
|
+
elif method == "mean":
|
|
149
|
+
mean_val = np.mean(embedding_array)
|
|
150
|
+
embedding_array = embedding_array - mean_val
|
|
151
|
+
|
|
152
|
+
return embedding_array.tolist()
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
class JinaV2Embedder(BaseMLXEmbedder):
|
|
156
|
+
"""
|
|
157
|
+
Embedder implementation specifically for Jina V2 models.
|
|
158
|
+
"""
|
|
159
|
+
|
|
81
160
|
def load_model(self, model_path: PathType) -> bool:
|
|
82
161
|
"""Load model from path."""
|
|
83
162
|
try:
|
|
@@ -97,10 +176,6 @@ class Embedder(BaseEmbedder, ProfilingMixin):
|
|
|
97
176
|
print(f"Failed to load model: {e}")
|
|
98
177
|
return False
|
|
99
178
|
|
|
100
|
-
def close(self) -> None:
|
|
101
|
-
"""Close the model."""
|
|
102
|
-
self.destroy()
|
|
103
|
-
|
|
104
179
|
def embed(
|
|
105
180
|
self,
|
|
106
181
|
texts: Sequence[str],
|
|
@@ -158,22 +233,6 @@ class Embedder(BaseEmbedder, ProfilingMixin):
|
|
|
158
233
|
return 768 # Default dimension for Jina v2
|
|
159
234
|
return self.config.hidden_size
|
|
160
235
|
|
|
161
|
-
def set_lora(self, lora_id: int) -> None:
|
|
162
|
-
"""Set active LoRA adapter. (Disabled for embedding models)"""
|
|
163
|
-
raise NotImplementedError("LoRA is not supported for embedding models")
|
|
164
|
-
|
|
165
|
-
def add_lora(self, lora_path: PathType) -> int:
|
|
166
|
-
"""Add LoRA adapter and return its ID. (Disabled for embedding models)"""
|
|
167
|
-
raise NotImplementedError("LoRA is not supported for embedding models")
|
|
168
|
-
|
|
169
|
-
def remove_lora(self, lora_id: int) -> None:
|
|
170
|
-
"""Remove LoRA adapter. (Disabled for embedding models)"""
|
|
171
|
-
raise NotImplementedError("LoRA is not supported for embedding models")
|
|
172
|
-
|
|
173
|
-
def list_loras(self) -> List[int]:
|
|
174
|
-
"""List available LoRA adapters. (Disabled for embedding models)"""
|
|
175
|
-
raise NotImplementedError("LoRA is not supported for embedding models")
|
|
176
|
-
|
|
177
236
|
def _load_jina_model(self, model_dir: str) -> Model:
|
|
178
237
|
"""Initialize and load the Jina V2 model with FP16 weights."""
|
|
179
238
|
|
|
@@ -281,22 +340,267 @@ class Embedder(BaseEmbedder, ProfilingMixin):
|
|
|
281
340
|
|
|
282
341
|
return embedding_list
|
|
283
342
|
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
343
|
+
|
|
344
|
+
class MlxEmbeddingEmbedder(BaseMLXEmbedder):
|
|
345
|
+
"""
|
|
346
|
+
Embedder implementation using mlx_embeddings package for general embedding models.
|
|
347
|
+
"""
|
|
348
|
+
|
|
349
|
+
def load_model(self, model_path: PathType) -> bool:
|
|
350
|
+
"""Load model from path using mlx_embeddings."""
|
|
351
|
+
if not MLX_EMBEDDINGS_AVAILABLE:
|
|
352
|
+
print("Warning: mlx_embeddings not available. Please install it to use general embedding models.")
|
|
353
|
+
raise ImportError("mlx_embeddings package is not available. Please install it first.")
|
|
288
354
|
|
|
289
|
-
|
|
355
|
+
try:
|
|
356
|
+
# Use the provided model_path or fall back to instance path
|
|
357
|
+
if model_path:
|
|
358
|
+
if os.path.isfile(model_path):
|
|
359
|
+
model_path = os.path.dirname(model_path)
|
|
360
|
+
self.model_path = model_path
|
|
361
|
+
|
|
362
|
+
# Load model and tokenizer using mlx_embeddings
|
|
363
|
+
self.model, self.tokenizer = mlx_embeddings.load(self.model_path)
|
|
364
|
+
|
|
365
|
+
# Load config to get dimensions
|
|
366
|
+
config_path = os.path.join(self.model_path, "config.json")
|
|
367
|
+
if os.path.exists(config_path):
|
|
368
|
+
with open(config_path, "r") as f:
|
|
369
|
+
self.config = json.load(f)
|
|
370
|
+
|
|
371
|
+
return True
|
|
372
|
+
except Exception as e:
|
|
373
|
+
print(f"Failed to load model: {e}")
|
|
374
|
+
return False
|
|
375
|
+
|
|
376
|
+
def embed(
|
|
377
|
+
self,
|
|
378
|
+
texts: Sequence[str],
|
|
379
|
+
config: Optional[EmbeddingConfig] = None,
|
|
380
|
+
clear_cache: bool = True,
|
|
381
|
+
) -> List[List[float]]:
|
|
382
|
+
"""Generate embeddings for texts using mlx_embeddings."""
|
|
383
|
+
if self.model is None or self.tokenizer is None:
|
|
384
|
+
raise RuntimeError("Model not loaded. Call load_model() first.")
|
|
290
385
|
|
|
291
|
-
if
|
|
292
|
-
|
|
293
|
-
if norm > 0:
|
|
294
|
-
embedding_array = embedding_array / norm
|
|
295
|
-
elif method == "mean":
|
|
296
|
-
mean_val = np.mean(embedding_array)
|
|
297
|
-
embedding_array = embedding_array - mean_val
|
|
386
|
+
if config is None:
|
|
387
|
+
config = EmbeddingConfig()
|
|
298
388
|
|
|
299
|
-
|
|
389
|
+
# Start profiling
|
|
390
|
+
self._start_profiling()
|
|
391
|
+
|
|
392
|
+
try:
|
|
393
|
+
# Calculate total tokens for profiling
|
|
394
|
+
if hasattr(self.tokenizer, 'encode'):
|
|
395
|
+
total_tokens = sum(len(self.tokenizer.encode(text)) for text in texts)
|
|
396
|
+
else:
|
|
397
|
+
# For tokenizers that don't have simple encode method
|
|
398
|
+
total_tokens = len(texts) * 50 # Rough estimate
|
|
399
|
+
|
|
400
|
+
self._update_prompt_tokens(total_tokens)
|
|
401
|
+
|
|
402
|
+
# End prompt processing, start decode
|
|
403
|
+
self._prompt_end()
|
|
404
|
+
self._decode_start()
|
|
405
|
+
|
|
406
|
+
# Check if this is a Gemma3TextModel
|
|
407
|
+
# WORKAROUND: Gemma3TextModel has a bug where it expects 'inputs' as positional arg
|
|
408
|
+
# but mlx_embeddings.generate passes 'input_ids' as keyword arg
|
|
409
|
+
# See: https://github.com/ml-explore/mlx-examples/issues/... (bug report pending)
|
|
410
|
+
is_gemma = False
|
|
411
|
+
if self.config and "architectures" in self.config:
|
|
412
|
+
architectures = self.config.get("architectures", [])
|
|
413
|
+
is_gemma = "Gemma3TextModel" in architectures
|
|
414
|
+
|
|
415
|
+
if is_gemma:
|
|
416
|
+
# HARDCODED WORKAROUND for Gemma3TextModel bug
|
|
417
|
+
# Use direct tokenization and model call instead of mlx_embeddings.generate
|
|
418
|
+
max_length = config.max_length if hasattr(config, 'max_length') else 512
|
|
419
|
+
|
|
420
|
+
# Tokenize using batch_encode_plus
|
|
421
|
+
encoded_input = self.tokenizer.batch_encode_plus(
|
|
422
|
+
list(texts),
|
|
423
|
+
padding=True,
|
|
424
|
+
truncation=True,
|
|
425
|
+
return_tensors='mlx',
|
|
426
|
+
max_length=max_length
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
# Get input tensors
|
|
430
|
+
input_ids = encoded_input['input_ids']
|
|
431
|
+
attention_mask = encoded_input.get('attention_mask', None)
|
|
432
|
+
|
|
433
|
+
# Call model with positional input_ids and keyword attention_mask
|
|
434
|
+
# This matches Gemma3TextModel's expected signature
|
|
435
|
+
output = self.model(input_ids, attention_mask=attention_mask)
|
|
436
|
+
|
|
437
|
+
# Extract embeddings
|
|
438
|
+
embeddings_tensor = output.text_embeds
|
|
439
|
+
else:
|
|
440
|
+
# Normal path for non-Gemma models
|
|
441
|
+
# Generate embeddings using mlx_embeddings standard approach
|
|
442
|
+
output = mlx_embeddings.generate(
|
|
443
|
+
self.model,
|
|
444
|
+
self.tokenizer,
|
|
445
|
+
texts=list(texts),
|
|
446
|
+
max_length=config.max_length if hasattr(config, 'max_length') else 512,
|
|
447
|
+
padding=True,
|
|
448
|
+
truncation=True
|
|
449
|
+
)
|
|
450
|
+
|
|
451
|
+
# Extract embeddings
|
|
452
|
+
embeddings_tensor = output.text_embeds
|
|
453
|
+
|
|
454
|
+
# Convert to list format
|
|
455
|
+
embeddings = []
|
|
456
|
+
for i in range(embeddings_tensor.shape[0]):
|
|
457
|
+
embedding = embeddings_tensor[i].tolist()
|
|
458
|
+
|
|
459
|
+
# Apply normalization if requested
|
|
460
|
+
if config.normalize:
|
|
461
|
+
embedding = self._normalize_embedding(embedding, config.normalize_method)
|
|
462
|
+
|
|
463
|
+
embeddings.append(embedding)
|
|
464
|
+
|
|
465
|
+
if clear_cache:
|
|
466
|
+
mx.clear_cache()
|
|
467
|
+
|
|
468
|
+
# End timing and finalize profiling data
|
|
469
|
+
self._update_generated_tokens(0) # No generation in embedding
|
|
470
|
+
self._set_stop_reason(StopReason.ML_STOP_REASON_COMPLETED)
|
|
471
|
+
self._decode_end()
|
|
472
|
+
self._end_profiling()
|
|
473
|
+
|
|
474
|
+
return embeddings
|
|
475
|
+
|
|
476
|
+
except Exception as e:
|
|
477
|
+
self._set_stop_reason(StopReason.ML_STOP_REASON_UNKNOWN)
|
|
478
|
+
self._decode_end()
|
|
479
|
+
self._end_profiling()
|
|
480
|
+
raise RuntimeError(f"Error generating embeddings: {str(e)}")
|
|
481
|
+
|
|
482
|
+
def embedding_dim(self) -> int:
|
|
483
|
+
"""Get embedding dimension."""
|
|
484
|
+
if self.config is None:
|
|
485
|
+
return 768 # Default dimension
|
|
486
|
+
|
|
487
|
+
# Try different config keys that might contain the dimension
|
|
488
|
+
if "hidden_size" in self.config:
|
|
489
|
+
return self.config["hidden_size"]
|
|
490
|
+
elif "d_model" in self.config:
|
|
491
|
+
return self.config["d_model"]
|
|
492
|
+
elif "dim" in self.config:
|
|
493
|
+
return self.config["dim"]
|
|
494
|
+
else:
|
|
495
|
+
return 768 # Fallback default
|
|
496
|
+
|
|
497
|
+
|
|
498
|
+
class MLXEmbedder(BaseMLXEmbedder):
|
|
499
|
+
"""
|
|
500
|
+
Concrete embedder class that routes to the appropriate implementation.
|
|
501
|
+
This class can be instantiated directly (for C++ compatibility) and will
|
|
502
|
+
automatically delegate to JinaV2Embedder or MlxEmbeddingEmbedder based on model type.
|
|
503
|
+
"""
|
|
504
|
+
|
|
505
|
+
def __init__(
|
|
506
|
+
self,
|
|
507
|
+
model_path: PathType,
|
|
508
|
+
tokenizer_path: PathType,
|
|
509
|
+
device: Optional[str] = None,
|
|
510
|
+
) -> None:
|
|
511
|
+
"""Initialize the Embedder model."""
|
|
512
|
+
super().__init__(model_path, tokenizer_path, device)
|
|
513
|
+
self._impl = None # Will hold the actual implementation
|
|
514
|
+
|
|
515
|
+
def _get_implementation(self) -> BaseMLXEmbedder:
|
|
516
|
+
"""Get or create the appropriate implementation based on model type."""
|
|
517
|
+
if self._impl is None:
|
|
518
|
+
# Detect model type and create appropriate implementation
|
|
519
|
+
model_type = _detect_model_type(self.model_path)
|
|
520
|
+
|
|
521
|
+
if model_type == "jina_v2":
|
|
522
|
+
self._impl = JinaV2Embedder(self.model_path, self.tokenizer_path, self.device)
|
|
523
|
+
else:
|
|
524
|
+
self._impl = MlxEmbeddingEmbedder(self.model_path, self.tokenizer_path, self.device)
|
|
525
|
+
|
|
526
|
+
# Copy over any existing state
|
|
527
|
+
if self.model is not None:
|
|
528
|
+
self._impl.model = self.model
|
|
529
|
+
if self.tokenizer is not None:
|
|
530
|
+
self._impl.tokenizer = self.tokenizer
|
|
531
|
+
if self.config is not None:
|
|
532
|
+
self._impl.config = self.config
|
|
533
|
+
|
|
534
|
+
return self._impl
|
|
535
|
+
|
|
536
|
+
def load_model(self, model_path: PathType) -> bool:
|
|
537
|
+
"""Load model from path."""
|
|
538
|
+
# Get the appropriate implementation and delegate
|
|
539
|
+
impl = self._get_implementation()
|
|
540
|
+
result = impl.load_model(model_path)
|
|
541
|
+
|
|
542
|
+
# Sync state back
|
|
543
|
+
self.model = impl.model
|
|
544
|
+
self.tokenizer = impl.tokenizer
|
|
545
|
+
self.config = impl.config
|
|
546
|
+
|
|
547
|
+
return result
|
|
548
|
+
|
|
549
|
+
def embed(
|
|
550
|
+
self,
|
|
551
|
+
texts: Sequence[str],
|
|
552
|
+
config: Optional[EmbeddingConfig] = None,
|
|
553
|
+
clear_cache: bool = True,
|
|
554
|
+
) -> List[List[float]]:
|
|
555
|
+
"""Generate embeddings for texts."""
|
|
556
|
+
# Get the appropriate implementation and delegate
|
|
557
|
+
impl = self._get_implementation()
|
|
558
|
+
return impl.embed(texts, config, clear_cache)
|
|
559
|
+
|
|
560
|
+
def embedding_dim(self) -> int:
|
|
561
|
+
"""Get embedding dimension."""
|
|
562
|
+
# Get the appropriate implementation and delegate
|
|
563
|
+
impl = self._get_implementation()
|
|
564
|
+
return impl.embedding_dim()
|
|
565
|
+
|
|
566
|
+
def destroy(self) -> None:
|
|
567
|
+
"""Destroy the model and free resources."""
|
|
568
|
+
super().destroy()
|
|
569
|
+
if self._impl is not None:
|
|
570
|
+
self._impl.destroy()
|
|
571
|
+
self._impl = None
|
|
572
|
+
|
|
573
|
+
|
|
574
|
+
# Backward compatibility alias
|
|
575
|
+
Embedder = MLXEmbedder
|
|
576
|
+
|
|
577
|
+
|
|
578
|
+
def _detect_model_type(model_path: PathType) -> str:
|
|
579
|
+
"""Detect the model type from config.json."""
|
|
580
|
+
if os.path.isfile(model_path):
|
|
581
|
+
model_path = os.path.dirname(model_path)
|
|
582
|
+
|
|
583
|
+
config_path = os.path.join(model_path, "config.json")
|
|
584
|
+
|
|
585
|
+
if not os.path.exists(config_path):
|
|
586
|
+
# If no config.json, assume it's a generic model
|
|
587
|
+
return "generic"
|
|
588
|
+
|
|
589
|
+
try:
|
|
590
|
+
with open(config_path, "r") as f:
|
|
591
|
+
config = json.load(f)
|
|
592
|
+
|
|
593
|
+
# Check architectures field for JinaBertModel
|
|
594
|
+
architectures = config.get("architectures", [])
|
|
595
|
+
if "JinaBertModel" in architectures:
|
|
596
|
+
return "jina_v2"
|
|
597
|
+
|
|
598
|
+
# Default to generic mlx_embeddings for other models
|
|
599
|
+
return "generic"
|
|
600
|
+
|
|
601
|
+
except Exception as e:
|
|
602
|
+
print(f"Warning: Could not parse config.json: {e}")
|
|
603
|
+
return "generic"
|
|
300
604
|
|
|
301
605
|
|
|
302
606
|
# Factory function for creating embedder instances
|
|
@@ -304,9 +608,10 @@ def create_embedder(
|
|
|
304
608
|
model_path: PathType,
|
|
305
609
|
tokenizer_path: Optional[PathType] = None,
|
|
306
610
|
device: Optional[str] = None,
|
|
307
|
-
) ->
|
|
308
|
-
"""Create and return an
|
|
611
|
+
) -> MLXEmbedder:
|
|
612
|
+
"""Create and return an MLXEmbedder instance that automatically routes to the appropriate implementation."""
|
|
309
613
|
if tokenizer_path is None:
|
|
310
614
|
tokenizer_path = model_path
|
|
311
615
|
|
|
312
|
-
|
|
616
|
+
# Return the concrete MLXEmbedder which will handle routing internally
|
|
617
|
+
return MLXEmbedder(model_path, tokenizer_path, device)
|
|
@@ -12,71 +12,162 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
|
|
15
|
+
import os
|
|
16
|
+
import sys
|
|
17
|
+
import numpy as np
|
|
18
|
+
from pathlib import Path
|
|
16
19
|
|
|
20
|
+
# Add parent path for imports
|
|
21
|
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
17
22
|
|
|
18
|
-
|
|
19
|
-
|
|
23
|
+
# Import from interface (uses the factory pattern with routing)
|
|
24
|
+
from .interface import create_embedder
|
|
25
|
+
from .interface import EmbeddingConfig
|
|
26
|
+
from huggingface_hub import snapshot_download
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def download_model_if_needed(model_id, local_dir):
|
|
30
|
+
"""Download model from Hugging Face Hub if not present locally."""
|
|
31
|
+
if not os.path.exists(os.path.join(local_dir, "config.json")):
|
|
32
|
+
print(f"📥 Model not found locally. Downloading {model_id}...")
|
|
33
|
+
os.makedirs(local_dir, exist_ok=True)
|
|
34
|
+
try:
|
|
35
|
+
snapshot_download(
|
|
36
|
+
repo_id=model_id,
|
|
37
|
+
local_dir=local_dir,
|
|
38
|
+
resume_download=True,
|
|
39
|
+
local_dir_use_symlinks=False
|
|
40
|
+
)
|
|
41
|
+
print("✅ Model download completed!")
|
|
42
|
+
except Exception as e:
|
|
43
|
+
print(f"❌ Failed to download model: {e}")
|
|
44
|
+
raise
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def test_embedding_interface(model_path, is_local=False):
|
|
48
|
+
"""Test embedding model functionality using the interface."""
|
|
49
|
+
|
|
50
|
+
print("=" * 70)
|
|
51
|
+
print("TESTING EMBEDDING MODEL VIA INTERFACE")
|
|
52
|
+
print("=" * 70)
|
|
53
|
+
|
|
54
|
+
# Handle model path - download if it's a HF model ID
|
|
55
|
+
if not is_local and "/" in model_path:
|
|
56
|
+
# It's a HuggingFace model ID
|
|
57
|
+
local_dir = f"./modelfiles/{model_path.replace('/', '_')}"
|
|
58
|
+
download_model_if_needed(model_path, local_dir)
|
|
59
|
+
model_path = local_dir
|
|
60
|
+
|
|
61
|
+
# Create embedder using factory function (will auto-detect model type)
|
|
62
|
+
print(f"\n🔍 Creating embedder for: {model_path}")
|
|
20
63
|
embedder = create_embedder(model_path=model_path)
|
|
64
|
+
print(f"✅ Created embedder type: {type(embedder).__name__}")
|
|
21
65
|
|
|
22
66
|
# Load the model
|
|
23
|
-
print("Loading embedding model...")
|
|
67
|
+
print("\n📚 Loading embedding model...")
|
|
24
68
|
success = embedder.load_model(model_path)
|
|
25
69
|
|
|
26
70
|
if not success:
|
|
27
|
-
print("Failed to load model!")
|
|
71
|
+
print("❌ Failed to load model!")
|
|
28
72
|
return
|
|
29
73
|
|
|
30
74
|
print("✅ Model loaded successfully!")
|
|
31
|
-
print(f"Embedding dimension: {embedder.embedding_dim()}")
|
|
75
|
+
print(f"📏 Embedding dimension: {embedder.embedding_dim()}")
|
|
32
76
|
|
|
33
77
|
# Test texts
|
|
34
78
|
test_texts = [
|
|
35
79
|
"Hello, how are you?",
|
|
36
80
|
"What is machine learning?",
|
|
37
81
|
"The weather is nice today.",
|
|
38
|
-
"Python is a programming language."
|
|
82
|
+
"Python is a programming language.",
|
|
83
|
+
"Artificial intelligence is changing the world."
|
|
39
84
|
]
|
|
40
85
|
|
|
41
|
-
# Configure embedding
|
|
42
|
-
|
|
43
|
-
batch_size=2,
|
|
44
|
-
normalize=
|
|
45
|
-
|
|
46
|
-
|
|
86
|
+
# Configure embedding with different settings
|
|
87
|
+
configs = [
|
|
88
|
+
EmbeddingConfig(batch_size=2, normalize=True, normalize_method="l2"),
|
|
89
|
+
EmbeddingConfig(batch_size=3, normalize=False),
|
|
90
|
+
]
|
|
91
|
+
|
|
92
|
+
for config_idx, config in enumerate(configs):
|
|
93
|
+
print(f"\n{'='*50}")
|
|
94
|
+
print(f"TEST {config_idx + 1}: Config - Batch: {config.batch_size}, "
|
|
95
|
+
f"Normalize: {config.normalize}, Method: {config.normalize_method}")
|
|
96
|
+
print('='*50)
|
|
97
|
+
|
|
98
|
+
# Generate embeddings
|
|
99
|
+
embeddings = embedder.embed(test_texts, config)
|
|
100
|
+
|
|
101
|
+
# Display results
|
|
102
|
+
print(f"\n📊 Generated {len(embeddings)} embeddings")
|
|
103
|
+
|
|
104
|
+
for i, (text, embedding) in enumerate(zip(test_texts[:3], embeddings[:3])):
|
|
105
|
+
print(f"\n Text {i+1}: '{text}'")
|
|
106
|
+
print(f" Dimension: {len(embedding)}")
|
|
107
|
+
print(f" First 5 values: {[f'{v:.4f}' for v in embedding[:5]]}")
|
|
108
|
+
|
|
109
|
+
# Calculate magnitude
|
|
110
|
+
magnitude = np.linalg.norm(embedding)
|
|
111
|
+
print(f" Magnitude: {magnitude:.6f}")
|
|
47
112
|
|
|
48
|
-
|
|
113
|
+
# Compute similarity matrix for normalized embeddings
|
|
114
|
+
print("\n" + "="*50)
|
|
115
|
+
print("SIMILARITY MATRIX (L2 Normalized)")
|
|
116
|
+
print("="*50)
|
|
49
117
|
|
|
50
|
-
|
|
118
|
+
config = EmbeddingConfig(batch_size=len(test_texts), normalize=True, normalize_method="l2")
|
|
51
119
|
embeddings = embedder.embed(test_texts, config)
|
|
52
120
|
|
|
53
|
-
#
|
|
54
|
-
|
|
55
|
-
|
|
121
|
+
# Convert to numpy for easier computation
|
|
122
|
+
embeddings_np = np.array(embeddings)
|
|
123
|
+
similarity_matrix = np.dot(embeddings_np, embeddings_np.T)
|
|
56
124
|
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
print(f"
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
print(f"
|
|
125
|
+
print("\nTexts:")
|
|
126
|
+
for i, text in enumerate(test_texts):
|
|
127
|
+
print(f" [{i}] {text[:30]}...")
|
|
128
|
+
|
|
129
|
+
print("\nSimilarity Matrix:")
|
|
130
|
+
print(" ", end="")
|
|
131
|
+
for i in range(len(test_texts)):
|
|
132
|
+
print(f" [{i}] ", end="")
|
|
133
|
+
print()
|
|
65
134
|
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
print(
|
|
135
|
+
for i in range(len(test_texts)):
|
|
136
|
+
print(f" [{i}]", end="")
|
|
137
|
+
for j in range(len(test_texts)):
|
|
138
|
+
print(f" {similarity_matrix[i, j]:5.2f}", end="")
|
|
139
|
+
print()
|
|
140
|
+
|
|
141
|
+
# Find most similar pairs
|
|
142
|
+
print("\n🔍 Most Similar Pairs (excluding self-similarity):")
|
|
143
|
+
similarities = []
|
|
144
|
+
for i in range(len(test_texts)):
|
|
145
|
+
for j in range(i+1, len(test_texts)):
|
|
146
|
+
similarities.append((similarity_matrix[i, j], i, j))
|
|
147
|
+
|
|
148
|
+
similarities.sort(reverse=True)
|
|
149
|
+
for sim, i, j in similarities[:3]:
|
|
150
|
+
print(f" • Texts [{i}] and [{j}]: {sim:.4f}")
|
|
71
151
|
|
|
72
152
|
# Cleanup
|
|
73
153
|
embedder.close()
|
|
74
|
-
print("\n✅
|
|
154
|
+
print("\n✅ Interface test completed successfully!")
|
|
75
155
|
|
|
76
156
|
|
|
77
157
|
if __name__ == "__main__":
|
|
78
158
|
import argparse
|
|
79
|
-
parser = argparse.ArgumentParser()
|
|
80
|
-
parser.add_argument(
|
|
159
|
+
parser = argparse.ArgumentParser(description="Test embedding models via interface")
|
|
160
|
+
parser.add_argument(
|
|
161
|
+
"--model_path",
|
|
162
|
+
type=str,
|
|
163
|
+
default="nexaml/jina-v2-fp16-mlx",
|
|
164
|
+
help="Model path (local) or HuggingFace model ID"
|
|
165
|
+
)
|
|
166
|
+
parser.add_argument(
|
|
167
|
+
"--local",
|
|
168
|
+
action="store_true",
|
|
169
|
+
help="Indicate if model_path is a local directory"
|
|
170
|
+
)
|
|
81
171
|
args = parser.parse_args()
|
|
82
|
-
|
|
172
|
+
|
|
173
|
+
test_embedding_interface(args.model_path, args.local)
|
nexaai/utils/model_manager.py
CHANGED
|
@@ -820,14 +820,6 @@ class HuggingFaceDownloader:
|
|
|
820
820
|
# Create a subdirectory for this specific repo
|
|
821
821
|
repo_local_dir = self._create_repo_directory(local_dir, repo_id)
|
|
822
822
|
|
|
823
|
-
# Check if repository already exists (basic check for directory existence)
|
|
824
|
-
if not force_download and os.path.exists(repo_local_dir) and os.listdir(repo_local_dir):
|
|
825
|
-
print(f"✓ Repository already exists, skipping: {repo_id}")
|
|
826
|
-
# Stop progress tracking
|
|
827
|
-
if progress_tracker:
|
|
828
|
-
progress_tracker.stop_tracking()
|
|
829
|
-
return repo_local_dir
|
|
830
|
-
|
|
831
823
|
try:
|
|
832
824
|
download_kwargs = {
|
|
833
825
|
'repo_id': repo_id,
|
nexaai/utils/progress_tracker.py
CHANGED
|
@@ -107,7 +107,7 @@ class DownloadProgressTracker:
|
|
|
107
107
|
time_diff = current_time - self.last_time
|
|
108
108
|
|
|
109
109
|
# Only calculate if we have a meaningful time difference (avoid division by very small numbers)
|
|
110
|
-
if time_diff > 0.
|
|
110
|
+
if time_diff > 0.1: # At least 100ms between measurements
|
|
111
111
|
bytes_diff = current_downloaded - self.last_downloaded
|
|
112
112
|
|
|
113
113
|
# Only calculate speed if bytes actually changed
|
|
@@ -118,6 +118,14 @@ class DownloadProgressTracker:
|
|
|
118
118
|
self.speed_history.append(speed)
|
|
119
119
|
if len(self.speed_history) > self.max_speed_history:
|
|
120
120
|
self.speed_history.pop(0)
|
|
121
|
+
|
|
122
|
+
# Update tracking variables when we actually calculate speed
|
|
123
|
+
self.last_downloaded = current_downloaded
|
|
124
|
+
self.last_time = current_time
|
|
125
|
+
else:
|
|
126
|
+
# First measurement - initialize tracking variables
|
|
127
|
+
self.last_downloaded = current_downloaded
|
|
128
|
+
self.last_time = current_time
|
|
121
129
|
|
|
122
130
|
# Return the average of historical speeds if we have any
|
|
123
131
|
# This ensures we show the last known speed even when skipping updates
|
|
@@ -157,13 +165,9 @@ class DownloadProgressTracker:
|
|
|
157
165
|
total_file_sizes += data['total']
|
|
158
166
|
active_file_count += 1
|
|
159
167
|
|
|
160
|
-
# Calculate speed
|
|
168
|
+
# Calculate speed (tracking variables are updated internally)
|
|
161
169
|
speed = self.calculate_speed(total_downloaded)
|
|
162
170
|
|
|
163
|
-
# Update tracking variables
|
|
164
|
-
self.last_downloaded = total_downloaded
|
|
165
|
-
self.last_time = time.time()
|
|
166
|
-
|
|
167
171
|
# Determine total size - prioritize pre-fetched repo size, then aggregate file sizes
|
|
168
172
|
if self.total_repo_size > 0:
|
|
169
173
|
# Use pre-fetched repository info if available
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: nexaai
|
|
3
|
-
Version: 1.0.
|
|
3
|
+
Version: 1.0.6
|
|
4
4
|
Summary: Python bindings for NexaSDK C-lib backend
|
|
5
5
|
Author-email: "Nexa AI, Inc." <dev@nexa.ai>
|
|
6
6
|
Project-URL: Homepage, https://github.com/NexaAI/nexasdk-bridge
|
|
@@ -21,6 +21,7 @@ Provides-Extra: mlx
|
|
|
21
21
|
Requires-Dist: mlx; extra == "mlx"
|
|
22
22
|
Requires-Dist: mlx-lm; extra == "mlx"
|
|
23
23
|
Requires-Dist: mlx-vlm; extra == "mlx"
|
|
24
|
+
Requires-Dist: mlx-embeddings; extra == "mlx"
|
|
24
25
|
Requires-Dist: tokenizers; extra == "mlx"
|
|
25
26
|
Requires-Dist: safetensors; extra == "mlx"
|
|
26
27
|
Requires-Dist: Pillow; extra == "mlx"
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
nexaai/__init__.py,sha256=jXdC4vv6DBK1fVewYTYSUhOOYfvf_Mk81UIeMGGIKUg,2029
|
|
2
|
-
nexaai/_stub.cpython-310-darwin.so,sha256=
|
|
3
|
-
nexaai/_version.py,sha256=
|
|
2
|
+
nexaai/_stub.cpython-310-darwin.so,sha256=B2vHaw7BNJAnmEiz8ExY-vseaXkK_yDGI950dxdbRQ4,49832
|
|
3
|
+
nexaai/_version.py,sha256=jSTrJA7aO22uaIWjegon7bkFN2WNt7W4are37-C1cqg,138
|
|
4
4
|
nexaai/asr.py,sha256=NljMXDErwPNMOPaRkJZMEDka9Nk8xyur7L8i924TStY,2054
|
|
5
5
|
nexaai/base.py,sha256=N8PRgDFA-XPku2vWnQIofQ7ipz3pPlO6f8YZGnuhquE,982
|
|
6
6
|
nexaai/common.py,sha256=yBnIbqYaQYnfrl7IczOBh6MDibYZVxwaRJEglYcKgGs,3422
|
|
@@ -19,21 +19,21 @@ nexaai/binds/__init__.py,sha256=T9Ua7SzHNglSeEqXlfH5ymYXRyXhNKkC9z_y_bWCNMo,80
|
|
|
19
19
|
nexaai/binds/common_bind.cpython-310-darwin.so,sha256=FF5WuJj0fNCim_HjseBQu38vL-1M5zI_7EVTD7Bs-Bc,233960
|
|
20
20
|
nexaai/binds/embedder_bind.cpython-310-darwin.so,sha256=mU6hP0SyH8vcmPpC2GIr7ioK7539dsg_YbmrBdmj7l0,202032
|
|
21
21
|
nexaai/binds/libcrypto.dylib,sha256=ysW8ydmDPnnNRy3AHESjJwMTFfmGDKU9eLIaiR37ca0,5091432
|
|
22
|
-
nexaai/binds/libnexa_bridge.dylib,sha256=
|
|
22
|
+
nexaai/binds/libnexa_bridge.dylib,sha256=oCAhLh6CIncKS4oO5_xAcvtLniIY8inRuN0S7iRfCM4,250712
|
|
23
23
|
nexaai/binds/libssl.dylib,sha256=JHPTSbRFnImmoWDO9rFdiKb0lJMT3q78VEsx-5-S0sk,889520
|
|
24
24
|
nexaai/binds/llm_bind.cpython-310-darwin.so,sha256=aYqMs5VhC07RNZZgyS9JeYJJgWCl-toZOmt6vXu5yp0,183008
|
|
25
|
-
nexaai/binds/nexa_llama_cpp/libggml-base.dylib,sha256=
|
|
26
|
-
nexaai/binds/nexa_llama_cpp/libggml-cpu.so,sha256=
|
|
27
|
-
nexaai/binds/nexa_llama_cpp/libggml-metal.so,sha256=
|
|
25
|
+
nexaai/binds/nexa_llama_cpp/libggml-base.dylib,sha256=oikz7Qxzx6A0mPROq7uHTUwWn66LvvOjcdVstG-M8Fw,629528
|
|
26
|
+
nexaai/binds/nexa_llama_cpp/libggml-cpu.so,sha256=WepzOOeElmdOlsoMv7loLHsj8-Qx2O9ZJPlNnX11KJI,1039800
|
|
27
|
+
nexaai/binds/nexa_llama_cpp/libggml-metal.so,sha256=ssn3Bqmnu7YA_FKL513Y18gbxG8WP9Udw71DNKV34eo,713680
|
|
28
28
|
nexaai/binds/nexa_llama_cpp/libggml.dylib,sha256=Z2ZvkyEEpPtHhMYap-44p9Q0M6TXJbLcMy-smR2X5sk,58336
|
|
29
|
-
nexaai/binds/nexa_llama_cpp/libllama.dylib,sha256=
|
|
30
|
-
nexaai/binds/nexa_llama_cpp/libmtmd.dylib,sha256=
|
|
31
|
-
nexaai/binds/nexa_llama_cpp/libnexa_plugin.dylib,sha256
|
|
29
|
+
nexaai/binds/nexa_llama_cpp/libllama.dylib,sha256=QZBn_w32g8NAJLE1unC_qx1BCVM531LeqTUqWipt9ks,1982280
|
|
30
|
+
nexaai/binds/nexa_llama_cpp/libmtmd.dylib,sha256=F1QLNlfjiECRssUtEZeuqNqej-8COYcQjMZKPAB0CGk,701504
|
|
31
|
+
nexaai/binds/nexa_llama_cpp/libnexa_plugin.dylib,sha256=csMdM6l21qpj-3_4z0xGsYM1snOBg4cJPfLXOQ8oTcI,2644752
|
|
32
32
|
nexaai/cv_impl/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
33
33
|
nexaai/cv_impl/mlx_cv_impl.py,sha256=gKECQOv8iaWwG3bl7xeqVy2NN_9K7tYerIFzfn4eLo4,3228
|
|
34
34
|
nexaai/cv_impl/pybind_cv_impl.py,sha256=uSmwBste4cT7c8DQmXzRLmzwDf773PAbXNYWW1UzVls,1064
|
|
35
35
|
nexaai/embedder_impl/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
36
|
-
nexaai/embedder_impl/mlx_embedder_impl.py,sha256=
|
|
36
|
+
nexaai/embedder_impl/mlx_embedder_impl.py,sha256=dTjOC1VJ9ypIgCvkK_jKNSWpswbg132rDcTzWcL5oFA,4482
|
|
37
37
|
nexaai/embedder_impl/pybind_embedder_impl.py,sha256=Ga1JYauVkRq6jwAGL7Xx5HDaIx483_v9gZVoTyd3xNU,3495
|
|
38
38
|
nexaai/image_gen_impl/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
39
39
|
nexaai/image_gen_impl/mlx_image_gen_impl.py,sha256=BuDkksvXyb4J02GsdnbGAmYckfUU0Eah6BimoMD3QqY,11219
|
|
@@ -53,9 +53,9 @@ nexaai/mlx_backend/cv/interface.py,sha256=qE51ApUETEZxDMPZB4VdV098fsXcIiEg4Hj9za
|
|
|
53
53
|
nexaai/mlx_backend/cv/main.py,sha256=hYaF2C36hKTyy7kGMNkzLrdczPiFVS73H320klzzpVM,2856
|
|
54
54
|
nexaai/mlx_backend/cv/modeling/pp_ocr_v4.py,sha256=Vpa-QTy7N5oFfGI7Emldx1dOYJWv_4nAFNRDz_5vHBI,58593
|
|
55
55
|
nexaai/mlx_backend/embedding/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
56
|
-
nexaai/mlx_backend/embedding/generate.py,sha256=
|
|
57
|
-
nexaai/mlx_backend/embedding/interface.py,sha256=
|
|
58
|
-
nexaai/mlx_backend/embedding/main.py,sha256=
|
|
56
|
+
nexaai/mlx_backend/embedding/generate.py,sha256=leZA0Ir78-5GV3jloPKYSAKgb04Wr5jORFJlSSVyKs0,12855
|
|
57
|
+
nexaai/mlx_backend/embedding/interface.py,sha256=M7AGiq_UVLNIi2Ie6H08ySnMxIjIhUlNgmV9I_rKYt4,22742
|
|
58
|
+
nexaai/mlx_backend/embedding/main.py,sha256=xKRebBcooKuf8DzWKwCicftes3MAcYAd1QvcT9_AAPQ,6003
|
|
59
59
|
nexaai/mlx_backend/embedding/modeling/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
60
60
|
nexaai/mlx_backend/embedding/modeling/nexa_jina_v2.py,sha256=F9Z_9r-Dh0wNThiMp5W5hqE2dt5bf4ps5_c6h4BuWGw,15218
|
|
61
61
|
nexaai/mlx_backend/llm/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -358,12 +358,12 @@ nexaai/tts_impl/mlx_tts_impl.py,sha256=i_uNPdvlXYtL3e01oKjDlP9jgkWCRt1bBHsExaaiJ
|
|
|
358
358
|
nexaai/tts_impl/pybind_tts_impl.py,sha256=mpn44r6pfYLIl-NrEy2dXHjGtWtNCmM7HRyxiANxUI4,1444
|
|
359
359
|
nexaai/utils/avatar_fetcher.py,sha256=bWy8ujgbOiTHFCjFxTwkn3uXbZ84PgEGUkXkR3MH4bI,3821
|
|
360
360
|
nexaai/utils/decode.py,sha256=61n4Zf6c5QLyqGoctEitlI9BX3tPlP2a5aaKNHbw3T4,404
|
|
361
|
-
nexaai/utils/model_manager.py,sha256=
|
|
362
|
-
nexaai/utils/progress_tracker.py,sha256=
|
|
361
|
+
nexaai/utils/model_manager.py,sha256=QJlE-VvkAy38-8hhPYDu8NVHZqpF_HygXU189qd4Wkg,48157
|
|
362
|
+
nexaai/utils/progress_tracker.py,sha256=mTw7kaKH8BkmecYm7iBMqRHd9uUH4Ch0S8CzbpARDCk,15404
|
|
363
363
|
nexaai/vlm_impl/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
364
364
|
nexaai/vlm_impl/mlx_vlm_impl.py,sha256=od1R1mRoIgPG3NHC7JiDlcB_YJY8aklX8Em3ZkeHNpE,10734
|
|
365
365
|
nexaai/vlm_impl/pybind_vlm_impl.py,sha256=5ZMFgDATthmMzjrd-vE5KX5ZAMoWPYbF_FTLz8DBKIk,8908
|
|
366
|
-
nexaai-1.0.
|
|
367
|
-
nexaai-1.0.
|
|
368
|
-
nexaai-1.0.
|
|
369
|
-
nexaai-1.0.
|
|
366
|
+
nexaai-1.0.6.dist-info/METADATA,sha256=A27rUPomOEG4fxgVqRitNGDw71-HoGLvlphrhjkebrU,1197
|
|
367
|
+
nexaai-1.0.6.dist-info/WHEEL,sha256=0KYp5feZ1CMUhsfFXKpSQTbSmQbXy4mv6yPPVBXg2EM,110
|
|
368
|
+
nexaai-1.0.6.dist-info/top_level.txt,sha256=LRE2YERlrZk2vfuygnSzsEeqSknnZbz3Z1MHyNmBU4w,7
|
|
369
|
+
nexaai-1.0.6.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|