ollamadiffuser 1.1.6__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,437 @@
1
+ """
2
+ GGUF Model Loader and Interface
3
+
4
+ This module provides support for loading and running GGUF quantized models,
5
+ specifically for FLUX.1-dev-gguf variants using stable-diffusion.cpp Python bindings.
6
+ """
7
+
8
+ import os
9
+ import logging
10
+ from pathlib import Path
11
+ from typing import Dict, Any, Optional, List, Union
12
+ import torch
13
+ from PIL import Image
14
+ import numpy as np
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ try:
19
+ from stable_diffusion_cpp import StableDiffusion
20
+ GGUF_AVAILABLE = True
21
+ logger.info("stable-diffusion-cpp-python is available")
22
+ except ImportError:
23
+ StableDiffusion = None
24
+ GGUF_AVAILABLE = False
25
+ logger.warning("stable-diffusion-cpp-python not available. GGUF models will not work.")
26
+
27
+ class GGUFModelLoader:
28
+ """Loader for GGUF quantized diffusion models using stable-diffusion.cpp"""
29
+
30
+ def __init__(self):
31
+ self.model = None
32
+ self.model_path = None
33
+ self.model_config = None
34
+ self.loaded_model_name = None
35
+ self.stable_diffusion = None
36
+
37
+ def is_gguf_model(self, model_name: str, model_config: Dict[str, Any]) -> bool:
38
+ """Check if a model is a GGUF model"""
39
+ variant = model_config.get('variant', '')
40
+ return 'gguf' in variant.lower() or model_name.endswith('-gguf') or 'gguf' in model_name.lower()
41
+
42
+ def get_gguf_file_path(self, model_dir: Path, variant: str) -> Optional[Path]:
43
+ """Find the appropriate GGUF file based on variant"""
44
+ if not model_dir.exists():
45
+ return None
46
+
47
+ # Map variant to actual file names
48
+ variant_mapping = {
49
+ 'gguf-q2k': 'flux1-dev-Q2_K.gguf',
50
+ 'gguf-q3ks': 'flux1-dev-Q3_K_S.gguf',
51
+ 'gguf-q4ks': 'flux1-dev-Q4_K_S.gguf',
52
+ 'gguf-q4-0': 'flux1-dev-Q4_0.gguf',
53
+ 'gguf-q4-1': 'flux1-dev-Q4_1.gguf',
54
+ 'gguf-q5ks': 'flux1-dev-Q5_K_S.gguf',
55
+ 'gguf-q5-0': 'flux1-dev-Q5_0.gguf',
56
+ 'gguf-q5-1': 'flux1-dev-Q5_1.gguf',
57
+ 'gguf-q6k': 'flux1-dev-Q6_K.gguf',
58
+ 'gguf-q8': 'flux1-dev-Q8_0.gguf',
59
+ 'gguf-f16': 'flux1-dev-F16.gguf',
60
+ 'gguf': 'flux1-dev-Q4_K_S.gguf', # Default to Q4_K_S
61
+ }
62
+
63
+ filename = variant_mapping.get(variant.lower())
64
+ if filename:
65
+ gguf_file = model_dir / filename
66
+ if gguf_file.exists():
67
+ return gguf_file
68
+
69
+ # Fallback: search for any .gguf file
70
+ gguf_files = list(model_dir.glob('*.gguf'))
71
+ if gguf_files:
72
+ return gguf_files[0] # Return first found
73
+
74
+ return None
75
+
76
+ def get_additional_model_files(self, model_dir: Path) -> Dict[str, Optional[Path]]:
77
+ """Find additional model files required for FLUX GGUF inference"""
78
+ files = {
79
+ 'vae': None,
80
+ 'clip_l': None,
81
+ 't5xxl': None
82
+ }
83
+
84
+ # Common file patterns for FLUX models
85
+ vae_patterns = ['ae.safetensors', 'vae.safetensors', 'flux_vae.safetensors']
86
+ clip_l_patterns = ['clip_l.safetensors', 'text_encoder.safetensors']
87
+ t5xxl_patterns = ['t5xxl_fp16.safetensors', 't5xxl.safetensors', 't5_encoder.safetensors']
88
+
89
+ # Search for VAE
90
+ for pattern in vae_patterns:
91
+ vae_file = model_dir / pattern
92
+ if vae_file.exists():
93
+ files['vae'] = vae_file
94
+ break
95
+
96
+ # Search for CLIP-L
97
+ for pattern in clip_l_patterns:
98
+ clip_file = model_dir / pattern
99
+ if clip_file.exists():
100
+ files['clip_l'] = clip_file
101
+ break
102
+
103
+ # Search for T5XXL
104
+ for pattern in t5xxl_patterns:
105
+ t5_file = model_dir / pattern
106
+ if t5_file.exists():
107
+ files['t5xxl'] = t5_file
108
+ break
109
+
110
+ return files
111
+
112
+ def load_model(self, model_config: Dict[str, Any], model_name: str = None, model_path: Path = None) -> bool:
113
+ """Load GGUF model using stable-diffusion.cpp"""
114
+ # Extract parameters from model_config if not provided separately
115
+ if model_name is None:
116
+ model_name = model_config.get('name', 'unknown')
117
+ if model_path is None:
118
+ model_path = Path(model_config.get('path', ''))
119
+
120
+ logger.info(f"Loading GGUF model: {model_name}")
121
+
122
+ try:
123
+ # Find the GGUF file
124
+ gguf_files = list(model_path.glob("*.gguf"))
125
+ if not gguf_files:
126
+ logger.error(f"No GGUF files found in {model_path}")
127
+ return False
128
+
129
+ gguf_file = gguf_files[0] # Use the first GGUF file found
130
+ logger.info(f"Using GGUF file: {gguf_file}")
131
+
132
+ # Download required components
133
+ components = self.download_required_components(model_path)
134
+
135
+ # Verify all components are available
136
+ missing_components = [name for name, path in components.items() if path is None]
137
+ if missing_components:
138
+ logger.error(f"Missing required components: {missing_components}")
139
+ return False
140
+
141
+ # Initialize stable-diffusion.cpp
142
+ if not GGUF_AVAILABLE:
143
+ logger.error("stable-diffusion-cpp-python not properly installed")
144
+ return False
145
+
146
+ # Create StableDiffusion instance with correct API for FLUX
147
+ # For FLUX models, use diffusion_model_path instead of model_path
148
+ self.stable_diffusion = StableDiffusion(
149
+ diffusion_model_path=str(gguf_file), # FLUX GGUF models use this parameter
150
+ vae_path=str(components['vae']),
151
+ clip_l_path=str(components['clip_l']),
152
+ t5xxl_path=str(components['t5xxl']),
153
+ vae_decode_only=True, # For txt2img only
154
+ n_threads=-1 # Auto-detect threads
155
+ )
156
+
157
+ self.model_path = str(gguf_file)
158
+ self.model_config = model_config
159
+ self.loaded_model_name = model_name
160
+
161
+ logger.info(f"Successfully loaded GGUF model: {model_name}")
162
+ return True
163
+
164
+ except Exception as e:
165
+ logger.error(f"Failed to load GGUF model {model_name}: {e}")
166
+ if hasattr(self, 'stable_diffusion') and self.stable_diffusion:
167
+ self.stable_diffusion = None
168
+ return False
169
+
170
+ def generate_image(self, prompt: str, **kwargs) -> Optional[Image.Image]:
171
+ """Generate image using stable-diffusion.cpp FLUX inference"""
172
+ if not self.stable_diffusion:
173
+ logger.error("GGUF model not loaded")
174
+ return None
175
+
176
+ try:
177
+ # Extract parameters with FLUX-optimized defaults
178
+ # Support both parameter naming conventions for compatibility
179
+ width = kwargs.get('width', 1024)
180
+ height = kwargs.get('height', 1024)
181
+
182
+ # Support both 'steps' and 'num_inference_steps'
183
+ steps = kwargs.get('steps') or kwargs.get('num_inference_steps', 20) # Increased for better quality
184
+
185
+ # Support both 'cfg_scale' and 'guidance_scale' - FLUX works best with low CFG
186
+ cfg_scale = kwargs.get('cfg_scale') or kwargs.get('guidance_scale', 1.0) # FLUX optimized CFG (reduced from 1.2)
187
+
188
+ seed = kwargs.get('seed', 42)
189
+ negative_prompt = kwargs.get('negative_prompt', "")
190
+
191
+ # Allow custom sampler, with FLUX-optimized default
192
+ sampler = kwargs.get('sampler', kwargs.get('sample_method', 'dpmpp2m')) # Better sampler for FLUX (fixed name)
193
+
194
+ # Validate sampler and provide fallback
195
+ valid_samplers = ['euler_a', 'euler', 'heun', 'dpm2', 'dpmpp2s_a', 'dpmpp2m', 'dpmpp2mv2', 'ipndm', 'ipndm_v', 'lcm', 'ddim_trailing', 'tcd']
196
+ if sampler not in valid_samplers:
197
+ logger.warning(f"Invalid sampler '{sampler}', falling back to 'dpmpp2m'")
198
+ sampler = 'dpmpp2m'
199
+
200
+ logger.info(f"Generating image: {width}x{height}, steps={steps}, cfg={cfg_scale}, sampler={sampler}, negative_prompt={negative_prompt}")
201
+
202
+ # Log model quantization info for quality assessment
203
+ if hasattr(self, 'model_path'):
204
+ if 'Q2' in str(self.model_path):
205
+ logger.warning("Using Q2 quantization - expect lower quality. Consider Q4_K_S or higher for better results.")
206
+ elif 'Q3' in str(self.model_path):
207
+ logger.info("Using Q3 quantization - moderate quality. Consider Q4_K_S or higher for better results.")
208
+ elif 'Q4' in str(self.model_path):
209
+ logger.info("Using Q4 quantization - good balance of quality and size.")
210
+ elif any(x in str(self.model_path) for x in ['Q5', 'Q6', 'Q8', 'F16']):
211
+ logger.info("Using high precision quantization - excellent quality expected.")
212
+
213
+ # Generate image using stable-diffusion.cpp
214
+ # According to the documentation, txt_to_img returns a list of PIL Images
215
+ try:
216
+ result = self.stable_diffusion.txt_to_img(
217
+ prompt=prompt,
218
+ negative_prompt=negative_prompt if negative_prompt else "",
219
+ cfg_scale=cfg_scale,
220
+ width=width,
221
+ height=height,
222
+ sample_method=sampler, # Use optimized sampler
223
+ sample_steps=steps,
224
+ seed=seed
225
+ )
226
+ logger.info(f"txt_to_img returned: {type(result)}, length: {len(result) if result else 'None'}")
227
+ except Exception as e:
228
+ logger.error(f"txt_to_img call failed: {e}")
229
+ return None
230
+
231
+ if not result:
232
+ logger.error("txt_to_img returned None")
233
+ return None
234
+
235
+ if not isinstance(result, list) or len(result) == 0:
236
+ logger.error(f"txt_to_img returned unexpected format: {type(result)}")
237
+ return None
238
+
239
+ # Get the first PIL Image from the result list
240
+ image = result[0]
241
+ logger.info(f"Retrieved PIL Image: {type(image)}")
242
+
243
+ # Verify it's a PIL Image
244
+ if not hasattr(image, 'save'):
245
+ logger.error(f"Result[0] is not a PIL Image: {type(image)}")
246
+ return None
247
+
248
+ # Optionally save a copy for debugging/history
249
+ try:
250
+ from ..config.settings import settings
251
+ output_dir = settings.config_dir / "outputs"
252
+ output_dir.mkdir(exist_ok=True)
253
+
254
+ output_path = output_dir / f"gguf_output_{seed}.png"
255
+ image.save(output_path)
256
+ logger.info(f"Generated image also saved to: {output_path}")
257
+ except Exception as e:
258
+ logger.warning(f"Failed to save debug copy: {e}")
259
+
260
+ # Return the PIL Image directly for API compatibility
261
+ logger.info("Returning PIL Image for API use")
262
+ return image
263
+
264
+ except Exception as e:
265
+ logger.error(f"Failed to generate image with GGUF model: {e}")
266
+ import traceback
267
+ logger.error(f"Traceback: {traceback.format_exc()}")
268
+ return None
269
+
270
+ def unload_model(self):
271
+ """Unload the GGUF model"""
272
+ if self.stable_diffusion:
273
+ try:
274
+ # stable-diffusion-cpp handles cleanup automatically
275
+ self.stable_diffusion = None
276
+ self.model_path = None
277
+ self.model_config = None
278
+ self.loaded_model_name = None
279
+ logger.info("GGUF model unloaded")
280
+ except Exception as e:
281
+ logger.error(f"Error unloading GGUF model: {e}")
282
+
283
+ def get_model_info(self) -> Dict[str, Any]:
284
+ """Get information about the loaded model"""
285
+ if not self.stable_diffusion:
286
+ return {
287
+ 'gguf_available': GGUF_AVAILABLE,
288
+ 'loaded': False
289
+ }
290
+
291
+ return {
292
+ 'type': 'gguf',
293
+ 'variant': self.model_config.get('variant', 'unknown'),
294
+ 'path': str(self.model_path),
295
+ 'name': self.loaded_model_name,
296
+ 'loaded': True,
297
+ 'gguf_available': GGUF_AVAILABLE,
298
+ 'backend': 'stable-diffusion.cpp'
299
+ }
300
+
301
+ def is_loaded(self) -> bool:
302
+ """Check if a model is loaded"""
303
+ return self.stable_diffusion is not None
304
+
305
+ def get_gguf_download_patterns(self, variant: str) -> Dict[str, List[str]]:
306
+ """Get file patterns for downloading specific GGUF variant
307
+
308
+ Args:
309
+ variant: Model variant (e.g., 'gguf-q4-1', 'gguf-q4ks')
310
+
311
+ Returns:
312
+ Dict with 'allow_patterns' and 'ignore_patterns' lists
313
+ """
314
+ # Map variant to specific GGUF file patterns
315
+ variant_patterns = {
316
+ 'gguf-q2k': ['*Q2_K*.gguf'],
317
+ 'gguf-q3ks': ['*Q3_K_S*.gguf'],
318
+ 'gguf-q4-0': ['*Q4_0*.gguf'],
319
+ 'gguf-q4-1': ['*Q4_1*.gguf'],
320
+ 'gguf-q4ks': ['*Q4_K_S*.gguf'],
321
+ 'gguf-q5-0': ['*Q5_0*.gguf'],
322
+ 'gguf-q5-1': ['*Q5_1*.gguf'],
323
+ 'gguf-q5ks': ['*Q5_K_S*.gguf'],
324
+ 'gguf-q6k': ['*Q6_K*.gguf'],
325
+ 'gguf-q8-0': ['*Q8_0*.gguf'],
326
+ 'gguf-f16': ['*F16*.gguf']
327
+ }
328
+
329
+ # Get the specific GGUF file pattern for this variant
330
+ gguf_pattern = variant_patterns.get(variant, ['*.gguf'])
331
+
332
+ # Essential files to download
333
+ essential_files = [
334
+ # Configuration and metadata
335
+ 'model_index.json',
336
+ 'README.md',
337
+ 'LICENSE*',
338
+ '.gitattributes',
339
+ 'config.json',
340
+ ]
341
+
342
+ # Include the specific GGUF model file
343
+ allow_patterns = essential_files + gguf_pattern
344
+
345
+ # Create ignore patterns - ignore all other GGUF variants
346
+ all_gguf_variants = []
347
+ for pattern_list in variant_patterns.values():
348
+ all_gguf_variants.extend(pattern_list)
349
+
350
+ # Remove the current variant from ignore list
351
+ ignore_patterns = [p for p in all_gguf_variants if p not in gguf_pattern]
352
+
353
+ return {
354
+ 'allow_patterns': allow_patterns,
355
+ 'ignore_patterns': ignore_patterns
356
+ }
357
+
358
+ def download_required_components(self, model_path: Path) -> Dict[str, Optional[Path]]:
359
+ """Download or locate required VAE, CLIP-L, and T5XXL components
360
+
361
+ For FLUX GGUF models, these components need to be downloaded separately:
362
+ - VAE: ae.safetensors from black-forest-labs/FLUX.1-dev
363
+ - CLIP-L: clip_l.safetensors from comfyanonymous/flux_text_encoders
364
+ - T5XXL: t5xxl_fp16.safetensors from comfyanonymous/flux_text_encoders
365
+ """
366
+ from ..utils.download_utils import robust_snapshot_download
367
+ from ..config.settings import settings
368
+
369
+ components = {
370
+ 'vae': None,
371
+ 'clip_l': None,
372
+ 't5xxl': None
373
+ }
374
+
375
+ logger.info("Downloading required FLUX components...")
376
+
377
+ try:
378
+ # Download VAE from official FLUX repository
379
+ vae_dir = model_path.parent / "flux_vae"
380
+ if not (vae_dir / "ae.safetensors").exists():
381
+ logger.info("Downloading FLUX VAE...")
382
+ robust_snapshot_download(
383
+ repo_id="black-forest-labs/FLUX.1-dev",
384
+ local_dir=str(vae_dir),
385
+ cache_dir=str(settings.cache_dir),
386
+ allow_patterns=['ae.safetensors'],
387
+ max_retries=3
388
+ )
389
+
390
+ vae_path = vae_dir / "ae.safetensors"
391
+ if vae_path.exists():
392
+ components['vae'] = vae_path
393
+ logger.info(f"VAE found at: {vae_path}")
394
+
395
+ # Download text encoders
396
+ text_encoders_dir = model_path.parent / "flux_text_encoders"
397
+
398
+ # Download CLIP-L
399
+ if not (text_encoders_dir / "clip_l.safetensors").exists():
400
+ logger.info("Downloading CLIP-L text encoder...")
401
+ robust_snapshot_download(
402
+ repo_id="comfyanonymous/flux_text_encoders",
403
+ local_dir=str(text_encoders_dir),
404
+ cache_dir=str(settings.cache_dir),
405
+ allow_patterns=['clip_l.safetensors'],
406
+ max_retries=3
407
+ )
408
+
409
+ clip_l_path = text_encoders_dir / "clip_l.safetensors"
410
+ if clip_l_path.exists():
411
+ components['clip_l'] = clip_l_path
412
+ logger.info(f"CLIP-L found at: {clip_l_path}")
413
+
414
+ # Download T5XXL
415
+ if not (text_encoders_dir / "t5xxl_fp16.safetensors").exists():
416
+ logger.info("Downloading T5XXL text encoder...")
417
+ robust_snapshot_download(
418
+ repo_id="comfyanonymous/flux_text_encoders",
419
+ local_dir=str(text_encoders_dir),
420
+ cache_dir=str(settings.cache_dir),
421
+ allow_patterns=['t5xxl_fp16.safetensors'],
422
+ max_retries=3
423
+ )
424
+
425
+ t5xxl_path = text_encoders_dir / "t5xxl_fp16.safetensors"
426
+ if t5xxl_path.exists():
427
+ components['t5xxl'] = t5xxl_path
428
+ logger.info(f"T5XXL found at: {t5xxl_path}")
429
+
430
+ except Exception as e:
431
+ logger.error(f"Failed to download FLUX components: {e}")
432
+
433
+ return components
434
+
435
+
436
+ # Global GGUF loader instance
437
+ gguf_loader = GGUFModelLoader()