nexaai 1.0.21rc5__cp313-cp313-win_arm64.whl → 1.0.21rc16__cp313-cp313-win_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nexaai might be problematic. Click here for more details.
- nexaai/__init__.py +95 -95
- nexaai/_stub.cp313-win_arm64.pyd +0 -0
- nexaai/_version.py +4 -1
- nexaai/asr.py +68 -65
- nexaai/asr_impl/mlx_asr_impl.py +92 -92
- nexaai/asr_impl/pybind_asr_impl.py +127 -44
- nexaai/base.py +39 -39
- nexaai/binds/__init__.py +6 -5
- nexaai/binds/asr_bind.cp313-win_arm64.pyd +0 -0
- nexaai/binds/common_bind.cp313-win_arm64.pyd +0 -0
- nexaai/binds/cpu_gpu/ggml-base.dll +0 -0
- nexaai/binds/cpu_gpu/ggml-cpu.dll +0 -0
- nexaai/binds/cpu_gpu/ggml-opencl.dll +0 -0
- nexaai/binds/cpu_gpu/ggml.dll +0 -0
- nexaai/binds/cpu_gpu/mtmd.dll +0 -0
- nexaai/binds/cpu_gpu/nexa_cpu_gpu.dll +0 -0
- nexaai/binds/cpu_gpu/nexa_plugin.dll +0 -0
- nexaai/binds/embedder_bind.cp313-win_arm64.pyd +0 -0
- nexaai/binds/libcrypto-3-arm64.dll +0 -0
- nexaai/binds/libssl-3-arm64.dll +0 -0
- nexaai/binds/llm_bind.cp313-win_arm64.pyd +0 -0
- nexaai/binds/nexa_bridge.dll +0 -0
- nexaai/binds/npu/convnext-sdk.dll +0 -0
- nexaai/binds/npu/embed-gemma-sdk.dll +0 -0
- nexaai/binds/npu/ggml-base.dll +0 -0
- nexaai/binds/npu/ggml-cpu.dll +0 -0
- nexaai/binds/{nexaml → npu}/ggml-opencl.dll +0 -0
- nexaai/binds/npu/ggml.dll +0 -0
- nexaai/binds/npu/granite-nano-sdk.dll +0 -0
- nexaai/binds/npu/granite4-sdk.dll +0 -0
- nexaai/binds/npu/jina-rerank-sdk.dll +0 -0
- nexaai/binds/npu/liquid-sdk.dll +0 -0
- nexaai/binds/npu/llama3-3b-sdk.dll +0 -0
- nexaai/binds/npu/nexa-mm-process.dll +0 -0
- nexaai/binds/npu/nexa-sampling.dll +0 -0
- nexaai/binds/npu/nexa_plugin.dll +0 -0
- nexaai/binds/npu/omni-neural-sdk.dll +0 -0
- nexaai/binds/npu/openblas.dll +0 -0
- nexaai/binds/npu/paddleocr-sdk.dll +0 -0
- nexaai/binds/npu/parakeet-sdk.dll +0 -0
- nexaai/binds/npu/phi3-5-sdk.dll +0 -0
- nexaai/binds/npu/phi4-sdk.dll +0 -0
- nexaai/binds/npu/pyannote-sdk.dll +0 -0
- nexaai/binds/npu/qwen3-4b-sdk.dll +0 -0
- nexaai/binds/npu/qwen3vl-sdk.dll +0 -0
- nexaai/binds/npu/qwen3vl-vision.dll +0 -0
- nexaai/binds/npu/yolov12-sdk.dll +0 -0
- nexaai/binds/npu/zlib1.dll +0 -0
- nexaai/binds/rerank_bind.cp313-win_arm64.pyd +0 -0
- nexaai/binds/vlm_bind.cp313-win_arm64.pyd +0 -0
- nexaai/common.py +105 -105
- nexaai/cv.py +93 -93
- nexaai/cv_impl/mlx_cv_impl.py +89 -89
- nexaai/cv_impl/pybind_cv_impl.py +32 -32
- nexaai/embedder.py +73 -73
- nexaai/embedder_impl/mlx_embedder_impl.py +118 -118
- nexaai/embedder_impl/pybind_embedder_impl.py +96 -96
- nexaai/image_gen.py +141 -141
- nexaai/image_gen_impl/mlx_image_gen_impl.py +292 -292
- nexaai/image_gen_impl/pybind_image_gen_impl.py +85 -85
- nexaai/llm.py +98 -98
- nexaai/llm_impl/mlx_llm_impl.py +271 -271
- nexaai/llm_impl/pybind_llm_impl.py +220 -220
- nexaai/log.py +92 -92
- nexaai/rerank.py +57 -57
- nexaai/rerank_impl/mlx_rerank_impl.py +94 -94
- nexaai/rerank_impl/pybind_rerank_impl.py +136 -136
- nexaai/runtime.py +68 -68
- nexaai/runtime_error.py +24 -24
- nexaai/tts.py +75 -75
- nexaai/tts_impl/mlx_tts_impl.py +94 -94
- nexaai/tts_impl/pybind_tts_impl.py +43 -43
- nexaai/utils/decode.py +17 -17
- nexaai/utils/manifest_utils.py +531 -531
- nexaai/utils/model_manager.py +1562 -1562
- nexaai/utils/model_types.py +49 -49
- nexaai/utils/progress_tracker.py +384 -384
- nexaai/utils/quantization_utils.py +245 -245
- nexaai/vlm.py +129 -129
- nexaai/vlm_impl/mlx_vlm_impl.py +258 -258
- nexaai/vlm_impl/pybind_vlm_impl.py +256 -256
- {nexaai-1.0.21rc5.dist-info → nexaai-1.0.21rc16.dist-info}/METADATA +1 -1
- nexaai-1.0.21rc16.dist-info/RECORD +154 -0
- nexaai/binds/nexaml/FLAC.dll +0 -0
- nexaai/binds/nexaml/fftw3.dll +0 -0
- nexaai/binds/nexaml/fftw3f.dll +0 -0
- nexaai/binds/nexaml/ggml-base.dll +0 -0
- nexaai/binds/nexaml/ggml-cpu.dll +0 -0
- nexaai/binds/nexaml/ggml.dll +0 -0
- nexaai/binds/nexaml/libmp3lame.DLL +0 -0
- nexaai/binds/nexaml/mpg123.dll +0 -0
- nexaai/binds/nexaml/nexa-mm-process.dll +0 -0
- nexaai/binds/nexaml/nexa-sampling.dll +0 -0
- nexaai/binds/nexaml/nexa_plugin.dll +0 -0
- nexaai/binds/nexaml/nexaproc.dll +0 -0
- nexaai/binds/nexaml/ogg.dll +0 -0
- nexaai/binds/nexaml/opus.dll +0 -0
- nexaai/binds/nexaml/qwen3-vl.dll +0 -0
- nexaai/binds/nexaml/qwen3vl-vision.dll +0 -0
- nexaai/binds/nexaml/vorbis.dll +0 -0
- nexaai/binds/nexaml/vorbisenc.dll +0 -0
- nexaai-1.0.21rc5.dist-info/RECORD +0 -162
- {nexaai-1.0.21rc5.dist-info → nexaai-1.0.21rc16.dist-info}/WHEEL +0 -0
- {nexaai-1.0.21rc5.dist-info → nexaai-1.0.21rc16.dist-info}/top_level.txt +0 -0
nexaai/llm_impl/mlx_llm_impl.py
CHANGED
|
@@ -1,271 +1,271 @@
|
|
|
1
|
-
from typing import Generator, Optional, Any, Sequence, Union
|
|
2
|
-
|
|
3
|
-
from nexaai.base import ProfilingData
|
|
4
|
-
from nexaai.common import ModelConfig, GenerationConfig, ChatMessage, PluginID
|
|
5
|
-
from nexaai.llm import LLM
|
|
6
|
-
from nexaai.mlx_backend.llm.interface import LLM as MLXLLMInterface
|
|
7
|
-
from nexaai.mlx_backend.ml import ModelConfig as MLXModelConfig, SamplerConfig as MLXSamplerConfig, GenerationConfig as MLXGenerationConfig, EmbeddingConfig
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
class MLXLLMImpl(LLM):
|
|
11
|
-
def __init__(self, m_cfg: ModelConfig = ModelConfig()):
|
|
12
|
-
"""Initialize MLX LLM implementation."""
|
|
13
|
-
super().__init__(m_cfg)
|
|
14
|
-
self._mlx_llm = None
|
|
15
|
-
|
|
16
|
-
@classmethod
|
|
17
|
-
def _load_from(cls,
|
|
18
|
-
local_path: str,
|
|
19
|
-
model_name: Optional[str] = None,
|
|
20
|
-
tokenizer_path: Optional[str] = None,
|
|
21
|
-
m_cfg: ModelConfig = ModelConfig(),
|
|
22
|
-
plugin_id: Union[PluginID, str] = PluginID.MLX,
|
|
23
|
-
device_id: Optional[str] = None
|
|
24
|
-
) -> 'MLXLLMImpl':
|
|
25
|
-
"""Load model from local path using MLX backend."""
|
|
26
|
-
try:
|
|
27
|
-
# MLX interface and configs are already imported
|
|
28
|
-
|
|
29
|
-
# Convert our ModelConfig to MLX ModelConfig
|
|
30
|
-
mlx_config = MLXModelConfig()
|
|
31
|
-
mlx_config.n_ctx = m_cfg.n_ctx
|
|
32
|
-
mlx_config.n_threads = m_cfg.n_threads
|
|
33
|
-
mlx_config.n_threads_batch = m_cfg.n_threads_batch
|
|
34
|
-
mlx_config.n_batch = m_cfg.n_batch
|
|
35
|
-
mlx_config.n_ubatch = m_cfg.n_ubatch
|
|
36
|
-
mlx_config.n_seq_max = m_cfg.n_seq_max
|
|
37
|
-
mlx_config.chat_template_path = m_cfg.chat_template_path
|
|
38
|
-
mlx_config.chat_template_content = m_cfg.chat_template_content
|
|
39
|
-
|
|
40
|
-
# Create instance and load MLX model
|
|
41
|
-
instance = cls(m_cfg)
|
|
42
|
-
instance._mlx_llm = MLXLLMInterface(
|
|
43
|
-
model_path=local_path,
|
|
44
|
-
# model_name=model_name, # FIXME: For MLX LLM, model_name is not used
|
|
45
|
-
tokenizer_path=tokenizer_path or local_path,
|
|
46
|
-
config=mlx_config,
|
|
47
|
-
device=device_id
|
|
48
|
-
)
|
|
49
|
-
|
|
50
|
-
return instance
|
|
51
|
-
except Exception as e:
|
|
52
|
-
raise RuntimeError(f"Failed to load MLX LLM: {str(e)}")
|
|
53
|
-
|
|
54
|
-
def eject(self):
|
|
55
|
-
"""Release the model from memory."""
|
|
56
|
-
if self._mlx_llm:
|
|
57
|
-
self._mlx_llm.destroy()
|
|
58
|
-
self._mlx_llm = None
|
|
59
|
-
|
|
60
|
-
def apply_chat_template(
|
|
61
|
-
self,
|
|
62
|
-
messages: Sequence[ChatMessage],
|
|
63
|
-
tools: Optional[str] = None,
|
|
64
|
-
enable_thinking: bool = True,
|
|
65
|
-
add_generation_prompt: bool = True
|
|
66
|
-
) -> str:
|
|
67
|
-
"""Apply the chat template to messages."""
|
|
68
|
-
if not self._mlx_llm:
|
|
69
|
-
raise RuntimeError("MLX LLM not loaded")
|
|
70
|
-
|
|
71
|
-
try:
|
|
72
|
-
# Convert to MLX ChatMessage format
|
|
73
|
-
mlx_messages = []
|
|
74
|
-
for msg in messages:
|
|
75
|
-
# Create a simple object with role and content attributes
|
|
76
|
-
class MLXChatMessage:
|
|
77
|
-
def __init__(self, role, content):
|
|
78
|
-
self.role = role
|
|
79
|
-
self.content = content
|
|
80
|
-
|
|
81
|
-
# Handle both dict-style and attribute-style access
|
|
82
|
-
if hasattr(msg, 'role') and hasattr(msg, 'content'):
|
|
83
|
-
# Message is already an object with attributes
|
|
84
|
-
mlx_messages.append(MLXChatMessage(msg.role, msg.content))
|
|
85
|
-
else:
|
|
86
|
-
# Message is a dict
|
|
87
|
-
mlx_messages.append(MLXChatMessage(msg["role"], msg["content"]))
|
|
88
|
-
|
|
89
|
-
return self._mlx_llm.apply_chat_template(mlx_messages, tools=tools, enable_thinking=enable_thinking, add_generation_prompt=add_generation_prompt)
|
|
90
|
-
except Exception as e:
|
|
91
|
-
raise RuntimeError(f"Failed to apply chat template: {str(e)}")
|
|
92
|
-
|
|
93
|
-
def generate_stream(self, prompt: str, g_cfg: GenerationConfig = GenerationConfig()) -> Generator[str, None, None]:
|
|
94
|
-
"""Generate text with streaming."""
|
|
95
|
-
if not self._mlx_llm:
|
|
96
|
-
raise RuntimeError("MLX LLM not loaded")
|
|
97
|
-
|
|
98
|
-
try:
|
|
99
|
-
import queue
|
|
100
|
-
import threading
|
|
101
|
-
|
|
102
|
-
# Convert GenerationConfig to MLX format
|
|
103
|
-
|
|
104
|
-
mlx_gen_config = MLXGenerationConfig()
|
|
105
|
-
mlx_gen_config.max_tokens = g_cfg.max_tokens
|
|
106
|
-
mlx_gen_config.stop = g_cfg.stop_words
|
|
107
|
-
mlx_gen_config.image_paths = g_cfg.image_paths
|
|
108
|
-
mlx_gen_config.audio_paths = g_cfg.audio_paths
|
|
109
|
-
|
|
110
|
-
if g_cfg.sampler_config:
|
|
111
|
-
mlx_sampler_config = MLXSamplerConfig()
|
|
112
|
-
mlx_sampler_config.temperature = g_cfg.sampler_config.temperature
|
|
113
|
-
mlx_sampler_config.top_p = g_cfg.sampler_config.top_p
|
|
114
|
-
mlx_sampler_config.top_k = g_cfg.sampler_config.top_k
|
|
115
|
-
mlx_sampler_config.repetition_penalty = g_cfg.sampler_config.repetition_penalty
|
|
116
|
-
mlx_sampler_config.presence_penalty = g_cfg.sampler_config.presence_penalty
|
|
117
|
-
mlx_sampler_config.frequency_penalty = g_cfg.sampler_config.frequency_penalty
|
|
118
|
-
mlx_sampler_config.seed = g_cfg.sampler_config.seed
|
|
119
|
-
mlx_sampler_config.grammar_path = g_cfg.sampler_config.grammar_path
|
|
120
|
-
mlx_sampler_config.grammar_string = g_cfg.sampler_config.grammar_string
|
|
121
|
-
mlx_gen_config.sampler_config = mlx_sampler_config
|
|
122
|
-
|
|
123
|
-
# Create a queue for streaming tokens
|
|
124
|
-
token_queue = queue.Queue()
|
|
125
|
-
exception_container = [None]
|
|
126
|
-
self.reset_cancel() # Reset cancel flag before generation
|
|
127
|
-
|
|
128
|
-
def token_callback(token: str, user_data: Any = None) -> bool:
|
|
129
|
-
if self._cancel_event.is_set():
|
|
130
|
-
token_queue.put(('end', None))
|
|
131
|
-
return False
|
|
132
|
-
try:
|
|
133
|
-
token_queue.put(('token', token))
|
|
134
|
-
return True
|
|
135
|
-
except Exception as e:
|
|
136
|
-
exception_container[0] = e
|
|
137
|
-
return False
|
|
138
|
-
|
|
139
|
-
# Run generation in a separate thread
|
|
140
|
-
def generate():
|
|
141
|
-
try:
|
|
142
|
-
self._mlx_llm.generate_stream(prompt, mlx_gen_config, token_callback)
|
|
143
|
-
except Exception as e:
|
|
144
|
-
exception_container[0] = e
|
|
145
|
-
finally:
|
|
146
|
-
token_queue.put(('end', None))
|
|
147
|
-
|
|
148
|
-
thread = threading.Thread(target=generate)
|
|
149
|
-
thread.start()
|
|
150
|
-
|
|
151
|
-
# Yield tokens as they come from the queue
|
|
152
|
-
while True:
|
|
153
|
-
if exception_container[0]:
|
|
154
|
-
raise exception_container[0]
|
|
155
|
-
|
|
156
|
-
try:
|
|
157
|
-
msg_type, token = token_queue.get(timeout=0.1)
|
|
158
|
-
if msg_type == 'end':
|
|
159
|
-
break
|
|
160
|
-
elif msg_type == 'token':
|
|
161
|
-
yield token
|
|
162
|
-
except queue.Empty:
|
|
163
|
-
if not thread.is_alive():
|
|
164
|
-
break
|
|
165
|
-
continue
|
|
166
|
-
|
|
167
|
-
thread.join()
|
|
168
|
-
|
|
169
|
-
if exception_container[0]:
|
|
170
|
-
raise exception_container[0]
|
|
171
|
-
|
|
172
|
-
except Exception as e:
|
|
173
|
-
raise RuntimeError(f"Failed to generate streaming text: {str(e)}")
|
|
174
|
-
|
|
175
|
-
def generate(self, prompt: str, g_cfg: GenerationConfig = GenerationConfig()) -> str:
|
|
176
|
-
"""
|
|
177
|
-
Generate text without streaming.
|
|
178
|
-
|
|
179
|
-
Args:
|
|
180
|
-
prompt (str): The prompt to generate text from.
|
|
181
|
-
g_cfg (GenerationConfig): Generation configuration.
|
|
182
|
-
|
|
183
|
-
Returns:
|
|
184
|
-
str: The generated text.
|
|
185
|
-
"""
|
|
186
|
-
if not self._mlx_llm:
|
|
187
|
-
raise RuntimeError("MLX LLM not loaded")
|
|
188
|
-
|
|
189
|
-
try:
|
|
190
|
-
# Convert GenerationConfig to MLX format
|
|
191
|
-
|
|
192
|
-
mlx_gen_config = MLXGenerationConfig()
|
|
193
|
-
mlx_gen_config.max_tokens = g_cfg.max_tokens
|
|
194
|
-
mlx_gen_config.stop = g_cfg.stop_words
|
|
195
|
-
mlx_gen_config.image_paths = g_cfg.image_paths
|
|
196
|
-
mlx_gen_config.audio_paths = g_cfg.audio_paths
|
|
197
|
-
|
|
198
|
-
if g_cfg.sampler_config:
|
|
199
|
-
mlx_sampler_config = MLXSamplerConfig()
|
|
200
|
-
mlx_sampler_config.temperature = g_cfg.sampler_config.temperature
|
|
201
|
-
mlx_sampler_config.top_p = g_cfg.sampler_config.top_p
|
|
202
|
-
mlx_sampler_config.top_k = g_cfg.sampler_config.top_k
|
|
203
|
-
mlx_sampler_config.repetition_penalty = g_cfg.sampler_config.repetition_penalty
|
|
204
|
-
mlx_sampler_config.presence_penalty = g_cfg.sampler_config.presence_penalty
|
|
205
|
-
mlx_sampler_config.frequency_penalty = g_cfg.sampler_config.frequency_penalty
|
|
206
|
-
mlx_sampler_config.seed = g_cfg.sampler_config.seed
|
|
207
|
-
mlx_sampler_config.grammar_path = g_cfg.sampler_config.grammar_path
|
|
208
|
-
mlx_sampler_config.grammar_string = g_cfg.sampler_config.grammar_string
|
|
209
|
-
mlx_gen_config.sampler_config = mlx_sampler_config
|
|
210
|
-
|
|
211
|
-
# Simple token callback that just continues
|
|
212
|
-
def token_callback(token: str, user_data: Any = None) -> bool:
|
|
213
|
-
return not self._cancel_event.is_set()
|
|
214
|
-
|
|
215
|
-
# Use MLX streaming generation and return the full result
|
|
216
|
-
return self._mlx_llm.generate_stream(prompt, mlx_gen_config, token_callback)
|
|
217
|
-
|
|
218
|
-
except Exception as e:
|
|
219
|
-
raise RuntimeError(f"Failed to generate text: {str(e)}")
|
|
220
|
-
|
|
221
|
-
def get_profiling_data(self) -> Optional[ProfilingData]:
|
|
222
|
-
"""Get profiling data from the last generation."""
|
|
223
|
-
if not self._mlx_llm:
|
|
224
|
-
raise RuntimeError("MLX LLM not loaded")
|
|
225
|
-
return self._mlx_llm.get_profiling_data()
|
|
226
|
-
|
|
227
|
-
def save_kv_cache(self, path: str):
|
|
228
|
-
"""
|
|
229
|
-
Save the key-value cache to the file.
|
|
230
|
-
|
|
231
|
-
Args:
|
|
232
|
-
path (str): The path to the file.
|
|
233
|
-
"""
|
|
234
|
-
if not self._mlx_llm:
|
|
235
|
-
raise RuntimeError("MLX LLM not loaded")
|
|
236
|
-
|
|
237
|
-
try:
|
|
238
|
-
success = self._mlx_llm.save_kv_cache(path)
|
|
239
|
-
if not success:
|
|
240
|
-
raise RuntimeError("Failed to save KV cache")
|
|
241
|
-
except Exception as e:
|
|
242
|
-
raise RuntimeError(f"Failed to save KV cache: {str(e)}")
|
|
243
|
-
|
|
244
|
-
def load_kv_cache(self, path: str):
|
|
245
|
-
"""
|
|
246
|
-
Load the key-value cache from the file.
|
|
247
|
-
|
|
248
|
-
Args:
|
|
249
|
-
path (str): The path to the file.
|
|
250
|
-
"""
|
|
251
|
-
if not self._mlx_llm:
|
|
252
|
-
raise RuntimeError("MLX LLM not loaded")
|
|
253
|
-
|
|
254
|
-
try:
|
|
255
|
-
success = self._mlx_llm.load_kv_cache(path)
|
|
256
|
-
if not success:
|
|
257
|
-
raise RuntimeError("Failed to load KV cache")
|
|
258
|
-
except Exception as e:
|
|
259
|
-
raise RuntimeError(f"Failed to load KV cache: {str(e)}")
|
|
260
|
-
|
|
261
|
-
def reset(self):
|
|
262
|
-
"""
|
|
263
|
-
Reset the LLM model context and KV cache.
|
|
264
|
-
"""
|
|
265
|
-
if not self._mlx_llm:
|
|
266
|
-
raise RuntimeError("MLX LLM not loaded")
|
|
267
|
-
|
|
268
|
-
try:
|
|
269
|
-
self._mlx_llm.reset()
|
|
270
|
-
except Exception as e:
|
|
271
|
-
raise RuntimeError(f"Failed to reset MLX LLM: {str(e)}")
|
|
1
|
+
from typing import Generator, Optional, Any, Sequence, Union
|
|
2
|
+
|
|
3
|
+
from nexaai.base import ProfilingData
|
|
4
|
+
from nexaai.common import ModelConfig, GenerationConfig, ChatMessage, PluginID
|
|
5
|
+
from nexaai.llm import LLM
|
|
6
|
+
from nexaai.mlx_backend.llm.interface import LLM as MLXLLMInterface
|
|
7
|
+
from nexaai.mlx_backend.ml import ModelConfig as MLXModelConfig, SamplerConfig as MLXSamplerConfig, GenerationConfig as MLXGenerationConfig, EmbeddingConfig
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class MLXLLMImpl(LLM):
|
|
11
|
+
def __init__(self, m_cfg: ModelConfig = ModelConfig()):
|
|
12
|
+
"""Initialize MLX LLM implementation."""
|
|
13
|
+
super().__init__(m_cfg)
|
|
14
|
+
self._mlx_llm = None
|
|
15
|
+
|
|
16
|
+
@classmethod
|
|
17
|
+
def _load_from(cls,
|
|
18
|
+
local_path: str,
|
|
19
|
+
model_name: Optional[str] = None,
|
|
20
|
+
tokenizer_path: Optional[str] = None,
|
|
21
|
+
m_cfg: ModelConfig = ModelConfig(),
|
|
22
|
+
plugin_id: Union[PluginID, str] = PluginID.MLX,
|
|
23
|
+
device_id: Optional[str] = None
|
|
24
|
+
) -> 'MLXLLMImpl':
|
|
25
|
+
"""Load model from local path using MLX backend."""
|
|
26
|
+
try:
|
|
27
|
+
# MLX interface and configs are already imported
|
|
28
|
+
|
|
29
|
+
# Convert our ModelConfig to MLX ModelConfig
|
|
30
|
+
mlx_config = MLXModelConfig()
|
|
31
|
+
mlx_config.n_ctx = m_cfg.n_ctx
|
|
32
|
+
mlx_config.n_threads = m_cfg.n_threads
|
|
33
|
+
mlx_config.n_threads_batch = m_cfg.n_threads_batch
|
|
34
|
+
mlx_config.n_batch = m_cfg.n_batch
|
|
35
|
+
mlx_config.n_ubatch = m_cfg.n_ubatch
|
|
36
|
+
mlx_config.n_seq_max = m_cfg.n_seq_max
|
|
37
|
+
mlx_config.chat_template_path = m_cfg.chat_template_path
|
|
38
|
+
mlx_config.chat_template_content = m_cfg.chat_template_content
|
|
39
|
+
|
|
40
|
+
# Create instance and load MLX model
|
|
41
|
+
instance = cls(m_cfg)
|
|
42
|
+
instance._mlx_llm = MLXLLMInterface(
|
|
43
|
+
model_path=local_path,
|
|
44
|
+
# model_name=model_name, # FIXME: For MLX LLM, model_name is not used
|
|
45
|
+
tokenizer_path=tokenizer_path or local_path,
|
|
46
|
+
config=mlx_config,
|
|
47
|
+
device=device_id
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
return instance
|
|
51
|
+
except Exception as e:
|
|
52
|
+
raise RuntimeError(f"Failed to load MLX LLM: {str(e)}")
|
|
53
|
+
|
|
54
|
+
def eject(self):
|
|
55
|
+
"""Release the model from memory."""
|
|
56
|
+
if self._mlx_llm:
|
|
57
|
+
self._mlx_llm.destroy()
|
|
58
|
+
self._mlx_llm = None
|
|
59
|
+
|
|
60
|
+
def apply_chat_template(
|
|
61
|
+
self,
|
|
62
|
+
messages: Sequence[ChatMessage],
|
|
63
|
+
tools: Optional[str] = None,
|
|
64
|
+
enable_thinking: bool = True,
|
|
65
|
+
add_generation_prompt: bool = True
|
|
66
|
+
) -> str:
|
|
67
|
+
"""Apply the chat template to messages."""
|
|
68
|
+
if not self._mlx_llm:
|
|
69
|
+
raise RuntimeError("MLX LLM not loaded")
|
|
70
|
+
|
|
71
|
+
try:
|
|
72
|
+
# Convert to MLX ChatMessage format
|
|
73
|
+
mlx_messages = []
|
|
74
|
+
for msg in messages:
|
|
75
|
+
# Create a simple object with role and content attributes
|
|
76
|
+
class MLXChatMessage:
|
|
77
|
+
def __init__(self, role, content):
|
|
78
|
+
self.role = role
|
|
79
|
+
self.content = content
|
|
80
|
+
|
|
81
|
+
# Handle both dict-style and attribute-style access
|
|
82
|
+
if hasattr(msg, 'role') and hasattr(msg, 'content'):
|
|
83
|
+
# Message is already an object with attributes
|
|
84
|
+
mlx_messages.append(MLXChatMessage(msg.role, msg.content))
|
|
85
|
+
else:
|
|
86
|
+
# Message is a dict
|
|
87
|
+
mlx_messages.append(MLXChatMessage(msg["role"], msg["content"]))
|
|
88
|
+
|
|
89
|
+
return self._mlx_llm.apply_chat_template(mlx_messages, tools=tools, enable_thinking=enable_thinking, add_generation_prompt=add_generation_prompt)
|
|
90
|
+
except Exception as e:
|
|
91
|
+
raise RuntimeError(f"Failed to apply chat template: {str(e)}")
|
|
92
|
+
|
|
93
|
+
def generate_stream(self, prompt: str, g_cfg: GenerationConfig = GenerationConfig()) -> Generator[str, None, None]:
|
|
94
|
+
"""Generate text with streaming."""
|
|
95
|
+
if not self._mlx_llm:
|
|
96
|
+
raise RuntimeError("MLX LLM not loaded")
|
|
97
|
+
|
|
98
|
+
try:
|
|
99
|
+
import queue
|
|
100
|
+
import threading
|
|
101
|
+
|
|
102
|
+
# Convert GenerationConfig to MLX format
|
|
103
|
+
|
|
104
|
+
mlx_gen_config = MLXGenerationConfig()
|
|
105
|
+
mlx_gen_config.max_tokens = g_cfg.max_tokens
|
|
106
|
+
mlx_gen_config.stop = g_cfg.stop_words
|
|
107
|
+
mlx_gen_config.image_paths = g_cfg.image_paths
|
|
108
|
+
mlx_gen_config.audio_paths = g_cfg.audio_paths
|
|
109
|
+
|
|
110
|
+
if g_cfg.sampler_config:
|
|
111
|
+
mlx_sampler_config = MLXSamplerConfig()
|
|
112
|
+
mlx_sampler_config.temperature = g_cfg.sampler_config.temperature
|
|
113
|
+
mlx_sampler_config.top_p = g_cfg.sampler_config.top_p
|
|
114
|
+
mlx_sampler_config.top_k = g_cfg.sampler_config.top_k
|
|
115
|
+
mlx_sampler_config.repetition_penalty = g_cfg.sampler_config.repetition_penalty
|
|
116
|
+
mlx_sampler_config.presence_penalty = g_cfg.sampler_config.presence_penalty
|
|
117
|
+
mlx_sampler_config.frequency_penalty = g_cfg.sampler_config.frequency_penalty
|
|
118
|
+
mlx_sampler_config.seed = g_cfg.sampler_config.seed
|
|
119
|
+
mlx_sampler_config.grammar_path = g_cfg.sampler_config.grammar_path
|
|
120
|
+
mlx_sampler_config.grammar_string = g_cfg.sampler_config.grammar_string
|
|
121
|
+
mlx_gen_config.sampler_config = mlx_sampler_config
|
|
122
|
+
|
|
123
|
+
# Create a queue for streaming tokens
|
|
124
|
+
token_queue = queue.Queue()
|
|
125
|
+
exception_container = [None]
|
|
126
|
+
self.reset_cancel() # Reset cancel flag before generation
|
|
127
|
+
|
|
128
|
+
def token_callback(token: str, user_data: Any = None) -> bool:
|
|
129
|
+
if self._cancel_event.is_set():
|
|
130
|
+
token_queue.put(('end', None))
|
|
131
|
+
return False
|
|
132
|
+
try:
|
|
133
|
+
token_queue.put(('token', token))
|
|
134
|
+
return True
|
|
135
|
+
except Exception as e:
|
|
136
|
+
exception_container[0] = e
|
|
137
|
+
return False
|
|
138
|
+
|
|
139
|
+
# Run generation in a separate thread
|
|
140
|
+
def generate():
|
|
141
|
+
try:
|
|
142
|
+
self._mlx_llm.generate_stream(prompt, mlx_gen_config, token_callback)
|
|
143
|
+
except Exception as e:
|
|
144
|
+
exception_container[0] = e
|
|
145
|
+
finally:
|
|
146
|
+
token_queue.put(('end', None))
|
|
147
|
+
|
|
148
|
+
thread = threading.Thread(target=generate)
|
|
149
|
+
thread.start()
|
|
150
|
+
|
|
151
|
+
# Yield tokens as they come from the queue
|
|
152
|
+
while True:
|
|
153
|
+
if exception_container[0]:
|
|
154
|
+
raise exception_container[0]
|
|
155
|
+
|
|
156
|
+
try:
|
|
157
|
+
msg_type, token = token_queue.get(timeout=0.1)
|
|
158
|
+
if msg_type == 'end':
|
|
159
|
+
break
|
|
160
|
+
elif msg_type == 'token':
|
|
161
|
+
yield token
|
|
162
|
+
except queue.Empty:
|
|
163
|
+
if not thread.is_alive():
|
|
164
|
+
break
|
|
165
|
+
continue
|
|
166
|
+
|
|
167
|
+
thread.join()
|
|
168
|
+
|
|
169
|
+
if exception_container[0]:
|
|
170
|
+
raise exception_container[0]
|
|
171
|
+
|
|
172
|
+
except Exception as e:
|
|
173
|
+
raise RuntimeError(f"Failed to generate streaming text: {str(e)}")
|
|
174
|
+
|
|
175
|
+
def generate(self, prompt: str, g_cfg: GenerationConfig = GenerationConfig()) -> str:
|
|
176
|
+
"""
|
|
177
|
+
Generate text without streaming.
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
prompt (str): The prompt to generate text from.
|
|
181
|
+
g_cfg (GenerationConfig): Generation configuration.
|
|
182
|
+
|
|
183
|
+
Returns:
|
|
184
|
+
str: The generated text.
|
|
185
|
+
"""
|
|
186
|
+
if not self._mlx_llm:
|
|
187
|
+
raise RuntimeError("MLX LLM not loaded")
|
|
188
|
+
|
|
189
|
+
try:
|
|
190
|
+
# Convert GenerationConfig to MLX format
|
|
191
|
+
|
|
192
|
+
mlx_gen_config = MLXGenerationConfig()
|
|
193
|
+
mlx_gen_config.max_tokens = g_cfg.max_tokens
|
|
194
|
+
mlx_gen_config.stop = g_cfg.stop_words
|
|
195
|
+
mlx_gen_config.image_paths = g_cfg.image_paths
|
|
196
|
+
mlx_gen_config.audio_paths = g_cfg.audio_paths
|
|
197
|
+
|
|
198
|
+
if g_cfg.sampler_config:
|
|
199
|
+
mlx_sampler_config = MLXSamplerConfig()
|
|
200
|
+
mlx_sampler_config.temperature = g_cfg.sampler_config.temperature
|
|
201
|
+
mlx_sampler_config.top_p = g_cfg.sampler_config.top_p
|
|
202
|
+
mlx_sampler_config.top_k = g_cfg.sampler_config.top_k
|
|
203
|
+
mlx_sampler_config.repetition_penalty = g_cfg.sampler_config.repetition_penalty
|
|
204
|
+
mlx_sampler_config.presence_penalty = g_cfg.sampler_config.presence_penalty
|
|
205
|
+
mlx_sampler_config.frequency_penalty = g_cfg.sampler_config.frequency_penalty
|
|
206
|
+
mlx_sampler_config.seed = g_cfg.sampler_config.seed
|
|
207
|
+
mlx_sampler_config.grammar_path = g_cfg.sampler_config.grammar_path
|
|
208
|
+
mlx_sampler_config.grammar_string = g_cfg.sampler_config.grammar_string
|
|
209
|
+
mlx_gen_config.sampler_config = mlx_sampler_config
|
|
210
|
+
|
|
211
|
+
# Simple token callback that just continues
|
|
212
|
+
def token_callback(token: str, user_data: Any = None) -> bool:
|
|
213
|
+
return not self._cancel_event.is_set()
|
|
214
|
+
|
|
215
|
+
# Use MLX streaming generation and return the full result
|
|
216
|
+
return self._mlx_llm.generate_stream(prompt, mlx_gen_config, token_callback)
|
|
217
|
+
|
|
218
|
+
except Exception as e:
|
|
219
|
+
raise RuntimeError(f"Failed to generate text: {str(e)}")
|
|
220
|
+
|
|
221
|
+
def get_profiling_data(self) -> Optional[ProfilingData]:
|
|
222
|
+
"""Get profiling data from the last generation."""
|
|
223
|
+
if not self._mlx_llm:
|
|
224
|
+
raise RuntimeError("MLX LLM not loaded")
|
|
225
|
+
return self._mlx_llm.get_profiling_data()
|
|
226
|
+
|
|
227
|
+
def save_kv_cache(self, path: str):
|
|
228
|
+
"""
|
|
229
|
+
Save the key-value cache to the file.
|
|
230
|
+
|
|
231
|
+
Args:
|
|
232
|
+
path (str): The path to the file.
|
|
233
|
+
"""
|
|
234
|
+
if not self._mlx_llm:
|
|
235
|
+
raise RuntimeError("MLX LLM not loaded")
|
|
236
|
+
|
|
237
|
+
try:
|
|
238
|
+
success = self._mlx_llm.save_kv_cache(path)
|
|
239
|
+
if not success:
|
|
240
|
+
raise RuntimeError("Failed to save KV cache")
|
|
241
|
+
except Exception as e:
|
|
242
|
+
raise RuntimeError(f"Failed to save KV cache: {str(e)}")
|
|
243
|
+
|
|
244
|
+
def load_kv_cache(self, path: str):
|
|
245
|
+
"""
|
|
246
|
+
Load the key-value cache from the file.
|
|
247
|
+
|
|
248
|
+
Args:
|
|
249
|
+
path (str): The path to the file.
|
|
250
|
+
"""
|
|
251
|
+
if not self._mlx_llm:
|
|
252
|
+
raise RuntimeError("MLX LLM not loaded")
|
|
253
|
+
|
|
254
|
+
try:
|
|
255
|
+
success = self._mlx_llm.load_kv_cache(path)
|
|
256
|
+
if not success:
|
|
257
|
+
raise RuntimeError("Failed to load KV cache")
|
|
258
|
+
except Exception as e:
|
|
259
|
+
raise RuntimeError(f"Failed to load KV cache: {str(e)}")
|
|
260
|
+
|
|
261
|
+
def reset(self):
|
|
262
|
+
"""
|
|
263
|
+
Reset the LLM model context and KV cache.
|
|
264
|
+
"""
|
|
265
|
+
if not self._mlx_llm:
|
|
266
|
+
raise RuntimeError("MLX LLM not loaded")
|
|
267
|
+
|
|
268
|
+
try:
|
|
269
|
+
self._mlx_llm.reset()
|
|
270
|
+
except Exception as e:
|
|
271
|
+
raise RuntimeError(f"Failed to reset MLX LLM: {str(e)}")
|