ollamadiffuser 1.2.2__py3-none-any.whl → 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ollamadiffuser/__init__.py +1 -1
- ollamadiffuser/api/server.py +312 -312
- ollamadiffuser/cli/config_commands.py +119 -0
- ollamadiffuser/cli/lora_commands.py +169 -0
- ollamadiffuser/cli/main.py +85 -1233
- ollamadiffuser/cli/model_commands.py +664 -0
- ollamadiffuser/cli/recommend_command.py +205 -0
- ollamadiffuser/cli/registry_commands.py +197 -0
- ollamadiffuser/core/config/model_registry.py +562 -11
- ollamadiffuser/core/config/settings.py +24 -2
- ollamadiffuser/core/inference/__init__.py +5 -0
- ollamadiffuser/core/inference/base.py +182 -0
- ollamadiffuser/core/inference/engine.py +204 -1405
- ollamadiffuser/core/inference/strategies/__init__.py +1 -0
- ollamadiffuser/core/inference/strategies/controlnet_strategy.py +170 -0
- ollamadiffuser/core/inference/strategies/flux_strategy.py +136 -0
- ollamadiffuser/core/inference/strategies/generic_strategy.py +164 -0
- ollamadiffuser/core/inference/strategies/gguf_strategy.py +113 -0
- ollamadiffuser/core/inference/strategies/hidream_strategy.py +104 -0
- ollamadiffuser/core/inference/strategies/sd15_strategy.py +134 -0
- ollamadiffuser/core/inference/strategies/sd3_strategy.py +80 -0
- ollamadiffuser/core/inference/strategies/sdxl_strategy.py +131 -0
- ollamadiffuser/core/inference/strategies/video_strategy.py +108 -0
- ollamadiffuser/mcp/__init__.py +0 -0
- ollamadiffuser/mcp/server.py +184 -0
- ollamadiffuser/ui/templates/index.html +62 -1
- ollamadiffuser/ui/web.py +116 -54
- {ollamadiffuser-1.2.2.dist-info → ollamadiffuser-2.0.0.dist-info}/METADATA +337 -108
- ollamadiffuser-2.0.0.dist-info/RECORD +61 -0
- {ollamadiffuser-1.2.2.dist-info → ollamadiffuser-2.0.0.dist-info}/WHEEL +1 -1
- {ollamadiffuser-1.2.2.dist-info → ollamadiffuser-2.0.0.dist-info}/entry_points.txt +1 -0
- ollamadiffuser/core/models/registry.py +0 -384
- ollamadiffuser/ui/samples/.DS_Store +0 -0
- ollamadiffuser-1.2.2.dist-info/RECORD +0 -45
- {ollamadiffuser-1.2.2.dist-info → ollamadiffuser-2.0.0.dist-info}/licenses/LICENSE +0 -0
- {ollamadiffuser-1.2.2.dist-info → ollamadiffuser-2.0.0.dist-info}/top_level.txt +0 -0
|
@@ -1,1427 +1,226 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
import numpy as np
|
|
5
|
-
from diffusers import (
|
|
6
|
-
StableDiffusionPipeline,
|
|
7
|
-
StableDiffusionXLPipeline,
|
|
8
|
-
StableDiffusion3Pipeline,
|
|
9
|
-
FluxPipeline,
|
|
10
|
-
StableDiffusionControlNetPipeline,
|
|
11
|
-
StableDiffusionXLControlNetPipeline,
|
|
12
|
-
ControlNetModel,
|
|
13
|
-
AnimateDiffPipeline,
|
|
14
|
-
MotionAdapter
|
|
15
|
-
)
|
|
16
|
-
# Try to import HiDreamImagePipeline if available
|
|
17
|
-
try:
|
|
18
|
-
from diffusers import HiDreamImagePipeline
|
|
19
|
-
HIDREAM_AVAILABLE = True
|
|
20
|
-
except ImportError:
|
|
21
|
-
HIDREAM_AVAILABLE = False
|
|
22
|
-
logger = logging.getLogger(__name__)
|
|
23
|
-
logger.warning("HiDreamImagePipeline not available. Install latest diffusers from source for HiDream support.")
|
|
1
|
+
"""
|
|
2
|
+
Inference Engine - Facade that delegates to model-specific strategies.
|
|
24
3
|
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
logger.warning("GGUF support not available. Install with: pip install llama-cpp-python gguf")
|
|
33
|
-
except ImportError:
|
|
34
|
-
GGUF_AVAILABLE = False
|
|
35
|
-
logger = logging.getLogger(__name__)
|
|
36
|
-
logger.warning("GGUF loader module not found")
|
|
4
|
+
This replaces the former 1400+ line god class with a clean strategy pattern.
|
|
5
|
+
Each model type (SD1.5, SDXL, FLUX, SD3, ControlNet, Video, HiDream, GGUF)
|
|
6
|
+
has its own strategy class that handles loading and generation.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import logging
|
|
10
|
+
from typing import Any, Dict, Optional, Union
|
|
37
11
|
|
|
38
12
|
from PIL import Image
|
|
39
|
-
from typing import Optional, Dict, Any, Union
|
|
40
|
-
from pathlib import Path
|
|
41
|
-
from ..config.settings import ModelConfig
|
|
42
|
-
from ..utils.controlnet_preprocessors import controlnet_preprocessor
|
|
43
13
|
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
os.environ["DIFFUSERS_DISABLE_SAFETY_CHECKER"] = "1"
|
|
14
|
+
from ..config.settings import ModelConfig
|
|
15
|
+
from .base import InferenceStrategy
|
|
47
16
|
|
|
48
17
|
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _get_strategy(model_type: str) -> InferenceStrategy:
|
|
21
|
+
"""Create the appropriate strategy for a model type."""
|
|
22
|
+
if model_type == "sd15":
|
|
23
|
+
from .strategies.sd15_strategy import SD15Strategy
|
|
24
|
+
return SD15Strategy()
|
|
25
|
+
elif model_type == "sdxl":
|
|
26
|
+
from .strategies.sdxl_strategy import SDXLStrategy
|
|
27
|
+
return SDXLStrategy()
|
|
28
|
+
elif model_type == "flux":
|
|
29
|
+
from .strategies.flux_strategy import FluxStrategy
|
|
30
|
+
return FluxStrategy()
|
|
31
|
+
elif model_type == "sd3":
|
|
32
|
+
from .strategies.sd3_strategy import SD3Strategy
|
|
33
|
+
return SD3Strategy()
|
|
34
|
+
elif model_type in ("controlnet_sd15", "controlnet_sdxl"):
|
|
35
|
+
from .strategies.controlnet_strategy import ControlNetStrategy
|
|
36
|
+
return ControlNetStrategy()
|
|
37
|
+
elif model_type == "video":
|
|
38
|
+
from .strategies.video_strategy import VideoStrategy
|
|
39
|
+
return VideoStrategy()
|
|
40
|
+
elif model_type == "hidream":
|
|
41
|
+
from .strategies.hidream_strategy import HiDreamStrategy
|
|
42
|
+
return HiDreamStrategy()
|
|
43
|
+
elif model_type == "gguf":
|
|
44
|
+
from .strategies.gguf_strategy import GGUFStrategy
|
|
45
|
+
return GGUFStrategy()
|
|
46
|
+
elif model_type == "generic":
|
|
47
|
+
from .strategies.generic_strategy import GenericPipelineStrategy
|
|
48
|
+
return GenericPipelineStrategy()
|
|
49
|
+
else:
|
|
50
|
+
raise ValueError(f"Unsupported model type: {model_type}")
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _detect_device() -> str:
|
|
54
|
+
"""Automatically detect the best available device."""
|
|
55
|
+
import torch
|
|
56
|
+
|
|
57
|
+
if torch.cuda.is_available():
|
|
58
|
+
device = "cuda"
|
|
59
|
+
logger.debug(f"CUDA device count: {torch.cuda.device_count()}")
|
|
60
|
+
elif torch.backends.mps.is_available():
|
|
61
|
+
device = "mps"
|
|
62
|
+
else:
|
|
63
|
+
device = "cpu"
|
|
64
|
+
|
|
65
|
+
logger.info(f"Using device: {device}")
|
|
66
|
+
if device == "cpu":
|
|
67
|
+
logger.warning("Using CPU - inference will be slower")
|
|
68
|
+
return device
|
|
69
|
+
|
|
70
|
+
|
|
49
71
|
class InferenceEngine:
|
|
50
|
-
"""
|
|
51
|
-
|
|
72
|
+
"""
|
|
73
|
+
Facade engine that delegates to model-type-specific strategies.
|
|
74
|
+
|
|
75
|
+
This class maintains backward compatibility with the previous monolithic
|
|
76
|
+
InferenceEngine API while cleanly separating concerns into per-model strategies.
|
|
77
|
+
"""
|
|
78
|
+
|
|
52
79
|
def __init__(self):
|
|
53
|
-
self.
|
|
80
|
+
self._strategy: Optional[InferenceStrategy] = None
|
|
54
81
|
self.model_config: Optional[ModelConfig] = None
|
|
55
|
-
self.device = None
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
else:
|
|
75
|
-
device = "cpu"
|
|
76
|
-
|
|
77
|
-
logger.info(f"Using device: {device}")
|
|
78
|
-
if device == "cpu":
|
|
79
|
-
logger.warning("⚠️ Using CPU - this will be slower for large models")
|
|
80
|
-
|
|
81
|
-
return device
|
|
82
|
-
|
|
83
|
-
def _get_pipeline_class(self, model_type: str):
|
|
84
|
-
"""Get corresponding pipeline class based on model type"""
|
|
85
|
-
pipeline_map = {
|
|
86
|
-
"sd15": StableDiffusionPipeline,
|
|
87
|
-
"sdxl": StableDiffusionXLPipeline,
|
|
88
|
-
"sd3": StableDiffusion3Pipeline,
|
|
89
|
-
"flux": FluxPipeline,
|
|
90
|
-
"gguf": "gguf_special", # Special marker for GGUF models
|
|
91
|
-
"controlnet_sd15": StableDiffusionControlNetPipeline,
|
|
92
|
-
"controlnet_sdxl": StableDiffusionXLControlNetPipeline,
|
|
93
|
-
"video": AnimateDiffPipeline,
|
|
94
|
-
}
|
|
95
|
-
|
|
96
|
-
# Add HiDream support if available
|
|
97
|
-
if HIDREAM_AVAILABLE:
|
|
98
|
-
pipeline_map["hidream"] = HiDreamImagePipeline
|
|
99
|
-
|
|
100
|
-
return pipeline_map.get(model_type)
|
|
101
|
-
|
|
82
|
+
self.device: Optional[str] = None
|
|
83
|
+
|
|
84
|
+
# -- Backward-compatible properties --
|
|
85
|
+
|
|
86
|
+
@property
|
|
87
|
+
def pipeline(self):
|
|
88
|
+
"""Access the underlying pipeline (backward compat)."""
|
|
89
|
+
return self._strategy.pipeline if self._strategy else None
|
|
90
|
+
|
|
91
|
+
@property
|
|
92
|
+
def current_lora(self):
|
|
93
|
+
return self._strategy.current_lora if self._strategy else None
|
|
94
|
+
|
|
95
|
+
@property
|
|
96
|
+
def is_controlnet_pipeline(self) -> bool:
|
|
97
|
+
return getattr(self._strategy, "is_controlnet_pipeline", False)
|
|
98
|
+
|
|
99
|
+
# -- Core API --
|
|
100
|
+
|
|
102
101
|
def load_model(self, model_config: ModelConfig) -> bool:
|
|
103
|
-
"""Load model"""
|
|
102
|
+
"""Load a model using the appropriate strategy."""
|
|
104
103
|
try:
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
logger.error("Model configuration is None")
|
|
108
|
-
return False
|
|
109
|
-
|
|
110
|
-
if not model_config.path:
|
|
111
|
-
logger.error(f"Model path is None for model: {model_config.name}")
|
|
104
|
+
if not model_config or not model_config.path:
|
|
105
|
+
logger.error("Invalid model configuration")
|
|
112
106
|
return False
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
self.device = self._get_device()
|
|
122
|
-
logger.info(f"Using device: {self.device}")
|
|
123
|
-
|
|
124
|
-
# Get corresponding pipeline class
|
|
125
|
-
pipeline_class = self._get_pipeline_class(model_config.model_type)
|
|
126
|
-
if not pipeline_class:
|
|
127
|
-
logger.error(f"Unsupported model type: {model_config.model_type}")
|
|
107
|
+
|
|
108
|
+
# Only check existence for local paths, not HuggingFace Hub IDs
|
|
109
|
+
# Hub IDs look like "org/model-name", local paths start with / . or ~
|
|
110
|
+
from pathlib import Path
|
|
111
|
+
model_path = model_config.path
|
|
112
|
+
is_local_path = model_path.startswith(("/", ".", "~")) or (len(model_path) > 1 and model_path[1] == ":")
|
|
113
|
+
if is_local_path and not Path(model_path).expanduser().exists():
|
|
114
|
+
logger.error(f"Model path does not exist: {model_path}")
|
|
128
115
|
return False
|
|
129
|
-
|
|
130
|
-
#
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
if gguf_loader.load_model(model_config_dict):
|
|
148
|
-
# Set pipeline to None since we're using GGUF loader
|
|
149
|
-
self.pipeline = None
|
|
150
|
-
self.model_config = model_config
|
|
151
|
-
self.device = self._get_device()
|
|
152
|
-
logger.info(f"GGUF model {model_config.name} loaded successfully")
|
|
153
|
-
return True
|
|
154
|
-
else:
|
|
155
|
-
logger.error(f"Failed to load GGUF model: {model_config.name}")
|
|
156
|
-
return False
|
|
157
|
-
|
|
158
|
-
# Check if this is a ControlNet model
|
|
159
|
-
self.is_controlnet_pipeline = model_config.model_type.startswith("controlnet_")
|
|
160
|
-
|
|
161
|
-
# Handle ControlNet models
|
|
162
|
-
if self.is_controlnet_pipeline:
|
|
163
|
-
return self._load_controlnet_model(model_config, pipeline_class, {})
|
|
164
|
-
|
|
165
|
-
# Set loading parameters
|
|
166
|
-
load_kwargs = {}
|
|
167
|
-
if model_config.variant == "fp16":
|
|
168
|
-
load_kwargs["torch_dtype"] = torch.float16
|
|
169
|
-
load_kwargs["variant"] = "fp16"
|
|
170
|
-
elif model_config.variant == "bf16":
|
|
171
|
-
load_kwargs["torch_dtype"] = torch.bfloat16
|
|
172
|
-
|
|
173
|
-
# Load pipeline
|
|
174
|
-
logger.info(f"Loading model: {model_config.name}")
|
|
175
|
-
|
|
176
|
-
# Special handling for FLUX models
|
|
177
|
-
if model_config.model_type == "flux":
|
|
178
|
-
# FLUX models work best with bfloat16, but use float32 on CPU or float16 on MPS
|
|
179
|
-
if self.device == "cpu":
|
|
180
|
-
load_kwargs["torch_dtype"] = torch.float32
|
|
181
|
-
logger.info("Using float32 for FLUX model on CPU")
|
|
182
|
-
logger.warning("⚠️ FLUX.1-dev is a 12B parameter model. CPU inference will be very slow!")
|
|
183
|
-
logger.warning("⚠️ For better performance, consider using a GPU with at least 12GB VRAM")
|
|
184
|
-
else:
|
|
185
|
-
load_kwargs["torch_dtype"] = torch.bfloat16
|
|
186
|
-
load_kwargs["use_safetensors"] = True
|
|
187
|
-
logger.info("Using bfloat16 for FLUX model")
|
|
188
|
-
|
|
189
|
-
# Special handling for Video (AnimateDiff) models
|
|
190
|
-
elif model_config.model_type == "video":
|
|
191
|
-
# AnimateDiff requires motion adapter
|
|
192
|
-
logger.info("Loading AnimateDiff (video) model")
|
|
193
|
-
motion_adapter_path = getattr(model_config, 'motion_adapter_path', None)
|
|
194
|
-
if not motion_adapter_path:
|
|
195
|
-
# Use default motion adapter if not specified
|
|
196
|
-
motion_adapter_path = "guoyww/animatediff-motion-adapter-v1-5-2"
|
|
197
|
-
logger.info(f"Using default motion adapter: {motion_adapter_path}")
|
|
198
|
-
|
|
199
|
-
try:
|
|
200
|
-
# Load motion adapter
|
|
201
|
-
motion_adapter = MotionAdapter.from_pretrained(
|
|
202
|
-
motion_adapter_path,
|
|
203
|
-
torch_dtype=load_kwargs.get("torch_dtype", torch.float16)
|
|
204
|
-
)
|
|
205
|
-
load_kwargs["motion_adapter"] = motion_adapter
|
|
206
|
-
logger.info(f"Motion adapter loaded from: {motion_adapter_path}")
|
|
207
|
-
except Exception as e:
|
|
208
|
-
logger.error(f"Failed to load motion adapter: {e}")
|
|
209
|
-
return False
|
|
210
|
-
|
|
211
|
-
# Disable safety checker for AnimateDiff
|
|
212
|
-
load_kwargs["safety_checker"] = None
|
|
213
|
-
load_kwargs["requires_safety_checker"] = False
|
|
214
|
-
load_kwargs["feature_extractor"] = None
|
|
215
|
-
logger.info("Safety checker disabled for AnimateDiff models")
|
|
216
|
-
|
|
217
|
-
# Special handling for HiDream models
|
|
218
|
-
elif model_config.model_type == "hidream":
|
|
219
|
-
if not HIDREAM_AVAILABLE:
|
|
220
|
-
logger.error("HiDream models require diffusers to be installed from source. Please install with: pip install git+https://github.com/huggingface/diffusers.git")
|
|
221
|
-
return False
|
|
222
|
-
|
|
223
|
-
logger.info("Loading HiDream model")
|
|
224
|
-
# HiDream models work best with bfloat16
|
|
225
|
-
if self.device == "cpu":
|
|
226
|
-
load_kwargs["torch_dtype"] = torch.float32
|
|
227
|
-
logger.info("Using float32 for HiDream model on CPU")
|
|
228
|
-
logger.warning("⚠️ HiDream models are large. CPU inference will be slow!")
|
|
229
|
-
else:
|
|
230
|
-
load_kwargs["torch_dtype"] = torch.bfloat16
|
|
231
|
-
logger.info("Using bfloat16 for HiDream model")
|
|
232
|
-
|
|
233
|
-
# Disable safety checker for HiDream models
|
|
234
|
-
load_kwargs["safety_checker"] = None
|
|
235
|
-
load_kwargs["requires_safety_checker"] = False
|
|
236
|
-
load_kwargs["feature_extractor"] = None
|
|
237
|
-
logger.info("Safety checker disabled for HiDream models")
|
|
238
|
-
|
|
239
|
-
# Disable safety checker for SD 1.5 to prevent false NSFW detections
|
|
240
|
-
if model_config.model_type == "sd15" or model_config.model_type == "sdxl":
|
|
241
|
-
load_kwargs["safety_checker"] = None
|
|
242
|
-
load_kwargs["requires_safety_checker"] = False
|
|
243
|
-
load_kwargs["feature_extractor"] = None
|
|
244
|
-
# Use float32 for better numerical stability on SD 1.5
|
|
245
|
-
if model_config.variant == "fp16" and (self.device == "cpu" or self.device == "mps"):
|
|
246
|
-
load_kwargs["torch_dtype"] = torch.float32
|
|
247
|
-
load_kwargs.pop("variant", None)
|
|
248
|
-
logger.info(f"Using float32 for {self.device} inference to improve stability")
|
|
249
|
-
elif self.device == "mps":
|
|
250
|
-
# Force float32 on MPS for SD 1.5 to avoid NaN issues
|
|
251
|
-
load_kwargs["torch_dtype"] = torch.float32
|
|
252
|
-
logger.info("Using float32 for MPS inference to avoid NaN issues with SD 1.5")
|
|
253
|
-
logger.info("Safety checker disabled for SD 1.5 to prevent false NSFW detections")
|
|
254
|
-
|
|
255
|
-
# Disable safety checker for FLUX models to prevent false NSFW detections
|
|
256
|
-
if model_config.model_type == "flux":
|
|
257
|
-
load_kwargs["safety_checker"] = None
|
|
258
|
-
load_kwargs["requires_safety_checker"] = False
|
|
259
|
-
load_kwargs["feature_extractor"] = None
|
|
260
|
-
logger.info("Safety checker disabled for FLUX models to prevent false NSFW detections")
|
|
261
|
-
|
|
262
|
-
# Load pipeline
|
|
263
|
-
self.pipeline = pipeline_class.from_pretrained(
|
|
264
|
-
model_config.path,
|
|
265
|
-
**load_kwargs
|
|
266
|
-
)
|
|
267
|
-
|
|
268
|
-
# Move to device with proper error handling
|
|
269
|
-
try:
|
|
270
|
-
self.pipeline = self.pipeline.to(self.device)
|
|
271
|
-
logger.info(f"Pipeline moved to {self.device}")
|
|
272
|
-
except Exception as e:
|
|
273
|
-
logger.warning(f"Failed to move pipeline to {self.device}: {e}")
|
|
274
|
-
if self.device != "cpu":
|
|
275
|
-
logger.info("Falling back to CPU")
|
|
276
|
-
self.device = "cpu"
|
|
277
|
-
self.pipeline = self.pipeline.to("cpu")
|
|
278
|
-
|
|
279
|
-
# Enable memory optimizations
|
|
280
|
-
if hasattr(self.pipeline, 'enable_attention_slicing'):
|
|
281
|
-
self.pipeline.enable_attention_slicing()
|
|
282
|
-
logger.info("Enabled attention slicing for memory optimization")
|
|
283
|
-
|
|
284
|
-
# Special optimizations for FLUX models
|
|
285
|
-
if model_config.model_type == "flux":
|
|
286
|
-
if self.device == "cuda":
|
|
287
|
-
# CUDA-specific optimizations
|
|
288
|
-
if hasattr(self.pipeline, 'enable_model_cpu_offload'):
|
|
289
|
-
self.pipeline.enable_model_cpu_offload()
|
|
290
|
-
logger.info("Enabled CPU offloading for FLUX model")
|
|
291
|
-
elif self.device == "cpu":
|
|
292
|
-
# CPU-specific optimizations
|
|
293
|
-
logger.info("Applying CPU-specific optimizations for FLUX model")
|
|
294
|
-
# Enable memory efficient attention if available
|
|
295
|
-
if hasattr(self.pipeline, 'enable_xformers_memory_efficient_attention'):
|
|
296
|
-
try:
|
|
297
|
-
self.pipeline.enable_xformers_memory_efficient_attention()
|
|
298
|
-
logger.info("Enabled xformers memory efficient attention")
|
|
299
|
-
except Exception as e:
|
|
300
|
-
logger.debug(f"xformers not available: {e}")
|
|
301
|
-
|
|
302
|
-
# Set low memory mode
|
|
303
|
-
if hasattr(self.pipeline, 'enable_sequential_cpu_offload'):
|
|
304
|
-
try:
|
|
305
|
-
self.pipeline.enable_sequential_cpu_offload()
|
|
306
|
-
logger.info("Enabled sequential CPU offload for memory efficiency")
|
|
307
|
-
except Exception as e:
|
|
308
|
-
logger.debug(f"Sequential CPU offload not available: {e}")
|
|
309
|
-
|
|
310
|
-
# Special optimizations for Video (AnimateDiff) models
|
|
311
|
-
elif model_config.model_type == "video":
|
|
312
|
-
logger.info("Applying optimizations for AnimateDiff video model")
|
|
313
|
-
# Enable VAE slicing for video models to reduce memory usage
|
|
314
|
-
if hasattr(self.pipeline, 'enable_vae_slicing'):
|
|
315
|
-
self.pipeline.enable_vae_slicing()
|
|
316
|
-
logger.info("Enabled VAE slicing for video model")
|
|
317
|
-
|
|
318
|
-
# Enable model CPU offload for better memory management
|
|
319
|
-
if self.device == "cuda" and hasattr(self.pipeline, 'enable_model_cpu_offload'):
|
|
320
|
-
self.pipeline.enable_model_cpu_offload()
|
|
321
|
-
logger.info("Enabled model CPU offload for video model")
|
|
322
|
-
|
|
323
|
-
# Set scheduler to work well with AnimateDiff
|
|
324
|
-
if hasattr(self.pipeline, 'scheduler'):
|
|
325
|
-
from diffusers import DDIMScheduler
|
|
326
|
-
try:
|
|
327
|
-
self.pipeline.scheduler = DDIMScheduler.from_config(
|
|
328
|
-
self.pipeline.scheduler.config,
|
|
329
|
-
clip_sample=False,
|
|
330
|
-
timestep_spacing="linspace",
|
|
331
|
-
beta_schedule="linear",
|
|
332
|
-
steps_offset=1,
|
|
333
|
-
)
|
|
334
|
-
logger.info("Configured DDIM scheduler for AnimateDiff")
|
|
335
|
-
except Exception as e:
|
|
336
|
-
logger.debug(f"Could not configure DDIM scheduler: {e}")
|
|
337
|
-
|
|
338
|
-
# Special optimizations for HiDream models
|
|
339
|
-
elif model_config.model_type == "hidream":
|
|
340
|
-
logger.info("Applying optimizations for HiDream model")
|
|
341
|
-
# Enable VAE slicing and tiling for HiDream models
|
|
342
|
-
if hasattr(self.pipeline, 'enable_vae_slicing'):
|
|
343
|
-
self.pipeline.enable_vae_slicing()
|
|
344
|
-
logger.info("Enabled VAE slicing for HiDream model")
|
|
345
|
-
|
|
346
|
-
if hasattr(self.pipeline, 'enable_vae_tiling'):
|
|
347
|
-
self.pipeline.enable_vae_tiling()
|
|
348
|
-
logger.info("Enabled VAE tiling for HiDream model")
|
|
349
|
-
|
|
350
|
-
# Enable model CPU offload for better memory management
|
|
351
|
-
if self.device == "cuda" and hasattr(self.pipeline, 'enable_model_cpu_offload'):
|
|
352
|
-
self.pipeline.enable_model_cpu_offload()
|
|
353
|
-
logger.info("Enabled model CPU offload for HiDream model")
|
|
354
|
-
elif self.device == "cpu":
|
|
355
|
-
# CPU-specific optimizations for HiDream
|
|
356
|
-
if hasattr(self.pipeline, 'enable_sequential_cpu_offload'):
|
|
357
|
-
try:
|
|
358
|
-
self.pipeline.enable_sequential_cpu_offload()
|
|
359
|
-
logger.info("Enabled sequential CPU offload for HiDream model")
|
|
360
|
-
except Exception as e:
|
|
361
|
-
logger.debug(f"Sequential CPU offload not available: {e}")
|
|
362
|
-
|
|
363
|
-
# Additional safety checker disabling for SD 1.5 (in case the above didn't work)
|
|
364
|
-
if model_config.model_type == "sd15" or model_config.model_type == "sdxl":
|
|
365
|
-
if hasattr(self.pipeline, 'safety_checker'):
|
|
366
|
-
self.pipeline.safety_checker = None
|
|
367
|
-
if hasattr(self.pipeline, 'feature_extractor'):
|
|
368
|
-
self.pipeline.feature_extractor = None
|
|
369
|
-
if hasattr(self.pipeline, 'requires_safety_checker'):
|
|
370
|
-
self.pipeline.requires_safety_checker = False
|
|
371
|
-
|
|
372
|
-
# Monkey patch the safety checker call to always return False
|
|
373
|
-
def dummy_safety_check(self, images, clip_input):
|
|
374
|
-
return images, [False] * len(images)
|
|
375
|
-
|
|
376
|
-
# Apply monkey patch if safety checker exists
|
|
377
|
-
if hasattr(self.pipeline, '_safety_check'):
|
|
378
|
-
self.pipeline._safety_check = dummy_safety_check.__get__(self.pipeline, type(self.pipeline))
|
|
379
|
-
|
|
380
|
-
# Also monkey patch the run_safety_checker method if it exists
|
|
381
|
-
if hasattr(self.pipeline, 'run_safety_checker'):
|
|
382
|
-
def dummy_run_safety_checker(images, device, dtype):
|
|
383
|
-
return images, [False] * len(images)
|
|
384
|
-
self.pipeline.run_safety_checker = dummy_run_safety_checker
|
|
385
|
-
|
|
386
|
-
# Monkey patch the check_inputs method to prevent safety checker validation
|
|
387
|
-
if hasattr(self.pipeline, 'check_inputs'):
|
|
388
|
-
original_check_inputs = self.pipeline.check_inputs
|
|
389
|
-
def patched_check_inputs(*args, **kwargs):
|
|
390
|
-
# Call original but ignore safety checker requirements
|
|
391
|
-
try:
|
|
392
|
-
return original_check_inputs(*args, **kwargs)
|
|
393
|
-
except Exception as e:
|
|
394
|
-
if "safety_checker" in str(e).lower():
|
|
395
|
-
logger.debug(f"Ignoring safety checker validation error: {e}")
|
|
396
|
-
return
|
|
397
|
-
raise e
|
|
398
|
-
self.pipeline.check_inputs = patched_check_inputs
|
|
399
|
-
|
|
400
|
-
logger.info("Additional safety checker components disabled with monkey patch")
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
# Load LoRA and other components
|
|
404
|
-
if model_config.components and "lora" in model_config.components:
|
|
405
|
-
self._load_lora(model_config)
|
|
406
|
-
|
|
407
|
-
# Apply optimizations
|
|
408
|
-
self._apply_optimizations()
|
|
409
|
-
|
|
410
|
-
# Set tokenizer
|
|
411
|
-
if hasattr(self.pipeline, 'tokenizer'):
|
|
412
|
-
self.tokenizer = self.pipeline.tokenizer
|
|
413
|
-
|
|
414
|
-
self.model_config = model_config
|
|
415
|
-
logger.info(f"Model {model_config.name} loaded successfully")
|
|
416
|
-
return True
|
|
417
|
-
|
|
418
|
-
except Exception as e:
|
|
419
|
-
logger.error(f"Failed to load model: {e}")
|
|
116
|
+
|
|
117
|
+
# Detect GGUF models by variant
|
|
118
|
+
model_type = model_config.model_type
|
|
119
|
+
if model_config.variant and "gguf" in model_config.variant.lower():
|
|
120
|
+
model_type = "gguf"
|
|
121
|
+
|
|
122
|
+
self.device = _detect_device()
|
|
123
|
+
self._strategy = _get_strategy(model_type)
|
|
124
|
+
|
|
125
|
+
if self._strategy.load(model_config, self.device):
|
|
126
|
+
self.model_config = model_config
|
|
127
|
+
# Update device in case strategy fell back to CPU
|
|
128
|
+
self.device = self._strategy.device
|
|
129
|
+
logger.info(f"Model {model_config.name} loaded successfully via {type(self._strategy).__name__}")
|
|
130
|
+
return True
|
|
131
|
+
|
|
132
|
+
self._strategy = None
|
|
420
133
|
return False
|
|
421
|
-
|
|
422
|
-
def _load_controlnet_model(self, model_config: ModelConfig, pipeline_class, load_kwargs: dict) -> bool:
|
|
423
|
-
"""Load ControlNet model with base model"""
|
|
424
|
-
try:
|
|
425
|
-
# Get base model info
|
|
426
|
-
base_model_name = getattr(model_config, 'base_model', None)
|
|
427
|
-
if not base_model_name:
|
|
428
|
-
# Try to extract from model registry
|
|
429
|
-
from ..models.manager import model_manager
|
|
430
|
-
model_info = model_manager.get_model_info(model_config.name)
|
|
431
|
-
if model_info and 'base_model' in model_info:
|
|
432
|
-
base_model_name = model_info['base_model']
|
|
433
|
-
else:
|
|
434
|
-
logger.error(f"No base model specified for ControlNet model: {model_config.name}")
|
|
435
|
-
return False
|
|
436
|
-
|
|
437
|
-
# Check if base model is installed
|
|
438
|
-
from ..models.manager import model_manager
|
|
439
|
-
if not model_manager.is_model_installed(base_model_name):
|
|
440
|
-
logger.error(f"Base model '{base_model_name}' not installed. Please install it first.")
|
|
441
|
-
return False
|
|
442
|
-
|
|
443
|
-
# Get base model config
|
|
444
|
-
from ..config.settings import settings
|
|
445
|
-
base_model_config = settings.models[base_model_name]
|
|
446
|
-
|
|
447
|
-
# Set loading parameters based on variant
|
|
448
|
-
if model_config.variant == "fp16":
|
|
449
|
-
load_kwargs["torch_dtype"] = torch.float16
|
|
450
|
-
load_kwargs["variant"] = "fp16"
|
|
451
|
-
elif model_config.variant == "bf16":
|
|
452
|
-
load_kwargs["torch_dtype"] = torch.bfloat16
|
|
453
|
-
|
|
454
|
-
# Handle device-specific optimizations
|
|
455
|
-
if self.device == "cpu" or self.device == "mps":
|
|
456
|
-
load_kwargs["torch_dtype"] = torch.float32
|
|
457
|
-
load_kwargs.pop("variant", None)
|
|
458
|
-
logger.info(f"Using float32 for {self.device} inference to improve stability")
|
|
459
|
-
|
|
460
|
-
# Disable safety checker
|
|
461
|
-
load_kwargs["safety_checker"] = None
|
|
462
|
-
load_kwargs["requires_safety_checker"] = False
|
|
463
|
-
load_kwargs["feature_extractor"] = None
|
|
464
|
-
|
|
465
|
-
# Load ControlNet model
|
|
466
|
-
logger.info(f"Loading ControlNet model from: {model_config.path}")
|
|
467
|
-
self.controlnet = ControlNetModel.from_pretrained(
|
|
468
|
-
model_config.path,
|
|
469
|
-
torch_dtype=load_kwargs.get("torch_dtype", torch.float32)
|
|
470
|
-
)
|
|
471
|
-
|
|
472
|
-
# Load pipeline with ControlNet and base model
|
|
473
|
-
logger.info(f"Loading ControlNet pipeline with base model: {base_model_name}")
|
|
474
|
-
self.pipeline = pipeline_class.from_pretrained(
|
|
475
|
-
base_model_config.path,
|
|
476
|
-
controlnet=self.controlnet,
|
|
477
|
-
**load_kwargs
|
|
478
|
-
)
|
|
479
|
-
|
|
480
|
-
# Move to device
|
|
481
|
-
try:
|
|
482
|
-
self.pipeline = self.pipeline.to(self.device)
|
|
483
|
-
self.controlnet = self.controlnet.to(self.device)
|
|
484
|
-
logger.info(f"ControlNet pipeline moved to {self.device}")
|
|
485
|
-
except Exception as e:
|
|
486
|
-
logger.warning(f"Failed to move ControlNet pipeline to {self.device}: {e}")
|
|
487
|
-
if self.device != "cpu":
|
|
488
|
-
logger.info("Falling back to CPU")
|
|
489
|
-
self.device = "cpu"
|
|
490
|
-
self.pipeline = self.pipeline.to("cpu")
|
|
491
|
-
self.controlnet = self.controlnet.to("cpu")
|
|
492
|
-
|
|
493
|
-
# Enable memory optimizations
|
|
494
|
-
if hasattr(self.pipeline, 'enable_attention_slicing'):
|
|
495
|
-
self.pipeline.enable_attention_slicing()
|
|
496
|
-
logger.info("Enabled attention slicing for ControlNet pipeline")
|
|
497
|
-
|
|
498
|
-
# Apply additional optimizations
|
|
499
|
-
self._apply_optimizations()
|
|
500
|
-
|
|
501
|
-
# Set tokenizer
|
|
502
|
-
if hasattr(self.pipeline, 'tokenizer'):
|
|
503
|
-
self.tokenizer = self.pipeline.tokenizer
|
|
504
|
-
|
|
505
|
-
self.model_config = model_config
|
|
506
|
-
logger.info(f"ControlNet model {model_config.name} loaded successfully")
|
|
507
|
-
return True
|
|
508
|
-
|
|
134
|
+
|
|
509
135
|
except Exception as e:
|
|
510
|
-
logger.error(f"Failed to load
|
|
136
|
+
logger.error(f"Failed to load model: {e}")
|
|
137
|
+
self._strategy = None
|
|
511
138
|
return False
|
|
512
|
-
|
|
513
|
-
def
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
self.pipeline.unet = torch.compile(
|
|
563
|
-
self.pipeline.unet,
|
|
564
|
-
mode="reduce-overhead",
|
|
565
|
-
fullgraph=True
|
|
566
|
-
)
|
|
567
|
-
logger.info("torch.compile optimization enabled")
|
|
568
|
-
except Exception as e:
|
|
569
|
-
logger.debug(f"torch.compile failed: {e}")
|
|
570
|
-
elif self.device == "mps":
|
|
571
|
-
logger.debug("Skipping torch.compile on MPS (not supported yet)")
|
|
572
|
-
elif self.device == "cpu":
|
|
573
|
-
logger.debug("Skipping torch.compile on CPU for stability")
|
|
574
|
-
|
|
575
|
-
except Exception as e:
|
|
576
|
-
logger.warning(f"Failed to apply optimization settings: {e}")
|
|
577
|
-
|
|
578
|
-
def truncate_prompt(self, prompt: str) -> str:
|
|
579
|
-
"""Truncate prompt to fit CLIP token limit"""
|
|
580
|
-
if not prompt or not self.tokenizer:
|
|
581
|
-
return prompt
|
|
582
|
-
|
|
583
|
-
# Encode prompt
|
|
584
|
-
tokens = self.tokenizer.encode(prompt)
|
|
585
|
-
|
|
586
|
-
# Check if truncation is needed
|
|
587
|
-
if len(tokens) <= self.max_token_limit:
|
|
588
|
-
return prompt
|
|
589
|
-
|
|
590
|
-
# Truncate tokens and decode back to text
|
|
591
|
-
truncated_tokens = tokens[:self.max_token_limit]
|
|
592
|
-
truncated_prompt = self.tokenizer.decode(truncated_tokens)
|
|
593
|
-
|
|
594
|
-
logger.warning(f"Prompt truncated: {len(tokens)} -> {len(truncated_tokens)} tokens")
|
|
595
|
-
return truncated_prompt
|
|
596
|
-
|
|
597
|
-
def generate_image(self,
|
|
598
|
-
prompt: str,
|
|
599
|
-
negative_prompt: str = "low quality, bad anatomy, worst quality, low resolution",
|
|
600
|
-
num_inference_steps: Optional[int] = None,
|
|
601
|
-
guidance_scale: Optional[float] = None,
|
|
602
|
-
width: int = 1024,
|
|
603
|
-
height: int = 1024,
|
|
604
|
-
control_image: Optional[Union[Image.Image, str]] = None,
|
|
605
|
-
controlnet_conditioning_scale: float = 1.0,
|
|
606
|
-
control_guidance_start: float = 0.0,
|
|
607
|
-
control_guidance_end: float = 1.0,
|
|
608
|
-
**kwargs) -> Image.Image:
|
|
609
|
-
"""Generate image"""
|
|
610
|
-
# Check if we're using a GGUF model
|
|
611
|
-
is_gguf_model = (
|
|
612
|
-
self.model_config and
|
|
613
|
-
(self.model_config.model_type == "gguf" or
|
|
614
|
-
(self.model_config.variant and "gguf" in self.model_config.variant.lower()))
|
|
139
|
+
|
|
140
|
+
def generate_image(
|
|
141
|
+
self,
|
|
142
|
+
prompt: str,
|
|
143
|
+
negative_prompt: str = "low quality, bad anatomy, worst quality, low resolution",
|
|
144
|
+
num_inference_steps: Optional[int] = None,
|
|
145
|
+
guidance_scale: Optional[float] = None,
|
|
146
|
+
width: int = 1024,
|
|
147
|
+
height: int = 1024,
|
|
148
|
+
seed: Optional[int] = None,
|
|
149
|
+
control_image: Optional[Union[Image.Image, str]] = None,
|
|
150
|
+
controlnet_conditioning_scale: float = 1.0,
|
|
151
|
+
control_guidance_start: float = 0.0,
|
|
152
|
+
control_guidance_end: float = 1.0,
|
|
153
|
+
image: Optional[Image.Image] = None,
|
|
154
|
+
mask_image: Optional[Image.Image] = None,
|
|
155
|
+
strength: float = 0.75,
|
|
156
|
+
**kwargs,
|
|
157
|
+
) -> Image.Image:
|
|
158
|
+
"""Generate an image using the current strategy."""
|
|
159
|
+
if not self._strategy:
|
|
160
|
+
raise RuntimeError("No model loaded")
|
|
161
|
+
|
|
162
|
+
# Build kwargs to pass through
|
|
163
|
+
gen_kwargs = dict(kwargs)
|
|
164
|
+
|
|
165
|
+
# Pass img2img / inpainting params
|
|
166
|
+
if image is not None:
|
|
167
|
+
gen_kwargs["image"] = image
|
|
168
|
+
if mask_image is not None:
|
|
169
|
+
gen_kwargs["mask_image"] = mask_image
|
|
170
|
+
if image is not None or mask_image is not None:
|
|
171
|
+
gen_kwargs["strength"] = strength
|
|
172
|
+
|
|
173
|
+
# Pass ControlNet params
|
|
174
|
+
if control_image is not None:
|
|
175
|
+
gen_kwargs["control_image"] = control_image
|
|
176
|
+
gen_kwargs["controlnet_conditioning_scale"] = controlnet_conditioning_scale
|
|
177
|
+
gen_kwargs["control_guidance_start"] = control_guidance_start
|
|
178
|
+
gen_kwargs["control_guidance_end"] = control_guidance_end
|
|
179
|
+
|
|
180
|
+
return self._strategy.generate(
|
|
181
|
+
prompt=prompt,
|
|
182
|
+
negative_prompt=negative_prompt,
|
|
183
|
+
num_inference_steps=num_inference_steps,
|
|
184
|
+
guidance_scale=guidance_scale,
|
|
185
|
+
width=width,
|
|
186
|
+
height=height,
|
|
187
|
+
seed=seed,
|
|
188
|
+
**gen_kwargs,
|
|
615
189
|
)
|
|
616
|
-
|
|
617
|
-
if is_gguf_model:
|
|
618
|
-
if not GGUF_AVAILABLE:
|
|
619
|
-
raise RuntimeError("GGUF support not available")
|
|
620
|
-
|
|
621
|
-
if not gguf_loader.is_loaded():
|
|
622
|
-
raise RuntimeError("GGUF model not loaded")
|
|
623
|
-
|
|
624
|
-
logger.info(f"Generating image using GGUF model: {prompt[:50]}...")
|
|
625
|
-
|
|
626
|
-
# Use model default parameters for GGUF
|
|
627
|
-
if num_inference_steps is None:
|
|
628
|
-
num_inference_steps = self.model_config.parameters.get("num_inference_steps", 20)
|
|
629
|
-
|
|
630
|
-
if guidance_scale is None:
|
|
631
|
-
guidance_scale = self.model_config.parameters.get("guidance_scale", 7.5)
|
|
632
|
-
|
|
633
|
-
# Generate using GGUF loader
|
|
634
|
-
generation_kwargs = {
|
|
635
|
-
"prompt": prompt,
|
|
636
|
-
"negative_prompt": negative_prompt,
|
|
637
|
-
"num_inference_steps": num_inference_steps,
|
|
638
|
-
"guidance_scale": guidance_scale,
|
|
639
|
-
"width": width,
|
|
640
|
-
"height": height,
|
|
641
|
-
**kwargs
|
|
642
|
-
}
|
|
643
|
-
|
|
644
|
-
try:
|
|
645
|
-
image = gguf_loader.generate_image(**generation_kwargs)
|
|
646
|
-
if image is None:
|
|
647
|
-
logger.warning("GGUF generation returned None, creating error image")
|
|
648
|
-
return self._create_error_image("GGUF generation failed or not yet implemented", prompt)
|
|
649
|
-
return image
|
|
650
|
-
except Exception as e:
|
|
651
|
-
logger.error(f"GGUF generation failed: {e}")
|
|
652
|
-
return self._create_error_image(str(e), prompt)
|
|
653
|
-
|
|
654
|
-
# Continue with regular pipeline generation for non-GGUF models
|
|
655
|
-
if not self.pipeline:
|
|
656
|
-
raise RuntimeError("Model not loaded")
|
|
657
|
-
|
|
658
|
-
# Handle ControlNet-specific logic
|
|
659
|
-
if self.is_controlnet_pipeline:
|
|
660
|
-
if control_image is None:
|
|
661
|
-
raise ValueError("ControlNet model requires a control image")
|
|
662
|
-
|
|
663
|
-
# Process control image
|
|
664
|
-
control_image = self._prepare_control_image(control_image, width, height)
|
|
665
|
-
|
|
666
|
-
# Use model default parameters
|
|
667
|
-
if num_inference_steps is None:
|
|
668
|
-
num_inference_steps = self.model_config.parameters.get("num_inference_steps", 28)
|
|
669
|
-
|
|
670
|
-
if guidance_scale is None:
|
|
671
|
-
guidance_scale = self.model_config.parameters.get("guidance_scale", 3.5)
|
|
672
|
-
|
|
673
|
-
# Truncate prompts
|
|
674
|
-
truncated_prompt = self.truncate_prompt(prompt)
|
|
675
|
-
truncated_negative_prompt = self.truncate_prompt(negative_prompt)
|
|
676
|
-
|
|
677
|
-
try:
|
|
678
|
-
logger.info(f"Starting image generation: {truncated_prompt[:50]}...")
|
|
679
|
-
|
|
680
|
-
# Generation parameters
|
|
681
|
-
generation_kwargs = {
|
|
682
|
-
"prompt": truncated_prompt,
|
|
683
|
-
"negative_prompt": truncated_negative_prompt,
|
|
684
|
-
"num_inference_steps": num_inference_steps,
|
|
685
|
-
"guidance_scale": guidance_scale,
|
|
686
|
-
**kwargs
|
|
687
|
-
}
|
|
688
|
-
|
|
689
|
-
# Add ControlNet parameters if this is a ControlNet pipeline
|
|
690
|
-
if self.is_controlnet_pipeline and control_image is not None:
|
|
691
|
-
generation_kwargs.update({
|
|
692
|
-
"image": control_image,
|
|
693
|
-
"controlnet_conditioning_scale": controlnet_conditioning_scale,
|
|
694
|
-
"control_guidance_start": control_guidance_start,
|
|
695
|
-
"control_guidance_end": control_guidance_end
|
|
696
|
-
})
|
|
697
|
-
logger.info(f"ControlNet parameters: conditioning_scale={controlnet_conditioning_scale}, "
|
|
698
|
-
f"guidance_start={control_guidance_start}, guidance_end={control_guidance_end}")
|
|
699
|
-
|
|
700
|
-
# Add size parameters based on model type
|
|
701
|
-
if self.model_config.model_type in ["sdxl", "sd3", "flux", "controlnet_sdxl"]:
|
|
702
|
-
generation_kwargs.update({
|
|
703
|
-
"width": width,
|
|
704
|
-
"height": height
|
|
705
|
-
})
|
|
706
|
-
|
|
707
|
-
# FLUX models have special parameters
|
|
708
|
-
if self.model_config.model_type == "flux":
|
|
709
|
-
# Add max_sequence_length for FLUX
|
|
710
|
-
max_seq_len = self.model_config.parameters.get("max_sequence_length", 512)
|
|
711
|
-
generation_kwargs["max_sequence_length"] = max_seq_len
|
|
712
|
-
logger.info(f"Using max_sequence_length={max_seq_len} for FLUX model")
|
|
713
|
-
|
|
714
|
-
# Special handling for FLUX.1-schnell (distilled model)
|
|
715
|
-
if "schnell" in self.model_config.name.lower():
|
|
716
|
-
# FLUX.1-schnell doesn't use guidance
|
|
717
|
-
if guidance_scale != 0.0:
|
|
718
|
-
logger.info("FLUX.1-schnell detected - setting guidance_scale to 0.0 (distilled model doesn't use guidance)")
|
|
719
|
-
generation_kwargs["guidance_scale"] = 0.0
|
|
720
|
-
|
|
721
|
-
# Use fewer steps for schnell (it's designed for 1-4 steps)
|
|
722
|
-
if num_inference_steps > 4:
|
|
723
|
-
logger.info(f"FLUX.1-schnell detected - reducing steps from {num_inference_steps} to 4 for optimal performance")
|
|
724
|
-
generation_kwargs["num_inference_steps"] = 4
|
|
725
|
-
|
|
726
|
-
logger.info("🚀 Using FLUX.1-schnell - fast distilled model optimized for 4-step generation")
|
|
727
|
-
|
|
728
|
-
# Device-specific adjustments for FLUX
|
|
729
|
-
if self.device == "cpu":
|
|
730
|
-
# Reduce steps for faster CPU inference
|
|
731
|
-
if "schnell" not in self.model_config.name.lower() and num_inference_steps > 20:
|
|
732
|
-
num_inference_steps = 20
|
|
733
|
-
generation_kwargs["num_inference_steps"] = num_inference_steps
|
|
734
|
-
logger.info(f"Reduced inference steps to {num_inference_steps} for CPU performance")
|
|
735
|
-
|
|
736
|
-
# Lower guidance scale for CPU stability (except for schnell which uses 0.0)
|
|
737
|
-
if "schnell" not in self.model_config.name.lower() and guidance_scale > 5.0:
|
|
738
|
-
guidance_scale = 5.0
|
|
739
|
-
generation_kwargs["guidance_scale"] = guidance_scale
|
|
740
|
-
logger.info(f"Reduced guidance scale to {guidance_scale} for CPU stability")
|
|
741
|
-
|
|
742
|
-
logger.warning("🐌 CPU inference detected - this may take several minutes per image")
|
|
743
|
-
elif self.device == "mps":
|
|
744
|
-
# MPS-specific adjustments for stability (except for schnell which uses 0.0)
|
|
745
|
-
if "schnell" not in self.model_config.name.lower() and guidance_scale > 7.0:
|
|
746
|
-
guidance_scale = 7.0
|
|
747
|
-
generation_kwargs["guidance_scale"] = guidance_scale
|
|
748
|
-
logger.info(f"Reduced guidance scale to {guidance_scale} for MPS stability")
|
|
749
|
-
|
|
750
|
-
logger.info("🍎 MPS inference - should be faster than CPU but slower than CUDA")
|
|
751
|
-
|
|
752
|
-
elif self.model_config.model_type in ["sd15", "controlnet_sd15"]:
|
|
753
|
-
# SD 1.5 and ControlNet SD 1.5 work best with 512x512, adjust if different sizes requested
|
|
754
|
-
if width != 1024 or height != 1024:
|
|
755
|
-
generation_kwargs.update({
|
|
756
|
-
"width": width,
|
|
757
|
-
"height": height
|
|
758
|
-
})
|
|
759
|
-
else:
|
|
760
|
-
# Use optimal size for SD 1.5
|
|
761
|
-
generation_kwargs.update({
|
|
762
|
-
"width": 512,
|
|
763
|
-
"height": 512
|
|
764
|
-
})
|
|
765
|
-
|
|
766
|
-
# Special handling for Video (AnimateDiff) models
|
|
767
|
-
elif self.model_config.model_type == "video":
|
|
768
|
-
logger.info("Configuring AnimateDiff video generation parameters")
|
|
769
|
-
|
|
770
|
-
# Video-specific parameters
|
|
771
|
-
num_frames = kwargs.get("num_frames", 16)
|
|
772
|
-
generation_kwargs["num_frames"] = num_frames
|
|
773
|
-
|
|
774
|
-
# AnimateDiff works best with specific resolutions
|
|
775
|
-
# Use 512x512 for better compatibility with most motion adapters
|
|
776
|
-
generation_kwargs.update({
|
|
777
|
-
"width": 512,
|
|
778
|
-
"height": 512
|
|
779
|
-
})
|
|
780
|
-
logger.info(f"Using 512x512 resolution for AnimateDiff compatibility")
|
|
781
|
-
|
|
782
|
-
# Set optimal parameters for video generation
|
|
783
|
-
if guidance_scale > 7.5:
|
|
784
|
-
generation_kwargs["guidance_scale"] = 7.5
|
|
785
|
-
logger.info("Reduced guidance scale to 7.5 for video stability")
|
|
786
|
-
|
|
787
|
-
if num_inference_steps > 25:
|
|
788
|
-
generation_kwargs["num_inference_steps"] = 25
|
|
789
|
-
logger.info("Reduced inference steps to 25 for video generation")
|
|
790
|
-
|
|
791
|
-
logger.info(f"Generating {num_frames} frames for video output")
|
|
792
|
-
|
|
793
|
-
# Special handling for HiDream models
|
|
794
|
-
elif self.model_config.model_type == "hidream":
|
|
795
|
-
logger.info("Configuring HiDream model parameters")
|
|
796
|
-
|
|
797
|
-
# HiDream models support high resolution
|
|
798
|
-
generation_kwargs.update({
|
|
799
|
-
"width": width,
|
|
800
|
-
"height": height
|
|
801
|
-
})
|
|
802
|
-
|
|
803
|
-
# HiDream models have multiple text encoders, handle if provided
|
|
804
|
-
if "prompt_2" in kwargs:
|
|
805
|
-
generation_kwargs["prompt_2"] = self.truncate_prompt(kwargs["prompt_2"])
|
|
806
|
-
if "prompt_3" in kwargs:
|
|
807
|
-
generation_kwargs["prompt_3"] = self.truncate_prompt(kwargs["prompt_3"])
|
|
808
|
-
if "prompt_4" in kwargs:
|
|
809
|
-
generation_kwargs["prompt_4"] = self.truncate_prompt(kwargs["prompt_4"])
|
|
810
|
-
|
|
811
|
-
# Set optimal parameters for HiDream
|
|
812
|
-
max_seq_len = self.model_config.parameters.get("max_sequence_length", 128)
|
|
813
|
-
generation_kwargs["max_sequence_length"] = max_seq_len
|
|
814
|
-
|
|
815
|
-
# HiDream models use different guidance scale defaults
|
|
816
|
-
if guidance_scale is None or guidance_scale == 3.5:
|
|
817
|
-
generation_kwargs["guidance_scale"] = 5.0
|
|
818
|
-
logger.info("Using default guidance_scale=5.0 for HiDream model")
|
|
819
|
-
|
|
820
|
-
logger.info(f"Using max_sequence_length={max_seq_len} for HiDream model")
|
|
821
|
-
|
|
822
|
-
# Generate image
|
|
823
|
-
logger.info(f"Generation parameters: steps={num_inference_steps}, guidance={guidance_scale}")
|
|
824
|
-
|
|
825
|
-
# Add generator for reproducible results
|
|
826
|
-
if self.device == "cpu":
|
|
827
|
-
generator = torch.Generator().manual_seed(42)
|
|
828
|
-
else:
|
|
829
|
-
generator = torch.Generator(device=self.device).manual_seed(42)
|
|
830
|
-
generation_kwargs["generator"] = generator
|
|
831
|
-
|
|
832
|
-
# For SD 1.5, use a more conservative approach to avoid numerical issues
|
|
833
|
-
if self.model_config.model_type == "sd15":
|
|
834
|
-
# Lower guidance scale to prevent numerical instability
|
|
835
|
-
if generation_kwargs["guidance_scale"] > 7.0:
|
|
836
|
-
generation_kwargs["guidance_scale"] = 7.0
|
|
837
|
-
logger.info("Reduced guidance scale to 7.0 for stability")
|
|
838
|
-
|
|
839
|
-
# Ensure we're using float32 for better numerical stability
|
|
840
|
-
if self.device == "mps":
|
|
841
|
-
# For Apple Silicon, use specific optimizations
|
|
842
|
-
generation_kwargs["guidance_scale"] = min(generation_kwargs["guidance_scale"], 6.0)
|
|
843
|
-
logger.info("Applied MPS-specific optimizations")
|
|
844
|
-
|
|
845
|
-
# For SD 1.5, use manual pipeline execution to completely bypass safety checker
|
|
846
|
-
if self.model_config.model_type == "sd15" and not self.is_controlnet_pipeline:
|
|
847
|
-
logger.info("Using manual pipeline execution for SD 1.5 to bypass safety checker")
|
|
848
|
-
try:
|
|
849
|
-
# Manual pipeline execution with safety checks disabled
|
|
850
|
-
with torch.no_grad():
|
|
851
|
-
# Encode prompt
|
|
852
|
-
text_inputs = self.pipeline.tokenizer(
|
|
853
|
-
generation_kwargs["prompt"],
|
|
854
|
-
padding="max_length",
|
|
855
|
-
max_length=self.pipeline.tokenizer.model_max_length,
|
|
856
|
-
truncation=True,
|
|
857
|
-
return_tensors="pt",
|
|
858
|
-
)
|
|
859
|
-
text_embeddings = self.pipeline.text_encoder(text_inputs.input_ids.to(self.device))[0]
|
|
860
|
-
|
|
861
|
-
# Encode negative prompt
|
|
862
|
-
uncond_inputs = self.pipeline.tokenizer(
|
|
863
|
-
generation_kwargs["negative_prompt"],
|
|
864
|
-
padding="max_length",
|
|
865
|
-
max_length=self.pipeline.tokenizer.model_max_length,
|
|
866
|
-
truncation=True,
|
|
867
|
-
return_tensors="pt",
|
|
868
|
-
)
|
|
869
|
-
uncond_embeddings = self.pipeline.text_encoder(uncond_inputs.input_ids.to(self.device))[0]
|
|
870
|
-
|
|
871
|
-
# Concatenate embeddings
|
|
872
|
-
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
|
873
|
-
|
|
874
|
-
# Generate latents
|
|
875
|
-
latents = torch.randn(
|
|
876
|
-
(1, self.pipeline.unet.config.in_channels,
|
|
877
|
-
generation_kwargs["height"] // 8, generation_kwargs["width"] // 8),
|
|
878
|
-
generator=generation_kwargs["generator"],
|
|
879
|
-
device=self.device,
|
|
880
|
-
dtype=text_embeddings.dtype,
|
|
881
|
-
)
|
|
882
|
-
|
|
883
|
-
logger.debug(f"Initial latents stats - mean: {latents.mean().item():.4f}, std: {latents.std().item():.4f}")
|
|
884
|
-
logger.debug(f"Text embeddings stats - mean: {text_embeddings.mean().item():.4f}, std: {text_embeddings.std().item():.4f}")
|
|
885
|
-
|
|
886
|
-
# Set scheduler
|
|
887
|
-
self.pipeline.scheduler.set_timesteps(generation_kwargs["num_inference_steps"])
|
|
888
|
-
latents = latents * self.pipeline.scheduler.init_noise_sigma
|
|
889
|
-
|
|
890
|
-
logger.debug(f"Latents after noise scaling - mean: {latents.mean().item():.4f}, std: {latents.std().item():.4f}")
|
|
891
|
-
logger.debug(f"Scheduler init_noise_sigma: {self.pipeline.scheduler.init_noise_sigma}")
|
|
892
|
-
|
|
893
|
-
# Denoising loop
|
|
894
|
-
for i, t in enumerate(self.pipeline.scheduler.timesteps):
|
|
895
|
-
latent_model_input = torch.cat([latents] * 2)
|
|
896
|
-
latent_model_input = self.pipeline.scheduler.scale_model_input(latent_model_input, t)
|
|
897
|
-
|
|
898
|
-
# Check for NaN before UNet
|
|
899
|
-
if torch.isnan(latent_model_input).any():
|
|
900
|
-
logger.error(f"NaN detected in latent_model_input at step {i}")
|
|
901
|
-
break
|
|
902
|
-
|
|
903
|
-
noise_pred = self.pipeline.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
|
904
|
-
|
|
905
|
-
# Check for NaN after UNet
|
|
906
|
-
if torch.isnan(noise_pred).any():
|
|
907
|
-
logger.error(f"NaN detected in noise_pred at step {i}")
|
|
908
|
-
break
|
|
909
|
-
|
|
910
|
-
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
|
911
|
-
noise_pred = noise_pred_uncond + generation_kwargs["guidance_scale"] * (noise_pred_text - noise_pred_uncond)
|
|
912
|
-
|
|
913
|
-
# Check for NaN after guidance
|
|
914
|
-
if torch.isnan(noise_pred).any():
|
|
915
|
-
logger.error(f"NaN detected after guidance at step {i}")
|
|
916
|
-
break
|
|
917
|
-
|
|
918
|
-
latents = self.pipeline.scheduler.step(noise_pred, t, latents).prev_sample
|
|
919
|
-
|
|
920
|
-
# Check for NaN after scheduler step
|
|
921
|
-
if torch.isnan(latents).any():
|
|
922
|
-
logger.error(f"NaN detected in latents after scheduler step {i}")
|
|
923
|
-
break
|
|
924
|
-
|
|
925
|
-
if i == 0: # Log first step for debugging
|
|
926
|
-
logger.debug(f"Step {i}: latents mean={latents.mean().item():.4f}, std={latents.std().item():.4f}")
|
|
927
|
-
|
|
928
|
-
# Decode latents
|
|
929
|
-
latents = 1 / self.pipeline.vae.config.scaling_factor * latents
|
|
930
|
-
|
|
931
|
-
# Debug latents before VAE decode
|
|
932
|
-
logger.debug(f"Latents stats before VAE decode - mean: {latents.mean().item():.4f}, std: {latents.std().item():.4f}")
|
|
933
|
-
logger.debug(f"Latents range: [{latents.min().item():.4f}, {latents.max().item():.4f}]")
|
|
934
|
-
|
|
935
|
-
with torch.no_grad():
|
|
936
|
-
# Ensure latents are on correct device and dtype
|
|
937
|
-
latents = latents.to(device=self.device, dtype=self.pipeline.vae.dtype)
|
|
938
|
-
|
|
939
|
-
try:
|
|
940
|
-
image = self.pipeline.vae.decode(latents).sample
|
|
941
|
-
logger.debug(f"VAE decode successful - image shape: {image.shape}")
|
|
942
|
-
except Exception as e:
|
|
943
|
-
logger.error(f"VAE decode failed: {e}")
|
|
944
|
-
# Create a fallback image
|
|
945
|
-
image = torch.randn_like(latents).repeat(1, 3, 8, 8) * 0.1 + 0.5
|
|
946
|
-
logger.warning("Using fallback random image due to VAE decode failure")
|
|
947
|
-
|
|
948
|
-
# Convert to PIL with proper NaN/inf handling
|
|
949
|
-
logger.debug(f"Image stats after VAE decode - mean: {image.mean().item():.4f}, std: {image.std().item():.4f}")
|
|
950
|
-
logger.debug(f"Image range: [{image.min().item():.4f}, {image.max().item():.4f}]")
|
|
951
|
-
|
|
952
|
-
image = (image / 2 + 0.5).clamp(0, 1)
|
|
953
|
-
|
|
954
|
-
logger.debug(f"Image stats after normalization - mean: {image.mean().item():.4f}, std: {image.std().item():.4f}")
|
|
955
|
-
logger.debug(f"Image range after norm: [{image.min().item():.4f}, {image.max().item():.4f}]")
|
|
956
|
-
|
|
957
|
-
# Check for NaN or infinite values before conversion
|
|
958
|
-
if torch.isnan(image).any() or torch.isinf(image).any():
|
|
959
|
-
logger.warning("NaN or infinite values detected in image tensor, applying selective fixes")
|
|
960
|
-
# Only replace NaN/inf values, keep valid pixels intact
|
|
961
|
-
nan_mask = torch.isnan(image)
|
|
962
|
-
inf_mask = torch.isinf(image)
|
|
963
|
-
|
|
964
|
-
# Replace only problematic pixels
|
|
965
|
-
image = torch.where(nan_mask, torch.tensor(0.5, device=image.device, dtype=image.dtype), image)
|
|
966
|
-
image = torch.where(inf_mask & (image > 0), torch.tensor(1.0, device=image.device, dtype=image.dtype), image)
|
|
967
|
-
image = torch.where(inf_mask & (image < 0), torch.tensor(0.0, device=image.device, dtype=image.dtype), image)
|
|
968
|
-
|
|
969
|
-
logger.info(f"Fixed {nan_mask.sum().item()} NaN pixels and {inf_mask.sum().item()} infinite pixels")
|
|
970
|
-
|
|
971
|
-
# Final clamp to ensure valid range
|
|
972
|
-
image = torch.clamp(image, 0, 1)
|
|
973
|
-
|
|
974
|
-
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
|
975
|
-
|
|
976
|
-
# Additional validation before uint8 conversion - only fix problematic pixels
|
|
977
|
-
if np.isnan(image).any() or np.isinf(image).any():
|
|
978
|
-
logger.warning("NaN/inf values detected in numpy array, applying selective fixes")
|
|
979
|
-
nan_count = np.isnan(image).sum()
|
|
980
|
-
inf_count = np.isinf(image).sum()
|
|
981
|
-
|
|
982
|
-
# Only replace problematic pixels, preserve valid ones
|
|
983
|
-
image = np.where(np.isnan(image), 0.5, image)
|
|
984
|
-
image = np.where(np.isinf(image) & (image > 0), 1.0, image)
|
|
985
|
-
image = np.where(np.isinf(image) & (image < 0), 0.0, image)
|
|
986
|
-
|
|
987
|
-
logger.info(f"Fixed {nan_count} NaN and {inf_count} infinite pixels in numpy array")
|
|
988
|
-
|
|
989
|
-
# Ensure valid range
|
|
990
|
-
image = np.clip(image, 0, 1)
|
|
991
|
-
|
|
992
|
-
# Safe conversion to uint8
|
|
993
|
-
image = (image * 255).astype(np.uint8)
|
|
994
|
-
|
|
995
|
-
from PIL import Image as PILImage
|
|
996
|
-
image = PILImage.fromarray(image[0])
|
|
997
|
-
|
|
998
|
-
# Create a mock output object
|
|
999
|
-
class MockOutput:
|
|
1000
|
-
def __init__(self, images):
|
|
1001
|
-
self.images = images
|
|
1002
|
-
self.nsfw_content_detected = [False] * len(images)
|
|
1003
|
-
|
|
1004
|
-
output = MockOutput([image])
|
|
1005
|
-
|
|
1006
|
-
except Exception as e:
|
|
1007
|
-
logger.error(f"Manual pipeline execution failed: {e}")
|
|
1008
|
-
raise e
|
|
1009
|
-
else:
|
|
1010
|
-
# For FLUX and other models, use regular pipeline execution with safety checker disabled
|
|
1011
|
-
logger.info(f"Using regular pipeline execution for {self.model_config.model_type} model")
|
|
1012
|
-
|
|
1013
|
-
# Debug: Log device and generation kwargs
|
|
1014
|
-
logger.debug(f"Pipeline device: {self.device}")
|
|
1015
|
-
logger.debug(f"Generator device: {generation_kwargs['generator'].device if hasattr(generation_kwargs['generator'], 'device') else 'CPU'}")
|
|
1016
|
-
|
|
1017
|
-
# Ensure all tensors are on the correct device
|
|
1018
|
-
try:
|
|
1019
|
-
# For FLUX models, temporarily disable any remaining safety checker components
|
|
1020
|
-
if self.model_config.model_type == "flux":
|
|
1021
|
-
# Store original safety checker components
|
|
1022
|
-
original_safety_checker = getattr(self.pipeline, 'safety_checker', None)
|
|
1023
|
-
original_feature_extractor = getattr(self.pipeline, 'feature_extractor', None)
|
|
1024
|
-
original_requires_safety_checker = getattr(self.pipeline, 'requires_safety_checker', None)
|
|
1025
|
-
|
|
1026
|
-
# Temporarily set to None
|
|
1027
|
-
if hasattr(self.pipeline, 'safety_checker'):
|
|
1028
|
-
self.pipeline.safety_checker = None
|
|
1029
|
-
if hasattr(self.pipeline, 'feature_extractor'):
|
|
1030
|
-
self.pipeline.feature_extractor = None
|
|
1031
|
-
if hasattr(self.pipeline, 'requires_safety_checker'):
|
|
1032
|
-
self.pipeline.requires_safety_checker = False
|
|
1033
|
-
|
|
1034
|
-
logger.info("Temporarily disabled safety checker components for FLUX generation")
|
|
1035
|
-
|
|
1036
|
-
output = self.pipeline(**generation_kwargs)
|
|
1037
|
-
|
|
1038
|
-
# Restore original safety checker components for FLUX (though they should remain None)
|
|
1039
|
-
if self.model_config.model_type == "flux":
|
|
1040
|
-
if hasattr(self.pipeline, 'safety_checker'):
|
|
1041
|
-
self.pipeline.safety_checker = original_safety_checker
|
|
1042
|
-
if hasattr(self.pipeline, 'feature_extractor'):
|
|
1043
|
-
self.pipeline.feature_extractor = original_feature_extractor
|
|
1044
|
-
if hasattr(self.pipeline, 'requires_safety_checker'):
|
|
1045
|
-
self.pipeline.requires_safety_checker = original_requires_safety_checker
|
|
1046
|
-
|
|
1047
|
-
except RuntimeError as e:
|
|
1048
|
-
if "CUDA" in str(e) and self.device == "cpu":
|
|
1049
|
-
logger.error(f"CUDA error on CPU device: {e}")
|
|
1050
|
-
logger.info("Attempting to fix device mismatch...")
|
|
1051
|
-
|
|
1052
|
-
# Remove generator and try again
|
|
1053
|
-
generation_kwargs_fixed = generation_kwargs.copy()
|
|
1054
|
-
generation_kwargs_fixed.pop("generator", None)
|
|
1055
|
-
|
|
1056
|
-
output = self.pipeline(**generation_kwargs_fixed)
|
|
1057
|
-
else:
|
|
1058
|
-
raise e
|
|
1059
|
-
|
|
1060
|
-
# Special handling for FLUX models to bypass any remaining safety checker issues
|
|
1061
|
-
if self.model_config.model_type == "flux" and hasattr(output, 'images'):
|
|
1062
|
-
# Check if we got a black image and try to regenerate with different approach
|
|
1063
|
-
test_image = output.images[0]
|
|
1064
|
-
test_array = np.array(test_image)
|
|
1065
|
-
|
|
1066
|
-
if np.all(test_array == 0):
|
|
1067
|
-
logger.warning("FLUX model returned black image, attempting manual image processing")
|
|
1068
|
-
|
|
1069
|
-
# Try to access the raw latents or intermediate results
|
|
1070
|
-
if hasattr(output, 'latents') or hasattr(self.pipeline, 'vae'):
|
|
1071
|
-
try:
|
|
1072
|
-
# Generate a simple test image to verify the pipeline is working
|
|
1073
|
-
logger.info("Generating test image with simple prompt")
|
|
1074
|
-
simple_kwargs = generation_kwargs.copy()
|
|
1075
|
-
simple_kwargs["prompt"] = "a red apple"
|
|
1076
|
-
simple_kwargs["negative_prompt"] = ""
|
|
1077
|
-
|
|
1078
|
-
# Temporarily disable any image processing that might cause issues
|
|
1079
|
-
original_image_processor = getattr(self.pipeline, 'image_processor', None)
|
|
1080
|
-
if hasattr(self.pipeline, 'image_processor'):
|
|
1081
|
-
# Create a custom image processor that handles NaN values
|
|
1082
|
-
class SafeImageProcessor:
|
|
1083
|
-
def postprocess(self, image, output_type="pil", do_denormalize=None):
|
|
1084
|
-
if isinstance(image, torch.Tensor):
|
|
1085
|
-
# Handle NaN and inf values before conversion
|
|
1086
|
-
image = torch.nan_to_num(image, nan=0.5, posinf=1.0, neginf=0.0)
|
|
1087
|
-
image = torch.clamp(image, 0, 1)
|
|
1088
|
-
|
|
1089
|
-
# Convert to numpy
|
|
1090
|
-
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
|
1091
|
-
|
|
1092
|
-
# Additional safety checks
|
|
1093
|
-
image = np.nan_to_num(image, nan=0.5, posinf=1.0, neginf=0.0)
|
|
1094
|
-
image = np.clip(image, 0, 1)
|
|
1095
|
-
|
|
1096
|
-
# Convert to uint8 safely
|
|
1097
|
-
image = (image * 255).astype(np.uint8)
|
|
1098
|
-
|
|
1099
|
-
# Convert to PIL
|
|
1100
|
-
if output_type == "pil":
|
|
1101
|
-
from PIL import Image as PILImage
|
|
1102
|
-
return [PILImage.fromarray(img) for img in image]
|
|
1103
|
-
return image
|
|
1104
|
-
return image
|
|
1105
|
-
|
|
1106
|
-
self.pipeline.image_processor = SafeImageProcessor()
|
|
1107
|
-
logger.info("Applied safe image processor for FLUX model")
|
|
1108
|
-
|
|
1109
|
-
# Try generation again with safe image processor
|
|
1110
|
-
test_output = self.pipeline(**simple_kwargs)
|
|
1111
|
-
|
|
1112
|
-
# Restore original image processor
|
|
1113
|
-
if original_image_processor:
|
|
1114
|
-
self.pipeline.image_processor = original_image_processor
|
|
1115
|
-
|
|
1116
|
-
if hasattr(test_output, 'images') and len(test_output.images) > 0:
|
|
1117
|
-
test_result = np.array(test_output.images[0])
|
|
1118
|
-
if not np.all(test_result == 0):
|
|
1119
|
-
logger.info("Test generation successful, using original output")
|
|
1120
|
-
# The issue might be with the specific prompt, return the test image
|
|
1121
|
-
output = test_output
|
|
1122
|
-
else:
|
|
1123
|
-
logger.warning("Test generation also returned black image")
|
|
1124
|
-
|
|
1125
|
-
except Exception as e:
|
|
1126
|
-
logger.warning(f"Manual image processing failed: {e}")
|
|
1127
|
-
|
|
1128
|
-
# Check if output contains nsfw_content_detected
|
|
1129
|
-
if hasattr(output, 'nsfw_content_detected') and output.nsfw_content_detected:
|
|
1130
|
-
logger.warning("NSFW content detected by pipeline - this should not happen with safety checker disabled")
|
|
1131
|
-
|
|
1132
|
-
# Special handling for video models that return multiple frames
|
|
1133
|
-
if self.model_config.model_type == "video":
|
|
1134
|
-
logger.info(f"Processing video output with {len(output.frames)} frames")
|
|
1135
|
-
|
|
1136
|
-
# For now, return the first frame as a single image
|
|
1137
|
-
# In the future, this could be extended to return a video file or GIF
|
|
1138
|
-
if hasattr(output, 'frames') and len(output.frames) > 0:
|
|
1139
|
-
image = output.frames[0]
|
|
1140
|
-
logger.info("Extracted first frame from video generation")
|
|
1141
|
-
else:
|
|
1142
|
-
# Fallback to images if frames not available
|
|
1143
|
-
image = output.images[0]
|
|
1144
|
-
logger.info("Using first image from video output")
|
|
1145
|
-
|
|
1146
|
-
# TODO: Add option to save all frames or create a GIF
|
|
1147
|
-
# frames = output.frames if hasattr(output, 'frames') else output.images
|
|
1148
|
-
# save_video_frames(frames, prompt)
|
|
1149
|
-
else:
|
|
1150
|
-
# Standard single image output for other models
|
|
1151
|
-
image = output.images[0]
|
|
1152
|
-
|
|
1153
|
-
# Debug: Check image properties
|
|
1154
|
-
logger.info(f"Generated image size: {image.size}, mode: {image.mode}")
|
|
1155
|
-
|
|
1156
|
-
# Validate and fix image data if needed
|
|
1157
|
-
image = self._validate_and_fix_image(image)
|
|
1158
|
-
|
|
1159
|
-
logger.info("Image generation completed")
|
|
1160
|
-
return image
|
|
1161
|
-
|
|
1162
|
-
except Exception as e:
|
|
1163
|
-
logger.error(f"Image generation failed: {e}")
|
|
1164
|
-
# Return error image
|
|
1165
|
-
return self._create_error_image(str(e), truncated_prompt)
|
|
1166
|
-
|
|
1167
|
-
def _validate_and_fix_image(self, image: Image.Image) -> Image.Image:
|
|
1168
|
-
"""Validate and fix image data to handle NaN/infinite values"""
|
|
1169
|
-
try:
|
|
1170
|
-
# Convert PIL image to numpy array
|
|
1171
|
-
img_array = np.array(image)
|
|
1172
|
-
|
|
1173
|
-
# Check if image is completely black (safety checker replacement)
|
|
1174
|
-
if np.all(img_array == 0):
|
|
1175
|
-
logger.error("Generated image is completely black - likely safety checker issue")
|
|
1176
|
-
if self.model_config.model_type == "flux":
|
|
1177
|
-
logger.error("FLUX model safety checker is still active despite our attempts to disable it")
|
|
1178
|
-
logger.error("This suggests the safety checker is built into the model weights or pipeline")
|
|
1179
|
-
logger.info("Attempting to generate a test pattern instead of black image")
|
|
1180
|
-
|
|
1181
|
-
# Create a test pattern to show the system is working
|
|
1182
|
-
test_image = np.zeros_like(img_array)
|
|
1183
|
-
height, width = test_image.shape[:2]
|
|
1184
|
-
|
|
1185
|
-
# Create a simple gradient pattern
|
|
1186
|
-
for i in range(height):
|
|
1187
|
-
for j in range(width):
|
|
1188
|
-
test_image[i, j] = [
|
|
1189
|
-
int(255 * i / height), # Red gradient
|
|
1190
|
-
int(255 * j / width), # Green gradient
|
|
1191
|
-
128 # Blue constant
|
|
1192
|
-
]
|
|
1193
|
-
|
|
1194
|
-
logger.info("Created test gradient pattern to replace black image")
|
|
1195
|
-
return Image.fromarray(test_image.astype(np.uint8))
|
|
1196
|
-
else:
|
|
1197
|
-
logger.error("This suggests the safety checker is still active despite our attempts to disable it")
|
|
1198
|
-
|
|
1199
|
-
# Check for NaN or infinite values
|
|
1200
|
-
if np.isnan(img_array).any() or np.isinf(img_array).any():
|
|
1201
|
-
logger.warning("Invalid values (NaN/inf) detected in generated image, applying fixes")
|
|
1202
|
-
|
|
1203
|
-
# Replace NaN and infinite values with valid ranges
|
|
1204
|
-
img_array = np.nan_to_num(img_array, nan=0.0, posinf=255.0, neginf=0.0)
|
|
1205
|
-
|
|
1206
|
-
# Ensure values are in valid range [0, 255]
|
|
1207
|
-
img_array = np.clip(img_array, 0, 255)
|
|
1208
|
-
|
|
1209
|
-
# Convert back to PIL Image
|
|
1210
|
-
image = Image.fromarray(img_array.astype(np.uint8))
|
|
1211
|
-
logger.info("Image data fixed successfully")
|
|
1212
|
-
|
|
1213
|
-
# Log image statistics for debugging
|
|
1214
|
-
mean_val = np.mean(img_array)
|
|
1215
|
-
std_val = np.std(img_array)
|
|
1216
|
-
logger.info(f"Image stats - mean: {mean_val:.2f}, std: {std_val:.2f}")
|
|
1217
|
-
|
|
1218
|
-
# Additional check for very low variance (mostly black/gray)
|
|
1219
|
-
if std_val < 10.0 and mean_val < 50.0:
|
|
1220
|
-
logger.warning(f"Image has very low variance (std={std_val:.2f}) and low brightness (mean={mean_val:.2f})")
|
|
1221
|
-
logger.warning("This might indicate safety checker interference or generation issues")
|
|
1222
|
-
if self.model_config.model_type == "flux":
|
|
1223
|
-
logger.info("For FLUX models, try using different prompts or adjusting generation parameters")
|
|
1224
|
-
|
|
1225
|
-
return image
|
|
1226
|
-
|
|
1227
|
-
except Exception as e:
|
|
1228
|
-
logger.warning(f"Failed to validate image data: {e}, returning original image")
|
|
1229
|
-
return image
|
|
1230
|
-
|
|
1231
|
-
def _prepare_control_image(self, control_image: Union[Image.Image, str], width: int, height: int) -> Image.Image:
|
|
1232
|
-
"""Prepare control image for ControlNet"""
|
|
1233
|
-
try:
|
|
1234
|
-
# Initialize ControlNet preprocessors if needed
|
|
1235
|
-
if not controlnet_preprocessor.is_initialized():
|
|
1236
|
-
logger.info("Initializing ControlNet preprocessors for image processing...")
|
|
1237
|
-
if not controlnet_preprocessor.initialize():
|
|
1238
|
-
logger.error("Failed to initialize ControlNet preprocessors")
|
|
1239
|
-
# Continue with basic processing
|
|
1240
|
-
|
|
1241
|
-
# Load image if path is provided
|
|
1242
|
-
if isinstance(control_image, str):
|
|
1243
|
-
control_image = Image.open(control_image).convert('RGB')
|
|
1244
|
-
elif not isinstance(control_image, Image.Image):
|
|
1245
|
-
raise ValueError("Control image must be PIL Image or file path")
|
|
1246
|
-
|
|
1247
|
-
# Ensure image is RGB
|
|
1248
|
-
if control_image.mode != 'RGB':
|
|
1249
|
-
control_image = control_image.convert('RGB')
|
|
1250
|
-
|
|
1251
|
-
# Get ControlNet type from model config
|
|
1252
|
-
from ..models.manager import model_manager
|
|
1253
|
-
model_info = model_manager.get_model_info(self.model_config.name)
|
|
1254
|
-
controlnet_type = model_info.get('controlnet_type', 'canny') if model_info else 'canny'
|
|
1255
|
-
|
|
1256
|
-
# Preprocess the control image based on ControlNet type
|
|
1257
|
-
logger.info(f"Preprocessing control image for {controlnet_type} ControlNet")
|
|
1258
|
-
processed_image = controlnet_preprocessor.preprocess(control_image, controlnet_type)
|
|
1259
|
-
|
|
1260
|
-
# Resize to match generation size
|
|
1261
|
-
processed_image = controlnet_preprocessor.resize_for_controlnet(processed_image, width, height)
|
|
1262
|
-
|
|
1263
|
-
logger.info(f"Control image prepared: {processed_image.size}")
|
|
1264
|
-
return processed_image
|
|
1265
|
-
|
|
1266
|
-
except Exception as e:
|
|
1267
|
-
logger.error(f"Failed to prepare control image: {e}")
|
|
1268
|
-
# Return resized original image as fallback
|
|
1269
|
-
if isinstance(control_image, str):
|
|
1270
|
-
control_image = Image.open(control_image).convert('RGB')
|
|
1271
|
-
return controlnet_preprocessor.resize_for_controlnet(control_image, width, height)
|
|
1272
|
-
|
|
1273
|
-
def _create_error_image(self, error_msg: str, prompt: str) -> Image.Image:
|
|
1274
|
-
"""Create error message image"""
|
|
1275
|
-
from PIL import ImageDraw, ImageFont
|
|
1276
|
-
|
|
1277
|
-
# Create white background image
|
|
1278
|
-
img = Image.new('RGB', (512, 512), color=(255, 255, 255))
|
|
1279
|
-
draw = ImageDraw.Draw(img)
|
|
1280
|
-
|
|
1281
|
-
# Draw error information
|
|
1282
|
-
try:
|
|
1283
|
-
# Try to use system font
|
|
1284
|
-
font = ImageFont.load_default()
|
|
1285
|
-
except:
|
|
1286
|
-
font = None
|
|
1287
|
-
|
|
1288
|
-
# Draw text
|
|
1289
|
-
draw.text((10, 10), f"Error: {error_msg}", fill=(255, 0, 0), font=font)
|
|
1290
|
-
draw.text((10, 30), f"Prompt: {prompt[:50]}...", fill=(0, 0, 0), font=font)
|
|
1291
|
-
|
|
1292
|
-
return img
|
|
1293
|
-
|
|
190
|
+
|
|
1294
191
|
def unload(self):
|
|
1295
|
-
"""Unload model and free
|
|
1296
|
-
|
|
1297
|
-
|
|
1298
|
-
self.
|
|
1299
|
-
|
|
1300
|
-
|
|
1301
|
-
)
|
|
1302
|
-
|
|
1303
|
-
if is_gguf_model:
|
|
1304
|
-
if GGUF_AVAILABLE and gguf_loader.is_loaded():
|
|
1305
|
-
gguf_loader.unload_model()
|
|
1306
|
-
logger.info("GGUF model unloaded")
|
|
1307
|
-
|
|
1308
|
-
self.model_config = None
|
|
1309
|
-
self.tokenizer = None
|
|
1310
|
-
return
|
|
1311
|
-
|
|
1312
|
-
# Handle regular diffusion models
|
|
1313
|
-
if self.pipeline:
|
|
1314
|
-
# Move to CPU to free GPU memory
|
|
1315
|
-
self.pipeline = self.pipeline.to("cpu")
|
|
1316
|
-
|
|
1317
|
-
# Clear GPU cache
|
|
1318
|
-
if torch.cuda.is_available():
|
|
1319
|
-
torch.cuda.empty_cache()
|
|
1320
|
-
|
|
1321
|
-
# Delete pipeline
|
|
1322
|
-
del self.pipeline
|
|
1323
|
-
self.pipeline = None
|
|
1324
|
-
self.model_config = None
|
|
1325
|
-
self.tokenizer = None
|
|
1326
|
-
|
|
1327
|
-
logger.info("Model unloaded")
|
|
1328
|
-
|
|
192
|
+
"""Unload the current model and free resources."""
|
|
193
|
+
if self._strategy:
|
|
194
|
+
self._strategy.unload()
|
|
195
|
+
self._strategy = None
|
|
196
|
+
self.model_config = None
|
|
197
|
+
self.device = None
|
|
198
|
+
logger.info("Engine unloaded")
|
|
199
|
+
|
|
1329
200
|
def is_loaded(self) -> bool:
|
|
1330
|
-
"""Check if model is loaded"""
|
|
1331
|
-
|
|
1332
|
-
is_gguf_model = (
|
|
1333
|
-
self.model_config and
|
|
1334
|
-
(self.model_config.model_type == "gguf" or
|
|
1335
|
-
(self.model_config.variant and "gguf" in self.model_config.variant.lower()))
|
|
1336
|
-
)
|
|
1337
|
-
|
|
1338
|
-
if is_gguf_model:
|
|
1339
|
-
return GGUF_AVAILABLE and gguf_loader.is_loaded()
|
|
1340
|
-
|
|
1341
|
-
# Check regular pipeline models
|
|
1342
|
-
return self.pipeline is not None
|
|
1343
|
-
|
|
1344
|
-
def load_lora_runtime(self, repo_id: str, weight_name: str = None, scale: float = 1.0):
|
|
1345
|
-
"""Load LoRA weights at runtime"""
|
|
1346
|
-
if not self.pipeline:
|
|
1347
|
-
raise RuntimeError("Model not loaded")
|
|
1348
|
-
|
|
1349
|
-
try:
|
|
1350
|
-
if weight_name:
|
|
1351
|
-
logger.info(f"Loading LoRA from {repo_id} with weight {weight_name}")
|
|
1352
|
-
self.pipeline.load_lora_weights(repo_id, weight_name=weight_name)
|
|
1353
|
-
else:
|
|
1354
|
-
logger.info(f"Loading LoRA from {repo_id}")
|
|
1355
|
-
self.pipeline.load_lora_weights(repo_id)
|
|
1356
|
-
|
|
1357
|
-
# Set LoRA scale
|
|
1358
|
-
if hasattr(self.pipeline, 'set_adapters') and scale != 1.0:
|
|
1359
|
-
self.pipeline.set_adapters(["default"], adapter_weights=[scale])
|
|
1360
|
-
logger.info(f"Set LoRA scale to {scale}")
|
|
1361
|
-
|
|
1362
|
-
# Track LoRA state
|
|
1363
|
-
self.current_lora = {
|
|
1364
|
-
"repo_id": repo_id,
|
|
1365
|
-
"weight_name": weight_name,
|
|
1366
|
-
"scale": scale,
|
|
1367
|
-
"loaded": True
|
|
1368
|
-
}
|
|
1369
|
-
|
|
1370
|
-
logger.info("LoRA weights loaded successfully at runtime")
|
|
1371
|
-
return True
|
|
1372
|
-
|
|
1373
|
-
except Exception as e:
|
|
1374
|
-
logger.error(f"Failed to load LoRA weights at runtime: {e}")
|
|
1375
|
-
return False
|
|
1376
|
-
|
|
1377
|
-
def unload_lora(self):
|
|
1378
|
-
"""Unload LoRA weights"""
|
|
1379
|
-
if not self.pipeline:
|
|
1380
|
-
return False
|
|
1381
|
-
|
|
1382
|
-
try:
|
|
1383
|
-
if hasattr(self.pipeline, 'unload_lora_weights'):
|
|
1384
|
-
self.pipeline.unload_lora_weights()
|
|
1385
|
-
# Clear LoRA state
|
|
1386
|
-
self.current_lora = None
|
|
1387
|
-
logger.info("LoRA weights unloaded successfully")
|
|
1388
|
-
return True
|
|
1389
|
-
else:
|
|
1390
|
-
logger.warning("Pipeline does not support LoRA unloading")
|
|
1391
|
-
return False
|
|
1392
|
-
except Exception as e:
|
|
1393
|
-
logger.error(f"Failed to unload LoRA weights: {e}")
|
|
1394
|
-
return False
|
|
201
|
+
"""Check if a model is loaded."""
|
|
202
|
+
return self._strategy is not None and self._strategy.is_loaded
|
|
1395
203
|
|
|
1396
204
|
def get_model_info(self) -> Optional[Dict[str, Any]]:
|
|
1397
|
-
"""Get
|
|
1398
|
-
if not self.
|
|
205
|
+
"""Get information about the currently loaded model."""
|
|
206
|
+
if not self._strategy:
|
|
1399
207
|
return None
|
|
1400
|
-
|
|
1401
|
-
|
|
1402
|
-
|
|
1403
|
-
|
|
1404
|
-
|
|
1405
|
-
|
|
1406
|
-
|
|
1407
|
-
|
|
1408
|
-
|
|
1409
|
-
|
|
1410
|
-
|
|
1411
|
-
|
|
1412
|
-
|
|
1413
|
-
|
|
1414
|
-
|
|
1415
|
-
|
|
1416
|
-
if
|
|
1417
|
-
|
|
1418
|
-
|
|
1419
|
-
base_info["gguf_available"] = True
|
|
1420
|
-
base_info["gguf_loaded"] = gguf_loader.is_loaded()
|
|
1421
|
-
base_info["is_gguf"] = True
|
|
1422
|
-
else:
|
|
1423
|
-
base_info["gguf_available"] = GGUF_AVAILABLE
|
|
1424
|
-
base_info["gguf_loaded"] = False
|
|
1425
|
-
base_info["is_gguf"] = is_gguf_model
|
|
1426
|
-
|
|
1427
|
-
return base_info
|
|
208
|
+
info = self._strategy.get_info()
|
|
209
|
+
info["strategy"] = type(self._strategy).__name__
|
|
210
|
+
return info
|
|
211
|
+
|
|
212
|
+
# -- LoRA support --
|
|
213
|
+
|
|
214
|
+
def load_lora_runtime(
|
|
215
|
+
self, repo_id: str, weight_name: str = None, scale: float = 1.0
|
|
216
|
+
) -> bool:
|
|
217
|
+
"""Load LoRA weights at runtime."""
|
|
218
|
+
if not self._strategy:
|
|
219
|
+
raise RuntimeError("No model loaded")
|
|
220
|
+
return self._strategy.load_lora_runtime(repo_id, weight_name, scale)
|
|
221
|
+
|
|
222
|
+
def unload_lora(self) -> bool:
|
|
223
|
+
"""Unload current LoRA weights."""
|
|
224
|
+
if not self._strategy:
|
|
225
|
+
return False
|
|
226
|
+
return self._strategy.unload_lora()
|