nexaai 1.0.21rc16__cp312-cp312-win_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nexaai might be problematic. Click here for more details.

Files changed (154) hide show
  1. nexaai/__init__.py +95 -0
  2. nexaai/_stub.cp312-win_arm64.pyd +0 -0
  3. nexaai/_version.py +4 -0
  4. nexaai/asr.py +68 -0
  5. nexaai/asr_impl/__init__.py +0 -0
  6. nexaai/asr_impl/mlx_asr_impl.py +92 -0
  7. nexaai/asr_impl/pybind_asr_impl.py +127 -0
  8. nexaai/base.py +39 -0
  9. nexaai/binds/__init__.py +6 -0
  10. nexaai/binds/asr_bind.cp312-win_arm64.pyd +0 -0
  11. nexaai/binds/common_bind.cp312-win_arm64.pyd +0 -0
  12. nexaai/binds/cpu_gpu/ggml-base.dll +0 -0
  13. nexaai/binds/cpu_gpu/ggml-cpu.dll +0 -0
  14. nexaai/binds/cpu_gpu/ggml-opencl.dll +0 -0
  15. nexaai/binds/cpu_gpu/ggml.dll +0 -0
  16. nexaai/binds/cpu_gpu/libomp140.aarch64.dll +0 -0
  17. nexaai/binds/cpu_gpu/mtmd.dll +0 -0
  18. nexaai/binds/cpu_gpu/nexa_cpu_gpu.dll +0 -0
  19. nexaai/binds/cpu_gpu/nexa_plugin.dll +0 -0
  20. nexaai/binds/embedder_bind.cp312-win_arm64.pyd +0 -0
  21. nexaai/binds/libcrypto-3-arm64.dll +0 -0
  22. nexaai/binds/libssl-3-arm64.dll +0 -0
  23. nexaai/binds/llm_bind.cp312-win_arm64.pyd +0 -0
  24. nexaai/binds/nexa_bridge.dll +0 -0
  25. nexaai/binds/npu/FLAC.dll +0 -0
  26. nexaai/binds/npu/convnext-sdk.dll +0 -0
  27. nexaai/binds/npu/embed-gemma-sdk.dll +0 -0
  28. nexaai/binds/npu/fftw3.dll +0 -0
  29. nexaai/binds/npu/fftw3f.dll +0 -0
  30. nexaai/binds/npu/ggml-base.dll +0 -0
  31. nexaai/binds/npu/ggml-cpu.dll +0 -0
  32. nexaai/binds/npu/ggml-opencl.dll +0 -0
  33. nexaai/binds/npu/ggml.dll +0 -0
  34. nexaai/binds/npu/granite-nano-sdk.dll +0 -0
  35. nexaai/binds/npu/granite4-sdk.dll +0 -0
  36. nexaai/binds/npu/htp-files/Genie.dll +0 -0
  37. nexaai/binds/npu/htp-files/PlatformValidatorShared.dll +0 -0
  38. nexaai/binds/npu/htp-files/QnnChrometraceProfilingReader.dll +0 -0
  39. nexaai/binds/npu/htp-files/QnnCpu.dll +0 -0
  40. nexaai/binds/npu/htp-files/QnnCpuNetRunExtensions.dll +0 -0
  41. nexaai/binds/npu/htp-files/QnnDsp.dll +0 -0
  42. nexaai/binds/npu/htp-files/QnnDspNetRunExtensions.dll +0 -0
  43. nexaai/binds/npu/htp-files/QnnDspV66CalculatorStub.dll +0 -0
  44. nexaai/binds/npu/htp-files/QnnDspV66Stub.dll +0 -0
  45. nexaai/binds/npu/htp-files/QnnGenAiTransformer.dll +0 -0
  46. nexaai/binds/npu/htp-files/QnnGenAiTransformerCpuOpPkg.dll +0 -0
  47. nexaai/binds/npu/htp-files/QnnGenAiTransformerModel.dll +0 -0
  48. nexaai/binds/npu/htp-files/QnnGpu.dll +0 -0
  49. nexaai/binds/npu/htp-files/QnnGpuNetRunExtensions.dll +0 -0
  50. nexaai/binds/npu/htp-files/QnnGpuProfilingReader.dll +0 -0
  51. nexaai/binds/npu/htp-files/QnnHtp.dll +0 -0
  52. nexaai/binds/npu/htp-files/QnnHtpNetRunExtensions.dll +0 -0
  53. nexaai/binds/npu/htp-files/QnnHtpOptraceProfilingReader.dll +0 -0
  54. nexaai/binds/npu/htp-files/QnnHtpPrepare.dll +0 -0
  55. nexaai/binds/npu/htp-files/QnnHtpProfilingReader.dll +0 -0
  56. nexaai/binds/npu/htp-files/QnnHtpV68CalculatorStub.dll +0 -0
  57. nexaai/binds/npu/htp-files/QnnHtpV68Stub.dll +0 -0
  58. nexaai/binds/npu/htp-files/QnnHtpV73CalculatorStub.dll +0 -0
  59. nexaai/binds/npu/htp-files/QnnHtpV73Stub.dll +0 -0
  60. nexaai/binds/npu/htp-files/QnnIr.dll +0 -0
  61. nexaai/binds/npu/htp-files/QnnJsonProfilingReader.dll +0 -0
  62. nexaai/binds/npu/htp-files/QnnModelDlc.dll +0 -0
  63. nexaai/binds/npu/htp-files/QnnSaver.dll +0 -0
  64. nexaai/binds/npu/htp-files/QnnSystem.dll +0 -0
  65. nexaai/binds/npu/htp-files/SNPE.dll +0 -0
  66. nexaai/binds/npu/htp-files/SnpeDspV66Stub.dll +0 -0
  67. nexaai/binds/npu/htp-files/SnpeHtpPrepare.dll +0 -0
  68. nexaai/binds/npu/htp-files/SnpeHtpV68Stub.dll +0 -0
  69. nexaai/binds/npu/htp-files/SnpeHtpV73Stub.dll +0 -0
  70. nexaai/binds/npu/htp-files/calculator.dll +0 -0
  71. nexaai/binds/npu/htp-files/calculator_htp.dll +0 -0
  72. nexaai/binds/npu/htp-files/libCalculator_skel.so +0 -0
  73. nexaai/binds/npu/htp-files/libQnnHtpV73.so +0 -0
  74. nexaai/binds/npu/htp-files/libQnnHtpV73QemuDriver.so +0 -0
  75. nexaai/binds/npu/htp-files/libQnnHtpV73Skel.so +0 -0
  76. nexaai/binds/npu/htp-files/libQnnSaver.so +0 -0
  77. nexaai/binds/npu/htp-files/libQnnSystem.so +0 -0
  78. nexaai/binds/npu/htp-files/libSnpeHtpV73Skel.so +0 -0
  79. nexaai/binds/npu/htp-files/libqnnhtpv73.cat +0 -0
  80. nexaai/binds/npu/htp-files/libsnpehtpv73.cat +0 -0
  81. nexaai/binds/npu/jina-rerank-sdk.dll +0 -0
  82. nexaai/binds/npu/libcrypto-3-arm64.dll +0 -0
  83. nexaai/binds/npu/libmp3lame.DLL +0 -0
  84. nexaai/binds/npu/libomp140.aarch64.dll +0 -0
  85. nexaai/binds/npu/libssl-3-arm64.dll +0 -0
  86. nexaai/binds/npu/liquid-sdk.dll +0 -0
  87. nexaai/binds/npu/llama3-3b-sdk.dll +0 -0
  88. nexaai/binds/npu/mpg123.dll +0 -0
  89. nexaai/binds/npu/nexa-mm-process.dll +0 -0
  90. nexaai/binds/npu/nexa-sampling.dll +0 -0
  91. nexaai/binds/npu/nexa_plugin.dll +0 -0
  92. nexaai/binds/npu/nexaproc.dll +0 -0
  93. nexaai/binds/npu/ogg.dll +0 -0
  94. nexaai/binds/npu/omni-neural-sdk.dll +0 -0
  95. nexaai/binds/npu/openblas.dll +0 -0
  96. nexaai/binds/npu/opus.dll +0 -0
  97. nexaai/binds/npu/paddle-ocr-proc-lib.dll +0 -0
  98. nexaai/binds/npu/paddleocr-sdk.dll +0 -0
  99. nexaai/binds/npu/parakeet-sdk.dll +0 -0
  100. nexaai/binds/npu/phi3-5-sdk.dll +0 -0
  101. nexaai/binds/npu/phi4-sdk.dll +0 -0
  102. nexaai/binds/npu/pyannote-sdk.dll +0 -0
  103. nexaai/binds/npu/qwen3-4b-sdk.dll +0 -0
  104. nexaai/binds/npu/qwen3vl-sdk.dll +0 -0
  105. nexaai/binds/npu/qwen3vl-vision.dll +0 -0
  106. nexaai/binds/npu/rtaudio.dll +0 -0
  107. nexaai/binds/npu/vorbis.dll +0 -0
  108. nexaai/binds/npu/vorbisenc.dll +0 -0
  109. nexaai/binds/npu/yolov12-sdk.dll +0 -0
  110. nexaai/binds/npu/zlib1.dll +0 -0
  111. nexaai/binds/rerank_bind.cp312-win_arm64.pyd +0 -0
  112. nexaai/binds/vlm_bind.cp312-win_arm64.pyd +0 -0
  113. nexaai/common.py +105 -0
  114. nexaai/cv.py +93 -0
  115. nexaai/cv_impl/__init__.py +0 -0
  116. nexaai/cv_impl/mlx_cv_impl.py +89 -0
  117. nexaai/cv_impl/pybind_cv_impl.py +32 -0
  118. nexaai/embedder.py +73 -0
  119. nexaai/embedder_impl/__init__.py +0 -0
  120. nexaai/embedder_impl/mlx_embedder_impl.py +118 -0
  121. nexaai/embedder_impl/pybind_embedder_impl.py +96 -0
  122. nexaai/image_gen.py +141 -0
  123. nexaai/image_gen_impl/__init__.py +0 -0
  124. nexaai/image_gen_impl/mlx_image_gen_impl.py +292 -0
  125. nexaai/image_gen_impl/pybind_image_gen_impl.py +85 -0
  126. nexaai/llm.py +98 -0
  127. nexaai/llm_impl/__init__.py +0 -0
  128. nexaai/llm_impl/mlx_llm_impl.py +271 -0
  129. nexaai/llm_impl/pybind_llm_impl.py +220 -0
  130. nexaai/log.py +92 -0
  131. nexaai/rerank.py +57 -0
  132. nexaai/rerank_impl/__init__.py +0 -0
  133. nexaai/rerank_impl/mlx_rerank_impl.py +94 -0
  134. nexaai/rerank_impl/pybind_rerank_impl.py +136 -0
  135. nexaai/runtime.py +68 -0
  136. nexaai/runtime_error.py +24 -0
  137. nexaai/tts.py +75 -0
  138. nexaai/tts_impl/__init__.py +0 -0
  139. nexaai/tts_impl/mlx_tts_impl.py +94 -0
  140. nexaai/tts_impl/pybind_tts_impl.py +43 -0
  141. nexaai/utils/decode.py +18 -0
  142. nexaai/utils/manifest_utils.py +531 -0
  143. nexaai/utils/model_manager.py +1562 -0
  144. nexaai/utils/model_types.py +49 -0
  145. nexaai/utils/progress_tracker.py +385 -0
  146. nexaai/utils/quantization_utils.py +245 -0
  147. nexaai/vlm.py +130 -0
  148. nexaai/vlm_impl/__init__.py +0 -0
  149. nexaai/vlm_impl/mlx_vlm_impl.py +259 -0
  150. nexaai/vlm_impl/pybind_vlm_impl.py +256 -0
  151. nexaai-1.0.21rc16.dist-info/METADATA +31 -0
  152. nexaai-1.0.21rc16.dist-info/RECORD +154 -0
  153. nexaai-1.0.21rc16.dist-info/WHEEL +5 -0
  154. nexaai-1.0.21rc16.dist-info/top_level.txt +1 -0
@@ -0,0 +1,49 @@
1
+ """
2
+ Model type mappings for HuggingFace pipeline tags to our internal model types.
3
+
4
+ This module provides centralized model type mapping functionality to avoid
5
+ circular imports between other utility modules.
6
+ """
7
+
8
+ from enum import Enum
9
+ from typing import Dict
10
+
11
+
12
+ class ModelTypeMapping(Enum):
13
+ """Enum for mapping HuggingFace pipeline_tag to our ModelType."""
14
+ TEXT_GENERATION = ("text-generation", "llm")
15
+ IMAGE_TEXT_TO_TEXT = ("image-text-to-text", "vlm")
16
+ ANY_TO_ANY = ("any-to-any", "ata")
17
+ AUTOMATIC_SPEECH_RECOGNITION = ("automatic-speech-recognition", "asr")
18
+
19
+ def __init__(self, pipeline_tag: str, model_type: str):
20
+ self.pipeline_tag = pipeline_tag
21
+ self.model_type = model_type
22
+
23
+
24
+ # Create mapping dictionaries from the enum
25
+ PIPELINE_TO_MODEL_TYPE: Dict[str, str] = {
26
+ mapping.pipeline_tag: mapping.model_type
27
+ for mapping in ModelTypeMapping
28
+ }
29
+
30
+ MODEL_TYPE_TO_PIPELINE: Dict[str, str] = {
31
+ mapping.model_type: mapping.pipeline_tag
32
+ for mapping in ModelTypeMapping
33
+ }
34
+
35
+
36
+ def map_pipeline_tag_to_model_type(pipeline_tag: str) -> str:
37
+ """Map HuggingFace pipeline_tag to our ModelType."""
38
+ if not pipeline_tag:
39
+ return "other"
40
+
41
+ return PIPELINE_TO_MODEL_TYPE.get(pipeline_tag, "other")
42
+
43
+
44
+ def map_model_type_to_pipeline_tag(model_type: str) -> str:
45
+ """Reverse map ModelType back to HuggingFace pipeline_tag."""
46
+ if not model_type:
47
+ return None
48
+
49
+ return MODEL_TYPE_TO_PIPELINE.get(model_type)
@@ -0,0 +1,385 @@
1
+ """
2
+ Progress tracking utilities for downloads with tqdm integration.
3
+
4
+ This module provides custom progress tracking classes that can monitor
5
+ download progress with callback support and customizable display options.
6
+ """
7
+
8
+ import os
9
+ import sys
10
+ import time
11
+ from typing import Optional, Callable, Dict, Any
12
+ from tqdm.auto import tqdm
13
+
14
+
15
+ class CustomProgressTqdm(tqdm):
16
+ """Custom tqdm that tracks progress but completely hides terminal output."""
17
+
18
+ def __init__(self, *args, **kwargs):
19
+ # Redirect output to devnull to completely suppress terminal output
20
+ kwargs['file'] = open(os.devnull, 'w')
21
+ kwargs['disable'] = False # Keep enabled for tracking
22
+ kwargs['leave'] = False # Don't leave progress bar
23
+ super().__init__(*args, **kwargs)
24
+
25
+ def display(self, msg=None, pos=None):
26
+ # Override display to show nothing
27
+ pass
28
+
29
+ def write(self, s, file=None, end="\n", nolock=False):
30
+ # Override write to prevent any output
31
+ pass
32
+
33
+ def close(self):
34
+ # Override close to avoid printing and properly close devnull
35
+ if hasattr(self, 'fp') and self.fp and self.fp != sys.stdout and self.fp != sys.stderr:
36
+ try:
37
+ self.fp.close()
38
+ except:
39
+ pass
40
+ self.disable = True
41
+ super(tqdm, self).close()
42
+
43
+
44
+ class DownloadProgressTracker:
45
+ """Progress tracker for HuggingFace downloads with callback support."""
46
+
47
+ def __init__(self, progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None, show_progress: bool = True):
48
+ self.progress_data: Dict[str, Dict[str, Any]] = {}
49
+ self.total_repo_size = 0
50
+ self.repo_file_count = 0
51
+ self.original_tqdm_update = None
52
+ self.original_tqdm_init = None
53
+ self.original_tqdm_display = None
54
+ self.original_tqdm_write = None
55
+ self.is_tracking = False
56
+
57
+ # Callback function
58
+ self.progress_callback = progress_callback
59
+
60
+ # Progress display
61
+ self.show_progress = show_progress
62
+ self.last_display_length = 0
63
+
64
+ # Speed tracking
65
+ self.last_downloaded = None # Use None to indicate no previous measurement
66
+ self.last_time = None # Use None to indicate no previous time measurement
67
+ self.speed_history = []
68
+ self.max_speed_history = 10
69
+
70
+ # Download status
71
+ self.download_status = "idle" # idle, downloading, completed, error
72
+ self.error_message = None
73
+ self.download_start_time = None
74
+
75
+ def set_repo_info(self, total_size: int, file_count: int):
76
+ """Set the total repository size and file count before download."""
77
+ self.total_repo_size = total_size
78
+ self.repo_file_count = file_count
79
+
80
+ def register_tqdm(self, tqdm_instance):
81
+ """Register a tqdm instance for monitoring."""
82
+ tqdm_id = str(id(tqdm_instance))
83
+ self.progress_data[tqdm_id] = {
84
+ 'current': 0,
85
+ 'total': getattr(tqdm_instance, 'total', 0) or 0,
86
+ 'desc': getattr(tqdm_instance, 'desc', 'Unknown'),
87
+ 'tqdm_obj': tqdm_instance
88
+ }
89
+ # Trigger callback when new file is registered
90
+ self._trigger_callback()
91
+
92
+ def update_progress(self, tqdm_instance, n=1):
93
+ """Update progress for a tqdm instance."""
94
+ tqdm_id = str(id(tqdm_instance))
95
+ if tqdm_id in self.progress_data:
96
+ self.progress_data[tqdm_id]['current'] = getattr(tqdm_instance, 'n', 0)
97
+ self.progress_data[tqdm_id]['total'] = getattr(tqdm_instance, 'total', 0) or 0
98
+ # Trigger callback on every progress update
99
+ self._trigger_callback()
100
+
101
+ def calculate_speed(self, current_downloaded: int) -> float:
102
+ """Calculate download speed in bytes per second."""
103
+ current_time = time.time()
104
+
105
+ # Check if we have a previous measurement to compare against
106
+ if self.last_time is not None and self.last_downloaded is not None:
107
+ time_diff = current_time - self.last_time
108
+
109
+ # Only calculate if we have a meaningful time difference (avoid division by very small numbers)
110
+ if time_diff > 0.1: # At least 100ms between measurements
111
+ bytes_diff = current_downloaded - self.last_downloaded
112
+
113
+ # Only calculate speed if bytes actually changed
114
+ if bytes_diff >= 0: # Allow 0 for periods with no progress
115
+ speed = bytes_diff / time_diff
116
+
117
+ # Add to speed history for smoothing
118
+ self.speed_history.append(speed)
119
+ if len(self.speed_history) > self.max_speed_history:
120
+ self.speed_history.pop(0)
121
+
122
+ # Update tracking variables when we actually calculate speed
123
+ self.last_downloaded = current_downloaded
124
+ self.last_time = current_time
125
+ else:
126
+ # First measurement - initialize tracking variables
127
+ self.last_downloaded = current_downloaded
128
+ self.last_time = current_time
129
+
130
+ # Return the average of historical speeds if we have any
131
+ # This ensures we show the last known speed even when skipping updates
132
+ if self.speed_history:
133
+ return sum(self.speed_history) / len(self.speed_history)
134
+
135
+ return 0.0
136
+
137
+ def format_bytes(self, bytes_value: int) -> str:
138
+ """Format bytes to human readable string."""
139
+ for unit in ['B', 'KB', 'MB', 'GB']:
140
+ if bytes_value < 1024.0:
141
+ return f"{bytes_value:.1f} {unit}"
142
+ bytes_value /= 1024.0
143
+ return f"{bytes_value:.1f} TB"
144
+
145
+ def format_speed(self, speed: float) -> str:
146
+ """Format speed to human readable string."""
147
+ if speed == 0:
148
+ return "0 B/s"
149
+
150
+ for unit in ['B/s', 'KB/s', 'MB/s', 'GB/s']:
151
+ if speed < 1024.0:
152
+ return f"{speed:.1f} {unit}"
153
+ speed /= 1024.0
154
+ return f"{speed:.1f} TB/s"
155
+
156
+ def get_progress_data(self) -> Dict[str, Any]:
157
+ """Get current progress data."""
158
+ total_downloaded = 0
159
+ active_file_count = 0
160
+ total_file_sizes = 0
161
+
162
+ for data in self.progress_data.values():
163
+ if data['total'] > 0:
164
+ total_downloaded += data['current']
165
+ total_file_sizes += data['total']
166
+ active_file_count += 1
167
+
168
+ # Calculate speed (tracking variables are updated internally)
169
+ speed = self.calculate_speed(total_downloaded)
170
+
171
+ # Determine total size - prioritize pre-fetched repo size, then aggregate file sizes
172
+ if self.total_repo_size > 0:
173
+ # Use pre-fetched repository info if available
174
+ total_size = self.total_repo_size
175
+ elif total_file_sizes > 0:
176
+ # Use sum of individual file sizes if available
177
+ total_size = total_file_sizes
178
+ else:
179
+ # Last resort - we don't know the total size yet
180
+ total_size = 0
181
+
182
+ file_count = self.repo_file_count if self.repo_file_count > 0 else active_file_count
183
+
184
+ # Calculate percentage - handle unknown total size gracefully
185
+ if total_size > 0:
186
+ percentage = min((total_downloaded / total_size * 100), 100.0)
187
+ else:
188
+ percentage = 0
189
+
190
+ # Calculate ETA
191
+ eta_seconds = None
192
+ if speed > 0 and total_size > total_downloaded:
193
+ eta_seconds = (total_size - total_downloaded) / speed
194
+
195
+ # Calculate elapsed time
196
+ elapsed_seconds = None
197
+ if self.download_start_time:
198
+ elapsed_seconds = time.time() - self.download_start_time
199
+
200
+ return {
201
+ 'status': self.download_status,
202
+ 'error_message': self.error_message,
203
+ 'progress': {
204
+ 'total_downloaded': total_downloaded,
205
+ 'total_size': total_size,
206
+ 'percentage': round(percentage, 2),
207
+ 'files_active': active_file_count,
208
+ 'files_total': file_count,
209
+ 'known_total': total_size > 0
210
+ },
211
+ 'speed': {
212
+ 'bytes_per_second': speed,
213
+ 'formatted': self.format_speed(speed)
214
+ },
215
+ 'formatting': {
216
+ 'downloaded': self.format_bytes(total_downloaded),
217
+ 'total_size': self.format_bytes(total_size)
218
+ },
219
+ 'timing': {
220
+ 'elapsed_seconds': elapsed_seconds,
221
+ 'eta_seconds': eta_seconds,
222
+ 'start_time': self.download_start_time
223
+ }
224
+ }
225
+
226
+ def _display_progress_bar(self, progress_data: Dict[str, Any]):
227
+ """Display a custom unified progress bar."""
228
+ if not self.show_progress:
229
+ return
230
+
231
+ # Clear previous line
232
+ if self.last_display_length > 0:
233
+ print('\r' + ' ' * self.last_display_length, end='\r')
234
+
235
+ progress_info = progress_data.get('progress', {})
236
+ speed_info = progress_data.get('speed', {})
237
+ timing_info = progress_data.get('timing', {})
238
+ formatting_info = progress_data.get('formatting', {})
239
+
240
+ percentage = progress_info.get('percentage', 0)
241
+ downloaded = formatting_info.get('downloaded', '0 B')
242
+ total_size_raw = progress_info.get('total_size', 0)
243
+ total_size = formatting_info.get('total_size', 'Unknown')
244
+ speed = speed_info.get('formatted', '0 B/s')
245
+ known_total = progress_info.get('known_total', False)
246
+
247
+ # Create progress bar
248
+ bar_width = 30
249
+ if known_total and total_size_raw > 0:
250
+ # Known total size - show actual progress
251
+ filled_width = int(bar_width * min(percentage, 100) / 100)
252
+ bar = '#' * filled_width + '-' * (bar_width - filled_width)
253
+ else:
254
+ # Unknown total size - show animated progress
255
+ animation_pos = int(time.time() * 2) % bar_width
256
+ bar = '-' * animation_pos + '#' + '-' * (bar_width - animation_pos - 1)
257
+
258
+ # Format the progress line
259
+ status = progress_data.get('status', 'unknown')
260
+ if status == 'downloading':
261
+ if known_total:
262
+ progress_line = f"[{bar}] {percentage:.1f}% | {downloaded}/{total_size} | {speed}"
263
+ else:
264
+ progress_line = f"[{bar}] {downloaded} | {speed} | Calculating size..."
265
+ elif status == 'completed':
266
+ progress_line = f"[{bar}] 100.0% | {downloaded} | Complete!"
267
+ elif status == 'error':
268
+ progress_line = f"Error: {progress_data.get('error_message', 'Unknown error')}"
269
+ else:
270
+ progress_line = f"Starting download..."
271
+
272
+ # Display and track length for next clear
273
+ print(progress_line, end='', flush=True)
274
+ self.last_display_length = len(progress_line)
275
+
276
+ def _clear_progress_bar(self):
277
+ """Clear the progress bar display."""
278
+ if self.show_progress and self.last_display_length > 0:
279
+ print('\r' + ' ' * self.last_display_length, end='\r')
280
+ print() # Move to next line
281
+ self.last_display_length = 0
282
+
283
+ def _trigger_callback(self):
284
+ """Trigger the progress callback if one is set."""
285
+ progress_data = self.get_progress_data()
286
+
287
+ if self.progress_callback:
288
+ try:
289
+ self.progress_callback(progress_data)
290
+ except Exception as e:
291
+ print(f"Error in progress callback: {e}")
292
+
293
+ # Show custom progress bar only if callback is enabled and show_progress is True
294
+ if self.progress_callback and self.show_progress:
295
+ self._display_progress_bar(progress_data)
296
+
297
+ def start_tracking(self):
298
+ """Start progress tracking (monkey patch tqdm)."""
299
+ if self.is_tracking:
300
+ return
301
+
302
+ # Store original methods
303
+ self.original_tqdm_update = tqdm.update
304
+ self.original_tqdm_init = tqdm.__init__
305
+ self.original_tqdm_display = tqdm.display
306
+ self.original_tqdm_write = tqdm.write
307
+
308
+ # Create references to self for the nested functions
309
+ tracker = self
310
+
311
+ def patched_init(self_tqdm, *args, **kwargs):
312
+ # Suppress tqdm display by redirecting to devnull
313
+ kwargs['file'] = open(os.devnull, 'w')
314
+ kwargs['disable'] = False # Keep enabled for tracking
315
+ kwargs['leave'] = False # Don't leave progress bar
316
+
317
+ result = tracker.original_tqdm_init(self_tqdm, *args, **kwargs)
318
+ tracker.register_tqdm(self_tqdm)
319
+ return result
320
+
321
+ def patched_update(self_tqdm, n=1):
322
+ result = tracker.original_tqdm_update(self_tqdm, n)
323
+ tracker.update_progress(self_tqdm, n)
324
+ return result
325
+
326
+ def patched_display(self_tqdm, msg=None, pos=None):
327
+ # Override display to show nothing
328
+ pass
329
+
330
+ def patched_write(self_tqdm, s, file=None, end="\n", nolock=False):
331
+ # Override write to prevent any output
332
+ pass
333
+
334
+ # Apply patches
335
+ tqdm.__init__ = patched_init
336
+ tqdm.update = patched_update
337
+ tqdm.display = patched_display
338
+ tqdm.write = patched_write
339
+
340
+ self.is_tracking = True
341
+ self.download_status = "downloading"
342
+ self.download_start_time = time.time()
343
+
344
+ # Trigger initial callback
345
+ self._trigger_callback()
346
+
347
+ def stop_tracking(self):
348
+ """Stop progress tracking and restore original tqdm."""
349
+ if not self.is_tracking:
350
+ return
351
+
352
+ # Restore original tqdm methods
353
+ if self.original_tqdm_update:
354
+ tqdm.update = self.original_tqdm_update
355
+ if self.original_tqdm_init:
356
+ tqdm.__init__ = self.original_tqdm_init
357
+ if hasattr(self, 'original_tqdm_display') and self.original_tqdm_display:
358
+ tqdm.display = self.original_tqdm_display
359
+ if hasattr(self, 'original_tqdm_write') and self.original_tqdm_write:
360
+ tqdm.write = self.original_tqdm_write
361
+
362
+ # Clean up any open devnull file handles from tqdm instances
363
+ for data in self.progress_data.values():
364
+ if 'tqdm_obj' in data and hasattr(data['tqdm_obj'], 'fp'):
365
+ try:
366
+ fp = data['tqdm_obj'].fp
367
+ if fp and fp != sys.stdout and fp != sys.stderr and not fp.closed:
368
+ fp.close()
369
+ except:
370
+ pass
371
+
372
+ self.is_tracking = False
373
+ if self.download_status == "downloading":
374
+ self.download_status = "completed"
375
+
376
+ # Trigger final callback and clear progress bar
377
+ self._trigger_callback()
378
+ self._clear_progress_bar()
379
+
380
+ def set_error(self, error_message: str):
381
+ """Set error status and trigger callback."""
382
+ self.download_status = "error"
383
+ self.error_message = error_message
384
+ self._trigger_callback()
385
+ self._clear_progress_bar()
@@ -0,0 +1,245 @@
1
+ """
2
+ Quantization utilities for extracting quantization types from model files and configurations.
3
+
4
+ This module provides utilities to extract quantization information from:
5
+ - GGUF model filenames
6
+ - MLX model repository IDs
7
+ - MLX model config.json files
8
+ """
9
+
10
+ import os
11
+ import json
12
+ import re
13
+ import logging
14
+ from enum import Enum
15
+ from typing import Optional
16
+
17
+ # Set up logger
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class QuantizationType(str, Enum):
22
+ """Enum for GGUF and MLX model quantization types."""
23
+ # GGUF quantization types
24
+ BF16 = "BF16"
25
+ F16 = "F16"
26
+ Q2_K = "Q2_K"
27
+ Q2_K_L = "Q2_K_L"
28
+ Q3_K = "Q3_K"
29
+ Q3_K_M = "Q3_K_M"
30
+ Q3_K_S = "Q3_K_S"
31
+ Q4_0 = "Q4_0"
32
+ Q4_1 = "Q4_1"
33
+ Q4_K = "Q4_K"
34
+ Q4_K_M = "Q4_K_M"
35
+ Q4_K_S = "Q4_K_S"
36
+ Q5_K = "Q5_K"
37
+ Q5_K_M = "Q5_K_M"
38
+ Q5_K_S = "Q5_K_S"
39
+ Q6_K = "Q6_K"
40
+ Q8_0 = "Q8_0"
41
+ MXFP4 = "MXFP4"
42
+ MXFP8 = "MXFP8"
43
+
44
+ # MLX bit-based quantization types
45
+ BIT_1 = "1BIT"
46
+ BIT_2 = "2BIT"
47
+ BIT_3 = "3BIT"
48
+ BIT_4 = "4BIT"
49
+ BIT_5 = "5BIT"
50
+ BIT_6 = "6BIT"
51
+ BIT_7 = "7BIT"
52
+ BIT_8 = "8BIT"
53
+ BIT_16 = "16BIT"
54
+
55
+
56
+ def extract_quantization_from_filename(filename: str) -> Optional[QuantizationType]:
57
+ """
58
+ Extract quantization type from filename.
59
+
60
+ Args:
61
+ filename: The filename to extract quantization from
62
+
63
+ Returns:
64
+ QuantizationType enum value or None if not found
65
+ """
66
+ # Define mapping from lowercase patterns to enum values
67
+ # Include "." to ensure precise matching (e.g., "q4_0." not "q4_0_xl")
68
+ pattern_to_enum = {
69
+ 'bf16.': QuantizationType.BF16,
70
+ 'f16.': QuantizationType.F16, # Add F16 support
71
+ 'q2_k_l.': QuantizationType.Q2_K_L, # Check Q2_K_L before Q2_K to avoid partial match
72
+ 'q2_k.': QuantizationType.Q2_K,
73
+ 'q3_k.': QuantizationType.Q3_K,
74
+ 'q3_k_m.': QuantizationType.Q3_K_M,
75
+ 'q3_k_s.': QuantizationType.Q3_K_S,
76
+ 'q4_k_m.': QuantizationType.Q4_K_M,
77
+ 'q4_k_s.': QuantizationType.Q4_K_S,
78
+ 'q4_0.': QuantizationType.Q4_0,
79
+ 'q4_1.': QuantizationType.Q4_1,
80
+ 'q4_k.': QuantizationType.Q4_K,
81
+ 'q5_k.': QuantizationType.Q5_K,
82
+ 'q5_k_m.': QuantizationType.Q5_K_M,
83
+ 'q5_k_s.': QuantizationType.Q5_K_S,
84
+ 'q6_k.': QuantizationType.Q6_K,
85
+ 'q8_0.': QuantizationType.Q8_0,
86
+ 'mxfp4.': QuantizationType.MXFP4,
87
+ 'mxfp8.': QuantizationType.MXFP8,
88
+ }
89
+
90
+ filename_lower = filename.lower()
91
+
92
+ # Check longer patterns first to avoid partial matches
93
+ # Sort by length descending to check q2_k_l before q2_k, q4_k_m before q4_0, etc.
94
+ for pattern in sorted(pattern_to_enum.keys(), key=len, reverse=True):
95
+ if pattern in filename_lower:
96
+ return pattern_to_enum[pattern]
97
+
98
+ return None
99
+
100
+
101
+ def extract_quantization_from_repo_id(repo_id: str) -> Optional[QuantizationType]:
102
+ """
103
+ Extract quantization type from repo_id for MLX models by looking for bit patterns.
104
+
105
+ Args:
106
+ repo_id: The repository ID to extract quantization from
107
+
108
+ Returns:
109
+ QuantizationType enum value or None if not found
110
+ """
111
+ # Define mapping from bit numbers to enum values
112
+ bit_to_enum = {
113
+ 1: QuantizationType.BIT_1,
114
+ 2: QuantizationType.BIT_2,
115
+ 3: QuantizationType.BIT_3,
116
+ 4: QuantizationType.BIT_4,
117
+ 5: QuantizationType.BIT_5,
118
+ 6: QuantizationType.BIT_6,
119
+ 7: QuantizationType.BIT_7,
120
+ 8: QuantizationType.BIT_8,
121
+ 16: QuantizationType.BIT_16,
122
+ }
123
+
124
+ # First check for patterns like "4bit", "8bit" etc. (case insensitive)
125
+ pattern = r'(\d+)bit'
126
+ matches = re.findall(pattern, repo_id.lower())
127
+
128
+ for match in matches:
129
+ try:
130
+ bit_number = int(match)
131
+ if bit_number in bit_to_enum:
132
+ logger.debug(f"Found {bit_number}bit quantization in repo_id: {repo_id}")
133
+ return bit_to_enum[bit_number]
134
+ except ValueError:
135
+ continue
136
+
137
+ # Also check for patterns like "-q8", "_Q4" etc.
138
+ q_pattern = r'[-_]q(\d+)'
139
+ q_matches = re.findall(q_pattern, repo_id.lower())
140
+
141
+ for match in q_matches:
142
+ try:
143
+ bit_number = int(match)
144
+ if bit_number in bit_to_enum:
145
+ logger.debug(f"Found Q{bit_number} quantization in repo_id: {repo_id}")
146
+ return bit_to_enum[bit_number]
147
+ except ValueError:
148
+ continue
149
+
150
+ return None
151
+
152
+
153
+ def extract_quantization_from_mlx_config(mlx_folder_path: str) -> Optional[QuantizationType]:
154
+ """
155
+ Extract quantization type from MLX model's config.json file.
156
+
157
+ Args:
158
+ mlx_folder_path: Path to the MLX model folder
159
+
160
+ Returns:
161
+ QuantizationType enum value or None if not found
162
+ """
163
+ config_path = os.path.join(mlx_folder_path, "config.json")
164
+
165
+ if not os.path.exists(config_path):
166
+ logger.debug(f"Config file not found: {config_path}")
167
+ return None
168
+
169
+ try:
170
+ with open(config_path, 'r', encoding='utf-8') as f:
171
+ config = json.load(f)
172
+
173
+ # Look for quantization.bits field
174
+ quantization_config = config.get("quantization", {})
175
+ if isinstance(quantization_config, dict):
176
+ bits = quantization_config.get("bits")
177
+ if isinstance(bits, int):
178
+ # Define mapping from bit numbers to enum values
179
+ bit_to_enum = {
180
+ 1: QuantizationType.BIT_1,
181
+ 2: QuantizationType.BIT_2,
182
+ 3: QuantizationType.BIT_3,
183
+ 4: QuantizationType.BIT_4,
184
+ 5: QuantizationType.BIT_5,
185
+ 6: QuantizationType.BIT_6,
186
+ 7: QuantizationType.BIT_7,
187
+ 8: QuantizationType.BIT_8,
188
+ 16: QuantizationType.BIT_16,
189
+ }
190
+
191
+ if bits in bit_to_enum:
192
+ logger.debug(f"Found {bits}bit quantization in config.json: {config_path}")
193
+ return bit_to_enum[bits]
194
+ else:
195
+ logger.debug(f"Unsupported quantization bits value: {bits}")
196
+
197
+ except (json.JSONDecodeError, IOError) as e:
198
+ logger.warning(f"Error reading config.json from {config_path}: {e}")
199
+ except Exception as e:
200
+ logger.warning(f"Unexpected error reading config.json from {config_path}: {e}")
201
+
202
+ return None
203
+
204
+
205
+ def extract_gguf_quantization(filename: str) -> str:
206
+ """
207
+ Extract quantization level from GGUF filename using the enum-based approach.
208
+
209
+ This function provides backward compatibility by returning a string representation
210
+ of the quantization type.
211
+
212
+ Args:
213
+ filename: The GGUF filename
214
+
215
+ Returns:
216
+ String representation of the quantization type or "UNKNOWN" if not found
217
+ """
218
+ quantization_type = extract_quantization_from_filename(filename)
219
+ if quantization_type:
220
+ return quantization_type.value
221
+ return "UNKNOWN"
222
+
223
+
224
+ def detect_quantization_for_mlx(repo_id: str, directory_path: str) -> Optional[QuantizationType]:
225
+ """
226
+ Detect quantization for MLX models using multiple methods in priority order.
227
+
228
+ Args:
229
+ repo_id: The repository ID
230
+ directory_path: Path to the model directory
231
+
232
+ Returns:
233
+ QuantizationType enum value or None if not found
234
+ """
235
+ # Method 1: Extract from repo_id
236
+ quantization_type = extract_quantization_from_repo_id(repo_id)
237
+ if quantization_type:
238
+ return quantization_type
239
+
240
+ # Method 2: Extract from config.json if available
241
+ quantization_type = extract_quantization_from_mlx_config(directory_path)
242
+ if quantization_type:
243
+ return quantization_type
244
+
245
+ return None