nexaai 1.0.21rc5__cp313-cp313-win_arm64.whl → 1.0.21rc14__cp313-cp313-win_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nexaai might be problematic. Click here for more details.

Files changed (105) hide show
  1. nexaai/__init__.py +95 -95
  2. nexaai/_stub.cp313-win_arm64.pyd +0 -0
  3. nexaai/_version.py +4 -1
  4. nexaai/asr.py +68 -65
  5. nexaai/asr_impl/mlx_asr_impl.py +92 -92
  6. nexaai/asr_impl/pybind_asr_impl.py +127 -44
  7. nexaai/base.py +39 -39
  8. nexaai/binds/__init__.py +6 -5
  9. nexaai/binds/asr_bind.cp313-win_arm64.pyd +0 -0
  10. nexaai/binds/common_bind.cp313-win_arm64.pyd +0 -0
  11. nexaai/binds/cpu_gpu/ggml-base.dll +0 -0
  12. nexaai/binds/cpu_gpu/ggml-cpu.dll +0 -0
  13. nexaai/binds/cpu_gpu/ggml-opencl.dll +0 -0
  14. nexaai/binds/cpu_gpu/ggml.dll +0 -0
  15. nexaai/binds/cpu_gpu/mtmd.dll +0 -0
  16. nexaai/binds/cpu_gpu/nexa_cpu_gpu.dll +0 -0
  17. nexaai/binds/cpu_gpu/nexa_plugin.dll +0 -0
  18. nexaai/binds/embedder_bind.cp313-win_arm64.pyd +0 -0
  19. nexaai/binds/libcrypto-3-arm64.dll +0 -0
  20. nexaai/binds/libssl-3-arm64.dll +0 -0
  21. nexaai/binds/llm_bind.cp313-win_arm64.pyd +0 -0
  22. nexaai/binds/nexa_bridge.dll +0 -0
  23. nexaai/binds/npu/convnext-sdk.dll +0 -0
  24. nexaai/binds/npu/embed-gemma-sdk.dll +0 -0
  25. nexaai/binds/npu/ggml-base.dll +0 -0
  26. nexaai/binds/npu/ggml-cpu.dll +0 -0
  27. nexaai/binds/npu/ggml-opencl.dll +0 -0
  28. nexaai/binds/npu/ggml.dll +0 -0
  29. nexaai/binds/npu/granite-nano-sdk.dll +0 -0
  30. nexaai/binds/npu/granite4-sdk.dll +0 -0
  31. nexaai/binds/npu/jina-rerank-sdk.dll +0 -0
  32. nexaai/binds/npu/liquid-sdk.dll +0 -0
  33. nexaai/binds/npu/llama3-3b-sdk.dll +0 -0
  34. nexaai/binds/npu/nexa-mm-process.dll +0 -0
  35. nexaai/binds/npu/nexa-sampling.dll +0 -0
  36. nexaai/binds/npu/nexa_plugin.dll +0 -0
  37. nexaai/binds/npu/omni-neural-sdk.dll +0 -0
  38. nexaai/binds/npu/openblas.dll +0 -0
  39. nexaai/binds/npu/paddleocr-sdk.dll +0 -0
  40. nexaai/binds/npu/parakeet-sdk.dll +0 -0
  41. nexaai/binds/npu/phi3-5-sdk.dll +0 -0
  42. nexaai/binds/npu/phi4-sdk.dll +0 -0
  43. nexaai/binds/npu/pyannote-sdk.dll +0 -0
  44. nexaai/binds/npu/qwen3-4b-sdk.dll +0 -0
  45. nexaai/binds/npu/qwen3vl-sdk.dll +0 -0
  46. nexaai/binds/npu/qwen3vl-vision.dll +0 -0
  47. nexaai/binds/npu/yolov12-sdk.dll +0 -0
  48. nexaai/binds/npu/zlib1.dll +0 -0
  49. nexaai/binds/rerank_bind.cp313-win_arm64.pyd +0 -0
  50. nexaai/binds/vlm_bind.cp313-win_arm64.pyd +0 -0
  51. nexaai/common.py +105 -105
  52. nexaai/cv.py +93 -93
  53. nexaai/cv_impl/mlx_cv_impl.py +89 -89
  54. nexaai/cv_impl/pybind_cv_impl.py +32 -32
  55. nexaai/embedder.py +73 -73
  56. nexaai/embedder_impl/mlx_embedder_impl.py +118 -118
  57. nexaai/embedder_impl/pybind_embedder_impl.py +96 -96
  58. nexaai/image_gen.py +141 -141
  59. nexaai/image_gen_impl/mlx_image_gen_impl.py +292 -292
  60. nexaai/image_gen_impl/pybind_image_gen_impl.py +85 -85
  61. nexaai/llm.py +98 -98
  62. nexaai/llm_impl/mlx_llm_impl.py +271 -271
  63. nexaai/llm_impl/pybind_llm_impl.py +220 -220
  64. nexaai/log.py +92 -92
  65. nexaai/rerank.py +57 -57
  66. nexaai/rerank_impl/mlx_rerank_impl.py +94 -94
  67. nexaai/rerank_impl/pybind_rerank_impl.py +136 -136
  68. nexaai/runtime.py +68 -68
  69. nexaai/runtime_error.py +24 -24
  70. nexaai/tts.py +75 -75
  71. nexaai/tts_impl/mlx_tts_impl.py +94 -94
  72. nexaai/tts_impl/pybind_tts_impl.py +43 -43
  73. nexaai/utils/decode.py +17 -17
  74. nexaai/utils/manifest_utils.py +531 -531
  75. nexaai/utils/model_manager.py +1562 -1562
  76. nexaai/utils/model_types.py +49 -49
  77. nexaai/utils/progress_tracker.py +384 -384
  78. nexaai/utils/quantization_utils.py +245 -245
  79. nexaai/vlm.py +129 -129
  80. nexaai/vlm_impl/mlx_vlm_impl.py +258 -258
  81. nexaai/vlm_impl/pybind_vlm_impl.py +256 -256
  82. {nexaai-1.0.21rc5.dist-info → nexaai-1.0.21rc14.dist-info}/METADATA +1 -1
  83. nexaai-1.0.21rc14.dist-info/RECORD +154 -0
  84. nexaai/binds/nexaml/FLAC.dll +0 -0
  85. nexaai/binds/nexaml/fftw3.dll +0 -0
  86. nexaai/binds/nexaml/fftw3f.dll +0 -0
  87. nexaai/binds/nexaml/ggml-base.dll +0 -0
  88. nexaai/binds/nexaml/ggml-cpu.dll +0 -0
  89. nexaai/binds/nexaml/ggml-opencl.dll +0 -0
  90. nexaai/binds/nexaml/ggml.dll +0 -0
  91. nexaai/binds/nexaml/libmp3lame.DLL +0 -0
  92. nexaai/binds/nexaml/mpg123.dll +0 -0
  93. nexaai/binds/nexaml/nexa-mm-process.dll +0 -0
  94. nexaai/binds/nexaml/nexa-sampling.dll +0 -0
  95. nexaai/binds/nexaml/nexa_plugin.dll +0 -0
  96. nexaai/binds/nexaml/nexaproc.dll +0 -0
  97. nexaai/binds/nexaml/ogg.dll +0 -0
  98. nexaai/binds/nexaml/opus.dll +0 -0
  99. nexaai/binds/nexaml/qwen3-vl.dll +0 -0
  100. nexaai/binds/nexaml/qwen3vl-vision.dll +0 -0
  101. nexaai/binds/nexaml/vorbis.dll +0 -0
  102. nexaai/binds/nexaml/vorbisenc.dll +0 -0
  103. nexaai-1.0.21rc5.dist-info/RECORD +0 -162
  104. {nexaai-1.0.21rc5.dist-info → nexaai-1.0.21rc14.dist-info}/WHEEL +0 -0
  105. {nexaai-1.0.21rc5.dist-info → nexaai-1.0.21rc14.dist-info}/top_level.txt +0 -0
@@ -1,259 +1,259 @@
1
- from typing import Generator, Optional, List, Dict, Any, Union
2
-
3
- from nexaai.base import ProfilingData
4
- from nexaai.common import ModelConfig, GenerationConfig, MultiModalMessage, PluginID
5
- from nexaai.vlm import VLM
6
- from nexaai.mlx_backend.vlm.interface import VLM as MLXVLMInterface
7
- from nexaai.mlx_backend.ml import ModelConfig as MLXModelConfig, SamplerConfig as MLXSamplerConfig, GenerationConfig as MLXGenerationConfig, EmbeddingConfig
8
-
9
-
10
- class MlxVlmImpl(VLM):
11
- def __init__(self, m_cfg: ModelConfig = ModelConfig()):
12
- """Initialize MLX VLM implementation."""
13
- super().__init__(m_cfg)
14
- self._mlx_vlm = None
15
-
16
- @classmethod
17
- def _load_from(cls,
18
- local_path: str,
19
- mmproj_path: str = None,
20
- model_name: Optional[str] = None,
21
- m_cfg: ModelConfig = ModelConfig(),
22
- plugin_id: Union[PluginID, str] = PluginID.MLX,
23
- device_id: Optional[str] = None
24
- ) -> 'MlxVlmImpl':
25
- """Load VLM model from local path using MLX backend.
26
-
27
- Args:
28
- local_path: Path to the main model file
29
- mmproj_path: Path to the multimodal projection file (not used in MLX VLM)
30
- m_cfg: Model configuration
31
- plugin_id: Plugin identifier
32
- device_id: Optional device ID
33
-
34
- Returns:
35
- MlxVlmImpl instance
36
- """
37
- try:
38
- # MLX interface is already imported
39
-
40
- # Create instance and load MLX VLM
41
- instance = cls(m_cfg)
42
- instance._mlx_vlm = MLXVLMInterface(
43
- model_name=model_name,
44
- model_path=local_path,
45
- mmproj_path=mmproj_path, # MLX VLM may not use this, but pass it anyway
46
- context_length=m_cfg.n_ctx,
47
- device=device_id
48
- )
49
-
50
- return instance
51
- except Exception as e:
52
- raise RuntimeError(f"Failed to load MLX VLM: {str(e)}")
53
-
54
- def eject(self):
55
- """Release the model from memory."""
56
- if self._mlx_vlm:
57
- self._mlx_vlm.destroy()
58
- self._mlx_vlm = None
59
-
60
- def reset(self):
61
- """
62
- Reset the VLM model context and KV cache.
63
- """
64
- if not self._mlx_vlm:
65
- raise RuntimeError("MLX VLM not loaded")
66
-
67
- try:
68
- self._mlx_vlm.reset()
69
- except Exception as e:
70
- raise RuntimeError(f"Failed to reset MLX VLM: {str(e)}")
71
-
72
- def apply_chat_template(
73
- self,
74
- messages: List[MultiModalMessage],
75
- tools: Optional[List[Dict[str, Any]]] = None,
76
- enable_thinking: bool = True
77
- ) -> str:
78
- """Apply the chat template to multimodal messages."""
79
- if not self._mlx_vlm:
80
- raise RuntimeError("MLX VLM not loaded")
81
-
82
- try:
83
- mlx_messages = []
84
- total_images = 0
85
- total_audios = 0
86
-
87
- for msg in messages:
88
- # Create a simple object with role and content attributes
89
- class MLXChatMessage:
90
- def __init__(self, role, content):
91
- self.role = role
92
- self.content = content
93
-
94
- # Extract text content and count media files
95
- text_content = ""
96
- first_content = True
97
-
98
- for content_item in msg["content"]:
99
- content_type = content_item.get("type", "")
100
-
101
- if content_type == "text":
102
- if not first_content:
103
- text_content += " "
104
- text_content += content_item.get("text", "")
105
- first_content = False
106
- elif content_type == "image":
107
- total_images += 1
108
- elif content_type == "audio":
109
- total_audios += 1
110
-
111
- mlx_messages.append(MLXChatMessage(msg["role"], text_content))
112
-
113
- if total_images > 0 or total_audios > 0:
114
- # Use apply_chat_template_with_media when media is present
115
- return self._mlx_vlm.apply_chat_template_with_media(
116
- mlx_messages,
117
- num_images=total_images,
118
- num_audios=total_audios,
119
- tools=tools,
120
- enable_thinking=enable_thinking
121
- )
122
- else:
123
- # Use regular apply_chat_template for text-only messages
124
- return self._mlx_vlm.apply_chat_template(mlx_messages)
125
-
126
- except Exception as e:
127
- raise RuntimeError(f"Failed to apply chat template: {str(e)}")
128
-
129
- def generate_stream(self, prompt: str, g_cfg: GenerationConfig = GenerationConfig()) -> Generator[str, None, None]:
130
- """Generate text with streaming."""
131
- if not self._mlx_vlm:
132
- raise RuntimeError("MLX VLM not loaded")
133
-
134
- try:
135
- # Convert GenerationConfig to MLX format
136
- mlx_gen_config = MLXGenerationConfig()
137
- mlx_gen_config.max_tokens = g_cfg.max_tokens
138
- mlx_gen_config.stop = g_cfg.stop_words
139
- mlx_gen_config.image_paths = g_cfg.image_paths
140
- mlx_gen_config.audio_paths = g_cfg.audio_paths
141
-
142
- if g_cfg.sampler_config:
143
- mlx_sampler_config = MLXSamplerConfig()
144
- mlx_sampler_config.temperature = g_cfg.sampler_config.temperature
145
- mlx_sampler_config.top_p = g_cfg.sampler_config.top_p
146
- mlx_sampler_config.top_k = g_cfg.sampler_config.top_k
147
- mlx_sampler_config.repetition_penalty = g_cfg.sampler_config.repetition_penalty
148
- mlx_sampler_config.presence_penalty = g_cfg.sampler_config.presence_penalty
149
- mlx_sampler_config.frequency_penalty = g_cfg.sampler_config.frequency_penalty
150
- mlx_sampler_config.seed = g_cfg.sampler_config.seed
151
- mlx_sampler_config.grammar_path = g_cfg.sampler_config.grammar_path
152
- mlx_sampler_config.grammar_string = g_cfg.sampler_config.grammar_string
153
- mlx_gen_config.sampler_config = mlx_sampler_config
154
-
155
- import queue
156
- import threading
157
-
158
- # Create a queue for streaming tokens
159
- token_queue = queue.Queue()
160
- exception_container = [None]
161
- self.reset_cancel() # Reset cancel flag before generation
162
-
163
- def token_callback(token: str, user_data: Any = None) -> bool:
164
- if self._cancel_event.is_set():
165
- token_queue.put(('end', None))
166
- return False
167
- try:
168
- token_queue.put(('token', token))
169
- return True
170
- except Exception as e:
171
- exception_container[0] = e
172
- return False
173
-
174
- # Run generation in a separate thread
175
- def generate():
176
- try:
177
- self._mlx_vlm.generate_stream(prompt, mlx_gen_config, token_callback)
178
- except Exception as e:
179
- exception_container[0] = e
180
- finally:
181
- token_queue.put(('end', None))
182
-
183
- thread = threading.Thread(target=generate)
184
- thread.start()
185
-
186
- # Yield tokens as they come from the queue
187
- while True:
188
- if exception_container[0]:
189
- raise exception_container[0]
190
-
191
- try:
192
- msg_type, token = token_queue.get(timeout=0.1)
193
- if msg_type == 'end':
194
- break
195
- elif msg_type == 'token':
196
- yield token
197
- except queue.Empty:
198
- if not thread.is_alive():
199
- break
200
- continue
201
-
202
- thread.join()
203
-
204
- if exception_container[0]:
205
- raise exception_container[0]
206
-
207
- except Exception as e:
208
- raise RuntimeError(f"Failed to generate streaming text: {str(e)}")
209
-
210
- def generate(self, prompt: str, g_cfg: GenerationConfig = GenerationConfig()) -> str:
211
- """
212
- Generate text without streaming.
213
-
214
- Args:
215
- prompt (str): The prompt to generate text from.
216
- g_cfg (GenerationConfig): Generation configuration.
217
-
218
- Returns:
219
- str: The generated text.
220
- """
221
- if not self._mlx_vlm:
222
- raise RuntimeError("MLX VLM not loaded")
223
-
224
- try:
225
- # Convert GenerationConfig to MLX format
226
- mlx_gen_config = MLXGenerationConfig()
227
- mlx_gen_config.max_tokens = g_cfg.max_tokens
228
- mlx_gen_config.stop = g_cfg.stop_words
229
- mlx_gen_config.image_paths = g_cfg.image_paths
230
- mlx_gen_config.audio_paths = g_cfg.audio_paths
231
-
232
- if g_cfg.sampler_config:
233
- mlx_sampler_config = MLXSamplerConfig()
234
- mlx_sampler_config.temperature = g_cfg.sampler_config.temperature
235
- mlx_sampler_config.top_p = g_cfg.sampler_config.top_p
236
- mlx_sampler_config.top_k = g_cfg.sampler_config.top_k
237
- mlx_sampler_config.repetition_penalty = g_cfg.sampler_config.repetition_penalty
238
- mlx_sampler_config.presence_penalty = g_cfg.sampler_config.presence_penalty
239
- mlx_sampler_config.frequency_penalty = g_cfg.sampler_config.frequency_penalty
240
- mlx_sampler_config.seed = g_cfg.sampler_config.seed
241
- mlx_sampler_config.grammar_path = g_cfg.sampler_config.grammar_path
242
- mlx_sampler_config.grammar_string = g_cfg.sampler_config.grammar_string
243
- mlx_gen_config.sampler_config = mlx_sampler_config
244
-
245
- # Simple token callback that just continues
246
- def token_callback(token: str, user_data: Any = None) -> bool:
247
- return not self._cancel_event.is_set()
248
-
249
- # Use MLX streaming generation and return the full result
250
- return self._mlx_vlm.generate_stream(prompt, mlx_gen_config, token_callback)
251
-
252
- except Exception as e:
253
- raise RuntimeError(f"Failed to generate text: {str(e)}")
254
-
255
- def get_profiling_data(self) -> Optional[ProfilingData]:
256
- """Get profiling data from the last generation."""
257
- if not self._mlx_vlm:
258
- raise RuntimeError("MLX VLM not loaded")
1
+ from typing import Generator, Optional, List, Dict, Any, Union
2
+
3
+ from nexaai.base import ProfilingData
4
+ from nexaai.common import ModelConfig, GenerationConfig, MultiModalMessage, PluginID
5
+ from nexaai.vlm import VLM
6
+ from nexaai.mlx_backend.vlm.interface import VLM as MLXVLMInterface
7
+ from nexaai.mlx_backend.ml import ModelConfig as MLXModelConfig, SamplerConfig as MLXSamplerConfig, GenerationConfig as MLXGenerationConfig, EmbeddingConfig
8
+
9
+
10
+ class MlxVlmImpl(VLM):
11
+ def __init__(self, m_cfg: ModelConfig = ModelConfig()):
12
+ """Initialize MLX VLM implementation."""
13
+ super().__init__(m_cfg)
14
+ self._mlx_vlm = None
15
+
16
+ @classmethod
17
+ def _load_from(cls,
18
+ local_path: str,
19
+ mmproj_path: str = None,
20
+ model_name: Optional[str] = None,
21
+ m_cfg: ModelConfig = ModelConfig(),
22
+ plugin_id: Union[PluginID, str] = PluginID.MLX,
23
+ device_id: Optional[str] = None
24
+ ) -> 'MlxVlmImpl':
25
+ """Load VLM model from local path using MLX backend.
26
+
27
+ Args:
28
+ local_path: Path to the main model file
29
+ mmproj_path: Path to the multimodal projection file (not used in MLX VLM)
30
+ m_cfg: Model configuration
31
+ plugin_id: Plugin identifier
32
+ device_id: Optional device ID
33
+
34
+ Returns:
35
+ MlxVlmImpl instance
36
+ """
37
+ try:
38
+ # MLX interface is already imported
39
+
40
+ # Create instance and load MLX VLM
41
+ instance = cls(m_cfg)
42
+ instance._mlx_vlm = MLXVLMInterface(
43
+ model_name=model_name,
44
+ model_path=local_path,
45
+ mmproj_path=mmproj_path, # MLX VLM may not use this, but pass it anyway
46
+ context_length=m_cfg.n_ctx,
47
+ device=device_id
48
+ )
49
+
50
+ return instance
51
+ except Exception as e:
52
+ raise RuntimeError(f"Failed to load MLX VLM: {str(e)}")
53
+
54
+ def eject(self):
55
+ """Release the model from memory."""
56
+ if self._mlx_vlm:
57
+ self._mlx_vlm.destroy()
58
+ self._mlx_vlm = None
59
+
60
+ def reset(self):
61
+ """
62
+ Reset the VLM model context and KV cache.
63
+ """
64
+ if not self._mlx_vlm:
65
+ raise RuntimeError("MLX VLM not loaded")
66
+
67
+ try:
68
+ self._mlx_vlm.reset()
69
+ except Exception as e:
70
+ raise RuntimeError(f"Failed to reset MLX VLM: {str(e)}")
71
+
72
+ def apply_chat_template(
73
+ self,
74
+ messages: List[MultiModalMessage],
75
+ tools: Optional[List[Dict[str, Any]]] = None,
76
+ enable_thinking: bool = True
77
+ ) -> str:
78
+ """Apply the chat template to multimodal messages."""
79
+ if not self._mlx_vlm:
80
+ raise RuntimeError("MLX VLM not loaded")
81
+
82
+ try:
83
+ mlx_messages = []
84
+ total_images = 0
85
+ total_audios = 0
86
+
87
+ for msg in messages:
88
+ # Create a simple object with role and content attributes
89
+ class MLXChatMessage:
90
+ def __init__(self, role, content):
91
+ self.role = role
92
+ self.content = content
93
+
94
+ # Extract text content and count media files
95
+ text_content = ""
96
+ first_content = True
97
+
98
+ for content_item in msg["content"]:
99
+ content_type = content_item.get("type", "")
100
+
101
+ if content_type == "text":
102
+ if not first_content:
103
+ text_content += " "
104
+ text_content += content_item.get("text", "")
105
+ first_content = False
106
+ elif content_type == "image":
107
+ total_images += 1
108
+ elif content_type == "audio":
109
+ total_audios += 1
110
+
111
+ mlx_messages.append(MLXChatMessage(msg["role"], text_content))
112
+
113
+ if total_images > 0 or total_audios > 0:
114
+ # Use apply_chat_template_with_media when media is present
115
+ return self._mlx_vlm.apply_chat_template_with_media(
116
+ mlx_messages,
117
+ num_images=total_images,
118
+ num_audios=total_audios,
119
+ tools=tools,
120
+ enable_thinking=enable_thinking
121
+ )
122
+ else:
123
+ # Use regular apply_chat_template for text-only messages
124
+ return self._mlx_vlm.apply_chat_template(mlx_messages)
125
+
126
+ except Exception as e:
127
+ raise RuntimeError(f"Failed to apply chat template: {str(e)}")
128
+
129
+ def generate_stream(self, prompt: str, g_cfg: GenerationConfig = GenerationConfig()) -> Generator[str, None, None]:
130
+ """Generate text with streaming."""
131
+ if not self._mlx_vlm:
132
+ raise RuntimeError("MLX VLM not loaded")
133
+
134
+ try:
135
+ # Convert GenerationConfig to MLX format
136
+ mlx_gen_config = MLXGenerationConfig()
137
+ mlx_gen_config.max_tokens = g_cfg.max_tokens
138
+ mlx_gen_config.stop = g_cfg.stop_words
139
+ mlx_gen_config.image_paths = g_cfg.image_paths
140
+ mlx_gen_config.audio_paths = g_cfg.audio_paths
141
+
142
+ if g_cfg.sampler_config:
143
+ mlx_sampler_config = MLXSamplerConfig()
144
+ mlx_sampler_config.temperature = g_cfg.sampler_config.temperature
145
+ mlx_sampler_config.top_p = g_cfg.sampler_config.top_p
146
+ mlx_sampler_config.top_k = g_cfg.sampler_config.top_k
147
+ mlx_sampler_config.repetition_penalty = g_cfg.sampler_config.repetition_penalty
148
+ mlx_sampler_config.presence_penalty = g_cfg.sampler_config.presence_penalty
149
+ mlx_sampler_config.frequency_penalty = g_cfg.sampler_config.frequency_penalty
150
+ mlx_sampler_config.seed = g_cfg.sampler_config.seed
151
+ mlx_sampler_config.grammar_path = g_cfg.sampler_config.grammar_path
152
+ mlx_sampler_config.grammar_string = g_cfg.sampler_config.grammar_string
153
+ mlx_gen_config.sampler_config = mlx_sampler_config
154
+
155
+ import queue
156
+ import threading
157
+
158
+ # Create a queue for streaming tokens
159
+ token_queue = queue.Queue()
160
+ exception_container = [None]
161
+ self.reset_cancel() # Reset cancel flag before generation
162
+
163
+ def token_callback(token: str, user_data: Any = None) -> bool:
164
+ if self._cancel_event.is_set():
165
+ token_queue.put(('end', None))
166
+ return False
167
+ try:
168
+ token_queue.put(('token', token))
169
+ return True
170
+ except Exception as e:
171
+ exception_container[0] = e
172
+ return False
173
+
174
+ # Run generation in a separate thread
175
+ def generate():
176
+ try:
177
+ self._mlx_vlm.generate_stream(prompt, mlx_gen_config, token_callback)
178
+ except Exception as e:
179
+ exception_container[0] = e
180
+ finally:
181
+ token_queue.put(('end', None))
182
+
183
+ thread = threading.Thread(target=generate)
184
+ thread.start()
185
+
186
+ # Yield tokens as they come from the queue
187
+ while True:
188
+ if exception_container[0]:
189
+ raise exception_container[0]
190
+
191
+ try:
192
+ msg_type, token = token_queue.get(timeout=0.1)
193
+ if msg_type == 'end':
194
+ break
195
+ elif msg_type == 'token':
196
+ yield token
197
+ except queue.Empty:
198
+ if not thread.is_alive():
199
+ break
200
+ continue
201
+
202
+ thread.join()
203
+
204
+ if exception_container[0]:
205
+ raise exception_container[0]
206
+
207
+ except Exception as e:
208
+ raise RuntimeError(f"Failed to generate streaming text: {str(e)}")
209
+
210
+ def generate(self, prompt: str, g_cfg: GenerationConfig = GenerationConfig()) -> str:
211
+ """
212
+ Generate text without streaming.
213
+
214
+ Args:
215
+ prompt (str): The prompt to generate text from.
216
+ g_cfg (GenerationConfig): Generation configuration.
217
+
218
+ Returns:
219
+ str: The generated text.
220
+ """
221
+ if not self._mlx_vlm:
222
+ raise RuntimeError("MLX VLM not loaded")
223
+
224
+ try:
225
+ # Convert GenerationConfig to MLX format
226
+ mlx_gen_config = MLXGenerationConfig()
227
+ mlx_gen_config.max_tokens = g_cfg.max_tokens
228
+ mlx_gen_config.stop = g_cfg.stop_words
229
+ mlx_gen_config.image_paths = g_cfg.image_paths
230
+ mlx_gen_config.audio_paths = g_cfg.audio_paths
231
+
232
+ if g_cfg.sampler_config:
233
+ mlx_sampler_config = MLXSamplerConfig()
234
+ mlx_sampler_config.temperature = g_cfg.sampler_config.temperature
235
+ mlx_sampler_config.top_p = g_cfg.sampler_config.top_p
236
+ mlx_sampler_config.top_k = g_cfg.sampler_config.top_k
237
+ mlx_sampler_config.repetition_penalty = g_cfg.sampler_config.repetition_penalty
238
+ mlx_sampler_config.presence_penalty = g_cfg.sampler_config.presence_penalty
239
+ mlx_sampler_config.frequency_penalty = g_cfg.sampler_config.frequency_penalty
240
+ mlx_sampler_config.seed = g_cfg.sampler_config.seed
241
+ mlx_sampler_config.grammar_path = g_cfg.sampler_config.grammar_path
242
+ mlx_sampler_config.grammar_string = g_cfg.sampler_config.grammar_string
243
+ mlx_gen_config.sampler_config = mlx_sampler_config
244
+
245
+ # Simple token callback that just continues
246
+ def token_callback(token: str, user_data: Any = None) -> bool:
247
+ return not self._cancel_event.is_set()
248
+
249
+ # Use MLX streaming generation and return the full result
250
+ return self._mlx_vlm.generate_stream(prompt, mlx_gen_config, token_callback)
251
+
252
+ except Exception as e:
253
+ raise RuntimeError(f"Failed to generate text: {str(e)}")
254
+
255
+ def get_profiling_data(self) -> Optional[ProfilingData]:
256
+ """Get profiling data from the last generation."""
257
+ if not self._mlx_vlm:
258
+ raise RuntimeError("MLX VLM not loaded")
259
259
  return self._mlx_vlm.get_profiling_data()