nexaai 1.0.21rc16__cp312-cp312-win_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nexaai might be problematic. Click here for more details.

Files changed (154) hide show
  1. nexaai/__init__.py +95 -0
  2. nexaai/_stub.cp312-win_arm64.pyd +0 -0
  3. nexaai/_version.py +4 -0
  4. nexaai/asr.py +68 -0
  5. nexaai/asr_impl/__init__.py +0 -0
  6. nexaai/asr_impl/mlx_asr_impl.py +92 -0
  7. nexaai/asr_impl/pybind_asr_impl.py +127 -0
  8. nexaai/base.py +39 -0
  9. nexaai/binds/__init__.py +6 -0
  10. nexaai/binds/asr_bind.cp312-win_arm64.pyd +0 -0
  11. nexaai/binds/common_bind.cp312-win_arm64.pyd +0 -0
  12. nexaai/binds/cpu_gpu/ggml-base.dll +0 -0
  13. nexaai/binds/cpu_gpu/ggml-cpu.dll +0 -0
  14. nexaai/binds/cpu_gpu/ggml-opencl.dll +0 -0
  15. nexaai/binds/cpu_gpu/ggml.dll +0 -0
  16. nexaai/binds/cpu_gpu/libomp140.aarch64.dll +0 -0
  17. nexaai/binds/cpu_gpu/mtmd.dll +0 -0
  18. nexaai/binds/cpu_gpu/nexa_cpu_gpu.dll +0 -0
  19. nexaai/binds/cpu_gpu/nexa_plugin.dll +0 -0
  20. nexaai/binds/embedder_bind.cp312-win_arm64.pyd +0 -0
  21. nexaai/binds/libcrypto-3-arm64.dll +0 -0
  22. nexaai/binds/libssl-3-arm64.dll +0 -0
  23. nexaai/binds/llm_bind.cp312-win_arm64.pyd +0 -0
  24. nexaai/binds/nexa_bridge.dll +0 -0
  25. nexaai/binds/npu/FLAC.dll +0 -0
  26. nexaai/binds/npu/convnext-sdk.dll +0 -0
  27. nexaai/binds/npu/embed-gemma-sdk.dll +0 -0
  28. nexaai/binds/npu/fftw3.dll +0 -0
  29. nexaai/binds/npu/fftw3f.dll +0 -0
  30. nexaai/binds/npu/ggml-base.dll +0 -0
  31. nexaai/binds/npu/ggml-cpu.dll +0 -0
  32. nexaai/binds/npu/ggml-opencl.dll +0 -0
  33. nexaai/binds/npu/ggml.dll +0 -0
  34. nexaai/binds/npu/granite-nano-sdk.dll +0 -0
  35. nexaai/binds/npu/granite4-sdk.dll +0 -0
  36. nexaai/binds/npu/htp-files/Genie.dll +0 -0
  37. nexaai/binds/npu/htp-files/PlatformValidatorShared.dll +0 -0
  38. nexaai/binds/npu/htp-files/QnnChrometraceProfilingReader.dll +0 -0
  39. nexaai/binds/npu/htp-files/QnnCpu.dll +0 -0
  40. nexaai/binds/npu/htp-files/QnnCpuNetRunExtensions.dll +0 -0
  41. nexaai/binds/npu/htp-files/QnnDsp.dll +0 -0
  42. nexaai/binds/npu/htp-files/QnnDspNetRunExtensions.dll +0 -0
  43. nexaai/binds/npu/htp-files/QnnDspV66CalculatorStub.dll +0 -0
  44. nexaai/binds/npu/htp-files/QnnDspV66Stub.dll +0 -0
  45. nexaai/binds/npu/htp-files/QnnGenAiTransformer.dll +0 -0
  46. nexaai/binds/npu/htp-files/QnnGenAiTransformerCpuOpPkg.dll +0 -0
  47. nexaai/binds/npu/htp-files/QnnGenAiTransformerModel.dll +0 -0
  48. nexaai/binds/npu/htp-files/QnnGpu.dll +0 -0
  49. nexaai/binds/npu/htp-files/QnnGpuNetRunExtensions.dll +0 -0
  50. nexaai/binds/npu/htp-files/QnnGpuProfilingReader.dll +0 -0
  51. nexaai/binds/npu/htp-files/QnnHtp.dll +0 -0
  52. nexaai/binds/npu/htp-files/QnnHtpNetRunExtensions.dll +0 -0
  53. nexaai/binds/npu/htp-files/QnnHtpOptraceProfilingReader.dll +0 -0
  54. nexaai/binds/npu/htp-files/QnnHtpPrepare.dll +0 -0
  55. nexaai/binds/npu/htp-files/QnnHtpProfilingReader.dll +0 -0
  56. nexaai/binds/npu/htp-files/QnnHtpV68CalculatorStub.dll +0 -0
  57. nexaai/binds/npu/htp-files/QnnHtpV68Stub.dll +0 -0
  58. nexaai/binds/npu/htp-files/QnnHtpV73CalculatorStub.dll +0 -0
  59. nexaai/binds/npu/htp-files/QnnHtpV73Stub.dll +0 -0
  60. nexaai/binds/npu/htp-files/QnnIr.dll +0 -0
  61. nexaai/binds/npu/htp-files/QnnJsonProfilingReader.dll +0 -0
  62. nexaai/binds/npu/htp-files/QnnModelDlc.dll +0 -0
  63. nexaai/binds/npu/htp-files/QnnSaver.dll +0 -0
  64. nexaai/binds/npu/htp-files/QnnSystem.dll +0 -0
  65. nexaai/binds/npu/htp-files/SNPE.dll +0 -0
  66. nexaai/binds/npu/htp-files/SnpeDspV66Stub.dll +0 -0
  67. nexaai/binds/npu/htp-files/SnpeHtpPrepare.dll +0 -0
  68. nexaai/binds/npu/htp-files/SnpeHtpV68Stub.dll +0 -0
  69. nexaai/binds/npu/htp-files/SnpeHtpV73Stub.dll +0 -0
  70. nexaai/binds/npu/htp-files/calculator.dll +0 -0
  71. nexaai/binds/npu/htp-files/calculator_htp.dll +0 -0
  72. nexaai/binds/npu/htp-files/libCalculator_skel.so +0 -0
  73. nexaai/binds/npu/htp-files/libQnnHtpV73.so +0 -0
  74. nexaai/binds/npu/htp-files/libQnnHtpV73QemuDriver.so +0 -0
  75. nexaai/binds/npu/htp-files/libQnnHtpV73Skel.so +0 -0
  76. nexaai/binds/npu/htp-files/libQnnSaver.so +0 -0
  77. nexaai/binds/npu/htp-files/libQnnSystem.so +0 -0
  78. nexaai/binds/npu/htp-files/libSnpeHtpV73Skel.so +0 -0
  79. nexaai/binds/npu/htp-files/libqnnhtpv73.cat +0 -0
  80. nexaai/binds/npu/htp-files/libsnpehtpv73.cat +0 -0
  81. nexaai/binds/npu/jina-rerank-sdk.dll +0 -0
  82. nexaai/binds/npu/libcrypto-3-arm64.dll +0 -0
  83. nexaai/binds/npu/libmp3lame.DLL +0 -0
  84. nexaai/binds/npu/libomp140.aarch64.dll +0 -0
  85. nexaai/binds/npu/libssl-3-arm64.dll +0 -0
  86. nexaai/binds/npu/liquid-sdk.dll +0 -0
  87. nexaai/binds/npu/llama3-3b-sdk.dll +0 -0
  88. nexaai/binds/npu/mpg123.dll +0 -0
  89. nexaai/binds/npu/nexa-mm-process.dll +0 -0
  90. nexaai/binds/npu/nexa-sampling.dll +0 -0
  91. nexaai/binds/npu/nexa_plugin.dll +0 -0
  92. nexaai/binds/npu/nexaproc.dll +0 -0
  93. nexaai/binds/npu/ogg.dll +0 -0
  94. nexaai/binds/npu/omni-neural-sdk.dll +0 -0
  95. nexaai/binds/npu/openblas.dll +0 -0
  96. nexaai/binds/npu/opus.dll +0 -0
  97. nexaai/binds/npu/paddle-ocr-proc-lib.dll +0 -0
  98. nexaai/binds/npu/paddleocr-sdk.dll +0 -0
  99. nexaai/binds/npu/parakeet-sdk.dll +0 -0
  100. nexaai/binds/npu/phi3-5-sdk.dll +0 -0
  101. nexaai/binds/npu/phi4-sdk.dll +0 -0
  102. nexaai/binds/npu/pyannote-sdk.dll +0 -0
  103. nexaai/binds/npu/qwen3-4b-sdk.dll +0 -0
  104. nexaai/binds/npu/qwen3vl-sdk.dll +0 -0
  105. nexaai/binds/npu/qwen3vl-vision.dll +0 -0
  106. nexaai/binds/npu/rtaudio.dll +0 -0
  107. nexaai/binds/npu/vorbis.dll +0 -0
  108. nexaai/binds/npu/vorbisenc.dll +0 -0
  109. nexaai/binds/npu/yolov12-sdk.dll +0 -0
  110. nexaai/binds/npu/zlib1.dll +0 -0
  111. nexaai/binds/rerank_bind.cp312-win_arm64.pyd +0 -0
  112. nexaai/binds/vlm_bind.cp312-win_arm64.pyd +0 -0
  113. nexaai/common.py +105 -0
  114. nexaai/cv.py +93 -0
  115. nexaai/cv_impl/__init__.py +0 -0
  116. nexaai/cv_impl/mlx_cv_impl.py +89 -0
  117. nexaai/cv_impl/pybind_cv_impl.py +32 -0
  118. nexaai/embedder.py +73 -0
  119. nexaai/embedder_impl/__init__.py +0 -0
  120. nexaai/embedder_impl/mlx_embedder_impl.py +118 -0
  121. nexaai/embedder_impl/pybind_embedder_impl.py +96 -0
  122. nexaai/image_gen.py +141 -0
  123. nexaai/image_gen_impl/__init__.py +0 -0
  124. nexaai/image_gen_impl/mlx_image_gen_impl.py +292 -0
  125. nexaai/image_gen_impl/pybind_image_gen_impl.py +85 -0
  126. nexaai/llm.py +98 -0
  127. nexaai/llm_impl/__init__.py +0 -0
  128. nexaai/llm_impl/mlx_llm_impl.py +271 -0
  129. nexaai/llm_impl/pybind_llm_impl.py +220 -0
  130. nexaai/log.py +92 -0
  131. nexaai/rerank.py +57 -0
  132. nexaai/rerank_impl/__init__.py +0 -0
  133. nexaai/rerank_impl/mlx_rerank_impl.py +94 -0
  134. nexaai/rerank_impl/pybind_rerank_impl.py +136 -0
  135. nexaai/runtime.py +68 -0
  136. nexaai/runtime_error.py +24 -0
  137. nexaai/tts.py +75 -0
  138. nexaai/tts_impl/__init__.py +0 -0
  139. nexaai/tts_impl/mlx_tts_impl.py +94 -0
  140. nexaai/tts_impl/pybind_tts_impl.py +43 -0
  141. nexaai/utils/decode.py +18 -0
  142. nexaai/utils/manifest_utils.py +531 -0
  143. nexaai/utils/model_manager.py +1562 -0
  144. nexaai/utils/model_types.py +49 -0
  145. nexaai/utils/progress_tracker.py +385 -0
  146. nexaai/utils/quantization_utils.py +245 -0
  147. nexaai/vlm.py +130 -0
  148. nexaai/vlm_impl/__init__.py +0 -0
  149. nexaai/vlm_impl/mlx_vlm_impl.py +259 -0
  150. nexaai/vlm_impl/pybind_vlm_impl.py +256 -0
  151. nexaai-1.0.21rc16.dist-info/METADATA +31 -0
  152. nexaai-1.0.21rc16.dist-info/RECORD +154 -0
  153. nexaai-1.0.21rc16.dist-info/WHEEL +5 -0
  154. nexaai-1.0.21rc16.dist-info/top_level.txt +1 -0
@@ -0,0 +1,271 @@
1
+ from typing import Generator, Optional, Any, Sequence, Union
2
+
3
+ from nexaai.base import ProfilingData
4
+ from nexaai.common import ModelConfig, GenerationConfig, ChatMessage, PluginID
5
+ from nexaai.llm import LLM
6
+ from nexaai.mlx_backend.llm.interface import LLM as MLXLLMInterface
7
+ from nexaai.mlx_backend.ml import ModelConfig as MLXModelConfig, SamplerConfig as MLXSamplerConfig, GenerationConfig as MLXGenerationConfig, EmbeddingConfig
8
+
9
+
10
+ class MLXLLMImpl(LLM):
11
+ def __init__(self, m_cfg: ModelConfig = ModelConfig()):
12
+ """Initialize MLX LLM implementation."""
13
+ super().__init__(m_cfg)
14
+ self._mlx_llm = None
15
+
16
+ @classmethod
17
+ def _load_from(cls,
18
+ local_path: str,
19
+ model_name: Optional[str] = None,
20
+ tokenizer_path: Optional[str] = None,
21
+ m_cfg: ModelConfig = ModelConfig(),
22
+ plugin_id: Union[PluginID, str] = PluginID.MLX,
23
+ device_id: Optional[str] = None
24
+ ) -> 'MLXLLMImpl':
25
+ """Load model from local path using MLX backend."""
26
+ try:
27
+ # MLX interface and configs are already imported
28
+
29
+ # Convert our ModelConfig to MLX ModelConfig
30
+ mlx_config = MLXModelConfig()
31
+ mlx_config.n_ctx = m_cfg.n_ctx
32
+ mlx_config.n_threads = m_cfg.n_threads
33
+ mlx_config.n_threads_batch = m_cfg.n_threads_batch
34
+ mlx_config.n_batch = m_cfg.n_batch
35
+ mlx_config.n_ubatch = m_cfg.n_ubatch
36
+ mlx_config.n_seq_max = m_cfg.n_seq_max
37
+ mlx_config.chat_template_path = m_cfg.chat_template_path
38
+ mlx_config.chat_template_content = m_cfg.chat_template_content
39
+
40
+ # Create instance and load MLX model
41
+ instance = cls(m_cfg)
42
+ instance._mlx_llm = MLXLLMInterface(
43
+ model_path=local_path,
44
+ # model_name=model_name, # FIXME: For MLX LLM, model_name is not used
45
+ tokenizer_path=tokenizer_path or local_path,
46
+ config=mlx_config,
47
+ device=device_id
48
+ )
49
+
50
+ return instance
51
+ except Exception as e:
52
+ raise RuntimeError(f"Failed to load MLX LLM: {str(e)}")
53
+
54
+ def eject(self):
55
+ """Release the model from memory."""
56
+ if self._mlx_llm:
57
+ self._mlx_llm.destroy()
58
+ self._mlx_llm = None
59
+
60
+ def apply_chat_template(
61
+ self,
62
+ messages: Sequence[ChatMessage],
63
+ tools: Optional[str] = None,
64
+ enable_thinking: bool = True,
65
+ add_generation_prompt: bool = True
66
+ ) -> str:
67
+ """Apply the chat template to messages."""
68
+ if not self._mlx_llm:
69
+ raise RuntimeError("MLX LLM not loaded")
70
+
71
+ try:
72
+ # Convert to MLX ChatMessage format
73
+ mlx_messages = []
74
+ for msg in messages:
75
+ # Create a simple object with role and content attributes
76
+ class MLXChatMessage:
77
+ def __init__(self, role, content):
78
+ self.role = role
79
+ self.content = content
80
+
81
+ # Handle both dict-style and attribute-style access
82
+ if hasattr(msg, 'role') and hasattr(msg, 'content'):
83
+ # Message is already an object with attributes
84
+ mlx_messages.append(MLXChatMessage(msg.role, msg.content))
85
+ else:
86
+ # Message is a dict
87
+ mlx_messages.append(MLXChatMessage(msg["role"], msg["content"]))
88
+
89
+ return self._mlx_llm.apply_chat_template(mlx_messages, tools=tools, enable_thinking=enable_thinking, add_generation_prompt=add_generation_prompt)
90
+ except Exception as e:
91
+ raise RuntimeError(f"Failed to apply chat template: {str(e)}")
92
+
93
+ def generate_stream(self, prompt: str, g_cfg: GenerationConfig = GenerationConfig()) -> Generator[str, None, None]:
94
+ """Generate text with streaming."""
95
+ if not self._mlx_llm:
96
+ raise RuntimeError("MLX LLM not loaded")
97
+
98
+ try:
99
+ import queue
100
+ import threading
101
+
102
+ # Convert GenerationConfig to MLX format
103
+
104
+ mlx_gen_config = MLXGenerationConfig()
105
+ mlx_gen_config.max_tokens = g_cfg.max_tokens
106
+ mlx_gen_config.stop = g_cfg.stop_words
107
+ mlx_gen_config.image_paths = g_cfg.image_paths
108
+ mlx_gen_config.audio_paths = g_cfg.audio_paths
109
+
110
+ if g_cfg.sampler_config:
111
+ mlx_sampler_config = MLXSamplerConfig()
112
+ mlx_sampler_config.temperature = g_cfg.sampler_config.temperature
113
+ mlx_sampler_config.top_p = g_cfg.sampler_config.top_p
114
+ mlx_sampler_config.top_k = g_cfg.sampler_config.top_k
115
+ mlx_sampler_config.repetition_penalty = g_cfg.sampler_config.repetition_penalty
116
+ mlx_sampler_config.presence_penalty = g_cfg.sampler_config.presence_penalty
117
+ mlx_sampler_config.frequency_penalty = g_cfg.sampler_config.frequency_penalty
118
+ mlx_sampler_config.seed = g_cfg.sampler_config.seed
119
+ mlx_sampler_config.grammar_path = g_cfg.sampler_config.grammar_path
120
+ mlx_sampler_config.grammar_string = g_cfg.sampler_config.grammar_string
121
+ mlx_gen_config.sampler_config = mlx_sampler_config
122
+
123
+ # Create a queue for streaming tokens
124
+ token_queue = queue.Queue()
125
+ exception_container = [None]
126
+ self.reset_cancel() # Reset cancel flag before generation
127
+
128
+ def token_callback(token: str, user_data: Any = None) -> bool:
129
+ if self._cancel_event.is_set():
130
+ token_queue.put(('end', None))
131
+ return False
132
+ try:
133
+ token_queue.put(('token', token))
134
+ return True
135
+ except Exception as e:
136
+ exception_container[0] = e
137
+ return False
138
+
139
+ # Run generation in a separate thread
140
+ def generate():
141
+ try:
142
+ self._mlx_llm.generate_stream(prompt, mlx_gen_config, token_callback)
143
+ except Exception as e:
144
+ exception_container[0] = e
145
+ finally:
146
+ token_queue.put(('end', None))
147
+
148
+ thread = threading.Thread(target=generate)
149
+ thread.start()
150
+
151
+ # Yield tokens as they come from the queue
152
+ while True:
153
+ if exception_container[0]:
154
+ raise exception_container[0]
155
+
156
+ try:
157
+ msg_type, token = token_queue.get(timeout=0.1)
158
+ if msg_type == 'end':
159
+ break
160
+ elif msg_type == 'token':
161
+ yield token
162
+ except queue.Empty:
163
+ if not thread.is_alive():
164
+ break
165
+ continue
166
+
167
+ thread.join()
168
+
169
+ if exception_container[0]:
170
+ raise exception_container[0]
171
+
172
+ except Exception as e:
173
+ raise RuntimeError(f"Failed to generate streaming text: {str(e)}")
174
+
175
+ def generate(self, prompt: str, g_cfg: GenerationConfig = GenerationConfig()) -> str:
176
+ """
177
+ Generate text without streaming.
178
+
179
+ Args:
180
+ prompt (str): The prompt to generate text from.
181
+ g_cfg (GenerationConfig): Generation configuration.
182
+
183
+ Returns:
184
+ str: The generated text.
185
+ """
186
+ if not self._mlx_llm:
187
+ raise RuntimeError("MLX LLM not loaded")
188
+
189
+ try:
190
+ # Convert GenerationConfig to MLX format
191
+
192
+ mlx_gen_config = MLXGenerationConfig()
193
+ mlx_gen_config.max_tokens = g_cfg.max_tokens
194
+ mlx_gen_config.stop = g_cfg.stop_words
195
+ mlx_gen_config.image_paths = g_cfg.image_paths
196
+ mlx_gen_config.audio_paths = g_cfg.audio_paths
197
+
198
+ if g_cfg.sampler_config:
199
+ mlx_sampler_config = MLXSamplerConfig()
200
+ mlx_sampler_config.temperature = g_cfg.sampler_config.temperature
201
+ mlx_sampler_config.top_p = g_cfg.sampler_config.top_p
202
+ mlx_sampler_config.top_k = g_cfg.sampler_config.top_k
203
+ mlx_sampler_config.repetition_penalty = g_cfg.sampler_config.repetition_penalty
204
+ mlx_sampler_config.presence_penalty = g_cfg.sampler_config.presence_penalty
205
+ mlx_sampler_config.frequency_penalty = g_cfg.sampler_config.frequency_penalty
206
+ mlx_sampler_config.seed = g_cfg.sampler_config.seed
207
+ mlx_sampler_config.grammar_path = g_cfg.sampler_config.grammar_path
208
+ mlx_sampler_config.grammar_string = g_cfg.sampler_config.grammar_string
209
+ mlx_gen_config.sampler_config = mlx_sampler_config
210
+
211
+ # Simple token callback that just continues
212
+ def token_callback(token: str, user_data: Any = None) -> bool:
213
+ return not self._cancel_event.is_set()
214
+
215
+ # Use MLX streaming generation and return the full result
216
+ return self._mlx_llm.generate_stream(prompt, mlx_gen_config, token_callback)
217
+
218
+ except Exception as e:
219
+ raise RuntimeError(f"Failed to generate text: {str(e)}")
220
+
221
+ def get_profiling_data(self) -> Optional[ProfilingData]:
222
+ """Get profiling data from the last generation."""
223
+ if not self._mlx_llm:
224
+ raise RuntimeError("MLX LLM not loaded")
225
+ return self._mlx_llm.get_profiling_data()
226
+
227
+ def save_kv_cache(self, path: str):
228
+ """
229
+ Save the key-value cache to the file.
230
+
231
+ Args:
232
+ path (str): The path to the file.
233
+ """
234
+ if not self._mlx_llm:
235
+ raise RuntimeError("MLX LLM not loaded")
236
+
237
+ try:
238
+ success = self._mlx_llm.save_kv_cache(path)
239
+ if not success:
240
+ raise RuntimeError("Failed to save KV cache")
241
+ except Exception as e:
242
+ raise RuntimeError(f"Failed to save KV cache: {str(e)}")
243
+
244
+ def load_kv_cache(self, path: str):
245
+ """
246
+ Load the key-value cache from the file.
247
+
248
+ Args:
249
+ path (str): The path to the file.
250
+ """
251
+ if not self._mlx_llm:
252
+ raise RuntimeError("MLX LLM not loaded")
253
+
254
+ try:
255
+ success = self._mlx_llm.load_kv_cache(path)
256
+ if not success:
257
+ raise RuntimeError("Failed to load KV cache")
258
+ except Exception as e:
259
+ raise RuntimeError(f"Failed to load KV cache: {str(e)}")
260
+
261
+ def reset(self):
262
+ """
263
+ Reset the LLM model context and KV cache.
264
+ """
265
+ if not self._mlx_llm:
266
+ raise RuntimeError("MLX LLM not loaded")
267
+
268
+ try:
269
+ self._mlx_llm.reset()
270
+ except Exception as e:
271
+ raise RuntimeError(f"Failed to reset MLX LLM: {str(e)}")
@@ -0,0 +1,220 @@
1
+ from typing import Generator, Optional, Union
2
+ import queue
3
+ import threading
4
+
5
+ from nexaai.base import ProfilingData
6
+ from nexaai.common import ModelConfig, GenerationConfig, ChatMessage, PluginID
7
+ from nexaai.binds import llm_bind, common_bind
8
+ from nexaai.runtime import _ensure_runtime
9
+ from nexaai.llm import LLM
10
+
11
+
12
+ class PyBindLLMImpl(LLM):
13
+ def __init__(self, handle: any, m_cfg: ModelConfig = ModelConfig()):
14
+ """Private constructor, should not be called directly."""
15
+ super().__init__(m_cfg)
16
+ self._handle = handle # This is a py::capsule
17
+ self._profiling_data = None
18
+
19
+ @classmethod
20
+ def _load_from(cls,
21
+ local_path: str,
22
+ model_name: Optional[str] = None,
23
+ tokenizer_path: Optional[str] = None,
24
+ m_cfg: ModelConfig = ModelConfig(),
25
+ plugin_id: Union[PluginID, str] = PluginID.LLAMA_CPP,
26
+ device_id: Optional[str] = None
27
+ ) -> 'PyBindLLMImpl':
28
+ """Load model from local path."""
29
+ _ensure_runtime()
30
+
31
+ config = common_bind.ModelConfig()
32
+
33
+ config.n_ctx = m_cfg.n_ctx
34
+ if m_cfg.n_threads is not None:
35
+ config.n_threads = m_cfg.n_threads
36
+ if m_cfg.n_threads_batch is not None:
37
+ config.n_threads_batch = m_cfg.n_threads_batch
38
+ if m_cfg.n_batch is not None:
39
+ config.n_batch = m_cfg.n_batch
40
+ if m_cfg.n_ubatch is not None:
41
+ config.n_ubatch = m_cfg.n_ubatch
42
+ if m_cfg.n_seq_max is not None:
43
+ config.n_seq_max = m_cfg.n_seq_max
44
+ if m_cfg.n_gpu_layers is not None:
45
+ config.n_gpu_layers = m_cfg.n_gpu_layers
46
+
47
+ # handle chat template strings
48
+ if m_cfg.chat_template_path:
49
+ config.chat_template_path = m_cfg.chat_template_path
50
+
51
+ if m_cfg.chat_template_content:
52
+ config.chat_template_content = m_cfg.chat_template_content
53
+
54
+ # Create handle : returns py::capsule with automatic cleanup
55
+ # Convert enum to string for C++ binding
56
+ plugin_id_str = plugin_id.value if isinstance(plugin_id, PluginID) else plugin_id
57
+ handle = llm_bind.ml_llm_create(
58
+ model_path=local_path,
59
+ model_name=model_name,
60
+ tokenizer_path=tokenizer_path,
61
+ model_config=config,
62
+ plugin_id=plugin_id_str,
63
+ device_id=device_id
64
+ )
65
+ return cls(handle, m_cfg)
66
+
67
+ def eject(self):
68
+ """Release the model from memory."""
69
+ # py::capsule handles cleanup automatically
70
+ del self._handle
71
+ self._handle = None
72
+
73
+ def apply_chat_template(self, messages: list[ChatMessage], tools: Optional[str] = None, enable_thinking: bool = True, add_generation_prompt: bool = True) -> str:
74
+ """Apply the chat template to messages."""
75
+ # Convert TypedDict to list of dicts for binding
76
+ message_dicts = [
77
+ {"role": m["role"], "content": m["content"]}
78
+ for m in messages
79
+ ]
80
+ return llm_bind.ml_llm_apply_chat_template(self._handle, message_dicts, tools, enable_thinking)
81
+
82
+ def generate_stream(self, prompt: str, g_cfg: GenerationConfig = GenerationConfig()) -> Generator[str, None, None]:
83
+ """Generate text with streaming."""
84
+ token_queue = queue.Queue()
85
+ exception_container = [None]
86
+ self.reset_cancel() # Reset cancel flag before generation
87
+
88
+ def on_token(token: str, user_data) -> bool:
89
+ if self._cancel_event.is_set():
90
+ token_queue.put(('end', None))
91
+ return False # Stop generation
92
+ try:
93
+ token_queue.put(('token', token))
94
+ return True # Continue generation
95
+ except Exception as e:
96
+ exception_container[0] = e
97
+ return False # Stop generation
98
+
99
+ config = self._convert_generation_config(g_cfg)
100
+
101
+ # Run generation in thread
102
+ def generate():
103
+ try:
104
+ result = llm_bind.ml_llm_generate(
105
+ handle=self._handle,
106
+ prompt=prompt,
107
+ config=config,
108
+ on_token=on_token,
109
+ user_data=None
110
+ )
111
+ self._profiling_data = ProfilingData.from_dict(result.get("profile_data", {}))
112
+ except Exception as e:
113
+ exception_container[0] = e
114
+ finally:
115
+ token_queue.put(('end', None))
116
+
117
+ thread = threading.Thread(target=generate)
118
+ thread.start()
119
+
120
+ # Yield tokens as they come
121
+ try:
122
+ while True:
123
+ msg_type, token = token_queue.get()
124
+ if msg_type == 'token':
125
+ yield token
126
+ elif msg_type in ('error', 'end'):
127
+ break
128
+ finally:
129
+ thread.join()
130
+
131
+ if exception_container[0]:
132
+ raise exception_container[0]
133
+
134
+ def generate(self, prompt: str, g_cfg: GenerationConfig = GenerationConfig()) -> str:
135
+ """
136
+ Generate text without streaming.
137
+
138
+ Args:
139
+ prompt (str): The prompt to generate text from. For chat models, this is the chat messages after chat template is applied.
140
+ g_cfg (GenerationConfig): Generation configuration.
141
+
142
+ Returns:
143
+ str: The generated text.
144
+ """
145
+ config = self._convert_generation_config(g_cfg)
146
+ result = llm_bind.ml_llm_generate(
147
+ handle=self._handle,
148
+ prompt=prompt,
149
+ config=config,
150
+ on_token=None, # No callback for non-streaming
151
+ user_data=None
152
+ )
153
+
154
+ self._profiling_data = ProfilingData.from_dict(result.get("profile_data", {}))
155
+ return result.get("text", "")
156
+
157
+ def get_profiling_data(self) -> Optional[ProfilingData]:
158
+ """Get profiling data."""
159
+ return self._profiling_data
160
+
161
+ def save_kv_cache(self, path: str):
162
+ """
163
+ Save the key-value cache to the file.
164
+
165
+ Args:
166
+ path (str): The path to the file.
167
+ """
168
+ llm_bind.ml_llm_save_kv_cache(self._handle, path)
169
+
170
+ def load_kv_cache(self, path: str):
171
+ """
172
+ Load the key-value cache from the file.
173
+
174
+ Args:
175
+ path (str): The path to the file.
176
+ """
177
+ llm_bind.ml_llm_load_kv_cache(self._handle, path)
178
+
179
+ def reset(self):
180
+ """
181
+ Reset the LLM model context and KV cache. If not reset, the model will skip the number of evaluated tokens and treat tokens after those as the new incremental tokens.
182
+ If your past chat history changed, or you are starting a new chat, you should always reset the model before running generate.
183
+ """
184
+ llm_bind.ml_llm_reset(self._handle)
185
+
186
+ def _convert_generation_config(self, g_cfg: GenerationConfig):
187
+ """Convert GenerationConfig to binding format."""
188
+ config = common_bind.GenerationConfig()
189
+
190
+ # Set basic generation parameters
191
+ config.max_tokens = g_cfg.max_tokens
192
+
193
+ if g_cfg.stop_words:
194
+ config.stop = g_cfg.stop_words
195
+
196
+ if g_cfg.image_paths:
197
+ config.image_paths = g_cfg.image_paths
198
+
199
+ if g_cfg.audio_paths:
200
+ config.audio_paths = g_cfg.audio_paths
201
+
202
+ if g_cfg.sampler_config:
203
+ sampler = common_bind.SamplerConfig()
204
+ sampler.temperature = g_cfg.sampler_config.temperature
205
+ sampler.top_p = g_cfg.sampler_config.top_p
206
+ sampler.top_k = g_cfg.sampler_config.top_k
207
+ sampler.repetition_penalty = g_cfg.sampler_config.repetition_penalty
208
+ sampler.presence_penalty = g_cfg.sampler_config.presence_penalty
209
+ sampler.frequency_penalty = g_cfg.sampler_config.frequency_penalty
210
+ sampler.seed = g_cfg.sampler_config.seed
211
+
212
+ if g_cfg.sampler_config.grammar_path:
213
+ sampler.grammar_path = g_cfg.sampler_config.grammar_path
214
+
215
+ if g_cfg.sampler_config.grammar_string:
216
+ sampler.grammar_string = g_cfg.sampler_config.grammar_string
217
+
218
+ config.sampler_config = sampler
219
+
220
+ return config
nexaai/log.py ADDED
@@ -0,0 +1,92 @@
1
+ """
2
+ Logging configuration for NexaAI bridge.
3
+
4
+ This module provides a minimal API to configure bridge-wide logging
5
+ to route into Python's logging system.
6
+ """
7
+
8
+ import logging
9
+ import threading
10
+ from enum import IntEnum
11
+ from typing import Optional
12
+
13
+ from nexaai.binds import common_bind
14
+ from nexaai.runtime import is_initialized
15
+
16
+
17
+ class LogLevel(IntEnum):
18
+ """Log levels matching ml_LogLevel from ml.h"""
19
+ TRACE = 0
20
+ DEBUG = 1
21
+ INFO = 2
22
+ WARN = 3
23
+ ERROR = 4
24
+
25
+
26
+ # Module-level state
27
+ _config_lock = threading.Lock()
28
+ _current_logger: Optional[logging.Logger] = None
29
+
30
+
31
+ def set_logger(logger: Optional[logging.Logger] = None, *, strict: bool = True) -> None:
32
+ """
33
+ Set the process-wide bridge logger.
34
+
35
+ Args:
36
+ logger: Python logger to receive bridge logs. If None, uses "nexaai.ml" logger.
37
+ strict: If True, raises if called after runtime initialization.
38
+ If False, attempts to set anyway (best-effort).
39
+
40
+ Raises:
41
+ RuntimeError: If strict=True and runtime is already initialized.
42
+ """
43
+ global _current_logger
44
+
45
+ with _config_lock:
46
+ # Check initialization state if strict mode
47
+ if strict and is_initialized():
48
+ raise RuntimeError(
49
+ "Cannot configure logging after runtime initialization. "
50
+ "Call set_logger() before creating any models, or use strict=False for best-effort."
51
+ )
52
+
53
+ # Use default logger if none provided
54
+ if logger is None:
55
+ logger = logging.getLogger("nexaai.ml")
56
+
57
+ _current_logger = logger
58
+
59
+ # Set the C callback
60
+ common_bind.ml_set_log(_log_callback)
61
+
62
+
63
+ def _log_callback(level: int, message: str) -> None:
64
+ """Internal callback that forwards bridge logs to Python logger."""
65
+ if _current_logger is None:
66
+ return
67
+
68
+ # Map bridge log levels to Python logging levels
69
+ if level == LogLevel.TRACE or level == LogLevel.DEBUG:
70
+ _current_logger.debug(message)
71
+ elif level == LogLevel.INFO:
72
+ _current_logger.info(message)
73
+ elif level == LogLevel.WARN:
74
+ _current_logger.warning(message)
75
+ elif level == LogLevel.ERROR:
76
+ _current_logger.error(message)
77
+ else:
78
+ # Fallback for unknown levels
79
+ _current_logger.info(f"[Level {level}] {message}")
80
+
81
+
82
+ def get_error_message(error_code: int) -> str:
83
+ """
84
+ Get error message string for error code.
85
+
86
+ Args:
87
+ error_code: ML error code (typically negative)
88
+
89
+ Returns:
90
+ Human-readable error message
91
+ """
92
+ return common_bind.ml_get_error_message(error_code)
nexaai/rerank.py ADDED
@@ -0,0 +1,57 @@
1
+ from typing import List, Optional, Sequence, Union
2
+ from abc import abstractmethod
3
+ from dataclasses import dataclass
4
+
5
+ from nexaai.base import BaseModel
6
+ from nexaai.common import PluginID
7
+
8
+
9
+ @dataclass
10
+ class RerankConfig:
11
+ """Configuration for reranking."""
12
+ batch_size: int = 1
13
+ normalize: bool = True
14
+ normalize_method: str = "softmax" # "softmax" | "min-max" | "none"
15
+
16
+
17
+ class Reranker(BaseModel):
18
+ """Abstract base class for reranker models."""
19
+
20
+ def __init__(self):
21
+ """Initialize base Reranker class."""
22
+ pass
23
+
24
+ @classmethod
25
+ def _load_from(cls,
26
+ model_path: str,
27
+ model_name: str = None,
28
+ tokenizer_file: str = "tokenizer.json",
29
+ plugin_id: Union[PluginID, str] = PluginID.LLAMA_CPP,
30
+ device_id: Optional[str] = None,
31
+ **kwargs
32
+ ) -> 'Reranker':
33
+ """Load reranker model from local path, routing to appropriate implementation."""
34
+ # Check plugin_id value for routing - handle both enum and string
35
+ plugin_value = plugin_id.value if isinstance(plugin_id, PluginID) else plugin_id
36
+
37
+ if plugin_value == "mlx":
38
+ from nexaai.rerank_impl.mlx_rerank_impl import MLXRerankImpl
39
+ return MLXRerankImpl._load_from(model_path, model_name, tokenizer_file, plugin_id, device_id)
40
+ else:
41
+ from nexaai.rerank_impl.pybind_rerank_impl import PyBindRerankImpl
42
+ return PyBindRerankImpl._load_from(model_path, model_name, tokenizer_file, plugin_id, device_id)
43
+
44
+ @abstractmethod
45
+ def load_model(self, model_path: str, extra_data: Optional[str] = None) -> bool:
46
+ """Load model from path."""
47
+ pass
48
+
49
+ @abstractmethod
50
+ def rerank(
51
+ self,
52
+ query: str,
53
+ documents: Sequence[str],
54
+ config: Optional[RerankConfig] = None,
55
+ ) -> List[float]:
56
+ """Rerank documents given a query."""
57
+ pass
File without changes