nexaai 1.0.21rc5__cp313-cp313-win_arm64.whl → 1.0.21rc16__cp313-cp313-win_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nexaai might be problematic. Click here for more details.

Files changed (104) hide show
  1. nexaai/__init__.py +95 -95
  2. nexaai/_stub.cp313-win_arm64.pyd +0 -0
  3. nexaai/_version.py +4 -1
  4. nexaai/asr.py +68 -65
  5. nexaai/asr_impl/mlx_asr_impl.py +92 -92
  6. nexaai/asr_impl/pybind_asr_impl.py +127 -44
  7. nexaai/base.py +39 -39
  8. nexaai/binds/__init__.py +6 -5
  9. nexaai/binds/asr_bind.cp313-win_arm64.pyd +0 -0
  10. nexaai/binds/common_bind.cp313-win_arm64.pyd +0 -0
  11. nexaai/binds/cpu_gpu/ggml-base.dll +0 -0
  12. nexaai/binds/cpu_gpu/ggml-cpu.dll +0 -0
  13. nexaai/binds/cpu_gpu/ggml-opencl.dll +0 -0
  14. nexaai/binds/cpu_gpu/ggml.dll +0 -0
  15. nexaai/binds/cpu_gpu/mtmd.dll +0 -0
  16. nexaai/binds/cpu_gpu/nexa_cpu_gpu.dll +0 -0
  17. nexaai/binds/cpu_gpu/nexa_plugin.dll +0 -0
  18. nexaai/binds/embedder_bind.cp313-win_arm64.pyd +0 -0
  19. nexaai/binds/libcrypto-3-arm64.dll +0 -0
  20. nexaai/binds/libssl-3-arm64.dll +0 -0
  21. nexaai/binds/llm_bind.cp313-win_arm64.pyd +0 -0
  22. nexaai/binds/nexa_bridge.dll +0 -0
  23. nexaai/binds/npu/convnext-sdk.dll +0 -0
  24. nexaai/binds/npu/embed-gemma-sdk.dll +0 -0
  25. nexaai/binds/npu/ggml-base.dll +0 -0
  26. nexaai/binds/npu/ggml-cpu.dll +0 -0
  27. nexaai/binds/{nexaml → npu}/ggml-opencl.dll +0 -0
  28. nexaai/binds/npu/ggml.dll +0 -0
  29. nexaai/binds/npu/granite-nano-sdk.dll +0 -0
  30. nexaai/binds/npu/granite4-sdk.dll +0 -0
  31. nexaai/binds/npu/jina-rerank-sdk.dll +0 -0
  32. nexaai/binds/npu/liquid-sdk.dll +0 -0
  33. nexaai/binds/npu/llama3-3b-sdk.dll +0 -0
  34. nexaai/binds/npu/nexa-mm-process.dll +0 -0
  35. nexaai/binds/npu/nexa-sampling.dll +0 -0
  36. nexaai/binds/npu/nexa_plugin.dll +0 -0
  37. nexaai/binds/npu/omni-neural-sdk.dll +0 -0
  38. nexaai/binds/npu/openblas.dll +0 -0
  39. nexaai/binds/npu/paddleocr-sdk.dll +0 -0
  40. nexaai/binds/npu/parakeet-sdk.dll +0 -0
  41. nexaai/binds/npu/phi3-5-sdk.dll +0 -0
  42. nexaai/binds/npu/phi4-sdk.dll +0 -0
  43. nexaai/binds/npu/pyannote-sdk.dll +0 -0
  44. nexaai/binds/npu/qwen3-4b-sdk.dll +0 -0
  45. nexaai/binds/npu/qwen3vl-sdk.dll +0 -0
  46. nexaai/binds/npu/qwen3vl-vision.dll +0 -0
  47. nexaai/binds/npu/yolov12-sdk.dll +0 -0
  48. nexaai/binds/npu/zlib1.dll +0 -0
  49. nexaai/binds/rerank_bind.cp313-win_arm64.pyd +0 -0
  50. nexaai/binds/vlm_bind.cp313-win_arm64.pyd +0 -0
  51. nexaai/common.py +105 -105
  52. nexaai/cv.py +93 -93
  53. nexaai/cv_impl/mlx_cv_impl.py +89 -89
  54. nexaai/cv_impl/pybind_cv_impl.py +32 -32
  55. nexaai/embedder.py +73 -73
  56. nexaai/embedder_impl/mlx_embedder_impl.py +118 -118
  57. nexaai/embedder_impl/pybind_embedder_impl.py +96 -96
  58. nexaai/image_gen.py +141 -141
  59. nexaai/image_gen_impl/mlx_image_gen_impl.py +292 -292
  60. nexaai/image_gen_impl/pybind_image_gen_impl.py +85 -85
  61. nexaai/llm.py +98 -98
  62. nexaai/llm_impl/mlx_llm_impl.py +271 -271
  63. nexaai/llm_impl/pybind_llm_impl.py +220 -220
  64. nexaai/log.py +92 -92
  65. nexaai/rerank.py +57 -57
  66. nexaai/rerank_impl/mlx_rerank_impl.py +94 -94
  67. nexaai/rerank_impl/pybind_rerank_impl.py +136 -136
  68. nexaai/runtime.py +68 -68
  69. nexaai/runtime_error.py +24 -24
  70. nexaai/tts.py +75 -75
  71. nexaai/tts_impl/mlx_tts_impl.py +94 -94
  72. nexaai/tts_impl/pybind_tts_impl.py +43 -43
  73. nexaai/utils/decode.py +17 -17
  74. nexaai/utils/manifest_utils.py +531 -531
  75. nexaai/utils/model_manager.py +1562 -1562
  76. nexaai/utils/model_types.py +49 -49
  77. nexaai/utils/progress_tracker.py +384 -384
  78. nexaai/utils/quantization_utils.py +245 -245
  79. nexaai/vlm.py +129 -129
  80. nexaai/vlm_impl/mlx_vlm_impl.py +258 -258
  81. nexaai/vlm_impl/pybind_vlm_impl.py +256 -256
  82. {nexaai-1.0.21rc5.dist-info → nexaai-1.0.21rc16.dist-info}/METADATA +1 -1
  83. nexaai-1.0.21rc16.dist-info/RECORD +154 -0
  84. nexaai/binds/nexaml/FLAC.dll +0 -0
  85. nexaai/binds/nexaml/fftw3.dll +0 -0
  86. nexaai/binds/nexaml/fftw3f.dll +0 -0
  87. nexaai/binds/nexaml/ggml-base.dll +0 -0
  88. nexaai/binds/nexaml/ggml-cpu.dll +0 -0
  89. nexaai/binds/nexaml/ggml.dll +0 -0
  90. nexaai/binds/nexaml/libmp3lame.DLL +0 -0
  91. nexaai/binds/nexaml/mpg123.dll +0 -0
  92. nexaai/binds/nexaml/nexa-mm-process.dll +0 -0
  93. nexaai/binds/nexaml/nexa-sampling.dll +0 -0
  94. nexaai/binds/nexaml/nexa_plugin.dll +0 -0
  95. nexaai/binds/nexaml/nexaproc.dll +0 -0
  96. nexaai/binds/nexaml/ogg.dll +0 -0
  97. nexaai/binds/nexaml/opus.dll +0 -0
  98. nexaai/binds/nexaml/qwen3-vl.dll +0 -0
  99. nexaai/binds/nexaml/qwen3vl-vision.dll +0 -0
  100. nexaai/binds/nexaml/vorbis.dll +0 -0
  101. nexaai/binds/nexaml/vorbisenc.dll +0 -0
  102. nexaai-1.0.21rc5.dist-info/RECORD +0 -162
  103. {nexaai-1.0.21rc5.dist-info → nexaai-1.0.21rc16.dist-info}/WHEEL +0 -0
  104. {nexaai-1.0.21rc5.dist-info → nexaai-1.0.21rc16.dist-info}/top_level.txt +0 -0
@@ -1,385 +1,385 @@
1
- """
2
- Progress tracking utilities for downloads with tqdm integration.
3
-
4
- This module provides custom progress tracking classes that can monitor
5
- download progress with callback support and customizable display options.
6
- """
7
-
8
- import os
9
- import sys
10
- import time
11
- from typing import Optional, Callable, Dict, Any
12
- from tqdm.auto import tqdm
13
-
14
-
15
- class CustomProgressTqdm(tqdm):
16
- """Custom tqdm that tracks progress but completely hides terminal output."""
17
-
18
- def __init__(self, *args, **kwargs):
19
- # Redirect output to devnull to completely suppress terminal output
20
- kwargs['file'] = open(os.devnull, 'w')
21
- kwargs['disable'] = False # Keep enabled for tracking
22
- kwargs['leave'] = False # Don't leave progress bar
23
- super().__init__(*args, **kwargs)
24
-
25
- def display(self, msg=None, pos=None):
26
- # Override display to show nothing
27
- pass
28
-
29
- def write(self, s, file=None, end="\n", nolock=False):
30
- # Override write to prevent any output
31
- pass
32
-
33
- def close(self):
34
- # Override close to avoid printing and properly close devnull
35
- if hasattr(self, 'fp') and self.fp and self.fp != sys.stdout and self.fp != sys.stderr:
36
- try:
37
- self.fp.close()
38
- except:
39
- pass
40
- self.disable = True
41
- super(tqdm, self).close()
42
-
43
-
44
- class DownloadProgressTracker:
45
- """Progress tracker for HuggingFace downloads with callback support."""
46
-
47
- def __init__(self, progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None, show_progress: bool = True):
48
- self.progress_data: Dict[str, Dict[str, Any]] = {}
49
- self.total_repo_size = 0
50
- self.repo_file_count = 0
51
- self.original_tqdm_update = None
52
- self.original_tqdm_init = None
53
- self.original_tqdm_display = None
54
- self.original_tqdm_write = None
55
- self.is_tracking = False
56
-
57
- # Callback function
58
- self.progress_callback = progress_callback
59
-
60
- # Progress display
61
- self.show_progress = show_progress
62
- self.last_display_length = 0
63
-
64
- # Speed tracking
65
- self.last_downloaded = None # Use None to indicate no previous measurement
66
- self.last_time = None # Use None to indicate no previous time measurement
67
- self.speed_history = []
68
- self.max_speed_history = 10
69
-
70
- # Download status
71
- self.download_status = "idle" # idle, downloading, completed, error
72
- self.error_message = None
73
- self.download_start_time = None
74
-
75
- def set_repo_info(self, total_size: int, file_count: int):
76
- """Set the total repository size and file count before download."""
77
- self.total_repo_size = total_size
78
- self.repo_file_count = file_count
79
-
80
- def register_tqdm(self, tqdm_instance):
81
- """Register a tqdm instance for monitoring."""
82
- tqdm_id = str(id(tqdm_instance))
83
- self.progress_data[tqdm_id] = {
84
- 'current': 0,
85
- 'total': getattr(tqdm_instance, 'total', 0) or 0,
86
- 'desc': getattr(tqdm_instance, 'desc', 'Unknown'),
87
- 'tqdm_obj': tqdm_instance
88
- }
89
- # Trigger callback when new file is registered
90
- self._trigger_callback()
91
-
92
- def update_progress(self, tqdm_instance, n=1):
93
- """Update progress for a tqdm instance."""
94
- tqdm_id = str(id(tqdm_instance))
95
- if tqdm_id in self.progress_data:
96
- self.progress_data[tqdm_id]['current'] = getattr(tqdm_instance, 'n', 0)
97
- self.progress_data[tqdm_id]['total'] = getattr(tqdm_instance, 'total', 0) or 0
98
- # Trigger callback on every progress update
99
- self._trigger_callback()
100
-
101
- def calculate_speed(self, current_downloaded: int) -> float:
102
- """Calculate download speed in bytes per second."""
103
- current_time = time.time()
104
-
105
- # Check if we have a previous measurement to compare against
106
- if self.last_time is not None and self.last_downloaded is not None:
107
- time_diff = current_time - self.last_time
108
-
109
- # Only calculate if we have a meaningful time difference (avoid division by very small numbers)
110
- if time_diff > 0.1: # At least 100ms between measurements
111
- bytes_diff = current_downloaded - self.last_downloaded
112
-
113
- # Only calculate speed if bytes actually changed
114
- if bytes_diff >= 0: # Allow 0 for periods with no progress
115
- speed = bytes_diff / time_diff
116
-
117
- # Add to speed history for smoothing
118
- self.speed_history.append(speed)
119
- if len(self.speed_history) > self.max_speed_history:
120
- self.speed_history.pop(0)
121
-
122
- # Update tracking variables when we actually calculate speed
123
- self.last_downloaded = current_downloaded
124
- self.last_time = current_time
125
- else:
126
- # First measurement - initialize tracking variables
127
- self.last_downloaded = current_downloaded
128
- self.last_time = current_time
129
-
130
- # Return the average of historical speeds if we have any
131
- # This ensures we show the last known speed even when skipping updates
132
- if self.speed_history:
133
- return sum(self.speed_history) / len(self.speed_history)
134
-
135
- return 0.0
136
-
137
- def format_bytes(self, bytes_value: int) -> str:
138
- """Format bytes to human readable string."""
139
- for unit in ['B', 'KB', 'MB', 'GB']:
140
- if bytes_value < 1024.0:
141
- return f"{bytes_value:.1f} {unit}"
142
- bytes_value /= 1024.0
143
- return f"{bytes_value:.1f} TB"
144
-
145
- def format_speed(self, speed: float) -> str:
146
- """Format speed to human readable string."""
147
- if speed == 0:
148
- return "0 B/s"
149
-
150
- for unit in ['B/s', 'KB/s', 'MB/s', 'GB/s']:
151
- if speed < 1024.0:
152
- return f"{speed:.1f} {unit}"
153
- speed /= 1024.0
154
- return f"{speed:.1f} TB/s"
155
-
156
- def get_progress_data(self) -> Dict[str, Any]:
157
- """Get current progress data."""
158
- total_downloaded = 0
159
- active_file_count = 0
160
- total_file_sizes = 0
161
-
162
- for data in self.progress_data.values():
163
- if data['total'] > 0:
164
- total_downloaded += data['current']
165
- total_file_sizes += data['total']
166
- active_file_count += 1
167
-
168
- # Calculate speed (tracking variables are updated internally)
169
- speed = self.calculate_speed(total_downloaded)
170
-
171
- # Determine total size - prioritize pre-fetched repo size, then aggregate file sizes
172
- if self.total_repo_size > 0:
173
- # Use pre-fetched repository info if available
174
- total_size = self.total_repo_size
175
- elif total_file_sizes > 0:
176
- # Use sum of individual file sizes if available
177
- total_size = total_file_sizes
178
- else:
179
- # Last resort - we don't know the total size yet
180
- total_size = 0
181
-
182
- file_count = self.repo_file_count if self.repo_file_count > 0 else active_file_count
183
-
184
- # Calculate percentage - handle unknown total size gracefully
185
- if total_size > 0:
186
- percentage = min((total_downloaded / total_size * 100), 100.0)
187
- else:
188
- percentage = 0
189
-
190
- # Calculate ETA
191
- eta_seconds = None
192
- if speed > 0 and total_size > total_downloaded:
193
- eta_seconds = (total_size - total_downloaded) / speed
194
-
195
- # Calculate elapsed time
196
- elapsed_seconds = None
197
- if self.download_start_time:
198
- elapsed_seconds = time.time() - self.download_start_time
199
-
200
- return {
201
- 'status': self.download_status,
202
- 'error_message': self.error_message,
203
- 'progress': {
204
- 'total_downloaded': total_downloaded,
205
- 'total_size': total_size,
206
- 'percentage': round(percentage, 2),
207
- 'files_active': active_file_count,
208
- 'files_total': file_count,
209
- 'known_total': total_size > 0
210
- },
211
- 'speed': {
212
- 'bytes_per_second': speed,
213
- 'formatted': self.format_speed(speed)
214
- },
215
- 'formatting': {
216
- 'downloaded': self.format_bytes(total_downloaded),
217
- 'total_size': self.format_bytes(total_size)
218
- },
219
- 'timing': {
220
- 'elapsed_seconds': elapsed_seconds,
221
- 'eta_seconds': eta_seconds,
222
- 'start_time': self.download_start_time
223
- }
224
- }
225
-
226
- def _display_progress_bar(self, progress_data: Dict[str, Any]):
227
- """Display a custom unified progress bar."""
228
- if not self.show_progress:
229
- return
230
-
231
- # Clear previous line
232
- if self.last_display_length > 0:
233
- print('\r' + ' ' * self.last_display_length, end='\r')
234
-
235
- progress_info = progress_data.get('progress', {})
236
- speed_info = progress_data.get('speed', {})
237
- timing_info = progress_data.get('timing', {})
238
- formatting_info = progress_data.get('formatting', {})
239
-
240
- percentage = progress_info.get('percentage', 0)
241
- downloaded = formatting_info.get('downloaded', '0 B')
242
- total_size_raw = progress_info.get('total_size', 0)
243
- total_size = formatting_info.get('total_size', 'Unknown')
244
- speed = speed_info.get('formatted', '0 B/s')
245
- known_total = progress_info.get('known_total', False)
246
-
247
- # Create progress bar
248
- bar_width = 30
249
- if known_total and total_size_raw > 0:
250
- # Known total size - show actual progress
251
- filled_width = int(bar_width * min(percentage, 100) / 100)
252
- bar = '#' * filled_width + '-' * (bar_width - filled_width)
253
- else:
254
- # Unknown total size - show animated progress
255
- animation_pos = int(time.time() * 2) % bar_width
256
- bar = '-' * animation_pos + '#' + '-' * (bar_width - animation_pos - 1)
257
-
258
- # Format the progress line
259
- status = progress_data.get('status', 'unknown')
260
- if status == 'downloading':
261
- if known_total:
262
- progress_line = f"[{bar}] {percentage:.1f}% | {downloaded}/{total_size} | {speed}"
263
- else:
264
- progress_line = f"[{bar}] {downloaded} | {speed} | Calculating size..."
265
- elif status == 'completed':
266
- progress_line = f"[{bar}] 100.0% | {downloaded} | Complete!"
267
- elif status == 'error':
268
- progress_line = f"Error: {progress_data.get('error_message', 'Unknown error')}"
269
- else:
270
- progress_line = f"Starting download..."
271
-
272
- # Display and track length for next clear
273
- print(progress_line, end='', flush=True)
274
- self.last_display_length = len(progress_line)
275
-
276
- def _clear_progress_bar(self):
277
- """Clear the progress bar display."""
278
- if self.show_progress and self.last_display_length > 0:
279
- print('\r' + ' ' * self.last_display_length, end='\r')
280
- print() # Move to next line
281
- self.last_display_length = 0
282
-
283
- def _trigger_callback(self):
284
- """Trigger the progress callback if one is set."""
285
- progress_data = self.get_progress_data()
286
-
287
- if self.progress_callback:
288
- try:
289
- self.progress_callback(progress_data)
290
- except Exception as e:
291
- print(f"Error in progress callback: {e}")
292
-
293
- # Show custom progress bar only if callback is enabled and show_progress is True
294
- if self.progress_callback and self.show_progress:
295
- self._display_progress_bar(progress_data)
296
-
297
- def start_tracking(self):
298
- """Start progress tracking (monkey patch tqdm)."""
299
- if self.is_tracking:
300
- return
301
-
302
- # Store original methods
303
- self.original_tqdm_update = tqdm.update
304
- self.original_tqdm_init = tqdm.__init__
305
- self.original_tqdm_display = tqdm.display
306
- self.original_tqdm_write = tqdm.write
307
-
308
- # Create references to self for the nested functions
309
- tracker = self
310
-
311
- def patched_init(self_tqdm, *args, **kwargs):
312
- # Suppress tqdm display by redirecting to devnull
313
- kwargs['file'] = open(os.devnull, 'w')
314
- kwargs['disable'] = False # Keep enabled for tracking
315
- kwargs['leave'] = False # Don't leave progress bar
316
-
317
- result = tracker.original_tqdm_init(self_tqdm, *args, **kwargs)
318
- tracker.register_tqdm(self_tqdm)
319
- return result
320
-
321
- def patched_update(self_tqdm, n=1):
322
- result = tracker.original_tqdm_update(self_tqdm, n)
323
- tracker.update_progress(self_tqdm, n)
324
- return result
325
-
326
- def patched_display(self_tqdm, msg=None, pos=None):
327
- # Override display to show nothing
328
- pass
329
-
330
- def patched_write(self_tqdm, s, file=None, end="\n", nolock=False):
331
- # Override write to prevent any output
332
- pass
333
-
334
- # Apply patches
335
- tqdm.__init__ = patched_init
336
- tqdm.update = patched_update
337
- tqdm.display = patched_display
338
- tqdm.write = patched_write
339
-
340
- self.is_tracking = True
341
- self.download_status = "downloading"
342
- self.download_start_time = time.time()
343
-
344
- # Trigger initial callback
345
- self._trigger_callback()
346
-
347
- def stop_tracking(self):
348
- """Stop progress tracking and restore original tqdm."""
349
- if not self.is_tracking:
350
- return
351
-
352
- # Restore original tqdm methods
353
- if self.original_tqdm_update:
354
- tqdm.update = self.original_tqdm_update
355
- if self.original_tqdm_init:
356
- tqdm.__init__ = self.original_tqdm_init
357
- if hasattr(self, 'original_tqdm_display') and self.original_tqdm_display:
358
- tqdm.display = self.original_tqdm_display
359
- if hasattr(self, 'original_tqdm_write') and self.original_tqdm_write:
360
- tqdm.write = self.original_tqdm_write
361
-
362
- # Clean up any open devnull file handles from tqdm instances
363
- for data in self.progress_data.values():
364
- if 'tqdm_obj' in data and hasattr(data['tqdm_obj'], 'fp'):
365
- try:
366
- fp = data['tqdm_obj'].fp
367
- if fp and fp != sys.stdout and fp != sys.stderr and not fp.closed:
368
- fp.close()
369
- except:
370
- pass
371
-
372
- self.is_tracking = False
373
- if self.download_status == "downloading":
374
- self.download_status = "completed"
375
-
376
- # Trigger final callback and clear progress bar
377
- self._trigger_callback()
378
- self._clear_progress_bar()
379
-
380
- def set_error(self, error_message: str):
381
- """Set error status and trigger callback."""
382
- self.download_status = "error"
383
- self.error_message = error_message
384
- self._trigger_callback()
1
+ """
2
+ Progress tracking utilities for downloads with tqdm integration.
3
+
4
+ This module provides custom progress tracking classes that can monitor
5
+ download progress with callback support and customizable display options.
6
+ """
7
+
8
+ import os
9
+ import sys
10
+ import time
11
+ from typing import Optional, Callable, Dict, Any
12
+ from tqdm.auto import tqdm
13
+
14
+
15
+ class CustomProgressTqdm(tqdm):
16
+ """Custom tqdm that tracks progress but completely hides terminal output."""
17
+
18
+ def __init__(self, *args, **kwargs):
19
+ # Redirect output to devnull to completely suppress terminal output
20
+ kwargs['file'] = open(os.devnull, 'w')
21
+ kwargs['disable'] = False # Keep enabled for tracking
22
+ kwargs['leave'] = False # Don't leave progress bar
23
+ super().__init__(*args, **kwargs)
24
+
25
+ def display(self, msg=None, pos=None):
26
+ # Override display to show nothing
27
+ pass
28
+
29
+ def write(self, s, file=None, end="\n", nolock=False):
30
+ # Override write to prevent any output
31
+ pass
32
+
33
+ def close(self):
34
+ # Override close to avoid printing and properly close devnull
35
+ if hasattr(self, 'fp') and self.fp and self.fp != sys.stdout and self.fp != sys.stderr:
36
+ try:
37
+ self.fp.close()
38
+ except:
39
+ pass
40
+ self.disable = True
41
+ super(tqdm, self).close()
42
+
43
+
44
+ class DownloadProgressTracker:
45
+ """Progress tracker for HuggingFace downloads with callback support."""
46
+
47
+ def __init__(self, progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None, show_progress: bool = True):
48
+ self.progress_data: Dict[str, Dict[str, Any]] = {}
49
+ self.total_repo_size = 0
50
+ self.repo_file_count = 0
51
+ self.original_tqdm_update = None
52
+ self.original_tqdm_init = None
53
+ self.original_tqdm_display = None
54
+ self.original_tqdm_write = None
55
+ self.is_tracking = False
56
+
57
+ # Callback function
58
+ self.progress_callback = progress_callback
59
+
60
+ # Progress display
61
+ self.show_progress = show_progress
62
+ self.last_display_length = 0
63
+
64
+ # Speed tracking
65
+ self.last_downloaded = None # Use None to indicate no previous measurement
66
+ self.last_time = None # Use None to indicate no previous time measurement
67
+ self.speed_history = []
68
+ self.max_speed_history = 10
69
+
70
+ # Download status
71
+ self.download_status = "idle" # idle, downloading, completed, error
72
+ self.error_message = None
73
+ self.download_start_time = None
74
+
75
+ def set_repo_info(self, total_size: int, file_count: int):
76
+ """Set the total repository size and file count before download."""
77
+ self.total_repo_size = total_size
78
+ self.repo_file_count = file_count
79
+
80
+ def register_tqdm(self, tqdm_instance):
81
+ """Register a tqdm instance for monitoring."""
82
+ tqdm_id = str(id(tqdm_instance))
83
+ self.progress_data[tqdm_id] = {
84
+ 'current': 0,
85
+ 'total': getattr(tqdm_instance, 'total', 0) or 0,
86
+ 'desc': getattr(tqdm_instance, 'desc', 'Unknown'),
87
+ 'tqdm_obj': tqdm_instance
88
+ }
89
+ # Trigger callback when new file is registered
90
+ self._trigger_callback()
91
+
92
+ def update_progress(self, tqdm_instance, n=1):
93
+ """Update progress for a tqdm instance."""
94
+ tqdm_id = str(id(tqdm_instance))
95
+ if tqdm_id in self.progress_data:
96
+ self.progress_data[tqdm_id]['current'] = getattr(tqdm_instance, 'n', 0)
97
+ self.progress_data[tqdm_id]['total'] = getattr(tqdm_instance, 'total', 0) or 0
98
+ # Trigger callback on every progress update
99
+ self._trigger_callback()
100
+
101
+ def calculate_speed(self, current_downloaded: int) -> float:
102
+ """Calculate download speed in bytes per second."""
103
+ current_time = time.time()
104
+
105
+ # Check if we have a previous measurement to compare against
106
+ if self.last_time is not None and self.last_downloaded is not None:
107
+ time_diff = current_time - self.last_time
108
+
109
+ # Only calculate if we have a meaningful time difference (avoid division by very small numbers)
110
+ if time_diff > 0.1: # At least 100ms between measurements
111
+ bytes_diff = current_downloaded - self.last_downloaded
112
+
113
+ # Only calculate speed if bytes actually changed
114
+ if bytes_diff >= 0: # Allow 0 for periods with no progress
115
+ speed = bytes_diff / time_diff
116
+
117
+ # Add to speed history for smoothing
118
+ self.speed_history.append(speed)
119
+ if len(self.speed_history) > self.max_speed_history:
120
+ self.speed_history.pop(0)
121
+
122
+ # Update tracking variables when we actually calculate speed
123
+ self.last_downloaded = current_downloaded
124
+ self.last_time = current_time
125
+ else:
126
+ # First measurement - initialize tracking variables
127
+ self.last_downloaded = current_downloaded
128
+ self.last_time = current_time
129
+
130
+ # Return the average of historical speeds if we have any
131
+ # This ensures we show the last known speed even when skipping updates
132
+ if self.speed_history:
133
+ return sum(self.speed_history) / len(self.speed_history)
134
+
135
+ return 0.0
136
+
137
+ def format_bytes(self, bytes_value: int) -> str:
138
+ """Format bytes to human readable string."""
139
+ for unit in ['B', 'KB', 'MB', 'GB']:
140
+ if bytes_value < 1024.0:
141
+ return f"{bytes_value:.1f} {unit}"
142
+ bytes_value /= 1024.0
143
+ return f"{bytes_value:.1f} TB"
144
+
145
+ def format_speed(self, speed: float) -> str:
146
+ """Format speed to human readable string."""
147
+ if speed == 0:
148
+ return "0 B/s"
149
+
150
+ for unit in ['B/s', 'KB/s', 'MB/s', 'GB/s']:
151
+ if speed < 1024.0:
152
+ return f"{speed:.1f} {unit}"
153
+ speed /= 1024.0
154
+ return f"{speed:.1f} TB/s"
155
+
156
+ def get_progress_data(self) -> Dict[str, Any]:
157
+ """Get current progress data."""
158
+ total_downloaded = 0
159
+ active_file_count = 0
160
+ total_file_sizes = 0
161
+
162
+ for data in self.progress_data.values():
163
+ if data['total'] > 0:
164
+ total_downloaded += data['current']
165
+ total_file_sizes += data['total']
166
+ active_file_count += 1
167
+
168
+ # Calculate speed (tracking variables are updated internally)
169
+ speed = self.calculate_speed(total_downloaded)
170
+
171
+ # Determine total size - prioritize pre-fetched repo size, then aggregate file sizes
172
+ if self.total_repo_size > 0:
173
+ # Use pre-fetched repository info if available
174
+ total_size = self.total_repo_size
175
+ elif total_file_sizes > 0:
176
+ # Use sum of individual file sizes if available
177
+ total_size = total_file_sizes
178
+ else:
179
+ # Last resort - we don't know the total size yet
180
+ total_size = 0
181
+
182
+ file_count = self.repo_file_count if self.repo_file_count > 0 else active_file_count
183
+
184
+ # Calculate percentage - handle unknown total size gracefully
185
+ if total_size > 0:
186
+ percentage = min((total_downloaded / total_size * 100), 100.0)
187
+ else:
188
+ percentage = 0
189
+
190
+ # Calculate ETA
191
+ eta_seconds = None
192
+ if speed > 0 and total_size > total_downloaded:
193
+ eta_seconds = (total_size - total_downloaded) / speed
194
+
195
+ # Calculate elapsed time
196
+ elapsed_seconds = None
197
+ if self.download_start_time:
198
+ elapsed_seconds = time.time() - self.download_start_time
199
+
200
+ return {
201
+ 'status': self.download_status,
202
+ 'error_message': self.error_message,
203
+ 'progress': {
204
+ 'total_downloaded': total_downloaded,
205
+ 'total_size': total_size,
206
+ 'percentage': round(percentage, 2),
207
+ 'files_active': active_file_count,
208
+ 'files_total': file_count,
209
+ 'known_total': total_size > 0
210
+ },
211
+ 'speed': {
212
+ 'bytes_per_second': speed,
213
+ 'formatted': self.format_speed(speed)
214
+ },
215
+ 'formatting': {
216
+ 'downloaded': self.format_bytes(total_downloaded),
217
+ 'total_size': self.format_bytes(total_size)
218
+ },
219
+ 'timing': {
220
+ 'elapsed_seconds': elapsed_seconds,
221
+ 'eta_seconds': eta_seconds,
222
+ 'start_time': self.download_start_time
223
+ }
224
+ }
225
+
226
+ def _display_progress_bar(self, progress_data: Dict[str, Any]):
227
+ """Display a custom unified progress bar."""
228
+ if not self.show_progress:
229
+ return
230
+
231
+ # Clear previous line
232
+ if self.last_display_length > 0:
233
+ print('\r' + ' ' * self.last_display_length, end='\r')
234
+
235
+ progress_info = progress_data.get('progress', {})
236
+ speed_info = progress_data.get('speed', {})
237
+ timing_info = progress_data.get('timing', {})
238
+ formatting_info = progress_data.get('formatting', {})
239
+
240
+ percentage = progress_info.get('percentage', 0)
241
+ downloaded = formatting_info.get('downloaded', '0 B')
242
+ total_size_raw = progress_info.get('total_size', 0)
243
+ total_size = formatting_info.get('total_size', 'Unknown')
244
+ speed = speed_info.get('formatted', '0 B/s')
245
+ known_total = progress_info.get('known_total', False)
246
+
247
+ # Create progress bar
248
+ bar_width = 30
249
+ if known_total and total_size_raw > 0:
250
+ # Known total size - show actual progress
251
+ filled_width = int(bar_width * min(percentage, 100) / 100)
252
+ bar = '#' * filled_width + '-' * (bar_width - filled_width)
253
+ else:
254
+ # Unknown total size - show animated progress
255
+ animation_pos = int(time.time() * 2) % bar_width
256
+ bar = '-' * animation_pos + '#' + '-' * (bar_width - animation_pos - 1)
257
+
258
+ # Format the progress line
259
+ status = progress_data.get('status', 'unknown')
260
+ if status == 'downloading':
261
+ if known_total:
262
+ progress_line = f"[{bar}] {percentage:.1f}% | {downloaded}/{total_size} | {speed}"
263
+ else:
264
+ progress_line = f"[{bar}] {downloaded} | {speed} | Calculating size..."
265
+ elif status == 'completed':
266
+ progress_line = f"[{bar}] 100.0% | {downloaded} | Complete!"
267
+ elif status == 'error':
268
+ progress_line = f"Error: {progress_data.get('error_message', 'Unknown error')}"
269
+ else:
270
+ progress_line = f"Starting download..."
271
+
272
+ # Display and track length for next clear
273
+ print(progress_line, end='', flush=True)
274
+ self.last_display_length = len(progress_line)
275
+
276
+ def _clear_progress_bar(self):
277
+ """Clear the progress bar display."""
278
+ if self.show_progress and self.last_display_length > 0:
279
+ print('\r' + ' ' * self.last_display_length, end='\r')
280
+ print() # Move to next line
281
+ self.last_display_length = 0
282
+
283
+ def _trigger_callback(self):
284
+ """Trigger the progress callback if one is set."""
285
+ progress_data = self.get_progress_data()
286
+
287
+ if self.progress_callback:
288
+ try:
289
+ self.progress_callback(progress_data)
290
+ except Exception as e:
291
+ print(f"Error in progress callback: {e}")
292
+
293
+ # Show custom progress bar only if callback is enabled and show_progress is True
294
+ if self.progress_callback and self.show_progress:
295
+ self._display_progress_bar(progress_data)
296
+
297
+ def start_tracking(self):
298
+ """Start progress tracking (monkey patch tqdm)."""
299
+ if self.is_tracking:
300
+ return
301
+
302
+ # Store original methods
303
+ self.original_tqdm_update = tqdm.update
304
+ self.original_tqdm_init = tqdm.__init__
305
+ self.original_tqdm_display = tqdm.display
306
+ self.original_tqdm_write = tqdm.write
307
+
308
+ # Create references to self for the nested functions
309
+ tracker = self
310
+
311
+ def patched_init(self_tqdm, *args, **kwargs):
312
+ # Suppress tqdm display by redirecting to devnull
313
+ kwargs['file'] = open(os.devnull, 'w')
314
+ kwargs['disable'] = False # Keep enabled for tracking
315
+ kwargs['leave'] = False # Don't leave progress bar
316
+
317
+ result = tracker.original_tqdm_init(self_tqdm, *args, **kwargs)
318
+ tracker.register_tqdm(self_tqdm)
319
+ return result
320
+
321
+ def patched_update(self_tqdm, n=1):
322
+ result = tracker.original_tqdm_update(self_tqdm, n)
323
+ tracker.update_progress(self_tqdm, n)
324
+ return result
325
+
326
+ def patched_display(self_tqdm, msg=None, pos=None):
327
+ # Override display to show nothing
328
+ pass
329
+
330
+ def patched_write(self_tqdm, s, file=None, end="\n", nolock=False):
331
+ # Override write to prevent any output
332
+ pass
333
+
334
+ # Apply patches
335
+ tqdm.__init__ = patched_init
336
+ tqdm.update = patched_update
337
+ tqdm.display = patched_display
338
+ tqdm.write = patched_write
339
+
340
+ self.is_tracking = True
341
+ self.download_status = "downloading"
342
+ self.download_start_time = time.time()
343
+
344
+ # Trigger initial callback
345
+ self._trigger_callback()
346
+
347
+ def stop_tracking(self):
348
+ """Stop progress tracking and restore original tqdm."""
349
+ if not self.is_tracking:
350
+ return
351
+
352
+ # Restore original tqdm methods
353
+ if self.original_tqdm_update:
354
+ tqdm.update = self.original_tqdm_update
355
+ if self.original_tqdm_init:
356
+ tqdm.__init__ = self.original_tqdm_init
357
+ if hasattr(self, 'original_tqdm_display') and self.original_tqdm_display:
358
+ tqdm.display = self.original_tqdm_display
359
+ if hasattr(self, 'original_tqdm_write') and self.original_tqdm_write:
360
+ tqdm.write = self.original_tqdm_write
361
+
362
+ # Clean up any open devnull file handles from tqdm instances
363
+ for data in self.progress_data.values():
364
+ if 'tqdm_obj' in data and hasattr(data['tqdm_obj'], 'fp'):
365
+ try:
366
+ fp = data['tqdm_obj'].fp
367
+ if fp and fp != sys.stdout and fp != sys.stderr and not fp.closed:
368
+ fp.close()
369
+ except:
370
+ pass
371
+
372
+ self.is_tracking = False
373
+ if self.download_status == "downloading":
374
+ self.download_status = "completed"
375
+
376
+ # Trigger final callback and clear progress bar
377
+ self._trigger_callback()
378
+ self._clear_progress_bar()
379
+
380
+ def set_error(self, error_message: str):
381
+ """Set error status and trigger callback."""
382
+ self.download_status = "error"
383
+ self.error_message = error_message
384
+ self._trigger_callback()
385
385
  self._clear_progress_bar()