nexaai 1.0.21__cp313-cp313-win_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nexaai might be problematic. Click here for more details.
- nexaai/__init__.py +95 -0
- nexaai/_stub.cp313-win_arm64.pyd +0 -0
- nexaai/_version.py +4 -0
- nexaai/asr.py +68 -0
- nexaai/asr_impl/__init__.py +0 -0
- nexaai/asr_impl/mlx_asr_impl.py +92 -0
- nexaai/asr_impl/pybind_asr_impl.py +127 -0
- nexaai/base.py +39 -0
- nexaai/binds/__init__.py +6 -0
- nexaai/binds/asr_bind.cp313-win_arm64.pyd +0 -0
- nexaai/binds/common_bind.cp313-win_arm64.pyd +0 -0
- nexaai/binds/cpu_gpu/ggml-base.dll +0 -0
- nexaai/binds/cpu_gpu/ggml-cpu.dll +0 -0
- nexaai/binds/cpu_gpu/ggml-opencl.dll +0 -0
- nexaai/binds/cpu_gpu/ggml.dll +0 -0
- nexaai/binds/cpu_gpu/libomp140.aarch64.dll +0 -0
- nexaai/binds/cpu_gpu/mtmd.dll +0 -0
- nexaai/binds/cpu_gpu/nexa_cpu_gpu.dll +0 -0
- nexaai/binds/cpu_gpu/nexa_plugin.dll +0 -0
- nexaai/binds/embedder_bind.cp313-win_arm64.pyd +0 -0
- nexaai/binds/libcrypto-3-arm64.dll +0 -0
- nexaai/binds/libssl-3-arm64.dll +0 -0
- nexaai/binds/llm_bind.cp313-win_arm64.pyd +0 -0
- nexaai/binds/nexa_bridge.dll +0 -0
- nexaai/binds/npu/FLAC.dll +0 -0
- nexaai/binds/npu/convnext-sdk.dll +0 -0
- nexaai/binds/npu/embed-gemma-sdk.dll +0 -0
- nexaai/binds/npu/fftw3.dll +0 -0
- nexaai/binds/npu/fftw3f.dll +0 -0
- nexaai/binds/npu/ggml-base.dll +0 -0
- nexaai/binds/npu/ggml-cpu.dll +0 -0
- nexaai/binds/npu/ggml-opencl.dll +0 -0
- nexaai/binds/npu/ggml.dll +0 -0
- nexaai/binds/npu/granite-nano-sdk.dll +0 -0
- nexaai/binds/npu/granite4-sdk.dll +0 -0
- nexaai/binds/npu/htp-files/Genie.dll +0 -0
- nexaai/binds/npu/htp-files/PlatformValidatorShared.dll +0 -0
- nexaai/binds/npu/htp-files/QnnChrometraceProfilingReader.dll +0 -0
- nexaai/binds/npu/htp-files/QnnCpu.dll +0 -0
- nexaai/binds/npu/htp-files/QnnCpuNetRunExtensions.dll +0 -0
- nexaai/binds/npu/htp-files/QnnDsp.dll +0 -0
- nexaai/binds/npu/htp-files/QnnDspNetRunExtensions.dll +0 -0
- nexaai/binds/npu/htp-files/QnnDspV66CalculatorStub.dll +0 -0
- nexaai/binds/npu/htp-files/QnnDspV66Stub.dll +0 -0
- nexaai/binds/npu/htp-files/QnnGenAiTransformer.dll +0 -0
- nexaai/binds/npu/htp-files/QnnGenAiTransformerCpuOpPkg.dll +0 -0
- nexaai/binds/npu/htp-files/QnnGenAiTransformerModel.dll +0 -0
- nexaai/binds/npu/htp-files/QnnGpu.dll +0 -0
- nexaai/binds/npu/htp-files/QnnGpuNetRunExtensions.dll +0 -0
- nexaai/binds/npu/htp-files/QnnGpuProfilingReader.dll +0 -0
- nexaai/binds/npu/htp-files/QnnHtp.dll +0 -0
- nexaai/binds/npu/htp-files/QnnHtpNetRunExtensions.dll +0 -0
- nexaai/binds/npu/htp-files/QnnHtpOptraceProfilingReader.dll +0 -0
- nexaai/binds/npu/htp-files/QnnHtpPrepare.dll +0 -0
- nexaai/binds/npu/htp-files/QnnHtpProfilingReader.dll +0 -0
- nexaai/binds/npu/htp-files/QnnHtpV68CalculatorStub.dll +0 -0
- nexaai/binds/npu/htp-files/QnnHtpV68Stub.dll +0 -0
- nexaai/binds/npu/htp-files/QnnHtpV73CalculatorStub.dll +0 -0
- nexaai/binds/npu/htp-files/QnnHtpV73Stub.dll +0 -0
- nexaai/binds/npu/htp-files/QnnIr.dll +0 -0
- nexaai/binds/npu/htp-files/QnnJsonProfilingReader.dll +0 -0
- nexaai/binds/npu/htp-files/QnnModelDlc.dll +0 -0
- nexaai/binds/npu/htp-files/QnnSaver.dll +0 -0
- nexaai/binds/npu/htp-files/QnnSystem.dll +0 -0
- nexaai/binds/npu/htp-files/SNPE.dll +0 -0
- nexaai/binds/npu/htp-files/SnpeDspV66Stub.dll +0 -0
- nexaai/binds/npu/htp-files/SnpeHtpPrepare.dll +0 -0
- nexaai/binds/npu/htp-files/SnpeHtpV68Stub.dll +0 -0
- nexaai/binds/npu/htp-files/SnpeHtpV73Stub.dll +0 -0
- nexaai/binds/npu/htp-files/calculator.dll +0 -0
- nexaai/binds/npu/htp-files/calculator_htp.dll +0 -0
- nexaai/binds/npu/htp-files/libCalculator_skel.so +0 -0
- nexaai/binds/npu/htp-files/libQnnHtpV73.so +0 -0
- nexaai/binds/npu/htp-files/libQnnHtpV73QemuDriver.so +0 -0
- nexaai/binds/npu/htp-files/libQnnHtpV73Skel.so +0 -0
- nexaai/binds/npu/htp-files/libQnnSaver.so +0 -0
- nexaai/binds/npu/htp-files/libQnnSystem.so +0 -0
- nexaai/binds/npu/htp-files/libSnpeHtpV73Skel.so +0 -0
- nexaai/binds/npu/htp-files/libqnnhtpv73.cat +0 -0
- nexaai/binds/npu/htp-files/libsnpehtpv73.cat +0 -0
- nexaai/binds/npu/jina-rerank-sdk.dll +0 -0
- nexaai/binds/npu/libcrypto-3-arm64.dll +0 -0
- nexaai/binds/npu/libmp3lame.DLL +0 -0
- nexaai/binds/npu/libomp140.aarch64.dll +0 -0
- nexaai/binds/npu/libssl-3-arm64.dll +0 -0
- nexaai/binds/npu/liquid-sdk.dll +0 -0
- nexaai/binds/npu/llama3-3b-sdk.dll +0 -0
- nexaai/binds/npu/mpg123.dll +0 -0
- nexaai/binds/npu/nexa-mm-process.dll +0 -0
- nexaai/binds/npu/nexa-sampling.dll +0 -0
- nexaai/binds/npu/nexa_plugin.dll +0 -0
- nexaai/binds/npu/nexaproc.dll +0 -0
- nexaai/binds/npu/ogg.dll +0 -0
- nexaai/binds/npu/omni-neural-sdk.dll +0 -0
- nexaai/binds/npu/openblas.dll +0 -0
- nexaai/binds/npu/opus.dll +0 -0
- nexaai/binds/npu/paddle-ocr-proc-lib.dll +0 -0
- nexaai/binds/npu/paddleocr-sdk.dll +0 -0
- nexaai/binds/npu/parakeet-sdk.dll +0 -0
- nexaai/binds/npu/phi3-5-sdk.dll +0 -0
- nexaai/binds/npu/phi4-sdk.dll +0 -0
- nexaai/binds/npu/pyannote-sdk.dll +0 -0
- nexaai/binds/npu/qwen3-4b-sdk.dll +0 -0
- nexaai/binds/npu/qwen3vl-sdk.dll +0 -0
- nexaai/binds/npu/qwen3vl-vision.dll +0 -0
- nexaai/binds/npu/rtaudio.dll +0 -0
- nexaai/binds/npu/vorbis.dll +0 -0
- nexaai/binds/npu/vorbisenc.dll +0 -0
- nexaai/binds/npu/yolov12-sdk.dll +0 -0
- nexaai/binds/npu/zlib1.dll +0 -0
- nexaai/binds/rerank_bind.cp313-win_arm64.pyd +0 -0
- nexaai/binds/vlm_bind.cp313-win_arm64.pyd +0 -0
- nexaai/common.py +105 -0
- nexaai/cv.py +93 -0
- nexaai/cv_impl/__init__.py +0 -0
- nexaai/cv_impl/mlx_cv_impl.py +89 -0
- nexaai/cv_impl/pybind_cv_impl.py +32 -0
- nexaai/embedder.py +73 -0
- nexaai/embedder_impl/__init__.py +0 -0
- nexaai/embedder_impl/mlx_embedder_impl.py +118 -0
- nexaai/embedder_impl/pybind_embedder_impl.py +96 -0
- nexaai/image_gen.py +141 -0
- nexaai/image_gen_impl/__init__.py +0 -0
- nexaai/image_gen_impl/mlx_image_gen_impl.py +292 -0
- nexaai/image_gen_impl/pybind_image_gen_impl.py +85 -0
- nexaai/llm.py +98 -0
- nexaai/llm_impl/__init__.py +0 -0
- nexaai/llm_impl/mlx_llm_impl.py +271 -0
- nexaai/llm_impl/pybind_llm_impl.py +220 -0
- nexaai/log.py +92 -0
- nexaai/rerank.py +57 -0
- nexaai/rerank_impl/__init__.py +0 -0
- nexaai/rerank_impl/mlx_rerank_impl.py +94 -0
- nexaai/rerank_impl/pybind_rerank_impl.py +136 -0
- nexaai/runtime.py +68 -0
- nexaai/runtime_error.py +24 -0
- nexaai/tts.py +75 -0
- nexaai/tts_impl/__init__.py +0 -0
- nexaai/tts_impl/mlx_tts_impl.py +94 -0
- nexaai/tts_impl/pybind_tts_impl.py +43 -0
- nexaai/utils/decode.py +18 -0
- nexaai/utils/manifest_utils.py +531 -0
- nexaai/utils/model_manager.py +1562 -0
- nexaai/utils/model_types.py +49 -0
- nexaai/utils/progress_tracker.py +385 -0
- nexaai/utils/quantization_utils.py +245 -0
- nexaai/vlm.py +130 -0
- nexaai/vlm_impl/__init__.py +0 -0
- nexaai/vlm_impl/mlx_vlm_impl.py +259 -0
- nexaai/vlm_impl/pybind_vlm_impl.py +256 -0
- nexaai-1.0.21.dist-info/METADATA +31 -0
- nexaai-1.0.21.dist-info/RECORD +154 -0
- nexaai-1.0.21.dist-info/WHEEL +5 -0
- nexaai-1.0.21.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Model type mappings for HuggingFace pipeline tags to our internal model types.
|
|
3
|
+
|
|
4
|
+
This module provides centralized model type mapping functionality to avoid
|
|
5
|
+
circular imports between other utility modules.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from enum import Enum
|
|
9
|
+
from typing import Dict
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ModelTypeMapping(Enum):
|
|
13
|
+
"""Enum for mapping HuggingFace pipeline_tag to our ModelType."""
|
|
14
|
+
TEXT_GENERATION = ("text-generation", "llm")
|
|
15
|
+
IMAGE_TEXT_TO_TEXT = ("image-text-to-text", "vlm")
|
|
16
|
+
ANY_TO_ANY = ("any-to-any", "ata")
|
|
17
|
+
AUTOMATIC_SPEECH_RECOGNITION = ("automatic-speech-recognition", "asr")
|
|
18
|
+
|
|
19
|
+
def __init__(self, pipeline_tag: str, model_type: str):
|
|
20
|
+
self.pipeline_tag = pipeline_tag
|
|
21
|
+
self.model_type = model_type
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
# Create mapping dictionaries from the enum
|
|
25
|
+
PIPELINE_TO_MODEL_TYPE: Dict[str, str] = {
|
|
26
|
+
mapping.pipeline_tag: mapping.model_type
|
|
27
|
+
for mapping in ModelTypeMapping
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
MODEL_TYPE_TO_PIPELINE: Dict[str, str] = {
|
|
31
|
+
mapping.model_type: mapping.pipeline_tag
|
|
32
|
+
for mapping in ModelTypeMapping
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def map_pipeline_tag_to_model_type(pipeline_tag: str) -> str:
|
|
37
|
+
"""Map HuggingFace pipeline_tag to our ModelType."""
|
|
38
|
+
if not pipeline_tag:
|
|
39
|
+
return "other"
|
|
40
|
+
|
|
41
|
+
return PIPELINE_TO_MODEL_TYPE.get(pipeline_tag, "other")
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def map_model_type_to_pipeline_tag(model_type: str) -> str:
|
|
45
|
+
"""Reverse map ModelType back to HuggingFace pipeline_tag."""
|
|
46
|
+
if not model_type:
|
|
47
|
+
return None
|
|
48
|
+
|
|
49
|
+
return MODEL_TYPE_TO_PIPELINE.get(model_type)
|
|
@@ -0,0 +1,385 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Progress tracking utilities for downloads with tqdm integration.
|
|
3
|
+
|
|
4
|
+
This module provides custom progress tracking classes that can monitor
|
|
5
|
+
download progress with callback support and customizable display options.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import os
|
|
9
|
+
import sys
|
|
10
|
+
import time
|
|
11
|
+
from typing import Optional, Callable, Dict, Any
|
|
12
|
+
from tqdm.auto import tqdm
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class CustomProgressTqdm(tqdm):
|
|
16
|
+
"""Custom tqdm that tracks progress but completely hides terminal output."""
|
|
17
|
+
|
|
18
|
+
def __init__(self, *args, **kwargs):
|
|
19
|
+
# Redirect output to devnull to completely suppress terminal output
|
|
20
|
+
kwargs['file'] = open(os.devnull, 'w')
|
|
21
|
+
kwargs['disable'] = False # Keep enabled for tracking
|
|
22
|
+
kwargs['leave'] = False # Don't leave progress bar
|
|
23
|
+
super().__init__(*args, **kwargs)
|
|
24
|
+
|
|
25
|
+
def display(self, msg=None, pos=None):
|
|
26
|
+
# Override display to show nothing
|
|
27
|
+
pass
|
|
28
|
+
|
|
29
|
+
def write(self, s, file=None, end="\n", nolock=False):
|
|
30
|
+
# Override write to prevent any output
|
|
31
|
+
pass
|
|
32
|
+
|
|
33
|
+
def close(self):
|
|
34
|
+
# Override close to avoid printing and properly close devnull
|
|
35
|
+
if hasattr(self, 'fp') and self.fp and self.fp != sys.stdout and self.fp != sys.stderr:
|
|
36
|
+
try:
|
|
37
|
+
self.fp.close()
|
|
38
|
+
except:
|
|
39
|
+
pass
|
|
40
|
+
self.disable = True
|
|
41
|
+
super(tqdm, self).close()
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class DownloadProgressTracker:
|
|
45
|
+
"""Progress tracker for HuggingFace downloads with callback support."""
|
|
46
|
+
|
|
47
|
+
def __init__(self, progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None, show_progress: bool = True):
|
|
48
|
+
self.progress_data: Dict[str, Dict[str, Any]] = {}
|
|
49
|
+
self.total_repo_size = 0
|
|
50
|
+
self.repo_file_count = 0
|
|
51
|
+
self.original_tqdm_update = None
|
|
52
|
+
self.original_tqdm_init = None
|
|
53
|
+
self.original_tqdm_display = None
|
|
54
|
+
self.original_tqdm_write = None
|
|
55
|
+
self.is_tracking = False
|
|
56
|
+
|
|
57
|
+
# Callback function
|
|
58
|
+
self.progress_callback = progress_callback
|
|
59
|
+
|
|
60
|
+
# Progress display
|
|
61
|
+
self.show_progress = show_progress
|
|
62
|
+
self.last_display_length = 0
|
|
63
|
+
|
|
64
|
+
# Speed tracking
|
|
65
|
+
self.last_downloaded = None # Use None to indicate no previous measurement
|
|
66
|
+
self.last_time = None # Use None to indicate no previous time measurement
|
|
67
|
+
self.speed_history = []
|
|
68
|
+
self.max_speed_history = 10
|
|
69
|
+
|
|
70
|
+
# Download status
|
|
71
|
+
self.download_status = "idle" # idle, downloading, completed, error
|
|
72
|
+
self.error_message = None
|
|
73
|
+
self.download_start_time = None
|
|
74
|
+
|
|
75
|
+
def set_repo_info(self, total_size: int, file_count: int):
|
|
76
|
+
"""Set the total repository size and file count before download."""
|
|
77
|
+
self.total_repo_size = total_size
|
|
78
|
+
self.repo_file_count = file_count
|
|
79
|
+
|
|
80
|
+
def register_tqdm(self, tqdm_instance):
|
|
81
|
+
"""Register a tqdm instance for monitoring."""
|
|
82
|
+
tqdm_id = str(id(tqdm_instance))
|
|
83
|
+
self.progress_data[tqdm_id] = {
|
|
84
|
+
'current': 0,
|
|
85
|
+
'total': getattr(tqdm_instance, 'total', 0) or 0,
|
|
86
|
+
'desc': getattr(tqdm_instance, 'desc', 'Unknown'),
|
|
87
|
+
'tqdm_obj': tqdm_instance
|
|
88
|
+
}
|
|
89
|
+
# Trigger callback when new file is registered
|
|
90
|
+
self._trigger_callback()
|
|
91
|
+
|
|
92
|
+
def update_progress(self, tqdm_instance, n=1):
|
|
93
|
+
"""Update progress for a tqdm instance."""
|
|
94
|
+
tqdm_id = str(id(tqdm_instance))
|
|
95
|
+
if tqdm_id in self.progress_data:
|
|
96
|
+
self.progress_data[tqdm_id]['current'] = getattr(tqdm_instance, 'n', 0)
|
|
97
|
+
self.progress_data[tqdm_id]['total'] = getattr(tqdm_instance, 'total', 0) or 0
|
|
98
|
+
# Trigger callback on every progress update
|
|
99
|
+
self._trigger_callback()
|
|
100
|
+
|
|
101
|
+
def calculate_speed(self, current_downloaded: int) -> float:
|
|
102
|
+
"""Calculate download speed in bytes per second."""
|
|
103
|
+
current_time = time.time()
|
|
104
|
+
|
|
105
|
+
# Check if we have a previous measurement to compare against
|
|
106
|
+
if self.last_time is not None and self.last_downloaded is not None:
|
|
107
|
+
time_diff = current_time - self.last_time
|
|
108
|
+
|
|
109
|
+
# Only calculate if we have a meaningful time difference (avoid division by very small numbers)
|
|
110
|
+
if time_diff > 0.1: # At least 100ms between measurements
|
|
111
|
+
bytes_diff = current_downloaded - self.last_downloaded
|
|
112
|
+
|
|
113
|
+
# Only calculate speed if bytes actually changed
|
|
114
|
+
if bytes_diff >= 0: # Allow 0 for periods with no progress
|
|
115
|
+
speed = bytes_diff / time_diff
|
|
116
|
+
|
|
117
|
+
# Add to speed history for smoothing
|
|
118
|
+
self.speed_history.append(speed)
|
|
119
|
+
if len(self.speed_history) > self.max_speed_history:
|
|
120
|
+
self.speed_history.pop(0)
|
|
121
|
+
|
|
122
|
+
# Update tracking variables when we actually calculate speed
|
|
123
|
+
self.last_downloaded = current_downloaded
|
|
124
|
+
self.last_time = current_time
|
|
125
|
+
else:
|
|
126
|
+
# First measurement - initialize tracking variables
|
|
127
|
+
self.last_downloaded = current_downloaded
|
|
128
|
+
self.last_time = current_time
|
|
129
|
+
|
|
130
|
+
# Return the average of historical speeds if we have any
|
|
131
|
+
# This ensures we show the last known speed even when skipping updates
|
|
132
|
+
if self.speed_history:
|
|
133
|
+
return sum(self.speed_history) / len(self.speed_history)
|
|
134
|
+
|
|
135
|
+
return 0.0
|
|
136
|
+
|
|
137
|
+
def format_bytes(self, bytes_value: int) -> str:
|
|
138
|
+
"""Format bytes to human readable string."""
|
|
139
|
+
for unit in ['B', 'KB', 'MB', 'GB']:
|
|
140
|
+
if bytes_value < 1024.0:
|
|
141
|
+
return f"{bytes_value:.1f} {unit}"
|
|
142
|
+
bytes_value /= 1024.0
|
|
143
|
+
return f"{bytes_value:.1f} TB"
|
|
144
|
+
|
|
145
|
+
def format_speed(self, speed: float) -> str:
|
|
146
|
+
"""Format speed to human readable string."""
|
|
147
|
+
if speed == 0:
|
|
148
|
+
return "0 B/s"
|
|
149
|
+
|
|
150
|
+
for unit in ['B/s', 'KB/s', 'MB/s', 'GB/s']:
|
|
151
|
+
if speed < 1024.0:
|
|
152
|
+
return f"{speed:.1f} {unit}"
|
|
153
|
+
speed /= 1024.0
|
|
154
|
+
return f"{speed:.1f} TB/s"
|
|
155
|
+
|
|
156
|
+
def get_progress_data(self) -> Dict[str, Any]:
|
|
157
|
+
"""Get current progress data."""
|
|
158
|
+
total_downloaded = 0
|
|
159
|
+
active_file_count = 0
|
|
160
|
+
total_file_sizes = 0
|
|
161
|
+
|
|
162
|
+
for data in self.progress_data.values():
|
|
163
|
+
if data['total'] > 0:
|
|
164
|
+
total_downloaded += data['current']
|
|
165
|
+
total_file_sizes += data['total']
|
|
166
|
+
active_file_count += 1
|
|
167
|
+
|
|
168
|
+
# Calculate speed (tracking variables are updated internally)
|
|
169
|
+
speed = self.calculate_speed(total_downloaded)
|
|
170
|
+
|
|
171
|
+
# Determine total size - prioritize pre-fetched repo size, then aggregate file sizes
|
|
172
|
+
if self.total_repo_size > 0:
|
|
173
|
+
# Use pre-fetched repository info if available
|
|
174
|
+
total_size = self.total_repo_size
|
|
175
|
+
elif total_file_sizes > 0:
|
|
176
|
+
# Use sum of individual file sizes if available
|
|
177
|
+
total_size = total_file_sizes
|
|
178
|
+
else:
|
|
179
|
+
# Last resort - we don't know the total size yet
|
|
180
|
+
total_size = 0
|
|
181
|
+
|
|
182
|
+
file_count = self.repo_file_count if self.repo_file_count > 0 else active_file_count
|
|
183
|
+
|
|
184
|
+
# Calculate percentage - handle unknown total size gracefully
|
|
185
|
+
if total_size > 0:
|
|
186
|
+
percentage = min((total_downloaded / total_size * 100), 100.0)
|
|
187
|
+
else:
|
|
188
|
+
percentage = 0
|
|
189
|
+
|
|
190
|
+
# Calculate ETA
|
|
191
|
+
eta_seconds = None
|
|
192
|
+
if speed > 0 and total_size > total_downloaded:
|
|
193
|
+
eta_seconds = (total_size - total_downloaded) / speed
|
|
194
|
+
|
|
195
|
+
# Calculate elapsed time
|
|
196
|
+
elapsed_seconds = None
|
|
197
|
+
if self.download_start_time:
|
|
198
|
+
elapsed_seconds = time.time() - self.download_start_time
|
|
199
|
+
|
|
200
|
+
return {
|
|
201
|
+
'status': self.download_status,
|
|
202
|
+
'error_message': self.error_message,
|
|
203
|
+
'progress': {
|
|
204
|
+
'total_downloaded': total_downloaded,
|
|
205
|
+
'total_size': total_size,
|
|
206
|
+
'percentage': round(percentage, 2),
|
|
207
|
+
'files_active': active_file_count,
|
|
208
|
+
'files_total': file_count,
|
|
209
|
+
'known_total': total_size > 0
|
|
210
|
+
},
|
|
211
|
+
'speed': {
|
|
212
|
+
'bytes_per_second': speed,
|
|
213
|
+
'formatted': self.format_speed(speed)
|
|
214
|
+
},
|
|
215
|
+
'formatting': {
|
|
216
|
+
'downloaded': self.format_bytes(total_downloaded),
|
|
217
|
+
'total_size': self.format_bytes(total_size)
|
|
218
|
+
},
|
|
219
|
+
'timing': {
|
|
220
|
+
'elapsed_seconds': elapsed_seconds,
|
|
221
|
+
'eta_seconds': eta_seconds,
|
|
222
|
+
'start_time': self.download_start_time
|
|
223
|
+
}
|
|
224
|
+
}
|
|
225
|
+
|
|
226
|
+
def _display_progress_bar(self, progress_data: Dict[str, Any]):
|
|
227
|
+
"""Display a custom unified progress bar."""
|
|
228
|
+
if not self.show_progress:
|
|
229
|
+
return
|
|
230
|
+
|
|
231
|
+
# Clear previous line
|
|
232
|
+
if self.last_display_length > 0:
|
|
233
|
+
print('\r' + ' ' * self.last_display_length, end='\r')
|
|
234
|
+
|
|
235
|
+
progress_info = progress_data.get('progress', {})
|
|
236
|
+
speed_info = progress_data.get('speed', {})
|
|
237
|
+
timing_info = progress_data.get('timing', {})
|
|
238
|
+
formatting_info = progress_data.get('formatting', {})
|
|
239
|
+
|
|
240
|
+
percentage = progress_info.get('percentage', 0)
|
|
241
|
+
downloaded = formatting_info.get('downloaded', '0 B')
|
|
242
|
+
total_size_raw = progress_info.get('total_size', 0)
|
|
243
|
+
total_size = formatting_info.get('total_size', 'Unknown')
|
|
244
|
+
speed = speed_info.get('formatted', '0 B/s')
|
|
245
|
+
known_total = progress_info.get('known_total', False)
|
|
246
|
+
|
|
247
|
+
# Create progress bar
|
|
248
|
+
bar_width = 30
|
|
249
|
+
if known_total and total_size_raw > 0:
|
|
250
|
+
# Known total size - show actual progress
|
|
251
|
+
filled_width = int(bar_width * min(percentage, 100) / 100)
|
|
252
|
+
bar = '#' * filled_width + '-' * (bar_width - filled_width)
|
|
253
|
+
else:
|
|
254
|
+
# Unknown total size - show animated progress
|
|
255
|
+
animation_pos = int(time.time() * 2) % bar_width
|
|
256
|
+
bar = '-' * animation_pos + '#' + '-' * (bar_width - animation_pos - 1)
|
|
257
|
+
|
|
258
|
+
# Format the progress line
|
|
259
|
+
status = progress_data.get('status', 'unknown')
|
|
260
|
+
if status == 'downloading':
|
|
261
|
+
if known_total:
|
|
262
|
+
progress_line = f"[{bar}] {percentage:.1f}% | {downloaded}/{total_size} | {speed}"
|
|
263
|
+
else:
|
|
264
|
+
progress_line = f"[{bar}] {downloaded} | {speed} | Calculating size..."
|
|
265
|
+
elif status == 'completed':
|
|
266
|
+
progress_line = f"[{bar}] 100.0% | {downloaded} | Complete!"
|
|
267
|
+
elif status == 'error':
|
|
268
|
+
progress_line = f"Error: {progress_data.get('error_message', 'Unknown error')}"
|
|
269
|
+
else:
|
|
270
|
+
progress_line = f"Starting download..."
|
|
271
|
+
|
|
272
|
+
# Display and track length for next clear
|
|
273
|
+
print(progress_line, end='', flush=True)
|
|
274
|
+
self.last_display_length = len(progress_line)
|
|
275
|
+
|
|
276
|
+
def _clear_progress_bar(self):
|
|
277
|
+
"""Clear the progress bar display."""
|
|
278
|
+
if self.show_progress and self.last_display_length > 0:
|
|
279
|
+
print('\r' + ' ' * self.last_display_length, end='\r')
|
|
280
|
+
print() # Move to next line
|
|
281
|
+
self.last_display_length = 0
|
|
282
|
+
|
|
283
|
+
def _trigger_callback(self):
|
|
284
|
+
"""Trigger the progress callback if one is set."""
|
|
285
|
+
progress_data = self.get_progress_data()
|
|
286
|
+
|
|
287
|
+
if self.progress_callback:
|
|
288
|
+
try:
|
|
289
|
+
self.progress_callback(progress_data)
|
|
290
|
+
except Exception as e:
|
|
291
|
+
print(f"Error in progress callback: {e}")
|
|
292
|
+
|
|
293
|
+
# Show custom progress bar only if callback is enabled and show_progress is True
|
|
294
|
+
if self.progress_callback and self.show_progress:
|
|
295
|
+
self._display_progress_bar(progress_data)
|
|
296
|
+
|
|
297
|
+
def start_tracking(self):
|
|
298
|
+
"""Start progress tracking (monkey patch tqdm)."""
|
|
299
|
+
if self.is_tracking:
|
|
300
|
+
return
|
|
301
|
+
|
|
302
|
+
# Store original methods
|
|
303
|
+
self.original_tqdm_update = tqdm.update
|
|
304
|
+
self.original_tqdm_init = tqdm.__init__
|
|
305
|
+
self.original_tqdm_display = tqdm.display
|
|
306
|
+
self.original_tqdm_write = tqdm.write
|
|
307
|
+
|
|
308
|
+
# Create references to self for the nested functions
|
|
309
|
+
tracker = self
|
|
310
|
+
|
|
311
|
+
def patched_init(self_tqdm, *args, **kwargs):
|
|
312
|
+
# Suppress tqdm display by redirecting to devnull
|
|
313
|
+
kwargs['file'] = open(os.devnull, 'w')
|
|
314
|
+
kwargs['disable'] = False # Keep enabled for tracking
|
|
315
|
+
kwargs['leave'] = False # Don't leave progress bar
|
|
316
|
+
|
|
317
|
+
result = tracker.original_tqdm_init(self_tqdm, *args, **kwargs)
|
|
318
|
+
tracker.register_tqdm(self_tqdm)
|
|
319
|
+
return result
|
|
320
|
+
|
|
321
|
+
def patched_update(self_tqdm, n=1):
|
|
322
|
+
result = tracker.original_tqdm_update(self_tqdm, n)
|
|
323
|
+
tracker.update_progress(self_tqdm, n)
|
|
324
|
+
return result
|
|
325
|
+
|
|
326
|
+
def patched_display(self_tqdm, msg=None, pos=None):
|
|
327
|
+
# Override display to show nothing
|
|
328
|
+
pass
|
|
329
|
+
|
|
330
|
+
def patched_write(self_tqdm, s, file=None, end="\n", nolock=False):
|
|
331
|
+
# Override write to prevent any output
|
|
332
|
+
pass
|
|
333
|
+
|
|
334
|
+
# Apply patches
|
|
335
|
+
tqdm.__init__ = patched_init
|
|
336
|
+
tqdm.update = patched_update
|
|
337
|
+
tqdm.display = patched_display
|
|
338
|
+
tqdm.write = patched_write
|
|
339
|
+
|
|
340
|
+
self.is_tracking = True
|
|
341
|
+
self.download_status = "downloading"
|
|
342
|
+
self.download_start_time = time.time()
|
|
343
|
+
|
|
344
|
+
# Trigger initial callback
|
|
345
|
+
self._trigger_callback()
|
|
346
|
+
|
|
347
|
+
def stop_tracking(self):
|
|
348
|
+
"""Stop progress tracking and restore original tqdm."""
|
|
349
|
+
if not self.is_tracking:
|
|
350
|
+
return
|
|
351
|
+
|
|
352
|
+
# Restore original tqdm methods
|
|
353
|
+
if self.original_tqdm_update:
|
|
354
|
+
tqdm.update = self.original_tqdm_update
|
|
355
|
+
if self.original_tqdm_init:
|
|
356
|
+
tqdm.__init__ = self.original_tqdm_init
|
|
357
|
+
if hasattr(self, 'original_tqdm_display') and self.original_tqdm_display:
|
|
358
|
+
tqdm.display = self.original_tqdm_display
|
|
359
|
+
if hasattr(self, 'original_tqdm_write') and self.original_tqdm_write:
|
|
360
|
+
tqdm.write = self.original_tqdm_write
|
|
361
|
+
|
|
362
|
+
# Clean up any open devnull file handles from tqdm instances
|
|
363
|
+
for data in self.progress_data.values():
|
|
364
|
+
if 'tqdm_obj' in data and hasattr(data['tqdm_obj'], 'fp'):
|
|
365
|
+
try:
|
|
366
|
+
fp = data['tqdm_obj'].fp
|
|
367
|
+
if fp and fp != sys.stdout and fp != sys.stderr and not fp.closed:
|
|
368
|
+
fp.close()
|
|
369
|
+
except:
|
|
370
|
+
pass
|
|
371
|
+
|
|
372
|
+
self.is_tracking = False
|
|
373
|
+
if self.download_status == "downloading":
|
|
374
|
+
self.download_status = "completed"
|
|
375
|
+
|
|
376
|
+
# Trigger final callback and clear progress bar
|
|
377
|
+
self._trigger_callback()
|
|
378
|
+
self._clear_progress_bar()
|
|
379
|
+
|
|
380
|
+
def set_error(self, error_message: str):
|
|
381
|
+
"""Set error status and trigger callback."""
|
|
382
|
+
self.download_status = "error"
|
|
383
|
+
self.error_message = error_message
|
|
384
|
+
self._trigger_callback()
|
|
385
|
+
self._clear_progress_bar()
|
|
@@ -0,0 +1,245 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Quantization utilities for extracting quantization types from model files and configurations.
|
|
3
|
+
|
|
4
|
+
This module provides utilities to extract quantization information from:
|
|
5
|
+
- GGUF model filenames
|
|
6
|
+
- MLX model repository IDs
|
|
7
|
+
- MLX model config.json files
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import os
|
|
11
|
+
import json
|
|
12
|
+
import re
|
|
13
|
+
import logging
|
|
14
|
+
from enum import Enum
|
|
15
|
+
from typing import Optional
|
|
16
|
+
|
|
17
|
+
# Set up logger
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class QuantizationType(str, Enum):
|
|
22
|
+
"""Enum for GGUF and MLX model quantization types."""
|
|
23
|
+
# GGUF quantization types
|
|
24
|
+
BF16 = "BF16"
|
|
25
|
+
F16 = "F16"
|
|
26
|
+
Q2_K = "Q2_K"
|
|
27
|
+
Q2_K_L = "Q2_K_L"
|
|
28
|
+
Q3_K = "Q3_K"
|
|
29
|
+
Q3_K_M = "Q3_K_M"
|
|
30
|
+
Q3_K_S = "Q3_K_S"
|
|
31
|
+
Q4_0 = "Q4_0"
|
|
32
|
+
Q4_1 = "Q4_1"
|
|
33
|
+
Q4_K = "Q4_K"
|
|
34
|
+
Q4_K_M = "Q4_K_M"
|
|
35
|
+
Q4_K_S = "Q4_K_S"
|
|
36
|
+
Q5_K = "Q5_K"
|
|
37
|
+
Q5_K_M = "Q5_K_M"
|
|
38
|
+
Q5_K_S = "Q5_K_S"
|
|
39
|
+
Q6_K = "Q6_K"
|
|
40
|
+
Q8_0 = "Q8_0"
|
|
41
|
+
MXFP4 = "MXFP4"
|
|
42
|
+
MXFP8 = "MXFP8"
|
|
43
|
+
|
|
44
|
+
# MLX bit-based quantization types
|
|
45
|
+
BIT_1 = "1BIT"
|
|
46
|
+
BIT_2 = "2BIT"
|
|
47
|
+
BIT_3 = "3BIT"
|
|
48
|
+
BIT_4 = "4BIT"
|
|
49
|
+
BIT_5 = "5BIT"
|
|
50
|
+
BIT_6 = "6BIT"
|
|
51
|
+
BIT_7 = "7BIT"
|
|
52
|
+
BIT_8 = "8BIT"
|
|
53
|
+
BIT_16 = "16BIT"
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def extract_quantization_from_filename(filename: str) -> Optional[QuantizationType]:
|
|
57
|
+
"""
|
|
58
|
+
Extract quantization type from filename.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
filename: The filename to extract quantization from
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
QuantizationType enum value or None if not found
|
|
65
|
+
"""
|
|
66
|
+
# Define mapping from lowercase patterns to enum values
|
|
67
|
+
# Include "." to ensure precise matching (e.g., "q4_0." not "q4_0_xl")
|
|
68
|
+
pattern_to_enum = {
|
|
69
|
+
'bf16.': QuantizationType.BF16,
|
|
70
|
+
'f16.': QuantizationType.F16, # Add F16 support
|
|
71
|
+
'q2_k_l.': QuantizationType.Q2_K_L, # Check Q2_K_L before Q2_K to avoid partial match
|
|
72
|
+
'q2_k.': QuantizationType.Q2_K,
|
|
73
|
+
'q3_k.': QuantizationType.Q3_K,
|
|
74
|
+
'q3_k_m.': QuantizationType.Q3_K_M,
|
|
75
|
+
'q3_k_s.': QuantizationType.Q3_K_S,
|
|
76
|
+
'q4_k_m.': QuantizationType.Q4_K_M,
|
|
77
|
+
'q4_k_s.': QuantizationType.Q4_K_S,
|
|
78
|
+
'q4_0.': QuantizationType.Q4_0,
|
|
79
|
+
'q4_1.': QuantizationType.Q4_1,
|
|
80
|
+
'q4_k.': QuantizationType.Q4_K,
|
|
81
|
+
'q5_k.': QuantizationType.Q5_K,
|
|
82
|
+
'q5_k_m.': QuantizationType.Q5_K_M,
|
|
83
|
+
'q5_k_s.': QuantizationType.Q5_K_S,
|
|
84
|
+
'q6_k.': QuantizationType.Q6_K,
|
|
85
|
+
'q8_0.': QuantizationType.Q8_0,
|
|
86
|
+
'mxfp4.': QuantizationType.MXFP4,
|
|
87
|
+
'mxfp8.': QuantizationType.MXFP8,
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
filename_lower = filename.lower()
|
|
91
|
+
|
|
92
|
+
# Check longer patterns first to avoid partial matches
|
|
93
|
+
# Sort by length descending to check q2_k_l before q2_k, q4_k_m before q4_0, etc.
|
|
94
|
+
for pattern in sorted(pattern_to_enum.keys(), key=len, reverse=True):
|
|
95
|
+
if pattern in filename_lower:
|
|
96
|
+
return pattern_to_enum[pattern]
|
|
97
|
+
|
|
98
|
+
return None
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def extract_quantization_from_repo_id(repo_id: str) -> Optional[QuantizationType]:
|
|
102
|
+
"""
|
|
103
|
+
Extract quantization type from repo_id for MLX models by looking for bit patterns.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
repo_id: The repository ID to extract quantization from
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
QuantizationType enum value or None if not found
|
|
110
|
+
"""
|
|
111
|
+
# Define mapping from bit numbers to enum values
|
|
112
|
+
bit_to_enum = {
|
|
113
|
+
1: QuantizationType.BIT_1,
|
|
114
|
+
2: QuantizationType.BIT_2,
|
|
115
|
+
3: QuantizationType.BIT_3,
|
|
116
|
+
4: QuantizationType.BIT_4,
|
|
117
|
+
5: QuantizationType.BIT_5,
|
|
118
|
+
6: QuantizationType.BIT_6,
|
|
119
|
+
7: QuantizationType.BIT_7,
|
|
120
|
+
8: QuantizationType.BIT_8,
|
|
121
|
+
16: QuantizationType.BIT_16,
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
# First check for patterns like "4bit", "8bit" etc. (case insensitive)
|
|
125
|
+
pattern = r'(\d+)bit'
|
|
126
|
+
matches = re.findall(pattern, repo_id.lower())
|
|
127
|
+
|
|
128
|
+
for match in matches:
|
|
129
|
+
try:
|
|
130
|
+
bit_number = int(match)
|
|
131
|
+
if bit_number in bit_to_enum:
|
|
132
|
+
logger.debug(f"Found {bit_number}bit quantization in repo_id: {repo_id}")
|
|
133
|
+
return bit_to_enum[bit_number]
|
|
134
|
+
except ValueError:
|
|
135
|
+
continue
|
|
136
|
+
|
|
137
|
+
# Also check for patterns like "-q8", "_Q4" etc.
|
|
138
|
+
q_pattern = r'[-_]q(\d+)'
|
|
139
|
+
q_matches = re.findall(q_pattern, repo_id.lower())
|
|
140
|
+
|
|
141
|
+
for match in q_matches:
|
|
142
|
+
try:
|
|
143
|
+
bit_number = int(match)
|
|
144
|
+
if bit_number in bit_to_enum:
|
|
145
|
+
logger.debug(f"Found Q{bit_number} quantization in repo_id: {repo_id}")
|
|
146
|
+
return bit_to_enum[bit_number]
|
|
147
|
+
except ValueError:
|
|
148
|
+
continue
|
|
149
|
+
|
|
150
|
+
return None
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def extract_quantization_from_mlx_config(mlx_folder_path: str) -> Optional[QuantizationType]:
|
|
154
|
+
"""
|
|
155
|
+
Extract quantization type from MLX model's config.json file.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
mlx_folder_path: Path to the MLX model folder
|
|
159
|
+
|
|
160
|
+
Returns:
|
|
161
|
+
QuantizationType enum value or None if not found
|
|
162
|
+
"""
|
|
163
|
+
config_path = os.path.join(mlx_folder_path, "config.json")
|
|
164
|
+
|
|
165
|
+
if not os.path.exists(config_path):
|
|
166
|
+
logger.debug(f"Config file not found: {config_path}")
|
|
167
|
+
return None
|
|
168
|
+
|
|
169
|
+
try:
|
|
170
|
+
with open(config_path, 'r', encoding='utf-8') as f:
|
|
171
|
+
config = json.load(f)
|
|
172
|
+
|
|
173
|
+
# Look for quantization.bits field
|
|
174
|
+
quantization_config = config.get("quantization", {})
|
|
175
|
+
if isinstance(quantization_config, dict):
|
|
176
|
+
bits = quantization_config.get("bits")
|
|
177
|
+
if isinstance(bits, int):
|
|
178
|
+
# Define mapping from bit numbers to enum values
|
|
179
|
+
bit_to_enum = {
|
|
180
|
+
1: QuantizationType.BIT_1,
|
|
181
|
+
2: QuantizationType.BIT_2,
|
|
182
|
+
3: QuantizationType.BIT_3,
|
|
183
|
+
4: QuantizationType.BIT_4,
|
|
184
|
+
5: QuantizationType.BIT_5,
|
|
185
|
+
6: QuantizationType.BIT_6,
|
|
186
|
+
7: QuantizationType.BIT_7,
|
|
187
|
+
8: QuantizationType.BIT_8,
|
|
188
|
+
16: QuantizationType.BIT_16,
|
|
189
|
+
}
|
|
190
|
+
|
|
191
|
+
if bits in bit_to_enum:
|
|
192
|
+
logger.debug(f"Found {bits}bit quantization in config.json: {config_path}")
|
|
193
|
+
return bit_to_enum[bits]
|
|
194
|
+
else:
|
|
195
|
+
logger.debug(f"Unsupported quantization bits value: {bits}")
|
|
196
|
+
|
|
197
|
+
except (json.JSONDecodeError, IOError) as e:
|
|
198
|
+
logger.warning(f"Error reading config.json from {config_path}: {e}")
|
|
199
|
+
except Exception as e:
|
|
200
|
+
logger.warning(f"Unexpected error reading config.json from {config_path}: {e}")
|
|
201
|
+
|
|
202
|
+
return None
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def extract_gguf_quantization(filename: str) -> str:
|
|
206
|
+
"""
|
|
207
|
+
Extract quantization level from GGUF filename using the enum-based approach.
|
|
208
|
+
|
|
209
|
+
This function provides backward compatibility by returning a string representation
|
|
210
|
+
of the quantization type.
|
|
211
|
+
|
|
212
|
+
Args:
|
|
213
|
+
filename: The GGUF filename
|
|
214
|
+
|
|
215
|
+
Returns:
|
|
216
|
+
String representation of the quantization type or "UNKNOWN" if not found
|
|
217
|
+
"""
|
|
218
|
+
quantization_type = extract_quantization_from_filename(filename)
|
|
219
|
+
if quantization_type:
|
|
220
|
+
return quantization_type.value
|
|
221
|
+
return "UNKNOWN"
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def detect_quantization_for_mlx(repo_id: str, directory_path: str) -> Optional[QuantizationType]:
|
|
225
|
+
"""
|
|
226
|
+
Detect quantization for MLX models using multiple methods in priority order.
|
|
227
|
+
|
|
228
|
+
Args:
|
|
229
|
+
repo_id: The repository ID
|
|
230
|
+
directory_path: Path to the model directory
|
|
231
|
+
|
|
232
|
+
Returns:
|
|
233
|
+
QuantizationType enum value or None if not found
|
|
234
|
+
"""
|
|
235
|
+
# Method 1: Extract from repo_id
|
|
236
|
+
quantization_type = extract_quantization_from_repo_id(repo_id)
|
|
237
|
+
if quantization_type:
|
|
238
|
+
return quantization_type
|
|
239
|
+
|
|
240
|
+
# Method 2: Extract from config.json if available
|
|
241
|
+
quantization_type = extract_quantization_from_mlx_config(directory_path)
|
|
242
|
+
if quantization_type:
|
|
243
|
+
return quantization_type
|
|
244
|
+
|
|
245
|
+
return None
|