ollamadiffuser 1.1.6__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,771 @@
1
+ """
2
+ GGUF Model Loader and Interface
3
+
4
+ This module provides support for loading and running GGUF quantized models,
5
+ specifically for FLUX.1-dev-gguf variants using stable-diffusion.cpp Python bindings.
6
+ """
7
+
8
+ import os
9
+ import logging
10
+ from pathlib import Path
11
+ from typing import Dict, Any, Optional, List, Union
12
+ import torch
13
+ from PIL import Image
14
+ import numpy as np
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ try:
19
+ from stable_diffusion_cpp import StableDiffusion
20
+ GGUF_AVAILABLE = True
21
+ logger.info("stable-diffusion-cpp-python is available")
22
+ except ImportError:
23
+ StableDiffusion = None
24
+ GGUF_AVAILABLE = False
25
+ logger.warning("stable-diffusion-cpp-python not available. GGUF models will not work.")
26
+
27
+ class GGUFModelLoader:
28
+ """Loader for GGUF quantized diffusion models using stable-diffusion.cpp"""
29
+
30
+ def __init__(self):
31
+ self.model = None
32
+ self.model_path = None
33
+ self.model_config = None
34
+ self.loaded_model_name = None
35
+ self.stable_diffusion = None
36
+
37
+ def is_gguf_model(self, model_name: str, model_config: Dict[str, Any]) -> bool:
38
+ """Check if a model is a GGUF model"""
39
+ variant = model_config.get('variant', '')
40
+ return 'gguf' in variant.lower() or model_name.endswith('-gguf') or 'gguf' in model_name.lower()
41
+
42
+ def get_gguf_file_path(self, model_dir: Path, variant: str) -> Optional[Path]:
43
+ """Find the appropriate GGUF file based on variant"""
44
+ if not model_dir.exists():
45
+ return None
46
+
47
+ # Map variant to actual file names
48
+ variant_mapping = {
49
+ # FLUX.1-dev variants
50
+ 'gguf-q2k': 'flux1-dev-Q2_K.gguf',
51
+ 'gguf-q3ks': 'flux1-dev-Q3_K_S.gguf',
52
+ 'gguf-q4ks': 'flux1-dev-Q4_K_S.gguf',
53
+ 'gguf-q4-0': 'flux1-dev-Q4_0.gguf',
54
+ 'gguf-q4-1': 'flux1-dev-Q4_1.gguf',
55
+ 'gguf-q5ks': 'flux1-dev-Q5_K_S.gguf',
56
+ 'gguf-q5-0': 'flux1-dev-Q5_0.gguf',
57
+ 'gguf-q5-1': 'flux1-dev-Q5_1.gguf',
58
+ 'gguf-q6k': 'flux1-dev-Q6_K.gguf',
59
+ 'gguf-q8': 'flux1-dev-Q8_0.gguf',
60
+ 'gguf-f16': 'flux1-dev-F16.gguf',
61
+
62
+ # FLUX.1-schnell variants
63
+ 'gguf-schnell': 'flux1-schnell-F16.gguf', # Default to F16
64
+ 'gguf-schnell-q2k': 'flux1-schnell-Q2_K.gguf',
65
+ 'gguf-schnell-q3ks': 'flux1-schnell-Q3_K_S.gguf',
66
+ 'gguf-schnell-q4-0': 'flux1-schnell-Q4_0.gguf',
67
+ 'gguf-schnell-q4-1': 'flux1-schnell-Q4_1.gguf',
68
+ 'gguf-schnell-q4ks': 'flux1-schnell-Q4_K_S.gguf',
69
+ 'gguf-schnell-q5-0': 'flux1-schnell-Q5_0.gguf',
70
+ 'gguf-schnell-q5-1': 'flux1-schnell-Q5_1.gguf',
71
+ 'gguf-schnell-q5ks': 'flux1-schnell-Q5_K_S.gguf',
72
+ 'gguf-schnell-q6k': 'flux1-schnell-Q6_K.gguf',
73
+ 'gguf-schnell-q8': 'flux1-schnell-Q8_0.gguf',
74
+ 'gguf-schnell-f16': 'flux1-schnell-F16.gguf',
75
+
76
+ # Stable Diffusion 3.5 Large variants
77
+ 'gguf-large': 'sd3.5_large-F16.gguf', # Default to F16
78
+ 'gguf-large-q4-0': 'sd3.5_large-Q4_0.gguf',
79
+ 'gguf-large-q4-1': 'sd3.5_large-Q4_1.gguf',
80
+ 'gguf-large-q5-0': 'sd3.5_large-Q5_0.gguf',
81
+ 'gguf-large-q5-1': 'sd3.5_large-Q5_1.gguf',
82
+ 'gguf-large-q8-0': 'sd3.5_large-Q8_0.gguf',
83
+ 'gguf-large-f16': 'sd3.5_large-F16.gguf',
84
+
85
+ # Stable Diffusion 3.5 Large Turbo variants
86
+ 'gguf-large-turbo': 'sd3.5_large_turbo.gguf', # Default to standard format
87
+ 'gguf-large-turbo-q4-0': 'sd3.5_large_turbo-Q4_0.gguf',
88
+ 'gguf-large-turbo-q4-1': 'sd3.5_large_turbo-Q4_1.gguf',
89
+ 'gguf-large-turbo-q5-0': 'sd3.5_large_turbo-Q5_0.gguf',
90
+ 'gguf-large-turbo-q5-1': 'sd3.5_large_turbo-Q5_1.gguf',
91
+ 'gguf-large-turbo-q8-0': 'sd3.5_large_turbo-Q8_0.gguf',
92
+ 'gguf-large-turbo-f16': 'sd3.5_large_turbo-F16.gguf',
93
+
94
+ # Other model variants
95
+ 'gguf-medium': 'sd3.5-medium-F16.gguf',
96
+ 'gguf-sd3-medium': 'sd3-medium-F16.gguf',
97
+ 'gguf-lite': 'flux-lite-8b-F16.gguf',
98
+ 'gguf-distilled': 'flux-dev-de-distill-F16.gguf',
99
+ 'gguf-fill': 'flux-fill-dev-F16.gguf',
100
+ 'gguf-full': 'hidream-i1-full-F16.gguf',
101
+ 'gguf-dev': 'hidream-i1-dev-F16.gguf',
102
+ 'gguf-fast': 'hidream-i1-fast-F16.gguf',
103
+ 'gguf-i2v': 'ltx-video-i2v-F16.gguf',
104
+ 'gguf-2b': 'ltx-video-2b-F16.gguf',
105
+ 'gguf-t2v': 'hunyuan-video-t2v-F16.gguf',
106
+
107
+ 'gguf': 'flux1-dev-Q4_K_S.gguf', # Default to Q4_K_S
108
+ }
109
+
110
+ filename = variant_mapping.get(variant.lower())
111
+ if filename:
112
+ gguf_file = model_dir / filename
113
+ if gguf_file.exists():
114
+ return gguf_file
115
+
116
+ # Fallback: search for any .gguf file
117
+ gguf_files = list(model_dir.glob('*.gguf'))
118
+ if gguf_files:
119
+ return gguf_files[0] # Return first found
120
+
121
+ return None
122
+
123
+ def get_additional_model_files(self, model_dir: Path) -> Dict[str, Optional[Path]]:
124
+ """Find additional model files required for FLUX GGUF inference"""
125
+ files = {
126
+ 'vae': None,
127
+ 'clip_l': None,
128
+ 't5xxl': None
129
+ }
130
+
131
+ # Common file patterns for FLUX models
132
+ vae_patterns = ['ae.safetensors', 'vae.safetensors', 'flux_vae.safetensors']
133
+ clip_l_patterns = ['clip_l.safetensors', 'text_encoder.safetensors']
134
+ t5xxl_patterns = ['t5xxl_fp16.safetensors', 't5xxl.safetensors', 't5_encoder.safetensors']
135
+
136
+ # Search for VAE
137
+ for pattern in vae_patterns:
138
+ vae_file = model_dir / pattern
139
+ if vae_file.exists():
140
+ files['vae'] = vae_file
141
+ break
142
+
143
+ # Search for CLIP-L
144
+ for pattern in clip_l_patterns:
145
+ clip_file = model_dir / pattern
146
+ if clip_file.exists():
147
+ files['clip_l'] = clip_file
148
+ break
149
+
150
+ # Search for T5XXL
151
+ for pattern in t5xxl_patterns:
152
+ t5_file = model_dir / pattern
153
+ if t5_file.exists():
154
+ files['t5xxl'] = t5_file
155
+ break
156
+
157
+ return files
158
+
159
+ def load_model(self, model_config: Dict[str, Any], model_name: str = None, model_path: Path = None) -> bool:
160
+ """Load GGUF model using stable-diffusion.cpp"""
161
+ # Extract parameters from model_config if not provided separately
162
+ if model_name is None:
163
+ model_name = model_config.get('name', 'unknown')
164
+ if model_path is None:
165
+ model_path = Path(model_config.get('path', ''))
166
+
167
+ logger.info(f"Loading GGUF model: {model_name}")
168
+
169
+ try:
170
+ # Find the GGUF file
171
+ gguf_files = list(model_path.glob("*.gguf"))
172
+ if not gguf_files:
173
+ logger.error(f"No GGUF files found in {model_path}")
174
+ return False
175
+
176
+ gguf_file = gguf_files[0] # Use the first GGUF file found
177
+ logger.info(f"Using GGUF file: {gguf_file}")
178
+
179
+ # Download required components
180
+ components = self.download_required_components(model_path)
181
+
182
+ # Detect model type for appropriate validation
183
+ is_sd35 = any(pattern in model_name.lower() for pattern in ['3.5', 'sd3.5', 'stable-diffusion-3-5'])
184
+
185
+ # Validate components based on model type
186
+ if is_sd35:
187
+ # SD 3.5 models need VAE, CLIP-L, CLIP-G, and T5XXL
188
+ required_components = ['vae', 'clip_l', 'clip_g', 't5xxl']
189
+ missing_components = [name for name in required_components if not components.get(name)]
190
+ if missing_components:
191
+ logger.error(f"Missing required SD 3.5 components: {missing_components}")
192
+ return False
193
+ else:
194
+ # FLUX models need VAE, CLIP-L, and T5XXL (no CLIP-G)
195
+ required_components = ['vae', 'clip_l', 't5xxl']
196
+ missing_components = [name for name in required_components if not components.get(name)]
197
+ if missing_components:
198
+ logger.error(f"Missing required FLUX components: {missing_components}")
199
+ return False
200
+
201
+ # Initialize the stable-diffusion.cpp model
202
+ logger.info("Loading GGUF model with stable-diffusion.cpp...")
203
+
204
+ if is_sd35:
205
+ logger.info("Detected SD 3.5 model - using appropriate configuration")
206
+
207
+ sd_params = {
208
+ 'diffusion_model_path': str(gguf_file),
209
+ 'n_threads': 4
210
+ }
211
+
212
+ if components['vae']:
213
+ sd_params['vae_path'] = str(components['vae'])
214
+ if components['clip_l']:
215
+ sd_params['clip_l_path'] = str(components['clip_l'])
216
+ if components['clip_g']:
217
+ sd_params['clip_g_path'] = str(components['clip_g'])
218
+ if components['t5xxl']:
219
+ sd_params['t5xxl_path'] = str(components['t5xxl'])
220
+
221
+ logger.info(f"Initializing SD 3.5 model with params: {sd_params}")
222
+ self.stable_diffusion = StableDiffusion(**sd_params)
223
+
224
+ else:
225
+ # FLUX models use different parameter structure
226
+ logger.info("Detected FLUX model - using CLIP-L and T5-XXL configuration")
227
+ self.stable_diffusion = StableDiffusion(
228
+ diffusion_model_path=str(gguf_file),
229
+ vae_path=str(components['vae']),
230
+ clip_l_path=str(components['clip_l']),
231
+ t5xxl_path=str(components['t5xxl']),
232
+ vae_decode_only=True,
233
+ n_threads=-1
234
+ )
235
+
236
+ self.model_path = str(gguf_file)
237
+ self.model_config = model_config
238
+ self.loaded_model_name = model_name
239
+
240
+ logger.info(f"Successfully loaded GGUF model: {model_name}")
241
+ return True
242
+
243
+ except Exception as e:
244
+ logger.error(f"Failed to load GGUF model {model_name}: {e}")
245
+ if hasattr(self, 'stable_diffusion') and self.stable_diffusion:
246
+ self.stable_diffusion = None
247
+ return False
248
+
249
+ def generate_image(self, prompt: str, **kwargs) -> Optional[Image.Image]:
250
+ """Generate image using stable-diffusion.cpp FLUX inference"""
251
+ if not self.stable_diffusion:
252
+ logger.error("GGUF model not loaded")
253
+ return None
254
+
255
+ try:
256
+ # Extract parameters with FLUX-optimized defaults
257
+ # Support both parameter naming conventions for compatibility
258
+ width = kwargs.get('width', 1024)
259
+ height = kwargs.get('height', 1024)
260
+
261
+ # Support both 'steps' and 'num_inference_steps' - ensure not None
262
+ steps = kwargs.get('steps') or kwargs.get('num_inference_steps') or 20
263
+
264
+ # Support both 'cfg_scale' and 'guidance_scale' - FLUX works best with low CFG - ensure not None
265
+ cfg_scale = kwargs.get('cfg_scale') or kwargs.get('guidance_scale') or 1.0
266
+
267
+ seed = kwargs.get('seed', 42)
268
+ negative_prompt = kwargs.get('negative_prompt', "")
269
+
270
+ # Allow custom sampler, with FLUX-optimized default
271
+ sampler = kwargs.get('sampler', kwargs.get('sample_method', 'dpmpp2m'))
272
+
273
+ # Validate sampler and provide fallback
274
+ valid_samplers = ['euler_a', 'euler', 'heun', 'dpm2', 'dpmpp2s_a', 'dpmpp2m', 'dpmpp2mv2', 'ipndm', 'ipndm_v', 'lcm', 'ddim_trailing', 'tcd']
275
+ if sampler not in valid_samplers:
276
+ logger.warning(f"Invalid sampler '{sampler}', falling back to 'dpmpp2m'")
277
+ sampler = 'dpmpp2m'
278
+
279
+ # Ensure all values are proper types and not None
280
+ steps = int(steps) if steps is not None else 20
281
+ cfg_scale = float(cfg_scale) if cfg_scale is not None else 1.0
282
+ width = int(width) if width is not None else 1024
283
+ height = int(height) if height is not None else 1024
284
+ seed = int(seed) if seed is not None else 42
285
+
286
+ logger.info(f"Generating image: {width}x{height}, steps={steps}, cfg={cfg_scale}, sampler={sampler}, negative_prompt={negative_prompt}")
287
+
288
+ # Log model quantization info for quality assessment
289
+ if hasattr(self, 'model_path'):
290
+ if 'Q2' in str(self.model_path):
291
+ logger.warning("Using Q2 quantization - expect lower quality. Consider Q4_K_S or higher for better results.")
292
+ elif 'Q3' in str(self.model_path):
293
+ logger.info("Using Q3 quantization - moderate quality. Consider Q4_K_S or higher for better results.")
294
+ elif 'Q4' in str(self.model_path):
295
+ logger.info("Using Q4 quantization - good balance of quality and size.")
296
+ elif any(x in str(self.model_path) for x in ['Q5', 'Q6', 'Q8', 'F16']):
297
+ logger.info("Using high precision quantization - excellent quality expected.")
298
+
299
+ # Generate image using stable-diffusion.cpp
300
+ # According to the documentation, txt_to_img returns a list of PIL Images
301
+ try:
302
+ result = self.stable_diffusion.txt_to_img(
303
+ prompt=prompt,
304
+ negative_prompt=negative_prompt if negative_prompt else "",
305
+ cfg_scale=cfg_scale,
306
+ width=width,
307
+ height=height,
308
+ sample_method=sampler,
309
+ sample_steps=steps,
310
+ seed=seed
311
+ )
312
+ logger.info(f"txt_to_img returned: {type(result)}, length: {len(result) if result else 'None'}")
313
+ except Exception as e:
314
+ logger.error(f"txt_to_img call failed: {e}")
315
+ return None
316
+
317
+ if not result:
318
+ logger.error("txt_to_img returned None")
319
+ return None
320
+
321
+ if not isinstance(result, list) or len(result) == 0:
322
+ logger.error(f"txt_to_img returned unexpected format: {type(result)}")
323
+ return None
324
+
325
+ # Get the first PIL Image from the result list
326
+ image = result[0]
327
+ logger.info(f"Retrieved PIL Image: {type(image)}")
328
+
329
+ # Verify it's a PIL Image
330
+ if not hasattr(image, 'save'):
331
+ logger.error(f"Result[0] is not a PIL Image: {type(image)}")
332
+ return None
333
+
334
+ # Optionally save a copy for debugging/history
335
+ try:
336
+ from ..config.settings import settings
337
+ output_dir = settings.config_dir / "outputs"
338
+ output_dir.mkdir(exist_ok=True)
339
+
340
+ output_path = output_dir / f"gguf_output_{seed}.png"
341
+ image.save(output_path)
342
+ logger.info(f"Generated image also saved to: {output_path}")
343
+ except Exception as e:
344
+ logger.warning(f"Failed to save debug copy: {e}")
345
+
346
+ # Return the PIL Image directly for API compatibility
347
+ logger.info("Returning PIL Image for API use")
348
+ return image
349
+
350
+ except Exception as e:
351
+ logger.error(f"Failed to generate image with GGUF model: {e}")
352
+ import traceback
353
+ logger.error(f"Traceback: {traceback.format_exc()}")
354
+ return None
355
+
356
+ def unload_model(self):
357
+ """Unload the GGUF model"""
358
+ if self.stable_diffusion:
359
+ try:
360
+ # stable-diffusion-cpp handles cleanup automatically
361
+ self.stable_diffusion = None
362
+ self.model_path = None
363
+ self.model_config = None
364
+ self.loaded_model_name = None
365
+ logger.info("GGUF model unloaded")
366
+ except Exception as e:
367
+ logger.error(f"Error unloading GGUF model: {e}")
368
+
369
+ def get_model_info(self) -> Dict[str, Any]:
370
+ """Get information about the loaded model"""
371
+ if not self.stable_diffusion:
372
+ return {
373
+ 'gguf_available': GGUF_AVAILABLE,
374
+ 'loaded': False
375
+ }
376
+
377
+ return {
378
+ 'type': 'gguf',
379
+ 'variant': self.model_config.get('variant', 'unknown'),
380
+ 'path': str(self.model_path),
381
+ 'name': self.loaded_model_name,
382
+ 'loaded': True,
383
+ 'gguf_available': GGUF_AVAILABLE,
384
+ 'backend': 'stable-diffusion.cpp'
385
+ }
386
+
387
+ def is_loaded(self) -> bool:
388
+ """Check if a model is loaded"""
389
+ return self.stable_diffusion is not None
390
+
391
+ def get_gguf_download_patterns(self, variant: str) -> Dict[str, List[str]]:
392
+ """Get file patterns for downloading specific GGUF variant
393
+
394
+ Args:
395
+ variant: Model variant (e.g., 'gguf-q4-1', 'gguf-q4ks')
396
+
397
+ Returns:
398
+ Dict with 'allow_patterns' and 'ignore_patterns' lists
399
+ """
400
+ # Map variant to specific GGUF file patterns
401
+ variant_patterns = {
402
+ # FLUX.1-dev variants
403
+ 'gguf-q2k': ['*Q2_K*.gguf'],
404
+ 'gguf-q3ks': ['*Q3_K_S*.gguf'],
405
+ 'gguf-q4ks': ['*Q4_K_S*.gguf'],
406
+ 'gguf-q4-0': ['*Q4_0*.gguf'],
407
+ 'gguf-q4-1': ['*Q4_1*.gguf'],
408
+ 'gguf-q5ks': ['*Q5_K_S*.gguf'],
409
+ 'gguf-q5-0': ['*Q5_0*.gguf'],
410
+ 'gguf-q5-1': ['*Q5_1*.gguf'],
411
+ 'gguf-q6k': ['*Q6_K*.gguf'],
412
+ 'gguf-q8': ['*Q8_0*.gguf'],
413
+ 'gguf-q8-0': ['*Q8_0*.gguf'], # Keep for backward compatibility
414
+ 'gguf-f16': ['*F16*.gguf'],
415
+
416
+ # FLUX.1-schnell variants
417
+ 'gguf-schnell': ['*flux1-schnell*F16*.gguf'],
418
+ 'gguf-schnell-q2k': ['*flux1-schnell*Q2_K*.gguf'],
419
+ 'gguf-schnell-q3ks': ['*flux1-schnell*Q3_K_S*.gguf'],
420
+ 'gguf-schnell-q4-0': ['*flux1-schnell*Q4_0*.gguf'],
421
+ 'gguf-schnell-q4-1': ['*flux1-schnell*Q4_1*.gguf'],
422
+ 'gguf-schnell-q4ks': ['*flux1-schnell*Q4_K_S*.gguf'],
423
+ 'gguf-schnell-q5-0': ['*flux1-schnell*Q5_0*.gguf'],
424
+ 'gguf-schnell-q5-1': ['*flux1-schnell*Q5_1*.gguf'],
425
+ 'gguf-schnell-q5ks': ['*flux1-schnell*Q5_K_S*.gguf'],
426
+ 'gguf-schnell-q6k': ['*flux1-schnell*Q6_K*.gguf'],
427
+ 'gguf-schnell-q8': ['*flux1-schnell*Q8_0*.gguf'],
428
+ 'gguf-schnell-f16': ['*flux1-schnell*F16*.gguf'],
429
+
430
+ # Stable Diffusion 3.5 Large variants
431
+ 'gguf-large': ['*sd3.5_large-F16*.gguf'],
432
+ 'gguf-large-q4-0': ['*sd3.5_large-Q4_0*.gguf'],
433
+ 'gguf-large-q4-1': ['*sd3.5_large-Q4_1*.gguf'],
434
+ 'gguf-large-q5-0': ['*sd3.5_large-Q5_0*.gguf'],
435
+ 'gguf-large-q5-1': ['*sd3.5_large-Q5_1*.gguf'],
436
+ 'gguf-large-q8-0': ['*sd3.5_large-Q8_0*.gguf'],
437
+ 'gguf-large-f16': ['*sd3.5_large-F16*.gguf'],
438
+
439
+ # Stable Diffusion 3.5 Large Turbo variants
440
+ 'gguf-large-turbo': ['*sd3.5_large_turbo*F16*.gguf'],
441
+ 'gguf-large-turbo-q4-0': ['*sd3.5_large_turbo*Q4_0*.gguf'],
442
+ 'gguf-large-turbo-q4-1': ['*sd3.5_large_turbo*Q4_1*.gguf'],
443
+ 'gguf-large-turbo-q5-0': ['*sd3.5_large_turbo*Q5_0*.gguf'],
444
+ 'gguf-large-turbo-q5-1': ['*sd3.5_large_turbo*Q5_1*.gguf'],
445
+ 'gguf-large-turbo-q8-0': ['*sd3.5_large_turbo*Q8_0*.gguf'],
446
+ 'gguf-large-turbo-f16': ['*sd3.5_large_turbo*F16*.gguf'],
447
+
448
+ # Other model variants
449
+ 'gguf-medium': ['*sd3.5-medium*.gguf'],
450
+ 'gguf-sd3-medium': ['*sd3-medium*.gguf'],
451
+ 'gguf-lite': ['*flux-lite-8b*.gguf'],
452
+ 'gguf-distilled': ['*flux-dev-de-distill*.gguf'],
453
+ 'gguf-fill': ['*flux-fill-dev*.gguf'],
454
+ 'gguf-full': ['*hidream-i1-full*.gguf'],
455
+ 'gguf-dev': ['*hidream-i1-dev*.gguf'],
456
+ 'gguf-fast': ['*hidream-i1-fast*.gguf'],
457
+ 'gguf-i2v': ['*ltx-video-i2v*.gguf', '*hunyuan-video-i2v*.gguf'],
458
+ 'gguf-2b': ['*ltx-video-2b*.gguf'],
459
+ 'gguf-t2v': ['*hunyuan-video-t2v*.gguf'],
460
+ }
461
+
462
+ # Get the specific GGUF file pattern for this variant
463
+ gguf_pattern = variant_patterns.get(variant, ['*.gguf'])
464
+
465
+ # Essential files to download
466
+ essential_files = [
467
+ # Configuration and metadata
468
+ 'model_index.json',
469
+ 'README.md',
470
+ 'LICENSE*',
471
+ '.gitattributes',
472
+ 'config.json',
473
+ ]
474
+
475
+ # Include the specific GGUF model file
476
+ allow_patterns = essential_files + gguf_pattern
477
+
478
+ # Create ignore patterns based on variant name (not pattern content)
479
+ # This prevents conflicts between allow and ignore patterns
480
+ ignore_patterns = []
481
+
482
+ # Determine model family from variant name
483
+ if variant.startswith('gguf-schnell') or 'schnell' in variant:
484
+ # FLUX.1-schnell variants - ignore other model types
485
+ ignore_patterns = [
486
+ '*flux1-dev*.gguf', # Ignore FLUX.1-dev
487
+ '*sd3.5*.gguf', # Ignore SD 3.5
488
+ '*ltx-video*.gguf', # Ignore video models
489
+ '*hidream*.gguf', # Ignore HiDream
490
+ '*hunyuan*.gguf' # Ignore Hunyuan
491
+ ]
492
+ # Ignore other schnell quantizations except the one we want
493
+ for other_variant, other_patterns in variant_patterns.items():
494
+ if (other_variant.startswith('gguf-schnell') and
495
+ other_variant != variant and
496
+ other_variant != 'gguf'):
497
+ # Only ignore if it doesn't conflict with our allow patterns
498
+ for pattern in other_patterns:
499
+ if pattern not in gguf_pattern:
500
+ ignore_patterns.append(pattern)
501
+
502
+ elif (variant.startswith('gguf-large-turbo') or
503
+ 'large-turbo' in variant or
504
+ variant.startswith('gguf-large') or
505
+ 'sd3.5' in variant or
506
+ 'stable-diffusion-3' in variant):
507
+ # SD 3.5 variants - ignore other model types
508
+ ignore_patterns = [
509
+ '*flux1-dev*.gguf', # Ignore FLUX.1-dev
510
+ '*flux1-schnell*.gguf', # Ignore FLUX.1-schnell
511
+ '*ltx-video*.gguf', # Ignore video models
512
+ '*hidream*.gguf', # Ignore HiDream
513
+ '*hunyuan*.gguf' # Ignore Hunyuan
514
+ ]
515
+ # Ignore other SD 3.5 quantizations except the one we want
516
+ for other_variant, other_patterns in variant_patterns.items():
517
+ if (('large' in other_variant or 'sd3.5' in other_variant or 'stable-diffusion-3' in other_variant) and
518
+ other_variant != variant and
519
+ other_variant != 'gguf'):
520
+ # Only ignore if it doesn't conflict with our allow patterns
521
+ for pattern in other_patterns:
522
+ if pattern not in gguf_pattern:
523
+ ignore_patterns.append(pattern)
524
+
525
+ elif ('video' in variant or
526
+ 'i2v' in variant or
527
+ 't2v' in variant or
528
+ '2b' in variant):
529
+ # Video model variants
530
+ ignore_patterns = [
531
+ '*flux1-dev*.gguf',
532
+ '*flux1-schnell*.gguf',
533
+ '*sd3.5*.gguf'
534
+ ]
535
+
536
+ elif ('hidream' in variant or
537
+ 'full' in variant or
538
+ 'fast' in variant):
539
+ # HiDream variants
540
+ ignore_patterns = [
541
+ '*flux1-dev*.gguf',
542
+ '*flux1-schnell*.gguf',
543
+ '*sd3.5*.gguf',
544
+ '*ltx-video*.gguf',
545
+ '*hunyuan*.gguf'
546
+ ]
547
+
548
+ else:
549
+ # FLUX.1-dev variants (default case) - ignore other model types
550
+ ignore_patterns = [
551
+ '*flux1-schnell*.gguf', # Ignore FLUX.1-schnell
552
+ '*sd3.5*.gguf', # Ignore SD 3.5
553
+ '*ltx-video*.gguf', # Ignore video models
554
+ '*hidream*.gguf', # Ignore HiDream
555
+ '*hunyuan*.gguf' # Ignore Hunyuan
556
+ ]
557
+ # Ignore other FLUX.1-dev quantizations except the one we want
558
+ for other_variant, other_patterns in variant_patterns.items():
559
+ if (not other_variant.startswith('gguf-schnell') and
560
+ not 'large' in other_variant and
561
+ not 'sd3.5' in other_variant and
562
+ not 'video' in other_variant and
563
+ not 'hidream' in other_variant and
564
+ other_variant != variant and
565
+ other_variant != 'gguf'):
566
+ # Only ignore if it doesn't conflict with our allow patterns
567
+ for pattern in other_patterns:
568
+ if pattern not in gguf_pattern:
569
+ ignore_patterns.append(pattern)
570
+
571
+ return {
572
+ 'allow_patterns': allow_patterns,
573
+ 'ignore_patterns': ignore_patterns
574
+ }
575
+
576
+ def _get_model_family(self, pattern: str) -> str:
577
+ """Extract model family from a pattern (e.g., flux1-dev, flux1-schnell, sd3.5-large)"""
578
+ if 'flux1-dev' in pattern:
579
+ return 'flux1-dev'
580
+ elif 'flux1-schnell' in pattern:
581
+ return 'flux1-schnell'
582
+ elif 'sd3.5-large-turbo' in pattern:
583
+ return 'sd3.5-large-turbo'
584
+ elif 'sd3.5-large' in pattern:
585
+ return 'sd3.5-large'
586
+ elif 'sd3.5' in pattern:
587
+ return 'sd3.5'
588
+ else:
589
+ return pattern.split('*')[1].split('*')[0] if '*' in pattern else pattern
590
+
591
+ def download_required_components(self, model_path: Path) -> Dict[str, Optional[Path]]:
592
+ """Download or locate required VAE, CLIP-L, and T5XXL components
593
+
594
+ For different model types:
595
+ - FLUX GGUF models need: ae.safetensors (VAE), clip_l.safetensors, t5xxl_fp16.safetensors
596
+ - SD 3.5 models need: different text encoders and VAE
597
+ """
598
+ from ..utils.download_utils import robust_snapshot_download
599
+ from ..config.settings import settings
600
+
601
+ components = {
602
+ 'vae': None,
603
+ 'clip_l': None,
604
+ 'clip_g': None, # Needed for SD 3.5 models
605
+ 't5xxl': None
606
+ }
607
+
608
+ # Detect model type based on model path or name
609
+ model_name = model_path.name.lower()
610
+ is_sd35 = any(pattern in model_name for pattern in ['3.5', 'sd3.5', 'stable-diffusion-3-5'])
611
+ is_flux = any(x in model_name for x in ['flux', 'flux1'])
612
+
613
+ logger.info(f"Downloading required components for model type: {'SD3.5' if is_sd35 else 'FLUX' if is_flux else 'Unknown'}")
614
+
615
+ try:
616
+ if is_sd35:
617
+ # SD 3.5 models - use SD 3.5 specific components
618
+ logger.info("Downloading SD 3.5 components...")
619
+
620
+ # Download SD 3.5 VAE
621
+ vae_dir = model_path.parent / "sd35_vae"
622
+ if not (vae_dir / "vae.safetensors").exists():
623
+ logger.info("Downloading SD 3.5 VAE...")
624
+ robust_snapshot_download(
625
+ repo_id="stabilityai/stable-diffusion-3.5-large",
626
+ local_dir=str(vae_dir),
627
+ cache_dir=str(settings.cache_dir),
628
+ allow_patterns=['vae/diffusion_pytorch_model.safetensors'],
629
+ max_retries=3
630
+ )
631
+ # Move to expected location if needed
632
+ vae_source = vae_dir / "vae" / "diffusion_pytorch_model.safetensors"
633
+ vae_target = vae_dir / "vae.safetensors"
634
+ if vae_source.exists() and not vae_target.exists():
635
+ vae_source.rename(vae_target)
636
+
637
+ vae_path = vae_dir / "vae.safetensors"
638
+ if vae_path.exists():
639
+ components['vae'] = vae_path
640
+ logger.info(f"SD 3.5 VAE found at: {vae_path}")
641
+
642
+ # Download SD 3.5 text encoders
643
+ text_encoders_dir = model_path.parent / "sd35_text_encoders"
644
+
645
+ # Download CLIP-L for SD 3.5
646
+ if not (text_encoders_dir / "clip_l.safetensors").exists():
647
+ logger.info("Downloading SD 3.5 CLIP-L text encoder...")
648
+ robust_snapshot_download(
649
+ repo_id="stabilityai/stable-diffusion-3.5-large",
650
+ local_dir=str(text_encoders_dir),
651
+ cache_dir=str(settings.cache_dir),
652
+ allow_patterns=['text_encoders/clip_l.safetensors'],
653
+ max_retries=3
654
+ )
655
+ # Move to expected location if needed
656
+ clip_source = text_encoders_dir / "text_encoders" / "clip_l.safetensors"
657
+ clip_target = text_encoders_dir / "clip_l.safetensors"
658
+ if clip_source.exists() and not clip_target.exists():
659
+ clip_source.rename(clip_target)
660
+
661
+ clip_l_path = text_encoders_dir / "clip_l.safetensors"
662
+ if clip_l_path.exists():
663
+ components['clip_l'] = clip_l_path
664
+ logger.info(f"SD 3.5 CLIP-L found at: {clip_l_path}")
665
+
666
+ # Download CLIP-G for SD 3.5
667
+ if not (text_encoders_dir / "clip_g.safetensors").exists():
668
+ logger.info("Downloading SD 3.5 CLIP-G text encoder...")
669
+ robust_snapshot_download(
670
+ repo_id="stabilityai/stable-diffusion-3.5-large",
671
+ local_dir=str(text_encoders_dir),
672
+ cache_dir=str(settings.cache_dir),
673
+ allow_patterns=['text_encoders/clip_g.safetensors'],
674
+ max_retries=3
675
+ )
676
+ # Move to expected location if needed
677
+ clipg_source = text_encoders_dir / "text_encoders" / "clip_g.safetensors"
678
+ clipg_target = text_encoders_dir / "clip_g.safetensors"
679
+ if clipg_source.exists() and not clipg_target.exists():
680
+ clipg_source.rename(clipg_target)
681
+
682
+ clip_g_path = text_encoders_dir / "clip_g.safetensors"
683
+ if clip_g_path.exists():
684
+ components['clip_g'] = clip_g_path
685
+ logger.info(f"SD 3.5 CLIP-G found at: {clip_g_path}")
686
+
687
+ # Download T5XXL for SD 3.5
688
+ if not (text_encoders_dir / "t5xxl_fp16.safetensors").exists():
689
+ logger.info("Downloading SD 3.5 T5XXL text encoder...")
690
+ robust_snapshot_download(
691
+ repo_id="stabilityai/stable-diffusion-3.5-large",
692
+ local_dir=str(text_encoders_dir),
693
+ cache_dir=str(settings.cache_dir),
694
+ allow_patterns=['text_encoders/t5xxl_fp16.safetensors'],
695
+ max_retries=3
696
+ )
697
+ # Move to expected location if needed
698
+ t5_source = text_encoders_dir / "text_encoders" / "t5xxl_fp16.safetensors"
699
+ t5_target = text_encoders_dir / "t5xxl_fp16.safetensors"
700
+ if t5_source.exists() and not t5_target.exists():
701
+ t5_source.rename(t5_target)
702
+
703
+ t5xxl_path = text_encoders_dir / "t5xxl_fp16.safetensors"
704
+ if t5xxl_path.exists():
705
+ components['t5xxl'] = t5xxl_path
706
+ logger.info(f"SD 3.5 T5XXL found at: {t5xxl_path}")
707
+
708
+ else:
709
+ # FLUX models (default) - use FLUX specific components
710
+ logger.info("Downloading FLUX components...")
711
+
712
+ # Download VAE from official FLUX repository
713
+ vae_dir = model_path.parent / "flux_vae"
714
+ if not (vae_dir / "ae.safetensors").exists():
715
+ logger.info("Downloading FLUX VAE...")
716
+ robust_snapshot_download(
717
+ repo_id="black-forest-labs/FLUX.1-dev",
718
+ local_dir=str(vae_dir),
719
+ cache_dir=str(settings.cache_dir),
720
+ allow_patterns=['ae.safetensors'],
721
+ max_retries=3
722
+ )
723
+
724
+ vae_path = vae_dir / "ae.safetensors"
725
+ if vae_path.exists():
726
+ components['vae'] = vae_path
727
+ logger.info(f"FLUX VAE found at: {vae_path}")
728
+
729
+ # Download text encoders
730
+ text_encoders_dir = model_path.parent / "flux_text_encoders"
731
+
732
+ # Download CLIP-L
733
+ if not (text_encoders_dir / "clip_l.safetensors").exists():
734
+ logger.info("Downloading FLUX CLIP-L text encoder...")
735
+ robust_snapshot_download(
736
+ repo_id="comfyanonymous/flux_text_encoders",
737
+ local_dir=str(text_encoders_dir),
738
+ cache_dir=str(settings.cache_dir),
739
+ allow_patterns=['clip_l.safetensors'],
740
+ max_retries=3
741
+ )
742
+
743
+ clip_l_path = text_encoders_dir / "clip_l.safetensors"
744
+ if clip_l_path.exists():
745
+ components['clip_l'] = clip_l_path
746
+ logger.info(f"FLUX CLIP-L found at: {clip_l_path}")
747
+
748
+ # Download T5XXL
749
+ if not (text_encoders_dir / "t5xxl_fp16.safetensors").exists():
750
+ logger.info("Downloading FLUX T5XXL text encoder...")
751
+ robust_snapshot_download(
752
+ repo_id="comfyanonymous/flux_text_encoders",
753
+ local_dir=str(text_encoders_dir),
754
+ cache_dir=str(settings.cache_dir),
755
+ allow_patterns=['t5xxl_fp16.safetensors'],
756
+ max_retries=3
757
+ )
758
+
759
+ t5xxl_path = text_encoders_dir / "t5xxl_fp16.safetensors"
760
+ if t5xxl_path.exists():
761
+ components['t5xxl'] = t5xxl_path
762
+ logger.info(f"FLUX T5XXL found at: {t5xxl_path}")
763
+
764
+ except Exception as e:
765
+ logger.error(f"Failed to download components: {e}")
766
+
767
+ return components
768
+
769
+
770
+ # Global GGUF loader instance
771
+ gguf_loader = GGUFModelLoader()