ollamadiffuser 1.0.0__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -6,12 +6,16 @@ from diffusers import (
6
6
  StableDiffusionPipeline,
7
7
  StableDiffusionXLPipeline,
8
8
  StableDiffusion3Pipeline,
9
- FluxPipeline
9
+ FluxPipeline,
10
+ StableDiffusionControlNetPipeline,
11
+ StableDiffusionXLControlNetPipeline,
12
+ ControlNetModel
10
13
  )
11
14
  from PIL import Image
12
- from typing import Optional, Dict, Any
15
+ from typing import Optional, Dict, Any, Union
13
16
  from pathlib import Path
14
17
  from ..config.settings import ModelConfig
18
+ from ..utils.controlnet_preprocessors import controlnet_preprocessor
15
19
 
16
20
  # Global safety checker disabling
17
21
  os.environ["DISABLE_NSFW_CHECKER"] = "1"
@@ -28,6 +32,8 @@ class InferenceEngine:
28
32
  self.tokenizer = None
29
33
  self.max_token_limit = 77
30
34
  self.current_lora = None # Track current LoRA state
35
+ self.controlnet = None # Track ControlNet model
36
+ self.is_controlnet_pipeline = False # Track if current pipeline is ControlNet
31
37
 
32
38
  def _get_device(self) -> str:
33
39
  """Automatically detect available device"""
@@ -56,7 +62,9 @@ class InferenceEngine:
56
62
  "sd15": StableDiffusionPipeline,
57
63
  "sdxl": StableDiffusionXLPipeline,
58
64
  "sd3": StableDiffusion3Pipeline,
59
- "flux": FluxPipeline
65
+ "flux": FluxPipeline,
66
+ "controlnet_sd15": StableDiffusionControlNetPipeline,
67
+ "controlnet_sdxl": StableDiffusionXLControlNetPipeline,
60
68
  }
61
69
  return pipeline_map.get(model_type)
62
70
 
@@ -88,6 +96,13 @@ class InferenceEngine:
88
96
  logger.error(f"Unsupported model type: {model_config.model_type}")
89
97
  return False
90
98
 
99
+ # Check if this is a ControlNet model
100
+ self.is_controlnet_pipeline = model_config.model_type.startswith("controlnet_")
101
+
102
+ # Handle ControlNet models
103
+ if self.is_controlnet_pipeline:
104
+ return self._load_controlnet_model(model_config, pipeline_class, {})
105
+
91
106
  # Set loading parameters
92
107
  load_kwargs = {}
93
108
  if model_config.variant == "fp16":
@@ -113,7 +128,7 @@ class InferenceEngine:
113
128
  logger.info("Using bfloat16 for FLUX model")
114
129
 
115
130
  # Disable safety checker for SD 1.5 to prevent false NSFW detections
116
- if model_config.model_type == "sd15":
131
+ if model_config.model_type == "sd15" or model_config.model_type == "sdxl":
117
132
  load_kwargs["safety_checker"] = None
118
133
  load_kwargs["requires_safety_checker"] = False
119
134
  load_kwargs["feature_extractor"] = None
@@ -184,7 +199,7 @@ class InferenceEngine:
184
199
  logger.debug(f"Sequential CPU offload not available: {e}")
185
200
 
186
201
  # Additional safety checker disabling for SD 1.5 (in case the above didn't work)
187
- if model_config.model_type == "sd15":
202
+ if model_config.model_type == "sd15" or model_config.model_type == "sdxl":
188
203
  if hasattr(self.pipeline, 'safety_checker'):
189
204
  self.pipeline.safety_checker = None
190
205
  if hasattr(self.pipeline, 'feature_extractor'):
@@ -242,6 +257,97 @@ class InferenceEngine:
242
257
  logger.error(f"Failed to load model: {e}")
243
258
  return False
244
259
 
260
+ def _load_controlnet_model(self, model_config: ModelConfig, pipeline_class, load_kwargs: dict) -> bool:
261
+ """Load ControlNet model with base model"""
262
+ try:
263
+ # Get base model info
264
+ base_model_name = getattr(model_config, 'base_model', None)
265
+ if not base_model_name:
266
+ # Try to extract from model registry
267
+ from ..models.manager import model_manager
268
+ model_info = model_manager.get_model_info(model_config.name)
269
+ if model_info and 'base_model' in model_info:
270
+ base_model_name = model_info['base_model']
271
+ else:
272
+ logger.error(f"No base model specified for ControlNet model: {model_config.name}")
273
+ return False
274
+
275
+ # Check if base model is installed
276
+ from ..models.manager import model_manager
277
+ if not model_manager.is_model_installed(base_model_name):
278
+ logger.error(f"Base model '{base_model_name}' not installed. Please install it first.")
279
+ return False
280
+
281
+ # Get base model config
282
+ from ..config.settings import settings
283
+ base_model_config = settings.models[base_model_name]
284
+
285
+ # Set loading parameters based on variant
286
+ if model_config.variant == "fp16":
287
+ load_kwargs["torch_dtype"] = torch.float16
288
+ load_kwargs["variant"] = "fp16"
289
+ elif model_config.variant == "bf16":
290
+ load_kwargs["torch_dtype"] = torch.bfloat16
291
+
292
+ # Handle device-specific optimizations
293
+ if self.device == "cpu" or self.device == "mps":
294
+ load_kwargs["torch_dtype"] = torch.float32
295
+ load_kwargs.pop("variant", None)
296
+ logger.info(f"Using float32 for {self.device} inference to improve stability")
297
+
298
+ # Disable safety checker
299
+ load_kwargs["safety_checker"] = None
300
+ load_kwargs["requires_safety_checker"] = False
301
+ load_kwargs["feature_extractor"] = None
302
+
303
+ # Load ControlNet model
304
+ logger.info(f"Loading ControlNet model from: {model_config.path}")
305
+ self.controlnet = ControlNetModel.from_pretrained(
306
+ model_config.path,
307
+ torch_dtype=load_kwargs.get("torch_dtype", torch.float32)
308
+ )
309
+
310
+ # Load pipeline with ControlNet and base model
311
+ logger.info(f"Loading ControlNet pipeline with base model: {base_model_name}")
312
+ self.pipeline = pipeline_class.from_pretrained(
313
+ base_model_config.path,
314
+ controlnet=self.controlnet,
315
+ **load_kwargs
316
+ )
317
+
318
+ # Move to device
319
+ try:
320
+ self.pipeline = self.pipeline.to(self.device)
321
+ self.controlnet = self.controlnet.to(self.device)
322
+ logger.info(f"ControlNet pipeline moved to {self.device}")
323
+ except Exception as e:
324
+ logger.warning(f"Failed to move ControlNet pipeline to {self.device}: {e}")
325
+ if self.device != "cpu":
326
+ logger.info("Falling back to CPU")
327
+ self.device = "cpu"
328
+ self.pipeline = self.pipeline.to("cpu")
329
+ self.controlnet = self.controlnet.to("cpu")
330
+
331
+ # Enable memory optimizations
332
+ if hasattr(self.pipeline, 'enable_attention_slicing'):
333
+ self.pipeline.enable_attention_slicing()
334
+ logger.info("Enabled attention slicing for ControlNet pipeline")
335
+
336
+ # Apply additional optimizations
337
+ self._apply_optimizations()
338
+
339
+ # Set tokenizer
340
+ if hasattr(self.pipeline, 'tokenizer'):
341
+ self.tokenizer = self.pipeline.tokenizer
342
+
343
+ self.model_config = model_config
344
+ logger.info(f"ControlNet model {model_config.name} loaded successfully")
345
+ return True
346
+
347
+ except Exception as e:
348
+ logger.error(f"Failed to load ControlNet model: {e}")
349
+ return False
350
+
245
351
  def _load_lora(self, model_config: ModelConfig):
246
352
  """Load LoRA weights"""
247
353
  try:
@@ -333,11 +439,23 @@ class InferenceEngine:
333
439
  guidance_scale: Optional[float] = None,
334
440
  width: int = 1024,
335
441
  height: int = 1024,
442
+ control_image: Optional[Union[Image.Image, str]] = None,
443
+ controlnet_conditioning_scale: float = 1.0,
444
+ control_guidance_start: float = 0.0,
445
+ control_guidance_end: float = 1.0,
336
446
  **kwargs) -> Image.Image:
337
447
  """Generate image"""
338
448
  if not self.pipeline:
339
449
  raise RuntimeError("Model not loaded")
340
450
 
451
+ # Handle ControlNet-specific logic
452
+ if self.is_controlnet_pipeline:
453
+ if control_image is None:
454
+ raise ValueError("ControlNet model requires a control image")
455
+
456
+ # Process control image
457
+ control_image = self._prepare_control_image(control_image, width, height)
458
+
341
459
  # Use model default parameters
342
460
  if num_inference_steps is None:
343
461
  num_inference_steps = self.model_config.parameters.get("num_inference_steps", 28)
@@ -361,8 +479,19 @@ class InferenceEngine:
361
479
  **kwargs
362
480
  }
363
481
 
482
+ # Add ControlNet parameters if this is a ControlNet pipeline
483
+ if self.is_controlnet_pipeline and control_image is not None:
484
+ generation_kwargs.update({
485
+ "image": control_image,
486
+ "controlnet_conditioning_scale": controlnet_conditioning_scale,
487
+ "control_guidance_start": control_guidance_start,
488
+ "control_guidance_end": control_guidance_end
489
+ })
490
+ logger.info(f"ControlNet parameters: conditioning_scale={controlnet_conditioning_scale}, "
491
+ f"guidance_start={control_guidance_start}, guidance_end={control_guidance_end}")
492
+
364
493
  # Add size parameters based on model type
365
- if self.model_config.model_type in ["sdxl", "sd3", "flux"]:
494
+ if self.model_config.model_type in ["sdxl", "sd3", "flux", "controlnet_sdxl"]:
366
495
  generation_kwargs.update({
367
496
  "width": width,
368
497
  "height": height
@@ -413,8 +542,8 @@ class InferenceEngine:
413
542
 
414
543
  logger.info("🍎 MPS inference - should be faster than CPU but slower than CUDA")
415
544
 
416
- elif self.model_config.model_type == "sd15":
417
- # SD 1.5 works best with 512x512, adjust if different sizes requested
545
+ elif self.model_config.model_type in ["sd15", "controlnet_sd15"]:
546
+ # SD 1.5 and ControlNet SD 1.5 work best with 512x512, adjust if different sizes requested
418
547
  if width != 1024 or height != 1024:
419
548
  generation_kwargs.update({
420
549
  "width": width,
@@ -451,7 +580,7 @@ class InferenceEngine:
451
580
  logger.info("Applied MPS-specific optimizations")
452
581
 
453
582
  # For SD 1.5, use manual pipeline execution to completely bypass safety checker
454
- if self.model_config.model_type == "sd15":
583
+ if self.model_config.model_type == "sd15" and not self.is_controlnet_pipeline:
455
584
  logger.info("Using manual pipeline execution for SD 1.5 to bypass safety checker")
456
585
  try:
457
586
  # Manual pipeline execution with safety checks disabled
@@ -817,6 +946,48 @@ class InferenceEngine:
817
946
  logger.warning(f"Failed to validate image data: {e}, returning original image")
818
947
  return image
819
948
 
949
+ def _prepare_control_image(self, control_image: Union[Image.Image, str], width: int, height: int) -> Image.Image:
950
+ """Prepare control image for ControlNet"""
951
+ try:
952
+ # Initialize ControlNet preprocessors if needed
953
+ if not controlnet_preprocessor.is_initialized():
954
+ logger.info("Initializing ControlNet preprocessors for image processing...")
955
+ if not controlnet_preprocessor.initialize():
956
+ logger.error("Failed to initialize ControlNet preprocessors")
957
+ # Continue with basic processing
958
+
959
+ # Load image if path is provided
960
+ if isinstance(control_image, str):
961
+ control_image = Image.open(control_image).convert('RGB')
962
+ elif not isinstance(control_image, Image.Image):
963
+ raise ValueError("Control image must be PIL Image or file path")
964
+
965
+ # Ensure image is RGB
966
+ if control_image.mode != 'RGB':
967
+ control_image = control_image.convert('RGB')
968
+
969
+ # Get ControlNet type from model config
970
+ from ..models.manager import model_manager
971
+ model_info = model_manager.get_model_info(self.model_config.name)
972
+ controlnet_type = model_info.get('controlnet_type', 'canny') if model_info else 'canny'
973
+
974
+ # Preprocess the control image based on ControlNet type
975
+ logger.info(f"Preprocessing control image for {controlnet_type} ControlNet")
976
+ processed_image = controlnet_preprocessor.preprocess(control_image, controlnet_type)
977
+
978
+ # Resize to match generation size
979
+ processed_image = controlnet_preprocessor.resize_for_controlnet(processed_image, width, height)
980
+
981
+ logger.info(f"Control image prepared: {processed_image.size}")
982
+ return processed_image
983
+
984
+ except Exception as e:
985
+ logger.error(f"Failed to prepare control image: {e}")
986
+ # Return resized original image as fallback
987
+ if isinstance(control_image, str):
988
+ control_image = Image.open(control_image).convert('RGB')
989
+ return controlnet_preprocessor.resize_for_controlnet(control_image, width, height)
990
+
820
991
  def _create_error_image(self, error_msg: str, prompt: str) -> Image.Image:
821
992
  """Create error message image"""
822
993
  from PIL import ImageDraw, ImageFont
@@ -110,8 +110,8 @@ class ModelManager:
110
110
  "model_type": "sd15",
111
111
  "variant": "fp16",
112
112
  "parameters": {
113
- "num_inference_steps": 20,
114
- "guidance_scale": 7.0
113
+ "num_inference_steps": 50,
114
+ "guidance_scale": 7.5
115
115
  },
116
116
  "hardware_requirements": {
117
117
  "min_vram_gb": 4,
@@ -122,6 +122,140 @@ class ModelManager:
122
122
  "supported_devices": ["CUDA", "MPS", "CPU"],
123
123
  "performance_notes": "Runs well on most modern GPUs, including GTX 1060+"
124
124
  }
125
+ },
126
+
127
+ # ControlNet models for SD 1.5
128
+ "controlnet-canny-sd15": {
129
+ "repo_id": "lllyasviel/sd-controlnet-canny",
130
+ "model_type": "controlnet_sd15",
131
+ "base_model": "stable-diffusion-1.5",
132
+ "controlnet_type": "canny",
133
+ "variant": "fp16",
134
+ "parameters": {
135
+ "num_inference_steps": 50,
136
+ "guidance_scale": 7.5,
137
+ "controlnet_conditioning_scale": 1.0
138
+ },
139
+ "hardware_requirements": {
140
+ "min_vram_gb": 6,
141
+ "recommended_vram_gb": 8,
142
+ "min_ram_gb": 12,
143
+ "recommended_ram_gb": 20,
144
+ "disk_space_gb": 7,
145
+ "supported_devices": ["CUDA", "MPS", "CPU"],
146
+ "performance_notes": "Requires base SD 1.5 model + ControlNet model. Good for edge detection."
147
+ }
148
+ },
149
+
150
+ "controlnet-depth-sd15": {
151
+ "repo_id": "lllyasviel/sd-controlnet-depth",
152
+ "model_type": "controlnet_sd15",
153
+ "base_model": "stable-diffusion-1.5",
154
+ "controlnet_type": "depth",
155
+ "variant": "fp16",
156
+ "parameters": {
157
+ "num_inference_steps": 50,
158
+ "guidance_scale": 7.5,
159
+ "controlnet_conditioning_scale": 1.0
160
+ },
161
+ "hardware_requirements": {
162
+ "min_vram_gb": 6,
163
+ "recommended_vram_gb": 8,
164
+ "min_ram_gb": 12,
165
+ "recommended_ram_gb": 20,
166
+ "disk_space_gb": 7,
167
+ "supported_devices": ["CUDA", "MPS", "CPU"],
168
+ "performance_notes": "Requires base SD 1.5 model + ControlNet model. Good for depth-based control."
169
+ }
170
+ },
171
+
172
+ "controlnet-openpose-sd15": {
173
+ "repo_id": "lllyasviel/sd-controlnet-openpose",
174
+ "model_type": "controlnet_sd15",
175
+ "base_model": "stable-diffusion-1.5",
176
+ "controlnet_type": "openpose",
177
+ "variant": "fp16",
178
+ "parameters": {
179
+ "num_inference_steps": 50,
180
+ "guidance_scale": 7.5,
181
+ "controlnet_conditioning_scale": 1.0
182
+ },
183
+ "hardware_requirements": {
184
+ "min_vram_gb": 6,
185
+ "recommended_vram_gb": 8,
186
+ "min_ram_gb": 12,
187
+ "recommended_ram_gb": 20,
188
+ "disk_space_gb": 7,
189
+ "supported_devices": ["CUDA", "MPS", "CPU"],
190
+ "performance_notes": "Requires base SD 1.5 model + ControlNet model. Good for pose control."
191
+ }
192
+ },
193
+
194
+ "controlnet-scribble-sd15": {
195
+ "repo_id": "lllyasviel/sd-controlnet-scribble",
196
+ "model_type": "controlnet_sd15",
197
+ "base_model": "stable-diffusion-1.5",
198
+ "controlnet_type": "scribble",
199
+ "variant": "fp16",
200
+ "parameters": {
201
+ "num_inference_steps": 50,
202
+ "guidance_scale": 7.5,
203
+ "controlnet_conditioning_scale": 1.0
204
+ },
205
+ "hardware_requirements": {
206
+ "min_vram_gb": 6,
207
+ "recommended_vram_gb": 8,
208
+ "min_ram_gb": 12,
209
+ "recommended_ram_gb": 20,
210
+ "disk_space_gb": 7,
211
+ "supported_devices": ["CUDA", "MPS", "CPU"],
212
+ "performance_notes": "Requires base SD 1.5 model + ControlNet model. Good for sketch-based control."
213
+ }
214
+ },
215
+
216
+ # ControlNet models for SDXL
217
+ "controlnet-canny-sdxl": {
218
+ "repo_id": "diffusers/controlnet-canny-sdxl-1.0",
219
+ "model_type": "controlnet_sdxl",
220
+ "base_model": "stable-diffusion-xl-base",
221
+ "controlnet_type": "canny",
222
+ "variant": "fp16",
223
+ "parameters": {
224
+ "num_inference_steps": 50,
225
+ "guidance_scale": 7.5,
226
+ "controlnet_conditioning_scale": 1.0
227
+ },
228
+ "hardware_requirements": {
229
+ "min_vram_gb": 8,
230
+ "recommended_vram_gb": 12,
231
+ "min_ram_gb": 16,
232
+ "recommended_ram_gb": 28,
233
+ "disk_space_gb": 10,
234
+ "supported_devices": ["CUDA", "MPS", "CPU"],
235
+ "performance_notes": "Requires base SDXL model + ControlNet model. Good for edge detection with SDXL quality."
236
+ }
237
+ },
238
+
239
+ "controlnet-depth-sdxl": {
240
+ "repo_id": "diffusers/controlnet-depth-sdxl-1.0",
241
+ "model_type": "controlnet_sdxl",
242
+ "base_model": "stable-diffusion-xl-base",
243
+ "controlnet_type": "depth",
244
+ "variant": "fp16",
245
+ "parameters": {
246
+ "num_inference_steps": 50,
247
+ "guidance_scale": 7.5,
248
+ "controlnet_conditioning_scale": 1.0
249
+ },
250
+ "hardware_requirements": {
251
+ "min_vram_gb": 8,
252
+ "recommended_vram_gb": 12,
253
+ "min_ram_gb": 16,
254
+ "recommended_ram_gb": 28,
255
+ "disk_space_gb": 10,
256
+ "supported_devices": ["CUDA", "MPS", "CPU"],
257
+ "performance_notes": "Requires base SDXL model + ControlNet model. Good for depth-based control with SDXL quality."
258
+ }
125
259
  }
126
260
  }
127
261