ollamadiffuser 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,356 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Download utilities for robust model downloading with detailed progress tracking
4
+ """
5
+
6
+ import os
7
+ import time
8
+ import logging
9
+ from typing import Optional, Callable, Any, Dict
10
+ from pathlib import Path
11
+ from huggingface_hub import snapshot_download, hf_hub_download, HfApi
12
+ from tqdm import tqdm
13
+ import threading
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ class ProgressTracker:
18
+ """Track download progress across multiple files"""
19
+
20
+ def __init__(self, total_files: int = 0, progress_callback: Optional[Callable] = None):
21
+ self.total_files = total_files
22
+ self.completed_files = 0
23
+ self.current_file = ""
24
+ self.file_progress = {}
25
+ self.progress_callback = progress_callback
26
+ self.lock = threading.Lock()
27
+
28
+ def update_file_progress(self, filename: str, downloaded: int, total: int):
29
+ """Update progress for a specific file"""
30
+ with self.lock:
31
+ self.file_progress[filename] = (downloaded, total)
32
+ self._report_progress()
33
+
34
+ def complete_file(self, filename: str):
35
+ """Mark a file as completed"""
36
+ with self.lock:
37
+ self.completed_files += 1
38
+ if filename in self.file_progress:
39
+ downloaded, total = self.file_progress[filename]
40
+ self.file_progress[filename] = (total, total)
41
+ self._report_progress()
42
+
43
+ def set_current_file(self, filename: str):
44
+ """Set the currently downloading file"""
45
+ with self.lock:
46
+ self.current_file = filename
47
+ self._report_progress()
48
+
49
+ def _report_progress(self):
50
+ """Report current progress"""
51
+ if self.progress_callback:
52
+ # Calculate overall progress
53
+ total_downloaded = 0
54
+ total_size = 0
55
+
56
+ for downloaded, size in self.file_progress.values():
57
+ total_downloaded += downloaded
58
+ total_size += size
59
+
60
+ progress_msg = f"Files: {self.completed_files}/{self.total_files}"
61
+ if total_size > 0:
62
+ percent = (total_downloaded / total_size) * 100
63
+ progress_msg += f" | Overall: {percent:.1f}%"
64
+
65
+ if self.current_file:
66
+ progress_msg += f" | Current: {self.current_file}"
67
+
68
+ self.progress_callback(progress_msg)
69
+
70
+ def configure_hf_environment():
71
+ """Configure HuggingFace Hub environment for better downloads"""
72
+ # Set reasonable timeouts
73
+ os.environ.setdefault('HF_HUB_DOWNLOAD_TIMEOUT', '600') # 10 minutes
74
+ os.environ.setdefault('HF_HUB_CONNECTION_TIMEOUT', '120') # 2 minutes
75
+
76
+ # Disable symlinks for better compatibility
77
+ os.environ.setdefault('HF_HUB_LOCAL_DIR_USE_SYMLINKS', 'False')
78
+
79
+ # Enable resume downloads
80
+ os.environ.setdefault('HF_HUB_ENABLE_HF_TRANSFER', 'False') # Disable for better compatibility
81
+
82
+ def get_repo_file_list(repo_id: str) -> Dict[str, int]:
83
+ """Get list of files in repository with their sizes"""
84
+ try:
85
+ api = HfApi()
86
+ repo_info = api.repo_info(repo_id=repo_id)
87
+
88
+ file_sizes = {}
89
+ for sibling in repo_info.siblings:
90
+ # Include all files, use 0 as default size if not available
91
+ size = sibling.size if sibling.size is not None else 0
92
+ file_sizes[sibling.rfilename] = size
93
+
94
+ return file_sizes
95
+ except Exception as e:
96
+ logger.warning(f"Could not get file list for {repo_id}: {e}")
97
+ return {}
98
+
99
+ def format_size(size_bytes: int) -> str:
100
+ """Format size in human readable format"""
101
+ for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
102
+ if size_bytes < 1024.0:
103
+ return f"{size_bytes:.1f} {unit}"
104
+ size_bytes /= 1024.0
105
+ return f"{size_bytes:.1f} PB"
106
+
107
+ def robust_snapshot_download(
108
+ repo_id: str,
109
+ local_dir: str,
110
+ cache_dir: Optional[str] = None,
111
+ max_retries: int = 3,
112
+ initial_workers: int = 2,
113
+ force_download: bool = False,
114
+ progress_callback: Optional[Callable] = None
115
+ ) -> str:
116
+ """
117
+ Download repository snapshot with robust error handling and detailed progress tracking
118
+
119
+ Args:
120
+ repo_id: Repository ID on HuggingFace Hub
121
+ local_dir: Local directory to download to
122
+ cache_dir: Cache directory
123
+ max_retries: Maximum number of retry attempts
124
+ initial_workers: Initial number of workers (reduced on retries)
125
+ force_download: Force re-download
126
+ progress_callback: Optional progress callback function
127
+
128
+ Returns:
129
+ Path to downloaded repository
130
+ """
131
+ configure_hf_environment()
132
+
133
+ # Get file list and sizes for progress tracking
134
+ if progress_callback:
135
+ progress_callback("📋 Getting repository information...")
136
+
137
+ file_sizes = get_repo_file_list(repo_id)
138
+ total_size = sum(file_sizes.values())
139
+
140
+ if progress_callback and file_sizes:
141
+ progress_callback(f"📦 Repository: {len(file_sizes)} files, {format_size(total_size)} total")
142
+
143
+ # Check what's already downloaded
144
+ local_path = Path(local_dir)
145
+ if local_path.exists() and not force_download:
146
+ existing_files = []
147
+ existing_size = 0
148
+ for file_path in local_path.rglob('*'):
149
+ if file_path.is_file():
150
+ rel_path = file_path.relative_to(local_path)
151
+ existing_files.append(str(rel_path))
152
+ existing_size += file_path.stat().st_size
153
+
154
+ if progress_callback and existing_files:
155
+ progress_callback(f"📁 Found {len(existing_files)} existing files ({format_size(existing_size)})")
156
+
157
+ last_exception = None
158
+
159
+ for attempt in range(max_retries):
160
+ try:
161
+ # Reduce workers on retry attempts to avoid overwhelming connections
162
+ workers = 1 if attempt > 0 else initial_workers
163
+
164
+ if progress_callback:
165
+ progress_callback(f"🔄 Download attempt {attempt + 1}/{max_retries} (workers: {workers})")
166
+
167
+ logger.info(f"Download attempt {attempt + 1}/{max_retries} with {workers} workers")
168
+
169
+ # Create a custom progress callback for tqdm
170
+ def tqdm_callback(t):
171
+ def inner(chunk_size):
172
+ t.update(chunk_size)
173
+ return inner
174
+
175
+ result = snapshot_download(
176
+ repo_id=repo_id,
177
+ local_dir=local_dir,
178
+ local_dir_use_symlinks=False,
179
+ cache_dir=cache_dir,
180
+ max_workers=workers,
181
+ resume_download=True, # Enable resume
182
+ etag_timeout=300 + (attempt * 60), # Increase timeout on retries
183
+ force_download=force_download,
184
+ tqdm_class=tqdm if progress_callback else None
185
+ )
186
+
187
+ if progress_callback:
188
+ progress_callback(f"✅ Successfully downloaded {repo_id}")
189
+
190
+ logger.info(f"Successfully downloaded {repo_id}")
191
+ return result
192
+
193
+ except Exception as e:
194
+ last_exception = e
195
+ error_msg = str(e)
196
+
197
+ # Log the specific error
198
+ logger.warning(f"Download attempt {attempt + 1} failed: {error_msg}")
199
+
200
+ if attempt < max_retries - 1:
201
+ # Determine wait time based on error type
202
+ if "timeout" in error_msg.lower():
203
+ wait_time = 30 + (attempt * 15) # Longer wait for timeouts
204
+ elif "connection" in error_msg.lower():
205
+ wait_time = 20 + (attempt * 10) # Medium wait for connection errors
206
+ else:
207
+ wait_time = 10 + (attempt * 5) # Shorter wait for other errors
208
+
209
+ logger.info(f"Waiting {wait_time} seconds before retry...")
210
+
211
+ if progress_callback:
212
+ progress_callback(f"⚠️ Download failed, retrying in {wait_time}s... (Error: {error_msg[:100]})")
213
+
214
+ time.sleep(wait_time)
215
+ else:
216
+ logger.error(f"All download attempts failed. Final error: {error_msg}")
217
+ if progress_callback:
218
+ progress_callback(f"❌ All download attempts failed: {error_msg}")
219
+
220
+ # If we get here, all retries failed
221
+ raise last_exception
222
+
223
+ def robust_file_download(
224
+ repo_id: str,
225
+ filename: str,
226
+ local_dir: str,
227
+ cache_dir: Optional[str] = None,
228
+ max_retries: int = 3,
229
+ progress_callback: Optional[Callable] = None
230
+ ) -> str:
231
+ """
232
+ Download single file with robust error handling and progress tracking
233
+
234
+ Args:
235
+ repo_id: Repository ID on HuggingFace Hub
236
+ filename: File to download
237
+ local_dir: Local directory to download to
238
+ cache_dir: Cache directory
239
+ max_retries: Maximum number of retry attempts
240
+ progress_callback: Optional progress callback function
241
+
242
+ Returns:
243
+ Path to downloaded file
244
+ """
245
+ configure_hf_environment()
246
+
247
+ last_exception = None
248
+
249
+ for attempt in range(max_retries):
250
+ try:
251
+ if progress_callback:
252
+ progress_callback(f"📥 Downloading {filename} (attempt {attempt + 1}/{max_retries})")
253
+
254
+ logger.info(f"File download attempt {attempt + 1}/{max_retries}: {filename}")
255
+
256
+ result = hf_hub_download(
257
+ repo_id=repo_id,
258
+ filename=filename,
259
+ local_dir=local_dir,
260
+ cache_dir=cache_dir,
261
+ resume_download=True, # Enable resume
262
+ etag_timeout=180 + (attempt * 30)
263
+ )
264
+
265
+ if progress_callback:
266
+ progress_callback(f"✅ Downloaded {filename}")
267
+
268
+ logger.info(f"Successfully downloaded {filename}")
269
+ return result
270
+
271
+ except Exception as e:
272
+ last_exception = e
273
+ error_msg = str(e)
274
+
275
+ logger.warning(f"File download attempt {attempt + 1} failed: {error_msg}")
276
+
277
+ if attempt < max_retries - 1:
278
+ wait_time = 5 + (attempt * 3) # Progressive backoff
279
+
280
+ if progress_callback:
281
+ progress_callback(f"⚠️ Retrying {filename} in {wait_time}s...")
282
+
283
+ time.sleep(wait_time)
284
+ else:
285
+ logger.error(f"All file download attempts failed. Final error: {error_msg}")
286
+ if progress_callback:
287
+ progress_callback(f"❌ Failed to download {filename}: {error_msg}")
288
+
289
+ # If we get here, all retries failed
290
+ raise last_exception
291
+
292
+ def check_download_integrity(local_dir: str, repo_id: str) -> bool:
293
+ """Check if downloaded files are complete and valid"""
294
+ try:
295
+ local_path = Path(local_dir)
296
+ if not local_path.exists():
297
+ return False
298
+
299
+ # Check for essential files
300
+ essential_files = ['model_index.json']
301
+ for essential_file in essential_files:
302
+ if not (local_path / essential_file).exists():
303
+ logger.warning(f"Missing essential file: {essential_file}")
304
+ return False
305
+
306
+ # Files to ignore during integrity check
307
+ ignore_patterns = [
308
+ '.lock', # HuggingFace lock files
309
+ '.metadata', # HuggingFace metadata files
310
+ '.incomplete', # Incomplete download files
311
+ '.cache', # Cache directory
312
+ '.git', # Git files
313
+ '.gitattributes', # Git attributes
314
+ 'README.md', # Documentation files
315
+ 'LICENSE.md', # License files
316
+ 'dev_grid.jpg' # Sample images
317
+ ]
318
+
319
+ # Check for empty files (excluding ignored patterns)
320
+ for file_path in local_path.rglob('*'):
321
+ if file_path.is_file():
322
+ # Skip files that match ignore patterns
323
+ should_ignore = any(pattern in str(file_path) for pattern in ignore_patterns)
324
+ if should_ignore:
325
+ continue
326
+
327
+ # Check if file is empty
328
+ if file_path.stat().st_size == 0:
329
+ logger.warning(f"Empty file detected: {file_path}")
330
+ return False
331
+
332
+ # Check for critical model files
333
+ critical_dirs = ['transformer', 'text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2']
334
+ for critical_dir in critical_dirs:
335
+ dir_path = local_path / critical_dir
336
+ if dir_path.exists():
337
+ # Check if directory has any non-empty files
338
+ has_content = False
339
+ for file_path in dir_path.rglob('*'):
340
+ if file_path.is_file() and file_path.stat().st_size > 0:
341
+ # Skip ignored files
342
+ should_ignore = any(pattern in str(file_path) for pattern in ignore_patterns)
343
+ if not should_ignore:
344
+ has_content = True
345
+ break
346
+
347
+ if not has_content:
348
+ logger.warning(f"Critical directory {critical_dir} appears to be empty or incomplete")
349
+ return False
350
+
351
+ logger.info("Download integrity check passed")
352
+ return True
353
+
354
+ except Exception as e:
355
+ logger.error(f"Error checking download integrity: {e}")
356
+ return False