nexaai 1.0.6rc1__cp310-cp310-macosx_14_0_universal2.whl → 1.0.7__cp310-cp310-macosx_14_0_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nexaai might be problematic. Click here for more details.

@@ -20,11 +20,16 @@ import mlx.core as mx
20
20
  import numpy as np
21
21
  from pathlib import Path
22
22
  from typing import Any, List, Optional, Sequence
23
+ from abc import ABC, abstractmethod
23
24
 
24
25
  # Import necessary modules
25
26
  from tokenizers import Tokenizer
26
27
 
27
28
  # Import from ml.py for API alignment
29
+ import sys
30
+ from pathlib import Path as PathLib
31
+ sys.path.insert(0, str(PathLib(__file__).parent.parent))
32
+
28
33
  from ml import (
29
34
  Embedder as BaseEmbedder,
30
35
  EmbeddingConfig,
@@ -34,13 +39,24 @@ from ml import (
34
39
  # Import profiling module
35
40
  from profiling import ProfilingMixin, StopReason
36
41
 
37
- # Import the model implementation
38
- from .modeling.nexa_jina_v2 import Model, ModelArgs
42
+ # Import the model implementation for Jina
43
+ try:
44
+ from .modeling.nexa_jina_v2 import Model, ModelArgs
45
+ except ImportError:
46
+ # Fallback for when module is run directly
47
+ from modeling.nexa_jina_v2 import Model, ModelArgs
48
+
49
+ # Import mlx_embeddings for general embedding support
50
+ try:
51
+ import mlx_embeddings
52
+ MLX_EMBEDDINGS_AVAILABLE = True
53
+ except ImportError:
54
+ MLX_EMBEDDINGS_AVAILABLE = False
39
55
 
40
56
 
41
- class Embedder(BaseEmbedder, ProfilingMixin):
57
+ class BaseMLXEmbedder(BaseEmbedder, ProfilingMixin, ABC):
42
58
  """
43
- Embedder interface for MLX embedding models.
59
+ Abstract base embedder interface for MLX embedding models.
44
60
  API aligned with ml.py Embedder abstract base class.
45
61
  """
46
62
 
@@ -64,7 +80,7 @@ class Embedder(BaseEmbedder, ProfilingMixin):
64
80
 
65
81
  self.model_path = model_path
66
82
  self.tokenizer_path = tokenizer_path
67
- self.device = device if device is not None else "cpu" # TODO: This device field is never used
83
+ self.device = device if device is not None else "cpu"
68
84
 
69
85
  # Initialize model and tokenizer as None
70
86
  self.model = None
@@ -78,6 +94,69 @@ class Embedder(BaseEmbedder, ProfilingMixin):
78
94
  self.config = None
79
95
  self.reset_profiling()
80
96
 
97
+ @abstractmethod
98
+ def load_model(self, model_path: PathType) -> bool:
99
+ """Load model from path."""
100
+ pass
101
+
102
+ def close(self) -> None:
103
+ """Close the model."""
104
+ self.destroy()
105
+
106
+ @abstractmethod
107
+ def embed(
108
+ self,
109
+ texts: Sequence[str],
110
+ config: Optional[EmbeddingConfig] = None,
111
+ clear_cache: bool = True,
112
+ ) -> List[List[float]]:
113
+ """Generate embeddings for texts."""
114
+ pass
115
+
116
+ @abstractmethod
117
+ def embedding_dim(self) -> int:
118
+ """Get embedding dimension."""
119
+ pass
120
+
121
+ def set_lora(self, lora_id: int) -> None:
122
+ """Set active LoRA adapter. (Disabled for embedding models)"""
123
+ raise NotImplementedError("LoRA is not supported for embedding models")
124
+
125
+ def add_lora(self, lora_path: PathType) -> int:
126
+ """Add LoRA adapter and return its ID. (Disabled for embedding models)"""
127
+ raise NotImplementedError("LoRA is not supported for embedding models")
128
+
129
+ def remove_lora(self, lora_id: int) -> None:
130
+ """Remove LoRA adapter. (Disabled for embedding models)"""
131
+ raise NotImplementedError("LoRA is not supported for embedding models")
132
+
133
+ def list_loras(self) -> List[int]:
134
+ """List available LoRA adapters. (Disabled for embedding models)"""
135
+ raise NotImplementedError("LoRA is not supported for embedding models")
136
+
137
+ def _normalize_embedding(self, embedding: List[float], method: str) -> List[float]:
138
+ """Normalize embedding using specified method."""
139
+ if method == "none":
140
+ return embedding
141
+
142
+ embedding_array = np.array(embedding)
143
+
144
+ if method == "l2":
145
+ norm = np.linalg.norm(embedding_array)
146
+ if norm > 0:
147
+ embedding_array = embedding_array / norm
148
+ elif method == "mean":
149
+ mean_val = np.mean(embedding_array)
150
+ embedding_array = embedding_array - mean_val
151
+
152
+ return embedding_array.tolist()
153
+
154
+
155
+ class JinaV2Embedder(BaseMLXEmbedder):
156
+ """
157
+ Embedder implementation specifically for Jina V2 models.
158
+ """
159
+
81
160
  def load_model(self, model_path: PathType) -> bool:
82
161
  """Load model from path."""
83
162
  try:
@@ -97,10 +176,6 @@ class Embedder(BaseEmbedder, ProfilingMixin):
97
176
  print(f"Failed to load model: {e}")
98
177
  return False
99
178
 
100
- def close(self) -> None:
101
- """Close the model."""
102
- self.destroy()
103
-
104
179
  def embed(
105
180
  self,
106
181
  texts: Sequence[str],
@@ -158,22 +233,6 @@ class Embedder(BaseEmbedder, ProfilingMixin):
158
233
  return 768 # Default dimension for Jina v2
159
234
  return self.config.hidden_size
160
235
 
161
- def set_lora(self, lora_id: int) -> None:
162
- """Set active LoRA adapter. (Disabled for embedding models)"""
163
- raise NotImplementedError("LoRA is not supported for embedding models")
164
-
165
- def add_lora(self, lora_path: PathType) -> int:
166
- """Add LoRA adapter and return its ID. (Disabled for embedding models)"""
167
- raise NotImplementedError("LoRA is not supported for embedding models")
168
-
169
- def remove_lora(self, lora_id: int) -> None:
170
- """Remove LoRA adapter. (Disabled for embedding models)"""
171
- raise NotImplementedError("LoRA is not supported for embedding models")
172
-
173
- def list_loras(self) -> List[int]:
174
- """List available LoRA adapters. (Disabled for embedding models)"""
175
- raise NotImplementedError("LoRA is not supported for embedding models")
176
-
177
236
  def _load_jina_model(self, model_dir: str) -> Model:
178
237
  """Initialize and load the Jina V2 model with FP16 weights."""
179
238
 
@@ -281,22 +340,267 @@ class Embedder(BaseEmbedder, ProfilingMixin):
281
340
 
282
341
  return embedding_list
283
342
 
284
- def _normalize_embedding(self, embedding: List[float], method: str) -> List[float]:
285
- """Normalize embedding using specified method."""
286
- if method == "none":
287
- return embedding
343
+
344
+ class MlxEmbeddingEmbedder(BaseMLXEmbedder):
345
+ """
346
+ Embedder implementation using mlx_embeddings package for general embedding models.
347
+ """
348
+
349
+ def load_model(self, model_path: PathType) -> bool:
350
+ """Load model from path using mlx_embeddings."""
351
+ if not MLX_EMBEDDINGS_AVAILABLE:
352
+ print("Warning: mlx_embeddings not available. Please install it to use general embedding models.")
353
+ raise ImportError("mlx_embeddings package is not available. Please install it first.")
288
354
 
289
- embedding_array = np.array(embedding)
355
+ try:
356
+ # Use the provided model_path or fall back to instance path
357
+ if model_path:
358
+ if os.path.isfile(model_path):
359
+ model_path = os.path.dirname(model_path)
360
+ self.model_path = model_path
361
+
362
+ # Load model and tokenizer using mlx_embeddings
363
+ self.model, self.tokenizer = mlx_embeddings.load(self.model_path)
364
+
365
+ # Load config to get dimensions
366
+ config_path = os.path.join(self.model_path, "config.json")
367
+ if os.path.exists(config_path):
368
+ with open(config_path, "r") as f:
369
+ self.config = json.load(f)
370
+
371
+ return True
372
+ except Exception as e:
373
+ print(f"Failed to load model: {e}")
374
+ return False
375
+
376
+ def embed(
377
+ self,
378
+ texts: Sequence[str],
379
+ config: Optional[EmbeddingConfig] = None,
380
+ clear_cache: bool = True,
381
+ ) -> List[List[float]]:
382
+ """Generate embeddings for texts using mlx_embeddings."""
383
+ if self.model is None or self.tokenizer is None:
384
+ raise RuntimeError("Model not loaded. Call load_model() first.")
290
385
 
291
- if method == "l2":
292
- norm = np.linalg.norm(embedding_array)
293
- if norm > 0:
294
- embedding_array = embedding_array / norm
295
- elif method == "mean":
296
- mean_val = np.mean(embedding_array)
297
- embedding_array = embedding_array - mean_val
386
+ if config is None:
387
+ config = EmbeddingConfig()
298
388
 
299
- return embedding_array.tolist()
389
+ # Start profiling
390
+ self._start_profiling()
391
+
392
+ try:
393
+ # Calculate total tokens for profiling
394
+ if hasattr(self.tokenizer, 'encode'):
395
+ total_tokens = sum(len(self.tokenizer.encode(text)) for text in texts)
396
+ else:
397
+ # For tokenizers that don't have simple encode method
398
+ total_tokens = len(texts) * 50 # Rough estimate
399
+
400
+ self._update_prompt_tokens(total_tokens)
401
+
402
+ # End prompt processing, start decode
403
+ self._prompt_end()
404
+ self._decode_start()
405
+
406
+ # Check if this is a Gemma3TextModel
407
+ # WORKAROUND: Gemma3TextModel has a bug where it expects 'inputs' as positional arg
408
+ # but mlx_embeddings.generate passes 'input_ids' as keyword arg
409
+ # See: https://github.com/ml-explore/mlx-examples/issues/... (bug report pending)
410
+ is_gemma = False
411
+ if self.config and "architectures" in self.config:
412
+ architectures = self.config.get("architectures", [])
413
+ is_gemma = "Gemma3TextModel" in architectures
414
+
415
+ if is_gemma:
416
+ # HARDCODED WORKAROUND for Gemma3TextModel bug
417
+ # Use direct tokenization and model call instead of mlx_embeddings.generate
418
+ max_length = config.max_length if hasattr(config, 'max_length') else 512
419
+
420
+ # Tokenize using batch_encode_plus
421
+ encoded_input = self.tokenizer.batch_encode_plus(
422
+ list(texts),
423
+ padding=True,
424
+ truncation=True,
425
+ return_tensors='mlx',
426
+ max_length=max_length
427
+ )
428
+
429
+ # Get input tensors
430
+ input_ids = encoded_input['input_ids']
431
+ attention_mask = encoded_input.get('attention_mask', None)
432
+
433
+ # Call model with positional input_ids and keyword attention_mask
434
+ # This matches Gemma3TextModel's expected signature
435
+ output = self.model(input_ids, attention_mask=attention_mask)
436
+
437
+ # Extract embeddings
438
+ embeddings_tensor = output.text_embeds
439
+ else:
440
+ # Normal path for non-Gemma models
441
+ # Generate embeddings using mlx_embeddings standard approach
442
+ output = mlx_embeddings.generate(
443
+ self.model,
444
+ self.tokenizer,
445
+ texts=list(texts),
446
+ max_length=config.max_length if hasattr(config, 'max_length') else 512,
447
+ padding=True,
448
+ truncation=True
449
+ )
450
+
451
+ # Extract embeddings
452
+ embeddings_tensor = output.text_embeds
453
+
454
+ # Convert to list format
455
+ embeddings = []
456
+ for i in range(embeddings_tensor.shape[0]):
457
+ embedding = embeddings_tensor[i].tolist()
458
+
459
+ # Apply normalization if requested
460
+ if config.normalize:
461
+ embedding = self._normalize_embedding(embedding, config.normalize_method)
462
+
463
+ embeddings.append(embedding)
464
+
465
+ if clear_cache:
466
+ mx.clear_cache()
467
+
468
+ # End timing and finalize profiling data
469
+ self._update_generated_tokens(0) # No generation in embedding
470
+ self._set_stop_reason(StopReason.ML_STOP_REASON_COMPLETED)
471
+ self._decode_end()
472
+ self._end_profiling()
473
+
474
+ return embeddings
475
+
476
+ except Exception as e:
477
+ self._set_stop_reason(StopReason.ML_STOP_REASON_UNKNOWN)
478
+ self._decode_end()
479
+ self._end_profiling()
480
+ raise RuntimeError(f"Error generating embeddings: {str(e)}")
481
+
482
+ def embedding_dim(self) -> int:
483
+ """Get embedding dimension."""
484
+ if self.config is None:
485
+ return 768 # Default dimension
486
+
487
+ # Try different config keys that might contain the dimension
488
+ if "hidden_size" in self.config:
489
+ return self.config["hidden_size"]
490
+ elif "d_model" in self.config:
491
+ return self.config["d_model"]
492
+ elif "dim" in self.config:
493
+ return self.config["dim"]
494
+ else:
495
+ return 768 # Fallback default
496
+
497
+
498
+ class MLXEmbedder(BaseMLXEmbedder):
499
+ """
500
+ Concrete embedder class that routes to the appropriate implementation.
501
+ This class can be instantiated directly (for C++ compatibility) and will
502
+ automatically delegate to JinaV2Embedder or MlxEmbeddingEmbedder based on model type.
503
+ """
504
+
505
+ def __init__(
506
+ self,
507
+ model_path: PathType,
508
+ tokenizer_path: PathType,
509
+ device: Optional[str] = None,
510
+ ) -> None:
511
+ """Initialize the Embedder model."""
512
+ super().__init__(model_path, tokenizer_path, device)
513
+ self._impl = None # Will hold the actual implementation
514
+
515
+ def _get_implementation(self) -> BaseMLXEmbedder:
516
+ """Get or create the appropriate implementation based on model type."""
517
+ if self._impl is None:
518
+ # Detect model type and create appropriate implementation
519
+ model_type = _detect_model_type(self.model_path)
520
+
521
+ if model_type == "jina_v2":
522
+ self._impl = JinaV2Embedder(self.model_path, self.tokenizer_path, self.device)
523
+ else:
524
+ self._impl = MlxEmbeddingEmbedder(self.model_path, self.tokenizer_path, self.device)
525
+
526
+ # Copy over any existing state
527
+ if self.model is not None:
528
+ self._impl.model = self.model
529
+ if self.tokenizer is not None:
530
+ self._impl.tokenizer = self.tokenizer
531
+ if self.config is not None:
532
+ self._impl.config = self.config
533
+
534
+ return self._impl
535
+
536
+ def load_model(self, model_path: PathType) -> bool:
537
+ """Load model from path."""
538
+ # Get the appropriate implementation and delegate
539
+ impl = self._get_implementation()
540
+ result = impl.load_model(model_path)
541
+
542
+ # Sync state back
543
+ self.model = impl.model
544
+ self.tokenizer = impl.tokenizer
545
+ self.config = impl.config
546
+
547
+ return result
548
+
549
+ def embed(
550
+ self,
551
+ texts: Sequence[str],
552
+ config: Optional[EmbeddingConfig] = None,
553
+ clear_cache: bool = True,
554
+ ) -> List[List[float]]:
555
+ """Generate embeddings for texts."""
556
+ # Get the appropriate implementation and delegate
557
+ impl = self._get_implementation()
558
+ return impl.embed(texts, config, clear_cache)
559
+
560
+ def embedding_dim(self) -> int:
561
+ """Get embedding dimension."""
562
+ # Get the appropriate implementation and delegate
563
+ impl = self._get_implementation()
564
+ return impl.embedding_dim()
565
+
566
+ def destroy(self) -> None:
567
+ """Destroy the model and free resources."""
568
+ super().destroy()
569
+ if self._impl is not None:
570
+ self._impl.destroy()
571
+ self._impl = None
572
+
573
+
574
+ # Backward compatibility alias
575
+ Embedder = MLXEmbedder
576
+
577
+
578
+ def _detect_model_type(model_path: PathType) -> str:
579
+ """Detect the model type from config.json."""
580
+ if os.path.isfile(model_path):
581
+ model_path = os.path.dirname(model_path)
582
+
583
+ config_path = os.path.join(model_path, "config.json")
584
+
585
+ if not os.path.exists(config_path):
586
+ # If no config.json, assume it's a generic model
587
+ return "generic"
588
+
589
+ try:
590
+ with open(config_path, "r") as f:
591
+ config = json.load(f)
592
+
593
+ # Check architectures field for JinaBertModel
594
+ architectures = config.get("architectures", [])
595
+ if "JinaBertModel" in architectures:
596
+ return "jina_v2"
597
+
598
+ # Default to generic mlx_embeddings for other models
599
+ return "generic"
600
+
601
+ except Exception as e:
602
+ print(f"Warning: Could not parse config.json: {e}")
603
+ return "generic"
300
604
 
301
605
 
302
606
  # Factory function for creating embedder instances
@@ -304,9 +608,10 @@ def create_embedder(
304
608
  model_path: PathType,
305
609
  tokenizer_path: Optional[PathType] = None,
306
610
  device: Optional[str] = None,
307
- ) -> Embedder:
308
- """Create and return an Embedder instance."""
611
+ ) -> MLXEmbedder:
612
+ """Create and return an MLXEmbedder instance that automatically routes to the appropriate implementation."""
309
613
  if tokenizer_path is None:
310
614
  tokenizer_path = model_path
311
615
 
312
- return Embedder(model_path, tokenizer_path, device)
616
+ # Return the concrete MLXEmbedder which will handle routing internally
617
+ return MLXEmbedder(model_path, tokenizer_path, device)
@@ -12,71 +12,162 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from .interface import create_embedder, EmbeddingConfig
15
+ import os
16
+ import sys
17
+ import numpy as np
18
+ from pathlib import Path
16
19
 
20
+ # Add parent path for imports
21
+ sys.path.insert(0, str(Path(__file__).parent.parent))
17
22
 
18
- def test_embedding(model_path):
19
- """Test embedding model functionality."""
23
+ # Import from interface (uses the factory pattern with routing)
24
+ from .interface import create_embedder
25
+ from .interface import EmbeddingConfig
26
+ from huggingface_hub import snapshot_download
27
+
28
+
29
+ def download_model_if_needed(model_id, local_dir):
30
+ """Download model from Hugging Face Hub if not present locally."""
31
+ if not os.path.exists(os.path.join(local_dir, "config.json")):
32
+ print(f"📥 Model not found locally. Downloading {model_id}...")
33
+ os.makedirs(local_dir, exist_ok=True)
34
+ try:
35
+ snapshot_download(
36
+ repo_id=model_id,
37
+ local_dir=local_dir,
38
+ resume_download=True,
39
+ local_dir_use_symlinks=False
40
+ )
41
+ print("✅ Model download completed!")
42
+ except Exception as e:
43
+ print(f"❌ Failed to download model: {e}")
44
+ raise
45
+
46
+
47
+ def test_embedding_interface(model_path, is_local=False):
48
+ """Test embedding model functionality using the interface."""
49
+
50
+ print("=" * 70)
51
+ print("TESTING EMBEDDING MODEL VIA INTERFACE")
52
+ print("=" * 70)
53
+
54
+ # Handle model path - download if it's a HF model ID
55
+ if not is_local and "/" in model_path:
56
+ # It's a HuggingFace model ID
57
+ local_dir = f"./modelfiles/{model_path.replace('/', '_')}"
58
+ download_model_if_needed(model_path, local_dir)
59
+ model_path = local_dir
60
+
61
+ # Create embedder using factory function (will auto-detect model type)
62
+ print(f"\n🔍 Creating embedder for: {model_path}")
20
63
  embedder = create_embedder(model_path=model_path)
64
+ print(f"✅ Created embedder type: {type(embedder).__name__}")
21
65
 
22
66
  # Load the model
23
- print("Loading embedding model...")
67
+ print("\n📚 Loading embedding model...")
24
68
  success = embedder.load_model(model_path)
25
69
 
26
70
  if not success:
27
- print("Failed to load model!")
71
+ print("Failed to load model!")
28
72
  return
29
73
 
30
74
  print("✅ Model loaded successfully!")
31
- print(f"Embedding dimension: {embedder.embedding_dim()}")
75
+ print(f"📏 Embedding dimension: {embedder.embedding_dim()}")
32
76
 
33
77
  # Test texts
34
78
  test_texts = [
35
79
  "Hello, how are you?",
36
80
  "What is machine learning?",
37
81
  "The weather is nice today.",
38
- "Python is a programming language."
82
+ "Python is a programming language.",
83
+ "Artificial intelligence is changing the world."
39
84
  ]
40
85
 
41
- # Configure embedding
42
- config = EmbeddingConfig(
43
- batch_size=2,
44
- normalize=True,
45
- normalize_method="l2"
46
- )
86
+ # Configure embedding with different settings
87
+ configs = [
88
+ EmbeddingConfig(batch_size=2, normalize=True, normalize_method="l2"),
89
+ EmbeddingConfig(batch_size=3, normalize=False),
90
+ ]
91
+
92
+ for config_idx, config in enumerate(configs):
93
+ print(f"\n{'='*50}")
94
+ print(f"TEST {config_idx + 1}: Config - Batch: {config.batch_size}, "
95
+ f"Normalize: {config.normalize}, Method: {config.normalize_method}")
96
+ print('='*50)
97
+
98
+ # Generate embeddings
99
+ embeddings = embedder.embed(test_texts, config)
100
+
101
+ # Display results
102
+ print(f"\n📊 Generated {len(embeddings)} embeddings")
103
+
104
+ for i, (text, embedding) in enumerate(zip(test_texts[:3], embeddings[:3])):
105
+ print(f"\n Text {i+1}: '{text}'")
106
+ print(f" Dimension: {len(embedding)}")
107
+ print(f" First 5 values: {[f'{v:.4f}' for v in embedding[:5]]}")
108
+
109
+ # Calculate magnitude
110
+ magnitude = np.linalg.norm(embedding)
111
+ print(f" Magnitude: {magnitude:.6f}")
47
112
 
48
- print(f"\nGenerating embeddings for {len(test_texts)} texts...")
113
+ # Compute similarity matrix for normalized embeddings
114
+ print("\n" + "="*50)
115
+ print("SIMILARITY MATRIX (L2 Normalized)")
116
+ print("="*50)
49
117
 
50
- # Generate embeddings
118
+ config = EmbeddingConfig(batch_size=len(test_texts), normalize=True, normalize_method="l2")
51
119
  embeddings = embedder.embed(test_texts, config)
52
120
 
53
- # Display results
54
- print("\nEmbedding Results:")
55
- print("=" * 50)
121
+ # Convert to numpy for easier computation
122
+ embeddings_np = np.array(embeddings)
123
+ similarity_matrix = np.dot(embeddings_np, embeddings_np.T)
56
124
 
57
- for i, (text, embedding) in enumerate(zip(test_texts, embeddings)):
58
- print(f"\nText {i+1}: '{text}'")
59
- print(f"Embedding shape: {len(embedding)}")
60
- print(f"First 5 values: {embedding[:5]}")
61
-
62
- # Calculate magnitude for normalized embeddings
63
- magnitude = sum(x*x for x in embedding) ** 0.5
64
- print(f"Magnitude: {magnitude:.6f}")
125
+ print("\nTexts:")
126
+ for i, text in enumerate(test_texts):
127
+ print(f" [{i}] {text[:30]}...")
128
+
129
+ print("\nSimilarity Matrix:")
130
+ print(" ", end="")
131
+ for i in range(len(test_texts)):
132
+ print(f" [{i}] ", end="")
133
+ print()
65
134
 
66
- # Test similarity between first two embeddings
67
- if len(embeddings) >= 2:
68
- emb1, emb2 = embeddings[0], embeddings[1]
69
- similarity = sum(a*b for a, b in zip(emb1, emb2))
70
- print(f"\nCosine similarity between text 1 and 2: {similarity:.6f}")
135
+ for i in range(len(test_texts)):
136
+ print(f" [{i}]", end="")
137
+ for j in range(len(test_texts)):
138
+ print(f" {similarity_matrix[i, j]:5.2f}", end="")
139
+ print()
140
+
141
+ # Find most similar pairs
142
+ print("\n🔍 Most Similar Pairs (excluding self-similarity):")
143
+ similarities = []
144
+ for i in range(len(test_texts)):
145
+ for j in range(i+1, len(test_texts)):
146
+ similarities.append((similarity_matrix[i, j], i, j))
147
+
148
+ similarities.sort(reverse=True)
149
+ for sim, i, j in similarities[:3]:
150
+ print(f" • Texts [{i}] and [{j}]: {sim:.4f}")
71
151
 
72
152
  # Cleanup
73
153
  embedder.close()
74
- print("\n✅ Embedding test completed!")
154
+ print("\n✅ Interface test completed successfully!")
75
155
 
76
156
 
77
157
  if __name__ == "__main__":
78
158
  import argparse
79
- parser = argparse.ArgumentParser()
80
- parser.add_argument("--model_path", type=str, default="nexaml/jina-v2-fp16-mlx")
159
+ parser = argparse.ArgumentParser(description="Test embedding models via interface")
160
+ parser.add_argument(
161
+ "--model_path",
162
+ type=str,
163
+ default="nexaml/jina-v2-fp16-mlx",
164
+ help="Model path (local) or HuggingFace model ID"
165
+ )
166
+ parser.add_argument(
167
+ "--local",
168
+ action="store_true",
169
+ help="Indicate if model_path is a local directory"
170
+ )
81
171
  args = parser.parse_args()
82
- test_embedding(args.model_path)
172
+
173
+ test_embedding_interface(args.model_path, args.local)