nexaai 1.0.4rc13__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nexaai might be problematic. Click here for more details.

Files changed (59) hide show
  1. nexaai/__init__.py +71 -0
  2. nexaai/_stub.cp310-win_amd64.pyd +0 -0
  3. nexaai/_version.py +4 -0
  4. nexaai/asr.py +60 -0
  5. nexaai/asr_impl/__init__.py +0 -0
  6. nexaai/asr_impl/mlx_asr_impl.py +91 -0
  7. nexaai/asr_impl/pybind_asr_impl.py +43 -0
  8. nexaai/base.py +39 -0
  9. nexaai/binds/__init__.py +3 -0
  10. nexaai/binds/common_bind.cp310-win_amd64.pyd +0 -0
  11. nexaai/binds/embedder_bind.cp310-win_amd64.pyd +0 -0
  12. nexaai/binds/llm_bind.cp310-win_amd64.pyd +0 -0
  13. nexaai/binds/nexa_bridge.dll +0 -0
  14. nexaai/binds/nexa_llama_cpp/ggml-base.dll +0 -0
  15. nexaai/binds/nexa_llama_cpp/ggml-cpu.dll +0 -0
  16. nexaai/binds/nexa_llama_cpp/ggml-cuda.dll +0 -0
  17. nexaai/binds/nexa_llama_cpp/ggml-vulkan.dll +0 -0
  18. nexaai/binds/nexa_llama_cpp/ggml.dll +0 -0
  19. nexaai/binds/nexa_llama_cpp/llama.dll +0 -0
  20. nexaai/binds/nexa_llama_cpp/mtmd.dll +0 -0
  21. nexaai/binds/nexa_llama_cpp/nexa_plugin.dll +0 -0
  22. nexaai/common.py +61 -0
  23. nexaai/cv.py +87 -0
  24. nexaai/cv_impl/__init__.py +0 -0
  25. nexaai/cv_impl/mlx_cv_impl.py +88 -0
  26. nexaai/cv_impl/pybind_cv_impl.py +31 -0
  27. nexaai/embedder.py +68 -0
  28. nexaai/embedder_impl/__init__.py +0 -0
  29. nexaai/embedder_impl/mlx_embedder_impl.py +114 -0
  30. nexaai/embedder_impl/pybind_embedder_impl.py +91 -0
  31. nexaai/image_gen.py +136 -0
  32. nexaai/image_gen_impl/__init__.py +0 -0
  33. nexaai/image_gen_impl/mlx_image_gen_impl.py +291 -0
  34. nexaai/image_gen_impl/pybind_image_gen_impl.py +84 -0
  35. nexaai/llm.py +89 -0
  36. nexaai/llm_impl/__init__.py +0 -0
  37. nexaai/llm_impl/mlx_llm_impl.py +249 -0
  38. nexaai/llm_impl/pybind_llm_impl.py +207 -0
  39. nexaai/rerank.py +51 -0
  40. nexaai/rerank_impl/__init__.py +0 -0
  41. nexaai/rerank_impl/mlx_rerank_impl.py +91 -0
  42. nexaai/rerank_impl/pybind_rerank_impl.py +42 -0
  43. nexaai/runtime.py +64 -0
  44. nexaai/tts.py +70 -0
  45. nexaai/tts_impl/__init__.py +0 -0
  46. nexaai/tts_impl/mlx_tts_impl.py +93 -0
  47. nexaai/tts_impl/pybind_tts_impl.py +42 -0
  48. nexaai/utils/avatar_fetcher.py +104 -0
  49. nexaai/utils/decode.py +18 -0
  50. nexaai/utils/model_manager.py +1195 -0
  51. nexaai/utils/progress_tracker.py +372 -0
  52. nexaai/vlm.py +120 -0
  53. nexaai/vlm_impl/__init__.py +0 -0
  54. nexaai/vlm_impl/mlx_vlm_impl.py +205 -0
  55. nexaai/vlm_impl/pybind_vlm_impl.py +228 -0
  56. nexaai-1.0.4rc13.dist-info/METADATA +26 -0
  57. nexaai-1.0.4rc13.dist-info/RECORD +59 -0
  58. nexaai-1.0.4rc13.dist-info/WHEEL +5 -0
  59. nexaai-1.0.4rc13.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1195 @@
1
+ import os
2
+ import shutil
3
+ import json
4
+ from datetime import datetime
5
+ from dataclasses import dataclass
6
+ from typing import Optional, Callable, Dict, Any, List, Union
7
+ import functools
8
+ from tqdm.auto import tqdm
9
+ from huggingface_hub import HfApi
10
+ from huggingface_hub.utils import HfHubHTTPError, RepositoryNotFoundError
11
+
12
+ from .progress_tracker import CustomProgressTqdm, DownloadProgressTracker
13
+ from .avatar_fetcher import get_avatar_url_for_repo
14
+
15
+ # Default path for model storage
16
+ DEFAULT_MODEL_SAVING_PATH = "~/.cache/nexa.ai/nexa_sdk/models/"
17
+
18
+
19
+ @dataclass
20
+ class DownloadedModel:
21
+ """Data class representing a downloaded model with all its metadata."""
22
+ repo_id: str
23
+ files: List[str]
24
+ folder_type: str # 'owner_repo' or 'direct_repo'
25
+ local_path: str
26
+ size_bytes: int
27
+ file_count: int
28
+ full_repo_download_complete: bool = True # True if no incomplete downloads detected
29
+ pipeline_tag: Optional[str] = None # Pipeline tag from HuggingFace model info
30
+ download_time: Optional[str] = None # ISO format timestamp of download
31
+ avatar_url: Optional[str] = None # Avatar URL for the model author
32
+
33
+ def to_dict(self) -> Dict[str, Any]:
34
+ """Convert to dictionary format for backward compatibility."""
35
+ result = {
36
+ 'repo_id': self.repo_id,
37
+ 'files': self.files,
38
+ 'folder_type': self.folder_type,
39
+ 'local_path': self.local_path,
40
+ 'size_bytes': self.size_bytes,
41
+ 'file_count': self.file_count,
42
+ 'full_repo_download_complete': self.full_repo_download_complete,
43
+ 'pipeline_tag': self.pipeline_tag,
44
+ 'download_time': self.download_time,
45
+ 'avatar_url': self.avatar_url
46
+ }
47
+ return result
48
+
49
+
50
+ ##########################################################################
51
+ # List downloaded models #
52
+ ##########################################################################
53
+
54
+
55
+ def _check_for_incomplete_downloads(directory_path: str) -> bool:
56
+ """
57
+ Check if there are incomplete downloads in the model directory.
58
+
59
+ This function checks for the presence of .incomplete or .lock files
60
+ in the .cache/huggingface/download directory within the model folder,
61
+ which indicates that the model download has not completed.
62
+
63
+ Args:
64
+ directory_path: Path to the model directory
65
+
66
+ Returns:
67
+ bool: True if download is complete (no incomplete files found),
68
+ False if incomplete downloads are detected
69
+ """
70
+ # Check for .cache/huggingface/download directory
71
+ cache_dir = os.path.join(directory_path, '.cache', 'huggingface', 'download')
72
+
73
+ # If the cache directory doesn't exist, assume download is complete
74
+ if not os.path.exists(cache_dir):
75
+ return True
76
+
77
+ try:
78
+ # Walk through the cache directory to find incomplete or lock files
79
+ for root, dirs, files in os.walk(cache_dir):
80
+ for filename in files:
81
+ # Check for .incomplete or .lock files
82
+ if filename.endswith('.incomplete'):
83
+ return False # Found incomplete download
84
+
85
+ # No incomplete files found
86
+ return True
87
+ except (OSError, IOError):
88
+ # If we can't access the directory, assume download is complete
89
+ return True
90
+
91
+
92
+ def _load_download_metadata(directory_path: str) -> Dict[str, Any]:
93
+ """Load download metadata from download_metadata.json if it exists."""
94
+ metadata_path = os.path.join(directory_path, 'download_metadata.json')
95
+ if os.path.exists(metadata_path):
96
+ try:
97
+ with open(metadata_path, 'r', encoding='utf-8') as f:
98
+ return json.load(f)
99
+ except (json.JSONDecodeError, IOError):
100
+ pass
101
+ return {}
102
+
103
+
104
+ def _save_download_metadata(directory_path: str, metadata: Dict[str, Any]) -> None:
105
+ """Save download metadata to download_metadata.json."""
106
+ metadata_path = os.path.join(directory_path, 'download_metadata.json')
107
+ try:
108
+ with open(metadata_path, 'w', encoding='utf-8') as f:
109
+ json.dump(metadata, f, indent=2)
110
+ except IOError:
111
+ # If we can't save metadata, don't fail the download
112
+ pass
113
+
114
+
115
+ def _get_directory_size_and_files(directory_path: str) -> tuple[int, List[str]]:
116
+ """Get total size and list of files in a directory."""
117
+ total_size = 0
118
+ files = []
119
+
120
+ try:
121
+ for root, dirs, filenames in os.walk(directory_path):
122
+ for filename in filenames:
123
+ file_path = os.path.join(root, filename)
124
+ try:
125
+ file_size = os.path.getsize(file_path)
126
+ total_size += file_size
127
+ # Store relative path from the directory
128
+ rel_path = os.path.relpath(file_path, directory_path)
129
+ files.append(rel_path)
130
+ except (OSError, IOError):
131
+ # Skip files that can't be accessed
132
+ continue
133
+ except (OSError, IOError):
134
+ # Skip directories that can't be accessed
135
+ pass
136
+
137
+ return total_size, files
138
+
139
+
140
+ def _scan_for_repo_folders(base_path: str) -> List[DownloadedModel]:
141
+ """Scan a directory for repository folders and return model information."""
142
+ models = []
143
+
144
+ try:
145
+ if not os.path.exists(base_path):
146
+ return models
147
+
148
+ for item in os.listdir(base_path):
149
+ item_path = os.path.join(base_path, item)
150
+
151
+ # Skip non-directory items
152
+ if not os.path.isdir(item_path):
153
+ continue
154
+
155
+ # Check if this might be an owner folder by looking for subdirectories
156
+ has_subdirs = False
157
+ direct_files = []
158
+
159
+ try:
160
+ for subitem in os.listdir(item_path):
161
+ subitem_path = os.path.join(item_path, subitem)
162
+ if os.path.isdir(subitem_path):
163
+ has_subdirs = True
164
+ # This looks like owner/repo structure
165
+ size_bytes, files = _get_directory_size_and_files(subitem_path)
166
+ if files: # Only include if there are files
167
+ # Check if the download is complete
168
+ download_complete = _check_for_incomplete_downloads(subitem_path)
169
+ # Load metadata if it exists
170
+ metadata = _load_download_metadata(subitem_path)
171
+ models.append(DownloadedModel(
172
+ repo_id=f"{item}/{subitem}",
173
+ files=files,
174
+ folder_type='owner_repo',
175
+ local_path=subitem_path,
176
+ size_bytes=size_bytes,
177
+ file_count=len(files),
178
+ full_repo_download_complete=download_complete,
179
+ pipeline_tag=metadata.get('pipeline_tag'),
180
+ download_time=metadata.get('download_time'),
181
+ avatar_url=metadata.get('avatar_url')
182
+ ))
183
+ else:
184
+ direct_files.append(subitem)
185
+ except (OSError, IOError):
186
+ # Skip directories that can't be accessed
187
+ continue
188
+
189
+ # Direct repo folder (no owner structure)
190
+ if not has_subdirs and direct_files:
191
+ size_bytes, files = _get_directory_size_and_files(item_path)
192
+ if files: # Only include if there are files
193
+ # Check if the download is complete
194
+ download_complete = _check_for_incomplete_downloads(item_path)
195
+ # Load metadata if it exists
196
+ metadata = _load_download_metadata(item_path)
197
+ models.append(DownloadedModel(
198
+ repo_id=item,
199
+ files=files,
200
+ folder_type='direct_repo',
201
+ local_path=item_path,
202
+ size_bytes=size_bytes,
203
+ file_count=len(files),
204
+ full_repo_download_complete=download_complete,
205
+ pipeline_tag=metadata.get('pipeline_tag'),
206
+ download_time=metadata.get('download_time'),
207
+ avatar_url=metadata.get('avatar_url')
208
+ ))
209
+
210
+ except (OSError, IOError):
211
+ # Skip if base path can't be accessed
212
+ pass
213
+
214
+ return models
215
+
216
+
217
+ def list_downloaded_models(local_dir: Optional[str] = None) -> List[DownloadedModel]:
218
+ """
219
+ List all downloaded models in the specified directory.
220
+
221
+ This function scans the local directory for downloaded models and returns
222
+ information about each repository including files, size, and folder structure.
223
+
224
+ It handles different folder naming conventions:
225
+ - Owner/repo structure (e.g., "microsoft/DialoGPT-small")
226
+ - Direct repo folders (repos without owner prefix)
227
+
228
+ Args:
229
+ local_dir (str, optional): Directory to scan for downloaded models.
230
+ If None, uses DEFAULT_MODEL_SAVING_PATH.
231
+
232
+ Returns:
233
+ List[DownloadedModel]: List of DownloadedModel objects with attributes:
234
+ - repo_id: str - Repository ID (e.g., "owner/repo")
235
+ - files: List[str] - List of relative file paths in the repository
236
+ - folder_type: str - 'owner_repo' or 'direct_repo'
237
+ - local_path: str - Full path to the model directory
238
+ - size_bytes: int - Total size of all files in bytes
239
+ - file_count: int - Number of files in the repository
240
+ - full_repo_download_complete: bool - True if no incomplete downloads detected,
241
+ False if .incomplete or .lock files exist
242
+ - pipeline_tag: Optional[str] - Pipeline tag from HuggingFace model info
243
+ - download_time: Optional[str] - ISO format timestamp when the model was downloaded
244
+ - avatar_url: Optional[str] - Avatar URL for the model author
245
+ """
246
+
247
+ # Set up local directory
248
+ if local_dir is None:
249
+ local_dir = os.path.expanduser(DEFAULT_MODEL_SAVING_PATH)
250
+
251
+ local_dir = os.path.abspath(local_dir)
252
+
253
+ if not os.path.exists(local_dir):
254
+ return []
255
+
256
+ # Scan for repository folders
257
+ models = _scan_for_repo_folders(local_dir)
258
+
259
+ # Sort by repo_id for consistent output
260
+ models.sort(key=lambda x: x.repo_id)
261
+
262
+ return models
263
+
264
+
265
+ ##########################################################################
266
+ # Remove model functions #
267
+ ##########################################################################
268
+
269
+
270
+ def _parse_model_path(model_path: str) -> tuple[str, str | None]:
271
+ """
272
+ Parse model_path to extract repo_id and optional filename.
273
+
274
+ Examples:
275
+ "microsoft/DialoGPT-small" -> ("microsoft/DialoGPT-small", None)
276
+ "microsoft/DialoGPT-small/pytorch_model.bin" -> ("microsoft/DialoGPT-small", "pytorch_model.bin")
277
+ "Qwen/Qwen3-4B-GGUF/Qwen3-4B-Q4_K_M.gguf" -> ("Qwen/Qwen3-4B-GGUF", "Qwen3-4B-Q4_K_M.gguf")
278
+
279
+ Args:
280
+ model_path: The model path string
281
+
282
+ Returns:
283
+ Tuple of (repo_id, filename) where filename can be None
284
+ """
285
+ parts = model_path.strip().split('/')
286
+
287
+ if len(parts) < 2:
288
+ # Invalid format, assume it's just a repo name without owner
289
+ return model_path, None
290
+ elif len(parts) == 2:
291
+ # Format: "owner/repo"
292
+ return model_path, None
293
+ else:
294
+ # Format: "owner/repo/file" or "owner/repo/subdir/file"
295
+ repo_id = f"{parts[0]}/{parts[1]}"
296
+ filename = '/'.join(parts[2:])
297
+ return repo_id, filename
298
+
299
+
300
+ def _validate_and_parse_input(model_path: str) -> tuple[str, Optional[str]]:
301
+ """Validate input and parse model path."""
302
+ if not model_path or not isinstance(model_path, str) or not model_path.strip():
303
+ raise ValueError("model_path is required and must be a non-empty string")
304
+
305
+ model_path = model_path.strip()
306
+ return _parse_model_path(model_path)
307
+
308
+
309
+ def _find_target_model(repo_id: str, local_dir: str) -> DownloadedModel:
310
+ """Find and validate the target model exists."""
311
+ downloaded_models = list_downloaded_models(local_dir)
312
+
313
+ for model in downloaded_models:
314
+ if model.repo_id == repo_id:
315
+ return model
316
+
317
+ available_repos = [model.repo_id for model in downloaded_models]
318
+ raise FileNotFoundError(
319
+ f"Repository '{repo_id}' not found in downloaded models. "
320
+ f"Available repositories: {available_repos}"
321
+ )
322
+
323
+
324
+ def _clean_empty_owner_directory(target_model: DownloadedModel) -> None:
325
+ """Remove empty owner directory if applicable."""
326
+ if target_model.folder_type != 'owner_repo':
327
+ return
328
+
329
+ parent_dir = os.path.dirname(target_model.local_path)
330
+ try:
331
+ if os.path.exists(parent_dir) and not os.listdir(parent_dir):
332
+ os.rmdir(parent_dir)
333
+ except OSError:
334
+ pass
335
+
336
+
337
+ def _remove_specific_file(target_model: DownloadedModel, file_name: str, local_dir: str) -> DownloadedModel:
338
+ """Remove a specific file from the repository."""
339
+ # Validate file exists in model
340
+ if file_name not in target_model.files:
341
+ raise FileNotFoundError(
342
+ f"File '{file_name}' not found in repository '{target_model.repo_id}'. "
343
+ f"Available files: {target_model.files[:10]}{'...' if len(target_model.files) > 10 else ''}"
344
+ )
345
+
346
+ # Construct full file path and validate it exists on disk
347
+ file_path = os.path.join(target_model.local_path, file_name)
348
+ if not os.path.exists(file_path):
349
+ raise FileNotFoundError(f"File does not exist on disk: {file_path}")
350
+
351
+ # Get file size before removal
352
+ try:
353
+ file_size = os.path.getsize(file_path)
354
+ except OSError:
355
+ file_size = 0
356
+
357
+ # Remove the file
358
+ try:
359
+ os.remove(file_path)
360
+ except OSError as e:
361
+ raise OSError(f"Failed to remove file '{file_path}': {e}")
362
+
363
+ # Create updated model object
364
+ updated_files = [f for f in target_model.files if f != file_name]
365
+ updated_size = target_model.size_bytes - file_size
366
+ # Re-check download completeness after file removal
367
+ download_complete = _check_for_incomplete_downloads(target_model.local_path)
368
+ updated_model = DownloadedModel(
369
+ repo_id=target_model.repo_id,
370
+ files=updated_files,
371
+ folder_type=target_model.folder_type,
372
+ local_path=target_model.local_path,
373
+ size_bytes=updated_size,
374
+ file_count=len(updated_files),
375
+ full_repo_download_complete=download_complete
376
+ )
377
+
378
+ # If no files left, remove the entire directory
379
+ if len(updated_files) == 0:
380
+ try:
381
+ shutil.rmtree(target_model.local_path)
382
+ _clean_empty_owner_directory(target_model)
383
+ except OSError:
384
+ pass
385
+
386
+ return updated_model
387
+
388
+
389
+ def _remove_entire_repository(target_model: DownloadedModel, local_dir: str) -> DownloadedModel:
390
+ """Remove the entire repository and clean up."""
391
+ # Remove the directory and all its contents
392
+ try:
393
+ shutil.rmtree(target_model.local_path)
394
+ except OSError as e:
395
+ raise OSError(f"Failed to remove directory '{target_model.local_path}': {e}")
396
+
397
+ # Clean up associated resources
398
+ _clean_empty_owner_directory(target_model)
399
+
400
+ return target_model
401
+
402
+
403
+ def remove_model_or_file(
404
+ model_path: str,
405
+ local_dir: Optional[str] = None
406
+ ) -> DownloadedModel:
407
+ """
408
+ Remove a downloaded model or specific file by repository ID or file path.
409
+
410
+ This function supports two modes:
411
+ 1. Remove entire repository: "microsoft/DialoGPT-small"
412
+ 2. Remove specific file: "Qwen/Qwen3-4B-GGUF/Qwen3-4B-Q4_K_M.gguf"
413
+
414
+ For entire repository removal, it removes the directory and all files. For specific file removal, it only
415
+ removes that file and updates the repository metadata.
416
+
417
+ Args:
418
+ model_path (str): Required. Either:
419
+ - Repository ID (e.g., "microsoft/DialoGPT-small") - removes entire repo
420
+ - File path (e.g., "Qwen/Qwen3-4B-GGUF/model.gguf") - removes specific file
421
+ local_dir (str, optional): Directory to search for downloaded models.
422
+ If None, uses DEFAULT_MODEL_SAVING_PATH.
423
+
424
+ Returns:
425
+ DownloadedModel: The model object representing what was removed from disk.
426
+ For file removal, returns updated model info after file removal.
427
+
428
+ Raises:
429
+ ValueError: If model_path is invalid (empty or None)
430
+ FileNotFoundError: If the repository or file is not found in downloaded models
431
+ OSError: If there's an error removing files from disk
432
+ """
433
+ # Validate input and parse path
434
+ repo_id, file_name = _validate_and_parse_input(model_path)
435
+
436
+ # Set up local directory
437
+ if local_dir is None:
438
+ local_dir = os.path.expanduser(DEFAULT_MODEL_SAVING_PATH)
439
+
440
+ local_dir = os.path.abspath(local_dir)
441
+
442
+ if not os.path.exists(local_dir):
443
+ raise FileNotFoundError(f"Local directory does not exist: {local_dir}")
444
+
445
+ # Find the target model
446
+ target_model = _find_target_model(repo_id, local_dir)
447
+
448
+ # Delegate to appropriate removal function
449
+ if file_name:
450
+ return _remove_specific_file(target_model, file_name, local_dir)
451
+ else:
452
+ return _remove_entire_repository(target_model, local_dir)
453
+
454
+
455
+ ##########################################################################
456
+ # Check model existence functions #
457
+ ##########################################################################
458
+
459
+
460
+ def check_model_existence(
461
+ model_path: str,
462
+ local_dir: Optional[str] = None
463
+ ) -> bool:
464
+ """
465
+ Check if a downloaded model or specific file exists locally.
466
+
467
+ This function supports two modes:
468
+ 1. Check entire repository: "microsoft/DialoGPT-small"
469
+ 2. Check specific file: "Qwen/Qwen3-4B-GGUF/Qwen3-4B-Q4_K_M.gguf"
470
+
471
+ Args:
472
+ model_path (str): Required. Either:
473
+ - Repository ID (e.g., "microsoft/DialoGPT-small") - checks entire repo
474
+ - File path (e.g., "Qwen/Qwen3-4B-GGUF/model.gguf") - checks specific file
475
+ local_dir (str, optional): Directory to search for downloaded models.
476
+ If None, uses DEFAULT_MODEL_SAVING_PATH.
477
+
478
+ Returns:
479
+ bool: True if the requested item exists, False otherwise
480
+
481
+ Raises:
482
+ ValueError: If model_path is invalid (empty or None)
483
+ """
484
+ # Validate input and parse path
485
+ repo_id, file_name = _validate_and_parse_input(model_path)
486
+
487
+ # Set up local directory
488
+ if local_dir is None:
489
+ local_dir = os.path.expanduser(DEFAULT_MODEL_SAVING_PATH)
490
+
491
+ local_dir = os.path.abspath(local_dir)
492
+
493
+ # Return False if local directory doesn't exist
494
+ if not os.path.exists(local_dir):
495
+ return False
496
+
497
+ # Get all downloaded models
498
+ downloaded_models = list_downloaded_models(local_dir)
499
+
500
+ # Find the target model
501
+ for model in downloaded_models:
502
+ if model.repo_id == repo_id:
503
+ # If no specific file requested, repository existence is sufficient
504
+ if file_name is None:
505
+ return True
506
+ else:
507
+ # Check specific file existence
508
+ return file_name in model.files
509
+
510
+ return False
511
+
512
+
513
+ ##########################################################################
514
+ # HuggingFace Downloader Class #
515
+ ##########################################################################
516
+
517
+
518
+ class HuggingFaceDownloader:
519
+ """Class to handle downloads from HuggingFace Hub with unified API usage."""
520
+
521
+ def __init__(
522
+ self,
523
+ endpoint: Optional[str] = None,
524
+ token: Union[bool, str, None] = None,
525
+ enable_transfer: bool = True
526
+ ):
527
+ """
528
+ Initialize the downloader with HuggingFace API.
529
+
530
+ Args:
531
+ endpoint: Custom endpoint URL (e.g., "https://hf-mirror.com").
532
+ If None, uses default HuggingFace Hub.
533
+ token: Authentication token for private repositories.
534
+ enable_transfer: Whether to enable HF transfer for faster downloads.
535
+ """
536
+ # Always create an HfApi instance - either with custom endpoint or default
537
+ self.token = token if isinstance(token, str) else False # False means disable authentication
538
+ self.api = HfApi(endpoint=endpoint, token=self.token) if endpoint else HfApi(token=self.token)
539
+ self.enable_transfer = enable_transfer
540
+ self.original_hf_transfer = None
541
+ self.endpoint = endpoint # Store endpoint for avatar fetching
542
+
543
+ def _create_repo_directory(self, local_dir: str, repo_id: str) -> str:
544
+ """Create a directory structure for the repository following HF convention."""
545
+ if '/' in repo_id:
546
+ # Standard format: owner/repo
547
+ owner, repo = repo_id.split('/', 1)
548
+ repo_dir = os.path.join(local_dir, owner, repo)
549
+ else:
550
+ # Direct repo name without owner
551
+ repo_dir = os.path.join(local_dir, repo_id)
552
+
553
+ os.makedirs(repo_dir, exist_ok=True)
554
+ return repo_dir
555
+
556
+ def _created_dir_if_not_exists(self, local_dir: Optional[str]) -> str:
557
+ """Create directory if it doesn't exist and return the expanded path."""
558
+ if local_dir is None:
559
+ local_dir = DEFAULT_MODEL_SAVING_PATH
560
+
561
+ local_dir = os.path.expanduser(local_dir)
562
+ os.makedirs(local_dir, exist_ok=True)
563
+ return local_dir
564
+
565
+ def _get_repo_info_for_progress(
566
+ self,
567
+ repo_id: str,
568
+ file_name: Optional[Union[str, List[str]]] = None
569
+ ) -> tuple[int, int]:
570
+ """Get total repository size and file count for progress tracking."""
571
+ try:
572
+ info = self.api.model_info(repo_id, files_metadata=True, token=self.token)
573
+
574
+ total_size = 0
575
+ file_count = 0
576
+
577
+ if info.siblings:
578
+ for sibling in info.siblings:
579
+ # Handle different file_name types
580
+ if file_name is not None:
581
+ if isinstance(file_name, str):
582
+ # Single file - only count if it matches
583
+ if sibling.rfilename != file_name:
584
+ continue
585
+ elif isinstance(file_name, list):
586
+ # Multiple files - only count if in the list
587
+ if sibling.rfilename not in file_name:
588
+ continue
589
+
590
+ # For all matching files (or all files if file_name is None)
591
+ if hasattr(sibling, 'size') and sibling.size is not None:
592
+ total_size += sibling.size
593
+ file_count += 1
594
+ else:
595
+ # Count files without size info
596
+ file_count += 1
597
+
598
+ return total_size, file_count if file_count > 0 else 1
599
+ except Exception:
600
+ # If we can't get info, return defaults
601
+ return 0, 1
602
+
603
+ def _validate_and_setup_params(
604
+ self,
605
+ repo_id: str,
606
+ file_name: Optional[Union[str, List[str]]]
607
+ ) -> tuple[str, Optional[Union[str, List[str]]]]:
608
+ """Validate and normalize input parameters."""
609
+ if not repo_id:
610
+ raise ValueError("repo_id is required")
611
+
612
+ repo_id = repo_id.strip()
613
+
614
+ # Handle file_name parameter
615
+ if file_name is not None:
616
+ if isinstance(file_name, str):
617
+ file_name = file_name.strip()
618
+ if not file_name:
619
+ file_name = None
620
+ elif isinstance(file_name, list):
621
+ # Filter out empty strings and strip whitespace
622
+ file_name = [f.strip() for f in file_name if f and f.strip()]
623
+ if not file_name:
624
+ file_name = None
625
+ else:
626
+ raise ValueError("file_name must be a string, list of strings, or None")
627
+
628
+ return repo_id, file_name
629
+
630
+ def _setup_progress_tracker(
631
+ self,
632
+ progress_callback: Optional[Callable[[Dict[str, Any]], None]],
633
+ show_progress: bool,
634
+ repo_id: str,
635
+ file_name: Optional[Union[str, List[str]]]
636
+ ) -> Optional[DownloadProgressTracker]:
637
+ """Initialize progress tracker if callback is provided."""
638
+ if not progress_callback:
639
+ return None
640
+
641
+ progress_tracker = DownloadProgressTracker(progress_callback, show_progress)
642
+ # Get repo info for progress tracking - now handles all cases
643
+ total_size, file_count = self._get_repo_info_for_progress(repo_id, file_name)
644
+ progress_tracker.set_repo_info(total_size, file_count)
645
+ return progress_tracker
646
+
647
+ def _setup_hf_transfer_env(self) -> None:
648
+ """Set up HF transfer environment."""
649
+ self.original_hf_transfer = os.environ.get("HF_HUB_ENABLE_HF_TRANSFER")
650
+ if self.enable_transfer:
651
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
652
+
653
+ def _cleanup_hf_transfer_env(self) -> None:
654
+ """Restore original HF transfer environment."""
655
+ if self.original_hf_transfer is not None:
656
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = self.original_hf_transfer
657
+ else:
658
+ os.environ.pop("HF_HUB_ENABLE_HF_TRANSFER", None)
659
+
660
+ def _validate_repository_and_get_info(
661
+ self,
662
+ repo_id: str,
663
+ progress_tracker: Optional[DownloadProgressTracker]
664
+ ):
665
+ """Validate repository exists and get info."""
666
+ try:
667
+ info = self.api.model_info(repo_id, token=self.token)
668
+ return info
669
+ except RepositoryNotFoundError:
670
+ error_msg = f"Repository '{repo_id}' not found. Please check the repository ID."
671
+ if progress_tracker:
672
+ progress_tracker.set_error(error_msg)
673
+ raise RepositoryNotFoundError(error_msg)
674
+ except HfHubHTTPError as e:
675
+ if e.response.status_code == 404:
676
+ error_msg = f"Repository '{repo_id}' not found. Please check the repository ID."
677
+ if progress_tracker:
678
+ progress_tracker.set_error(error_msg)
679
+ raise RepositoryNotFoundError(error_msg)
680
+ else:
681
+ error_msg = f"HTTP error while accessing repository '{repo_id}': {e}"
682
+ if progress_tracker:
683
+ progress_tracker.set_error(error_msg)
684
+ raise HfHubHTTPError(error_msg)
685
+
686
+ def _validate_file_exists_in_repo(
687
+ self,
688
+ file_name: str,
689
+ info,
690
+ repo_id: str,
691
+ progress_tracker: Optional[DownloadProgressTracker]
692
+ ) -> None:
693
+ """Validate that the file exists in the repository."""
694
+ file_exists = False
695
+ if info.siblings:
696
+ for sibling in info.siblings:
697
+ if sibling.rfilename == file_name:
698
+ file_exists = True
699
+ break
700
+
701
+ if not file_exists:
702
+ available_files = [sibling.rfilename for sibling in info.siblings] if info.siblings else []
703
+ error_msg = (
704
+ f"File '{file_name}' not found in repository '{repo_id}'. "
705
+ f"Available files: {available_files[:10]}{'...' if len(available_files) > 10 else ''}"
706
+ )
707
+ if progress_tracker:
708
+ progress_tracker.set_error(error_msg)
709
+ progress_tracker.stop_tracking()
710
+ raise ValueError(error_msg)
711
+
712
+ def _check_file_exists_and_valid(
713
+ self,
714
+ file_path: str,
715
+ expected_size: Optional[int] = None
716
+ ) -> bool:
717
+ """Check if a file exists and is valid (non-empty, correct size if known)."""
718
+ if not os.path.exists(file_path):
719
+ return False
720
+
721
+ # Check file is not empty
722
+ try:
723
+ file_size = os.path.getsize(file_path)
724
+ if file_size == 0:
725
+ return False
726
+ except (OSError, IOError):
727
+ return False
728
+
729
+ # If we have expected size, check it matches
730
+ if expected_size is not None and file_size != expected_size:
731
+ return False
732
+
733
+ # If no expected size, just check that file is not empty
734
+ return os.path.getsize(file_path) > 0
735
+
736
+ def _fetch_and_save_metadata(self, repo_id: str, local_dir: str) -> None:
737
+ """Fetch model info and save metadata after successful download."""
738
+ try:
739
+ # Fetch model info to get pipeline_tag
740
+ info = self.api.model_info(repo_id, token=self.token)
741
+ pipeline_tag = info.pipeline_tag if hasattr(info, 'pipeline_tag') else None
742
+
743
+ # Get avatar URL
744
+ avatar_url = get_avatar_url_for_repo(repo_id, custom_endpoint=self.endpoint)
745
+
746
+ # Prepare metadata
747
+ metadata = {
748
+ 'pipeline_tag': pipeline_tag,
749
+ 'download_time': datetime.now().isoformat(),
750
+ 'avatar_url': avatar_url
751
+ }
752
+
753
+ # Save metadata to the repository directory
754
+ _save_download_metadata(local_dir, metadata)
755
+
756
+ except Exception:
757
+ # Don't fail the download if metadata fetch fails
758
+ pass
759
+
760
+ def _download_single_file(
761
+ self,
762
+ repo_id: str,
763
+ file_name: str,
764
+ local_dir: str,
765
+ progress_tracker: Optional[DownloadProgressTracker],
766
+ force_download: bool = False
767
+ ) -> str:
768
+ """Download a single file from the repository using HuggingFace Hub API."""
769
+ # Create repo-specific directory for the single file
770
+ file_local_dir = self._create_repo_directory(local_dir, repo_id)
771
+
772
+ # Check if file already exists
773
+ local_file_path = os.path.join(file_local_dir, file_name)
774
+ if not force_download and self._check_file_exists_and_valid(local_file_path):
775
+ print(f"✓ File already exists, skipping: {file_name}")
776
+ # Stop progress tracking
777
+ if progress_tracker:
778
+ progress_tracker.stop_tracking()
779
+ return local_file_path
780
+
781
+ try:
782
+ # Note: hf_hub_download doesn't support tqdm_class parameter
783
+ # Progress tracking works through the global tqdm monkey patching
784
+ downloaded_path = self.api.hf_hub_download(
785
+ repo_id=repo_id,
786
+ filename=file_name,
787
+ local_dir=file_local_dir,
788
+ local_dir_use_symlinks=False,
789
+ token=self.token,
790
+ force_download=force_download
791
+ )
792
+
793
+ # Stop progress tracking
794
+ if progress_tracker:
795
+ progress_tracker.stop_tracking()
796
+
797
+ # Save metadata after successful download
798
+ self._fetch_and_save_metadata(repo_id, file_local_dir)
799
+
800
+ return downloaded_path
801
+
802
+ except HfHubHTTPError as e:
803
+ error_msg = f"Error downloading file '{file_name}': {e}"
804
+ if progress_tracker:
805
+ progress_tracker.set_error(error_msg)
806
+ progress_tracker.stop_tracking()
807
+ if e.response.status_code == 404:
808
+ raise ValueError(f"File '{file_name}' not found in repository '{repo_id}'")
809
+ else:
810
+ raise HfHubHTTPError(error_msg)
811
+
812
+ def _download_entire_repository(
813
+ self,
814
+ repo_id: str,
815
+ local_dir: str,
816
+ progress_tracker: Optional[DownloadProgressTracker],
817
+ force_download: bool = False
818
+ ) -> str:
819
+ """Download the entire repository."""
820
+ # Create a subdirectory for this specific repo
821
+ repo_local_dir = self._create_repo_directory(local_dir, repo_id)
822
+
823
+ # Check if repository already exists (basic check for directory existence)
824
+ if not force_download and os.path.exists(repo_local_dir) and os.listdir(repo_local_dir):
825
+ print(f"✓ Repository already exists, skipping: {repo_id}")
826
+ # Stop progress tracking
827
+ if progress_tracker:
828
+ progress_tracker.stop_tracking()
829
+ return repo_local_dir
830
+
831
+ try:
832
+ download_kwargs = {
833
+ 'repo_id': repo_id,
834
+ 'local_dir': repo_local_dir,
835
+ 'local_dir_use_symlinks': False,
836
+ 'token': self.token,
837
+ 'force_download': force_download
838
+ }
839
+
840
+ # Add tqdm_class if progress tracking is enabled
841
+ if progress_tracker:
842
+ download_kwargs['tqdm_class'] = CustomProgressTqdm
843
+
844
+ downloaded_path = self.api.snapshot_download(**download_kwargs)
845
+
846
+ # Stop progress tracking
847
+ if progress_tracker:
848
+ progress_tracker.stop_tracking()
849
+
850
+ # Save metadata after successful download
851
+ self._fetch_and_save_metadata(repo_id, repo_local_dir)
852
+
853
+ return downloaded_path
854
+
855
+ except HfHubHTTPError as e:
856
+ error_msg = f"Error downloading repository '{repo_id}': {e}"
857
+ if progress_tracker:
858
+ progress_tracker.set_error(error_msg)
859
+ progress_tracker.stop_tracking()
860
+ raise HfHubHTTPError(error_msg)
861
+
862
+ def _download_multiple_files_from_hf(
863
+ self,
864
+ repo_id: str,
865
+ file_names: List[str],
866
+ local_dir: str,
867
+ progress_tracker: Optional[DownloadProgressTracker],
868
+ force_download: bool = False
869
+ ) -> str:
870
+ """Download multiple specific files from HuggingFace Hub."""
871
+ # Create repo-specific directory
872
+ repo_local_dir = self._create_repo_directory(local_dir, repo_id)
873
+
874
+ # Create overall progress bar for multiple files
875
+ overall_progress = tqdm(
876
+ total=len(file_names),
877
+ unit='file',
878
+ desc=f"Downloading {len(file_names)} files from {repo_id}",
879
+ position=0,
880
+ leave=True
881
+ )
882
+
883
+ try:
884
+ for file_name in file_names:
885
+ overall_progress.set_postfix_str(f"Current: {os.path.basename(file_name)}")
886
+
887
+ # Check if file already exists
888
+ local_file_path = os.path.join(repo_local_dir, file_name)
889
+ if not force_download and self._check_file_exists_and_valid(local_file_path):
890
+ print(f"✓ File already exists, skipping: {file_name}")
891
+ overall_progress.update(1)
892
+ continue
893
+
894
+ # Download each file using hf_hub_download
895
+ self.api.hf_hub_download(
896
+ repo_id=repo_id,
897
+ filename=file_name,
898
+ local_dir=repo_local_dir,
899
+ local_dir_use_symlinks=False,
900
+ token=self.token,
901
+ force_download=force_download
902
+ )
903
+
904
+ overall_progress.update(1)
905
+
906
+ overall_progress.close()
907
+
908
+ # Stop progress tracking
909
+ if progress_tracker:
910
+ progress_tracker.stop_tracking()
911
+
912
+ # Save metadata after successful download
913
+ self._fetch_and_save_metadata(repo_id, repo_local_dir)
914
+
915
+ return repo_local_dir
916
+
917
+ except HfHubHTTPError as e:
918
+ overall_progress.close()
919
+ error_msg = f"Error downloading files from '{repo_id}': {e}"
920
+ if progress_tracker:
921
+ progress_tracker.set_error(error_msg)
922
+ progress_tracker.stop_tracking()
923
+ raise HfHubHTTPError(error_msg)
924
+ except Exception as e:
925
+ overall_progress.close()
926
+ if progress_tracker:
927
+ progress_tracker.set_error(str(e))
928
+ progress_tracker.stop_tracking()
929
+ raise
930
+
931
+ def download(
932
+ self,
933
+ repo_id: str,
934
+ file_name: Optional[Union[str, List[str]]] = None,
935
+ local_dir: Optional[str] = None,
936
+ progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
937
+ show_progress: bool = True,
938
+ force_download: bool = False
939
+ ) -> str:
940
+ """
941
+ Main download method that handles all download scenarios.
942
+
943
+ Args:
944
+ repo_id: Repository ID to download from
945
+ file_name: Optional file name(s) to download
946
+ local_dir: Local directory to save files
947
+ progress_callback: Callback for progress updates
948
+ show_progress: Whether to show progress bar
949
+ force_download: Force re-download even if files exist
950
+
951
+ Returns:
952
+ Path to downloaded file or directory
953
+ """
954
+ # Validate and normalize parameters
955
+ repo_id, file_name = self._validate_and_setup_params(repo_id, file_name)
956
+
957
+ # Set up local directory
958
+ local_dir = self._created_dir_if_not_exists(local_dir)
959
+
960
+ # Set up progress tracker
961
+ file_name_for_progress = file_name if isinstance(file_name, str) else None
962
+ progress_tracker = self._setup_progress_tracker(
963
+ progress_callback, show_progress, repo_id, file_name_for_progress
964
+ )
965
+
966
+ # Set up HF transfer environment
967
+ self._setup_hf_transfer_env()
968
+
969
+ try:
970
+ # Validate repository and get info
971
+ info = self._validate_repository_and_get_info(repo_id, progress_tracker)
972
+
973
+ # Start progress tracking
974
+ if progress_tracker:
975
+ progress_tracker.start_tracking()
976
+
977
+ # Choose download strategy based on file_name
978
+ if file_name is None:
979
+ # Download entire repository
980
+ return self._download_entire_repository(
981
+ repo_id, local_dir, progress_tracker, force_download
982
+ )
983
+ elif isinstance(file_name, str):
984
+ # Download specific single file
985
+ self._validate_file_exists_in_repo(file_name, info, repo_id, progress_tracker)
986
+ return self._download_single_file(
987
+ repo_id, file_name, local_dir, progress_tracker, force_download
988
+ )
989
+ else: # file_name is a list
990
+ # Download multiple specific files
991
+ # Validate all files exist
992
+ for fname in file_name:
993
+ self._validate_file_exists_in_repo(fname, info, repo_id, progress_tracker)
994
+
995
+ return self._download_multiple_files_from_hf(
996
+ repo_id, file_name, local_dir, progress_tracker, force_download
997
+ )
998
+
999
+ except Exception as e:
1000
+ # Handle any unexpected errors
1001
+ if progress_tracker and progress_tracker.download_status != "error":
1002
+ progress_tracker.set_error(str(e))
1003
+ progress_tracker.stop_tracking()
1004
+ raise
1005
+
1006
+ finally:
1007
+ # Restore original HF transfer setting
1008
+ self._cleanup_hf_transfer_env()
1009
+
1010
+
1011
+ ##########################################################################
1012
+ # Public Download Function #
1013
+ ##########################################################################
1014
+
1015
+
1016
+ def download_from_huggingface(
1017
+ repo_id: str,
1018
+ file_name: Optional[Union[str, List[str]]] = None,
1019
+ local_dir: Optional[str] = None,
1020
+ enable_transfer: bool = True,
1021
+ progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
1022
+ show_progress: bool = True,
1023
+ token: Union[bool, str, None] = None,
1024
+ custom_endpoint: Optional[str] = None,
1025
+ force_download: bool = False
1026
+ ) -> str:
1027
+ """
1028
+ Download models or files from HuggingFace Hub or custom mirror endpoints.
1029
+
1030
+ Args:
1031
+ repo_id (str): Required. The repository ID to download from (e.g., "microsoft/DialoGPT-medium")
1032
+ file_name (Union[str, List[str]], optional): Single filename or list of filenames to download.
1033
+ If None, downloads entire repo.
1034
+ local_dir (str, optional): Local directory to save files. If None, uses DEFAULT_MODEL_SAVING_PATH.
1035
+ enable_transfer (bool, optional): Whether to enable HF transfer for faster downloads. Default True.
1036
+ progress_callback (Callable, optional): Callback function to receive progress updates.
1037
+ Function receives a dict with progress information.
1038
+ show_progress (bool, optional): Whether to show a unified progress bar in the terminal. Default True.
1039
+ Only works when progress_callback is provided.
1040
+ token (Union[bool, str, None], optional): A token to be used for the download.
1041
+ - If True, the token is read from the HuggingFace config folder.
1042
+ - If a string, it's used as the authentication token.
1043
+ - If None, uses default behavior.
1044
+ custom_endpoint (str, optional): A custom HuggingFace-compatible endpoint URL.
1045
+ Should be ONLY the base endpoint without any paths.
1046
+ Examples:
1047
+ - "https://hf-mirror.com"
1048
+ - "https://huggingface.co" (default)
1049
+ The endpoint will be used to initialize HfApi for all downloads.
1050
+ force_download (bool, optional): If True, download files even if they already exist locally.
1051
+ Default False (skip existing files).
1052
+
1053
+ Returns:
1054
+ str: Path to the downloaded file or directory
1055
+
1056
+ Raises:
1057
+ ValueError: If repo_id is invalid or file_name doesn't exist in the repo
1058
+ RepositoryNotFoundError: If the repository doesn't exist
1059
+ HfHubHTTPError: If there's an HTTP error during download
1060
+
1061
+ Progress Callback Data Format:
1062
+ {
1063
+ 'status': str, # 'idle', 'downloading', 'completed', 'error'
1064
+ 'error_message': str, # Only present if status is 'error'
1065
+ 'progress': {
1066
+ 'total_downloaded': int, # Bytes downloaded
1067
+ 'total_size': int, # Total bytes to download
1068
+ 'percentage': float, # Progress percentage (0-100)
1069
+ 'files_active': int, # Number of files currently downloading
1070
+ 'files_total': int, # Total number of files
1071
+ 'known_total': bool # Whether total size is known
1072
+ },
1073
+ 'speed': {
1074
+ 'bytes_per_second': float, # Download speed in bytes/sec
1075
+ 'formatted': str # Human readable speed (e.g., "1.2 MB/s")
1076
+ },
1077
+ 'formatting': {
1078
+ 'downloaded': str, # Human readable downloaded size
1079
+ 'total_size': str # Human readable total size
1080
+ },
1081
+ 'timing': {
1082
+ 'elapsed_seconds': float, # Time since download started
1083
+ 'eta_seconds': float, # Estimated time remaining
1084
+ 'start_time': float # Download start timestamp
1085
+ }
1086
+ }
1087
+ """
1088
+ # Create downloader instance with custom endpoint if provided
1089
+ downloader = HuggingFaceDownloader(
1090
+ endpoint=custom_endpoint,
1091
+ token=token,
1092
+ enable_transfer=enable_transfer
1093
+ )
1094
+
1095
+ # Use the downloader to perform the download
1096
+ return downloader.download(
1097
+ repo_id=repo_id,
1098
+ file_name=file_name,
1099
+ local_dir=local_dir,
1100
+ progress_callback=progress_callback,
1101
+ show_progress=show_progress,
1102
+ force_download=force_download
1103
+ )
1104
+
1105
+
1106
+ ##########################################################################
1107
+ # Auto-download decorator #
1108
+ ##########################################################################
1109
+
1110
+
1111
+ def auto_download_model(func: Callable) -> Callable:
1112
+ """
1113
+ Decorator that automatically downloads models from HuggingFace if they don't exist locally.
1114
+
1115
+ This decorator should be applied to __init__ methods that take a name_or_path parameter.
1116
+ If name_or_path doesn't exist as a local file/directory, it will attempt to download
1117
+ it from HuggingFace Hub using the download_from_huggingface function.
1118
+
1119
+ The name_or_path can be in formats like:
1120
+ - "microsoft/DialoGPT-small" (downloads entire repo)
1121
+ - "microsoft/DialoGPT-small/pytorch_model.bin" (downloads specific file)
1122
+ - "Qwen/Qwen3-4B-GGUF/Qwen3-4B-Q4_K_M.gguf" (downloads specific file)
1123
+
1124
+ Optional kwargs that are extracted and passed to download_from_huggingface:
1125
+ - progress_callback: Callback function for download progress updates
1126
+ - token: HuggingFace authentication token for private repositories
1127
+
1128
+ Args:
1129
+ func: The __init__ method to wrap
1130
+
1131
+ Returns:
1132
+ Wrapped function that handles automatic model downloading
1133
+ """
1134
+ @functools.wraps(func)
1135
+ def wrapper(*args, **kwargs):
1136
+ # Find name_or_path in arguments
1137
+ # Assuming name_or_path is the first argument after self
1138
+ if len(args) >= 2:
1139
+ name_or_path = args[1]
1140
+ args_list = list(args)
1141
+ path_index = 1
1142
+ is_positional = True
1143
+ elif 'name_or_path' in kwargs:
1144
+ name_or_path = kwargs['name_or_path']
1145
+ path_index = None
1146
+ is_positional = False
1147
+ else:
1148
+ # No name_or_path found, call original function
1149
+ return func(*args, **kwargs)
1150
+
1151
+ # Extract progress_callback and token from arguments
1152
+ progress_callback = None
1153
+ if 'progress_callback' in kwargs:
1154
+ progress_callback = kwargs.pop('progress_callback') # Remove from kwargs to avoid passing to original func
1155
+
1156
+ token = None
1157
+ if 'token' in kwargs:
1158
+ token = kwargs.pop('token') # Remove from kwargs to avoid passing to original func
1159
+
1160
+ # Check if name_or_path exists locally (file or directory)
1161
+ if os.path.exists(name_or_path):
1162
+ # Local path exists, use as-is without downloading
1163
+ return func(*args, **kwargs)
1164
+
1165
+ # Model path doesn't exist locally, try to download from HuggingFace
1166
+ try:
1167
+ # Parse name_or_path to extract repo_id and filename
1168
+ repo_id, file_name = _parse_model_path(name_or_path)
1169
+
1170
+ # Download the model
1171
+ downloaded_path = download_from_huggingface(
1172
+ repo_id=repo_id,
1173
+ file_name=file_name,
1174
+ local_dir=None, # Use default cache directory
1175
+ enable_transfer=True,
1176
+ progress_callback=progress_callback, # Use the extracted callback
1177
+ show_progress=True,
1178
+ token=token # Use the extracted token
1179
+ )
1180
+
1181
+ # Replace name_or_path with downloaded path
1182
+ if is_positional:
1183
+ args_list[path_index] = downloaded_path
1184
+ args = tuple(args_list)
1185
+ else:
1186
+ kwargs['name_or_path'] = downloaded_path
1187
+
1188
+ except Exception as e:
1189
+ # Only handle download-related errors
1190
+ raise RuntimeError(f"Could not load model from '{name_or_path}': {e}")
1191
+
1192
+ # Call original function with updated path (outside try-catch to let model creation errors bubble up)
1193
+ return func(*args, **kwargs)
1194
+
1195
+ return wrapper