ollamadiffuser 1.2.3__py3-none-any.whl → 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ollamadiffuser/__init__.py +1 -1
- ollamadiffuser/api/server.py +312 -312
- ollamadiffuser/cli/config_commands.py +119 -0
- ollamadiffuser/cli/lora_commands.py +169 -0
- ollamadiffuser/cli/main.py +85 -1233
- ollamadiffuser/cli/model_commands.py +664 -0
- ollamadiffuser/cli/recommend_command.py +205 -0
- ollamadiffuser/cli/registry_commands.py +197 -0
- ollamadiffuser/core/config/model_registry.py +562 -11
- ollamadiffuser/core/config/settings.py +24 -2
- ollamadiffuser/core/inference/__init__.py +5 -0
- ollamadiffuser/core/inference/base.py +182 -0
- ollamadiffuser/core/inference/engine.py +204 -1405
- ollamadiffuser/core/inference/strategies/__init__.py +1 -0
- ollamadiffuser/core/inference/strategies/controlnet_strategy.py +170 -0
- ollamadiffuser/core/inference/strategies/flux_strategy.py +136 -0
- ollamadiffuser/core/inference/strategies/generic_strategy.py +164 -0
- ollamadiffuser/core/inference/strategies/gguf_strategy.py +113 -0
- ollamadiffuser/core/inference/strategies/hidream_strategy.py +104 -0
- ollamadiffuser/core/inference/strategies/sd15_strategy.py +134 -0
- ollamadiffuser/core/inference/strategies/sd3_strategy.py +80 -0
- ollamadiffuser/core/inference/strategies/sdxl_strategy.py +131 -0
- ollamadiffuser/core/inference/strategies/video_strategy.py +108 -0
- ollamadiffuser/mcp/__init__.py +0 -0
- ollamadiffuser/mcp/server.py +184 -0
- ollamadiffuser/ui/templates/index.html +62 -1
- ollamadiffuser/ui/web.py +116 -54
- {ollamadiffuser-1.2.3.dist-info → ollamadiffuser-2.0.0.dist-info}/METADATA +321 -108
- ollamadiffuser-2.0.0.dist-info/RECORD +61 -0
- {ollamadiffuser-1.2.3.dist-info → ollamadiffuser-2.0.0.dist-info}/WHEEL +1 -1
- {ollamadiffuser-1.2.3.dist-info → ollamadiffuser-2.0.0.dist-info}/entry_points.txt +1 -0
- ollamadiffuser/core/models/registry.py +0 -384
- ollamadiffuser/ui/samples/.DS_Store +0 -0
- ollamadiffuser-1.2.3.dist-info/RECORD +0 -45
- {ollamadiffuser-1.2.3.dist-info → ollamadiffuser-2.0.0.dist-info}/licenses/LICENSE +0 -0
- {ollamadiffuser-1.2.3.dist-info → ollamadiffuser-2.0.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Inference strategies for different model types"""
|
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
"""ControlNet inference strategy"""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import Optional, Union
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from PIL import Image
|
|
8
|
+
|
|
9
|
+
from ..base import InferenceStrategy, SAFETY_DISABLED_KWARGS
|
|
10
|
+
from ...config.settings import ModelConfig
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ControlNetStrategy(InferenceStrategy):
|
|
16
|
+
"""Strategy for ControlNet models (SD 1.5 and SDXL based)"""
|
|
17
|
+
|
|
18
|
+
def __init__(self):
|
|
19
|
+
super().__init__()
|
|
20
|
+
self.controlnet = None
|
|
21
|
+
self.is_controlnet_pipeline = True
|
|
22
|
+
|
|
23
|
+
def load(self, model_config: ModelConfig, device: str) -> bool:
|
|
24
|
+
try:
|
|
25
|
+
from diffusers import (
|
|
26
|
+
ControlNetModel,
|
|
27
|
+
StableDiffusionControlNetPipeline,
|
|
28
|
+
StableDiffusionXLControlNetPipeline,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
self.device = device
|
|
32
|
+
self.model_config = model_config
|
|
33
|
+
|
|
34
|
+
# Determine if SD15 or SDXL based
|
|
35
|
+
is_sdxl = model_config.model_type == "controlnet_sdxl"
|
|
36
|
+
pipeline_class = StableDiffusionXLControlNetPipeline if is_sdxl else StableDiffusionControlNetPipeline
|
|
37
|
+
|
|
38
|
+
load_kwargs = {**SAFETY_DISABLED_KWARGS}
|
|
39
|
+
if device in ("cpu", "mps"):
|
|
40
|
+
load_kwargs["torch_dtype"] = torch.float32
|
|
41
|
+
elif model_config.variant == "fp16":
|
|
42
|
+
load_kwargs["torch_dtype"] = torch.float16
|
|
43
|
+
load_kwargs["variant"] = "fp16"
|
|
44
|
+
else:
|
|
45
|
+
load_kwargs["torch_dtype"] = self._get_dtype(device)
|
|
46
|
+
|
|
47
|
+
# Load ControlNet model
|
|
48
|
+
logger.info(f"Loading ControlNet model from: {model_config.path}")
|
|
49
|
+
self.controlnet = ControlNetModel.from_pretrained(
|
|
50
|
+
model_config.path,
|
|
51
|
+
torch_dtype=load_kwargs.get("torch_dtype", torch.float32),
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
# Get base model path
|
|
55
|
+
base_model_name = getattr(model_config, "base_model", None)
|
|
56
|
+
if not base_model_name:
|
|
57
|
+
from ...models.manager import model_manager
|
|
58
|
+
model_info = model_manager.get_model_info(model_config.name)
|
|
59
|
+
if model_info and "base_model" in model_info:
|
|
60
|
+
base_model_name = model_info["base_model"]
|
|
61
|
+
else:
|
|
62
|
+
raise ValueError(f"No base model specified for ControlNet: {model_config.name}")
|
|
63
|
+
|
|
64
|
+
from ...models.manager import model_manager
|
|
65
|
+
if not model_manager.is_model_installed(base_model_name):
|
|
66
|
+
raise ValueError(f"Base model '{base_model_name}' not installed")
|
|
67
|
+
|
|
68
|
+
from ...config.settings import settings
|
|
69
|
+
base_config = settings.models[base_model_name]
|
|
70
|
+
|
|
71
|
+
# Load pipeline with controlnet
|
|
72
|
+
self.pipeline = pipeline_class.from_pretrained(
|
|
73
|
+
base_config.path,
|
|
74
|
+
controlnet=self.controlnet,
|
|
75
|
+
**load_kwargs,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
self._move_to_device(device)
|
|
79
|
+
self.controlnet = self.controlnet.to(self.device)
|
|
80
|
+
self._apply_memory_optimizations()
|
|
81
|
+
|
|
82
|
+
logger.info(f"ControlNet model {model_config.name} loaded on {self.device}")
|
|
83
|
+
return True
|
|
84
|
+
except Exception as e:
|
|
85
|
+
logger.error(f"Failed to load ControlNet model: {e}")
|
|
86
|
+
return False
|
|
87
|
+
|
|
88
|
+
def generate(
|
|
89
|
+
self,
|
|
90
|
+
prompt: str,
|
|
91
|
+
negative_prompt: str = "low quality, bad anatomy, worst quality, low resolution",
|
|
92
|
+
num_inference_steps: Optional[int] = None,
|
|
93
|
+
guidance_scale: Optional[float] = None,
|
|
94
|
+
width: int = 512,
|
|
95
|
+
height: int = 512,
|
|
96
|
+
seed: Optional[int] = None,
|
|
97
|
+
control_image: Optional[Union[Image.Image, str]] = None,
|
|
98
|
+
controlnet_conditioning_scale: float = 1.0,
|
|
99
|
+
control_guidance_start: float = 0.0,
|
|
100
|
+
control_guidance_end: float = 1.0,
|
|
101
|
+
**kwargs,
|
|
102
|
+
) -> Image.Image:
|
|
103
|
+
if not self.pipeline:
|
|
104
|
+
raise RuntimeError("Model not loaded")
|
|
105
|
+
if control_image is None:
|
|
106
|
+
raise ValueError("ControlNet requires a control_image")
|
|
107
|
+
|
|
108
|
+
params = self.model_config.parameters or {}
|
|
109
|
+
steps = num_inference_steps if num_inference_steps is not None else params.get("num_inference_steps", 50)
|
|
110
|
+
guidance = guidance_scale if guidance_scale is not None else params.get("guidance_scale", 7.5)
|
|
111
|
+
|
|
112
|
+
# Prepare control image
|
|
113
|
+
control_image = self._prepare_control_image(control_image, width, height)
|
|
114
|
+
|
|
115
|
+
generator, used_seed = self._make_generator(seed, self.device)
|
|
116
|
+
|
|
117
|
+
gen_kwargs = {
|
|
118
|
+
"prompt": prompt,
|
|
119
|
+
"negative_prompt": negative_prompt,
|
|
120
|
+
"num_inference_steps": steps,
|
|
121
|
+
"guidance_scale": guidance,
|
|
122
|
+
"width": width,
|
|
123
|
+
"height": height,
|
|
124
|
+
"image": control_image,
|
|
125
|
+
"controlnet_conditioning_scale": controlnet_conditioning_scale,
|
|
126
|
+
"control_guidance_start": control_guidance_start,
|
|
127
|
+
"control_guidance_end": control_guidance_end,
|
|
128
|
+
"generator": generator,
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
try:
|
|
132
|
+
logger.info(f"Generating ControlNet image: steps={steps}, guidance={guidance}, seed={used_seed}")
|
|
133
|
+
output = self.pipeline(**gen_kwargs)
|
|
134
|
+
return output.images[0]
|
|
135
|
+
except Exception as e:
|
|
136
|
+
logger.error(f"ControlNet generation failed: {e}")
|
|
137
|
+
return self._create_error_image(str(e), prompt)
|
|
138
|
+
|
|
139
|
+
def _prepare_control_image(
|
|
140
|
+
self, control_image: Union[Image.Image, str], width: int, height: int
|
|
141
|
+
) -> Image.Image:
|
|
142
|
+
"""Prepare and preprocess control image"""
|
|
143
|
+
from ...utils.controlnet_preprocessors import controlnet_preprocessor
|
|
144
|
+
|
|
145
|
+
if isinstance(control_image, str):
|
|
146
|
+
control_image = Image.open(control_image).convert("RGB")
|
|
147
|
+
elif not isinstance(control_image, Image.Image):
|
|
148
|
+
raise ValueError("control_image must be PIL Image or file path")
|
|
149
|
+
|
|
150
|
+
if control_image.mode != "RGB":
|
|
151
|
+
control_image = control_image.convert("RGB")
|
|
152
|
+
|
|
153
|
+
# Get controlnet type from model info
|
|
154
|
+
from ...models.manager import model_manager
|
|
155
|
+
model_info = model_manager.get_model_info(self.model_config.name)
|
|
156
|
+
cn_type = model_info.get("controlnet_type", "canny") if model_info else "canny"
|
|
157
|
+
|
|
158
|
+
# Initialize preprocessor if needed
|
|
159
|
+
if not controlnet_preprocessor.is_initialized():
|
|
160
|
+
controlnet_preprocessor.initialize()
|
|
161
|
+
|
|
162
|
+
processed = controlnet_preprocessor.preprocess(control_image, cn_type)
|
|
163
|
+
processed = controlnet_preprocessor.resize_for_controlnet(processed, width, height)
|
|
164
|
+
return processed
|
|
165
|
+
|
|
166
|
+
def unload(self) -> None:
|
|
167
|
+
if self.controlnet:
|
|
168
|
+
del self.controlnet
|
|
169
|
+
self.controlnet = None
|
|
170
|
+
super().unload()
|
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
"""FLUX inference strategy"""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from PIL import Image
|
|
8
|
+
|
|
9
|
+
from ..base import InferenceStrategy, SAFETY_DISABLED_KWARGS
|
|
10
|
+
from ...config.settings import ModelConfig
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class FluxStrategy(InferenceStrategy):
|
|
16
|
+
"""Strategy for FLUX models (FLUX.1-dev, FLUX.1-schnell)"""
|
|
17
|
+
|
|
18
|
+
def load(self, model_config: ModelConfig, device: str) -> bool:
|
|
19
|
+
try:
|
|
20
|
+
import diffusers
|
|
21
|
+
|
|
22
|
+
self.device = device
|
|
23
|
+
self.model_config = model_config
|
|
24
|
+
|
|
25
|
+
# Resolve pipeline class from parameters, defaulting to FluxPipeline
|
|
26
|
+
params = model_config.parameters or {}
|
|
27
|
+
pipeline_class_name = params.get("pipeline_class", "FluxPipeline")
|
|
28
|
+
pipeline_cls = getattr(diffusers, pipeline_class_name, None)
|
|
29
|
+
if pipeline_cls is None:
|
|
30
|
+
logger.error(
|
|
31
|
+
f"Pipeline class '{pipeline_class_name}' not found in diffusers. "
|
|
32
|
+
"You may need to upgrade: pip install --upgrade diffusers"
|
|
33
|
+
)
|
|
34
|
+
return False
|
|
35
|
+
|
|
36
|
+
load_kwargs = {**SAFETY_DISABLED_KWARGS}
|
|
37
|
+
|
|
38
|
+
if device == "cpu":
|
|
39
|
+
load_kwargs["torch_dtype"] = torch.float32
|
|
40
|
+
logger.warning("FLUX on CPU will be very slow for this 12B parameter model")
|
|
41
|
+
else:
|
|
42
|
+
load_kwargs["torch_dtype"] = torch.bfloat16
|
|
43
|
+
load_kwargs["use_safetensors"] = True
|
|
44
|
+
|
|
45
|
+
self.pipeline = pipeline_cls.from_pretrained(
|
|
46
|
+
model_config.path, **load_kwargs
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
if device in ("cuda", "mps") and hasattr(self.pipeline, "enable_model_cpu_offload"):
|
|
50
|
+
# CPU offloading manages device placement itself — don't call _move_to_device
|
|
51
|
+
self.pipeline.enable_model_cpu_offload(device=device)
|
|
52
|
+
logger.info(f"Enabled CPU offloading for FLUX on {device}")
|
|
53
|
+
else:
|
|
54
|
+
self._move_to_device(device)
|
|
55
|
+
self._apply_memory_optimizations()
|
|
56
|
+
|
|
57
|
+
logger.info(f"FLUX model {model_config.name} loaded on {self.device}")
|
|
58
|
+
return True
|
|
59
|
+
except Exception as e:
|
|
60
|
+
logger.error(f"Failed to load FLUX model: {e}")
|
|
61
|
+
return False
|
|
62
|
+
|
|
63
|
+
@property
|
|
64
|
+
def _is_schnell(self) -> bool:
|
|
65
|
+
return self.model_config is not None and "schnell" in self.model_config.name.lower()
|
|
66
|
+
|
|
67
|
+
def generate(
|
|
68
|
+
self,
|
|
69
|
+
prompt: str,
|
|
70
|
+
negative_prompt: str = "",
|
|
71
|
+
num_inference_steps: Optional[int] = None,
|
|
72
|
+
guidance_scale: Optional[float] = None,
|
|
73
|
+
width: int = 1024,
|
|
74
|
+
height: int = 1024,
|
|
75
|
+
seed: Optional[int] = None,
|
|
76
|
+
**kwargs,
|
|
77
|
+
) -> Image.Image:
|
|
78
|
+
if not self.pipeline:
|
|
79
|
+
raise RuntimeError("Model not loaded")
|
|
80
|
+
|
|
81
|
+
params = self.model_config.parameters or {}
|
|
82
|
+
|
|
83
|
+
# Schnell-specific defaults
|
|
84
|
+
if self._is_schnell:
|
|
85
|
+
steps = num_inference_steps if num_inference_steps is not None else 4
|
|
86
|
+
if steps > 4:
|
|
87
|
+
logger.info(f"FLUX.1-schnell: reducing steps from {steps} to 4")
|
|
88
|
+
steps = 4
|
|
89
|
+
guidance = 0.0
|
|
90
|
+
logger.info("FLUX.1-schnell: using 0.0 guidance (distilled model)")
|
|
91
|
+
else:
|
|
92
|
+
steps = num_inference_steps if num_inference_steps is not None else params.get("num_inference_steps", 20)
|
|
93
|
+
guidance = guidance_scale if guidance_scale is not None else params.get("guidance_scale", 3.5)
|
|
94
|
+
|
|
95
|
+
# CPU-specific caps
|
|
96
|
+
if self.device == "cpu":
|
|
97
|
+
if not self._is_schnell and steps > 20:
|
|
98
|
+
steps = 20
|
|
99
|
+
logger.info(f"Reduced steps to {steps} for CPU performance")
|
|
100
|
+
if not self._is_schnell and guidance > 5.0:
|
|
101
|
+
guidance = 5.0
|
|
102
|
+
|
|
103
|
+
max_seq_len = kwargs.get("max_sequence_length", params.get("max_sequence_length", 512))
|
|
104
|
+
|
|
105
|
+
generator, used_seed = self._make_generator(seed, self.device)
|
|
106
|
+
|
|
107
|
+
gen_kwargs = {
|
|
108
|
+
"prompt": prompt,
|
|
109
|
+
"num_inference_steps": steps,
|
|
110
|
+
"guidance_scale": guidance,
|
|
111
|
+
"width": width,
|
|
112
|
+
"height": height,
|
|
113
|
+
"max_sequence_length": max_seq_len,
|
|
114
|
+
"generator": generator,
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
# Pass through params for Fill/Control pipeline variants
|
|
118
|
+
for key in ("image", "mask_image", "strength", "control_image",
|
|
119
|
+
"controlnet_conditioning_scale"):
|
|
120
|
+
if key in kwargs:
|
|
121
|
+
gen_kwargs[key] = kwargs[key]
|
|
122
|
+
|
|
123
|
+
try:
|
|
124
|
+
logger.info(f"Generating FLUX image: steps={steps}, guidance={guidance}, seed={used_seed}")
|
|
125
|
+
output = self.pipeline(**gen_kwargs)
|
|
126
|
+
return output.images[0]
|
|
127
|
+
except RuntimeError as e:
|
|
128
|
+
if "CUDA" in str(e) and self.device == "cpu":
|
|
129
|
+
logger.warning("Device mismatch, retrying without generator")
|
|
130
|
+
gen_kwargs.pop("generator", None)
|
|
131
|
+
output = self.pipeline(**gen_kwargs)
|
|
132
|
+
return output.images[0]
|
|
133
|
+
raise
|
|
134
|
+
except Exception as e:
|
|
135
|
+
logger.error(f"FLUX generation failed: {e}")
|
|
136
|
+
return self._create_error_image(str(e), prompt)
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
"""Generic pipeline inference strategy for any diffusers-compatible model."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from PIL import Image
|
|
8
|
+
|
|
9
|
+
from ..base import InferenceStrategy
|
|
10
|
+
from ...config.settings import ModelConfig
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
# Dtype name -> torch dtype mapping
|
|
15
|
+
_DTYPE_MAP = {
|
|
16
|
+
"float32": torch.float32,
|
|
17
|
+
"float16": torch.float16,
|
|
18
|
+
"bfloat16": torch.bfloat16,
|
|
19
|
+
}
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class GenericPipelineStrategy(InferenceStrategy):
|
|
23
|
+
"""Strategy that dynamically loads any diffusers pipeline class.
|
|
24
|
+
|
|
25
|
+
The pipeline class name is read from model_config.parameters["pipeline_class"].
|
|
26
|
+
This allows adding new model types to the registry without writing new strategy code.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def load(self, model_config: ModelConfig, device: str) -> bool:
|
|
30
|
+
try:
|
|
31
|
+
import diffusers
|
|
32
|
+
|
|
33
|
+
self.device = device
|
|
34
|
+
self.model_config = model_config
|
|
35
|
+
|
|
36
|
+
params = model_config.parameters or {}
|
|
37
|
+
pipeline_class_name = params.get("pipeline_class")
|
|
38
|
+
if not pipeline_class_name:
|
|
39
|
+
logger.error("GenericPipelineStrategy requires 'pipeline_class' in model parameters")
|
|
40
|
+
return False
|
|
41
|
+
|
|
42
|
+
pipeline_cls = getattr(diffusers, pipeline_class_name, None)
|
|
43
|
+
if pipeline_cls is None:
|
|
44
|
+
logger.error(
|
|
45
|
+
f"Pipeline class '{pipeline_class_name}' not found in diffusers. "
|
|
46
|
+
"You may need to upgrade: pip install --upgrade diffusers"
|
|
47
|
+
)
|
|
48
|
+
return False
|
|
49
|
+
|
|
50
|
+
# Resolve dtype from parameters or auto-detect
|
|
51
|
+
dtype_name = params.get("torch_dtype")
|
|
52
|
+
if dtype_name and dtype_name in _DTYPE_MAP:
|
|
53
|
+
dtype = _DTYPE_MAP[dtype_name]
|
|
54
|
+
elif device == "cpu":
|
|
55
|
+
dtype = torch.float32
|
|
56
|
+
else:
|
|
57
|
+
dtype = torch.bfloat16
|
|
58
|
+
|
|
59
|
+
# MPS safety: bfloat16 is not reliably supported on Metal
|
|
60
|
+
if device == "mps" and dtype == torch.bfloat16:
|
|
61
|
+
logger.info("Falling back from bfloat16 to float16 for MPS device compatibility")
|
|
62
|
+
dtype = torch.float16
|
|
63
|
+
|
|
64
|
+
load_kwargs = {"torch_dtype": dtype, "low_cpu_mem_usage": True}
|
|
65
|
+
if dtype in (torch.float16, torch.bfloat16):
|
|
66
|
+
load_kwargs["use_safetensors"] = True
|
|
67
|
+
|
|
68
|
+
logger.info(f"Loading {pipeline_class_name} from {model_config.path} (dtype={dtype})")
|
|
69
|
+
self.pipeline = pipeline_cls.from_pretrained(
|
|
70
|
+
model_config.path, **load_kwargs
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
# Device placement
|
|
74
|
+
enable_offload = params.get("enable_cpu_offload", False)
|
|
75
|
+
# Auto-enable CPU offload on MPS to avoid OOM on unified memory
|
|
76
|
+
if device == "mps":
|
|
77
|
+
enable_offload = True
|
|
78
|
+
|
|
79
|
+
if enable_offload and device in ("cuda", "mps"):
|
|
80
|
+
if device == "mps" and hasattr(self.pipeline, "enable_model_cpu_offload"):
|
|
81
|
+
# MPS/unified memory: model-level offload is more effective than
|
|
82
|
+
# sequential offload because it fully deallocates entire components
|
|
83
|
+
# (T5 encoder, transformer, VAE) between stages, reducing peak
|
|
84
|
+
# memory pressure on the MPS allocator.
|
|
85
|
+
self.pipeline.enable_model_cpu_offload(device=device)
|
|
86
|
+
logger.info(f"Enabled model CPU offloading on {device}")
|
|
87
|
+
elif hasattr(self.pipeline, "enable_sequential_cpu_offload"):
|
|
88
|
+
# CUDA: sequential offload moves individual layers, lowest VRAM usage
|
|
89
|
+
self.pipeline.enable_sequential_cpu_offload(device=device)
|
|
90
|
+
logger.info(f"Enabled sequential CPU offloading on {device}")
|
|
91
|
+
elif hasattr(self.pipeline, "enable_model_cpu_offload"):
|
|
92
|
+
self.pipeline.enable_model_cpu_offload(device=device)
|
|
93
|
+
logger.info(f"Enabled model CPU offloading on {device}")
|
|
94
|
+
else:
|
|
95
|
+
self._move_to_device(device)
|
|
96
|
+
else:
|
|
97
|
+
self._move_to_device(device)
|
|
98
|
+
|
|
99
|
+
self._apply_memory_optimizations()
|
|
100
|
+
|
|
101
|
+
logger.info(f"{pipeline_class_name} model {model_config.name} loaded on {self.device}")
|
|
102
|
+
return True
|
|
103
|
+
except Exception as e:
|
|
104
|
+
logger.error(f"Failed to load {model_config.name}: {e}")
|
|
105
|
+
return False
|
|
106
|
+
|
|
107
|
+
def generate(
|
|
108
|
+
self,
|
|
109
|
+
prompt: str,
|
|
110
|
+
negative_prompt: str = "low quality, bad anatomy, worst quality, low resolution",
|
|
111
|
+
num_inference_steps: Optional[int] = None,
|
|
112
|
+
guidance_scale: Optional[float] = None,
|
|
113
|
+
width: int = 1024,
|
|
114
|
+
height: int = 1024,
|
|
115
|
+
seed: Optional[int] = None,
|
|
116
|
+
**kwargs,
|
|
117
|
+
) -> Image.Image:
|
|
118
|
+
if not self.pipeline:
|
|
119
|
+
raise RuntimeError("Model not loaded")
|
|
120
|
+
|
|
121
|
+
params = self.model_config.parameters or {}
|
|
122
|
+
steps = num_inference_steps if num_inference_steps is not None else params.get("num_inference_steps", 28)
|
|
123
|
+
guidance = guidance_scale if guidance_scale is not None else params.get("guidance_scale", 7.0)
|
|
124
|
+
|
|
125
|
+
generator, used_seed = self._make_generator(seed, self.device)
|
|
126
|
+
|
|
127
|
+
gen_kwargs = {
|
|
128
|
+
"prompt": prompt,
|
|
129
|
+
"num_inference_steps": steps,
|
|
130
|
+
"guidance_scale": guidance,
|
|
131
|
+
"width": width,
|
|
132
|
+
"height": height,
|
|
133
|
+
"generator": generator,
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
# Include negative_prompt only if the pipeline supports it
|
|
137
|
+
if params.get("supports_negative_prompt", True):
|
|
138
|
+
gen_kwargs["negative_prompt"] = negative_prompt
|
|
139
|
+
|
|
140
|
+
# Pass through image/mask/control params from kwargs
|
|
141
|
+
for key in ("image", "mask_image", "strength", "control_image",
|
|
142
|
+
"controlnet_conditioning_scale", "control_guidance_start",
|
|
143
|
+
"control_guidance_end"):
|
|
144
|
+
if key in kwargs:
|
|
145
|
+
gen_kwargs[key] = kwargs[key]
|
|
146
|
+
|
|
147
|
+
try:
|
|
148
|
+
logger.info(
|
|
149
|
+
f"Generating with {type(self.pipeline).__name__}: "
|
|
150
|
+
f"steps={steps}, guidance={guidance}, seed={used_seed}"
|
|
151
|
+
)
|
|
152
|
+
output = self.pipeline(**gen_kwargs)
|
|
153
|
+
return output.images[0]
|
|
154
|
+
except TypeError as e:
|
|
155
|
+
# Some pipelines don't accept all standard params (e.g., width/height)
|
|
156
|
+
# Retry without optional params
|
|
157
|
+
logger.warning(f"Pipeline call failed: {e}. Retrying with minimal params.")
|
|
158
|
+
for key in ("width", "height", "negative_prompt"):
|
|
159
|
+
gen_kwargs.pop(key, None)
|
|
160
|
+
output = self.pipeline(**gen_kwargs)
|
|
161
|
+
return output.images[0]
|
|
162
|
+
except Exception as e:
|
|
163
|
+
logger.error(f"Generation failed: {e}")
|
|
164
|
+
return self._create_error_image(str(e), prompt)
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
"""GGUF quantized model inference strategy"""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import random
|
|
5
|
+
from typing import Any, Dict, Optional
|
|
6
|
+
|
|
7
|
+
from PIL import Image
|
|
8
|
+
|
|
9
|
+
from ..base import InferenceStrategy
|
|
10
|
+
from ...config.settings import ModelConfig
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
try:
|
|
15
|
+
from ...models.gguf_loader import gguf_loader, GGUF_AVAILABLE
|
|
16
|
+
except ImportError:
|
|
17
|
+
GGUF_AVAILABLE = False
|
|
18
|
+
gguf_loader = None
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class GGUFStrategy(InferenceStrategy):
|
|
22
|
+
"""Strategy for GGUF quantized models"""
|
|
23
|
+
|
|
24
|
+
def load(self, model_config: ModelConfig, device: str) -> bool:
|
|
25
|
+
if not GGUF_AVAILABLE:
|
|
26
|
+
logger.error("GGUF support not available. Install with: pip install stable-diffusion-cpp-python gguf")
|
|
27
|
+
return False
|
|
28
|
+
|
|
29
|
+
try:
|
|
30
|
+
self.device = device
|
|
31
|
+
self.model_config = model_config
|
|
32
|
+
|
|
33
|
+
config_dict = {
|
|
34
|
+
"name": model_config.name,
|
|
35
|
+
"path": model_config.path,
|
|
36
|
+
"variant": model_config.variant,
|
|
37
|
+
"model_type": model_config.model_type,
|
|
38
|
+
"parameters": model_config.parameters,
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
if gguf_loader.load_model(config_dict):
|
|
42
|
+
self.pipeline = None # GGUF uses its own loader
|
|
43
|
+
logger.info(f"GGUF model {model_config.name} loaded")
|
|
44
|
+
return True
|
|
45
|
+
|
|
46
|
+
logger.error(f"Failed to load GGUF model: {model_config.name}")
|
|
47
|
+
return False
|
|
48
|
+
except Exception as e:
|
|
49
|
+
logger.error(f"GGUF load error: {e}")
|
|
50
|
+
return False
|
|
51
|
+
|
|
52
|
+
@property
|
|
53
|
+
def is_loaded(self) -> bool:
|
|
54
|
+
return GGUF_AVAILABLE and gguf_loader is not None and gguf_loader.is_loaded()
|
|
55
|
+
|
|
56
|
+
def generate(
|
|
57
|
+
self,
|
|
58
|
+
prompt: str,
|
|
59
|
+
negative_prompt: str = "low quality, bad anatomy, worst quality, low resolution",
|
|
60
|
+
num_inference_steps: Optional[int] = None,
|
|
61
|
+
guidance_scale: Optional[float] = None,
|
|
62
|
+
width: int = 1024,
|
|
63
|
+
height: int = 1024,
|
|
64
|
+
seed: Optional[int] = None,
|
|
65
|
+
**kwargs,
|
|
66
|
+
) -> Image.Image:
|
|
67
|
+
if not self.is_loaded:
|
|
68
|
+
raise RuntimeError("GGUF model not loaded")
|
|
69
|
+
|
|
70
|
+
params = self.model_config.parameters or {}
|
|
71
|
+
steps = num_inference_steps if num_inference_steps is not None else params.get("num_inference_steps", 20)
|
|
72
|
+
guidance = guidance_scale if guidance_scale is not None else params.get("guidance_scale", 7.5)
|
|
73
|
+
|
|
74
|
+
# Resolve seed (GGUF uses integer seed, not torch.Generator)
|
|
75
|
+
if seed is None:
|
|
76
|
+
seed = random.randint(0, 2**32 - 1)
|
|
77
|
+
logger.info(f"Using seed: {seed}")
|
|
78
|
+
|
|
79
|
+
gen_kwargs = {
|
|
80
|
+
"prompt": prompt,
|
|
81
|
+
"negative_prompt": negative_prompt,
|
|
82
|
+
"num_inference_steps": steps,
|
|
83
|
+
"guidance_scale": guidance,
|
|
84
|
+
"width": width,
|
|
85
|
+
"height": height,
|
|
86
|
+
"seed": seed,
|
|
87
|
+
**kwargs,
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
try:
|
|
91
|
+
logger.info(f"Generating GGUF image: steps={steps}, guidance={guidance}, seed={seed}")
|
|
92
|
+
image = gguf_loader.generate_image(**gen_kwargs)
|
|
93
|
+
if image is None:
|
|
94
|
+
return self._create_error_image("GGUF generation returned None", prompt)
|
|
95
|
+
return image
|
|
96
|
+
except Exception as e:
|
|
97
|
+
logger.error(f"GGUF generation failed: {e}")
|
|
98
|
+
return self._create_error_image(str(e), prompt)
|
|
99
|
+
|
|
100
|
+
def unload(self) -> None:
|
|
101
|
+
if GGUF_AVAILABLE and gguf_loader is not None and gguf_loader.is_loaded():
|
|
102
|
+
gguf_loader.unload_model()
|
|
103
|
+
self.model_config = None
|
|
104
|
+
self.current_lora = None
|
|
105
|
+
logger.info("GGUF model unloaded")
|
|
106
|
+
|
|
107
|
+
def get_info(self) -> Dict[str, Any]:
|
|
108
|
+
info = super().get_info()
|
|
109
|
+
if GGUF_AVAILABLE and gguf_loader is not None:
|
|
110
|
+
info.update(gguf_loader.get_model_info())
|
|
111
|
+
info["is_gguf"] = True
|
|
112
|
+
info["gguf_available"] = GGUF_AVAILABLE
|
|
113
|
+
return info
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
"""HiDream inference strategy"""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from PIL import Image
|
|
8
|
+
|
|
9
|
+
from ..base import InferenceStrategy, SAFETY_DISABLED_KWARGS
|
|
10
|
+
from ...config.settings import ModelConfig
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
# Check availability
|
|
15
|
+
try:
|
|
16
|
+
from diffusers import HiDreamImagePipeline
|
|
17
|
+
HIDREAM_AVAILABLE = True
|
|
18
|
+
except ImportError:
|
|
19
|
+
HIDREAM_AVAILABLE = False
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class HiDreamStrategy(InferenceStrategy):
|
|
23
|
+
"""Strategy for HiDream models"""
|
|
24
|
+
|
|
25
|
+
def load(self, model_config: ModelConfig, device: str) -> bool:
|
|
26
|
+
if not HIDREAM_AVAILABLE:
|
|
27
|
+
logger.error("HiDreamImagePipeline not available. Install diffusers from source.")
|
|
28
|
+
return False
|
|
29
|
+
|
|
30
|
+
try:
|
|
31
|
+
self.device = device
|
|
32
|
+
self.model_config = model_config
|
|
33
|
+
|
|
34
|
+
load_kwargs = {**SAFETY_DISABLED_KWARGS}
|
|
35
|
+
if device == "cpu":
|
|
36
|
+
load_kwargs["torch_dtype"] = torch.float32
|
|
37
|
+
else:
|
|
38
|
+
load_kwargs["torch_dtype"] = torch.bfloat16
|
|
39
|
+
|
|
40
|
+
self.pipeline = HiDreamImagePipeline.from_pretrained(
|
|
41
|
+
model_config.path, **load_kwargs
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
if device in ("cuda", "mps") and hasattr(self.pipeline, "enable_model_cpu_offload"):
|
|
45
|
+
# CPU offloading manages device placement itself — don't call _move_to_device
|
|
46
|
+
self.pipeline.enable_model_cpu_offload(device=device)
|
|
47
|
+
else:
|
|
48
|
+
self._move_to_device(device)
|
|
49
|
+
|
|
50
|
+
if hasattr(self.pipeline, "enable_vae_slicing"):
|
|
51
|
+
self.pipeline.enable_vae_slicing()
|
|
52
|
+
if hasattr(self.pipeline, "enable_vae_tiling"):
|
|
53
|
+
self.pipeline.enable_vae_tiling()
|
|
54
|
+
|
|
55
|
+
logger.info(f"HiDream model {model_config.name} loaded on {self.device}")
|
|
56
|
+
return True
|
|
57
|
+
except Exception as e:
|
|
58
|
+
logger.error(f"Failed to load HiDream model: {e}")
|
|
59
|
+
return False
|
|
60
|
+
|
|
61
|
+
def generate(
|
|
62
|
+
self,
|
|
63
|
+
prompt: str,
|
|
64
|
+
negative_prompt: str = "low quality, bad anatomy, worst quality, low resolution",
|
|
65
|
+
num_inference_steps: Optional[int] = None,
|
|
66
|
+
guidance_scale: Optional[float] = None,
|
|
67
|
+
width: int = 1024,
|
|
68
|
+
height: int = 1024,
|
|
69
|
+
seed: Optional[int] = None,
|
|
70
|
+
**kwargs,
|
|
71
|
+
) -> Image.Image:
|
|
72
|
+
if not self.pipeline:
|
|
73
|
+
raise RuntimeError("Model not loaded")
|
|
74
|
+
|
|
75
|
+
params = self.model_config.parameters or {}
|
|
76
|
+
steps = num_inference_steps if num_inference_steps is not None else params.get("num_inference_steps", 28)
|
|
77
|
+
guidance = guidance_scale if guidance_scale is not None else params.get("guidance_scale", 5.0)
|
|
78
|
+
max_seq_len = kwargs.get("max_sequence_length", params.get("max_sequence_length", 128))
|
|
79
|
+
|
|
80
|
+
generator, used_seed = self._make_generator(seed, self.device)
|
|
81
|
+
|
|
82
|
+
gen_kwargs = {
|
|
83
|
+
"prompt": prompt,
|
|
84
|
+
"negative_prompt": negative_prompt,
|
|
85
|
+
"num_inference_steps": steps,
|
|
86
|
+
"guidance_scale": guidance,
|
|
87
|
+
"width": width,
|
|
88
|
+
"height": height,
|
|
89
|
+
"max_sequence_length": max_seq_len,
|
|
90
|
+
"generator": generator,
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
# Support multiple text encoder prompts
|
|
94
|
+
for key in ("prompt_2", "prompt_3", "prompt_4"):
|
|
95
|
+
if key in kwargs:
|
|
96
|
+
gen_kwargs[key] = kwargs[key]
|
|
97
|
+
|
|
98
|
+
try:
|
|
99
|
+
logger.info(f"Generating HiDream image: steps={steps}, guidance={guidance}, seed={used_seed}")
|
|
100
|
+
output = self.pipeline(**gen_kwargs)
|
|
101
|
+
return output.images[0]
|
|
102
|
+
except Exception as e:
|
|
103
|
+
logger.error(f"HiDream generation failed: {e}")
|
|
104
|
+
return self._create_error_image(str(e), prompt)
|