nexaai 1.0.21rc5__cp313-cp313-win_arm64.whl → 1.0.21rc14__cp313-cp313-win_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nexaai might be problematic. Click here for more details.

Files changed (105) hide show
  1. nexaai/__init__.py +95 -95
  2. nexaai/_stub.cp313-win_arm64.pyd +0 -0
  3. nexaai/_version.py +4 -1
  4. nexaai/asr.py +68 -65
  5. nexaai/asr_impl/mlx_asr_impl.py +92 -92
  6. nexaai/asr_impl/pybind_asr_impl.py +127 -44
  7. nexaai/base.py +39 -39
  8. nexaai/binds/__init__.py +6 -5
  9. nexaai/binds/asr_bind.cp313-win_arm64.pyd +0 -0
  10. nexaai/binds/common_bind.cp313-win_arm64.pyd +0 -0
  11. nexaai/binds/cpu_gpu/ggml-base.dll +0 -0
  12. nexaai/binds/cpu_gpu/ggml-cpu.dll +0 -0
  13. nexaai/binds/cpu_gpu/ggml-opencl.dll +0 -0
  14. nexaai/binds/cpu_gpu/ggml.dll +0 -0
  15. nexaai/binds/cpu_gpu/mtmd.dll +0 -0
  16. nexaai/binds/cpu_gpu/nexa_cpu_gpu.dll +0 -0
  17. nexaai/binds/cpu_gpu/nexa_plugin.dll +0 -0
  18. nexaai/binds/embedder_bind.cp313-win_arm64.pyd +0 -0
  19. nexaai/binds/libcrypto-3-arm64.dll +0 -0
  20. nexaai/binds/libssl-3-arm64.dll +0 -0
  21. nexaai/binds/llm_bind.cp313-win_arm64.pyd +0 -0
  22. nexaai/binds/nexa_bridge.dll +0 -0
  23. nexaai/binds/npu/convnext-sdk.dll +0 -0
  24. nexaai/binds/npu/embed-gemma-sdk.dll +0 -0
  25. nexaai/binds/npu/ggml-base.dll +0 -0
  26. nexaai/binds/npu/ggml-cpu.dll +0 -0
  27. nexaai/binds/npu/ggml-opencl.dll +0 -0
  28. nexaai/binds/npu/ggml.dll +0 -0
  29. nexaai/binds/npu/granite-nano-sdk.dll +0 -0
  30. nexaai/binds/npu/granite4-sdk.dll +0 -0
  31. nexaai/binds/npu/jina-rerank-sdk.dll +0 -0
  32. nexaai/binds/npu/liquid-sdk.dll +0 -0
  33. nexaai/binds/npu/llama3-3b-sdk.dll +0 -0
  34. nexaai/binds/npu/nexa-mm-process.dll +0 -0
  35. nexaai/binds/npu/nexa-sampling.dll +0 -0
  36. nexaai/binds/npu/nexa_plugin.dll +0 -0
  37. nexaai/binds/npu/omni-neural-sdk.dll +0 -0
  38. nexaai/binds/npu/openblas.dll +0 -0
  39. nexaai/binds/npu/paddleocr-sdk.dll +0 -0
  40. nexaai/binds/npu/parakeet-sdk.dll +0 -0
  41. nexaai/binds/npu/phi3-5-sdk.dll +0 -0
  42. nexaai/binds/npu/phi4-sdk.dll +0 -0
  43. nexaai/binds/npu/pyannote-sdk.dll +0 -0
  44. nexaai/binds/npu/qwen3-4b-sdk.dll +0 -0
  45. nexaai/binds/npu/qwen3vl-sdk.dll +0 -0
  46. nexaai/binds/npu/qwen3vl-vision.dll +0 -0
  47. nexaai/binds/npu/yolov12-sdk.dll +0 -0
  48. nexaai/binds/npu/zlib1.dll +0 -0
  49. nexaai/binds/rerank_bind.cp313-win_arm64.pyd +0 -0
  50. nexaai/binds/vlm_bind.cp313-win_arm64.pyd +0 -0
  51. nexaai/common.py +105 -105
  52. nexaai/cv.py +93 -93
  53. nexaai/cv_impl/mlx_cv_impl.py +89 -89
  54. nexaai/cv_impl/pybind_cv_impl.py +32 -32
  55. nexaai/embedder.py +73 -73
  56. nexaai/embedder_impl/mlx_embedder_impl.py +118 -118
  57. nexaai/embedder_impl/pybind_embedder_impl.py +96 -96
  58. nexaai/image_gen.py +141 -141
  59. nexaai/image_gen_impl/mlx_image_gen_impl.py +292 -292
  60. nexaai/image_gen_impl/pybind_image_gen_impl.py +85 -85
  61. nexaai/llm.py +98 -98
  62. nexaai/llm_impl/mlx_llm_impl.py +271 -271
  63. nexaai/llm_impl/pybind_llm_impl.py +220 -220
  64. nexaai/log.py +92 -92
  65. nexaai/rerank.py +57 -57
  66. nexaai/rerank_impl/mlx_rerank_impl.py +94 -94
  67. nexaai/rerank_impl/pybind_rerank_impl.py +136 -136
  68. nexaai/runtime.py +68 -68
  69. nexaai/runtime_error.py +24 -24
  70. nexaai/tts.py +75 -75
  71. nexaai/tts_impl/mlx_tts_impl.py +94 -94
  72. nexaai/tts_impl/pybind_tts_impl.py +43 -43
  73. nexaai/utils/decode.py +17 -17
  74. nexaai/utils/manifest_utils.py +531 -531
  75. nexaai/utils/model_manager.py +1562 -1562
  76. nexaai/utils/model_types.py +49 -49
  77. nexaai/utils/progress_tracker.py +384 -384
  78. nexaai/utils/quantization_utils.py +245 -245
  79. nexaai/vlm.py +129 -129
  80. nexaai/vlm_impl/mlx_vlm_impl.py +258 -258
  81. nexaai/vlm_impl/pybind_vlm_impl.py +256 -256
  82. {nexaai-1.0.21rc5.dist-info → nexaai-1.0.21rc14.dist-info}/METADATA +1 -1
  83. nexaai-1.0.21rc14.dist-info/RECORD +154 -0
  84. nexaai/binds/nexaml/FLAC.dll +0 -0
  85. nexaai/binds/nexaml/fftw3.dll +0 -0
  86. nexaai/binds/nexaml/fftw3f.dll +0 -0
  87. nexaai/binds/nexaml/ggml-base.dll +0 -0
  88. nexaai/binds/nexaml/ggml-cpu.dll +0 -0
  89. nexaai/binds/nexaml/ggml-opencl.dll +0 -0
  90. nexaai/binds/nexaml/ggml.dll +0 -0
  91. nexaai/binds/nexaml/libmp3lame.DLL +0 -0
  92. nexaai/binds/nexaml/mpg123.dll +0 -0
  93. nexaai/binds/nexaml/nexa-mm-process.dll +0 -0
  94. nexaai/binds/nexaml/nexa-sampling.dll +0 -0
  95. nexaai/binds/nexaml/nexa_plugin.dll +0 -0
  96. nexaai/binds/nexaml/nexaproc.dll +0 -0
  97. nexaai/binds/nexaml/ogg.dll +0 -0
  98. nexaai/binds/nexaml/opus.dll +0 -0
  99. nexaai/binds/nexaml/qwen3-vl.dll +0 -0
  100. nexaai/binds/nexaml/qwen3vl-vision.dll +0 -0
  101. nexaai/binds/nexaml/vorbis.dll +0 -0
  102. nexaai/binds/nexaml/vorbisenc.dll +0 -0
  103. nexaai-1.0.21rc5.dist-info/RECORD +0 -162
  104. {nexaai-1.0.21rc5.dist-info → nexaai-1.0.21rc14.dist-info}/WHEEL +0 -0
  105. {nexaai-1.0.21rc5.dist-info → nexaai-1.0.21rc14.dist-info}/top_level.txt +0 -0
@@ -1,271 +1,271 @@
1
- from typing import Generator, Optional, Any, Sequence, Union
2
-
3
- from nexaai.base import ProfilingData
4
- from nexaai.common import ModelConfig, GenerationConfig, ChatMessage, PluginID
5
- from nexaai.llm import LLM
6
- from nexaai.mlx_backend.llm.interface import LLM as MLXLLMInterface
7
- from nexaai.mlx_backend.ml import ModelConfig as MLXModelConfig, SamplerConfig as MLXSamplerConfig, GenerationConfig as MLXGenerationConfig, EmbeddingConfig
8
-
9
-
10
- class MLXLLMImpl(LLM):
11
- def __init__(self, m_cfg: ModelConfig = ModelConfig()):
12
- """Initialize MLX LLM implementation."""
13
- super().__init__(m_cfg)
14
- self._mlx_llm = None
15
-
16
- @classmethod
17
- def _load_from(cls,
18
- local_path: str,
19
- model_name: Optional[str] = None,
20
- tokenizer_path: Optional[str] = None,
21
- m_cfg: ModelConfig = ModelConfig(),
22
- plugin_id: Union[PluginID, str] = PluginID.MLX,
23
- device_id: Optional[str] = None
24
- ) -> 'MLXLLMImpl':
25
- """Load model from local path using MLX backend."""
26
- try:
27
- # MLX interface and configs are already imported
28
-
29
- # Convert our ModelConfig to MLX ModelConfig
30
- mlx_config = MLXModelConfig()
31
- mlx_config.n_ctx = m_cfg.n_ctx
32
- mlx_config.n_threads = m_cfg.n_threads
33
- mlx_config.n_threads_batch = m_cfg.n_threads_batch
34
- mlx_config.n_batch = m_cfg.n_batch
35
- mlx_config.n_ubatch = m_cfg.n_ubatch
36
- mlx_config.n_seq_max = m_cfg.n_seq_max
37
- mlx_config.chat_template_path = m_cfg.chat_template_path
38
- mlx_config.chat_template_content = m_cfg.chat_template_content
39
-
40
- # Create instance and load MLX model
41
- instance = cls(m_cfg)
42
- instance._mlx_llm = MLXLLMInterface(
43
- model_path=local_path,
44
- # model_name=model_name, # FIXME: For MLX LLM, model_name is not used
45
- tokenizer_path=tokenizer_path or local_path,
46
- config=mlx_config,
47
- device=device_id
48
- )
49
-
50
- return instance
51
- except Exception as e:
52
- raise RuntimeError(f"Failed to load MLX LLM: {str(e)}")
53
-
54
- def eject(self):
55
- """Release the model from memory."""
56
- if self._mlx_llm:
57
- self._mlx_llm.destroy()
58
- self._mlx_llm = None
59
-
60
- def apply_chat_template(
61
- self,
62
- messages: Sequence[ChatMessage],
63
- tools: Optional[str] = None,
64
- enable_thinking: bool = True,
65
- add_generation_prompt: bool = True
66
- ) -> str:
67
- """Apply the chat template to messages."""
68
- if not self._mlx_llm:
69
- raise RuntimeError("MLX LLM not loaded")
70
-
71
- try:
72
- # Convert to MLX ChatMessage format
73
- mlx_messages = []
74
- for msg in messages:
75
- # Create a simple object with role and content attributes
76
- class MLXChatMessage:
77
- def __init__(self, role, content):
78
- self.role = role
79
- self.content = content
80
-
81
- # Handle both dict-style and attribute-style access
82
- if hasattr(msg, 'role') and hasattr(msg, 'content'):
83
- # Message is already an object with attributes
84
- mlx_messages.append(MLXChatMessage(msg.role, msg.content))
85
- else:
86
- # Message is a dict
87
- mlx_messages.append(MLXChatMessage(msg["role"], msg["content"]))
88
-
89
- return self._mlx_llm.apply_chat_template(mlx_messages, tools=tools, enable_thinking=enable_thinking, add_generation_prompt=add_generation_prompt)
90
- except Exception as e:
91
- raise RuntimeError(f"Failed to apply chat template: {str(e)}")
92
-
93
- def generate_stream(self, prompt: str, g_cfg: GenerationConfig = GenerationConfig()) -> Generator[str, None, None]:
94
- """Generate text with streaming."""
95
- if not self._mlx_llm:
96
- raise RuntimeError("MLX LLM not loaded")
97
-
98
- try:
99
- import queue
100
- import threading
101
-
102
- # Convert GenerationConfig to MLX format
103
-
104
- mlx_gen_config = MLXGenerationConfig()
105
- mlx_gen_config.max_tokens = g_cfg.max_tokens
106
- mlx_gen_config.stop = g_cfg.stop_words
107
- mlx_gen_config.image_paths = g_cfg.image_paths
108
- mlx_gen_config.audio_paths = g_cfg.audio_paths
109
-
110
- if g_cfg.sampler_config:
111
- mlx_sampler_config = MLXSamplerConfig()
112
- mlx_sampler_config.temperature = g_cfg.sampler_config.temperature
113
- mlx_sampler_config.top_p = g_cfg.sampler_config.top_p
114
- mlx_sampler_config.top_k = g_cfg.sampler_config.top_k
115
- mlx_sampler_config.repetition_penalty = g_cfg.sampler_config.repetition_penalty
116
- mlx_sampler_config.presence_penalty = g_cfg.sampler_config.presence_penalty
117
- mlx_sampler_config.frequency_penalty = g_cfg.sampler_config.frequency_penalty
118
- mlx_sampler_config.seed = g_cfg.sampler_config.seed
119
- mlx_sampler_config.grammar_path = g_cfg.sampler_config.grammar_path
120
- mlx_sampler_config.grammar_string = g_cfg.sampler_config.grammar_string
121
- mlx_gen_config.sampler_config = mlx_sampler_config
122
-
123
- # Create a queue for streaming tokens
124
- token_queue = queue.Queue()
125
- exception_container = [None]
126
- self.reset_cancel() # Reset cancel flag before generation
127
-
128
- def token_callback(token: str, user_data: Any = None) -> bool:
129
- if self._cancel_event.is_set():
130
- token_queue.put(('end', None))
131
- return False
132
- try:
133
- token_queue.put(('token', token))
134
- return True
135
- except Exception as e:
136
- exception_container[0] = e
137
- return False
138
-
139
- # Run generation in a separate thread
140
- def generate():
141
- try:
142
- self._mlx_llm.generate_stream(prompt, mlx_gen_config, token_callback)
143
- except Exception as e:
144
- exception_container[0] = e
145
- finally:
146
- token_queue.put(('end', None))
147
-
148
- thread = threading.Thread(target=generate)
149
- thread.start()
150
-
151
- # Yield tokens as they come from the queue
152
- while True:
153
- if exception_container[0]:
154
- raise exception_container[0]
155
-
156
- try:
157
- msg_type, token = token_queue.get(timeout=0.1)
158
- if msg_type == 'end':
159
- break
160
- elif msg_type == 'token':
161
- yield token
162
- except queue.Empty:
163
- if not thread.is_alive():
164
- break
165
- continue
166
-
167
- thread.join()
168
-
169
- if exception_container[0]:
170
- raise exception_container[0]
171
-
172
- except Exception as e:
173
- raise RuntimeError(f"Failed to generate streaming text: {str(e)}")
174
-
175
- def generate(self, prompt: str, g_cfg: GenerationConfig = GenerationConfig()) -> str:
176
- """
177
- Generate text without streaming.
178
-
179
- Args:
180
- prompt (str): The prompt to generate text from.
181
- g_cfg (GenerationConfig): Generation configuration.
182
-
183
- Returns:
184
- str: The generated text.
185
- """
186
- if not self._mlx_llm:
187
- raise RuntimeError("MLX LLM not loaded")
188
-
189
- try:
190
- # Convert GenerationConfig to MLX format
191
-
192
- mlx_gen_config = MLXGenerationConfig()
193
- mlx_gen_config.max_tokens = g_cfg.max_tokens
194
- mlx_gen_config.stop = g_cfg.stop_words
195
- mlx_gen_config.image_paths = g_cfg.image_paths
196
- mlx_gen_config.audio_paths = g_cfg.audio_paths
197
-
198
- if g_cfg.sampler_config:
199
- mlx_sampler_config = MLXSamplerConfig()
200
- mlx_sampler_config.temperature = g_cfg.sampler_config.temperature
201
- mlx_sampler_config.top_p = g_cfg.sampler_config.top_p
202
- mlx_sampler_config.top_k = g_cfg.sampler_config.top_k
203
- mlx_sampler_config.repetition_penalty = g_cfg.sampler_config.repetition_penalty
204
- mlx_sampler_config.presence_penalty = g_cfg.sampler_config.presence_penalty
205
- mlx_sampler_config.frequency_penalty = g_cfg.sampler_config.frequency_penalty
206
- mlx_sampler_config.seed = g_cfg.sampler_config.seed
207
- mlx_sampler_config.grammar_path = g_cfg.sampler_config.grammar_path
208
- mlx_sampler_config.grammar_string = g_cfg.sampler_config.grammar_string
209
- mlx_gen_config.sampler_config = mlx_sampler_config
210
-
211
- # Simple token callback that just continues
212
- def token_callback(token: str, user_data: Any = None) -> bool:
213
- return not self._cancel_event.is_set()
214
-
215
- # Use MLX streaming generation and return the full result
216
- return self._mlx_llm.generate_stream(prompt, mlx_gen_config, token_callback)
217
-
218
- except Exception as e:
219
- raise RuntimeError(f"Failed to generate text: {str(e)}")
220
-
221
- def get_profiling_data(self) -> Optional[ProfilingData]:
222
- """Get profiling data from the last generation."""
223
- if not self._mlx_llm:
224
- raise RuntimeError("MLX LLM not loaded")
225
- return self._mlx_llm.get_profiling_data()
226
-
227
- def save_kv_cache(self, path: str):
228
- """
229
- Save the key-value cache to the file.
230
-
231
- Args:
232
- path (str): The path to the file.
233
- """
234
- if not self._mlx_llm:
235
- raise RuntimeError("MLX LLM not loaded")
236
-
237
- try:
238
- success = self._mlx_llm.save_kv_cache(path)
239
- if not success:
240
- raise RuntimeError("Failed to save KV cache")
241
- except Exception as e:
242
- raise RuntimeError(f"Failed to save KV cache: {str(e)}")
243
-
244
- def load_kv_cache(self, path: str):
245
- """
246
- Load the key-value cache from the file.
247
-
248
- Args:
249
- path (str): The path to the file.
250
- """
251
- if not self._mlx_llm:
252
- raise RuntimeError("MLX LLM not loaded")
253
-
254
- try:
255
- success = self._mlx_llm.load_kv_cache(path)
256
- if not success:
257
- raise RuntimeError("Failed to load KV cache")
258
- except Exception as e:
259
- raise RuntimeError(f"Failed to load KV cache: {str(e)}")
260
-
261
- def reset(self):
262
- """
263
- Reset the LLM model context and KV cache.
264
- """
265
- if not self._mlx_llm:
266
- raise RuntimeError("MLX LLM not loaded")
267
-
268
- try:
269
- self._mlx_llm.reset()
270
- except Exception as e:
271
- raise RuntimeError(f"Failed to reset MLX LLM: {str(e)}")
1
+ from typing import Generator, Optional, Any, Sequence, Union
2
+
3
+ from nexaai.base import ProfilingData
4
+ from nexaai.common import ModelConfig, GenerationConfig, ChatMessage, PluginID
5
+ from nexaai.llm import LLM
6
+ from nexaai.mlx_backend.llm.interface import LLM as MLXLLMInterface
7
+ from nexaai.mlx_backend.ml import ModelConfig as MLXModelConfig, SamplerConfig as MLXSamplerConfig, GenerationConfig as MLXGenerationConfig, EmbeddingConfig
8
+
9
+
10
+ class MLXLLMImpl(LLM):
11
+ def __init__(self, m_cfg: ModelConfig = ModelConfig()):
12
+ """Initialize MLX LLM implementation."""
13
+ super().__init__(m_cfg)
14
+ self._mlx_llm = None
15
+
16
+ @classmethod
17
+ def _load_from(cls,
18
+ local_path: str,
19
+ model_name: Optional[str] = None,
20
+ tokenizer_path: Optional[str] = None,
21
+ m_cfg: ModelConfig = ModelConfig(),
22
+ plugin_id: Union[PluginID, str] = PluginID.MLX,
23
+ device_id: Optional[str] = None
24
+ ) -> 'MLXLLMImpl':
25
+ """Load model from local path using MLX backend."""
26
+ try:
27
+ # MLX interface and configs are already imported
28
+
29
+ # Convert our ModelConfig to MLX ModelConfig
30
+ mlx_config = MLXModelConfig()
31
+ mlx_config.n_ctx = m_cfg.n_ctx
32
+ mlx_config.n_threads = m_cfg.n_threads
33
+ mlx_config.n_threads_batch = m_cfg.n_threads_batch
34
+ mlx_config.n_batch = m_cfg.n_batch
35
+ mlx_config.n_ubatch = m_cfg.n_ubatch
36
+ mlx_config.n_seq_max = m_cfg.n_seq_max
37
+ mlx_config.chat_template_path = m_cfg.chat_template_path
38
+ mlx_config.chat_template_content = m_cfg.chat_template_content
39
+
40
+ # Create instance and load MLX model
41
+ instance = cls(m_cfg)
42
+ instance._mlx_llm = MLXLLMInterface(
43
+ model_path=local_path,
44
+ # model_name=model_name, # FIXME: For MLX LLM, model_name is not used
45
+ tokenizer_path=tokenizer_path or local_path,
46
+ config=mlx_config,
47
+ device=device_id
48
+ )
49
+
50
+ return instance
51
+ except Exception as e:
52
+ raise RuntimeError(f"Failed to load MLX LLM: {str(e)}")
53
+
54
+ def eject(self):
55
+ """Release the model from memory."""
56
+ if self._mlx_llm:
57
+ self._mlx_llm.destroy()
58
+ self._mlx_llm = None
59
+
60
+ def apply_chat_template(
61
+ self,
62
+ messages: Sequence[ChatMessage],
63
+ tools: Optional[str] = None,
64
+ enable_thinking: bool = True,
65
+ add_generation_prompt: bool = True
66
+ ) -> str:
67
+ """Apply the chat template to messages."""
68
+ if not self._mlx_llm:
69
+ raise RuntimeError("MLX LLM not loaded")
70
+
71
+ try:
72
+ # Convert to MLX ChatMessage format
73
+ mlx_messages = []
74
+ for msg in messages:
75
+ # Create a simple object with role and content attributes
76
+ class MLXChatMessage:
77
+ def __init__(self, role, content):
78
+ self.role = role
79
+ self.content = content
80
+
81
+ # Handle both dict-style and attribute-style access
82
+ if hasattr(msg, 'role') and hasattr(msg, 'content'):
83
+ # Message is already an object with attributes
84
+ mlx_messages.append(MLXChatMessage(msg.role, msg.content))
85
+ else:
86
+ # Message is a dict
87
+ mlx_messages.append(MLXChatMessage(msg["role"], msg["content"]))
88
+
89
+ return self._mlx_llm.apply_chat_template(mlx_messages, tools=tools, enable_thinking=enable_thinking, add_generation_prompt=add_generation_prompt)
90
+ except Exception as e:
91
+ raise RuntimeError(f"Failed to apply chat template: {str(e)}")
92
+
93
+ def generate_stream(self, prompt: str, g_cfg: GenerationConfig = GenerationConfig()) -> Generator[str, None, None]:
94
+ """Generate text with streaming."""
95
+ if not self._mlx_llm:
96
+ raise RuntimeError("MLX LLM not loaded")
97
+
98
+ try:
99
+ import queue
100
+ import threading
101
+
102
+ # Convert GenerationConfig to MLX format
103
+
104
+ mlx_gen_config = MLXGenerationConfig()
105
+ mlx_gen_config.max_tokens = g_cfg.max_tokens
106
+ mlx_gen_config.stop = g_cfg.stop_words
107
+ mlx_gen_config.image_paths = g_cfg.image_paths
108
+ mlx_gen_config.audio_paths = g_cfg.audio_paths
109
+
110
+ if g_cfg.sampler_config:
111
+ mlx_sampler_config = MLXSamplerConfig()
112
+ mlx_sampler_config.temperature = g_cfg.sampler_config.temperature
113
+ mlx_sampler_config.top_p = g_cfg.sampler_config.top_p
114
+ mlx_sampler_config.top_k = g_cfg.sampler_config.top_k
115
+ mlx_sampler_config.repetition_penalty = g_cfg.sampler_config.repetition_penalty
116
+ mlx_sampler_config.presence_penalty = g_cfg.sampler_config.presence_penalty
117
+ mlx_sampler_config.frequency_penalty = g_cfg.sampler_config.frequency_penalty
118
+ mlx_sampler_config.seed = g_cfg.sampler_config.seed
119
+ mlx_sampler_config.grammar_path = g_cfg.sampler_config.grammar_path
120
+ mlx_sampler_config.grammar_string = g_cfg.sampler_config.grammar_string
121
+ mlx_gen_config.sampler_config = mlx_sampler_config
122
+
123
+ # Create a queue for streaming tokens
124
+ token_queue = queue.Queue()
125
+ exception_container = [None]
126
+ self.reset_cancel() # Reset cancel flag before generation
127
+
128
+ def token_callback(token: str, user_data: Any = None) -> bool:
129
+ if self._cancel_event.is_set():
130
+ token_queue.put(('end', None))
131
+ return False
132
+ try:
133
+ token_queue.put(('token', token))
134
+ return True
135
+ except Exception as e:
136
+ exception_container[0] = e
137
+ return False
138
+
139
+ # Run generation in a separate thread
140
+ def generate():
141
+ try:
142
+ self._mlx_llm.generate_stream(prompt, mlx_gen_config, token_callback)
143
+ except Exception as e:
144
+ exception_container[0] = e
145
+ finally:
146
+ token_queue.put(('end', None))
147
+
148
+ thread = threading.Thread(target=generate)
149
+ thread.start()
150
+
151
+ # Yield tokens as they come from the queue
152
+ while True:
153
+ if exception_container[0]:
154
+ raise exception_container[0]
155
+
156
+ try:
157
+ msg_type, token = token_queue.get(timeout=0.1)
158
+ if msg_type == 'end':
159
+ break
160
+ elif msg_type == 'token':
161
+ yield token
162
+ except queue.Empty:
163
+ if not thread.is_alive():
164
+ break
165
+ continue
166
+
167
+ thread.join()
168
+
169
+ if exception_container[0]:
170
+ raise exception_container[0]
171
+
172
+ except Exception as e:
173
+ raise RuntimeError(f"Failed to generate streaming text: {str(e)}")
174
+
175
+ def generate(self, prompt: str, g_cfg: GenerationConfig = GenerationConfig()) -> str:
176
+ """
177
+ Generate text without streaming.
178
+
179
+ Args:
180
+ prompt (str): The prompt to generate text from.
181
+ g_cfg (GenerationConfig): Generation configuration.
182
+
183
+ Returns:
184
+ str: The generated text.
185
+ """
186
+ if not self._mlx_llm:
187
+ raise RuntimeError("MLX LLM not loaded")
188
+
189
+ try:
190
+ # Convert GenerationConfig to MLX format
191
+
192
+ mlx_gen_config = MLXGenerationConfig()
193
+ mlx_gen_config.max_tokens = g_cfg.max_tokens
194
+ mlx_gen_config.stop = g_cfg.stop_words
195
+ mlx_gen_config.image_paths = g_cfg.image_paths
196
+ mlx_gen_config.audio_paths = g_cfg.audio_paths
197
+
198
+ if g_cfg.sampler_config:
199
+ mlx_sampler_config = MLXSamplerConfig()
200
+ mlx_sampler_config.temperature = g_cfg.sampler_config.temperature
201
+ mlx_sampler_config.top_p = g_cfg.sampler_config.top_p
202
+ mlx_sampler_config.top_k = g_cfg.sampler_config.top_k
203
+ mlx_sampler_config.repetition_penalty = g_cfg.sampler_config.repetition_penalty
204
+ mlx_sampler_config.presence_penalty = g_cfg.sampler_config.presence_penalty
205
+ mlx_sampler_config.frequency_penalty = g_cfg.sampler_config.frequency_penalty
206
+ mlx_sampler_config.seed = g_cfg.sampler_config.seed
207
+ mlx_sampler_config.grammar_path = g_cfg.sampler_config.grammar_path
208
+ mlx_sampler_config.grammar_string = g_cfg.sampler_config.grammar_string
209
+ mlx_gen_config.sampler_config = mlx_sampler_config
210
+
211
+ # Simple token callback that just continues
212
+ def token_callback(token: str, user_data: Any = None) -> bool:
213
+ return not self._cancel_event.is_set()
214
+
215
+ # Use MLX streaming generation and return the full result
216
+ return self._mlx_llm.generate_stream(prompt, mlx_gen_config, token_callback)
217
+
218
+ except Exception as e:
219
+ raise RuntimeError(f"Failed to generate text: {str(e)}")
220
+
221
+ def get_profiling_data(self) -> Optional[ProfilingData]:
222
+ """Get profiling data from the last generation."""
223
+ if not self._mlx_llm:
224
+ raise RuntimeError("MLX LLM not loaded")
225
+ return self._mlx_llm.get_profiling_data()
226
+
227
+ def save_kv_cache(self, path: str):
228
+ """
229
+ Save the key-value cache to the file.
230
+
231
+ Args:
232
+ path (str): The path to the file.
233
+ """
234
+ if not self._mlx_llm:
235
+ raise RuntimeError("MLX LLM not loaded")
236
+
237
+ try:
238
+ success = self._mlx_llm.save_kv_cache(path)
239
+ if not success:
240
+ raise RuntimeError("Failed to save KV cache")
241
+ except Exception as e:
242
+ raise RuntimeError(f"Failed to save KV cache: {str(e)}")
243
+
244
+ def load_kv_cache(self, path: str):
245
+ """
246
+ Load the key-value cache from the file.
247
+
248
+ Args:
249
+ path (str): The path to the file.
250
+ """
251
+ if not self._mlx_llm:
252
+ raise RuntimeError("MLX LLM not loaded")
253
+
254
+ try:
255
+ success = self._mlx_llm.load_kv_cache(path)
256
+ if not success:
257
+ raise RuntimeError("Failed to load KV cache")
258
+ except Exception as e:
259
+ raise RuntimeError(f"Failed to load KV cache: {str(e)}")
260
+
261
+ def reset(self):
262
+ """
263
+ Reset the LLM model context and KV cache.
264
+ """
265
+ if not self._mlx_llm:
266
+ raise RuntimeError("MLX LLM not loaded")
267
+
268
+ try:
269
+ self._mlx_llm.reset()
270
+ except Exception as e:
271
+ raise RuntimeError(f"Failed to reset MLX LLM: {str(e)}")