nexaai 1.0.4rc16__cp310-cp310-macosx_13_0_x86_64.whl → 1.0.6__cp310-cp310-macosx_13_0_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nexaai might be problematic. Click here for more details.

Binary file
nexaai/_version.py CHANGED
@@ -1,4 +1,4 @@
1
1
  # This file is generated by CMake from _version.py.in
2
2
  # Do not modify this file manually - it will be overwritten
3
3
 
4
- __version__ = "1.0.4-rc16"
4
+ __version__ = "1.0.6"
Binary file
Binary file
Binary file
Binary file
Binary file
nexaai/common.py CHANGED
@@ -59,10 +59,45 @@ class ModelConfig:
59
59
 
60
60
  @dataclass(frozen=True) # Read-only
61
61
  class ProfilingData:
62
- start_time: int
63
- end_time: int
64
- prompt_start_time: int = None
65
- prompt_end_time: int = None
66
- decode_start_time: int = None
67
- decode_ent_time: int = None
68
- first_token_time: int = None
62
+ """Profiling data structure for LLM/VLM performance metrics."""
63
+ ttft: int = 0 # Time to first token (us)
64
+ prompt_time: int = 0 # Prompt processing time (us)
65
+ decode_time: int = 0 # Token generation time (us)
66
+ prompt_tokens: int = 0 # Number of prompt tokens
67
+ generated_tokens: int = 0 # Number of generated tokens
68
+ audio_duration: int = 0 # Audio duration (us)
69
+ prefill_speed: float = 0.0 # Prefill speed (tokens/sec)
70
+ decoding_speed: float = 0.0 # Decoding speed (tokens/sec)
71
+ real_time_factor: float = 0.0 # Real-Time Factor (RTF)
72
+ stop_reason: str = "" # Stop reason: "eos", "length", "user", "stop_sequence"
73
+
74
+ @classmethod
75
+ def from_dict(cls, data: dict) -> "ProfilingData":
76
+ """Create ProfilingData from dictionary."""
77
+ return cls(
78
+ ttft=data.get("ttft", 0),
79
+ prompt_time=data.get("prompt_time", 0),
80
+ decode_time=data.get("decode_time", 0),
81
+ prompt_tokens=data.get("prompt_tokens", 0),
82
+ generated_tokens=data.get("generated_tokens", 0),
83
+ audio_duration=data.get("audio_duration", 0),
84
+ prefill_speed=data.get("prefill_speed", 0.0),
85
+ decoding_speed=data.get("decoding_speed", 0.0),
86
+ real_time_factor=data.get("real_time_factor", 0.0),
87
+ stop_reason=data.get("stop_reason", "")
88
+ )
89
+
90
+ def to_dict(self) -> dict:
91
+ """Convert to dictionary."""
92
+ return {
93
+ "ttft": self.ttft,
94
+ "prompt_time": self.prompt_time,
95
+ "decode_time": self.decode_time,
96
+ "prompt_tokens": self.prompt_tokens,
97
+ "generated_tokens": self.generated_tokens,
98
+ "audio_duration": self.audio_duration,
99
+ "prefill_speed": self.prefill_speed,
100
+ "decoding_speed": self.decoding_speed,
101
+ "real_time_factor": self.real_time_factor,
102
+ "stop_reason": self.stop_reason
103
+ }
@@ -3,7 +3,7 @@ import numpy as np
3
3
 
4
4
  from nexaai.common import PluginID
5
5
  from nexaai.embedder import Embedder, EmbeddingConfig
6
- from nexaai.mlx_backend.embedding.interface import Embedder as MLXEmbedderInterface
6
+ from nexaai.mlx_backend.embedding.interface import create_embedder
7
7
  from nexaai.mlx_backend.ml import ModelConfig as MLXModelConfig, SamplerConfig as MLXSamplerConfig, GenerationConfig as MLXGenerationConfig, EmbeddingConfig
8
8
 
9
9
 
@@ -27,11 +27,12 @@ class MLXEmbedderImpl(Embedder):
27
27
  MLXEmbedderImpl instance
28
28
  """
29
29
  try:
30
- # MLX interface is already imported
31
-
32
- # Create instance and load MLX embedder
30
+ # Create instance
33
31
  instance = cls()
34
- instance._mlx_embedder = MLXEmbedderInterface(
32
+
33
+ # Use the factory function to create the appropriate embedder based on model type
34
+ # This will automatically detect if it's JinaV2 or generic model and route correctly
35
+ instance._mlx_embedder = create_embedder(
35
36
  model_path=model_path,
36
37
  tokenizer_path=tokenizer_file
37
38
  )
nexaai/llm.py CHANGED
@@ -4,7 +4,7 @@ import queue
4
4
  import threading
5
5
 
6
6
  from nexaai.common import ModelConfig, GenerationConfig, ChatMessage, PluginID
7
- from nexaai.base import BaseModel
7
+ from nexaai.base import BaseModel, ProfilingData
8
8
 
9
9
  class LLM(BaseModel):
10
10
  def __init__(self, m_cfg: ModelConfig = ModelConfig()):
@@ -63,6 +63,10 @@ class LLM(BaseModel):
63
63
  """
64
64
  pass
65
65
 
66
+ def get_profiling_data(self) -> Optional[ProfilingData]:
67
+ """Get profiling data from the last generation."""
68
+ pass
69
+
66
70
  @abstractmethod
67
71
  def save_kv_cache(self, path: str):
68
72
  """
@@ -1,5 +1,6 @@
1
1
  from typing import Generator, Optional, Any, Sequence, Union
2
2
 
3
+ from nexaai.base import ProfilingData
3
4
  from nexaai.common import ModelConfig, GenerationConfig, ChatMessage, PluginID
4
5
  from nexaai.llm import LLM
5
6
  from nexaai.mlx_backend.llm.interface import LLM as MLXLLMInterface
@@ -215,6 +216,12 @@ class MLXLLMImpl(LLM):
215
216
  except Exception as e:
216
217
  raise RuntimeError(f"Failed to generate text: {str(e)}")
217
218
 
219
+ def get_profiling_data(self) -> Optional[ProfilingData]:
220
+ """Get profiling data from the last generation."""
221
+ if not self._mlx_llm:
222
+ raise RuntimeError("MLX LLM not loaded")
223
+ return self._mlx_llm.get_profiling_data()
224
+
218
225
  def save_kv_cache(self, path: str):
219
226
  """
220
227
  Save the key-value cache to the file.
@@ -2,6 +2,7 @@ from typing import Generator, Optional, Union
2
2
  import queue
3
3
  import threading
4
4
 
5
+ from nexaai.base import ProfilingData
5
6
  from nexaai.common import ModelConfig, GenerationConfig, ChatMessage, PluginID
6
7
  from nexaai.binds import llm_bind, common_bind
7
8
  from nexaai.runtime import _ensure_runtime
@@ -13,6 +14,7 @@ class PyBindLLMImpl(LLM):
13
14
  """Private constructor, should not be called directly."""
14
15
  super().__init__(m_cfg)
15
16
  self._handle = handle # This is a py::capsule
17
+ self._profiling_data = None
16
18
 
17
19
  @classmethod
18
20
  def _load_from(cls,
@@ -97,13 +99,14 @@ class PyBindLLMImpl(LLM):
97
99
  # Run generation in thread
98
100
  def generate():
99
101
  try:
100
- llm_bind.ml_llm_generate(
102
+ result = llm_bind.ml_llm_generate(
101
103
  handle=self._handle,
102
104
  prompt=prompt,
103
105
  config=config,
104
106
  on_token=on_token,
105
107
  user_data=None
106
108
  )
109
+ self._profiling_data = ProfilingData.from_dict(result.get("profile_data", {}))
107
110
  except Exception as e:
108
111
  exception_container[0] = e
109
112
  finally:
@@ -145,8 +148,14 @@ class PyBindLLMImpl(LLM):
145
148
  on_token=None, # No callback for non-streaming
146
149
  user_data=None
147
150
  )
151
+
152
+ self._profiling_data = ProfilingData.from_dict(result.get("profile_data", {}))
148
153
  return result.get("text", "")
149
154
 
155
+ def get_profiling_data(self) -> Optional[ProfilingData]:
156
+ """Get profiling data."""
157
+ return self._profiling_data
158
+
150
159
  def save_kv_cache(self, path: str):
151
160
  """
152
161
  Save the key-value cache to the file.
@@ -23,11 +23,46 @@ from .modeling.nexa_jina_v2 import Model, ModelArgs
23
23
  from tokenizers import Tokenizer
24
24
  from huggingface_hub import snapshot_download
25
25
 
26
- def load_model(model_id):
26
+ # Try to import mlx_embeddings for general embedding support
27
+ try:
28
+ import mlx_embeddings
29
+ MLX_EMBEDDINGS_AVAILABLE = True
30
+ except ImportError:
31
+ MLX_EMBEDDINGS_AVAILABLE = False
32
+ # Suppress warning during import to avoid interfering with C++ tests
33
+ # The warning will be shown when actually trying to use mlx_embeddings functionality
34
+ pass
35
+
36
+ def detect_model_type(model_path):
37
+ """Detect if the model is Jina V2 or generic mlx_embeddings model."""
38
+ config_path = os.path.join(model_path, "config.json") if os.path.isdir(model_path) else f"{model_path}/config.json"
39
+
40
+ if not os.path.exists(config_path):
41
+ # Try default modelfiles directory
42
+ config_path = f"{curr_dir}/modelfiles/config.json"
43
+ if not os.path.exists(config_path):
44
+ return "generic"
45
+
46
+ try:
47
+ with open(config_path, "r") as f:
48
+ config = json.load(f)
49
+
50
+ # Check if it's a Jina V2 model
51
+ architectures = config.get("architectures", [])
52
+ if "JinaBertModel" in architectures:
53
+ return "jina_v2"
54
+
55
+ return "generic"
56
+ except Exception:
57
+ return "generic"
58
+
59
+ # ========== Jina V2 Direct Implementation ==========
60
+
61
+ def load_jina_model(model_id):
27
62
  """Initialize and load the Jina V2 model with FP16 weights."""
28
63
  # Load configuration from config.json
29
64
  if not os.path.exists(f"{curr_dir}/modelfiles/config.json"):
30
- print(f"📥 Downloading model {model_id}...")
65
+ print(f"📥 Downloading Jina V2 model {model_id}...")
31
66
 
32
67
  # Ensure modelfiles directory exists
33
68
  os.makedirs(f"{curr_dir}/modelfiles", exist_ok=True)
@@ -82,15 +117,15 @@ def load_model(model_id):
82
117
 
83
118
  return model
84
119
 
85
- def load_tokenizer():
86
- """Load and configure the tokenizer."""
120
+ def load_jina_tokenizer():
121
+ """Load and configure the tokenizer for Jina V2."""
87
122
  tokenizer = Tokenizer.from_file(f"{curr_dir}/modelfiles/tokenizer.json")
88
123
  tokenizer.enable_padding(pad_id=0, pad_token="[PAD]")
89
124
  tokenizer.enable_truncation(max_length=512)
90
125
  return tokenizer
91
126
 
92
- def encode_text(model, tokenizer, text):
93
- """Encode a single text and return its embedding."""
127
+ def encode_jina_text(model, tokenizer, text):
128
+ """Encode a single text using Jina V2 and return its embedding."""
94
129
  # Tokenize the text
95
130
  encoding = tokenizer.encode(text)
96
131
 
@@ -113,18 +148,186 @@ def encode_text(model, tokenizer, text):
113
148
 
114
149
  return embeddings
115
150
 
151
+ # ========== MLX Embeddings Direct Implementation ==========
152
+
153
+ def load_mlx_embeddings_model(model_id):
154
+ """Load model using mlx_embeddings package."""
155
+ if not MLX_EMBEDDINGS_AVAILABLE:
156
+ print("Warning: mlx_embeddings not available. Please install it to use general embedding models.")
157
+ raise ImportError("mlx_embeddings package is not available. Please install it first.")
158
+
159
+ # Download model if needed
160
+ model_path = f"{curr_dir}/modelfiles"
161
+
162
+ if not os.path.exists(f"{model_path}/config.json"):
163
+ print(f"📥 Downloading model {model_id}...")
164
+ os.makedirs(model_path, exist_ok=True)
165
+
166
+ try:
167
+ snapshot_download(
168
+ repo_id=model_id,
169
+ local_dir=model_path,
170
+ resume_download=True,
171
+ local_dir_use_symlinks=False
172
+ )
173
+ print("✅ Model download completed!")
174
+ except Exception as e:
175
+ print(f"❌ Failed to download model: {e}")
176
+ raise
177
+
178
+ # Load model and tokenizer using mlx_embeddings
179
+ model, tokenizer = mlx_embeddings.load(model_path)
180
+ return model, tokenizer
181
+
182
+ def encode_mlx_embeddings_text(model, tokenizer, texts, model_path=None):
183
+ """Generate embeddings using mlx_embeddings."""
184
+ if isinstance(texts, str):
185
+ texts = [texts]
186
+
187
+ # Check if this is a Gemma3TextModel by checking config
188
+ # WORKAROUND: Gemma3TextModel has a bug where it expects 'inputs' as positional arg
189
+ # but mlx_embeddings.generate passes 'input_ids' as keyword arg
190
+ # See: https://github.com/ml-explore/mlx-examples/issues/... (bug report pending)
191
+ is_gemma = False
192
+ if model_path:
193
+ config_path = os.path.join(model_path, "config.json") if os.path.isdir(model_path) else f"{model_path}/config.json"
194
+ else:
195
+ config_path = f"{curr_dir}/modelfiles/config.json"
196
+
197
+ if os.path.exists(config_path):
198
+ try:
199
+ with open(config_path, "r") as f:
200
+ config = json.load(f)
201
+ architectures = config.get("architectures", [])
202
+ is_gemma = "Gemma3TextModel" in architectures
203
+ except Exception:
204
+ pass
205
+
206
+ if is_gemma:
207
+ # HARDCODED WORKAROUND for Gemma3TextModel bug
208
+ # Use direct tokenization and model call instead of mlx_embeddings.generate
209
+ # This avoids the bug where generate passes 'input_ids' as keyword arg
210
+ # but Gemma3TextModel.__call__ expects 'inputs' as positional arg
211
+
212
+ # Tokenize using batch_encode_plus for Gemma models
213
+ encoded_input = tokenizer.batch_encode_plus(
214
+ texts,
215
+ padding=True,
216
+ truncation=True,
217
+ return_tensors='mlx',
218
+ max_length=512
219
+ )
220
+
221
+ # Get input tensors
222
+ input_ids = encoded_input['input_ids']
223
+ attention_mask = encoded_input.get('attention_mask', None)
224
+
225
+ # Call model with positional input_ids and keyword attention_mask
226
+ # This matches Gemma3TextModel's expected signature:
227
+ # def __call__(self, inputs: mx.array, attention_mask: Optional[mx.array] = None)
228
+ output = model(input_ids, attention_mask=attention_mask)
229
+
230
+ # Get the normalized embeddings
231
+ return output.text_embeds
232
+ else:
233
+ # Normal path for non-Gemma models
234
+ # Use standard mlx_embeddings.generate approach
235
+ output = mlx_embeddings.generate(
236
+ model,
237
+ tokenizer,
238
+ texts=texts,
239
+ max_length=512,
240
+ padding=True,
241
+ truncation=True
242
+ )
243
+
244
+ return output.text_embeds
245
+
116
246
  def main(model_id):
117
247
  """Main function to handle user input and generate embeddings."""
118
248
 
119
- # Load model and tokenizer
120
- model = load_model(model_id)
121
- tokenizer = load_tokenizer()
122
- user_input = "Hello, how are you?"
123
- embedding = encode_text(model, tokenizer, user_input)
124
- print(f"Embedding shape: {embedding.shape}")
125
- print(f"Embedding sample values: {embedding.flatten()[:5].tolist()}")
126
- print(f"Embedding min: {embedding.min()}, Max: {embedding.max()}, Mean: {embedding.mean()}, Std: {embedding.std()}")
249
+ print(f"🔍 Loading model: {model_id}")
250
+
251
+ # Detect model type
252
+ model_type = detect_model_type(f"{curr_dir}/modelfiles")
253
+
254
+ # First try to download/check if model exists
255
+ if not os.path.exists(f"{curr_dir}/modelfiles/config.json"):
256
+ # Download the model first to detect its type
257
+ print(f"Model not found locally. Downloading...")
258
+ os.makedirs(f"{curr_dir}/modelfiles", exist_ok=True)
259
+ try:
260
+ snapshot_download(
261
+ repo_id=model_id,
262
+ local_dir=f"{curr_dir}/modelfiles",
263
+ resume_download=True,
264
+ local_dir_use_symlinks=False
265
+ )
266
+ print("✅ Model download completed!")
267
+ # Re-detect model type after download
268
+ model_type = detect_model_type(f"{curr_dir}/modelfiles")
269
+ except Exception as e:
270
+ print(f"❌ Failed to download model: {e}")
271
+ raise
272
+
273
+ print(f"📦 Detected model type: {model_type}")
274
+
275
+ # Test texts
276
+ test_texts = [
277
+ "Hello, how are you?",
278
+ "What is machine learning?",
279
+ "The weather is nice today."
280
+ ]
281
+
282
+ if model_type == "jina_v2":
283
+ print("Using Jina V2 direct implementation")
284
+
285
+ # Load Jina V2 model
286
+ model = load_jina_model(model_id)
287
+ tokenizer = load_jina_tokenizer()
288
+
289
+ print("\nGenerating embeddings for test texts:")
290
+ for text in test_texts:
291
+ embedding = encode_jina_text(model, tokenizer, text)
292
+ print(f"\nText: '{text}'")
293
+ print(f" Embedding shape: {embedding.shape}")
294
+ print(f" Sample values (first 5): {embedding.flatten()[:5].tolist()}")
295
+ print(f" Stats - Min: {embedding.min():.4f}, Max: {embedding.max():.4f}, Mean: {embedding.mean():.4f}")
296
+
297
+ else:
298
+ print("Using mlx_embeddings direct implementation")
299
+
300
+ if not MLX_EMBEDDINGS_AVAILABLE:
301
+ print("❌ mlx_embeddings is not installed. Please install it to use generic models.")
302
+ return
303
+
304
+ # Load generic model using mlx_embeddings
305
+ model, tokenizer = load_mlx_embeddings_model(model_id)
306
+
307
+ print("\nGenerating embeddings for test texts:")
308
+ # Pass model_path to handle Gemma workaround if needed
309
+ embeddings = encode_mlx_embeddings_text(model, tokenizer, test_texts, model_path=f"{curr_dir}/modelfiles")
310
+
311
+ for i, text in enumerate(test_texts):
312
+ embedding = embeddings[i]
313
+ print(f"\nText: '{text}'")
314
+ print(f" Embedding shape: {embedding.shape}")
315
+ print(f" Sample values (first 5): {embedding[:5].tolist()}")
316
+
317
+ # Calculate stats
318
+ emb_array = mx.array(embedding) if not isinstance(embedding, mx.array) else embedding
319
+ print(f" Stats - Min: {emb_array.min():.4f}, Max: {emb_array.max():.4f}, Mean: {emb_array.mean():.4f}")
320
+
321
+ print("\n✅ Direct embedding generation completed!")
127
322
 
128
323
  if __name__ == "__main__":
129
- model_id = "nexaml/jina-v2-fp16-mlx"
130
- main(model_id)
324
+ import argparse
325
+ parser = argparse.ArgumentParser(description="Generate embeddings using direct implementation")
326
+ parser.add_argument(
327
+ "--model_id",
328
+ type=str,
329
+ default="nexaml/jina-v2-fp16-mlx",
330
+ help="Model ID from Hugging Face Hub (e.g., 'nexaml/jina-v2-fp16-mlx' or 'mlx-community/embeddinggemma-300m-bf16')"
331
+ )
332
+ args = parser.parse_args()
333
+ main(args.model_id)