ollamadiffuser 1.2.2__py3-none-any.whl → 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. ollamadiffuser/__init__.py +1 -1
  2. ollamadiffuser/api/server.py +312 -312
  3. ollamadiffuser/cli/config_commands.py +119 -0
  4. ollamadiffuser/cli/lora_commands.py +169 -0
  5. ollamadiffuser/cli/main.py +85 -1233
  6. ollamadiffuser/cli/model_commands.py +664 -0
  7. ollamadiffuser/cli/recommend_command.py +205 -0
  8. ollamadiffuser/cli/registry_commands.py +197 -0
  9. ollamadiffuser/core/config/model_registry.py +562 -11
  10. ollamadiffuser/core/config/settings.py +24 -2
  11. ollamadiffuser/core/inference/__init__.py +5 -0
  12. ollamadiffuser/core/inference/base.py +182 -0
  13. ollamadiffuser/core/inference/engine.py +204 -1405
  14. ollamadiffuser/core/inference/strategies/__init__.py +1 -0
  15. ollamadiffuser/core/inference/strategies/controlnet_strategy.py +170 -0
  16. ollamadiffuser/core/inference/strategies/flux_strategy.py +136 -0
  17. ollamadiffuser/core/inference/strategies/generic_strategy.py +164 -0
  18. ollamadiffuser/core/inference/strategies/gguf_strategy.py +113 -0
  19. ollamadiffuser/core/inference/strategies/hidream_strategy.py +104 -0
  20. ollamadiffuser/core/inference/strategies/sd15_strategy.py +134 -0
  21. ollamadiffuser/core/inference/strategies/sd3_strategy.py +80 -0
  22. ollamadiffuser/core/inference/strategies/sdxl_strategy.py +131 -0
  23. ollamadiffuser/core/inference/strategies/video_strategy.py +108 -0
  24. ollamadiffuser/mcp/__init__.py +0 -0
  25. ollamadiffuser/mcp/server.py +184 -0
  26. ollamadiffuser/ui/templates/index.html +62 -1
  27. ollamadiffuser/ui/web.py +116 -54
  28. {ollamadiffuser-1.2.2.dist-info → ollamadiffuser-2.0.0.dist-info}/METADATA +337 -108
  29. ollamadiffuser-2.0.0.dist-info/RECORD +61 -0
  30. {ollamadiffuser-1.2.2.dist-info → ollamadiffuser-2.0.0.dist-info}/WHEEL +1 -1
  31. {ollamadiffuser-1.2.2.dist-info → ollamadiffuser-2.0.0.dist-info}/entry_points.txt +1 -0
  32. ollamadiffuser/core/models/registry.py +0 -384
  33. ollamadiffuser/ui/samples/.DS_Store +0 -0
  34. ollamadiffuser-1.2.2.dist-info/RECORD +0 -45
  35. {ollamadiffuser-1.2.2.dist-info → ollamadiffuser-2.0.0.dist-info}/licenses/LICENSE +0 -0
  36. {ollamadiffuser-1.2.2.dist-info → ollamadiffuser-2.0.0.dist-info}/top_level.txt +0 -0
@@ -1,1427 +1,226 @@
1
- import os
2
- import logging
3
- import torch
4
- import numpy as np
5
- from diffusers import (
6
- StableDiffusionPipeline,
7
- StableDiffusionXLPipeline,
8
- StableDiffusion3Pipeline,
9
- FluxPipeline,
10
- StableDiffusionControlNetPipeline,
11
- StableDiffusionXLControlNetPipeline,
12
- ControlNetModel,
13
- AnimateDiffPipeline,
14
- MotionAdapter
15
- )
16
- # Try to import HiDreamImagePipeline if available
17
- try:
18
- from diffusers import HiDreamImagePipeline
19
- HIDREAM_AVAILABLE = True
20
- except ImportError:
21
- HIDREAM_AVAILABLE = False
22
- logger = logging.getLogger(__name__)
23
- logger.warning("HiDreamImagePipeline not available. Install latest diffusers from source for HiDream support.")
1
+ """
2
+ Inference Engine - Facade that delegates to model-specific strategies.
24
3
 
25
- # Import GGUF support
26
- try:
27
- from ..models.gguf_loader import gguf_loader, GGUF_AVAILABLE
28
- logger = logging.getLogger(__name__)
29
- if GGUF_AVAILABLE:
30
- logger.info("GGUF support available for quantized model inference")
31
- else:
32
- logger.warning("GGUF support not available. Install with: pip install llama-cpp-python gguf")
33
- except ImportError:
34
- GGUF_AVAILABLE = False
35
- logger = logging.getLogger(__name__)
36
- logger.warning("GGUF loader module not found")
4
+ This replaces the former 1400+ line god class with a clean strategy pattern.
5
+ Each model type (SD1.5, SDXL, FLUX, SD3, ControlNet, Video, HiDream, GGUF)
6
+ has its own strategy class that handles loading and generation.
7
+ """
8
+
9
+ import logging
10
+ from typing import Any, Dict, Optional, Union
37
11
 
38
12
  from PIL import Image
39
- from typing import Optional, Dict, Any, Union
40
- from pathlib import Path
41
- from ..config.settings import ModelConfig
42
- from ..utils.controlnet_preprocessors import controlnet_preprocessor
43
13
 
44
- # Global safety checker disabling
45
- os.environ["DISABLE_NSFW_CHECKER"] = "1"
46
- os.environ["DIFFUSERS_DISABLE_SAFETY_CHECKER"] = "1"
14
+ from ..config.settings import ModelConfig
15
+ from .base import InferenceStrategy
47
16
 
48
17
  logger = logging.getLogger(__name__)
18
+
19
+
20
+ def _get_strategy(model_type: str) -> InferenceStrategy:
21
+ """Create the appropriate strategy for a model type."""
22
+ if model_type == "sd15":
23
+ from .strategies.sd15_strategy import SD15Strategy
24
+ return SD15Strategy()
25
+ elif model_type == "sdxl":
26
+ from .strategies.sdxl_strategy import SDXLStrategy
27
+ return SDXLStrategy()
28
+ elif model_type == "flux":
29
+ from .strategies.flux_strategy import FluxStrategy
30
+ return FluxStrategy()
31
+ elif model_type == "sd3":
32
+ from .strategies.sd3_strategy import SD3Strategy
33
+ return SD3Strategy()
34
+ elif model_type in ("controlnet_sd15", "controlnet_sdxl"):
35
+ from .strategies.controlnet_strategy import ControlNetStrategy
36
+ return ControlNetStrategy()
37
+ elif model_type == "video":
38
+ from .strategies.video_strategy import VideoStrategy
39
+ return VideoStrategy()
40
+ elif model_type == "hidream":
41
+ from .strategies.hidream_strategy import HiDreamStrategy
42
+ return HiDreamStrategy()
43
+ elif model_type == "gguf":
44
+ from .strategies.gguf_strategy import GGUFStrategy
45
+ return GGUFStrategy()
46
+ elif model_type == "generic":
47
+ from .strategies.generic_strategy import GenericPipelineStrategy
48
+ return GenericPipelineStrategy()
49
+ else:
50
+ raise ValueError(f"Unsupported model type: {model_type}")
51
+
52
+
53
+ def _detect_device() -> str:
54
+ """Automatically detect the best available device."""
55
+ import torch
56
+
57
+ if torch.cuda.is_available():
58
+ device = "cuda"
59
+ logger.debug(f"CUDA device count: {torch.cuda.device_count()}")
60
+ elif torch.backends.mps.is_available():
61
+ device = "mps"
62
+ else:
63
+ device = "cpu"
64
+
65
+ logger.info(f"Using device: {device}")
66
+ if device == "cpu":
67
+ logger.warning("Using CPU - inference will be slower")
68
+ return device
69
+
70
+
49
71
  class InferenceEngine:
50
- """Inference engine responsible for actual image generation"""
51
-
72
+ """
73
+ Facade engine that delegates to model-type-specific strategies.
74
+
75
+ This class maintains backward compatibility with the previous monolithic
76
+ InferenceEngine API while cleanly separating concerns into per-model strategies.
77
+ """
78
+
52
79
  def __init__(self):
53
- self.pipeline = None
80
+ self._strategy: Optional[InferenceStrategy] = None
54
81
  self.model_config: Optional[ModelConfig] = None
55
- self.device = None
56
- self.tokenizer = None
57
- self.max_token_limit = 77
58
- self.current_lora = None # Track current LoRA state
59
- self.controlnet = None # Track ControlNet model
60
- self.is_controlnet_pipeline = False # Track if current pipeline is ControlNet
61
-
62
- def _get_device(self) -> str:
63
- """Automatically detect available device"""
64
- # Debug device availability
65
- logger.debug(f"CUDA available: {torch.cuda.is_available()}")
66
- logger.debug(f"MPS available: {torch.backends.mps.is_available()}")
67
-
68
- # Determine device
69
- if torch.cuda.is_available():
70
- device = "cuda"
71
- logger.debug(f"CUDA device count: {torch.cuda.device_count()}")
72
- elif torch.backends.mps.is_available():
73
- device = "mps" # Apple Silicon GPU
74
- else:
75
- device = "cpu"
76
-
77
- logger.info(f"Using device: {device}")
78
- if device == "cpu":
79
- logger.warning("⚠️ Using CPU - this will be slower for large models")
80
-
81
- return device
82
-
83
- def _get_pipeline_class(self, model_type: str):
84
- """Get corresponding pipeline class based on model type"""
85
- pipeline_map = {
86
- "sd15": StableDiffusionPipeline,
87
- "sdxl": StableDiffusionXLPipeline,
88
- "sd3": StableDiffusion3Pipeline,
89
- "flux": FluxPipeline,
90
- "gguf": "gguf_special", # Special marker for GGUF models
91
- "controlnet_sd15": StableDiffusionControlNetPipeline,
92
- "controlnet_sdxl": StableDiffusionXLControlNetPipeline,
93
- "video": AnimateDiffPipeline,
94
- }
95
-
96
- # Add HiDream support if available
97
- if HIDREAM_AVAILABLE:
98
- pipeline_map["hidream"] = HiDreamImagePipeline
99
-
100
- return pipeline_map.get(model_type)
101
-
82
+ self.device: Optional[str] = None
83
+
84
+ # -- Backward-compatible properties --
85
+
86
+ @property
87
+ def pipeline(self):
88
+ """Access the underlying pipeline (backward compat)."""
89
+ return self._strategy.pipeline if self._strategy else None
90
+
91
+ @property
92
+ def current_lora(self):
93
+ return self._strategy.current_lora if self._strategy else None
94
+
95
+ @property
96
+ def is_controlnet_pipeline(self) -> bool:
97
+ return getattr(self._strategy, "is_controlnet_pipeline", False)
98
+
99
+ # -- Core API --
100
+
102
101
  def load_model(self, model_config: ModelConfig) -> bool:
103
- """Load model"""
102
+ """Load a model using the appropriate strategy."""
104
103
  try:
105
- # Validate model configuration
106
- if not model_config:
107
- logger.error("Model configuration is None")
108
- return False
109
-
110
- if not model_config.path:
111
- logger.error(f"Model path is None for model: {model_config.name}")
104
+ if not model_config or not model_config.path:
105
+ logger.error("Invalid model configuration")
112
106
  return False
113
-
114
- model_path = Path(model_config.path)
115
- if not model_path.exists():
116
- logger.error(f"Model path does not exist: {model_config.path}")
117
- return False
118
-
119
- logger.info(f"Loading model from path: {model_config.path}")
120
-
121
- self.device = self._get_device()
122
- logger.info(f"Using device: {self.device}")
123
-
124
- # Get corresponding pipeline class
125
- pipeline_class = self._get_pipeline_class(model_config.model_type)
126
- if not pipeline_class:
127
- logger.error(f"Unsupported model type: {model_config.model_type}")
107
+
108
+ # Only check existence for local paths, not HuggingFace Hub IDs
109
+ # Hub IDs look like "org/model-name", local paths start with / . or ~
110
+ from pathlib import Path
111
+ model_path = model_config.path
112
+ is_local_path = model_path.startswith(("/", ".", "~")) or (len(model_path) > 1 and model_path[1] == ":")
113
+ if is_local_path and not Path(model_path).expanduser().exists():
114
+ logger.error(f"Model path does not exist: {model_path}")
128
115
  return False
129
-
130
- # Handle GGUF models specially
131
- if model_config.model_type == "gguf" or (model_config.variant and "gguf" in model_config.variant.lower()):
132
- if not GGUF_AVAILABLE:
133
- logger.error("GGUF support not available. Install with: pip install llama-cpp-python gguf")
134
- return False
135
-
136
- logger.info(f"Loading GGUF model: {model_config.name} (variant: {model_config.variant})")
137
-
138
- # Use GGUF loader instead of regular pipeline
139
- model_config_dict = {
140
- 'name': model_config.name,
141
- 'path': model_config.path,
142
- 'variant': model_config.variant,
143
- 'model_type': model_config.model_type,
144
- 'parameters': model_config.parameters
145
- }
146
-
147
- if gguf_loader.load_model(model_config_dict):
148
- # Set pipeline to None since we're using GGUF loader
149
- self.pipeline = None
150
- self.model_config = model_config
151
- self.device = self._get_device()
152
- logger.info(f"GGUF model {model_config.name} loaded successfully")
153
- return True
154
- else:
155
- logger.error(f"Failed to load GGUF model: {model_config.name}")
156
- return False
157
-
158
- # Check if this is a ControlNet model
159
- self.is_controlnet_pipeline = model_config.model_type.startswith("controlnet_")
160
-
161
- # Handle ControlNet models
162
- if self.is_controlnet_pipeline:
163
- return self._load_controlnet_model(model_config, pipeline_class, {})
164
-
165
- # Set loading parameters
166
- load_kwargs = {}
167
- if model_config.variant == "fp16":
168
- load_kwargs["torch_dtype"] = torch.float16
169
- load_kwargs["variant"] = "fp16"
170
- elif model_config.variant == "bf16":
171
- load_kwargs["torch_dtype"] = torch.bfloat16
172
-
173
- # Load pipeline
174
- logger.info(f"Loading model: {model_config.name}")
175
-
176
- # Special handling for FLUX models
177
- if model_config.model_type == "flux":
178
- # FLUX models work best with bfloat16, but use float32 on CPU or float16 on MPS
179
- if self.device == "cpu":
180
- load_kwargs["torch_dtype"] = torch.float32
181
- logger.info("Using float32 for FLUX model on CPU")
182
- logger.warning("⚠️ FLUX.1-dev is a 12B parameter model. CPU inference will be very slow!")
183
- logger.warning("⚠️ For better performance, consider using a GPU with at least 12GB VRAM")
184
- else:
185
- load_kwargs["torch_dtype"] = torch.bfloat16
186
- load_kwargs["use_safetensors"] = True
187
- logger.info("Using bfloat16 for FLUX model")
188
-
189
- # Special handling for Video (AnimateDiff) models
190
- elif model_config.model_type == "video":
191
- # AnimateDiff requires motion adapter
192
- logger.info("Loading AnimateDiff (video) model")
193
- motion_adapter_path = getattr(model_config, 'motion_adapter_path', None)
194
- if not motion_adapter_path:
195
- # Use default motion adapter if not specified
196
- motion_adapter_path = "guoyww/animatediff-motion-adapter-v1-5-2"
197
- logger.info(f"Using default motion adapter: {motion_adapter_path}")
198
-
199
- try:
200
- # Load motion adapter
201
- motion_adapter = MotionAdapter.from_pretrained(
202
- motion_adapter_path,
203
- torch_dtype=load_kwargs.get("torch_dtype", torch.float16)
204
- )
205
- load_kwargs["motion_adapter"] = motion_adapter
206
- logger.info(f"Motion adapter loaded from: {motion_adapter_path}")
207
- except Exception as e:
208
- logger.error(f"Failed to load motion adapter: {e}")
209
- return False
210
-
211
- # Disable safety checker for AnimateDiff
212
- load_kwargs["safety_checker"] = None
213
- load_kwargs["requires_safety_checker"] = False
214
- load_kwargs["feature_extractor"] = None
215
- logger.info("Safety checker disabled for AnimateDiff models")
216
-
217
- # Special handling for HiDream models
218
- elif model_config.model_type == "hidream":
219
- if not HIDREAM_AVAILABLE:
220
- logger.error("HiDream models require diffusers to be installed from source. Please install with: pip install git+https://github.com/huggingface/diffusers.git")
221
- return False
222
-
223
- logger.info("Loading HiDream model")
224
- # HiDream models work best with bfloat16
225
- if self.device == "cpu":
226
- load_kwargs["torch_dtype"] = torch.float32
227
- logger.info("Using float32 for HiDream model on CPU")
228
- logger.warning("⚠️ HiDream models are large. CPU inference will be slow!")
229
- else:
230
- load_kwargs["torch_dtype"] = torch.bfloat16
231
- logger.info("Using bfloat16 for HiDream model")
232
-
233
- # Disable safety checker for HiDream models
234
- load_kwargs["safety_checker"] = None
235
- load_kwargs["requires_safety_checker"] = False
236
- load_kwargs["feature_extractor"] = None
237
- logger.info("Safety checker disabled for HiDream models")
238
-
239
- # Disable safety checker for SD 1.5 to prevent false NSFW detections
240
- if model_config.model_type == "sd15" or model_config.model_type == "sdxl":
241
- load_kwargs["safety_checker"] = None
242
- load_kwargs["requires_safety_checker"] = False
243
- load_kwargs["feature_extractor"] = None
244
- # Use float32 for better numerical stability on SD 1.5
245
- if model_config.variant == "fp16" and (self.device == "cpu" or self.device == "mps"):
246
- load_kwargs["torch_dtype"] = torch.float32
247
- load_kwargs.pop("variant", None)
248
- logger.info(f"Using float32 for {self.device} inference to improve stability")
249
- elif self.device == "mps":
250
- # Force float32 on MPS for SD 1.5 to avoid NaN issues
251
- load_kwargs["torch_dtype"] = torch.float32
252
- logger.info("Using float32 for MPS inference to avoid NaN issues with SD 1.5")
253
- logger.info("Safety checker disabled for SD 1.5 to prevent false NSFW detections")
254
-
255
- # Disable safety checker for FLUX models to prevent false NSFW detections
256
- if model_config.model_type == "flux":
257
- load_kwargs["safety_checker"] = None
258
- load_kwargs["requires_safety_checker"] = False
259
- load_kwargs["feature_extractor"] = None
260
- logger.info("Safety checker disabled for FLUX models to prevent false NSFW detections")
261
-
262
- # Load pipeline
263
- self.pipeline = pipeline_class.from_pretrained(
264
- model_config.path,
265
- **load_kwargs
266
- )
267
-
268
- # Move to device with proper error handling
269
- try:
270
- self.pipeline = self.pipeline.to(self.device)
271
- logger.info(f"Pipeline moved to {self.device}")
272
- except Exception as e:
273
- logger.warning(f"Failed to move pipeline to {self.device}: {e}")
274
- if self.device != "cpu":
275
- logger.info("Falling back to CPU")
276
- self.device = "cpu"
277
- self.pipeline = self.pipeline.to("cpu")
278
-
279
- # Enable memory optimizations
280
- if hasattr(self.pipeline, 'enable_attention_slicing'):
281
- self.pipeline.enable_attention_slicing()
282
- logger.info("Enabled attention slicing for memory optimization")
283
-
284
- # Special optimizations for FLUX models
285
- if model_config.model_type == "flux":
286
- if self.device == "cuda":
287
- # CUDA-specific optimizations
288
- if hasattr(self.pipeline, 'enable_model_cpu_offload'):
289
- self.pipeline.enable_model_cpu_offload()
290
- logger.info("Enabled CPU offloading for FLUX model")
291
- elif self.device == "cpu":
292
- # CPU-specific optimizations
293
- logger.info("Applying CPU-specific optimizations for FLUX model")
294
- # Enable memory efficient attention if available
295
- if hasattr(self.pipeline, 'enable_xformers_memory_efficient_attention'):
296
- try:
297
- self.pipeline.enable_xformers_memory_efficient_attention()
298
- logger.info("Enabled xformers memory efficient attention")
299
- except Exception as e:
300
- logger.debug(f"xformers not available: {e}")
301
-
302
- # Set low memory mode
303
- if hasattr(self.pipeline, 'enable_sequential_cpu_offload'):
304
- try:
305
- self.pipeline.enable_sequential_cpu_offload()
306
- logger.info("Enabled sequential CPU offload for memory efficiency")
307
- except Exception as e:
308
- logger.debug(f"Sequential CPU offload not available: {e}")
309
-
310
- # Special optimizations for Video (AnimateDiff) models
311
- elif model_config.model_type == "video":
312
- logger.info("Applying optimizations for AnimateDiff video model")
313
- # Enable VAE slicing for video models to reduce memory usage
314
- if hasattr(self.pipeline, 'enable_vae_slicing'):
315
- self.pipeline.enable_vae_slicing()
316
- logger.info("Enabled VAE slicing for video model")
317
-
318
- # Enable model CPU offload for better memory management
319
- if self.device == "cuda" and hasattr(self.pipeline, 'enable_model_cpu_offload'):
320
- self.pipeline.enable_model_cpu_offload()
321
- logger.info("Enabled model CPU offload for video model")
322
-
323
- # Set scheduler to work well with AnimateDiff
324
- if hasattr(self.pipeline, 'scheduler'):
325
- from diffusers import DDIMScheduler
326
- try:
327
- self.pipeline.scheduler = DDIMScheduler.from_config(
328
- self.pipeline.scheduler.config,
329
- clip_sample=False,
330
- timestep_spacing="linspace",
331
- beta_schedule="linear",
332
- steps_offset=1,
333
- )
334
- logger.info("Configured DDIM scheduler for AnimateDiff")
335
- except Exception as e:
336
- logger.debug(f"Could not configure DDIM scheduler: {e}")
337
-
338
- # Special optimizations for HiDream models
339
- elif model_config.model_type == "hidream":
340
- logger.info("Applying optimizations for HiDream model")
341
- # Enable VAE slicing and tiling for HiDream models
342
- if hasattr(self.pipeline, 'enable_vae_slicing'):
343
- self.pipeline.enable_vae_slicing()
344
- logger.info("Enabled VAE slicing for HiDream model")
345
-
346
- if hasattr(self.pipeline, 'enable_vae_tiling'):
347
- self.pipeline.enable_vae_tiling()
348
- logger.info("Enabled VAE tiling for HiDream model")
349
-
350
- # Enable model CPU offload for better memory management
351
- if self.device == "cuda" and hasattr(self.pipeline, 'enable_model_cpu_offload'):
352
- self.pipeline.enable_model_cpu_offload()
353
- logger.info("Enabled model CPU offload for HiDream model")
354
- elif self.device == "cpu":
355
- # CPU-specific optimizations for HiDream
356
- if hasattr(self.pipeline, 'enable_sequential_cpu_offload'):
357
- try:
358
- self.pipeline.enable_sequential_cpu_offload()
359
- logger.info("Enabled sequential CPU offload for HiDream model")
360
- except Exception as e:
361
- logger.debug(f"Sequential CPU offload not available: {e}")
362
-
363
- # Additional safety checker disabling for SD 1.5 (in case the above didn't work)
364
- if model_config.model_type == "sd15" or model_config.model_type == "sdxl":
365
- if hasattr(self.pipeline, 'safety_checker'):
366
- self.pipeline.safety_checker = None
367
- if hasattr(self.pipeline, 'feature_extractor'):
368
- self.pipeline.feature_extractor = None
369
- if hasattr(self.pipeline, 'requires_safety_checker'):
370
- self.pipeline.requires_safety_checker = False
371
-
372
- # Monkey patch the safety checker call to always return False
373
- def dummy_safety_check(self, images, clip_input):
374
- return images, [False] * len(images)
375
-
376
- # Apply monkey patch if safety checker exists
377
- if hasattr(self.pipeline, '_safety_check'):
378
- self.pipeline._safety_check = dummy_safety_check.__get__(self.pipeline, type(self.pipeline))
379
-
380
- # Also monkey patch the run_safety_checker method if it exists
381
- if hasattr(self.pipeline, 'run_safety_checker'):
382
- def dummy_run_safety_checker(images, device, dtype):
383
- return images, [False] * len(images)
384
- self.pipeline.run_safety_checker = dummy_run_safety_checker
385
-
386
- # Monkey patch the check_inputs method to prevent safety checker validation
387
- if hasattr(self.pipeline, 'check_inputs'):
388
- original_check_inputs = self.pipeline.check_inputs
389
- def patched_check_inputs(*args, **kwargs):
390
- # Call original but ignore safety checker requirements
391
- try:
392
- return original_check_inputs(*args, **kwargs)
393
- except Exception as e:
394
- if "safety_checker" in str(e).lower():
395
- logger.debug(f"Ignoring safety checker validation error: {e}")
396
- return
397
- raise e
398
- self.pipeline.check_inputs = patched_check_inputs
399
-
400
- logger.info("Additional safety checker components disabled with monkey patch")
401
-
402
-
403
- # Load LoRA and other components
404
- if model_config.components and "lora" in model_config.components:
405
- self._load_lora(model_config)
406
-
407
- # Apply optimizations
408
- self._apply_optimizations()
409
-
410
- # Set tokenizer
411
- if hasattr(self.pipeline, 'tokenizer'):
412
- self.tokenizer = self.pipeline.tokenizer
413
-
414
- self.model_config = model_config
415
- logger.info(f"Model {model_config.name} loaded successfully")
416
- return True
417
-
418
- except Exception as e:
419
- logger.error(f"Failed to load model: {e}")
116
+
117
+ # Detect GGUF models by variant
118
+ model_type = model_config.model_type
119
+ if model_config.variant and "gguf" in model_config.variant.lower():
120
+ model_type = "gguf"
121
+
122
+ self.device = _detect_device()
123
+ self._strategy = _get_strategy(model_type)
124
+
125
+ if self._strategy.load(model_config, self.device):
126
+ self.model_config = model_config
127
+ # Update device in case strategy fell back to CPU
128
+ self.device = self._strategy.device
129
+ logger.info(f"Model {model_config.name} loaded successfully via {type(self._strategy).__name__}")
130
+ return True
131
+
132
+ self._strategy = None
420
133
  return False
421
-
422
- def _load_controlnet_model(self, model_config: ModelConfig, pipeline_class, load_kwargs: dict) -> bool:
423
- """Load ControlNet model with base model"""
424
- try:
425
- # Get base model info
426
- base_model_name = getattr(model_config, 'base_model', None)
427
- if not base_model_name:
428
- # Try to extract from model registry
429
- from ..models.manager import model_manager
430
- model_info = model_manager.get_model_info(model_config.name)
431
- if model_info and 'base_model' in model_info:
432
- base_model_name = model_info['base_model']
433
- else:
434
- logger.error(f"No base model specified for ControlNet model: {model_config.name}")
435
- return False
436
-
437
- # Check if base model is installed
438
- from ..models.manager import model_manager
439
- if not model_manager.is_model_installed(base_model_name):
440
- logger.error(f"Base model '{base_model_name}' not installed. Please install it first.")
441
- return False
442
-
443
- # Get base model config
444
- from ..config.settings import settings
445
- base_model_config = settings.models[base_model_name]
446
-
447
- # Set loading parameters based on variant
448
- if model_config.variant == "fp16":
449
- load_kwargs["torch_dtype"] = torch.float16
450
- load_kwargs["variant"] = "fp16"
451
- elif model_config.variant == "bf16":
452
- load_kwargs["torch_dtype"] = torch.bfloat16
453
-
454
- # Handle device-specific optimizations
455
- if self.device == "cpu" or self.device == "mps":
456
- load_kwargs["torch_dtype"] = torch.float32
457
- load_kwargs.pop("variant", None)
458
- logger.info(f"Using float32 for {self.device} inference to improve stability")
459
-
460
- # Disable safety checker
461
- load_kwargs["safety_checker"] = None
462
- load_kwargs["requires_safety_checker"] = False
463
- load_kwargs["feature_extractor"] = None
464
-
465
- # Load ControlNet model
466
- logger.info(f"Loading ControlNet model from: {model_config.path}")
467
- self.controlnet = ControlNetModel.from_pretrained(
468
- model_config.path,
469
- torch_dtype=load_kwargs.get("torch_dtype", torch.float32)
470
- )
471
-
472
- # Load pipeline with ControlNet and base model
473
- logger.info(f"Loading ControlNet pipeline with base model: {base_model_name}")
474
- self.pipeline = pipeline_class.from_pretrained(
475
- base_model_config.path,
476
- controlnet=self.controlnet,
477
- **load_kwargs
478
- )
479
-
480
- # Move to device
481
- try:
482
- self.pipeline = self.pipeline.to(self.device)
483
- self.controlnet = self.controlnet.to(self.device)
484
- logger.info(f"ControlNet pipeline moved to {self.device}")
485
- except Exception as e:
486
- logger.warning(f"Failed to move ControlNet pipeline to {self.device}: {e}")
487
- if self.device != "cpu":
488
- logger.info("Falling back to CPU")
489
- self.device = "cpu"
490
- self.pipeline = self.pipeline.to("cpu")
491
- self.controlnet = self.controlnet.to("cpu")
492
-
493
- # Enable memory optimizations
494
- if hasattr(self.pipeline, 'enable_attention_slicing'):
495
- self.pipeline.enable_attention_slicing()
496
- logger.info("Enabled attention slicing for ControlNet pipeline")
497
-
498
- # Apply additional optimizations
499
- self._apply_optimizations()
500
-
501
- # Set tokenizer
502
- if hasattr(self.pipeline, 'tokenizer'):
503
- self.tokenizer = self.pipeline.tokenizer
504
-
505
- self.model_config = model_config
506
- logger.info(f"ControlNet model {model_config.name} loaded successfully")
507
- return True
508
-
134
+
509
135
  except Exception as e:
510
- logger.error(f"Failed to load ControlNet model: {e}")
136
+ logger.error(f"Failed to load model: {e}")
137
+ self._strategy = None
511
138
  return False
512
-
513
- def _load_lora(self, model_config: ModelConfig):
514
- """Load LoRA weights"""
515
- try:
516
- lora_config = model_config.components["lora"]
517
-
518
- # Check if it's a Hugging Face Hub model
519
- if "repo_id" in lora_config:
520
- # Load from Hugging Face Hub
521
- repo_id = lora_config["repo_id"]
522
- weight_name = lora_config.get("weight_name", "pytorch_lora_weights.safetensors")
523
-
524
- logger.info(f"Loading LoRA from Hugging Face Hub: {repo_id}")
525
- self.pipeline.load_lora_weights(repo_id, weight_name=weight_name)
526
-
527
- # Set LoRA scale if specified
528
- if "scale" in lora_config:
529
- scale = lora_config["scale"]
530
- if hasattr(self.pipeline, 'set_adapters'):
531
- self.pipeline.set_adapters(["default"], adapter_weights=[scale])
532
- logger.info(f"Set LoRA scale to {scale}")
533
-
534
- logger.info(f"LoRA weights loaded successfully from {repo_id}")
535
-
536
- elif "filename" in lora_config:
537
- # Load from local file
538
- components_path = Path(model_config.path) / "components" / "lora"
539
- lora_path = components_path / lora_config["filename"]
540
- if lora_path.exists():
541
- self.pipeline.load_lora_weights(str(components_path), weight_name=lora_config["filename"])
542
- self.pipeline.fuse_lora()
543
- logger.info("LoRA weights loaded successfully from local file")
544
- else:
545
- # Load from directory
546
- components_path = Path(model_config.path) / "components" / "lora"
547
- if components_path.exists():
548
- self.pipeline.load_lora_weights(str(components_path))
549
- self.pipeline.fuse_lora()
550
- logger.info("LoRA weights loaded successfully from directory")
551
-
552
- except Exception as e:
553
- logger.warning(f"Failed to load LoRA weights: {e}")
554
-
555
- def _apply_optimizations(self):
556
- """Apply performance optimizations"""
557
- try:
558
- # Enable torch compile for faster inference
559
- if hasattr(torch, 'compile') and self.device != "mps": # MPS doesn't support torch.compile yet
560
- if hasattr(self.pipeline, 'unet'):
561
- try:
562
- self.pipeline.unet = torch.compile(
563
- self.pipeline.unet,
564
- mode="reduce-overhead",
565
- fullgraph=True
566
- )
567
- logger.info("torch.compile optimization enabled")
568
- except Exception as e:
569
- logger.debug(f"torch.compile failed: {e}")
570
- elif self.device == "mps":
571
- logger.debug("Skipping torch.compile on MPS (not supported yet)")
572
- elif self.device == "cpu":
573
- logger.debug("Skipping torch.compile on CPU for stability")
574
-
575
- except Exception as e:
576
- logger.warning(f"Failed to apply optimization settings: {e}")
577
-
578
- def truncate_prompt(self, prompt: str) -> str:
579
- """Truncate prompt to fit CLIP token limit"""
580
- if not prompt or not self.tokenizer:
581
- return prompt
582
-
583
- # Encode prompt
584
- tokens = self.tokenizer.encode(prompt)
585
-
586
- # Check if truncation is needed
587
- if len(tokens) <= self.max_token_limit:
588
- return prompt
589
-
590
- # Truncate tokens and decode back to text
591
- truncated_tokens = tokens[:self.max_token_limit]
592
- truncated_prompt = self.tokenizer.decode(truncated_tokens)
593
-
594
- logger.warning(f"Prompt truncated: {len(tokens)} -> {len(truncated_tokens)} tokens")
595
- return truncated_prompt
596
-
597
- def generate_image(self,
598
- prompt: str,
599
- negative_prompt: str = "low quality, bad anatomy, worst quality, low resolution",
600
- num_inference_steps: Optional[int] = None,
601
- guidance_scale: Optional[float] = None,
602
- width: int = 1024,
603
- height: int = 1024,
604
- control_image: Optional[Union[Image.Image, str]] = None,
605
- controlnet_conditioning_scale: float = 1.0,
606
- control_guidance_start: float = 0.0,
607
- control_guidance_end: float = 1.0,
608
- **kwargs) -> Image.Image:
609
- """Generate image"""
610
- # Check if we're using a GGUF model
611
- is_gguf_model = (
612
- self.model_config and
613
- (self.model_config.model_type == "gguf" or
614
- (self.model_config.variant and "gguf" in self.model_config.variant.lower()))
139
+
140
+ def generate_image(
141
+ self,
142
+ prompt: str,
143
+ negative_prompt: str = "low quality, bad anatomy, worst quality, low resolution",
144
+ num_inference_steps: Optional[int] = None,
145
+ guidance_scale: Optional[float] = None,
146
+ width: int = 1024,
147
+ height: int = 1024,
148
+ seed: Optional[int] = None,
149
+ control_image: Optional[Union[Image.Image, str]] = None,
150
+ controlnet_conditioning_scale: float = 1.0,
151
+ control_guidance_start: float = 0.0,
152
+ control_guidance_end: float = 1.0,
153
+ image: Optional[Image.Image] = None,
154
+ mask_image: Optional[Image.Image] = None,
155
+ strength: float = 0.75,
156
+ **kwargs,
157
+ ) -> Image.Image:
158
+ """Generate an image using the current strategy."""
159
+ if not self._strategy:
160
+ raise RuntimeError("No model loaded")
161
+
162
+ # Build kwargs to pass through
163
+ gen_kwargs = dict(kwargs)
164
+
165
+ # Pass img2img / inpainting params
166
+ if image is not None:
167
+ gen_kwargs["image"] = image
168
+ if mask_image is not None:
169
+ gen_kwargs["mask_image"] = mask_image
170
+ if image is not None or mask_image is not None:
171
+ gen_kwargs["strength"] = strength
172
+
173
+ # Pass ControlNet params
174
+ if control_image is not None:
175
+ gen_kwargs["control_image"] = control_image
176
+ gen_kwargs["controlnet_conditioning_scale"] = controlnet_conditioning_scale
177
+ gen_kwargs["control_guidance_start"] = control_guidance_start
178
+ gen_kwargs["control_guidance_end"] = control_guidance_end
179
+
180
+ return self._strategy.generate(
181
+ prompt=prompt,
182
+ negative_prompt=negative_prompt,
183
+ num_inference_steps=num_inference_steps,
184
+ guidance_scale=guidance_scale,
185
+ width=width,
186
+ height=height,
187
+ seed=seed,
188
+ **gen_kwargs,
615
189
  )
616
-
617
- if is_gguf_model:
618
- if not GGUF_AVAILABLE:
619
- raise RuntimeError("GGUF support not available")
620
-
621
- if not gguf_loader.is_loaded():
622
- raise RuntimeError("GGUF model not loaded")
623
-
624
- logger.info(f"Generating image using GGUF model: {prompt[:50]}...")
625
-
626
- # Use model default parameters for GGUF
627
- if num_inference_steps is None:
628
- num_inference_steps = self.model_config.parameters.get("num_inference_steps", 20)
629
-
630
- if guidance_scale is None:
631
- guidance_scale = self.model_config.parameters.get("guidance_scale", 7.5)
632
-
633
- # Generate using GGUF loader
634
- generation_kwargs = {
635
- "prompt": prompt,
636
- "negative_prompt": negative_prompt,
637
- "num_inference_steps": num_inference_steps,
638
- "guidance_scale": guidance_scale,
639
- "width": width,
640
- "height": height,
641
- **kwargs
642
- }
643
-
644
- try:
645
- image = gguf_loader.generate_image(**generation_kwargs)
646
- if image is None:
647
- logger.warning("GGUF generation returned None, creating error image")
648
- return self._create_error_image("GGUF generation failed or not yet implemented", prompt)
649
- return image
650
- except Exception as e:
651
- logger.error(f"GGUF generation failed: {e}")
652
- return self._create_error_image(str(e), prompt)
653
-
654
- # Continue with regular pipeline generation for non-GGUF models
655
- if not self.pipeline:
656
- raise RuntimeError("Model not loaded")
657
-
658
- # Handle ControlNet-specific logic
659
- if self.is_controlnet_pipeline:
660
- if control_image is None:
661
- raise ValueError("ControlNet model requires a control image")
662
-
663
- # Process control image
664
- control_image = self._prepare_control_image(control_image, width, height)
665
-
666
- # Use model default parameters
667
- if num_inference_steps is None:
668
- num_inference_steps = self.model_config.parameters.get("num_inference_steps", 28)
669
-
670
- if guidance_scale is None:
671
- guidance_scale = self.model_config.parameters.get("guidance_scale", 3.5)
672
-
673
- # Truncate prompts
674
- truncated_prompt = self.truncate_prompt(prompt)
675
- truncated_negative_prompt = self.truncate_prompt(negative_prompt)
676
-
677
- try:
678
- logger.info(f"Starting image generation: {truncated_prompt[:50]}...")
679
-
680
- # Generation parameters
681
- generation_kwargs = {
682
- "prompt": truncated_prompt,
683
- "negative_prompt": truncated_negative_prompt,
684
- "num_inference_steps": num_inference_steps,
685
- "guidance_scale": guidance_scale,
686
- **kwargs
687
- }
688
-
689
- # Add ControlNet parameters if this is a ControlNet pipeline
690
- if self.is_controlnet_pipeline and control_image is not None:
691
- generation_kwargs.update({
692
- "image": control_image,
693
- "controlnet_conditioning_scale": controlnet_conditioning_scale,
694
- "control_guidance_start": control_guidance_start,
695
- "control_guidance_end": control_guidance_end
696
- })
697
- logger.info(f"ControlNet parameters: conditioning_scale={controlnet_conditioning_scale}, "
698
- f"guidance_start={control_guidance_start}, guidance_end={control_guidance_end}")
699
-
700
- # Add size parameters based on model type
701
- if self.model_config.model_type in ["sdxl", "sd3", "flux", "controlnet_sdxl"]:
702
- generation_kwargs.update({
703
- "width": width,
704
- "height": height
705
- })
706
-
707
- # FLUX models have special parameters
708
- if self.model_config.model_type == "flux":
709
- # Add max_sequence_length for FLUX
710
- max_seq_len = self.model_config.parameters.get("max_sequence_length", 512)
711
- generation_kwargs["max_sequence_length"] = max_seq_len
712
- logger.info(f"Using max_sequence_length={max_seq_len} for FLUX model")
713
-
714
- # Special handling for FLUX.1-schnell (distilled model)
715
- if "schnell" in self.model_config.name.lower():
716
- # FLUX.1-schnell doesn't use guidance
717
- if guidance_scale != 0.0:
718
- logger.info("FLUX.1-schnell detected - setting guidance_scale to 0.0 (distilled model doesn't use guidance)")
719
- generation_kwargs["guidance_scale"] = 0.0
720
-
721
- # Use fewer steps for schnell (it's designed for 1-4 steps)
722
- if num_inference_steps > 4:
723
- logger.info(f"FLUX.1-schnell detected - reducing steps from {num_inference_steps} to 4 for optimal performance")
724
- generation_kwargs["num_inference_steps"] = 4
725
-
726
- logger.info("🚀 Using FLUX.1-schnell - fast distilled model optimized for 4-step generation")
727
-
728
- # Device-specific adjustments for FLUX
729
- if self.device == "cpu":
730
- # Reduce steps for faster CPU inference
731
- if "schnell" not in self.model_config.name.lower() and num_inference_steps > 20:
732
- num_inference_steps = 20
733
- generation_kwargs["num_inference_steps"] = num_inference_steps
734
- logger.info(f"Reduced inference steps to {num_inference_steps} for CPU performance")
735
-
736
- # Lower guidance scale for CPU stability (except for schnell which uses 0.0)
737
- if "schnell" not in self.model_config.name.lower() and guidance_scale > 5.0:
738
- guidance_scale = 5.0
739
- generation_kwargs["guidance_scale"] = guidance_scale
740
- logger.info(f"Reduced guidance scale to {guidance_scale} for CPU stability")
741
-
742
- logger.warning("🐌 CPU inference detected - this may take several minutes per image")
743
- elif self.device == "mps":
744
- # MPS-specific adjustments for stability (except for schnell which uses 0.0)
745
- if "schnell" not in self.model_config.name.lower() and guidance_scale > 7.0:
746
- guidance_scale = 7.0
747
- generation_kwargs["guidance_scale"] = guidance_scale
748
- logger.info(f"Reduced guidance scale to {guidance_scale} for MPS stability")
749
-
750
- logger.info("🍎 MPS inference - should be faster than CPU but slower than CUDA")
751
-
752
- elif self.model_config.model_type in ["sd15", "controlnet_sd15"]:
753
- # SD 1.5 and ControlNet SD 1.5 work best with 512x512, adjust if different sizes requested
754
- if width != 1024 or height != 1024:
755
- generation_kwargs.update({
756
- "width": width,
757
- "height": height
758
- })
759
- else:
760
- # Use optimal size for SD 1.5
761
- generation_kwargs.update({
762
- "width": 512,
763
- "height": 512
764
- })
765
-
766
- # Special handling for Video (AnimateDiff) models
767
- elif self.model_config.model_type == "video":
768
- logger.info("Configuring AnimateDiff video generation parameters")
769
-
770
- # Video-specific parameters
771
- num_frames = kwargs.get("num_frames", 16)
772
- generation_kwargs["num_frames"] = num_frames
773
-
774
- # AnimateDiff works best with specific resolutions
775
- # Use 512x512 for better compatibility with most motion adapters
776
- generation_kwargs.update({
777
- "width": 512,
778
- "height": 512
779
- })
780
- logger.info(f"Using 512x512 resolution for AnimateDiff compatibility")
781
-
782
- # Set optimal parameters for video generation
783
- if guidance_scale > 7.5:
784
- generation_kwargs["guidance_scale"] = 7.5
785
- logger.info("Reduced guidance scale to 7.5 for video stability")
786
-
787
- if num_inference_steps > 25:
788
- generation_kwargs["num_inference_steps"] = 25
789
- logger.info("Reduced inference steps to 25 for video generation")
790
-
791
- logger.info(f"Generating {num_frames} frames for video output")
792
-
793
- # Special handling for HiDream models
794
- elif self.model_config.model_type == "hidream":
795
- logger.info("Configuring HiDream model parameters")
796
-
797
- # HiDream models support high resolution
798
- generation_kwargs.update({
799
- "width": width,
800
- "height": height
801
- })
802
-
803
- # HiDream models have multiple text encoders, handle if provided
804
- if "prompt_2" in kwargs:
805
- generation_kwargs["prompt_2"] = self.truncate_prompt(kwargs["prompt_2"])
806
- if "prompt_3" in kwargs:
807
- generation_kwargs["prompt_3"] = self.truncate_prompt(kwargs["prompt_3"])
808
- if "prompt_4" in kwargs:
809
- generation_kwargs["prompt_4"] = self.truncate_prompt(kwargs["prompt_4"])
810
-
811
- # Set optimal parameters for HiDream
812
- max_seq_len = self.model_config.parameters.get("max_sequence_length", 128)
813
- generation_kwargs["max_sequence_length"] = max_seq_len
814
-
815
- # HiDream models use different guidance scale defaults
816
- if guidance_scale is None or guidance_scale == 3.5:
817
- generation_kwargs["guidance_scale"] = 5.0
818
- logger.info("Using default guidance_scale=5.0 for HiDream model")
819
-
820
- logger.info(f"Using max_sequence_length={max_seq_len} for HiDream model")
821
-
822
- # Generate image
823
- logger.info(f"Generation parameters: steps={num_inference_steps}, guidance={guidance_scale}")
824
-
825
- # Add generator for reproducible results
826
- if self.device == "cpu":
827
- generator = torch.Generator().manual_seed(42)
828
- else:
829
- generator = torch.Generator(device=self.device).manual_seed(42)
830
- generation_kwargs["generator"] = generator
831
-
832
- # For SD 1.5, use a more conservative approach to avoid numerical issues
833
- if self.model_config.model_type == "sd15":
834
- # Lower guidance scale to prevent numerical instability
835
- if generation_kwargs["guidance_scale"] > 7.0:
836
- generation_kwargs["guidance_scale"] = 7.0
837
- logger.info("Reduced guidance scale to 7.0 for stability")
838
-
839
- # Ensure we're using float32 for better numerical stability
840
- if self.device == "mps":
841
- # For Apple Silicon, use specific optimizations
842
- generation_kwargs["guidance_scale"] = min(generation_kwargs["guidance_scale"], 6.0)
843
- logger.info("Applied MPS-specific optimizations")
844
-
845
- # For SD 1.5, use manual pipeline execution to completely bypass safety checker
846
- if self.model_config.model_type == "sd15" and not self.is_controlnet_pipeline:
847
- logger.info("Using manual pipeline execution for SD 1.5 to bypass safety checker")
848
- try:
849
- # Manual pipeline execution with safety checks disabled
850
- with torch.no_grad():
851
- # Encode prompt
852
- text_inputs = self.pipeline.tokenizer(
853
- generation_kwargs["prompt"],
854
- padding="max_length",
855
- max_length=self.pipeline.tokenizer.model_max_length,
856
- truncation=True,
857
- return_tensors="pt",
858
- )
859
- text_embeddings = self.pipeline.text_encoder(text_inputs.input_ids.to(self.device))[0]
860
-
861
- # Encode negative prompt
862
- uncond_inputs = self.pipeline.tokenizer(
863
- generation_kwargs["negative_prompt"],
864
- padding="max_length",
865
- max_length=self.pipeline.tokenizer.model_max_length,
866
- truncation=True,
867
- return_tensors="pt",
868
- )
869
- uncond_embeddings = self.pipeline.text_encoder(uncond_inputs.input_ids.to(self.device))[0]
870
-
871
- # Concatenate embeddings
872
- text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
873
-
874
- # Generate latents
875
- latents = torch.randn(
876
- (1, self.pipeline.unet.config.in_channels,
877
- generation_kwargs["height"] // 8, generation_kwargs["width"] // 8),
878
- generator=generation_kwargs["generator"],
879
- device=self.device,
880
- dtype=text_embeddings.dtype,
881
- )
882
-
883
- logger.debug(f"Initial latents stats - mean: {latents.mean().item():.4f}, std: {latents.std().item():.4f}")
884
- logger.debug(f"Text embeddings stats - mean: {text_embeddings.mean().item():.4f}, std: {text_embeddings.std().item():.4f}")
885
-
886
- # Set scheduler
887
- self.pipeline.scheduler.set_timesteps(generation_kwargs["num_inference_steps"])
888
- latents = latents * self.pipeline.scheduler.init_noise_sigma
889
-
890
- logger.debug(f"Latents after noise scaling - mean: {latents.mean().item():.4f}, std: {latents.std().item():.4f}")
891
- logger.debug(f"Scheduler init_noise_sigma: {self.pipeline.scheduler.init_noise_sigma}")
892
-
893
- # Denoising loop
894
- for i, t in enumerate(self.pipeline.scheduler.timesteps):
895
- latent_model_input = torch.cat([latents] * 2)
896
- latent_model_input = self.pipeline.scheduler.scale_model_input(latent_model_input, t)
897
-
898
- # Check for NaN before UNet
899
- if torch.isnan(latent_model_input).any():
900
- logger.error(f"NaN detected in latent_model_input at step {i}")
901
- break
902
-
903
- noise_pred = self.pipeline.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
904
-
905
- # Check for NaN after UNet
906
- if torch.isnan(noise_pred).any():
907
- logger.error(f"NaN detected in noise_pred at step {i}")
908
- break
909
-
910
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
911
- noise_pred = noise_pred_uncond + generation_kwargs["guidance_scale"] * (noise_pred_text - noise_pred_uncond)
912
-
913
- # Check for NaN after guidance
914
- if torch.isnan(noise_pred).any():
915
- logger.error(f"NaN detected after guidance at step {i}")
916
- break
917
-
918
- latents = self.pipeline.scheduler.step(noise_pred, t, latents).prev_sample
919
-
920
- # Check for NaN after scheduler step
921
- if torch.isnan(latents).any():
922
- logger.error(f"NaN detected in latents after scheduler step {i}")
923
- break
924
-
925
- if i == 0: # Log first step for debugging
926
- logger.debug(f"Step {i}: latents mean={latents.mean().item():.4f}, std={latents.std().item():.4f}")
927
-
928
- # Decode latents
929
- latents = 1 / self.pipeline.vae.config.scaling_factor * latents
930
-
931
- # Debug latents before VAE decode
932
- logger.debug(f"Latents stats before VAE decode - mean: {latents.mean().item():.4f}, std: {latents.std().item():.4f}")
933
- logger.debug(f"Latents range: [{latents.min().item():.4f}, {latents.max().item():.4f}]")
934
-
935
- with torch.no_grad():
936
- # Ensure latents are on correct device and dtype
937
- latents = latents.to(device=self.device, dtype=self.pipeline.vae.dtype)
938
-
939
- try:
940
- image = self.pipeline.vae.decode(latents).sample
941
- logger.debug(f"VAE decode successful - image shape: {image.shape}")
942
- except Exception as e:
943
- logger.error(f"VAE decode failed: {e}")
944
- # Create a fallback image
945
- image = torch.randn_like(latents).repeat(1, 3, 8, 8) * 0.1 + 0.5
946
- logger.warning("Using fallback random image due to VAE decode failure")
947
-
948
- # Convert to PIL with proper NaN/inf handling
949
- logger.debug(f"Image stats after VAE decode - mean: {image.mean().item():.4f}, std: {image.std().item():.4f}")
950
- logger.debug(f"Image range: [{image.min().item():.4f}, {image.max().item():.4f}]")
951
-
952
- image = (image / 2 + 0.5).clamp(0, 1)
953
-
954
- logger.debug(f"Image stats after normalization - mean: {image.mean().item():.4f}, std: {image.std().item():.4f}")
955
- logger.debug(f"Image range after norm: [{image.min().item():.4f}, {image.max().item():.4f}]")
956
-
957
- # Check for NaN or infinite values before conversion
958
- if torch.isnan(image).any() or torch.isinf(image).any():
959
- logger.warning("NaN or infinite values detected in image tensor, applying selective fixes")
960
- # Only replace NaN/inf values, keep valid pixels intact
961
- nan_mask = torch.isnan(image)
962
- inf_mask = torch.isinf(image)
963
-
964
- # Replace only problematic pixels
965
- image = torch.where(nan_mask, torch.tensor(0.5, device=image.device, dtype=image.dtype), image)
966
- image = torch.where(inf_mask & (image > 0), torch.tensor(1.0, device=image.device, dtype=image.dtype), image)
967
- image = torch.where(inf_mask & (image < 0), torch.tensor(0.0, device=image.device, dtype=image.dtype), image)
968
-
969
- logger.info(f"Fixed {nan_mask.sum().item()} NaN pixels and {inf_mask.sum().item()} infinite pixels")
970
-
971
- # Final clamp to ensure valid range
972
- image = torch.clamp(image, 0, 1)
973
-
974
- image = image.cpu().permute(0, 2, 3, 1).float().numpy()
975
-
976
- # Additional validation before uint8 conversion - only fix problematic pixels
977
- if np.isnan(image).any() or np.isinf(image).any():
978
- logger.warning("NaN/inf values detected in numpy array, applying selective fixes")
979
- nan_count = np.isnan(image).sum()
980
- inf_count = np.isinf(image).sum()
981
-
982
- # Only replace problematic pixels, preserve valid ones
983
- image = np.where(np.isnan(image), 0.5, image)
984
- image = np.where(np.isinf(image) & (image > 0), 1.0, image)
985
- image = np.where(np.isinf(image) & (image < 0), 0.0, image)
986
-
987
- logger.info(f"Fixed {nan_count} NaN and {inf_count} infinite pixels in numpy array")
988
-
989
- # Ensure valid range
990
- image = np.clip(image, 0, 1)
991
-
992
- # Safe conversion to uint8
993
- image = (image * 255).astype(np.uint8)
994
-
995
- from PIL import Image as PILImage
996
- image = PILImage.fromarray(image[0])
997
-
998
- # Create a mock output object
999
- class MockOutput:
1000
- def __init__(self, images):
1001
- self.images = images
1002
- self.nsfw_content_detected = [False] * len(images)
1003
-
1004
- output = MockOutput([image])
1005
-
1006
- except Exception as e:
1007
- logger.error(f"Manual pipeline execution failed: {e}")
1008
- raise e
1009
- else:
1010
- # For FLUX and other models, use regular pipeline execution with safety checker disabled
1011
- logger.info(f"Using regular pipeline execution for {self.model_config.model_type} model")
1012
-
1013
- # Debug: Log device and generation kwargs
1014
- logger.debug(f"Pipeline device: {self.device}")
1015
- logger.debug(f"Generator device: {generation_kwargs['generator'].device if hasattr(generation_kwargs['generator'], 'device') else 'CPU'}")
1016
-
1017
- # Ensure all tensors are on the correct device
1018
- try:
1019
- # For FLUX models, temporarily disable any remaining safety checker components
1020
- if self.model_config.model_type == "flux":
1021
- # Store original safety checker components
1022
- original_safety_checker = getattr(self.pipeline, 'safety_checker', None)
1023
- original_feature_extractor = getattr(self.pipeline, 'feature_extractor', None)
1024
- original_requires_safety_checker = getattr(self.pipeline, 'requires_safety_checker', None)
1025
-
1026
- # Temporarily set to None
1027
- if hasattr(self.pipeline, 'safety_checker'):
1028
- self.pipeline.safety_checker = None
1029
- if hasattr(self.pipeline, 'feature_extractor'):
1030
- self.pipeline.feature_extractor = None
1031
- if hasattr(self.pipeline, 'requires_safety_checker'):
1032
- self.pipeline.requires_safety_checker = False
1033
-
1034
- logger.info("Temporarily disabled safety checker components for FLUX generation")
1035
-
1036
- output = self.pipeline(**generation_kwargs)
1037
-
1038
- # Restore original safety checker components for FLUX (though they should remain None)
1039
- if self.model_config.model_type == "flux":
1040
- if hasattr(self.pipeline, 'safety_checker'):
1041
- self.pipeline.safety_checker = original_safety_checker
1042
- if hasattr(self.pipeline, 'feature_extractor'):
1043
- self.pipeline.feature_extractor = original_feature_extractor
1044
- if hasattr(self.pipeline, 'requires_safety_checker'):
1045
- self.pipeline.requires_safety_checker = original_requires_safety_checker
1046
-
1047
- except RuntimeError as e:
1048
- if "CUDA" in str(e) and self.device == "cpu":
1049
- logger.error(f"CUDA error on CPU device: {e}")
1050
- logger.info("Attempting to fix device mismatch...")
1051
-
1052
- # Remove generator and try again
1053
- generation_kwargs_fixed = generation_kwargs.copy()
1054
- generation_kwargs_fixed.pop("generator", None)
1055
-
1056
- output = self.pipeline(**generation_kwargs_fixed)
1057
- else:
1058
- raise e
1059
-
1060
- # Special handling for FLUX models to bypass any remaining safety checker issues
1061
- if self.model_config.model_type == "flux" and hasattr(output, 'images'):
1062
- # Check if we got a black image and try to regenerate with different approach
1063
- test_image = output.images[0]
1064
- test_array = np.array(test_image)
1065
-
1066
- if np.all(test_array == 0):
1067
- logger.warning("FLUX model returned black image, attempting manual image processing")
1068
-
1069
- # Try to access the raw latents or intermediate results
1070
- if hasattr(output, 'latents') or hasattr(self.pipeline, 'vae'):
1071
- try:
1072
- # Generate a simple test image to verify the pipeline is working
1073
- logger.info("Generating test image with simple prompt")
1074
- simple_kwargs = generation_kwargs.copy()
1075
- simple_kwargs["prompt"] = "a red apple"
1076
- simple_kwargs["negative_prompt"] = ""
1077
-
1078
- # Temporarily disable any image processing that might cause issues
1079
- original_image_processor = getattr(self.pipeline, 'image_processor', None)
1080
- if hasattr(self.pipeline, 'image_processor'):
1081
- # Create a custom image processor that handles NaN values
1082
- class SafeImageProcessor:
1083
- def postprocess(self, image, output_type="pil", do_denormalize=None):
1084
- if isinstance(image, torch.Tensor):
1085
- # Handle NaN and inf values before conversion
1086
- image = torch.nan_to_num(image, nan=0.5, posinf=1.0, neginf=0.0)
1087
- image = torch.clamp(image, 0, 1)
1088
-
1089
- # Convert to numpy
1090
- image = image.cpu().permute(0, 2, 3, 1).float().numpy()
1091
-
1092
- # Additional safety checks
1093
- image = np.nan_to_num(image, nan=0.5, posinf=1.0, neginf=0.0)
1094
- image = np.clip(image, 0, 1)
1095
-
1096
- # Convert to uint8 safely
1097
- image = (image * 255).astype(np.uint8)
1098
-
1099
- # Convert to PIL
1100
- if output_type == "pil":
1101
- from PIL import Image as PILImage
1102
- return [PILImage.fromarray(img) for img in image]
1103
- return image
1104
- return image
1105
-
1106
- self.pipeline.image_processor = SafeImageProcessor()
1107
- logger.info("Applied safe image processor for FLUX model")
1108
-
1109
- # Try generation again with safe image processor
1110
- test_output = self.pipeline(**simple_kwargs)
1111
-
1112
- # Restore original image processor
1113
- if original_image_processor:
1114
- self.pipeline.image_processor = original_image_processor
1115
-
1116
- if hasattr(test_output, 'images') and len(test_output.images) > 0:
1117
- test_result = np.array(test_output.images[0])
1118
- if not np.all(test_result == 0):
1119
- logger.info("Test generation successful, using original output")
1120
- # The issue might be with the specific prompt, return the test image
1121
- output = test_output
1122
- else:
1123
- logger.warning("Test generation also returned black image")
1124
-
1125
- except Exception as e:
1126
- logger.warning(f"Manual image processing failed: {e}")
1127
-
1128
- # Check if output contains nsfw_content_detected
1129
- if hasattr(output, 'nsfw_content_detected') and output.nsfw_content_detected:
1130
- logger.warning("NSFW content detected by pipeline - this should not happen with safety checker disabled")
1131
-
1132
- # Special handling for video models that return multiple frames
1133
- if self.model_config.model_type == "video":
1134
- logger.info(f"Processing video output with {len(output.frames)} frames")
1135
-
1136
- # For now, return the first frame as a single image
1137
- # In the future, this could be extended to return a video file or GIF
1138
- if hasattr(output, 'frames') and len(output.frames) > 0:
1139
- image = output.frames[0]
1140
- logger.info("Extracted first frame from video generation")
1141
- else:
1142
- # Fallback to images if frames not available
1143
- image = output.images[0]
1144
- logger.info("Using first image from video output")
1145
-
1146
- # TODO: Add option to save all frames or create a GIF
1147
- # frames = output.frames if hasattr(output, 'frames') else output.images
1148
- # save_video_frames(frames, prompt)
1149
- else:
1150
- # Standard single image output for other models
1151
- image = output.images[0]
1152
-
1153
- # Debug: Check image properties
1154
- logger.info(f"Generated image size: {image.size}, mode: {image.mode}")
1155
-
1156
- # Validate and fix image data if needed
1157
- image = self._validate_and_fix_image(image)
1158
-
1159
- logger.info("Image generation completed")
1160
- return image
1161
-
1162
- except Exception as e:
1163
- logger.error(f"Image generation failed: {e}")
1164
- # Return error image
1165
- return self._create_error_image(str(e), truncated_prompt)
1166
-
1167
- def _validate_and_fix_image(self, image: Image.Image) -> Image.Image:
1168
- """Validate and fix image data to handle NaN/infinite values"""
1169
- try:
1170
- # Convert PIL image to numpy array
1171
- img_array = np.array(image)
1172
-
1173
- # Check if image is completely black (safety checker replacement)
1174
- if np.all(img_array == 0):
1175
- logger.error("Generated image is completely black - likely safety checker issue")
1176
- if self.model_config.model_type == "flux":
1177
- logger.error("FLUX model safety checker is still active despite our attempts to disable it")
1178
- logger.error("This suggests the safety checker is built into the model weights or pipeline")
1179
- logger.info("Attempting to generate a test pattern instead of black image")
1180
-
1181
- # Create a test pattern to show the system is working
1182
- test_image = np.zeros_like(img_array)
1183
- height, width = test_image.shape[:2]
1184
-
1185
- # Create a simple gradient pattern
1186
- for i in range(height):
1187
- for j in range(width):
1188
- test_image[i, j] = [
1189
- int(255 * i / height), # Red gradient
1190
- int(255 * j / width), # Green gradient
1191
- 128 # Blue constant
1192
- ]
1193
-
1194
- logger.info("Created test gradient pattern to replace black image")
1195
- return Image.fromarray(test_image.astype(np.uint8))
1196
- else:
1197
- logger.error("This suggests the safety checker is still active despite our attempts to disable it")
1198
-
1199
- # Check for NaN or infinite values
1200
- if np.isnan(img_array).any() or np.isinf(img_array).any():
1201
- logger.warning("Invalid values (NaN/inf) detected in generated image, applying fixes")
1202
-
1203
- # Replace NaN and infinite values with valid ranges
1204
- img_array = np.nan_to_num(img_array, nan=0.0, posinf=255.0, neginf=0.0)
1205
-
1206
- # Ensure values are in valid range [0, 255]
1207
- img_array = np.clip(img_array, 0, 255)
1208
-
1209
- # Convert back to PIL Image
1210
- image = Image.fromarray(img_array.astype(np.uint8))
1211
- logger.info("Image data fixed successfully")
1212
-
1213
- # Log image statistics for debugging
1214
- mean_val = np.mean(img_array)
1215
- std_val = np.std(img_array)
1216
- logger.info(f"Image stats - mean: {mean_val:.2f}, std: {std_val:.2f}")
1217
-
1218
- # Additional check for very low variance (mostly black/gray)
1219
- if std_val < 10.0 and mean_val < 50.0:
1220
- logger.warning(f"Image has very low variance (std={std_val:.2f}) and low brightness (mean={mean_val:.2f})")
1221
- logger.warning("This might indicate safety checker interference or generation issues")
1222
- if self.model_config.model_type == "flux":
1223
- logger.info("For FLUX models, try using different prompts or adjusting generation parameters")
1224
-
1225
- return image
1226
-
1227
- except Exception as e:
1228
- logger.warning(f"Failed to validate image data: {e}, returning original image")
1229
- return image
1230
-
1231
- def _prepare_control_image(self, control_image: Union[Image.Image, str], width: int, height: int) -> Image.Image:
1232
- """Prepare control image for ControlNet"""
1233
- try:
1234
- # Initialize ControlNet preprocessors if needed
1235
- if not controlnet_preprocessor.is_initialized():
1236
- logger.info("Initializing ControlNet preprocessors for image processing...")
1237
- if not controlnet_preprocessor.initialize():
1238
- logger.error("Failed to initialize ControlNet preprocessors")
1239
- # Continue with basic processing
1240
-
1241
- # Load image if path is provided
1242
- if isinstance(control_image, str):
1243
- control_image = Image.open(control_image).convert('RGB')
1244
- elif not isinstance(control_image, Image.Image):
1245
- raise ValueError("Control image must be PIL Image or file path")
1246
-
1247
- # Ensure image is RGB
1248
- if control_image.mode != 'RGB':
1249
- control_image = control_image.convert('RGB')
1250
-
1251
- # Get ControlNet type from model config
1252
- from ..models.manager import model_manager
1253
- model_info = model_manager.get_model_info(self.model_config.name)
1254
- controlnet_type = model_info.get('controlnet_type', 'canny') if model_info else 'canny'
1255
-
1256
- # Preprocess the control image based on ControlNet type
1257
- logger.info(f"Preprocessing control image for {controlnet_type} ControlNet")
1258
- processed_image = controlnet_preprocessor.preprocess(control_image, controlnet_type)
1259
-
1260
- # Resize to match generation size
1261
- processed_image = controlnet_preprocessor.resize_for_controlnet(processed_image, width, height)
1262
-
1263
- logger.info(f"Control image prepared: {processed_image.size}")
1264
- return processed_image
1265
-
1266
- except Exception as e:
1267
- logger.error(f"Failed to prepare control image: {e}")
1268
- # Return resized original image as fallback
1269
- if isinstance(control_image, str):
1270
- control_image = Image.open(control_image).convert('RGB')
1271
- return controlnet_preprocessor.resize_for_controlnet(control_image, width, height)
1272
-
1273
- def _create_error_image(self, error_msg: str, prompt: str) -> Image.Image:
1274
- """Create error message image"""
1275
- from PIL import ImageDraw, ImageFont
1276
-
1277
- # Create white background image
1278
- img = Image.new('RGB', (512, 512), color=(255, 255, 255))
1279
- draw = ImageDraw.Draw(img)
1280
-
1281
- # Draw error information
1282
- try:
1283
- # Try to use system font
1284
- font = ImageFont.load_default()
1285
- except:
1286
- font = None
1287
-
1288
- # Draw text
1289
- draw.text((10, 10), f"Error: {error_msg}", fill=(255, 0, 0), font=font)
1290
- draw.text((10, 30), f"Prompt: {prompt[:50]}...", fill=(0, 0, 0), font=font)
1291
-
1292
- return img
1293
-
190
+
1294
191
  def unload(self):
1295
- """Unload model and free GPU memory"""
1296
- # Handle GGUF models
1297
- is_gguf_model = (
1298
- self.model_config and
1299
- (self.model_config.model_type == "gguf" or
1300
- (self.model_config.variant and "gguf" in self.model_config.variant.lower()))
1301
- )
1302
-
1303
- if is_gguf_model:
1304
- if GGUF_AVAILABLE and gguf_loader.is_loaded():
1305
- gguf_loader.unload_model()
1306
- logger.info("GGUF model unloaded")
1307
-
1308
- self.model_config = None
1309
- self.tokenizer = None
1310
- return
1311
-
1312
- # Handle regular diffusion models
1313
- if self.pipeline:
1314
- # Move to CPU to free GPU memory
1315
- self.pipeline = self.pipeline.to("cpu")
1316
-
1317
- # Clear GPU cache
1318
- if torch.cuda.is_available():
1319
- torch.cuda.empty_cache()
1320
-
1321
- # Delete pipeline
1322
- del self.pipeline
1323
- self.pipeline = None
1324
- self.model_config = None
1325
- self.tokenizer = None
1326
-
1327
- logger.info("Model unloaded")
1328
-
192
+ """Unload the current model and free resources."""
193
+ if self._strategy:
194
+ self._strategy.unload()
195
+ self._strategy = None
196
+ self.model_config = None
197
+ self.device = None
198
+ logger.info("Engine unloaded")
199
+
1329
200
  def is_loaded(self) -> bool:
1330
- """Check if model is loaded"""
1331
- # Check GGUF models
1332
- is_gguf_model = (
1333
- self.model_config and
1334
- (self.model_config.model_type == "gguf" or
1335
- (self.model_config.variant and "gguf" in self.model_config.variant.lower()))
1336
- )
1337
-
1338
- if is_gguf_model:
1339
- return GGUF_AVAILABLE and gguf_loader.is_loaded()
1340
-
1341
- # Check regular pipeline models
1342
- return self.pipeline is not None
1343
-
1344
- def load_lora_runtime(self, repo_id: str, weight_name: str = None, scale: float = 1.0):
1345
- """Load LoRA weights at runtime"""
1346
- if not self.pipeline:
1347
- raise RuntimeError("Model not loaded")
1348
-
1349
- try:
1350
- if weight_name:
1351
- logger.info(f"Loading LoRA from {repo_id} with weight {weight_name}")
1352
- self.pipeline.load_lora_weights(repo_id, weight_name=weight_name)
1353
- else:
1354
- logger.info(f"Loading LoRA from {repo_id}")
1355
- self.pipeline.load_lora_weights(repo_id)
1356
-
1357
- # Set LoRA scale
1358
- if hasattr(self.pipeline, 'set_adapters') and scale != 1.0:
1359
- self.pipeline.set_adapters(["default"], adapter_weights=[scale])
1360
- logger.info(f"Set LoRA scale to {scale}")
1361
-
1362
- # Track LoRA state
1363
- self.current_lora = {
1364
- "repo_id": repo_id,
1365
- "weight_name": weight_name,
1366
- "scale": scale,
1367
- "loaded": True
1368
- }
1369
-
1370
- logger.info("LoRA weights loaded successfully at runtime")
1371
- return True
1372
-
1373
- except Exception as e:
1374
- logger.error(f"Failed to load LoRA weights at runtime: {e}")
1375
- return False
1376
-
1377
- def unload_lora(self):
1378
- """Unload LoRA weights"""
1379
- if not self.pipeline:
1380
- return False
1381
-
1382
- try:
1383
- if hasattr(self.pipeline, 'unload_lora_weights'):
1384
- self.pipeline.unload_lora_weights()
1385
- # Clear LoRA state
1386
- self.current_lora = None
1387
- logger.info("LoRA weights unloaded successfully")
1388
- return True
1389
- else:
1390
- logger.warning("Pipeline does not support LoRA unloading")
1391
- return False
1392
- except Exception as e:
1393
- logger.error(f"Failed to unload LoRA weights: {e}")
1394
- return False
201
+ """Check if a model is loaded."""
202
+ return self._strategy is not None and self._strategy.is_loaded
1395
203
 
1396
204
  def get_model_info(self) -> Optional[Dict[str, Any]]:
1397
- """Get current loaded model information"""
1398
- if not self.model_config:
205
+ """Get information about the currently loaded model."""
206
+ if not self._strategy:
1399
207
  return None
1400
-
1401
- base_info = {
1402
- "name": self.model_config.name,
1403
- "type": self.model_config.model_type,
1404
- "device": self.device,
1405
- "variant": self.model_config.variant,
1406
- "parameters": self.model_config.parameters
1407
- }
1408
-
1409
- # Check if this is a GGUF model
1410
- is_gguf_model = (
1411
- self.model_config.model_type == "gguf" or
1412
- (self.model_config.variant and "gguf" in self.model_config.variant.lower())
1413
- )
1414
-
1415
- # Add GGUF-specific information
1416
- if is_gguf_model and GGUF_AVAILABLE:
1417
- gguf_info = gguf_loader.get_model_info()
1418
- base_info.update(gguf_info)
1419
- base_info["gguf_available"] = True
1420
- base_info["gguf_loaded"] = gguf_loader.is_loaded()
1421
- base_info["is_gguf"] = True
1422
- else:
1423
- base_info["gguf_available"] = GGUF_AVAILABLE
1424
- base_info["gguf_loaded"] = False
1425
- base_info["is_gguf"] = is_gguf_model
1426
-
1427
- return base_info
208
+ info = self._strategy.get_info()
209
+ info["strategy"] = type(self._strategy).__name__
210
+ return info
211
+
212
+ # -- LoRA support --
213
+
214
+ def load_lora_runtime(
215
+ self, repo_id: str, weight_name: str = None, scale: float = 1.0
216
+ ) -> bool:
217
+ """Load LoRA weights at runtime."""
218
+ if not self._strategy:
219
+ raise RuntimeError("No model loaded")
220
+ return self._strategy.load_lora_runtime(repo_id, weight_name, scale)
221
+
222
+ def unload_lora(self) -> bool:
223
+ """Unload current LoRA weights."""
224
+ if not self._strategy:
225
+ return False
226
+ return self._strategy.unload_lora()