ollamadiffuser 1.0.0__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ollamadiffuser/api/server.py +147 -2
- ollamadiffuser/cli/main.py +325 -25
- ollamadiffuser/core/inference/engine.py +180 -9
- ollamadiffuser/core/models/manager.py +136 -2
- ollamadiffuser/core/utils/controlnet_preprocessors.py +317 -0
- ollamadiffuser/core/utils/download_utils.py +209 -60
- ollamadiffuser/ui/templates/index.html +384 -7
- ollamadiffuser/ui/web.py +181 -100
- ollamadiffuser-1.1.1.dist-info/METADATA +470 -0
- {ollamadiffuser-1.0.0.dist-info → ollamadiffuser-1.1.1.dist-info}/RECORD +14 -13
- ollamadiffuser-1.0.0.dist-info/METADATA +0 -493
- {ollamadiffuser-1.0.0.dist-info → ollamadiffuser-1.1.1.dist-info}/WHEEL +0 -0
- {ollamadiffuser-1.0.0.dist-info → ollamadiffuser-1.1.1.dist-info}/entry_points.txt +0 -0
- {ollamadiffuser-1.0.0.dist-info → ollamadiffuser-1.1.1.dist-info}/licenses/LICENSE +0 -0
- {ollamadiffuser-1.0.0.dist-info → ollamadiffuser-1.1.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,317 @@
|
|
|
1
|
+
import cv2
|
|
2
|
+
import numpy as np
|
|
3
|
+
from PIL import Image
|
|
4
|
+
import logging
|
|
5
|
+
from typing import Optional, Union, Tuple, Dict, Any
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
logger = logging.getLogger(__name__)
|
|
9
|
+
|
|
10
|
+
class ControlNetPreprocessorManager:
|
|
11
|
+
"""Lazy-loading manager for ControlNet preprocessors"""
|
|
12
|
+
|
|
13
|
+
def __init__(self):
|
|
14
|
+
self.processors = {}
|
|
15
|
+
self._initialized = False
|
|
16
|
+
self._initialization_attempted = False
|
|
17
|
+
self._available_types = []
|
|
18
|
+
|
|
19
|
+
def is_initialized(self) -> bool:
|
|
20
|
+
"""Check if preprocessors are initialized"""
|
|
21
|
+
return self._initialized
|
|
22
|
+
|
|
23
|
+
def is_available(self) -> bool:
|
|
24
|
+
"""Check if ControlNet preprocessors are available"""
|
|
25
|
+
if not self._initialization_attempted:
|
|
26
|
+
# Try a lightweight check without full initialization
|
|
27
|
+
try:
|
|
28
|
+
import controlnet_aux
|
|
29
|
+
return True
|
|
30
|
+
except ImportError:
|
|
31
|
+
return False
|
|
32
|
+
return self._initialized
|
|
33
|
+
|
|
34
|
+
def initialize(self, force: bool = False) -> bool:
|
|
35
|
+
"""
|
|
36
|
+
Initialize ControlNet preprocessors
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
force: Force re-initialization even if already initialized
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
True if initialization successful, False otherwise
|
|
43
|
+
"""
|
|
44
|
+
if self._initialized and not force:
|
|
45
|
+
return True
|
|
46
|
+
|
|
47
|
+
if self._initialization_attempted and not force:
|
|
48
|
+
return self._initialized
|
|
49
|
+
|
|
50
|
+
self._initialization_attempted = True
|
|
51
|
+
logger.info("Initializing ControlNet preprocessors...")
|
|
52
|
+
|
|
53
|
+
try:
|
|
54
|
+
from controlnet_aux import (
|
|
55
|
+
CannyDetector,
|
|
56
|
+
MidasDetector,
|
|
57
|
+
OpenposeDetector,
|
|
58
|
+
HEDdetector,
|
|
59
|
+
MLSDdetector,
|
|
60
|
+
NormalBaeDetector,
|
|
61
|
+
LineartDetector,
|
|
62
|
+
LineartAnimeDetector,
|
|
63
|
+
ContentShuffleDetector,
|
|
64
|
+
ZoeDetector
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
# Initialize processors with proper error handling
|
|
68
|
+
self.processors = {}
|
|
69
|
+
|
|
70
|
+
# Canny detector (no model download needed)
|
|
71
|
+
try:
|
|
72
|
+
self.processors['canny'] = CannyDetector()
|
|
73
|
+
logger.info("Canny detector initialized")
|
|
74
|
+
except Exception as e:
|
|
75
|
+
logger.warning(f"Failed to initialize Canny detector: {e}")
|
|
76
|
+
|
|
77
|
+
# Depth detectors with fallback
|
|
78
|
+
try:
|
|
79
|
+
# Try MiDaS first with default model
|
|
80
|
+
self.processors['depth'] = MidasDetector.from_pretrained('lllyasviel/Annotators')
|
|
81
|
+
logger.info("MiDaS depth detector initialized")
|
|
82
|
+
except Exception as e:
|
|
83
|
+
logger.warning(f"Failed to initialize MiDaS detector: {e}")
|
|
84
|
+
try:
|
|
85
|
+
# Try ZoeDepth as fallback
|
|
86
|
+
self.processors['depth_zoe'] = ZoeDetector.from_pretrained('lllyasviel/Annotators')
|
|
87
|
+
self.processors['depth'] = self.processors['depth_zoe'] # Use as main depth
|
|
88
|
+
logger.info("ZoeDepth detector initialized as fallback")
|
|
89
|
+
except Exception as e2:
|
|
90
|
+
logger.warning(f"Failed to initialize ZoeDepth detector: {e2}")
|
|
91
|
+
|
|
92
|
+
# OpenPose detector
|
|
93
|
+
try:
|
|
94
|
+
self.processors['openpose'] = OpenposeDetector.from_pretrained('lllyasviel/Annotators')
|
|
95
|
+
logger.info("OpenPose detector initialized")
|
|
96
|
+
except Exception as e:
|
|
97
|
+
logger.warning(f"Failed to initialize OpenPose detector: {e}")
|
|
98
|
+
|
|
99
|
+
# HED detector
|
|
100
|
+
try:
|
|
101
|
+
self.processors['hed'] = HEDdetector.from_pretrained('lllyasviel/Annotators')
|
|
102
|
+
logger.info("HED detector initialized")
|
|
103
|
+
except Exception as e:
|
|
104
|
+
logger.warning(f"Failed to initialize HED detector: {e}")
|
|
105
|
+
|
|
106
|
+
# MLSD detector
|
|
107
|
+
try:
|
|
108
|
+
self.processors['mlsd'] = MLSDdetector.from_pretrained('lllyasviel/Annotators')
|
|
109
|
+
logger.info("MLSD detector initialized")
|
|
110
|
+
except Exception as e:
|
|
111
|
+
logger.warning(f"Failed to initialize MLSD detector: {e}")
|
|
112
|
+
|
|
113
|
+
# Normal detector
|
|
114
|
+
try:
|
|
115
|
+
self.processors['normal'] = NormalBaeDetector.from_pretrained('lllyasviel/Annotators')
|
|
116
|
+
logger.info("Normal detector initialized")
|
|
117
|
+
except Exception as e:
|
|
118
|
+
logger.warning(f"Failed to initialize Normal detector: {e}")
|
|
119
|
+
|
|
120
|
+
# Lineart detectors
|
|
121
|
+
try:
|
|
122
|
+
self.processors['lineart'] = LineartDetector.from_pretrained('lllyasviel/Annotators')
|
|
123
|
+
logger.info("Lineart detector initialized")
|
|
124
|
+
except Exception as e:
|
|
125
|
+
logger.warning(f"Failed to initialize Lineart detector: {e}")
|
|
126
|
+
|
|
127
|
+
try:
|
|
128
|
+
self.processors['lineart_anime'] = LineartAnimeDetector.from_pretrained('lllyasviel/Annotators')
|
|
129
|
+
logger.info("Lineart Anime detector initialized")
|
|
130
|
+
except Exception as e:
|
|
131
|
+
logger.warning(f"Failed to initialize Lineart Anime detector: {e}")
|
|
132
|
+
|
|
133
|
+
# Content shuffle (no model download needed)
|
|
134
|
+
try:
|
|
135
|
+
self.processors['shuffle'] = ContentShuffleDetector()
|
|
136
|
+
logger.info("Content Shuffle detector initialized")
|
|
137
|
+
except Exception as e:
|
|
138
|
+
logger.warning(f"Failed to initialize Content Shuffle detector: {e}")
|
|
139
|
+
|
|
140
|
+
# Add scribble as alias for HED
|
|
141
|
+
if 'hed' in self.processors:
|
|
142
|
+
self.processors['scribble'] = self.processors['hed']
|
|
143
|
+
|
|
144
|
+
if self.processors:
|
|
145
|
+
self._initialized = True
|
|
146
|
+
self._available_types = list(self.processors.keys())
|
|
147
|
+
logger.info(f"ControlNet preprocessors initialized: {self._available_types}")
|
|
148
|
+
return True
|
|
149
|
+
else:
|
|
150
|
+
logger.warning("No ControlNet preprocessors could be initialized, falling back to basic processors")
|
|
151
|
+
self._init_basic_processors()
|
|
152
|
+
return True
|
|
153
|
+
|
|
154
|
+
except ImportError as e:
|
|
155
|
+
logger.warning(f"controlnet-aux not available: {e}")
|
|
156
|
+
# Fallback to basic OpenCV-based processors
|
|
157
|
+
self._init_basic_processors()
|
|
158
|
+
return True
|
|
159
|
+
except Exception as e:
|
|
160
|
+
logger.error(f"Error initializing ControlNet preprocessors: {e}")
|
|
161
|
+
# Fallback to basic OpenCV-based processors
|
|
162
|
+
self._init_basic_processors()
|
|
163
|
+
return True
|
|
164
|
+
|
|
165
|
+
def _init_basic_processors(self):
|
|
166
|
+
"""Initialize basic OpenCV-based processors as fallback"""
|
|
167
|
+
logger.info("Using basic OpenCV-based preprocessors")
|
|
168
|
+
self.processors = {
|
|
169
|
+
'canny': self._canny_opencv,
|
|
170
|
+
'depth': self._depth_basic,
|
|
171
|
+
'scribble': self._scribble_basic,
|
|
172
|
+
}
|
|
173
|
+
self._initialized = True
|
|
174
|
+
self._available_types = list(self.processors.keys())
|
|
175
|
+
|
|
176
|
+
def _canny_opencv(self, image: Image.Image, low_threshold: int = 100, high_threshold: int = 200) -> Image.Image:
|
|
177
|
+
"""Basic Canny edge detection using OpenCV"""
|
|
178
|
+
# Convert PIL to OpenCV format
|
|
179
|
+
image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
|
180
|
+
|
|
181
|
+
# Apply Canny edge detection
|
|
182
|
+
edges = cv2.Canny(image_cv, low_threshold, high_threshold)
|
|
183
|
+
|
|
184
|
+
# Convert back to PIL
|
|
185
|
+
edges_rgb = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB)
|
|
186
|
+
return Image.fromarray(edges_rgb)
|
|
187
|
+
|
|
188
|
+
def _depth_basic(self, image: Image.Image) -> Image.Image:
|
|
189
|
+
"""Basic depth estimation using simple gradients"""
|
|
190
|
+
# Convert to grayscale
|
|
191
|
+
gray = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
|
|
192
|
+
|
|
193
|
+
# Apply Gaussian blur
|
|
194
|
+
blurred = cv2.GaussianBlur(gray, (5, 5), 0)
|
|
195
|
+
|
|
196
|
+
# Create a simple depth map using gradients
|
|
197
|
+
grad_x = cv2.Sobel(blurred, cv2.CV_64F, 1, 0, ksize=3)
|
|
198
|
+
grad_y = cv2.Sobel(blurred, cv2.CV_64F, 0, 1, ksize=3)
|
|
199
|
+
|
|
200
|
+
# Combine gradients
|
|
201
|
+
depth = np.sqrt(grad_x**2 + grad_y**2)
|
|
202
|
+
depth = np.uint8(255 * depth / np.max(depth))
|
|
203
|
+
|
|
204
|
+
# Convert to RGB
|
|
205
|
+
depth_rgb = cv2.cvtColor(depth, cv2.COLOR_GRAY2RGB)
|
|
206
|
+
return Image.fromarray(depth_rgb)
|
|
207
|
+
|
|
208
|
+
def _scribble_basic(self, image: Image.Image) -> Image.Image:
|
|
209
|
+
"""Basic scribble detection using edge detection"""
|
|
210
|
+
# Convert to grayscale
|
|
211
|
+
gray = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
|
|
212
|
+
|
|
213
|
+
# Apply edge detection
|
|
214
|
+
edges = cv2.Canny(gray, 50, 150)
|
|
215
|
+
|
|
216
|
+
# Dilate to make lines thicker
|
|
217
|
+
kernel = np.ones((3, 3), np.uint8)
|
|
218
|
+
edges = cv2.dilate(edges, kernel, iterations=1)
|
|
219
|
+
|
|
220
|
+
# Convert to RGB
|
|
221
|
+
edges_rgb = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB)
|
|
222
|
+
return Image.fromarray(edges_rgb)
|
|
223
|
+
|
|
224
|
+
def preprocess(self,
|
|
225
|
+
image: Union[Image.Image, str],
|
|
226
|
+
control_type: str,
|
|
227
|
+
**kwargs) -> Image.Image:
|
|
228
|
+
"""
|
|
229
|
+
Preprocess image for ControlNet
|
|
230
|
+
|
|
231
|
+
Args:
|
|
232
|
+
image: Input image (PIL Image or path)
|
|
233
|
+
control_type: Type of control (canny, depth, openpose, etc.)
|
|
234
|
+
**kwargs: Additional parameters for specific processors
|
|
235
|
+
|
|
236
|
+
Returns:
|
|
237
|
+
Preprocessed control image
|
|
238
|
+
"""
|
|
239
|
+
# Initialize if not already done
|
|
240
|
+
if not self._initialized:
|
|
241
|
+
if not self.initialize():
|
|
242
|
+
raise RuntimeError("Failed to initialize ControlNet preprocessors")
|
|
243
|
+
|
|
244
|
+
# Load image if path is provided
|
|
245
|
+
if isinstance(image, str):
|
|
246
|
+
image = Image.open(image).convert('RGB')
|
|
247
|
+
elif not isinstance(image, Image.Image):
|
|
248
|
+
raise ValueError("Image must be PIL Image or file path")
|
|
249
|
+
|
|
250
|
+
# Ensure image is RGB
|
|
251
|
+
if image.mode != 'RGB':
|
|
252
|
+
image = image.convert('RGB')
|
|
253
|
+
|
|
254
|
+
# Get processor
|
|
255
|
+
if control_type not in self.processors:
|
|
256
|
+
available = list(self.processors.keys())
|
|
257
|
+
raise ValueError(f"Control type '{control_type}' not available. Available: {available}")
|
|
258
|
+
|
|
259
|
+
processor = self.processors[control_type]
|
|
260
|
+
|
|
261
|
+
try:
|
|
262
|
+
if callable(processor):
|
|
263
|
+
# Basic OpenCV processor
|
|
264
|
+
return processor(image, **kwargs)
|
|
265
|
+
else:
|
|
266
|
+
# ControlNet-aux processor
|
|
267
|
+
return processor(image, **kwargs)
|
|
268
|
+
|
|
269
|
+
except Exception as e:
|
|
270
|
+
logger.error(f"Failed to preprocess image with {control_type}: {e}")
|
|
271
|
+
# Return original image as fallback
|
|
272
|
+
return image
|
|
273
|
+
|
|
274
|
+
def get_available_types(self) -> list:
|
|
275
|
+
"""Get list of available control types"""
|
|
276
|
+
if not self._initialized:
|
|
277
|
+
# Return cached types if available, otherwise return basic types
|
|
278
|
+
if self._available_types:
|
|
279
|
+
return self._available_types
|
|
280
|
+
elif self.is_available():
|
|
281
|
+
return ['canny', 'depth', 'openpose', 'hed', 'mlsd', 'normal', 'lineart', 'lineart_anime', 'shuffle', 'scribble']
|
|
282
|
+
else:
|
|
283
|
+
return ['canny', 'depth', 'scribble'] # Basic OpenCV types
|
|
284
|
+
return list(self.processors.keys())
|
|
285
|
+
|
|
286
|
+
def resize_for_controlnet(self,
|
|
287
|
+
image: Image.Image,
|
|
288
|
+
width: int = 512,
|
|
289
|
+
height: int = 512) -> Image.Image:
|
|
290
|
+
"""Resize image for ControlNet while maintaining aspect ratio"""
|
|
291
|
+
# Calculate aspect ratio
|
|
292
|
+
aspect_ratio = image.width / image.height
|
|
293
|
+
|
|
294
|
+
if aspect_ratio > 1:
|
|
295
|
+
# Landscape
|
|
296
|
+
new_width = width
|
|
297
|
+
new_height = int(width / aspect_ratio)
|
|
298
|
+
else:
|
|
299
|
+
# Portrait
|
|
300
|
+
new_height = height
|
|
301
|
+
new_width = int(height * aspect_ratio)
|
|
302
|
+
|
|
303
|
+
# Resize image
|
|
304
|
+
resized = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
|
305
|
+
|
|
306
|
+
# Create new image with target size and paste resized image
|
|
307
|
+
result = Image.new('RGB', (width, height), (0, 0, 0))
|
|
308
|
+
|
|
309
|
+
# Calculate position to center the image
|
|
310
|
+
x = (width - new_width) // 2
|
|
311
|
+
y = (height - new_height) // 2
|
|
312
|
+
|
|
313
|
+
result.paste(resized, (x, y))
|
|
314
|
+
return result
|
|
315
|
+
|
|
316
|
+
# Global manager instance - no initialization at import time
|
|
317
|
+
controlnet_preprocessor = ControlNetPreprocessorManager()
|
|
@@ -11,25 +11,107 @@ from pathlib import Path
|
|
|
11
11
|
from huggingface_hub import snapshot_download, hf_hub_download, HfApi
|
|
12
12
|
from tqdm import tqdm
|
|
13
13
|
import threading
|
|
14
|
+
import requests
|
|
14
15
|
|
|
15
16
|
logger = logging.getLogger(__name__)
|
|
16
17
|
|
|
17
|
-
class
|
|
18
|
-
"""
|
|
18
|
+
class EnhancedProgressTracker:
|
|
19
|
+
"""Enhanced progress tracker that provides Ollama-style detailed progress information"""
|
|
19
20
|
|
|
20
21
|
def __init__(self, total_files: int = 0, progress_callback: Optional[Callable] = None):
|
|
21
22
|
self.total_files = total_files
|
|
22
23
|
self.completed_files = 0
|
|
23
24
|
self.current_file = ""
|
|
24
25
|
self.file_progress = {}
|
|
26
|
+
self.file_start_times = {}
|
|
27
|
+
self.file_speeds = {}
|
|
25
28
|
self.progress_callback = progress_callback
|
|
26
29
|
self.lock = threading.Lock()
|
|
30
|
+
self.overall_start_time = time.time()
|
|
31
|
+
self.total_size = 0
|
|
32
|
+
self.downloaded_size = 0
|
|
27
33
|
|
|
34
|
+
def set_total_size(self, total_size: int):
|
|
35
|
+
"""Set the total size for all files"""
|
|
36
|
+
with self.lock:
|
|
37
|
+
self.total_size = total_size
|
|
38
|
+
|
|
39
|
+
def start_file(self, filename: str, file_size: int = 0):
|
|
40
|
+
"""Mark a file as started"""
|
|
41
|
+
with self.lock:
|
|
42
|
+
self.current_file = filename
|
|
43
|
+
self.file_start_times[filename] = time.time()
|
|
44
|
+
self.file_progress[filename] = (0, file_size)
|
|
45
|
+
|
|
46
|
+
# Extract hash-like identifier for Ollama-style display
|
|
47
|
+
import re
|
|
48
|
+
hash_match = re.search(r'([a-f0-9]{8,})', filename)
|
|
49
|
+
if hash_match:
|
|
50
|
+
display_name = hash_match.group(1)[:12] # First 12 characters
|
|
51
|
+
else:
|
|
52
|
+
# Fallback to filename without extension
|
|
53
|
+
display_name = Path(filename).stem[:12]
|
|
54
|
+
|
|
55
|
+
if self.progress_callback:
|
|
56
|
+
self.progress_callback(f"pulling {display_name}")
|
|
57
|
+
|
|
28
58
|
def update_file_progress(self, filename: str, downloaded: int, total: int):
|
|
29
|
-
"""Update progress for a specific file"""
|
|
59
|
+
"""Update progress for a specific file with speed calculation"""
|
|
30
60
|
with self.lock:
|
|
61
|
+
current_time = time.time()
|
|
62
|
+
|
|
63
|
+
# Update file progress
|
|
64
|
+
old_downloaded = self.file_progress.get(filename, (0, 0))[0]
|
|
31
65
|
self.file_progress[filename] = (downloaded, total)
|
|
32
|
-
|
|
66
|
+
|
|
67
|
+
# Update overall downloaded size
|
|
68
|
+
size_diff = downloaded - old_downloaded
|
|
69
|
+
self.downloaded_size += size_diff
|
|
70
|
+
|
|
71
|
+
# Calculate speed for this file
|
|
72
|
+
if filename in self.file_start_times:
|
|
73
|
+
elapsed = current_time - self.file_start_times[filename]
|
|
74
|
+
if elapsed > 0 and downloaded > 0:
|
|
75
|
+
speed = downloaded / elapsed # bytes per second
|
|
76
|
+
self.file_speeds[filename] = speed
|
|
77
|
+
|
|
78
|
+
# Report progress in Ollama style
|
|
79
|
+
if self.progress_callback and total > 0:
|
|
80
|
+
percentage = (downloaded / total) * 100
|
|
81
|
+
|
|
82
|
+
# Format sizes
|
|
83
|
+
downloaded_mb = downloaded / (1024 * 1024)
|
|
84
|
+
total_mb = total / (1024 * 1024)
|
|
85
|
+
|
|
86
|
+
# Calculate speed in MB/s
|
|
87
|
+
speed_mbps = self.file_speeds.get(filename, 0) / (1024 * 1024)
|
|
88
|
+
|
|
89
|
+
# Calculate ETA
|
|
90
|
+
if speed_mbps > 0:
|
|
91
|
+
remaining_mb = total_mb - downloaded_mb
|
|
92
|
+
eta_seconds = remaining_mb / speed_mbps
|
|
93
|
+
eta_min = int(eta_seconds // 60)
|
|
94
|
+
eta_sec = int(eta_seconds % 60)
|
|
95
|
+
eta_str = f"{eta_min}m{eta_sec:02d}s"
|
|
96
|
+
else:
|
|
97
|
+
eta_str = "?"
|
|
98
|
+
|
|
99
|
+
# Extract hash for display
|
|
100
|
+
import re
|
|
101
|
+
hash_match = re.search(r'([a-f0-9]{8,})', filename)
|
|
102
|
+
if hash_match:
|
|
103
|
+
display_name = hash_match.group(1)[:12]
|
|
104
|
+
else:
|
|
105
|
+
display_name = Path(filename).stem[:12]
|
|
106
|
+
|
|
107
|
+
# Create progress bar
|
|
108
|
+
bar_width = 20
|
|
109
|
+
filled = int((percentage / 100) * bar_width)
|
|
110
|
+
bar = "█" * filled + " " * (bar_width - filled)
|
|
111
|
+
|
|
112
|
+
progress_msg = f"pulling {display_name}: {percentage:3.0f}% ▕{bar}▏ {downloaded_mb:.0f} MB/{total_mb:.0f} MB {speed_mbps:.0f} MB/s {eta_str}"
|
|
113
|
+
|
|
114
|
+
self.progress_callback(progress_msg)
|
|
33
115
|
|
|
34
116
|
def complete_file(self, filename: str):
|
|
35
117
|
"""Mark a file as completed"""
|
|
@@ -38,34 +120,42 @@ class ProgressTracker:
|
|
|
38
120
|
if filename in self.file_progress:
|
|
39
121
|
downloaded, total = self.file_progress[filename]
|
|
40
122
|
self.file_progress[filename] = (total, total)
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
123
|
+
|
|
124
|
+
# Report completion
|
|
125
|
+
if self.progress_callback:
|
|
126
|
+
import re
|
|
127
|
+
hash_match = re.search(r'([a-f0-9]{8,})', filename)
|
|
128
|
+
if hash_match:
|
|
129
|
+
display_name = hash_match.group(1)[:12]
|
|
130
|
+
else:
|
|
131
|
+
display_name = Path(filename).stem[:12]
|
|
132
|
+
|
|
133
|
+
total_mb = self.file_progress.get(filename, (0, 0))[1] / (1024 * 1024)
|
|
134
|
+
self.progress_callback(f"pulling {display_name}: 100% ▕████████████████████▏ {total_mb:.0f} MB/{total_mb:.0f} MB")
|
|
48
135
|
|
|
49
|
-
def
|
|
50
|
-
"""Report
|
|
136
|
+
def report_overall_progress(self):
|
|
137
|
+
"""Report overall progress"""
|
|
51
138
|
if self.progress_callback:
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
139
|
+
if self.total_size > 0:
|
|
140
|
+
overall_percent = (self.downloaded_size / self.total_size) * 100
|
|
141
|
+
downloaded_gb = self.downloaded_size / (1024 * 1024 * 1024)
|
|
142
|
+
total_gb = self.total_size / (1024 * 1024 * 1024)
|
|
143
|
+
|
|
144
|
+
elapsed = time.time() - self.overall_start_time
|
|
145
|
+
if elapsed > 0:
|
|
146
|
+
overall_speed = self.downloaded_size / elapsed / (1024 * 1024) # MB/s
|
|
147
|
+
|
|
148
|
+
if overall_speed > 0:
|
|
149
|
+
remaining_gb = total_gb - downloaded_gb
|
|
150
|
+
eta_seconds = (remaining_gb * 1024) / overall_speed # Convert GB to MB for calculation
|
|
151
|
+
eta_min = int(eta_seconds // 60)
|
|
152
|
+
eta_sec = int(eta_seconds % 60)
|
|
153
|
+
eta_str = f"{eta_min}m{eta_sec:02d}s"
|
|
154
|
+
else:
|
|
155
|
+
eta_str = "?"
|
|
156
|
+
|
|
157
|
+
progress_msg = f"Overall progress: {overall_percent:.1f}% | {downloaded_gb:.1f} GB/{total_gb:.1f} GB | {overall_speed:.1f} MB/s | ETA: {eta_str}"
|
|
158
|
+
self.progress_callback(progress_msg)
|
|
69
159
|
|
|
70
160
|
def configure_hf_environment():
|
|
71
161
|
"""Configure HuggingFace Hub environment for better downloads"""
|
|
@@ -132,7 +222,7 @@ def robust_snapshot_download(
|
|
|
132
222
|
|
|
133
223
|
# Get file list and sizes for progress tracking
|
|
134
224
|
if progress_callback:
|
|
135
|
-
progress_callback("
|
|
225
|
+
progress_callback("pulling manifest")
|
|
136
226
|
|
|
137
227
|
file_sizes = get_repo_file_list(repo_id)
|
|
138
228
|
total_size = sum(file_sizes.values())
|
|
@@ -140,20 +230,65 @@ def robust_snapshot_download(
|
|
|
140
230
|
if progress_callback and file_sizes:
|
|
141
231
|
progress_callback(f"📦 Repository: {len(file_sizes)} files, {format_size(total_size)} total")
|
|
142
232
|
|
|
233
|
+
# Initialize enhanced progress tracker
|
|
234
|
+
progress_tracker = EnhancedProgressTracker(len(file_sizes), progress_callback)
|
|
235
|
+
progress_tracker.set_total_size(total_size)
|
|
236
|
+
|
|
143
237
|
# Check what's already downloaded
|
|
144
238
|
local_path = Path(local_dir)
|
|
239
|
+
existing_size = 0
|
|
145
240
|
if local_path.exists() and not force_download:
|
|
146
241
|
existing_files = []
|
|
147
|
-
existing_size = 0
|
|
148
242
|
for file_path in local_path.rglob('*'):
|
|
149
243
|
if file_path.is_file():
|
|
150
244
|
rel_path = file_path.relative_to(local_path)
|
|
151
245
|
existing_files.append(str(rel_path))
|
|
152
|
-
|
|
246
|
+
file_size = file_path.stat().st_size
|
|
247
|
+
existing_size += file_size
|
|
248
|
+
# Mark existing files as completed in progress tracker
|
|
249
|
+
progress_tracker.file_progress[str(rel_path)] = (file_size, file_size)
|
|
250
|
+
progress_tracker.downloaded_size += file_size
|
|
251
|
+
progress_tracker.completed_files += 1
|
|
153
252
|
|
|
154
253
|
if progress_callback and existing_files:
|
|
155
254
|
progress_callback(f"📁 Found {len(existing_files)} existing files ({format_size(existing_size)})")
|
|
156
255
|
|
|
256
|
+
# Custom tqdm class to capture HuggingFace download progress
|
|
257
|
+
class OllamaStyleTqdm(tqdm):
|
|
258
|
+
def __init__(self, *args, **kwargs):
|
|
259
|
+
# Extract description to get filename
|
|
260
|
+
desc = kwargs.get('desc', '')
|
|
261
|
+
self.current_filename = desc
|
|
262
|
+
|
|
263
|
+
# Get file size from our pre-fetched data
|
|
264
|
+
file_size = file_sizes.get(self.current_filename, 0)
|
|
265
|
+
if file_size > 0:
|
|
266
|
+
kwargs['total'] = file_size
|
|
267
|
+
|
|
268
|
+
super().__init__(*args, **kwargs)
|
|
269
|
+
|
|
270
|
+
# Start tracking this file
|
|
271
|
+
if self.current_filename and progress_callback:
|
|
272
|
+
progress_tracker.start_file(self.current_filename, file_size)
|
|
273
|
+
|
|
274
|
+
def update(self, n=1):
|
|
275
|
+
super().update(n)
|
|
276
|
+
|
|
277
|
+
# Update our progress tracker
|
|
278
|
+
if self.current_filename and progress_callback:
|
|
279
|
+
downloaded = getattr(self, 'n', 0)
|
|
280
|
+
total = getattr(self, 'total', 0) or file_sizes.get(self.current_filename, 0)
|
|
281
|
+
|
|
282
|
+
if total > 0:
|
|
283
|
+
progress_tracker.update_file_progress(self.current_filename, downloaded, total)
|
|
284
|
+
|
|
285
|
+
def close(self):
|
|
286
|
+
super().close()
|
|
287
|
+
|
|
288
|
+
# Mark file as completed
|
|
289
|
+
if self.current_filename and progress_callback:
|
|
290
|
+
progress_tracker.complete_file(self.current_filename)
|
|
291
|
+
|
|
157
292
|
last_exception = None
|
|
158
293
|
|
|
159
294
|
for attempt in range(max_retries):
|
|
@@ -166,12 +301,6 @@ def robust_snapshot_download(
|
|
|
166
301
|
|
|
167
302
|
logger.info(f"Download attempt {attempt + 1}/{max_retries} with {workers} workers")
|
|
168
303
|
|
|
169
|
-
# Create a custom progress callback for tqdm
|
|
170
|
-
def tqdm_callback(t):
|
|
171
|
-
def inner(chunk_size):
|
|
172
|
-
t.update(chunk_size)
|
|
173
|
-
return inner
|
|
174
|
-
|
|
175
304
|
result = snapshot_download(
|
|
176
305
|
repo_id=repo_id,
|
|
177
306
|
local_dir=local_dir,
|
|
@@ -181,7 +310,7 @@ def robust_snapshot_download(
|
|
|
181
310
|
resume_download=True, # Enable resume
|
|
182
311
|
etag_timeout=300 + (attempt * 60), # Increase timeout on retries
|
|
183
312
|
force_download=force_download,
|
|
184
|
-
tqdm_class=
|
|
313
|
+
tqdm_class=OllamaStyleTqdm if progress_callback else None
|
|
185
314
|
)
|
|
186
315
|
|
|
187
316
|
if progress_callback:
|
|
@@ -296,8 +425,23 @@ def check_download_integrity(local_dir: str, repo_id: str) -> bool:
|
|
|
296
425
|
if not local_path.exists():
|
|
297
426
|
return False
|
|
298
427
|
|
|
299
|
-
#
|
|
300
|
-
|
|
428
|
+
# Determine model type based on repo_id
|
|
429
|
+
is_controlnet = 'controlnet' in repo_id.lower()
|
|
430
|
+
|
|
431
|
+
# Check for essential files based on model type
|
|
432
|
+
if is_controlnet:
|
|
433
|
+
# ControlNet models have different essential files
|
|
434
|
+
essential_files = ['config.json'] # ControlNet models use config.json instead of model_index.json
|
|
435
|
+
# Also check for model files
|
|
436
|
+
model_files = ['diffusion_pytorch_model.safetensors', 'diffusion_pytorch_model.bin']
|
|
437
|
+
has_model_file = any((local_path / model_file).exists() for model_file in model_files)
|
|
438
|
+
if not has_model_file:
|
|
439
|
+
logger.warning(f"Missing model file: expected one of {model_files}")
|
|
440
|
+
return False
|
|
441
|
+
else:
|
|
442
|
+
# Regular diffusion models
|
|
443
|
+
essential_files = ['model_index.json']
|
|
444
|
+
|
|
301
445
|
for essential_file in essential_files:
|
|
302
446
|
if not (local_path / essential_file).exists():
|
|
303
447
|
logger.warning(f"Missing essential file: {essential_file}")
|
|
@@ -329,24 +473,29 @@ def check_download_integrity(local_dir: str, repo_id: str) -> bool:
|
|
|
329
473
|
logger.warning(f"Empty file detected: {file_path}")
|
|
330
474
|
return False
|
|
331
475
|
|
|
332
|
-
# Check for critical model files
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
476
|
+
# Check for critical model files based on model type
|
|
477
|
+
if is_controlnet:
|
|
478
|
+
# ControlNet models are simpler - just need config.json and model weights
|
|
479
|
+
logger.info("ControlNet model integrity check passed")
|
|
480
|
+
else:
|
|
481
|
+
# Check for critical directories in regular diffusion models
|
|
482
|
+
critical_dirs = ['transformer', 'text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2']
|
|
483
|
+
for critical_dir in critical_dirs:
|
|
484
|
+
dir_path = local_path / critical_dir
|
|
485
|
+
if dir_path.exists():
|
|
486
|
+
# Check if directory has any non-empty files
|
|
487
|
+
has_content = False
|
|
488
|
+
for file_path in dir_path.rglob('*'):
|
|
489
|
+
if file_path.is_file() and file_path.stat().st_size > 0:
|
|
490
|
+
# Skip ignored files
|
|
491
|
+
should_ignore = any(pattern in str(file_path) for pattern in ignore_patterns)
|
|
492
|
+
if not should_ignore:
|
|
493
|
+
has_content = True
|
|
494
|
+
break
|
|
495
|
+
|
|
496
|
+
if not has_content:
|
|
497
|
+
logger.warning(f"Critical directory {critical_dir} appears to be empty or incomplete")
|
|
498
|
+
return False
|
|
350
499
|
|
|
351
500
|
logger.info("Download integrity check passed")
|
|
352
501
|
return True
|