nexaai 1.0.21rc5__cp313-cp313-win_arm64.whl → 1.0.21rc14__cp313-cp313-win_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nexaai might be problematic. Click here for more details.
- nexaai/__init__.py +95 -95
- nexaai/_stub.cp313-win_arm64.pyd +0 -0
- nexaai/_version.py +4 -1
- nexaai/asr.py +68 -65
- nexaai/asr_impl/mlx_asr_impl.py +92 -92
- nexaai/asr_impl/pybind_asr_impl.py +127 -44
- nexaai/base.py +39 -39
- nexaai/binds/__init__.py +6 -5
- nexaai/binds/asr_bind.cp313-win_arm64.pyd +0 -0
- nexaai/binds/common_bind.cp313-win_arm64.pyd +0 -0
- nexaai/binds/cpu_gpu/ggml-base.dll +0 -0
- nexaai/binds/cpu_gpu/ggml-cpu.dll +0 -0
- nexaai/binds/cpu_gpu/ggml-opencl.dll +0 -0
- nexaai/binds/cpu_gpu/ggml.dll +0 -0
- nexaai/binds/cpu_gpu/mtmd.dll +0 -0
- nexaai/binds/cpu_gpu/nexa_cpu_gpu.dll +0 -0
- nexaai/binds/cpu_gpu/nexa_plugin.dll +0 -0
- nexaai/binds/embedder_bind.cp313-win_arm64.pyd +0 -0
- nexaai/binds/libcrypto-3-arm64.dll +0 -0
- nexaai/binds/libssl-3-arm64.dll +0 -0
- nexaai/binds/llm_bind.cp313-win_arm64.pyd +0 -0
- nexaai/binds/nexa_bridge.dll +0 -0
- nexaai/binds/npu/convnext-sdk.dll +0 -0
- nexaai/binds/npu/embed-gemma-sdk.dll +0 -0
- nexaai/binds/npu/ggml-base.dll +0 -0
- nexaai/binds/npu/ggml-cpu.dll +0 -0
- nexaai/binds/npu/ggml-opencl.dll +0 -0
- nexaai/binds/npu/ggml.dll +0 -0
- nexaai/binds/npu/granite-nano-sdk.dll +0 -0
- nexaai/binds/npu/granite4-sdk.dll +0 -0
- nexaai/binds/npu/jina-rerank-sdk.dll +0 -0
- nexaai/binds/npu/liquid-sdk.dll +0 -0
- nexaai/binds/npu/llama3-3b-sdk.dll +0 -0
- nexaai/binds/npu/nexa-mm-process.dll +0 -0
- nexaai/binds/npu/nexa-sampling.dll +0 -0
- nexaai/binds/npu/nexa_plugin.dll +0 -0
- nexaai/binds/npu/omni-neural-sdk.dll +0 -0
- nexaai/binds/npu/openblas.dll +0 -0
- nexaai/binds/npu/paddleocr-sdk.dll +0 -0
- nexaai/binds/npu/parakeet-sdk.dll +0 -0
- nexaai/binds/npu/phi3-5-sdk.dll +0 -0
- nexaai/binds/npu/phi4-sdk.dll +0 -0
- nexaai/binds/npu/pyannote-sdk.dll +0 -0
- nexaai/binds/npu/qwen3-4b-sdk.dll +0 -0
- nexaai/binds/npu/qwen3vl-sdk.dll +0 -0
- nexaai/binds/npu/qwen3vl-vision.dll +0 -0
- nexaai/binds/npu/yolov12-sdk.dll +0 -0
- nexaai/binds/npu/zlib1.dll +0 -0
- nexaai/binds/rerank_bind.cp313-win_arm64.pyd +0 -0
- nexaai/binds/vlm_bind.cp313-win_arm64.pyd +0 -0
- nexaai/common.py +105 -105
- nexaai/cv.py +93 -93
- nexaai/cv_impl/mlx_cv_impl.py +89 -89
- nexaai/cv_impl/pybind_cv_impl.py +32 -32
- nexaai/embedder.py +73 -73
- nexaai/embedder_impl/mlx_embedder_impl.py +118 -118
- nexaai/embedder_impl/pybind_embedder_impl.py +96 -96
- nexaai/image_gen.py +141 -141
- nexaai/image_gen_impl/mlx_image_gen_impl.py +292 -292
- nexaai/image_gen_impl/pybind_image_gen_impl.py +85 -85
- nexaai/llm.py +98 -98
- nexaai/llm_impl/mlx_llm_impl.py +271 -271
- nexaai/llm_impl/pybind_llm_impl.py +220 -220
- nexaai/log.py +92 -92
- nexaai/rerank.py +57 -57
- nexaai/rerank_impl/mlx_rerank_impl.py +94 -94
- nexaai/rerank_impl/pybind_rerank_impl.py +136 -136
- nexaai/runtime.py +68 -68
- nexaai/runtime_error.py +24 -24
- nexaai/tts.py +75 -75
- nexaai/tts_impl/mlx_tts_impl.py +94 -94
- nexaai/tts_impl/pybind_tts_impl.py +43 -43
- nexaai/utils/decode.py +17 -17
- nexaai/utils/manifest_utils.py +531 -531
- nexaai/utils/model_manager.py +1562 -1562
- nexaai/utils/model_types.py +49 -49
- nexaai/utils/progress_tracker.py +384 -384
- nexaai/utils/quantization_utils.py +245 -245
- nexaai/vlm.py +129 -129
- nexaai/vlm_impl/mlx_vlm_impl.py +258 -258
- nexaai/vlm_impl/pybind_vlm_impl.py +256 -256
- {nexaai-1.0.21rc5.dist-info → nexaai-1.0.21rc14.dist-info}/METADATA +1 -1
- nexaai-1.0.21rc14.dist-info/RECORD +154 -0
- nexaai/binds/nexaml/FLAC.dll +0 -0
- nexaai/binds/nexaml/fftw3.dll +0 -0
- nexaai/binds/nexaml/fftw3f.dll +0 -0
- nexaai/binds/nexaml/ggml-base.dll +0 -0
- nexaai/binds/nexaml/ggml-cpu.dll +0 -0
- nexaai/binds/nexaml/ggml-opencl.dll +0 -0
- nexaai/binds/nexaml/ggml.dll +0 -0
- nexaai/binds/nexaml/libmp3lame.DLL +0 -0
- nexaai/binds/nexaml/mpg123.dll +0 -0
- nexaai/binds/nexaml/nexa-mm-process.dll +0 -0
- nexaai/binds/nexaml/nexa-sampling.dll +0 -0
- nexaai/binds/nexaml/nexa_plugin.dll +0 -0
- nexaai/binds/nexaml/nexaproc.dll +0 -0
- nexaai/binds/nexaml/ogg.dll +0 -0
- nexaai/binds/nexaml/opus.dll +0 -0
- nexaai/binds/nexaml/qwen3-vl.dll +0 -0
- nexaai/binds/nexaml/qwen3vl-vision.dll +0 -0
- nexaai/binds/nexaml/vorbis.dll +0 -0
- nexaai/binds/nexaml/vorbisenc.dll +0 -0
- nexaai-1.0.21rc5.dist-info/RECORD +0 -162
- {nexaai-1.0.21rc5.dist-info → nexaai-1.0.21rc14.dist-info}/WHEEL +0 -0
- {nexaai-1.0.21rc5.dist-info → nexaai-1.0.21rc14.dist-info}/top_level.txt +0 -0
nexaai/utils/progress_tracker.py
CHANGED
|
@@ -1,385 +1,385 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Progress tracking utilities for downloads with tqdm integration.
|
|
3
|
-
|
|
4
|
-
This module provides custom progress tracking classes that can monitor
|
|
5
|
-
download progress with callback support and customizable display options.
|
|
6
|
-
"""
|
|
7
|
-
|
|
8
|
-
import os
|
|
9
|
-
import sys
|
|
10
|
-
import time
|
|
11
|
-
from typing import Optional, Callable, Dict, Any
|
|
12
|
-
from tqdm.auto import tqdm
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
class CustomProgressTqdm(tqdm):
|
|
16
|
-
"""Custom tqdm that tracks progress but completely hides terminal output."""
|
|
17
|
-
|
|
18
|
-
def __init__(self, *args, **kwargs):
|
|
19
|
-
# Redirect output to devnull to completely suppress terminal output
|
|
20
|
-
kwargs['file'] = open(os.devnull, 'w')
|
|
21
|
-
kwargs['disable'] = False # Keep enabled for tracking
|
|
22
|
-
kwargs['leave'] = False # Don't leave progress bar
|
|
23
|
-
super().__init__(*args, **kwargs)
|
|
24
|
-
|
|
25
|
-
def display(self, msg=None, pos=None):
|
|
26
|
-
# Override display to show nothing
|
|
27
|
-
pass
|
|
28
|
-
|
|
29
|
-
def write(self, s, file=None, end="\n", nolock=False):
|
|
30
|
-
# Override write to prevent any output
|
|
31
|
-
pass
|
|
32
|
-
|
|
33
|
-
def close(self):
|
|
34
|
-
# Override close to avoid printing and properly close devnull
|
|
35
|
-
if hasattr(self, 'fp') and self.fp and self.fp != sys.stdout and self.fp != sys.stderr:
|
|
36
|
-
try:
|
|
37
|
-
self.fp.close()
|
|
38
|
-
except:
|
|
39
|
-
pass
|
|
40
|
-
self.disable = True
|
|
41
|
-
super(tqdm, self).close()
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
class DownloadProgressTracker:
|
|
45
|
-
"""Progress tracker for HuggingFace downloads with callback support."""
|
|
46
|
-
|
|
47
|
-
def __init__(self, progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None, show_progress: bool = True):
|
|
48
|
-
self.progress_data: Dict[str, Dict[str, Any]] = {}
|
|
49
|
-
self.total_repo_size = 0
|
|
50
|
-
self.repo_file_count = 0
|
|
51
|
-
self.original_tqdm_update = None
|
|
52
|
-
self.original_tqdm_init = None
|
|
53
|
-
self.original_tqdm_display = None
|
|
54
|
-
self.original_tqdm_write = None
|
|
55
|
-
self.is_tracking = False
|
|
56
|
-
|
|
57
|
-
# Callback function
|
|
58
|
-
self.progress_callback = progress_callback
|
|
59
|
-
|
|
60
|
-
# Progress display
|
|
61
|
-
self.show_progress = show_progress
|
|
62
|
-
self.last_display_length = 0
|
|
63
|
-
|
|
64
|
-
# Speed tracking
|
|
65
|
-
self.last_downloaded = None # Use None to indicate no previous measurement
|
|
66
|
-
self.last_time = None # Use None to indicate no previous time measurement
|
|
67
|
-
self.speed_history = []
|
|
68
|
-
self.max_speed_history = 10
|
|
69
|
-
|
|
70
|
-
# Download status
|
|
71
|
-
self.download_status = "idle" # idle, downloading, completed, error
|
|
72
|
-
self.error_message = None
|
|
73
|
-
self.download_start_time = None
|
|
74
|
-
|
|
75
|
-
def set_repo_info(self, total_size: int, file_count: int):
|
|
76
|
-
"""Set the total repository size and file count before download."""
|
|
77
|
-
self.total_repo_size = total_size
|
|
78
|
-
self.repo_file_count = file_count
|
|
79
|
-
|
|
80
|
-
def register_tqdm(self, tqdm_instance):
|
|
81
|
-
"""Register a tqdm instance for monitoring."""
|
|
82
|
-
tqdm_id = str(id(tqdm_instance))
|
|
83
|
-
self.progress_data[tqdm_id] = {
|
|
84
|
-
'current': 0,
|
|
85
|
-
'total': getattr(tqdm_instance, 'total', 0) or 0,
|
|
86
|
-
'desc': getattr(tqdm_instance, 'desc', 'Unknown'),
|
|
87
|
-
'tqdm_obj': tqdm_instance
|
|
88
|
-
}
|
|
89
|
-
# Trigger callback when new file is registered
|
|
90
|
-
self._trigger_callback()
|
|
91
|
-
|
|
92
|
-
def update_progress(self, tqdm_instance, n=1):
|
|
93
|
-
"""Update progress for a tqdm instance."""
|
|
94
|
-
tqdm_id = str(id(tqdm_instance))
|
|
95
|
-
if tqdm_id in self.progress_data:
|
|
96
|
-
self.progress_data[tqdm_id]['current'] = getattr(tqdm_instance, 'n', 0)
|
|
97
|
-
self.progress_data[tqdm_id]['total'] = getattr(tqdm_instance, 'total', 0) or 0
|
|
98
|
-
# Trigger callback on every progress update
|
|
99
|
-
self._trigger_callback()
|
|
100
|
-
|
|
101
|
-
def calculate_speed(self, current_downloaded: int) -> float:
|
|
102
|
-
"""Calculate download speed in bytes per second."""
|
|
103
|
-
current_time = time.time()
|
|
104
|
-
|
|
105
|
-
# Check if we have a previous measurement to compare against
|
|
106
|
-
if self.last_time is not None and self.last_downloaded is not None:
|
|
107
|
-
time_diff = current_time - self.last_time
|
|
108
|
-
|
|
109
|
-
# Only calculate if we have a meaningful time difference (avoid division by very small numbers)
|
|
110
|
-
if time_diff > 0.1: # At least 100ms between measurements
|
|
111
|
-
bytes_diff = current_downloaded - self.last_downloaded
|
|
112
|
-
|
|
113
|
-
# Only calculate speed if bytes actually changed
|
|
114
|
-
if bytes_diff >= 0: # Allow 0 for periods with no progress
|
|
115
|
-
speed = bytes_diff / time_diff
|
|
116
|
-
|
|
117
|
-
# Add to speed history for smoothing
|
|
118
|
-
self.speed_history.append(speed)
|
|
119
|
-
if len(self.speed_history) > self.max_speed_history:
|
|
120
|
-
self.speed_history.pop(0)
|
|
121
|
-
|
|
122
|
-
# Update tracking variables when we actually calculate speed
|
|
123
|
-
self.last_downloaded = current_downloaded
|
|
124
|
-
self.last_time = current_time
|
|
125
|
-
else:
|
|
126
|
-
# First measurement - initialize tracking variables
|
|
127
|
-
self.last_downloaded = current_downloaded
|
|
128
|
-
self.last_time = current_time
|
|
129
|
-
|
|
130
|
-
# Return the average of historical speeds if we have any
|
|
131
|
-
# This ensures we show the last known speed even when skipping updates
|
|
132
|
-
if self.speed_history:
|
|
133
|
-
return sum(self.speed_history) / len(self.speed_history)
|
|
134
|
-
|
|
135
|
-
return 0.0
|
|
136
|
-
|
|
137
|
-
def format_bytes(self, bytes_value: int) -> str:
|
|
138
|
-
"""Format bytes to human readable string."""
|
|
139
|
-
for unit in ['B', 'KB', 'MB', 'GB']:
|
|
140
|
-
if bytes_value < 1024.0:
|
|
141
|
-
return f"{bytes_value:.1f} {unit}"
|
|
142
|
-
bytes_value /= 1024.0
|
|
143
|
-
return f"{bytes_value:.1f} TB"
|
|
144
|
-
|
|
145
|
-
def format_speed(self, speed: float) -> str:
|
|
146
|
-
"""Format speed to human readable string."""
|
|
147
|
-
if speed == 0:
|
|
148
|
-
return "0 B/s"
|
|
149
|
-
|
|
150
|
-
for unit in ['B/s', 'KB/s', 'MB/s', 'GB/s']:
|
|
151
|
-
if speed < 1024.0:
|
|
152
|
-
return f"{speed:.1f} {unit}"
|
|
153
|
-
speed /= 1024.0
|
|
154
|
-
return f"{speed:.1f} TB/s"
|
|
155
|
-
|
|
156
|
-
def get_progress_data(self) -> Dict[str, Any]:
|
|
157
|
-
"""Get current progress data."""
|
|
158
|
-
total_downloaded = 0
|
|
159
|
-
active_file_count = 0
|
|
160
|
-
total_file_sizes = 0
|
|
161
|
-
|
|
162
|
-
for data in self.progress_data.values():
|
|
163
|
-
if data['total'] > 0:
|
|
164
|
-
total_downloaded += data['current']
|
|
165
|
-
total_file_sizes += data['total']
|
|
166
|
-
active_file_count += 1
|
|
167
|
-
|
|
168
|
-
# Calculate speed (tracking variables are updated internally)
|
|
169
|
-
speed = self.calculate_speed(total_downloaded)
|
|
170
|
-
|
|
171
|
-
# Determine total size - prioritize pre-fetched repo size, then aggregate file sizes
|
|
172
|
-
if self.total_repo_size > 0:
|
|
173
|
-
# Use pre-fetched repository info if available
|
|
174
|
-
total_size = self.total_repo_size
|
|
175
|
-
elif total_file_sizes > 0:
|
|
176
|
-
# Use sum of individual file sizes if available
|
|
177
|
-
total_size = total_file_sizes
|
|
178
|
-
else:
|
|
179
|
-
# Last resort - we don't know the total size yet
|
|
180
|
-
total_size = 0
|
|
181
|
-
|
|
182
|
-
file_count = self.repo_file_count if self.repo_file_count > 0 else active_file_count
|
|
183
|
-
|
|
184
|
-
# Calculate percentage - handle unknown total size gracefully
|
|
185
|
-
if total_size > 0:
|
|
186
|
-
percentage = min((total_downloaded / total_size * 100), 100.0)
|
|
187
|
-
else:
|
|
188
|
-
percentage = 0
|
|
189
|
-
|
|
190
|
-
# Calculate ETA
|
|
191
|
-
eta_seconds = None
|
|
192
|
-
if speed > 0 and total_size > total_downloaded:
|
|
193
|
-
eta_seconds = (total_size - total_downloaded) / speed
|
|
194
|
-
|
|
195
|
-
# Calculate elapsed time
|
|
196
|
-
elapsed_seconds = None
|
|
197
|
-
if self.download_start_time:
|
|
198
|
-
elapsed_seconds = time.time() - self.download_start_time
|
|
199
|
-
|
|
200
|
-
return {
|
|
201
|
-
'status': self.download_status,
|
|
202
|
-
'error_message': self.error_message,
|
|
203
|
-
'progress': {
|
|
204
|
-
'total_downloaded': total_downloaded,
|
|
205
|
-
'total_size': total_size,
|
|
206
|
-
'percentage': round(percentage, 2),
|
|
207
|
-
'files_active': active_file_count,
|
|
208
|
-
'files_total': file_count,
|
|
209
|
-
'known_total': total_size > 0
|
|
210
|
-
},
|
|
211
|
-
'speed': {
|
|
212
|
-
'bytes_per_second': speed,
|
|
213
|
-
'formatted': self.format_speed(speed)
|
|
214
|
-
},
|
|
215
|
-
'formatting': {
|
|
216
|
-
'downloaded': self.format_bytes(total_downloaded),
|
|
217
|
-
'total_size': self.format_bytes(total_size)
|
|
218
|
-
},
|
|
219
|
-
'timing': {
|
|
220
|
-
'elapsed_seconds': elapsed_seconds,
|
|
221
|
-
'eta_seconds': eta_seconds,
|
|
222
|
-
'start_time': self.download_start_time
|
|
223
|
-
}
|
|
224
|
-
}
|
|
225
|
-
|
|
226
|
-
def _display_progress_bar(self, progress_data: Dict[str, Any]):
|
|
227
|
-
"""Display a custom unified progress bar."""
|
|
228
|
-
if not self.show_progress:
|
|
229
|
-
return
|
|
230
|
-
|
|
231
|
-
# Clear previous line
|
|
232
|
-
if self.last_display_length > 0:
|
|
233
|
-
print('\r' + ' ' * self.last_display_length, end='\r')
|
|
234
|
-
|
|
235
|
-
progress_info = progress_data.get('progress', {})
|
|
236
|
-
speed_info = progress_data.get('speed', {})
|
|
237
|
-
timing_info = progress_data.get('timing', {})
|
|
238
|
-
formatting_info = progress_data.get('formatting', {})
|
|
239
|
-
|
|
240
|
-
percentage = progress_info.get('percentage', 0)
|
|
241
|
-
downloaded = formatting_info.get('downloaded', '0 B')
|
|
242
|
-
total_size_raw = progress_info.get('total_size', 0)
|
|
243
|
-
total_size = formatting_info.get('total_size', 'Unknown')
|
|
244
|
-
speed = speed_info.get('formatted', '0 B/s')
|
|
245
|
-
known_total = progress_info.get('known_total', False)
|
|
246
|
-
|
|
247
|
-
# Create progress bar
|
|
248
|
-
bar_width = 30
|
|
249
|
-
if known_total and total_size_raw > 0:
|
|
250
|
-
# Known total size - show actual progress
|
|
251
|
-
filled_width = int(bar_width * min(percentage, 100) / 100)
|
|
252
|
-
bar = '#' * filled_width + '-' * (bar_width - filled_width)
|
|
253
|
-
else:
|
|
254
|
-
# Unknown total size - show animated progress
|
|
255
|
-
animation_pos = int(time.time() * 2) % bar_width
|
|
256
|
-
bar = '-' * animation_pos + '#' + '-' * (bar_width - animation_pos - 1)
|
|
257
|
-
|
|
258
|
-
# Format the progress line
|
|
259
|
-
status = progress_data.get('status', 'unknown')
|
|
260
|
-
if status == 'downloading':
|
|
261
|
-
if known_total:
|
|
262
|
-
progress_line = f"[{bar}] {percentage:.1f}% | {downloaded}/{total_size} | {speed}"
|
|
263
|
-
else:
|
|
264
|
-
progress_line = f"[{bar}] {downloaded} | {speed} | Calculating size..."
|
|
265
|
-
elif status == 'completed':
|
|
266
|
-
progress_line = f"[{bar}] 100.0% | {downloaded} | Complete!"
|
|
267
|
-
elif status == 'error':
|
|
268
|
-
progress_line = f"Error: {progress_data.get('error_message', 'Unknown error')}"
|
|
269
|
-
else:
|
|
270
|
-
progress_line = f"Starting download..."
|
|
271
|
-
|
|
272
|
-
# Display and track length for next clear
|
|
273
|
-
print(progress_line, end='', flush=True)
|
|
274
|
-
self.last_display_length = len(progress_line)
|
|
275
|
-
|
|
276
|
-
def _clear_progress_bar(self):
|
|
277
|
-
"""Clear the progress bar display."""
|
|
278
|
-
if self.show_progress and self.last_display_length > 0:
|
|
279
|
-
print('\r' + ' ' * self.last_display_length, end='\r')
|
|
280
|
-
print() # Move to next line
|
|
281
|
-
self.last_display_length = 0
|
|
282
|
-
|
|
283
|
-
def _trigger_callback(self):
|
|
284
|
-
"""Trigger the progress callback if one is set."""
|
|
285
|
-
progress_data = self.get_progress_data()
|
|
286
|
-
|
|
287
|
-
if self.progress_callback:
|
|
288
|
-
try:
|
|
289
|
-
self.progress_callback(progress_data)
|
|
290
|
-
except Exception as e:
|
|
291
|
-
print(f"Error in progress callback: {e}")
|
|
292
|
-
|
|
293
|
-
# Show custom progress bar only if callback is enabled and show_progress is True
|
|
294
|
-
if self.progress_callback and self.show_progress:
|
|
295
|
-
self._display_progress_bar(progress_data)
|
|
296
|
-
|
|
297
|
-
def start_tracking(self):
|
|
298
|
-
"""Start progress tracking (monkey patch tqdm)."""
|
|
299
|
-
if self.is_tracking:
|
|
300
|
-
return
|
|
301
|
-
|
|
302
|
-
# Store original methods
|
|
303
|
-
self.original_tqdm_update = tqdm.update
|
|
304
|
-
self.original_tqdm_init = tqdm.__init__
|
|
305
|
-
self.original_tqdm_display = tqdm.display
|
|
306
|
-
self.original_tqdm_write = tqdm.write
|
|
307
|
-
|
|
308
|
-
# Create references to self for the nested functions
|
|
309
|
-
tracker = self
|
|
310
|
-
|
|
311
|
-
def patched_init(self_tqdm, *args, **kwargs):
|
|
312
|
-
# Suppress tqdm display by redirecting to devnull
|
|
313
|
-
kwargs['file'] = open(os.devnull, 'w')
|
|
314
|
-
kwargs['disable'] = False # Keep enabled for tracking
|
|
315
|
-
kwargs['leave'] = False # Don't leave progress bar
|
|
316
|
-
|
|
317
|
-
result = tracker.original_tqdm_init(self_tqdm, *args, **kwargs)
|
|
318
|
-
tracker.register_tqdm(self_tqdm)
|
|
319
|
-
return result
|
|
320
|
-
|
|
321
|
-
def patched_update(self_tqdm, n=1):
|
|
322
|
-
result = tracker.original_tqdm_update(self_tqdm, n)
|
|
323
|
-
tracker.update_progress(self_tqdm, n)
|
|
324
|
-
return result
|
|
325
|
-
|
|
326
|
-
def patched_display(self_tqdm, msg=None, pos=None):
|
|
327
|
-
# Override display to show nothing
|
|
328
|
-
pass
|
|
329
|
-
|
|
330
|
-
def patched_write(self_tqdm, s, file=None, end="\n", nolock=False):
|
|
331
|
-
# Override write to prevent any output
|
|
332
|
-
pass
|
|
333
|
-
|
|
334
|
-
# Apply patches
|
|
335
|
-
tqdm.__init__ = patched_init
|
|
336
|
-
tqdm.update = patched_update
|
|
337
|
-
tqdm.display = patched_display
|
|
338
|
-
tqdm.write = patched_write
|
|
339
|
-
|
|
340
|
-
self.is_tracking = True
|
|
341
|
-
self.download_status = "downloading"
|
|
342
|
-
self.download_start_time = time.time()
|
|
343
|
-
|
|
344
|
-
# Trigger initial callback
|
|
345
|
-
self._trigger_callback()
|
|
346
|
-
|
|
347
|
-
def stop_tracking(self):
|
|
348
|
-
"""Stop progress tracking and restore original tqdm."""
|
|
349
|
-
if not self.is_tracking:
|
|
350
|
-
return
|
|
351
|
-
|
|
352
|
-
# Restore original tqdm methods
|
|
353
|
-
if self.original_tqdm_update:
|
|
354
|
-
tqdm.update = self.original_tqdm_update
|
|
355
|
-
if self.original_tqdm_init:
|
|
356
|
-
tqdm.__init__ = self.original_tqdm_init
|
|
357
|
-
if hasattr(self, 'original_tqdm_display') and self.original_tqdm_display:
|
|
358
|
-
tqdm.display = self.original_tqdm_display
|
|
359
|
-
if hasattr(self, 'original_tqdm_write') and self.original_tqdm_write:
|
|
360
|
-
tqdm.write = self.original_tqdm_write
|
|
361
|
-
|
|
362
|
-
# Clean up any open devnull file handles from tqdm instances
|
|
363
|
-
for data in self.progress_data.values():
|
|
364
|
-
if 'tqdm_obj' in data and hasattr(data['tqdm_obj'], 'fp'):
|
|
365
|
-
try:
|
|
366
|
-
fp = data['tqdm_obj'].fp
|
|
367
|
-
if fp and fp != sys.stdout and fp != sys.stderr and not fp.closed:
|
|
368
|
-
fp.close()
|
|
369
|
-
except:
|
|
370
|
-
pass
|
|
371
|
-
|
|
372
|
-
self.is_tracking = False
|
|
373
|
-
if self.download_status == "downloading":
|
|
374
|
-
self.download_status = "completed"
|
|
375
|
-
|
|
376
|
-
# Trigger final callback and clear progress bar
|
|
377
|
-
self._trigger_callback()
|
|
378
|
-
self._clear_progress_bar()
|
|
379
|
-
|
|
380
|
-
def set_error(self, error_message: str):
|
|
381
|
-
"""Set error status and trigger callback."""
|
|
382
|
-
self.download_status = "error"
|
|
383
|
-
self.error_message = error_message
|
|
384
|
-
self._trigger_callback()
|
|
1
|
+
"""
|
|
2
|
+
Progress tracking utilities for downloads with tqdm integration.
|
|
3
|
+
|
|
4
|
+
This module provides custom progress tracking classes that can monitor
|
|
5
|
+
download progress with callback support and customizable display options.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import os
|
|
9
|
+
import sys
|
|
10
|
+
import time
|
|
11
|
+
from typing import Optional, Callable, Dict, Any
|
|
12
|
+
from tqdm.auto import tqdm
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class CustomProgressTqdm(tqdm):
|
|
16
|
+
"""Custom tqdm that tracks progress but completely hides terminal output."""
|
|
17
|
+
|
|
18
|
+
def __init__(self, *args, **kwargs):
|
|
19
|
+
# Redirect output to devnull to completely suppress terminal output
|
|
20
|
+
kwargs['file'] = open(os.devnull, 'w')
|
|
21
|
+
kwargs['disable'] = False # Keep enabled for tracking
|
|
22
|
+
kwargs['leave'] = False # Don't leave progress bar
|
|
23
|
+
super().__init__(*args, **kwargs)
|
|
24
|
+
|
|
25
|
+
def display(self, msg=None, pos=None):
|
|
26
|
+
# Override display to show nothing
|
|
27
|
+
pass
|
|
28
|
+
|
|
29
|
+
def write(self, s, file=None, end="\n", nolock=False):
|
|
30
|
+
# Override write to prevent any output
|
|
31
|
+
pass
|
|
32
|
+
|
|
33
|
+
def close(self):
|
|
34
|
+
# Override close to avoid printing and properly close devnull
|
|
35
|
+
if hasattr(self, 'fp') and self.fp and self.fp != sys.stdout and self.fp != sys.stderr:
|
|
36
|
+
try:
|
|
37
|
+
self.fp.close()
|
|
38
|
+
except:
|
|
39
|
+
pass
|
|
40
|
+
self.disable = True
|
|
41
|
+
super(tqdm, self).close()
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class DownloadProgressTracker:
|
|
45
|
+
"""Progress tracker for HuggingFace downloads with callback support."""
|
|
46
|
+
|
|
47
|
+
def __init__(self, progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None, show_progress: bool = True):
|
|
48
|
+
self.progress_data: Dict[str, Dict[str, Any]] = {}
|
|
49
|
+
self.total_repo_size = 0
|
|
50
|
+
self.repo_file_count = 0
|
|
51
|
+
self.original_tqdm_update = None
|
|
52
|
+
self.original_tqdm_init = None
|
|
53
|
+
self.original_tqdm_display = None
|
|
54
|
+
self.original_tqdm_write = None
|
|
55
|
+
self.is_tracking = False
|
|
56
|
+
|
|
57
|
+
# Callback function
|
|
58
|
+
self.progress_callback = progress_callback
|
|
59
|
+
|
|
60
|
+
# Progress display
|
|
61
|
+
self.show_progress = show_progress
|
|
62
|
+
self.last_display_length = 0
|
|
63
|
+
|
|
64
|
+
# Speed tracking
|
|
65
|
+
self.last_downloaded = None # Use None to indicate no previous measurement
|
|
66
|
+
self.last_time = None # Use None to indicate no previous time measurement
|
|
67
|
+
self.speed_history = []
|
|
68
|
+
self.max_speed_history = 10
|
|
69
|
+
|
|
70
|
+
# Download status
|
|
71
|
+
self.download_status = "idle" # idle, downloading, completed, error
|
|
72
|
+
self.error_message = None
|
|
73
|
+
self.download_start_time = None
|
|
74
|
+
|
|
75
|
+
def set_repo_info(self, total_size: int, file_count: int):
|
|
76
|
+
"""Set the total repository size and file count before download."""
|
|
77
|
+
self.total_repo_size = total_size
|
|
78
|
+
self.repo_file_count = file_count
|
|
79
|
+
|
|
80
|
+
def register_tqdm(self, tqdm_instance):
|
|
81
|
+
"""Register a tqdm instance for monitoring."""
|
|
82
|
+
tqdm_id = str(id(tqdm_instance))
|
|
83
|
+
self.progress_data[tqdm_id] = {
|
|
84
|
+
'current': 0,
|
|
85
|
+
'total': getattr(tqdm_instance, 'total', 0) or 0,
|
|
86
|
+
'desc': getattr(tqdm_instance, 'desc', 'Unknown'),
|
|
87
|
+
'tqdm_obj': tqdm_instance
|
|
88
|
+
}
|
|
89
|
+
# Trigger callback when new file is registered
|
|
90
|
+
self._trigger_callback()
|
|
91
|
+
|
|
92
|
+
def update_progress(self, tqdm_instance, n=1):
|
|
93
|
+
"""Update progress for a tqdm instance."""
|
|
94
|
+
tqdm_id = str(id(tqdm_instance))
|
|
95
|
+
if tqdm_id in self.progress_data:
|
|
96
|
+
self.progress_data[tqdm_id]['current'] = getattr(tqdm_instance, 'n', 0)
|
|
97
|
+
self.progress_data[tqdm_id]['total'] = getattr(tqdm_instance, 'total', 0) or 0
|
|
98
|
+
# Trigger callback on every progress update
|
|
99
|
+
self._trigger_callback()
|
|
100
|
+
|
|
101
|
+
def calculate_speed(self, current_downloaded: int) -> float:
|
|
102
|
+
"""Calculate download speed in bytes per second."""
|
|
103
|
+
current_time = time.time()
|
|
104
|
+
|
|
105
|
+
# Check if we have a previous measurement to compare against
|
|
106
|
+
if self.last_time is not None and self.last_downloaded is not None:
|
|
107
|
+
time_diff = current_time - self.last_time
|
|
108
|
+
|
|
109
|
+
# Only calculate if we have a meaningful time difference (avoid division by very small numbers)
|
|
110
|
+
if time_diff > 0.1: # At least 100ms between measurements
|
|
111
|
+
bytes_diff = current_downloaded - self.last_downloaded
|
|
112
|
+
|
|
113
|
+
# Only calculate speed if bytes actually changed
|
|
114
|
+
if bytes_diff >= 0: # Allow 0 for periods with no progress
|
|
115
|
+
speed = bytes_diff / time_diff
|
|
116
|
+
|
|
117
|
+
# Add to speed history for smoothing
|
|
118
|
+
self.speed_history.append(speed)
|
|
119
|
+
if len(self.speed_history) > self.max_speed_history:
|
|
120
|
+
self.speed_history.pop(0)
|
|
121
|
+
|
|
122
|
+
# Update tracking variables when we actually calculate speed
|
|
123
|
+
self.last_downloaded = current_downloaded
|
|
124
|
+
self.last_time = current_time
|
|
125
|
+
else:
|
|
126
|
+
# First measurement - initialize tracking variables
|
|
127
|
+
self.last_downloaded = current_downloaded
|
|
128
|
+
self.last_time = current_time
|
|
129
|
+
|
|
130
|
+
# Return the average of historical speeds if we have any
|
|
131
|
+
# This ensures we show the last known speed even when skipping updates
|
|
132
|
+
if self.speed_history:
|
|
133
|
+
return sum(self.speed_history) / len(self.speed_history)
|
|
134
|
+
|
|
135
|
+
return 0.0
|
|
136
|
+
|
|
137
|
+
def format_bytes(self, bytes_value: int) -> str:
|
|
138
|
+
"""Format bytes to human readable string."""
|
|
139
|
+
for unit in ['B', 'KB', 'MB', 'GB']:
|
|
140
|
+
if bytes_value < 1024.0:
|
|
141
|
+
return f"{bytes_value:.1f} {unit}"
|
|
142
|
+
bytes_value /= 1024.0
|
|
143
|
+
return f"{bytes_value:.1f} TB"
|
|
144
|
+
|
|
145
|
+
def format_speed(self, speed: float) -> str:
|
|
146
|
+
"""Format speed to human readable string."""
|
|
147
|
+
if speed == 0:
|
|
148
|
+
return "0 B/s"
|
|
149
|
+
|
|
150
|
+
for unit in ['B/s', 'KB/s', 'MB/s', 'GB/s']:
|
|
151
|
+
if speed < 1024.0:
|
|
152
|
+
return f"{speed:.1f} {unit}"
|
|
153
|
+
speed /= 1024.0
|
|
154
|
+
return f"{speed:.1f} TB/s"
|
|
155
|
+
|
|
156
|
+
def get_progress_data(self) -> Dict[str, Any]:
|
|
157
|
+
"""Get current progress data."""
|
|
158
|
+
total_downloaded = 0
|
|
159
|
+
active_file_count = 0
|
|
160
|
+
total_file_sizes = 0
|
|
161
|
+
|
|
162
|
+
for data in self.progress_data.values():
|
|
163
|
+
if data['total'] > 0:
|
|
164
|
+
total_downloaded += data['current']
|
|
165
|
+
total_file_sizes += data['total']
|
|
166
|
+
active_file_count += 1
|
|
167
|
+
|
|
168
|
+
# Calculate speed (tracking variables are updated internally)
|
|
169
|
+
speed = self.calculate_speed(total_downloaded)
|
|
170
|
+
|
|
171
|
+
# Determine total size - prioritize pre-fetched repo size, then aggregate file sizes
|
|
172
|
+
if self.total_repo_size > 0:
|
|
173
|
+
# Use pre-fetched repository info if available
|
|
174
|
+
total_size = self.total_repo_size
|
|
175
|
+
elif total_file_sizes > 0:
|
|
176
|
+
# Use sum of individual file sizes if available
|
|
177
|
+
total_size = total_file_sizes
|
|
178
|
+
else:
|
|
179
|
+
# Last resort - we don't know the total size yet
|
|
180
|
+
total_size = 0
|
|
181
|
+
|
|
182
|
+
file_count = self.repo_file_count if self.repo_file_count > 0 else active_file_count
|
|
183
|
+
|
|
184
|
+
# Calculate percentage - handle unknown total size gracefully
|
|
185
|
+
if total_size > 0:
|
|
186
|
+
percentage = min((total_downloaded / total_size * 100), 100.0)
|
|
187
|
+
else:
|
|
188
|
+
percentage = 0
|
|
189
|
+
|
|
190
|
+
# Calculate ETA
|
|
191
|
+
eta_seconds = None
|
|
192
|
+
if speed > 0 and total_size > total_downloaded:
|
|
193
|
+
eta_seconds = (total_size - total_downloaded) / speed
|
|
194
|
+
|
|
195
|
+
# Calculate elapsed time
|
|
196
|
+
elapsed_seconds = None
|
|
197
|
+
if self.download_start_time:
|
|
198
|
+
elapsed_seconds = time.time() - self.download_start_time
|
|
199
|
+
|
|
200
|
+
return {
|
|
201
|
+
'status': self.download_status,
|
|
202
|
+
'error_message': self.error_message,
|
|
203
|
+
'progress': {
|
|
204
|
+
'total_downloaded': total_downloaded,
|
|
205
|
+
'total_size': total_size,
|
|
206
|
+
'percentage': round(percentage, 2),
|
|
207
|
+
'files_active': active_file_count,
|
|
208
|
+
'files_total': file_count,
|
|
209
|
+
'known_total': total_size > 0
|
|
210
|
+
},
|
|
211
|
+
'speed': {
|
|
212
|
+
'bytes_per_second': speed,
|
|
213
|
+
'formatted': self.format_speed(speed)
|
|
214
|
+
},
|
|
215
|
+
'formatting': {
|
|
216
|
+
'downloaded': self.format_bytes(total_downloaded),
|
|
217
|
+
'total_size': self.format_bytes(total_size)
|
|
218
|
+
},
|
|
219
|
+
'timing': {
|
|
220
|
+
'elapsed_seconds': elapsed_seconds,
|
|
221
|
+
'eta_seconds': eta_seconds,
|
|
222
|
+
'start_time': self.download_start_time
|
|
223
|
+
}
|
|
224
|
+
}
|
|
225
|
+
|
|
226
|
+
def _display_progress_bar(self, progress_data: Dict[str, Any]):
|
|
227
|
+
"""Display a custom unified progress bar."""
|
|
228
|
+
if not self.show_progress:
|
|
229
|
+
return
|
|
230
|
+
|
|
231
|
+
# Clear previous line
|
|
232
|
+
if self.last_display_length > 0:
|
|
233
|
+
print('\r' + ' ' * self.last_display_length, end='\r')
|
|
234
|
+
|
|
235
|
+
progress_info = progress_data.get('progress', {})
|
|
236
|
+
speed_info = progress_data.get('speed', {})
|
|
237
|
+
timing_info = progress_data.get('timing', {})
|
|
238
|
+
formatting_info = progress_data.get('formatting', {})
|
|
239
|
+
|
|
240
|
+
percentage = progress_info.get('percentage', 0)
|
|
241
|
+
downloaded = formatting_info.get('downloaded', '0 B')
|
|
242
|
+
total_size_raw = progress_info.get('total_size', 0)
|
|
243
|
+
total_size = formatting_info.get('total_size', 'Unknown')
|
|
244
|
+
speed = speed_info.get('formatted', '0 B/s')
|
|
245
|
+
known_total = progress_info.get('known_total', False)
|
|
246
|
+
|
|
247
|
+
# Create progress bar
|
|
248
|
+
bar_width = 30
|
|
249
|
+
if known_total and total_size_raw > 0:
|
|
250
|
+
# Known total size - show actual progress
|
|
251
|
+
filled_width = int(bar_width * min(percentage, 100) / 100)
|
|
252
|
+
bar = '#' * filled_width + '-' * (bar_width - filled_width)
|
|
253
|
+
else:
|
|
254
|
+
# Unknown total size - show animated progress
|
|
255
|
+
animation_pos = int(time.time() * 2) % bar_width
|
|
256
|
+
bar = '-' * animation_pos + '#' + '-' * (bar_width - animation_pos - 1)
|
|
257
|
+
|
|
258
|
+
# Format the progress line
|
|
259
|
+
status = progress_data.get('status', 'unknown')
|
|
260
|
+
if status == 'downloading':
|
|
261
|
+
if known_total:
|
|
262
|
+
progress_line = f"[{bar}] {percentage:.1f}% | {downloaded}/{total_size} | {speed}"
|
|
263
|
+
else:
|
|
264
|
+
progress_line = f"[{bar}] {downloaded} | {speed} | Calculating size..."
|
|
265
|
+
elif status == 'completed':
|
|
266
|
+
progress_line = f"[{bar}] 100.0% | {downloaded} | Complete!"
|
|
267
|
+
elif status == 'error':
|
|
268
|
+
progress_line = f"Error: {progress_data.get('error_message', 'Unknown error')}"
|
|
269
|
+
else:
|
|
270
|
+
progress_line = f"Starting download..."
|
|
271
|
+
|
|
272
|
+
# Display and track length for next clear
|
|
273
|
+
print(progress_line, end='', flush=True)
|
|
274
|
+
self.last_display_length = len(progress_line)
|
|
275
|
+
|
|
276
|
+
def _clear_progress_bar(self):
|
|
277
|
+
"""Clear the progress bar display."""
|
|
278
|
+
if self.show_progress and self.last_display_length > 0:
|
|
279
|
+
print('\r' + ' ' * self.last_display_length, end='\r')
|
|
280
|
+
print() # Move to next line
|
|
281
|
+
self.last_display_length = 0
|
|
282
|
+
|
|
283
|
+
def _trigger_callback(self):
|
|
284
|
+
"""Trigger the progress callback if one is set."""
|
|
285
|
+
progress_data = self.get_progress_data()
|
|
286
|
+
|
|
287
|
+
if self.progress_callback:
|
|
288
|
+
try:
|
|
289
|
+
self.progress_callback(progress_data)
|
|
290
|
+
except Exception as e:
|
|
291
|
+
print(f"Error in progress callback: {e}")
|
|
292
|
+
|
|
293
|
+
# Show custom progress bar only if callback is enabled and show_progress is True
|
|
294
|
+
if self.progress_callback and self.show_progress:
|
|
295
|
+
self._display_progress_bar(progress_data)
|
|
296
|
+
|
|
297
|
+
def start_tracking(self):
|
|
298
|
+
"""Start progress tracking (monkey patch tqdm)."""
|
|
299
|
+
if self.is_tracking:
|
|
300
|
+
return
|
|
301
|
+
|
|
302
|
+
# Store original methods
|
|
303
|
+
self.original_tqdm_update = tqdm.update
|
|
304
|
+
self.original_tqdm_init = tqdm.__init__
|
|
305
|
+
self.original_tqdm_display = tqdm.display
|
|
306
|
+
self.original_tqdm_write = tqdm.write
|
|
307
|
+
|
|
308
|
+
# Create references to self for the nested functions
|
|
309
|
+
tracker = self
|
|
310
|
+
|
|
311
|
+
def patched_init(self_tqdm, *args, **kwargs):
|
|
312
|
+
# Suppress tqdm display by redirecting to devnull
|
|
313
|
+
kwargs['file'] = open(os.devnull, 'w')
|
|
314
|
+
kwargs['disable'] = False # Keep enabled for tracking
|
|
315
|
+
kwargs['leave'] = False # Don't leave progress bar
|
|
316
|
+
|
|
317
|
+
result = tracker.original_tqdm_init(self_tqdm, *args, **kwargs)
|
|
318
|
+
tracker.register_tqdm(self_tqdm)
|
|
319
|
+
return result
|
|
320
|
+
|
|
321
|
+
def patched_update(self_tqdm, n=1):
|
|
322
|
+
result = tracker.original_tqdm_update(self_tqdm, n)
|
|
323
|
+
tracker.update_progress(self_tqdm, n)
|
|
324
|
+
return result
|
|
325
|
+
|
|
326
|
+
def patched_display(self_tqdm, msg=None, pos=None):
|
|
327
|
+
# Override display to show nothing
|
|
328
|
+
pass
|
|
329
|
+
|
|
330
|
+
def patched_write(self_tqdm, s, file=None, end="\n", nolock=False):
|
|
331
|
+
# Override write to prevent any output
|
|
332
|
+
pass
|
|
333
|
+
|
|
334
|
+
# Apply patches
|
|
335
|
+
tqdm.__init__ = patched_init
|
|
336
|
+
tqdm.update = patched_update
|
|
337
|
+
tqdm.display = patched_display
|
|
338
|
+
tqdm.write = patched_write
|
|
339
|
+
|
|
340
|
+
self.is_tracking = True
|
|
341
|
+
self.download_status = "downloading"
|
|
342
|
+
self.download_start_time = time.time()
|
|
343
|
+
|
|
344
|
+
# Trigger initial callback
|
|
345
|
+
self._trigger_callback()
|
|
346
|
+
|
|
347
|
+
def stop_tracking(self):
|
|
348
|
+
"""Stop progress tracking and restore original tqdm."""
|
|
349
|
+
if not self.is_tracking:
|
|
350
|
+
return
|
|
351
|
+
|
|
352
|
+
# Restore original tqdm methods
|
|
353
|
+
if self.original_tqdm_update:
|
|
354
|
+
tqdm.update = self.original_tqdm_update
|
|
355
|
+
if self.original_tqdm_init:
|
|
356
|
+
tqdm.__init__ = self.original_tqdm_init
|
|
357
|
+
if hasattr(self, 'original_tqdm_display') and self.original_tqdm_display:
|
|
358
|
+
tqdm.display = self.original_tqdm_display
|
|
359
|
+
if hasattr(self, 'original_tqdm_write') and self.original_tqdm_write:
|
|
360
|
+
tqdm.write = self.original_tqdm_write
|
|
361
|
+
|
|
362
|
+
# Clean up any open devnull file handles from tqdm instances
|
|
363
|
+
for data in self.progress_data.values():
|
|
364
|
+
if 'tqdm_obj' in data and hasattr(data['tqdm_obj'], 'fp'):
|
|
365
|
+
try:
|
|
366
|
+
fp = data['tqdm_obj'].fp
|
|
367
|
+
if fp and fp != sys.stdout and fp != sys.stderr and not fp.closed:
|
|
368
|
+
fp.close()
|
|
369
|
+
except:
|
|
370
|
+
pass
|
|
371
|
+
|
|
372
|
+
self.is_tracking = False
|
|
373
|
+
if self.download_status == "downloading":
|
|
374
|
+
self.download_status = "completed"
|
|
375
|
+
|
|
376
|
+
# Trigger final callback and clear progress bar
|
|
377
|
+
self._trigger_callback()
|
|
378
|
+
self._clear_progress_bar()
|
|
379
|
+
|
|
380
|
+
def set_error(self, error_message: str):
|
|
381
|
+
"""Set error status and trigger callback."""
|
|
382
|
+
self.download_status = "error"
|
|
383
|
+
self.error_message = error_message
|
|
384
|
+
self._trigger_callback()
|
|
385
385
|
self._clear_progress_bar()
|