ollamadiffuser 1.0.0__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,317 @@
1
+ import cv2
2
+ import numpy as np
3
+ from PIL import Image
4
+ import logging
5
+ from typing import Optional, Union, Tuple, Dict, Any
6
+ import torch
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ class ControlNetPreprocessorManager:
11
+ """Lazy-loading manager for ControlNet preprocessors"""
12
+
13
+ def __init__(self):
14
+ self.processors = {}
15
+ self._initialized = False
16
+ self._initialization_attempted = False
17
+ self._available_types = []
18
+
19
+ def is_initialized(self) -> bool:
20
+ """Check if preprocessors are initialized"""
21
+ return self._initialized
22
+
23
+ def is_available(self) -> bool:
24
+ """Check if ControlNet preprocessors are available"""
25
+ if not self._initialization_attempted:
26
+ # Try a lightweight check without full initialization
27
+ try:
28
+ import controlnet_aux
29
+ return True
30
+ except ImportError:
31
+ return False
32
+ return self._initialized
33
+
34
+ def initialize(self, force: bool = False) -> bool:
35
+ """
36
+ Initialize ControlNet preprocessors
37
+
38
+ Args:
39
+ force: Force re-initialization even if already initialized
40
+
41
+ Returns:
42
+ True if initialization successful, False otherwise
43
+ """
44
+ if self._initialized and not force:
45
+ return True
46
+
47
+ if self._initialization_attempted and not force:
48
+ return self._initialized
49
+
50
+ self._initialization_attempted = True
51
+ logger.info("Initializing ControlNet preprocessors...")
52
+
53
+ try:
54
+ from controlnet_aux import (
55
+ CannyDetector,
56
+ MidasDetector,
57
+ OpenposeDetector,
58
+ HEDdetector,
59
+ MLSDdetector,
60
+ NormalBaeDetector,
61
+ LineartDetector,
62
+ LineartAnimeDetector,
63
+ ContentShuffleDetector,
64
+ ZoeDetector
65
+ )
66
+
67
+ # Initialize processors with proper error handling
68
+ self.processors = {}
69
+
70
+ # Canny detector (no model download needed)
71
+ try:
72
+ self.processors['canny'] = CannyDetector()
73
+ logger.info("Canny detector initialized")
74
+ except Exception as e:
75
+ logger.warning(f"Failed to initialize Canny detector: {e}")
76
+
77
+ # Depth detectors with fallback
78
+ try:
79
+ # Try MiDaS first with default model
80
+ self.processors['depth'] = MidasDetector.from_pretrained('lllyasviel/Annotators')
81
+ logger.info("MiDaS depth detector initialized")
82
+ except Exception as e:
83
+ logger.warning(f"Failed to initialize MiDaS detector: {e}")
84
+ try:
85
+ # Try ZoeDepth as fallback
86
+ self.processors['depth_zoe'] = ZoeDetector.from_pretrained('lllyasviel/Annotators')
87
+ self.processors['depth'] = self.processors['depth_zoe'] # Use as main depth
88
+ logger.info("ZoeDepth detector initialized as fallback")
89
+ except Exception as e2:
90
+ logger.warning(f"Failed to initialize ZoeDepth detector: {e2}")
91
+
92
+ # OpenPose detector
93
+ try:
94
+ self.processors['openpose'] = OpenposeDetector.from_pretrained('lllyasviel/Annotators')
95
+ logger.info("OpenPose detector initialized")
96
+ except Exception as e:
97
+ logger.warning(f"Failed to initialize OpenPose detector: {e}")
98
+
99
+ # HED detector
100
+ try:
101
+ self.processors['hed'] = HEDdetector.from_pretrained('lllyasviel/Annotators')
102
+ logger.info("HED detector initialized")
103
+ except Exception as e:
104
+ logger.warning(f"Failed to initialize HED detector: {e}")
105
+
106
+ # MLSD detector
107
+ try:
108
+ self.processors['mlsd'] = MLSDdetector.from_pretrained('lllyasviel/Annotators')
109
+ logger.info("MLSD detector initialized")
110
+ except Exception as e:
111
+ logger.warning(f"Failed to initialize MLSD detector: {e}")
112
+
113
+ # Normal detector
114
+ try:
115
+ self.processors['normal'] = NormalBaeDetector.from_pretrained('lllyasviel/Annotators')
116
+ logger.info("Normal detector initialized")
117
+ except Exception as e:
118
+ logger.warning(f"Failed to initialize Normal detector: {e}")
119
+
120
+ # Lineart detectors
121
+ try:
122
+ self.processors['lineart'] = LineartDetector.from_pretrained('lllyasviel/Annotators')
123
+ logger.info("Lineart detector initialized")
124
+ except Exception as e:
125
+ logger.warning(f"Failed to initialize Lineart detector: {e}")
126
+
127
+ try:
128
+ self.processors['lineart_anime'] = LineartAnimeDetector.from_pretrained('lllyasviel/Annotators')
129
+ logger.info("Lineart Anime detector initialized")
130
+ except Exception as e:
131
+ logger.warning(f"Failed to initialize Lineart Anime detector: {e}")
132
+
133
+ # Content shuffle (no model download needed)
134
+ try:
135
+ self.processors['shuffle'] = ContentShuffleDetector()
136
+ logger.info("Content Shuffle detector initialized")
137
+ except Exception as e:
138
+ logger.warning(f"Failed to initialize Content Shuffle detector: {e}")
139
+
140
+ # Add scribble as alias for HED
141
+ if 'hed' in self.processors:
142
+ self.processors['scribble'] = self.processors['hed']
143
+
144
+ if self.processors:
145
+ self._initialized = True
146
+ self._available_types = list(self.processors.keys())
147
+ logger.info(f"ControlNet preprocessors initialized: {self._available_types}")
148
+ return True
149
+ else:
150
+ logger.warning("No ControlNet preprocessors could be initialized, falling back to basic processors")
151
+ self._init_basic_processors()
152
+ return True
153
+
154
+ except ImportError as e:
155
+ logger.warning(f"controlnet-aux not available: {e}")
156
+ # Fallback to basic OpenCV-based processors
157
+ self._init_basic_processors()
158
+ return True
159
+ except Exception as e:
160
+ logger.error(f"Error initializing ControlNet preprocessors: {e}")
161
+ # Fallback to basic OpenCV-based processors
162
+ self._init_basic_processors()
163
+ return True
164
+
165
+ def _init_basic_processors(self):
166
+ """Initialize basic OpenCV-based processors as fallback"""
167
+ logger.info("Using basic OpenCV-based preprocessors")
168
+ self.processors = {
169
+ 'canny': self._canny_opencv,
170
+ 'depth': self._depth_basic,
171
+ 'scribble': self._scribble_basic,
172
+ }
173
+ self._initialized = True
174
+ self._available_types = list(self.processors.keys())
175
+
176
+ def _canny_opencv(self, image: Image.Image, low_threshold: int = 100, high_threshold: int = 200) -> Image.Image:
177
+ """Basic Canny edge detection using OpenCV"""
178
+ # Convert PIL to OpenCV format
179
+ image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
180
+
181
+ # Apply Canny edge detection
182
+ edges = cv2.Canny(image_cv, low_threshold, high_threshold)
183
+
184
+ # Convert back to PIL
185
+ edges_rgb = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB)
186
+ return Image.fromarray(edges_rgb)
187
+
188
+ def _depth_basic(self, image: Image.Image) -> Image.Image:
189
+ """Basic depth estimation using simple gradients"""
190
+ # Convert to grayscale
191
+ gray = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
192
+
193
+ # Apply Gaussian blur
194
+ blurred = cv2.GaussianBlur(gray, (5, 5), 0)
195
+
196
+ # Create a simple depth map using gradients
197
+ grad_x = cv2.Sobel(blurred, cv2.CV_64F, 1, 0, ksize=3)
198
+ grad_y = cv2.Sobel(blurred, cv2.CV_64F, 0, 1, ksize=3)
199
+
200
+ # Combine gradients
201
+ depth = np.sqrt(grad_x**2 + grad_y**2)
202
+ depth = np.uint8(255 * depth / np.max(depth))
203
+
204
+ # Convert to RGB
205
+ depth_rgb = cv2.cvtColor(depth, cv2.COLOR_GRAY2RGB)
206
+ return Image.fromarray(depth_rgb)
207
+
208
+ def _scribble_basic(self, image: Image.Image) -> Image.Image:
209
+ """Basic scribble detection using edge detection"""
210
+ # Convert to grayscale
211
+ gray = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
212
+
213
+ # Apply edge detection
214
+ edges = cv2.Canny(gray, 50, 150)
215
+
216
+ # Dilate to make lines thicker
217
+ kernel = np.ones((3, 3), np.uint8)
218
+ edges = cv2.dilate(edges, kernel, iterations=1)
219
+
220
+ # Convert to RGB
221
+ edges_rgb = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB)
222
+ return Image.fromarray(edges_rgb)
223
+
224
+ def preprocess(self,
225
+ image: Union[Image.Image, str],
226
+ control_type: str,
227
+ **kwargs) -> Image.Image:
228
+ """
229
+ Preprocess image for ControlNet
230
+
231
+ Args:
232
+ image: Input image (PIL Image or path)
233
+ control_type: Type of control (canny, depth, openpose, etc.)
234
+ **kwargs: Additional parameters for specific processors
235
+
236
+ Returns:
237
+ Preprocessed control image
238
+ """
239
+ # Initialize if not already done
240
+ if not self._initialized:
241
+ if not self.initialize():
242
+ raise RuntimeError("Failed to initialize ControlNet preprocessors")
243
+
244
+ # Load image if path is provided
245
+ if isinstance(image, str):
246
+ image = Image.open(image).convert('RGB')
247
+ elif not isinstance(image, Image.Image):
248
+ raise ValueError("Image must be PIL Image or file path")
249
+
250
+ # Ensure image is RGB
251
+ if image.mode != 'RGB':
252
+ image = image.convert('RGB')
253
+
254
+ # Get processor
255
+ if control_type not in self.processors:
256
+ available = list(self.processors.keys())
257
+ raise ValueError(f"Control type '{control_type}' not available. Available: {available}")
258
+
259
+ processor = self.processors[control_type]
260
+
261
+ try:
262
+ if callable(processor):
263
+ # Basic OpenCV processor
264
+ return processor(image, **kwargs)
265
+ else:
266
+ # ControlNet-aux processor
267
+ return processor(image, **kwargs)
268
+
269
+ except Exception as e:
270
+ logger.error(f"Failed to preprocess image with {control_type}: {e}")
271
+ # Return original image as fallback
272
+ return image
273
+
274
+ def get_available_types(self) -> list:
275
+ """Get list of available control types"""
276
+ if not self._initialized:
277
+ # Return cached types if available, otherwise return basic types
278
+ if self._available_types:
279
+ return self._available_types
280
+ elif self.is_available():
281
+ return ['canny', 'depth', 'openpose', 'hed', 'mlsd', 'normal', 'lineart', 'lineart_anime', 'shuffle', 'scribble']
282
+ else:
283
+ return ['canny', 'depth', 'scribble'] # Basic OpenCV types
284
+ return list(self.processors.keys())
285
+
286
+ def resize_for_controlnet(self,
287
+ image: Image.Image,
288
+ width: int = 512,
289
+ height: int = 512) -> Image.Image:
290
+ """Resize image for ControlNet while maintaining aspect ratio"""
291
+ # Calculate aspect ratio
292
+ aspect_ratio = image.width / image.height
293
+
294
+ if aspect_ratio > 1:
295
+ # Landscape
296
+ new_width = width
297
+ new_height = int(width / aspect_ratio)
298
+ else:
299
+ # Portrait
300
+ new_height = height
301
+ new_width = int(height * aspect_ratio)
302
+
303
+ # Resize image
304
+ resized = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
305
+
306
+ # Create new image with target size and paste resized image
307
+ result = Image.new('RGB', (width, height), (0, 0, 0))
308
+
309
+ # Calculate position to center the image
310
+ x = (width - new_width) // 2
311
+ y = (height - new_height) // 2
312
+
313
+ result.paste(resized, (x, y))
314
+ return result
315
+
316
+ # Global manager instance - no initialization at import time
317
+ controlnet_preprocessor = ControlNetPreprocessorManager()
@@ -11,25 +11,107 @@ from pathlib import Path
11
11
  from huggingface_hub import snapshot_download, hf_hub_download, HfApi
12
12
  from tqdm import tqdm
13
13
  import threading
14
+ import requests
14
15
 
15
16
  logger = logging.getLogger(__name__)
16
17
 
17
- class ProgressTracker:
18
- """Track download progress across multiple files"""
18
+ class EnhancedProgressTracker:
19
+ """Enhanced progress tracker that provides Ollama-style detailed progress information"""
19
20
 
20
21
  def __init__(self, total_files: int = 0, progress_callback: Optional[Callable] = None):
21
22
  self.total_files = total_files
22
23
  self.completed_files = 0
23
24
  self.current_file = ""
24
25
  self.file_progress = {}
26
+ self.file_start_times = {}
27
+ self.file_speeds = {}
25
28
  self.progress_callback = progress_callback
26
29
  self.lock = threading.Lock()
30
+ self.overall_start_time = time.time()
31
+ self.total_size = 0
32
+ self.downloaded_size = 0
27
33
 
34
+ def set_total_size(self, total_size: int):
35
+ """Set the total size for all files"""
36
+ with self.lock:
37
+ self.total_size = total_size
38
+
39
+ def start_file(self, filename: str, file_size: int = 0):
40
+ """Mark a file as started"""
41
+ with self.lock:
42
+ self.current_file = filename
43
+ self.file_start_times[filename] = time.time()
44
+ self.file_progress[filename] = (0, file_size)
45
+
46
+ # Extract hash-like identifier for Ollama-style display
47
+ import re
48
+ hash_match = re.search(r'([a-f0-9]{8,})', filename)
49
+ if hash_match:
50
+ display_name = hash_match.group(1)[:12] # First 12 characters
51
+ else:
52
+ # Fallback to filename without extension
53
+ display_name = Path(filename).stem[:12]
54
+
55
+ if self.progress_callback:
56
+ self.progress_callback(f"pulling {display_name}")
57
+
28
58
  def update_file_progress(self, filename: str, downloaded: int, total: int):
29
- """Update progress for a specific file"""
59
+ """Update progress for a specific file with speed calculation"""
30
60
  with self.lock:
61
+ current_time = time.time()
62
+
63
+ # Update file progress
64
+ old_downloaded = self.file_progress.get(filename, (0, 0))[0]
31
65
  self.file_progress[filename] = (downloaded, total)
32
- self._report_progress()
66
+
67
+ # Update overall downloaded size
68
+ size_diff = downloaded - old_downloaded
69
+ self.downloaded_size += size_diff
70
+
71
+ # Calculate speed for this file
72
+ if filename in self.file_start_times:
73
+ elapsed = current_time - self.file_start_times[filename]
74
+ if elapsed > 0 and downloaded > 0:
75
+ speed = downloaded / elapsed # bytes per second
76
+ self.file_speeds[filename] = speed
77
+
78
+ # Report progress in Ollama style
79
+ if self.progress_callback and total > 0:
80
+ percentage = (downloaded / total) * 100
81
+
82
+ # Format sizes
83
+ downloaded_mb = downloaded / (1024 * 1024)
84
+ total_mb = total / (1024 * 1024)
85
+
86
+ # Calculate speed in MB/s
87
+ speed_mbps = self.file_speeds.get(filename, 0) / (1024 * 1024)
88
+
89
+ # Calculate ETA
90
+ if speed_mbps > 0:
91
+ remaining_mb = total_mb - downloaded_mb
92
+ eta_seconds = remaining_mb / speed_mbps
93
+ eta_min = int(eta_seconds // 60)
94
+ eta_sec = int(eta_seconds % 60)
95
+ eta_str = f"{eta_min}m{eta_sec:02d}s"
96
+ else:
97
+ eta_str = "?"
98
+
99
+ # Extract hash for display
100
+ import re
101
+ hash_match = re.search(r'([a-f0-9]{8,})', filename)
102
+ if hash_match:
103
+ display_name = hash_match.group(1)[:12]
104
+ else:
105
+ display_name = Path(filename).stem[:12]
106
+
107
+ # Create progress bar
108
+ bar_width = 20
109
+ filled = int((percentage / 100) * bar_width)
110
+ bar = "█" * filled + " " * (bar_width - filled)
111
+
112
+ progress_msg = f"pulling {display_name}: {percentage:3.0f}% ▕{bar}▏ {downloaded_mb:.0f} MB/{total_mb:.0f} MB {speed_mbps:.0f} MB/s {eta_str}"
113
+
114
+ self.progress_callback(progress_msg)
33
115
 
34
116
  def complete_file(self, filename: str):
35
117
  """Mark a file as completed"""
@@ -38,34 +120,42 @@ class ProgressTracker:
38
120
  if filename in self.file_progress:
39
121
  downloaded, total = self.file_progress[filename]
40
122
  self.file_progress[filename] = (total, total)
41
- self._report_progress()
42
-
43
- def set_current_file(self, filename: str):
44
- """Set the currently downloading file"""
45
- with self.lock:
46
- self.current_file = filename
47
- self._report_progress()
123
+
124
+ # Report completion
125
+ if self.progress_callback:
126
+ import re
127
+ hash_match = re.search(r'([a-f0-9]{8,})', filename)
128
+ if hash_match:
129
+ display_name = hash_match.group(1)[:12]
130
+ else:
131
+ display_name = Path(filename).stem[:12]
132
+
133
+ total_mb = self.file_progress.get(filename, (0, 0))[1] / (1024 * 1024)
134
+ self.progress_callback(f"pulling {display_name}: 100% ▕████████████████████▏ {total_mb:.0f} MB/{total_mb:.0f} MB")
48
135
 
49
- def _report_progress(self):
50
- """Report current progress"""
136
+ def report_overall_progress(self):
137
+ """Report overall progress"""
51
138
  if self.progress_callback:
52
- # Calculate overall progress
53
- total_downloaded = 0
54
- total_size = 0
55
-
56
- for downloaded, size in self.file_progress.values():
57
- total_downloaded += downloaded
58
- total_size += size
59
-
60
- progress_msg = f"Files: {self.completed_files}/{self.total_files}"
61
- if total_size > 0:
62
- percent = (total_downloaded / total_size) * 100
63
- progress_msg += f" | Overall: {percent:.1f}%"
64
-
65
- if self.current_file:
66
- progress_msg += f" | Current: {self.current_file}"
67
-
68
- self.progress_callback(progress_msg)
139
+ if self.total_size > 0:
140
+ overall_percent = (self.downloaded_size / self.total_size) * 100
141
+ downloaded_gb = self.downloaded_size / (1024 * 1024 * 1024)
142
+ total_gb = self.total_size / (1024 * 1024 * 1024)
143
+
144
+ elapsed = time.time() - self.overall_start_time
145
+ if elapsed > 0:
146
+ overall_speed = self.downloaded_size / elapsed / (1024 * 1024) # MB/s
147
+
148
+ if overall_speed > 0:
149
+ remaining_gb = total_gb - downloaded_gb
150
+ eta_seconds = (remaining_gb * 1024) / overall_speed # Convert GB to MB for calculation
151
+ eta_min = int(eta_seconds // 60)
152
+ eta_sec = int(eta_seconds % 60)
153
+ eta_str = f"{eta_min}m{eta_sec:02d}s"
154
+ else:
155
+ eta_str = "?"
156
+
157
+ progress_msg = f"Overall progress: {overall_percent:.1f}% | {downloaded_gb:.1f} GB/{total_gb:.1f} GB | {overall_speed:.1f} MB/s | ETA: {eta_str}"
158
+ self.progress_callback(progress_msg)
69
159
 
70
160
  def configure_hf_environment():
71
161
  """Configure HuggingFace Hub environment for better downloads"""
@@ -132,7 +222,7 @@ def robust_snapshot_download(
132
222
 
133
223
  # Get file list and sizes for progress tracking
134
224
  if progress_callback:
135
- progress_callback("📋 Getting repository information...")
225
+ progress_callback("pulling manifest")
136
226
 
137
227
  file_sizes = get_repo_file_list(repo_id)
138
228
  total_size = sum(file_sizes.values())
@@ -140,20 +230,65 @@ def robust_snapshot_download(
140
230
  if progress_callback and file_sizes:
141
231
  progress_callback(f"📦 Repository: {len(file_sizes)} files, {format_size(total_size)} total")
142
232
 
233
+ # Initialize enhanced progress tracker
234
+ progress_tracker = EnhancedProgressTracker(len(file_sizes), progress_callback)
235
+ progress_tracker.set_total_size(total_size)
236
+
143
237
  # Check what's already downloaded
144
238
  local_path = Path(local_dir)
239
+ existing_size = 0
145
240
  if local_path.exists() and not force_download:
146
241
  existing_files = []
147
- existing_size = 0
148
242
  for file_path in local_path.rglob('*'):
149
243
  if file_path.is_file():
150
244
  rel_path = file_path.relative_to(local_path)
151
245
  existing_files.append(str(rel_path))
152
- existing_size += file_path.stat().st_size
246
+ file_size = file_path.stat().st_size
247
+ existing_size += file_size
248
+ # Mark existing files as completed in progress tracker
249
+ progress_tracker.file_progress[str(rel_path)] = (file_size, file_size)
250
+ progress_tracker.downloaded_size += file_size
251
+ progress_tracker.completed_files += 1
153
252
 
154
253
  if progress_callback and existing_files:
155
254
  progress_callback(f"📁 Found {len(existing_files)} existing files ({format_size(existing_size)})")
156
255
 
256
+ # Custom tqdm class to capture HuggingFace download progress
257
+ class OllamaStyleTqdm(tqdm):
258
+ def __init__(self, *args, **kwargs):
259
+ # Extract description to get filename
260
+ desc = kwargs.get('desc', '')
261
+ self.current_filename = desc
262
+
263
+ # Get file size from our pre-fetched data
264
+ file_size = file_sizes.get(self.current_filename, 0)
265
+ if file_size > 0:
266
+ kwargs['total'] = file_size
267
+
268
+ super().__init__(*args, **kwargs)
269
+
270
+ # Start tracking this file
271
+ if self.current_filename and progress_callback:
272
+ progress_tracker.start_file(self.current_filename, file_size)
273
+
274
+ def update(self, n=1):
275
+ super().update(n)
276
+
277
+ # Update our progress tracker
278
+ if self.current_filename and progress_callback:
279
+ downloaded = getattr(self, 'n', 0)
280
+ total = getattr(self, 'total', 0) or file_sizes.get(self.current_filename, 0)
281
+
282
+ if total > 0:
283
+ progress_tracker.update_file_progress(self.current_filename, downloaded, total)
284
+
285
+ def close(self):
286
+ super().close()
287
+
288
+ # Mark file as completed
289
+ if self.current_filename and progress_callback:
290
+ progress_tracker.complete_file(self.current_filename)
291
+
157
292
  last_exception = None
158
293
 
159
294
  for attempt in range(max_retries):
@@ -166,12 +301,6 @@ def robust_snapshot_download(
166
301
 
167
302
  logger.info(f"Download attempt {attempt + 1}/{max_retries} with {workers} workers")
168
303
 
169
- # Create a custom progress callback for tqdm
170
- def tqdm_callback(t):
171
- def inner(chunk_size):
172
- t.update(chunk_size)
173
- return inner
174
-
175
304
  result = snapshot_download(
176
305
  repo_id=repo_id,
177
306
  local_dir=local_dir,
@@ -181,7 +310,7 @@ def robust_snapshot_download(
181
310
  resume_download=True, # Enable resume
182
311
  etag_timeout=300 + (attempt * 60), # Increase timeout on retries
183
312
  force_download=force_download,
184
- tqdm_class=tqdm if progress_callback else None
313
+ tqdm_class=OllamaStyleTqdm if progress_callback else None
185
314
  )
186
315
 
187
316
  if progress_callback:
@@ -296,8 +425,23 @@ def check_download_integrity(local_dir: str, repo_id: str) -> bool:
296
425
  if not local_path.exists():
297
426
  return False
298
427
 
299
- # Check for essential files
300
- essential_files = ['model_index.json']
428
+ # Determine model type based on repo_id
429
+ is_controlnet = 'controlnet' in repo_id.lower()
430
+
431
+ # Check for essential files based on model type
432
+ if is_controlnet:
433
+ # ControlNet models have different essential files
434
+ essential_files = ['config.json'] # ControlNet models use config.json instead of model_index.json
435
+ # Also check for model files
436
+ model_files = ['diffusion_pytorch_model.safetensors', 'diffusion_pytorch_model.bin']
437
+ has_model_file = any((local_path / model_file).exists() for model_file in model_files)
438
+ if not has_model_file:
439
+ logger.warning(f"Missing model file: expected one of {model_files}")
440
+ return False
441
+ else:
442
+ # Regular diffusion models
443
+ essential_files = ['model_index.json']
444
+
301
445
  for essential_file in essential_files:
302
446
  if not (local_path / essential_file).exists():
303
447
  logger.warning(f"Missing essential file: {essential_file}")
@@ -329,24 +473,29 @@ def check_download_integrity(local_dir: str, repo_id: str) -> bool:
329
473
  logger.warning(f"Empty file detected: {file_path}")
330
474
  return False
331
475
 
332
- # Check for critical model files
333
- critical_dirs = ['transformer', 'text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2']
334
- for critical_dir in critical_dirs:
335
- dir_path = local_path / critical_dir
336
- if dir_path.exists():
337
- # Check if directory has any non-empty files
338
- has_content = False
339
- for file_path in dir_path.rglob('*'):
340
- if file_path.is_file() and file_path.stat().st_size > 0:
341
- # Skip ignored files
342
- should_ignore = any(pattern in str(file_path) for pattern in ignore_patterns)
343
- if not should_ignore:
344
- has_content = True
345
- break
346
-
347
- if not has_content:
348
- logger.warning(f"Critical directory {critical_dir} appears to be empty or incomplete")
349
- return False
476
+ # Check for critical model files based on model type
477
+ if is_controlnet:
478
+ # ControlNet models are simpler - just need config.json and model weights
479
+ logger.info("ControlNet model integrity check passed")
480
+ else:
481
+ # Check for critical directories in regular diffusion models
482
+ critical_dirs = ['transformer', 'text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2']
483
+ for critical_dir in critical_dirs:
484
+ dir_path = local_path / critical_dir
485
+ if dir_path.exists():
486
+ # Check if directory has any non-empty files
487
+ has_content = False
488
+ for file_path in dir_path.rglob('*'):
489
+ if file_path.is_file() and file_path.stat().st_size > 0:
490
+ # Skip ignored files
491
+ should_ignore = any(pattern in str(file_path) for pattern in ignore_patterns)
492
+ if not should_ignore:
493
+ has_content = True
494
+ break
495
+
496
+ if not has_content:
497
+ logger.warning(f"Critical directory {critical_dir} appears to be empty or incomplete")
498
+ return False
350
499
 
351
500
  logger.info("Download integrity check passed")
352
501
  return True