nexaai 1.0.4rc13__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nexaai might be problematic. Click here for more details.

Files changed (59) hide show
  1. nexaai/__init__.py +71 -0
  2. nexaai/_stub.cp310-win_amd64.pyd +0 -0
  3. nexaai/_version.py +4 -0
  4. nexaai/asr.py +60 -0
  5. nexaai/asr_impl/__init__.py +0 -0
  6. nexaai/asr_impl/mlx_asr_impl.py +91 -0
  7. nexaai/asr_impl/pybind_asr_impl.py +43 -0
  8. nexaai/base.py +39 -0
  9. nexaai/binds/__init__.py +3 -0
  10. nexaai/binds/common_bind.cp310-win_amd64.pyd +0 -0
  11. nexaai/binds/embedder_bind.cp310-win_amd64.pyd +0 -0
  12. nexaai/binds/llm_bind.cp310-win_amd64.pyd +0 -0
  13. nexaai/binds/nexa_bridge.dll +0 -0
  14. nexaai/binds/nexa_llama_cpp/ggml-base.dll +0 -0
  15. nexaai/binds/nexa_llama_cpp/ggml-cpu.dll +0 -0
  16. nexaai/binds/nexa_llama_cpp/ggml-cuda.dll +0 -0
  17. nexaai/binds/nexa_llama_cpp/ggml-vulkan.dll +0 -0
  18. nexaai/binds/nexa_llama_cpp/ggml.dll +0 -0
  19. nexaai/binds/nexa_llama_cpp/llama.dll +0 -0
  20. nexaai/binds/nexa_llama_cpp/mtmd.dll +0 -0
  21. nexaai/binds/nexa_llama_cpp/nexa_plugin.dll +0 -0
  22. nexaai/common.py +61 -0
  23. nexaai/cv.py +87 -0
  24. nexaai/cv_impl/__init__.py +0 -0
  25. nexaai/cv_impl/mlx_cv_impl.py +88 -0
  26. nexaai/cv_impl/pybind_cv_impl.py +31 -0
  27. nexaai/embedder.py +68 -0
  28. nexaai/embedder_impl/__init__.py +0 -0
  29. nexaai/embedder_impl/mlx_embedder_impl.py +114 -0
  30. nexaai/embedder_impl/pybind_embedder_impl.py +91 -0
  31. nexaai/image_gen.py +136 -0
  32. nexaai/image_gen_impl/__init__.py +0 -0
  33. nexaai/image_gen_impl/mlx_image_gen_impl.py +291 -0
  34. nexaai/image_gen_impl/pybind_image_gen_impl.py +84 -0
  35. nexaai/llm.py +89 -0
  36. nexaai/llm_impl/__init__.py +0 -0
  37. nexaai/llm_impl/mlx_llm_impl.py +249 -0
  38. nexaai/llm_impl/pybind_llm_impl.py +207 -0
  39. nexaai/rerank.py +51 -0
  40. nexaai/rerank_impl/__init__.py +0 -0
  41. nexaai/rerank_impl/mlx_rerank_impl.py +91 -0
  42. nexaai/rerank_impl/pybind_rerank_impl.py +42 -0
  43. nexaai/runtime.py +64 -0
  44. nexaai/tts.py +70 -0
  45. nexaai/tts_impl/__init__.py +0 -0
  46. nexaai/tts_impl/mlx_tts_impl.py +93 -0
  47. nexaai/tts_impl/pybind_tts_impl.py +42 -0
  48. nexaai/utils/avatar_fetcher.py +104 -0
  49. nexaai/utils/decode.py +18 -0
  50. nexaai/utils/model_manager.py +1195 -0
  51. nexaai/utils/progress_tracker.py +372 -0
  52. nexaai/vlm.py +120 -0
  53. nexaai/vlm_impl/__init__.py +0 -0
  54. nexaai/vlm_impl/mlx_vlm_impl.py +205 -0
  55. nexaai/vlm_impl/pybind_vlm_impl.py +228 -0
  56. nexaai-1.0.4rc13.dist-info/METADATA +26 -0
  57. nexaai-1.0.4rc13.dist-info/RECORD +59 -0
  58. nexaai-1.0.4rc13.dist-info/WHEEL +5 -0
  59. nexaai-1.0.4rc13.dist-info/top_level.txt +1 -0
@@ -0,0 +1,372 @@
1
+ """
2
+ Progress tracking utilities for downloads with tqdm integration.
3
+
4
+ This module provides custom progress tracking classes that can monitor
5
+ download progress with callback support and customizable display options.
6
+ """
7
+
8
+ import os
9
+ import sys
10
+ import time
11
+ from typing import Optional, Callable, Dict, Any
12
+ from tqdm.auto import tqdm
13
+
14
+
15
+ class CustomProgressTqdm(tqdm):
16
+ """Custom tqdm that tracks progress but completely hides terminal output."""
17
+
18
+ def __init__(self, *args, **kwargs):
19
+ # Redirect output to devnull to completely suppress terminal output
20
+ kwargs['file'] = open(os.devnull, 'w')
21
+ kwargs['disable'] = False # Keep enabled for tracking
22
+ kwargs['leave'] = False # Don't leave progress bar
23
+ super().__init__(*args, **kwargs)
24
+
25
+ def display(self, msg=None, pos=None):
26
+ # Override display to show nothing
27
+ pass
28
+
29
+ def write(self, s, file=None, end="\n", nolock=False):
30
+ # Override write to prevent any output
31
+ pass
32
+
33
+ def close(self):
34
+ # Override close to avoid printing and properly close devnull
35
+ if hasattr(self, 'fp') and self.fp and self.fp != sys.stdout and self.fp != sys.stderr:
36
+ try:
37
+ self.fp.close()
38
+ except:
39
+ pass
40
+ self.disable = True
41
+ super(tqdm, self).close()
42
+
43
+
44
+ class DownloadProgressTracker:
45
+ """Progress tracker for HuggingFace downloads with callback support."""
46
+
47
+ def __init__(self, progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None, show_progress: bool = True):
48
+ self.progress_data: Dict[str, Dict[str, Any]] = {}
49
+ self.total_repo_size = 0
50
+ self.repo_file_count = 0
51
+ self.original_tqdm_update = None
52
+ self.original_tqdm_init = None
53
+ self.original_tqdm_display = None
54
+ self.original_tqdm_write = None
55
+ self.is_tracking = False
56
+
57
+ # Callback function
58
+ self.progress_callback = progress_callback
59
+
60
+ # Progress display
61
+ self.show_progress = show_progress
62
+ self.last_display_length = 0
63
+
64
+ # Speed tracking
65
+ self.last_downloaded = 0
66
+ self.last_time = time.time()
67
+ self.speed_history = []
68
+ self.max_speed_history = 10
69
+
70
+ # Download status
71
+ self.download_status = "idle" # idle, downloading, completed, error
72
+ self.error_message = None
73
+ self.download_start_time = None
74
+
75
+ def set_repo_info(self, total_size: int, file_count: int):
76
+ """Set the total repository size and file count before download."""
77
+ self.total_repo_size = total_size
78
+ self.repo_file_count = file_count
79
+
80
+ def register_tqdm(self, tqdm_instance):
81
+ """Register a tqdm instance for monitoring."""
82
+ tqdm_id = str(id(tqdm_instance))
83
+ self.progress_data[tqdm_id] = {
84
+ 'current': 0,
85
+ 'total': getattr(tqdm_instance, 'total', 0) or 0,
86
+ 'desc': getattr(tqdm_instance, 'desc', 'Unknown'),
87
+ 'tqdm_obj': tqdm_instance
88
+ }
89
+ # Trigger callback when new file is registered
90
+ self._trigger_callback()
91
+
92
+ def update_progress(self, tqdm_instance, n=1):
93
+ """Update progress for a tqdm instance."""
94
+ tqdm_id = str(id(tqdm_instance))
95
+ if tqdm_id in self.progress_data:
96
+ self.progress_data[tqdm_id]['current'] = getattr(tqdm_instance, 'n', 0)
97
+ self.progress_data[tqdm_id]['total'] = getattr(tqdm_instance, 'total', 0) or 0
98
+ # Trigger callback on every progress update
99
+ self._trigger_callback()
100
+
101
+ def calculate_speed(self, current_downloaded: int) -> float:
102
+ """Calculate download speed in bytes per second."""
103
+ current_time = time.time()
104
+ time_diff = current_time - self.last_time
105
+
106
+ if time_diff > 0 and self.last_downloaded > 0:
107
+ bytes_diff = current_downloaded - self.last_downloaded
108
+ speed = bytes_diff / time_diff
109
+
110
+ # Add to speed history for smoothing
111
+ self.speed_history.append(speed)
112
+ if len(self.speed_history) > self.max_speed_history:
113
+ self.speed_history.pop(0)
114
+
115
+ # Return smoothed speed
116
+ return sum(self.speed_history) / len(self.speed_history)
117
+
118
+ return 0.0
119
+
120
+ def format_bytes(self, bytes_value: int) -> str:
121
+ """Format bytes to human readable string."""
122
+ for unit in ['B', 'KB', 'MB', 'GB']:
123
+ if bytes_value < 1024.0:
124
+ return f"{bytes_value:.1f} {unit}"
125
+ bytes_value /= 1024.0
126
+ return f"{bytes_value:.1f} TB"
127
+
128
+ def format_speed(self, speed: float) -> str:
129
+ """Format speed to human readable string."""
130
+ if speed == 0:
131
+ return "0 B/s"
132
+
133
+ for unit in ['B/s', 'KB/s', 'MB/s', 'GB/s']:
134
+ if speed < 1024.0:
135
+ return f"{speed:.1f} {unit}"
136
+ speed /= 1024.0
137
+ return f"{speed:.1f} TB/s"
138
+
139
+ def get_progress_data(self) -> Dict[str, Any]:
140
+ """Get current progress data."""
141
+ total_downloaded = 0
142
+ active_file_count = 0
143
+ total_file_sizes = 0
144
+
145
+ for data in self.progress_data.values():
146
+ if data['total'] > 0:
147
+ total_downloaded += data['current']
148
+ total_file_sizes += data['total']
149
+ active_file_count += 1
150
+
151
+ # Calculate speed
152
+ speed = self.calculate_speed(total_downloaded)
153
+
154
+ # Update tracking variables
155
+ self.last_downloaded = total_downloaded
156
+ self.last_time = time.time()
157
+
158
+ # Determine total size - prioritize pre-fetched repo size, then aggregate file sizes
159
+ if self.total_repo_size > 0:
160
+ # Use pre-fetched repository info if available
161
+ total_size = self.total_repo_size
162
+ elif total_file_sizes > 0:
163
+ # Use sum of individual file sizes if available
164
+ total_size = total_file_sizes
165
+ else:
166
+ # Last resort - we don't know the total size yet
167
+ total_size = 0
168
+
169
+ file_count = self.repo_file_count if self.repo_file_count > 0 else active_file_count
170
+
171
+ # Calculate percentage - handle unknown total size gracefully
172
+ if total_size > 0:
173
+ percentage = min((total_downloaded / total_size * 100), 100.0)
174
+ else:
175
+ percentage = 0
176
+
177
+ # Calculate ETA
178
+ eta_seconds = None
179
+ if speed > 0 and total_size > total_downloaded:
180
+ eta_seconds = (total_size - total_downloaded) / speed
181
+
182
+ # Calculate elapsed time
183
+ elapsed_seconds = None
184
+ if self.download_start_time:
185
+ elapsed_seconds = time.time() - self.download_start_time
186
+
187
+ return {
188
+ 'status': self.download_status,
189
+ 'error_message': self.error_message,
190
+ 'progress': {
191
+ 'total_downloaded': total_downloaded,
192
+ 'total_size': total_size,
193
+ 'percentage': round(percentage, 2),
194
+ 'files_active': active_file_count,
195
+ 'files_total': file_count,
196
+ 'known_total': total_size > 0
197
+ },
198
+ 'speed': {
199
+ 'bytes_per_second': speed,
200
+ 'formatted': self.format_speed(speed)
201
+ },
202
+ 'formatting': {
203
+ 'downloaded': self.format_bytes(total_downloaded),
204
+ 'total_size': self.format_bytes(total_size)
205
+ },
206
+ 'timing': {
207
+ 'elapsed_seconds': elapsed_seconds,
208
+ 'eta_seconds': eta_seconds,
209
+ 'start_time': self.download_start_time
210
+ }
211
+ }
212
+
213
+ def _display_progress_bar(self, progress_data: Dict[str, Any]):
214
+ """Display a custom unified progress bar."""
215
+ if not self.show_progress:
216
+ return
217
+
218
+ # Clear previous line
219
+ if self.last_display_length > 0:
220
+ print('\r' + ' ' * self.last_display_length, end='\r')
221
+
222
+ progress_info = progress_data.get('progress', {})
223
+ speed_info = progress_data.get('speed', {})
224
+ timing_info = progress_data.get('timing', {})
225
+ formatting_info = progress_data.get('formatting', {})
226
+
227
+ percentage = progress_info.get('percentage', 0)
228
+ downloaded = formatting_info.get('downloaded', '0 B')
229
+ total_size_raw = progress_info.get('total_size', 0)
230
+ total_size = formatting_info.get('total_size', 'Unknown')
231
+ speed = speed_info.get('formatted', '0 B/s')
232
+ known_total = progress_info.get('known_total', False)
233
+
234
+ # Create progress bar
235
+ bar_width = 30
236
+ if known_total and total_size_raw > 0:
237
+ # Known total size - show actual progress
238
+ filled_width = int(bar_width * min(percentage, 100) / 100)
239
+ bar = '█' * filled_width + '░' * (bar_width - filled_width)
240
+ else:
241
+ # Unknown total size - show animated progress
242
+ animation_pos = int(time.time() * 2) % bar_width
243
+ bar = '░' * animation_pos + '█' + '░' * (bar_width - animation_pos - 1)
244
+
245
+ # Format the progress line
246
+ status = progress_data.get('status', 'unknown')
247
+ if status == 'downloading':
248
+ if known_total:
249
+ progress_line = f"[{bar}] {percentage:.1f}% | {downloaded}/{total_size} | {speed}"
250
+ else:
251
+ progress_line = f"[{bar}] {downloaded} | {speed} | Calculating size..."
252
+ elif status == 'completed':
253
+ progress_line = f"[{bar}] 100.0% | {downloaded} | Complete!"
254
+ elif status == 'error':
255
+ progress_line = f"Error: {progress_data.get('error_message', 'Unknown error')}"
256
+ else:
257
+ progress_line = f"Starting download..."
258
+
259
+ # Display and track length for next clear
260
+ print(progress_line, end='', flush=True)
261
+ self.last_display_length = len(progress_line)
262
+
263
+ def _clear_progress_bar(self):
264
+ """Clear the progress bar display."""
265
+ if self.show_progress and self.last_display_length > 0:
266
+ print('\r' + ' ' * self.last_display_length, end='\r')
267
+ print() # Move to next line
268
+ self.last_display_length = 0
269
+
270
+ def _trigger_callback(self):
271
+ """Trigger the progress callback if one is set."""
272
+ progress_data = self.get_progress_data()
273
+
274
+ if self.progress_callback:
275
+ try:
276
+ self.progress_callback(progress_data)
277
+ except Exception as e:
278
+ print(f"Error in progress callback: {e}")
279
+
280
+ # Show custom progress bar only if callback is enabled and show_progress is True
281
+ if self.progress_callback and self.show_progress:
282
+ self._display_progress_bar(progress_data)
283
+
284
+ def start_tracking(self):
285
+ """Start progress tracking (monkey patch tqdm)."""
286
+ if self.is_tracking:
287
+ return
288
+
289
+ # Store original methods
290
+ self.original_tqdm_update = tqdm.update
291
+ self.original_tqdm_init = tqdm.__init__
292
+ self.original_tqdm_display = tqdm.display
293
+ self.original_tqdm_write = tqdm.write
294
+
295
+ # Create references to self for the nested functions
296
+ tracker = self
297
+
298
+ def patched_init(self_tqdm, *args, **kwargs):
299
+ # Suppress tqdm display by redirecting to devnull
300
+ kwargs['file'] = open(os.devnull, 'w')
301
+ kwargs['disable'] = False # Keep enabled for tracking
302
+ kwargs['leave'] = False # Don't leave progress bar
303
+
304
+ result = tracker.original_tqdm_init(self_tqdm, *args, **kwargs)
305
+ tracker.register_tqdm(self_tqdm)
306
+ return result
307
+
308
+ def patched_update(self_tqdm, n=1):
309
+ result = tracker.original_tqdm_update(self_tqdm, n)
310
+ tracker.update_progress(self_tqdm, n)
311
+ return result
312
+
313
+ def patched_display(self_tqdm, msg=None, pos=None):
314
+ # Override display to show nothing
315
+ pass
316
+
317
+ def patched_write(self_tqdm, s, file=None, end="\n", nolock=False):
318
+ # Override write to prevent any output
319
+ pass
320
+
321
+ # Apply patches
322
+ tqdm.__init__ = patched_init
323
+ tqdm.update = patched_update
324
+ tqdm.display = patched_display
325
+ tqdm.write = patched_write
326
+
327
+ self.is_tracking = True
328
+ self.download_status = "downloading"
329
+ self.download_start_time = time.time()
330
+
331
+ # Trigger initial callback
332
+ self._trigger_callback()
333
+
334
+ def stop_tracking(self):
335
+ """Stop progress tracking and restore original tqdm."""
336
+ if not self.is_tracking:
337
+ return
338
+
339
+ # Restore original tqdm methods
340
+ if self.original_tqdm_update:
341
+ tqdm.update = self.original_tqdm_update
342
+ if self.original_tqdm_init:
343
+ tqdm.__init__ = self.original_tqdm_init
344
+ if hasattr(self, 'original_tqdm_display') and self.original_tqdm_display:
345
+ tqdm.display = self.original_tqdm_display
346
+ if hasattr(self, 'original_tqdm_write') and self.original_tqdm_write:
347
+ tqdm.write = self.original_tqdm_write
348
+
349
+ # Clean up any open devnull file handles from tqdm instances
350
+ for data in self.progress_data.values():
351
+ if 'tqdm_obj' in data and hasattr(data['tqdm_obj'], 'fp'):
352
+ try:
353
+ fp = data['tqdm_obj'].fp
354
+ if fp and fp != sys.stdout and fp != sys.stderr and not fp.closed:
355
+ fp.close()
356
+ except:
357
+ pass
358
+
359
+ self.is_tracking = False
360
+ if self.download_status == "downloading":
361
+ self.download_status = "completed"
362
+
363
+ # Trigger final callback and clear progress bar
364
+ self._trigger_callback()
365
+ self._clear_progress_bar()
366
+
367
+ def set_error(self, error_message: str):
368
+ """Set error status and trigger callback."""
369
+ self.download_status = "error"
370
+ self.error_message = error_message
371
+ self._trigger_callback()
372
+ self._clear_progress_bar()
nexaai/vlm.py ADDED
@@ -0,0 +1,120 @@
1
+ from typing import Generator, Optional, List, Dict, Any, Union
2
+ from abc import abstractmethod
3
+ import queue
4
+ import threading
5
+ import base64
6
+ from pathlib import Path
7
+
8
+ from nexaai.common import ModelConfig, GenerationConfig, MultiModalMessage
9
+ from nexaai.base import BaseModel
10
+
11
+
12
+ class VLM(BaseModel):
13
+ def __init__(self, m_cfg: ModelConfig = ModelConfig()):
14
+ """Initialize base VLM class."""
15
+ self._m_cfg = m_cfg
16
+ self._cancel_event = threading.Event() # New attribute to control cancellation
17
+
18
+ @classmethod
19
+ def _load_from(cls,
20
+ local_path: str,
21
+ mmproj_path: str,
22
+ m_cfg: ModelConfig = ModelConfig(),
23
+ plugin_id: str = "llama_cpp",
24
+ device_id: Optional[str] = None
25
+ ) -> 'VLM':
26
+ """Load VLM model from local path, routing to appropriate implementation.
27
+
28
+ Args:
29
+ local_path: Path to the main model file
30
+ mmproj_path: Path to the multimodal projection file
31
+ m_cfg: Model configuration
32
+ plugin_id: Plugin identifier
33
+ device_id: Optional device ID (not used in current binding)
34
+
35
+ Returns:
36
+ VLM instance
37
+ """
38
+ if plugin_id == "mlx":
39
+ from nexaai.vlm_impl.mlx_vlm_impl import MlxVlmImpl
40
+ return MlxVlmImpl._load_from(local_path, mmproj_path, m_cfg, plugin_id, device_id)
41
+ else:
42
+ from nexaai.vlm_impl.pybind_vlm_impl import PyBindVLMImpl
43
+ return PyBindVLMImpl._load_from(local_path, mmproj_path, m_cfg, plugin_id, device_id)
44
+
45
+ @abstractmethod
46
+ def eject(self):
47
+ """Release the model from memory."""
48
+ pass
49
+
50
+ def cancel_generation(self):
51
+ """Signal to cancel any ongoing stream generation."""
52
+ self._cancel_event.set()
53
+
54
+ def reset_cancel(self):
55
+ """Reset the cancel event. Call before starting a new generation if needed."""
56
+ self._cancel_event.clear()
57
+
58
+ @abstractmethod
59
+ def reset(self):
60
+ """
61
+ Reset the VLM model context and KV cache. If not reset, the model will skip the number of evaluated tokens and treat tokens after those as the new incremental tokens.
62
+ If your past chat history changed, or you are starting a new chat, you should always reset the model before running generate.
63
+ """
64
+ pass
65
+
66
+ def _process_image(self, image: Union[bytes, str, Path]) -> bytes:
67
+ """Process image input to bytes format.
68
+
69
+ Args:
70
+ image: Image data as bytes, base64 string, or file path
71
+
72
+ Returns:
73
+ Image data as bytes
74
+ """
75
+ if isinstance(image, bytes):
76
+ return image
77
+ elif isinstance(image, str):
78
+ # Check if it's a base64 string
79
+ if image.startswith('data:image'):
80
+ # Extract base64 data from data URL
81
+ base64_data = image.split(',')[1] if ',' in image else image
82
+ return base64.b64decode(base64_data)
83
+ else:
84
+ # Assume it's a file path
85
+ with open(image, 'rb') as f:
86
+ return f.read()
87
+ elif isinstance(image, Path):
88
+ with open(image, 'rb') as f:
89
+ return f.read()
90
+ else:
91
+ raise ValueError(f"Unsupported image type: {type(image)}")
92
+
93
+
94
+ @abstractmethod
95
+ def apply_chat_template(
96
+ self,
97
+ messages: List[MultiModalMessage],
98
+ tools: Optional[List[Dict[str, Any]]] = None
99
+ ) -> str:
100
+ """Apply the chat template to multimodal messages."""
101
+ pass
102
+
103
+ @abstractmethod
104
+ def generate_stream(self, prompt: str, g_cfg: GenerationConfig = GenerationConfig()) -> Generator[str, None, None]:
105
+ """Generate text with streaming."""
106
+ pass
107
+
108
+ @abstractmethod
109
+ def generate(self, prompt: str, g_cfg: GenerationConfig = GenerationConfig()) -> str:
110
+ """
111
+ Generate text without streaming.
112
+
113
+ Args:
114
+ prompt (str): The prompt to generate text from. For chat models, this is the chat messages after chat template is applied.
115
+ g_cfg (GenerationConfig): Generation configuration.
116
+
117
+ Returns:
118
+ str: The generated text.
119
+ """
120
+ pass
File without changes